Machine Learning/Paper

[논문 구현] Auto-Encoding Variational Bayes (Varitional Auto Encoder)

devson 2024. 8. 5. 14:43

앞서 VAE에 대해 이론에 대해 살펴보았다. (논문 리뷰)

이제 이를 어떻게 PyTorch 코드로 구현할 수 있는지에 대해 알아보도록 하겠다.

 

논문에서는 흑백 이미지를 사용했지만 여기서는 컬러 이미지인 CelebA 데이터셋을 사용하도록 하겠다.

 

해당 코드는 여기에서 살펴볼 수 있다.

 


 

Base

import base libraries

import torch, torchvision
from torch import nn
import numpy as np
import torch.nn.functional as F

 

Configuration parameters

설정과 관련된 값들과 hyperparameter 등을 정한다.

  • 이미지 크기는 빠른 학습 결과를 보기위해 100x100으로 해두었다.
    • 원래 사이즈는 178×218 이다.
  • latent space의 차원은 500으로 두었다.
# Data
IMAGE_SIZE = (100, 100)

# ENV
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LATENT_DIM = 500

# TRAIN
BATCH_SIZE = 512
EPOCHS = 50
LR = 1e-3

 

 

Data

custom Dataset을 정의하고 DataLoader를 생성한다.

나의 경우 CelebA 데이터셋이 방대하기 때문에 디렉토리를 나눠서 저장해두었기 때문에 아래와 같은 코드를 작성하였는데 각자 데이터셋 사용 방법에 따라 바꾸면 될 것이다.

import torchvision.transforms as T
import os
from PIL import Image

transform = T.Compose([
    T.Resize(IMAGE_SIZE),
    T.ToTensor(),
])

class CelebADataset(torch.utils.data.Dataset):
    def __init__(self, paths, transform):
        self.file_paths = []
        
        for path in paths:
            for file_name in os.listdir(path):
                self.file_paths.append(f"{path}/{file_name}")

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        image = Image.open(self.file_paths[idx])
        return transform(image)

celeb_a_dataset = CelebADataset(paths=["../data/celeb_a/00",
                                       "../data/celeb_a/01"],
                                transform=transform)

celeb_a_dataloader = torch.utils.data.DataLoader(celeb_a_dataset,
                                                 shuffle=True,
                                                 batch_size=BATCH_SIZE)

 

데이터셋에 있는 이미지들을 몇 개 확인해보자.

import matplotlib.pyplot as plt

def tensor_to_pil_image(tensor_image):
    tensor_image = tensor_image * 255
    np_image = tensor_image.detach().numpy().transpose(1,2,0).astype(np.uint8)
    return Image.fromarray(np_image)

# plot
plt.figure(figsize=(10,10))
for i, image in enumerate(iter(celeb_a_dataset)):
    image = tensor_to_pil_image(image)
    
    plt.subplot(5,5,i+1)
    plt.imshow(image)
    plt.axis("off")
    if i == 24:
        break

 

다음과 같이 사람의 얼굴이 중심에 있는 이미지임을 확인할 수 있다.

 

 

Model

다음으로 모델을 코딩해보자.

나는 각 부분을 작게 나눠서 구현하였는데 AntixK/PyTorch-VAE 코드를 참고하였다.

 

Encoder

Encoder의 경우 이미지를 입력으로 받아 latent space의 평균과 분산을 출력하도록 한다.

 

구현에서는 분산이 아니라 log 분산(log_var)를 출력하도록 하였는데,

그 이유는 분산 값 자체는 매우 작기 때문에 log transform을 시킴으로써 이 값을 더 큰 수로 만들어주어 모델 훈련을 보다 원활하게 만들어주기 위함이다. (참고)

from torchsummary import summary

class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.latent_dim = latent_dim

        self.conv_layer = nn.Sequential(
            self.__encoder_module(3, 32),
            self.__encoder_module(32, 64),
            self.__encoder_module(64, 128),
            self.__encoder_module(128, 256),
            self.__encoder_module(256, 512),
        )
        self.layer_flatten = nn.Flatten(start_dim=1) # batch 제외
        self.layer_mean = nn.Linear(512*4*4, latent_dim) # 3x150x150 -> 512x4x4
        self.layer_log_var = nn.Linear(512*4*4, latent_dim)
        
    def __encoder_module(self, in_channel, out_channel):
        return nn.Sequential(
            nn.Conv2d(in_channel, out_channel,
                      kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(out_channel),
            nn.LeakyReLU(),
        )

    def forward(self, x):
        assert tuple(x.shape[-2:]) == IMAGE_SIZE

        x = self.conv_layer(x)
        x = self.layer_flatten(x)
        mean = self.layer_mean(x)
        log_var = self.layer_log_var(x)
        return mean, log_var

encoder = Encoder(latent_dim=LATENT_DIM)
summary(encoder, input_size=(3,100,100), device="cpu")

 

Decoder

Decoder는 latent space에서 샘플링 된 $z$를 입력으로 받아 이미지를 출력한다.

class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.latent_dim = latent_dim

        self.inputs = nn.Linear(latent_dim, 512*4*4)
        self.linear_to_conv = lambda x: x.view(-1, 512, 4, 4)
        self.conv_layer = nn.Sequential(
            self.__decoder_modules(512, 256, output_padding=0),
            self.__decoder_modules(256, 128, output_padding=0),
            self.__decoder_modules(128, 64, output_padding=0),
            self.__decoder_modules(64, 32),
        )
        self.output_layer = nn.ConvTranspose2d(32, 3,
                                               kernel_size=3, stride=2, 
                                               padding=1, output_padding=1)
    
    def __decoder_modules(self, in_channel, out_channel, output_padding=1):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channel, out_channel,
                               kernel_size=3, stride=2, 
                               padding=1, output_padding=output_padding),
            nn.BatchNorm2d(out_channel),
            nn.LeakyReLU(),
        )

    def forward(self, z):
        assert z.shape[-1] == self.latent_dim

        outputs = self.inputs(z)
        outputs = self.linear_to_conv(outputs)
        outputs = self.conv_layer(outputs)
        outputs = self.output_layer(outputs)
        return outputs

