파이토치에서 제공하는 튜토리얼이 있는 것을 알게 되었고, 이를 활용하면 좋을 것 같다고 판단해 하나씩 공부해보려고 한다.
https://pytorch.org/tutorials/beginner/introyt/introyt1_tutorial.html
아래는 LeNet-5이라는 모델이다.
모델을 구현할 때는 Figure만 보고도 얼추 가능할 정도로 Figure의 역할이 중요하다.
Figure를 살펴보면 convolution 레이어가 2개, subsampling 레이어가 2개, full connection이 3개가 있는 것을 확인할 수 있다.
먼저 모델을 구현한 전체 코드는 아래와 같다
학부 시절 부전공으로 AI를 전공하긴 했지만 강의를 들은 지 오래되기도 했고, 전공 지식이 좀 많이 증발한 상태라서 코드 내에서 사용된 관련 개념들도 하나씩 짚고 넘어가 보도록 하겠다.
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# 1 input image channel (black & white), 6 output channels, 5x5 square convolution
# kernel
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
# an affine operation: y = Wx + b
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5*5 from image dimension
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
# Max pooling over a (2, 2) window
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
# If the size is a square you can only specify a single number
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, self.num_flat_features(x))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def num_flat_features(self, x):
size = x.size()[1:] # all dimensions except the batch dimension
num_features = 1
for s in size:
num_features *= s
return num_features
모델을 구현하기 위해서 먼저 클래스 선언이 필요하며, init과 forward 함수 2가지를 선언해줘야 한다.
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
def forward(self, x):
return
self.conv1 = nn.Conv2d(1, 6, 5) #in_channels, out_channels, kernel_size
* 컨볼루션(convolution): 입력 데이터에서 중요한 패턴을 추출하기 위해 필터를 사용하여 가중 합을 계산하는 연산
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5*5 from image dimension
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
Subsampling
Pooling
컨볼루션 이후 활성화 함수(relu)를 적용하는 이유
x.view(-1, self.num_flat_features(x))
이 코드에서 view(-1, self.num_flat_features(x))는 입력 텐서를 배치 크기 N을 유지하면서 2D 텐서 (N, C * H * W) 형태로 변환하며, 컨볼루션 레이어에서 추출된 특징을 완전 연결 레이어에 입력할 수 있도록 텐서를 평평하게 만드는 데 사용한다
PointNet 코드 분석 - PointNetfeat (0) | 2024.08.24 |
---|---|
PointNet 코드 분석 - PointNetCls (0) | 2024.08.24 |
PointNet 코드 분석을 시작하기 전 (1) | 2024.08.24 |