관리 메뉴

Bull

[DL] MobileNet 요약 본문

Artificial Intelligence/Deep Learning

[DL] MobileNet 요약

Bull_ 2024. 8. 5. 09:49
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.utils import load_state_dict_from_url

# Depthwise Separable Convolution 블록 정의
class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(DepthwiseSeparableConv, self).__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=stride, padding=1, groups=in_channels)
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        x = F.relu(self.depthwise(x))
        x = F.relu(self.pointwise(x))
        return x

# MobileNetV1 모델 정의
class MobileNetV1(nn.Module):
    def __init__(self, num_classes=1):
        super(MobileNetV1, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1)
        self.conv2 = DepthwiseSeparableConv(32, 64, stride=1)
        self.conv3 = DepthwiseSeparableConv(64, 128, stride=2)
        self.conv4 = DepthwiseSeparableConv(128, 128, stride=1)
        self.conv5 = DepthwiseSeparableConv(128, 256, stride=2)
        self.conv6 = DepthwiseSeparableConv(256, 256, stride=1)
        self.conv7 = DepthwiseSeparableConv(256, 512, stride=2)

        # 5번 반복되는 블록
        self.conv8 = DepthwiseSeparableConv(512, 512, stride=1)
        self.conv9 = DepthwiseSeparableConv(512, 512, stride=1)
        self.conv10 = DepthwiseSeparableConv(512, 512, stride=1)
        self.conv11 = DepthwiseSeparableConv(512, 512, stride=1)
        self.conv12 = DepthwiseSeparableConv(512, 512, stride=1)

        self.conv13 = DepthwiseSeparableConv(512, 1024, stride=2)
        self.conv14 = DepthwiseSeparableConv(1024, 1024, stride=1)
        
        # 글로벌 평균 풀링
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        # 이진 분류를 위한 마지막 레이어
        self.fc = nn.Linear(1024, num_classes)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)
        x = self.conv7(x)
        x = self.conv8(x)
        x = self.conv9(x)
        x = self.conv10(x)
        x = self.conv11(x)
        x = self.conv12(x)
        x = self.conv13(x)
        x = self.conv14(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.fc(x)
        x = self.sigmoid(x)
        return x

# 모델 인스턴스 생성
model = MobileNetV1()

# 모델 출력 확인
print(model)

# 모델 손실 함수와 옵티마이저 설정
criterion = nn.BCELoss()  # 이진 크로스엔트로피 손실 함수
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 예시 입력 데이터
inputs = torch.randn(8, 3, 224, 224)  # 배치 크기 8, 3채널(RGB), 224x224 크기 이미지
labels = torch.randint(0, 2, (8, 1)).float()  # 배치 크기 8, 이진 라벨 (0 또는 1)

# Forward pass
outputs = model(inputs)

# 손실 계산
loss = criterion(outputs, labels)

# Backward pass 및 옵티마이저 스텝
optimizer.zero_grad()
loss.backward()
optimizer.step()

print(f"Loss: {loss.item()}")

MobileNet

MobileNet은 모바일 및 임베디드 장치에서 효율적으로 작동하도록 설계된 경량화된 신경망 아키텍처다. MobileNet의 핵심 개념은 Depthwise Separable Convolution와 Pointwise Convolution을 사용하여 연산량과 파라미터 수를 크게 줄이는 것이다.

1. Depthwise Separable Convolution

Depthwise Separable Convokution은 일반적인 CNN과 다르게 채널을 독립적으로 분리하여 컨볼루션을 진행한다.
여기서 채널은 첫 번째 레이어를 기준으로 했을 때 R, G, B 3개로 구분할 수 있다.

그래서 기본 CNN은 같은 경우 필터 수와 출력 채널 수가 같은 반면에 Depthwise는 필터수가 아닌 입력 채널 수와 출력 채널 수가 같아진다.

2. Pointwise Convolution

Pointwise Convolution은 1칸에 대해서만 컨볼루션을 진행한다. 만약 스트라이드가 1이라면 원본크기(w,h)를 그대로 유지할 수 있고 커널 사이즈는 $1 \times 1 \times C_{in}$ 라고 할 수 있다.

MobileNetV1 레이어 구조

간단하게 MobileNetV1 논문에 나와있는 레이어 구조를 살펴 보겠다.

표준 Convolution

  • 입력 이미지: $(224 \times 224 \times 3)$
  • 필터 크기: $(3 \times 3 \times 3)$
  • 스트라이드: 2
  • 출력 채널 수: 32
  • 출력 크기: $(112 \times 112 \times 32)$

MobileNetV1 레이어 구조 표

단계 레이어 종류 필터 크기 스트라이드 출력 채널 수 (필터 수) 출력 크기
1 표준 Convolution $(3 \times 3 \times 3)$ 2 32 $(112 \times 112 \times 32)$
2 Depthwise Convolution $(3 \times 3 \times 1)$ 1 32 $(112 \times 112 \times 32)$
3 Pointwise Convolution $(1 \times 1 \times 32)$ 1 64 $(112 \times 112 \times 64)$
4 Depthwise Convolution $(3 \times 3 \times 1)$ 2 64 $(56 \times 56 \times 64)$
5 Pointwise Convolution $(1 \times 1 \times 64)$ 1 128 $(56 \times 56 \times 128)$
6 Depthwise Convolution $(3 \times 3 \times 1)$ 1 128 $(56 \times 56 \times 128)$
7 Pointwise Convolution $(1 \times 1 \times 128)$ 1 128 $(56 \times 56 \times 128)$
8 Depthwise Convolution $(3 \times 3 \times 1)$ 2 128 $(28 \times 28 \times 128)$
9 Pointwise Convolution $(1 \times 1 \times 128)$ 1 256 $(28 \times 28 \times 256)$
10 Depthwise Convolution $(3 \times 3 \times 1)$ 1 256 $(28 \times 28 \times 256)$
11 Pointwise Convolution $(1 \times 1 \times 256)$ 1 256 $(28 \times 28 \times 256)$
12 Depthwise Convolution $(3 \times 3 \times 1)$ 2 256 $(14 \times 14 \times 256)$
13 Pointwise Convolution $(1 \times 1 \times 256)$ 1 512 $(14 \times 14 \times 512)$
14 Depthwise Convolution $(3 \times 3 \times 1)$ 1 512 $(14 \times 14 \times 512)$
15 Pointwise Convolution $(1 \times 1 \times 512)$ 1 512 $(14 \times 14 \times 512)$
16 Depthwise Convolution $(3 \times 3 \times 1)$ 1 512 $(14 \times 14 \times 512)$
17 Pointwise Convolution $(1 \times 1 \times 512)$ 1 512 $(14 \times 14 \times 512)$
18 Depthwise Convolution $(3 \times 3 \times 1)$ 1 512 $(14 \times 14 \times 512)$
19 Pointwise Convolution $(1 \times 1 \times 512)$ 1 512 $(14 \times 14 \times 512)$
20 Depthwise Convolution $(3 \times 3 \times 1)$ 1 512 $(14 \times 14 \times 512)$
21 Pointwise Convolution $(1 \times 1 \times 512)$ 1 512 $(14 \times 14 \times 512)$
22 Depthwise Convolution $(3 \times 3 \times 1)$ 2 512 $(7 \times 7 \times 512)$
23 Pointwise Convolution $(1 \times 1 \times 512)$ 1 1024 $(7 \times 7 \times 1024)$
24 글로벌 평균 풀링 - - 1024 $(1 \times 1 \times 1024)$
25 Dense 레이어 - - 1000 (클래스 수) 1000 (클래스 수)

어설프지만 쉽게 플면 이렇다.

컨볼루션 스트라이드 필터
표준 2 32
뎁스 1 -
포인트 1 필터증가 64
뎁스 2 -
포인트 1 필터증가 128
뎁스 1 -
포인트 1 -
뎁스 2 -
포인트 1 필터증가 256
뎁스 1 -
포인트 1 -
뎁스 2 -
포인트 1 필터증가 512
뎁스 1 -
포인트 1 -
뎁스 1 -
포인트 1 -
뎁스 1 -
포인트 1 -
뎁스 1 -
포인트 1 -
뎁스 2 -
포인트 1 필터증가 1024

 

Pointwise에서 필터를 증가시킨다. 필터 증가 이전에 포인트를 Depthwise s2를 적용한다. 초반에 필터를 2단계 증가시키는 부분은 준규칙성을 띄우면서 빠르게 증가시킨다. 이후 Depthwise s1 -> Pointwise s1 -> Depthwise s2 를 적용한 후 Pointwise s1 으로 필터 2배 증가 규칙성을 띈다.

512 필터 수 이후 5번 반복한다. 여기서 네트워크 깊이를 늘리고 더 많은 특징을 추출한다.

마지막으로 남은 $(7 \times 7 \times 1024)$를 글로벌 평균 풀링을 통해 $(1 \times 1 \times 1024)$로 만든다. 그리고 Fully Connected를 통해 1000개의 클래스로 출력한다.

 

CODE

 

 참고자료

https://ctkim.tistory.com/entry/%EB%AA%A8%EB%B0%94%EC%9D%BC-%EB%84%B7

 

모바일 넷 (MobileNet) 정리 및 구현

1. 모바일 넷 (MobileNet)이란? 모바일 넷은 스마트폰 및 기타 모바일 장치와 같이 리소스가 제한된 환경에서 효율적인 계산을 위해 설계된 경량 심층신경망으로 2017년 구글에서 개발했다. 모바일

ctkim.tistory.com