먼저 PointNetCls 클래스의 전체 코드를 살펴보며 PointNetCls가 Figure 상에서 어떤 부분을 담당하는지 파악해 보았다.
클래스 전체 코드는 아래와 같다.
class PointNetCls(nn.Module):
def __init__(self, k=2, feature_transform=False):
super(PointNetCls, self).__init__()
self.feature_transform = feature_transform
self.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, k)
self.dropout = nn.Dropout(p=0.3)
self.bn1 = nn.BatchNorm1d(512)
self.bn2 = nn.BatchNorm1d(256)
self.relu = nn.ReLU()
def forward(self, x):
x, trans, trans_feat = self.feat(x)
x = F.relu(self.bn1(self.fc1(x)))
x = F.relu(self.bn2(self.dropout(self.fc2(x))))
x = self.fc3(x)
return F.log_softmax(x, dim=1), trans, trans_feat
Figure에서 PointNetCls가 나타내는 부분은 Classification Network 전체로 볼 수 있다.
하지만 PointNetfeat이라는 커스텀 클래스가 호출되어 사용되기 때문에 구분이 필요하다고 보았다. forward 함수를 기준으로 보았을 때 코드 첫 줄인 feat 함수를 호출한 코드가 파란색 영역이며 이후 두 번째 줄부터는 빨간색 영역이라고 파악했다.
이 포스팅에서는 PointNetCls만 다룰 것이므로 feat 함수를 나타내는 PointNetfeat 클래스는 다음 포스팅에서 정리하겠다.
Dropout
BatchNorm1d
log_softmax
배치 정규화(Batch Normalization)
Softmax
F.relu(self.bn1(self.fc1(x)))
PointNet 코드 분석 - PointNetfeat (0) | 2024.08.24 |
---|---|
PointNet 코드 분석을 시작하기 전 (1) | 2024.08.24 |
[PyTorch Tutorial] LeNet-5 코드 분석 (0) | 2024.08.20 |