본문 바로가기

SW/DeepLearning

[Pytorch] Pytorch로 ResNet bottleneck 만들기

0. 개요

 

처음 PyTorch 를 쓰다보면 대부분 torch.nn.Sequential으로 네트워크를 구성하게 되는데,

다들 아는 AlexNet, VGGNet 다음의 ResNet을 구현하려고 하면 난관에 마주치게 된다.

 

Bottleneck layer 을 어떻게 구현하지?

 

본 포스팅에서는 ResNet의 bottleneck을 만드는 법을 긁어와 정리해보았다.

사실 내부적으로는 다 똑같은데, 코딩 스타일의 차이라고 느껴지지만... 아무튼.

 

1. ResNet 모듈 따로 만들기

 

import torch.nn as nn

class ResNetModule(nn.Module):
    def __init__(self):
    	super().__init__()
        self.inner = nn.Sequential(
        	..
        )
        self.outer = nn.Sequential(
        	..
        )
    def forward(self, x):
    	return self.outer(self.inner(x) + x)
        
class LargeModel(nn.Module):
    def __init__(self):
    	super().__init__()
        self.layer1 = ...
        self.layer2 = nn.Sequential(
        	ResNetModule(),
            ResNetModule(),
            ResNetModule(),
        )
        self.layer3 = ...
        
    def forward(self, x):
    	out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        return out

 

 

2. Bottleneck만 따로 구현하기

 

(1) 함수형

 

def bottleneck(x, non_skip):
	return non_skip(x)+x
    
class LargeModel(nn.Module):
    def __init__(self):
    	super().__init__()
        self.non_skip = nn.Sequential(
        	...
        )
        self.layer = nn.Sequential(
        	...,
        )
        
    def forward(self, x):
    	out = bottleneck(x, self.layer1)
        out = layer(out)
        return out

 

(2) 클래스형

 

import torch.nn as nn

class bottleneck(nn.Module):
	def __init__(self, non_skip):
    	super().__init__()
        self.non_skip = non_skip

    def forward(self, x):
    	return self.non_skip(x) + x

class LargeModel(nn.Module):
	def __init__(self):
    	super().__init__()
        self.layer1 = nn.Sequential(
        	...,
            ### nn.Sequential 중간에 존재 ###
            bottleneck(
            	nn.Sequential(
                	...,
                )
            ),
            ###############################
            ...,
        )
        
        self.layer2 = nn.Sequential(
        	...,
        )
        
    def forward(self, x):
    	out = self.layer1(x)
        out = self.layer2(out)
        return out

 

개인적으로는 2-(2)의 클래스형을 선호하는데,

큰 네트워크를 구성할 때에도 nn.Sequential을 사용할 수 있으며,

skip부분과 non_skip 부분을 나눠서 보기 편기하게 정리할 수 있기 때문이다.

 

3. Skipped Connection의 경우

 

skipped connection도 유사하게 만들 수 있는데

단순히 ResNet 모듈의 덧셈 연산을 concatenate으로 바꿔주면 된다.

 

import torch
import torch.nn as nn

class skipped(nn.Module):
    def __init__(self, non_skip):
    	super().__init__()
        self.non_skip = non_skip

    def forward(self, x):
    	return torch.cat((self.non_skip(x), x), 1)
        # return self.non_skip(x) + x