0. 개요
처음 PyTorch 를 쓰다보면 대부분 torch.nn.Sequential으로 네트워크를 구성하게 되는데,
다들 아는 AlexNet, VGGNet 다음의 ResNet을 구현하려고 하면 난관에 마주치게 된다.
Bottleneck layer 을 어떻게 구현하지?
본 포스팅에서는 ResNet의 bottleneck을 만드는 법을 긁어와 정리해보았다.
사실 내부적으로는 다 똑같은데, 코딩 스타일의 차이라고 느껴지지만... 아무튼.
1. ResNet 모듈 따로 만들기
import torch.nn as nn
class ResNetModule(nn.Module):
def __init__(self):
super().__init__()
self.inner = nn.Sequential(
..
)
self.outer = nn.Sequential(
..
)
def forward(self, x):
return self.outer(self.inner(x) + x)
class LargeModel(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = ...
self.layer2 = nn.Sequential(
ResNetModule(),
ResNetModule(),
ResNetModule(),
)
self.layer3 = ...
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = self.layer3(out)
return out
2. Bottleneck만 따로 구현하기
(1) 함수형
def bottleneck(x, non_skip):
return non_skip(x)+x
class LargeModel(nn.Module):
def __init__(self):
super().__init__()
self.non_skip = nn.Sequential(
...
)
self.layer = nn.Sequential(
...,
)
def forward(self, x):
out = bottleneck(x, self.layer1)
out = layer(out)
return out
(2) 클래스형
import torch.nn as nn
class bottleneck(nn.Module):
def __init__(self, non_skip):
super().__init__()
self.non_skip = non_skip
def forward(self, x):
return self.non_skip(x) + x
class LargeModel(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Sequential(
...,
### nn.Sequential 중간에 존재 ###
bottleneck(
nn.Sequential(
...,
)
),
###############################
...,
)
self.layer2 = nn.Sequential(
...,
)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
return out
개인적으로는 2-(2)의 클래스형을 선호하는데,
큰 네트워크를 구성할 때에도 nn.Sequential을 사용할 수 있으며,
skip부분과 non_skip 부분을 나눠서 보기 편기하게 정리할 수 있기 때문이다.
3. Skipped Connection의 경우
skipped connection도 유사하게 만들 수 있는데
단순히 ResNet 모듈의 덧셈 연산을 concatenate으로 바꿔주면 된다.
import torch
import torch.nn as nn
class skipped(nn.Module):
def __init__(self, non_skip):
super().__init__()
self.non_skip = non_skip
def forward(self, x):
return torch.cat((self.non_skip(x), x), 1)
# return self.non_skip(x) + x
'SW > DeepLearning' 카테고리의 다른 글
[Deep Learning] L1 Regularization의 Gradient에 관한 소고 (0) | 2020.07.02 |
---|---|
[PyTorch] PyTorch Basic - 기본 틀 (0) | 2020.06.29 |
[PyTorch] PyTorch Basic - Data Handling (0) | 2020.06.27 |