본문 바로가기
Machine Learning/Deep Learning

🤗 Transformers - CLIPModel fine-tuning

by devson 2024. 8. 15.

앞서 image classification을 하는 예시를 통해 🤗 Transformers의 CLIPModel의 사용법을 알아보았다. - 참고

 

이번에는 CLIPModel을 custom dataset에 맞춰 fine-tuning하는 예제에 대해 알아보도록 하겠다.

(전체 코드는 여기에서 살펴볼 수 있다)


 

개요

데이터

모델을 훈련시키기 위해서 image-text pair 데이터가 필요하다.

여기서는 CIFAR10 데이터셋을 사용할 것이고, image의 label을 통해 text를 생성하여 pair 데이터를 만들어 사용할 것이다.

 

이 포스팅에서는 CIFAR10 데이터셋을 사용하였지만 이미지 사이즈도 작기도 하고 class도 10개로 매우 적기 때문에 실제로는 fine-tuning을 적용하기엔 부적합하다.
예시를 보여주기 위함이니 어떤 식으로 fine-tuning이 진행되는지만 참고하도록 하자.

 

학습 과정

전체적인 학습 과정에 대한 overview는 아래와 같다.

  • CLIPProcessor, CLIPModel을 통해 image-text pair의 embedding을 구한다.
  • 동일한 image-text pair의 embedding을 끼리는 서로 가까워야하기 때문에 dot product 값이 커야한다.
    • 동일한 image-text pair의 내적값이 최대화되도록 학습을 진행시킨다.

 

코드 구현

이제 본격적으로 fine-tuning 코드를 살펴보자.

 

Config

개발 환경과 hyperparameter에 대한 config 설정이다.

각자 상황에 맞게 바꾸면 된다.

# env
DEVICE = "cuda"
DATA_PATH = "data"

# train
BATCH_SIZE = 512
EPOCHS = 5
LR = 5e-5

 

Load pretrained model

pretrained CLIP 모델을 가져온다.

from transformers import CLIPProcessor, CLIPModel

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
model.to(DEVICE)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", clean_up_tokenization_spaces=False)

 

Setup Dataset

데이터를 CLIPModel에서 사용할 수 있도록 가공하기 위해 custom Dataset을 구현하였다.

CIFAR10 데이터를 조회하여 image-text pair 데이터를 만들고 CLIPProcessor를 통해 전처리하는 Dataset이다.

import torch, torchvision

class CIFAR10Dataset(torch.utils.data.Dataset):
    def __init__(self, clip_processor, is_train):
        self.clip_processor = clip_processor
        self.is_train = is_train
        self.dataset = torchvision.datasets.CIFAR10(DATA_PATH, train=is_train)
        self.class_texts = [
            f"A photo of a {class_}."
            for class_ in self.dataset.class_to_idx.keys()
        ]
        
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        text = self.class_texts[label]
        return {
            "image": image,
            "label": label,
            "text": text,
        }
    
    # batch로 데이터 전처리
    def preprocess(self, batch):
        images = [data["image"] for data in batch]
        labels = [data["label"] for data in batch]
        texts = [data["text"] for data in batch]

        inputs = self.clip_processor(
            text=texts,
            images=images,
            return_tensors="pt", 
            padding=True
        )
        
        return {
            "text": texts,
            "label": labels,
            **inputs,
        }

train_dataset = CIFAR10Dataset(processor, is_train=True)
test_dataset = CIFAR10Dataset(processor, is_train=False)

train_dataloader = torch.utils.data.DataLoader(train_dataset, collate_fn=train_dataset.preprocess, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, collate_fn=train_dataset.preprocess, batch_size=BATCH_SIZE)

 

 

CLIPProcessor의 출력값이 batch 데이터를 기준으로 나오기 때문에 Dataset.__getitem__ 에서 전처리를 하면 차원을 바꿔줘야 해서 불필요하게 코드가 장황해진다.

