1. 클래스 상속 (Class inheritance)
클래스는 다른 클래스의 메서드등을 상속받아 그대로 사용할 수 있다.
Parent class가 상속을 해주는 클래스, Child class가 상속을 받는 클래스이다.
class Person:
def __init__(self, fname, lname): # init으로 fname, lanme을 받아 firstname, lastname에 각각 저장함.
self.firstname = fname
self.lastname = lname
def printname(self): # 저장한 firstname, lastname을 출력함.
print(self.firstname, self.lastname)
x = Person("John", "Doe")
x.printname()
output:
John Doe
Parent class 상속 받기
위와 같은 Parent class를 선언한 후 Child class에서 상속받아 보자.
class Student(Person):
pass
Parent class인 Person에서 모든 것을 상속받았고 그 외의 기능은 없으므로(pass) Person과 동일하게 사용할 수 있다. 가령,
y = Student("HH", "jj")
y.printname()
output:
HH jj
__init__ 함수 대체하기 (override)
Parent class의 __init__ 함수가 마음에 안든다면 대체할 수 있다. 그냥 원래 __init__ 함수를 선언하듯 하면 된다.
class Student(Person):
def __init__(self, fname, lname):
print(fname, lname)
y = Student("hello", 100)
output:
hello 100
__init__ 함수만 바뀐 것이기 때문에 Parent class의 다른 기능들은 사용할 수 있다.
y.printname() # 원래는 사용 가능하다! 다만 이 경우 Parent class에서 self.firstname, self.lastname으로 저장을 했었는데 그러지 않았으므로 에러를 출력할 것이다.
super() 함수로 상속하기
Parent class의 __init__을 유지하고 싶다면 아래처럼 "Parent_class_name.__init__()"의 문법을 따르면 된다.
class Student(Person):
def __init__(self, fname, lname):
Person.__init__(self, fname, lname)
하지만 굳이 이름을 넣지 않고도 super()를 통해서도 가능하다.
class Student(Person):
def __init__(self, fname, lname):
super().__init__(fname, lname) # __init__에 'self' 인자가 들어가지 않는다.
다만 super()를 사용 시 __init__에 'self' 인자를 넣으면 안된다.
처음에는 "어 근데 유지를 하고싶은데 굳이 선언을 해줘야 하나? 위에 pass처럼 그냥 두면 되는거 아닌가?" 라는 생각이 들었는데 아래를 보고 의문이 풀렸다.
class Student(Person):
def __init__(self, fname, lname, year):
super().__init__(fname, lname)
self.graduationyear = year
x = Student("Mike", "Olsen", 2019)
이렇게 코드를 짜면 Person의 __init__을 받으면서 self.graduationyear까지 init에 추가시킬 수 있다. 말하자면 Person.__init__ + My __init__
Person의 __init__에서 self.firstname, self.lastname이 저장되고 나의 __init__에서 self.graduationyear가 저장되었다.
2. Pytorch Neural Network Model
Pytorch로 설계하는 뉴럴네트워크는 기본적으로 아래와 같은 코드구조를 갖는다. Pytorch 내장 모델뿐만 아니라 사용자 설정 모델도 아래의 구조를 따라야 한다고 한다. 아래는 예시 LeNet이다.
import torch.nn as nn
import torch.nn.functional as F
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 3)
self.conv2 = nn.Conv2d(6, 16, 3)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = Model_Name(...)
이게 무슨 말인가 하면,
1. nn.Module을 상속해야 한다.
→ nn.Module을 상속받아서 기본적인 기능들을 사용할 수 있게 만들어야 한다. 깊게 들어가면 복잡하다. 이 블로그를 참고해보자.
쉽게 설명하자면, nn.Module의 기능을 빌리려고 상속한다. 너무 당연한 말인가?
예를 들면, forward() 함수를 호출하지 않아도 파라미터가 전달되면 바로 forward function이 시작된다. 이것은 nn.Module을 상속받기 때문이다.
2. __init__ 함수를 override 해야한다.
3. forward 함수를 override 해야한다.
중간에 super(LeNet, self).__init__() 은 뭘까?:
super().__init__() vs super(Class, self).__init__()
super() 함수에 parameter로 class와 self를 받는다. 단순 super().__init__()과 어떻게 다를까?
위 LeNet의 경우 두 경우 기능적으로 차이가 없고 파생클래스와 self를 넣어서 현재 클래스가 어떤 클래스인지 명확하게 표시해줄 수 있다고 한다.
그럼 처음의 Student, Person 예시에서는 아래 두 코드가 같다.
class Student(Person):
def __init__(self, fname, lname):
super().__init__(fname, lname) # __init__에 'self' 인자가 들어가지 않는다.
class Student(Person):
def __init__(self, fname, lname):
super(Student, self).__init__(fname, lname) # __init__에 'self' 인자가 들어가지 않는다.
그리고,
super().__init__() → python 3 에서만 작동함
super(Class, self).__init__() → python 2,3 모두 작동함
이런 조그만 차이가 있다고 한다.
그럼 어디서 큰 차이가 나느냐, 이 블로그에서 설명을 잘 해주었다.
아래와 같은 세 class들이 있다. class B는 A를 상속받고, class C는 B를 상속받는다. 쉽게 말해 A→B→C 순으로 상속.
class A():
def __init__(self):
self.a = 10
def get_a(self):
return self.a
class B(A):
def __init__(self):
super(B,self).__init__()
self.b = 20
def get_b(self):
return self.b
class C(B):
def __init__(self):
super(C,self).__init__()
self.c = 30
def get_c(self):
return self.c
여기서 class B와 C에 있는 super(SelfClass, self).__init__()은 super().__init__()과 같은 기능을 한다.
new_c = C()
print(new_c.get_a())
print(new_c.get_c())
print(new_c.get_b())
output:
10
20
30
하지만 만약 class C에 있는 super(C, self).__init__()을 super(B, self).__init__()으로 바꾼다면?
이렇게 바꾼다면 "class C(B)"에 의해서 class B의 다른 것들은 상속을 받겠지만 __init__단계에서는 class B의 __init__을 상속 받으므로 결국 self.b는 상속받지 못한다.
class A:
def __init__(self):
self.a = 10
def get_a(self):
return self.a
class B(A):
def __init__(self):
super(B,self).__init__()
self.b = 20
def get_b(self):
return self.b
class C(B):
def __init__(self):
super(B,self).__init__() # super(C,self)가 바뀜
self.c = 30
def get_c(self):
return self.c
new_c = C()
print(new_c.get_a())
print(new_c.get_c())
output:
10
30
위는 가능하지만, 아래는 에러가 난다.
print(new_c.get_b())
AttributeError: 'C' object has no attribute 'b'
고민 해결!
출처:
https://www.w3schools.com/python/python_inheritance.asp
[DL, PyTorch] 신경망 모델 정의하기 -- Class, nn.Module
PyTorch로 신경망 모델을 설계할 때, 크게 다음과 같은 세 가지 스텝을 따르면 된다. Design your model using class with Variables Construct loss and optim Train cycle (forward, backward, update) 이 포스팅에선 첫번째 단계
anweh.tistory.com
https://daebaq27.tistory.com/60
[Pytorch] nn.Module & super().__init__()
우리는 pytorch에서 각자 레이어 혹은 모델을 구성할 때, nn.Module을 상속받는다. 왜 상속을 받을까? 또 상속받을 때, super().__init__()은 왜 해주는 것일까? 해당 코드를 작성함으로써 어떤 속성을 갖게
daebaq27.tistory.com
'Coding' 카테고리의 다른 글
[Python] HackerRank - Capitalize! (0) | 2023.07.23 |
---|---|
[Python] torch.nn.CrossEntropyLoss 에서 ignore_index (0) | 2023.04.18 |
[Python] torch.scatter_ 이해하기 (0) | 2023.04.18 |
[Python] string split과 rsplit method 차이 (0) | 2023.04.16 |
torch.nn.utils.prune에서 L1Unstructured와 l1_unstructured의 차이 (0) | 2023.01.25 |