상세 컨텐츠

본문 제목

PointNet 코드 분석 - PointNetfeat

AI/PyTorch

by 쑤야. 2024. 8. 24. 15:25

본문

PointNetfeat이 논문의 Figure에서 지칭하는 부분을 미리 알아봤었다.

 

아래 그림의 빨간 영역이다. 

 

 

클래스 전체 코드는 아래와 같다. 

 

class PointNetfeat(nn.Module):
    def __init__(self, global_feat = True, feature_transform = False):
        super(PointNetfeat, self).__init__()
        self.stn = STN3d()
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.global_feat = global_feat
        self.feature_transform = feature_transform
        if self.feature_transform:
            self.fstn = STNkd(k=64)

    def forward(self, x):
        n_pts = x.size()[2]
        trans = self.stn(x)
        x = x.transpose(2, 1)
        x = torch.bmm(x, trans)
        x = x.transpose(2, 1)
        x = F.relu(self.bn1(self.conv1(x)))

        if self.feature_transform:
            trans_feat = self.fstn(x)
            x = x.transpose(2,1)
            x = torch.bmm(x, trans_feat)
            x = x.transpose(2,1)
        else:
            trans_feat = None

        pointfeat = x
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)
        if self.global_feat:
            return x, trans, trans_feat
        else:
            x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
            return torch.cat([x, pointfeat], 1), trans, trans_feat

1. pytorch 기능 정리

 

torch.Tensor.transpose

  • 입력값으로 넣은 2 개의 차원을 서로 맞교환

torch.bmm

  • 배치 행렬 곱셈을 수행

torch.max

  • 텐서의 최댓값을 찾는 함수

torch.Tensor.view

  • 텐서의 크기 변경, 텐서의 데이터를 그대로 유지하면서 원하는 모양으로 재구성

torch.Tensor.repeat

  • 기존의 텐서를 주어진 횟수만큼 복제하여 새로운 크기의 텐서를 생성
# 첫 번재 차원에서 2번, 두 번째 차원에서 3번 반복
x.repeat(2, 3)

torch.cat

  • 텐서들을 특정 차원에 따라 결합하여 새로운 텐서를 생성하는 데 사용


2. 단계별로 나눠서 분석하기

1. 입력 데이터의 크기 및 변수 초기화

  • x는 3D 포인트 클라우드 데이러를 나타내는데, n_pts 변수에 포인트 클라우드의 점 개수를 저장
n_pts = x.size()[2]

 

2. 입력 데이터에 대한 공간 변환 적용

  • 위 그림을 살펴보면 입력값 n x 3과 T-Net을 통해 반환된 3 x 3 크기의 데이터가 행렬 곱셈 연산이 수행되어 n x 3 크기의 데이터가 반환되어야 한다
  • 이때 행렬 곱셈 연산을 위해서 n x 3 데이터는 3 x n으로 전치시킨 후, 곱셈 연산을 수행한다
trans = self.stn(x) # 3 x 3
x = x.transpose(2, 1) # n x 3 → 3 x n
x = torch.bmm(x, trans) # 행렬 곱셈 연산 후, 3 x n 
x = x.transpose(2, 1) # 3 x n → n x 3

 

3. 첫 번째 합성곱 연산과 배치 정규화, 활성화 함수 적용

x = F.relu(self.bn1(self.conv1(x)))

 

4. 특징 변환 조건부 적용

if self.feature_transform:
    trans_feat = self.fstn(x)
    x = x.transpose(2, 1)
    x = torch.bmm(x, trans_feat)
    x = x.transpose(2, 1)
else:
    trans_feat = None

 

5. 포인트 특징 저장

pointfeat = x

 

6. 두 번째 합성곱 연산과 활성화 함수, 배치 정규화 적용

x = F.relu(self.bn2(self.conv2(x)))
x = self.bn3(self.conv3(x))

 

7. 특징 맥스 풀링 및 변환

x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, 1024)
  • 두 번째 차원(포인트 수)에 대해 max pooling을 수행하여 각 특징 채널의 최댓값을 선택
    → 각 점에 대해 가장 중요한 특징만을 남기는 역할
  • view 함수를 통해 2D 텐서로 변환하여 (B, 1024) 크기의 텐서로 만든다

 

8. 글로벌 및 지역 특징 결합 여부에 따른 출력 결정

if self.global_feat:
    return x, trans, trans_feat
else:
    x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
    return torch.cat([x, pointfeat], 1), trans, trans_feat
  • 전역 특징이 아닌 경우, 확장된 x와 지역 포인트 특징 pointfeat을 연결하여 최종 특징을 생성
  • 반환된 값은 글로벌 및 지역 특징을 모두 포함하게 된다

1편

https://developer-ssooya.tistory.com/entry/PointNet-코드-분석을-시작하기-전

 

PointNet 코드 분석을 시작하기 전

파이토치를 공부할 겸 포인트 클라우드 딥러닝 모델 중 유명한 PointNet 코드를 분석해 보도록 하겠다.  공부용으로 선택한 코드는 papers with code라는 사이트에서 PointNet을 검색하여 파이토치로 구

developer-ssooya.tistory.com

 

2편

https://developer-ssooya.tistory.com/entry/PointNet-코드-분석-PointNetCls

 

PointNet 코드 분석 - PointNetCls

먼저 PointNetCls 클래스의 전체 코드를 살펴보며 PointNetCls가 Figure 상에서 어떤 부분을 담당하는지 파악해 보았다.  클래스 전체 코드는 아래와 같다. class PointNetCls(nn.Module): def __init__(self, k=2, featu

developer-ssooya.tistory.com

 

'AI > PyTorch' 카테고리의 다른 글

PointNet 코드 분석 - PointNetCls  (0) 2024.08.24
PointNet 코드 분석을 시작하기 전  (1) 2024.08.24
[PyTorch Tutorial] LeNet-5 코드 분석  (0) 2024.08.20

관련글 더보기