decoder = Decoder(latent_dim=LATENT_DIM)
summary(decoder, input_size=(LATENT_DIM,), device="cpu")

 

VAE

VAE 클래스는 Encoder와 Decoder를 이어주는 역할을 하도록 만들었다.

여기서 reparameterize 메서드를 통해 Reparameterization trick을 구현한 것을 확인할 수 있다.

class VAE(nn.Module):
    @staticmethod
    def create(latent_dim):
        encoder = Encoder(latent_dim)
        decoder = Decoder(latent_dim)
        return VAE(encoder, decoder)

    def __init__(self, encoder, decoder):
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder

    def forward(self, x):
        mean, log_var = self.encoder(x)
        z = self.reparameterize(mean, log_var)
        x_prime = self.decoder(z)
        return x_prime, mean, log_var
    
    def reparameterize(self, mean, log_var):
        std = torch.exp(0.5 * log_var)
        epsilon = torch.randn_like(std)
        return epsilon*std + mean

    # utils
    def save(self, path):
        torch.save(self.state_dict(), path)

    @staticmethod
    def load(latent_dim, path):
        vae = VAE.create(latent_dim)
        vae.load_state_dict(torch.load(path))
        return vae

vae = VAE.create(latent_dim=LATENT_DIM)
summary(vae, input_size=(3,100,100), device="cpu")

 

Train

Loss function

loss function은 ELBO를 최대화하기 위한 수식이라고 보면된다. (참고)

loss로 사용하기 위해 -ELBO를 loss로 사용한다.

def loss_fn(x, x_prime, mean, log_var):
    reconstruction_loss = F.mse_loss(x_prime, x, reduction="sum")
    kld = -0.5 * torch.sum(1 + log_var - mean**2 - log_var.exp())
    return reconstruction_loss + kld

 

Model Training

Model과 Loss를 정해두었으니 이제 모델을 훈련시킨다.

from tqdm import tqdm

optimizer = torch.optim.Adam(vae.parameters(), lr=LR)

vae.to(DEVICE)
vae.train()

for epoch in tqdm(range(EPOCHS), desc="Epoch", position=0):
    for x in tqdm(celeb_a_dataloader, desc="Batch", position=0, leave=False):
        x = x.to(DEVICE)
        optimizer.zero_grad()
        
        x_prime, mean, log_var = vae(x)
        loss = loss_fn(x, x_prime, mean, log_var)
        loss.backward()
        optimizer.step()
        
    print(f"loss: {loss.item()}")

 

Inference

모델을 훈련시킨 뒤 성능을 확인해보도록 하겠다.

따로 metric을 사용하진 않고 정성적으로 성능이 어떻게 되는지 확인해보도록 하겠다.

 

Reconstruction

먼저 원본 이미지를 얼마나 그대로 출력하는지를 살펴보았다.

sample_dataloader = torch.utils.data.DataLoader(celeb_a_dataset,
                                                shuffle=True,
                                                batch_size=10)
sample_images = next(iter(sample_dataloader)).to(DEVICE)
x_prime, mean, log_var = vae(sample_images)

plt.figure(figsize=(10,10))
for i, (origin_image, reconstructed_image) in enumerate(zip(sample_images, x_prime)):
    reconstructed_image = reconstructed_image.to("cpu")
    reconstructed_image = tensor_to_pil_image(reconstructed_image)
    
    origin_image = origin_image.to("cpu")
    origin_image = tensor_to_pil_image(origin_image)

    plt.subplot(5,4,2*i+1)
    plt.imshow(origin_image)
    plt.title("origin")
    plt.axis("off")

    plt.subplot(5,4,2*i+2)
    plt.imshow(reconstructed_image)
    plt.title("reconstructed")
    plt.axis("off")

 

latent space 차원에 따른 비교를 해보기 위해 차원을 500과 1,000으로 두었을 때를 비교해보았다.

 

Generate sample images

그 다음으로 latent sapce 내에서 $z$를 랜덤으로 샘플링하여 어떤 결과가 나오는지를 확인해보았다.

sample_z = torch.randn((25, LATENT_DIM)).to(DEVICE)

plt.figure(figsize=(10,10))
for i, generated_image in enumerate(vae.decoder(sample_z)):
    generated_image = generated_image.to("cpu")
    generated_image = tensor_to_pil_image(generated_image)
    
    plt.subplot(5,5,i+1)
    plt.imshow(generated_image)
    plt.axis("off")

 

 

대체적으로 결과가 좋지 않게 나왔고, 특히 이미지가 전체적으로 blurry하게 나오게되는데

차원 수를 2배로 늘려도 눈에 띄는 개선이 보이지 않는걸 보면 단순한 vanilla VAE 구조로는 현실의 이미지와 비슷한 이미지를 생성하기에는 어느정도 무리가 있다고 본다.

 

관련해서는 아래 글을 참고해보면 좋을 것이다.