관리 메뉴

Bull

[DL::GAN::Metric] FID(Frechet Inception Distance) 본문

Artificial Intelligence/Deep Learning

[DL::GAN::Metric] FID(Frechet Inception Distance)

Bull_ 2024. 11. 25. 11:11
  • FID(Frechet Inception Distance)은 Generative Adversarial Networks(GAN) 같은 생성 모델에서 생성된 이미지의 품질과 다양성을 평가하는 지표.
  • 생성된 이미지와 실제 데이터 간의 분포 차이를 측정하여 얼마나 사실적인지를 평가

1. FID의 핵심 개념

  1. FID는 두 개의 이미지 데이터 분포(실제 이미지와 생성된 이미지) 간의 Frechet Distance를 계산
  2. Inception 네트워크(사전 훈련된 Inception v3 모델)를 사용하여 이미지의 특징 벡터를 추출한 후, 특징 벡터의 평균과 공분산을 비교
  • 평균(mean): 각 특징 차원의 평균 값
  • 공분산(covariance): 각 특징 차원의 상관 관계를 나타내는 행렬

2. 계산 과정

$$FID = ||\mu_r - \mu_g||^2 + Tr(\Sigma_r + \Sigma_g - 2 \cdot (\Sigma_r \cdot \Sigma_g)^{1/2})$$

  • $\Sigma_r$: 실제 데이터 분포의 평균과 공분산
  • $\Sigma_g$: 생성된 데이터 분포의 평균과 공분산
  • $Tr$: 행렬의 대각합(trace)을 계산하는 연산
  1. 특징 추출:
    • Inception v3 모델의 중간 레이어(일반적으로 Pooling 레이어)에서 실제 이미지와 생성 이미지의 특징 벡터를 추출
  2. 통계 계산:
    • 실제 이미지와 생성된 이미지 각각의 평균 $\mu$와 공분산 행렬 $\Sigma$를 계산
  3. Frechet Distance 계산:
    • 위의 수식을 적용하여 분포 간의 거리를 계산

3. 특징

-낮은 FID 점수

  • 생성된 이미지 분포가 실제 이미지 분포와 가까움을 의미
    • 두 분포 간 차이가 크다는 것을 나타내며, 생성된 이미지의 품질이 낮거나 다양성이 부족함
  • 높은 FID 점수

4. 장점

이미지 품질 평가

  • FID는 단순히 픽셀 기반이 아닌, 이미지의 고수준 특성을 평가하여 더 신뢰성 있는 평가
  • 다양성 고려*
    • 단일 이미지가 아닌 전체 데이터 분포를 기반으로 평가하기 때문에 다양성을 반영할 수 있다

5. 단점

Inception 모델 의존성

  • Inception v3 모델을 기반으로 하므로, 특정 도메인에 적합하지 않을 수 있다.

훈련된 데이터에 민감

  • 실제 데이터 분포의 품질이 낮다면 FID 점수의 신뢰성이 떨어질 수 있다.
    • 입력 이미지의 크기나 전처리에 따라 점수가 변동될 수 있다.
  • 스케일 민감성

FID 계산 코드

import torch
from torch.nn.functional import adaptive_avg_pool2d
from torchvision.models import inception_v3
from scipy.linalg import sqrtm
import numpy as np
def calculate_fid(real_features, fake_features):
    """
    FID를 계산하는 함수
    :param real_features: 실제 데이터 특징 벡터 (NumPy 배열)
    :param fake_features: 생성 데이터 특징 벡터 (NumPy 배열)
    :return: FID 값
    """
    mu_r, sigma_r = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
    mu_f, sigma_f = fake_features.mean(axis=0), np.cov(fake_features, rowvar=False)

    # Frechet Distance 계산
    mean_diff = mu_r - mu_f
    cov_sqrt = sqrtm(sigma_r @ sigma_f)

    # 복소수 처리
    if np.iscomplexobj(cov_sqrt):
        cov_sqrt = cov_sqrt.real

    fid = np.sum(mean_diff**2) + np.trace(sigma_r + sigma_f - 2 * cov_sqrt)
    return fid
def extract_features(model, images):
    """
    이미지 데이터를 Inception v3 모델에서 특징 추출
    :param model: Inception v3 모델
    :param images: 입력 이미지 (Tensor)
    :return: 특징 벡터 (NumPy 배열)
    """
    with torch.no_grad():
        model.eval()
        # Inception v3는 299x299 크기에서 동작
        images = adaptive_avg_pool2d(images, (299, 299))
        features = model(images)
    return features.cpu().numpy()
# Inception v3 모델 로드
inception = inception_v3(pretrained=True).eval()

# 샘플 데이터 (이미지 텐서를 준비해야 함)
real_images = torch.randn(32, 3, 299, 299)  # 실제 데이터 (32장 샘플)
fake_images = torch.randn(32, 3, 299, 299)  # 생성 데이터 (32장 샘플)

# 특징 추출
real_features = extract_features(inception, real_images)
fake_features = extract_features(inception, fake_images)

# FID 계산
fid_score = calculate_fid(real_features, fake_features)
print(f"FID Score: {fid_score}")

출력

FID Score: 20.881609066726284

이론

  • 10 이하: 매우 뛰어난 성능 (실제 데이터와 거의 동일한 수준).
  • 10~30: 연구/논문에서 자주 인용되는 성능. 대부분의 GAN 모델이 목표로 삼는 범위.
  • 30 이상: 품질이 부족하거나 분포 차이가 큼.

실제

  • 초기 GAN(예: DCGAN)의 경우 FID 값이 50~100 정도로 높았다.
  • StyleGAN 또는 BigGAN 같은 최신 GAN 모델은 FID 값을 10~20 이하로 낮추는 데 성공했다.