그렇기 때문에 Dataloader의 collate_fn에서 CLIPProcessor를 통해 batch 데이터를 전처리하도록 하여 코드를 간결하게 만들었다.

 

Fine-tune model

모델과 데이터가 준비되었으니 학습을 위해 loss를 정하면 학습을 진행할 수 있다.

loss는 논문에서와 동일하게 CrossEntropy를 사용하였다.

https://arxiv.org/pdf/2103.00020

  • logits는 image와 text를 내적한 결과이다.
  • labels는 우리가 pair 데이터를 갖고 contrastive learning을 하기 때문에 arange로 만들어주면 된다.

 

여기서 loss_i와 loss_t의 경우 어차피 동일한 형태로 볼 수 있을텐데 왜 나눠서 구하는지 이해가 안될 수도 있다.

 

이는 각각 데이터 입장에서 loss를 구한 것이라고 볼 수 있는데,

loss_i의 경우 주어진 image에서 어떤 text가 가장 적합한지를 측정하기 위한 loss이고,

loss_t의 경우 주어진 text에 대해서 어떤 image가 가장 적합한지를 측정하기 위한 loss이다.

(관련해서는 여기를 참고하길 바란다)

 

이를 코드로 나타내면 아래와 같다.

logits_per_imagelogits_per_text는 서로 transpose 관계이다.

import torch.nn.functional as F

def loss_fn(logits_per_image, logits_per_text):
    assert logits_per_image.shape[0] == logits_per_image.shape[0] # logits' shape should be (nxn)
    assert logits_per_image.shape == logits_per_text.shape
    
    labels = torch.arange(logits_per_image.shape[0], device=DEVICE)
    loss_i = F.cross_entropy(logits_per_image, labels)
    loss_t = F.cross_entropy(logits_per_image, labels)
    loss = (loss_i + loss_t) / 2
    
    return loss

 

이제 데이터와 loss가 모두 주어졌기 때문에 학습을 시작할 수 있다.

최종적으로 fine-tuning 코드는 아래와 같다.

from torch.optim import AdamW
from tqdm import tqdm

optimizer = AdamW(model.parameters(), lr=LR)
model.train()

for epoch in tqdm(range(1, EPOCHS+1), position=0, desc="epoch"):
    for batch in tqdm(train_dataloader, position=0, desc="batch", leave=False):
        optimizer.zero_grad()

        outputs = model(
            pixel_values=batch["pixel_values"].to(DEVICE),
            input_ids=batch["input_ids"].to(DEVICE),
            attention_mask=batch["attention_mask"].to(DEVICE),
        )

        logits_per_image = outputs.logits_per_image
        logits_per_text = outputs.logits_per_text # logits_per_text == logits_per_image.T
        loss = loss_fn(logits_per_image, logits_per_text)
        loss.backward()
                
        optimizer.step()

    print(f"Train loss: {loss}")

 

Evaluate

fine-tuning을 마친 후에 test 데이터셋으로 성능을 평가해보자.

CIFAR10의 10개의 class에 대해 image classification을 진행하여 이를 통해 모델을 평가하였다.

import torch.nn.functional as F

all_class_texts = processor.tokenizer(test_dataset.class_texts)
all_class_texts = {k: torch.tensor(v, device=DEVICE) for k, v in all_class_texts.items()}

model.eval()
correct_count = 0
ce_loss_sum = 0

with torch.no_grad():
    for batch in tqdm(test_dataloader):
        outputs = model(
            pixel_values=batch["pixel_values"].to(DEVICE),
            **all_class_texts,
        )
        
        probs = outputs.logits_per_image.cpu().softmax(dim=1)
        pred = probs.argmax(dim=1)
        label = batch["label"]

        correct_count += (pred == label).sum().item()
        ce_loss_sum += F.cross_entropy(probs, label).item()
    
accuracy = correct_count / len(test_dataloader.dataset)
ce_loss = ce_loss_sum / len(test_dataloader)
print(f"Test CE loss: {ce_loss:.4}, Test accuracy: {accuracy:.4}")

 

댓글