PointNetfeat이 논문의 Figure에서 지칭하는 부분을 미리 알아봤었다.
아래 그림의 빨간 영역이다.
클래스 전체 코드는 아래와 같다.
class PointNetfeat(nn.Module):
def __init__(self, global_feat = True, feature_transform = False):
super(PointNetfeat, self).__init__()
self.stn = STN3d()
self.conv1 = torch.nn.Conv1d(3, 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(1024)
self.global_feat = global_feat
self.feature_transform = feature_transform
if self.feature_transform:
self.fstn = STNkd(k=64)
def forward(self, x):
n_pts = x.size()[2]
trans = self.stn(x)
x = x.transpose(2, 1)
x = torch.bmm(x, trans)
x = x.transpose(2, 1)
x = F.relu(self.bn1(self.conv1(x)))
if self.feature_transform:
trans_feat = self.fstn(x)
x = x.transpose(2,1)
x = torch.bmm(x, trans_feat)
x = x.transpose(2,1)
else:
trans_feat = None
pointfeat = x
x = F.relu(self.bn2(self.conv2(x)))
x = self.bn3(self.conv3(x))
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, 1024)
if self.global_feat:
return x, trans, trans_feat
else:
x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
return torch.cat([x, pointfeat], 1), trans, trans_feat
# 첫 번재 차원에서 2번, 두 번째 차원에서 3번 반복
x.repeat(2, 3)
1. 입력 데이터의 크기 및 변수 초기화
n_pts = x.size()[2]
2. 입력 데이터에 대한 공간 변환 적용
trans = self.stn(x) # 3 x 3
x = x.transpose(2, 1) # n x 3 → 3 x n
x = torch.bmm(x, trans) # 행렬 곱셈 연산 후, 3 x n
x = x.transpose(2, 1) # 3 x n → n x 3
3. 첫 번째 합성곱 연산과 배치 정규화, 활성화 함수 적용
x = F.relu(self.bn1(self.conv1(x)))
4. 특징 변환 조건부 적용
if self.feature_transform:
trans_feat = self.fstn(x)
x = x.transpose(2, 1)
x = torch.bmm(x, trans_feat)
x = x.transpose(2, 1)
else:
trans_feat = None
5. 포인트 특징 저장
pointfeat = x
6. 두 번째 합성곱 연산과 활성화 함수, 배치 정규화 적용
x = F.relu(self.bn2(self.conv2(x)))
x = self.bn3(self.conv3(x))
7. 특징 맥스 풀링 및 변환
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, 1024)
8. 글로벌 및 지역 특징 결합 여부에 따른 출력 결정
if self.global_feat:
return x, trans, trans_feat
else:
x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
return torch.cat([x, pointfeat], 1), trans, trans_feat
1편
https://developer-ssooya.tistory.com/entry/PointNet-코드-분석을-시작하기-전
PointNet 코드 분석을 시작하기 전
파이토치를 공부할 겸 포인트 클라우드 딥러닝 모델 중 유명한 PointNet 코드를 분석해 보도록 하겠다. 공부용으로 선택한 코드는 papers with code라는 사이트에서 PointNet을 검색하여 파이토치로 구
developer-ssooya.tistory.com
2편
https://developer-ssooya.tistory.com/entry/PointNet-코드-분석-PointNetCls
PointNet 코드 분석 - PointNetCls
먼저 PointNetCls 클래스의 전체 코드를 살펴보며 PointNetCls가 Figure 상에서 어떤 부분을 담당하는지 파악해 보았다. 클래스 전체 코드는 아래와 같다. class PointNetCls(nn.Module): def __init__(self, k=2, featu
developer-ssooya.tistory.com
PointNet 코드 분석 - PointNetCls (0) | 2024.08.24 |
---|---|
PointNet 코드 분석을 시작하기 전 (1) | 2024.08.24 |
[PyTorch Tutorial] LeNet-5 코드 분석 (0) | 2024.08.20 |