앞서 image classification을 하는 예시를 통해 🤗 Transformers의 CLIPModel의 사용법을 알아보았다. - 참고
이번에는 CLIPModel을 custom dataset에 맞춰 fine-tuning하는 예제에 대해 알아보도록 하겠다.
(전체 코드는 여기에서 살펴볼 수 있다)
개요
데이터
모델을 훈련시키기 위해서 image-text pair 데이터가 필요하다.
여기서는 CIFAR10 데이터셋을 사용할 것이고, image의 label을 통해 text를 생성하여 pair 데이터를 만들어 사용할 것이다.

이 포스팅에서는 CIFAR10 데이터셋을 사용하였지만 이미지 사이즈도 작기도 하고 class도 10개로 매우 적기 때문에 실제로는 fine-tuning을 적용하기엔 부적합하다.
예시를 보여주기 위함이니 어떤 식으로 fine-tuning이 진행되는지만 참고하도록 하자.
학습 과정
전체적인 학습 과정에 대한 overview는 아래와 같다.
- CLIPProcessor, CLIPModel을 통해 image-text pair의 embedding을 구한다.
- 동일한 image-text pair의 embedding을 끼리는 서로 가까워야하기 때문에 dot product 값이 커야한다.
- 동일한 image-text pair의 내적값이 최대화되도록 학습을 진행시킨다.

코드 구현
이제 본격적으로 fine-tuning 코드를 살펴보자.
Config
개발 환경과 hyperparameter에 대한 config 설정이다.
각자 상황에 맞게 바꾸면 된다.
# env DEVICE = "cuda" DATA_PATH = "data" # train BATCH_SIZE = 512 EPOCHS = 5 LR = 5e-5
Load pretrained model
pretrained CLIP 모델을 가져온다.
from transformers import CLIPProcessor, CLIPModel model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") model.to(DEVICE) processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", clean_up_tokenization_spaces=False)
Setup Dataset
데이터를 CLIPModel에서 사용할 수 있도록 가공하기 위해 custom Dataset을 구현하였다.
CIFAR10 데이터를 조회하여 image-text pair 데이터를 만들고 CLIPProcessor를 통해 전처리하는 Dataset이다.
import torch, torchvision class CIFAR10Dataset(torch.utils.data.Dataset): def __init__(self, clip_processor, is_train): self.clip_processor = clip_processor self.is_train = is_train self.dataset = torchvision.datasets.CIFAR10(DATA_PATH, train=is_train) self.class_texts = [ f"A photo of a {class_}." for class_ in self.dataset.class_to_idx.keys() ] def __len__(self): return len(self.dataset) def __getitem__(self, idx): image, label = self.dataset[idx] text = self.class_texts[label] return { "image": image, "label": label, "text": text, } # batch로 데이터 전처리 def preprocess(self, batch): images = [data["image"] for data in batch] labels = [data["label"] for data in batch] texts = [data["text"] for data in batch] inputs = self.clip_processor( text=texts, images=images, return_tensors="pt", padding=True ) return { "text": texts, "label": labels, **inputs, } train_dataset = CIFAR10Dataset(processor, is_train=True) test_dataset = CIFAR10Dataset(processor, is_train=False) train_dataloader = torch.utils.data.DataLoader(train_dataset, collate_fn=train_dataset.preprocess, batch_size=BATCH_SIZE, shuffle=True) test_dataloader = torch.utils.data.DataLoader(test_dataset, collate_fn=train_dataset.preprocess, batch_size=BATCH_SIZE)
CLIPProcessor의 출력값이 batch 데이터를 기준으로 나오기 때문에 Dataset.__getitem__ 에서 전처리를 하면 차원을 바꿔줘야 해서 불필요하게 코드가 장황해진다.
그렇기 때문에 Dataloader의 collate_fn에서 CLIPProcessor를 통해 batch 데이터를 전처리하도록 하여 코드를 간결하게 만들었다.
Fine-tune model
모델과 데이터가 준비되었으니 학습을 위해 loss를 정하면 학습을 진행할 수 있다.
loss는 논문에서와 동일하게 CrossEntropy를 사용하였다.

- logits는 image와 text를 내적한 결과이다.
- labels는 우리가 pair 데이터를 갖고 contrastive learning을 하기 때문에 arange로 만들어주면 된다.

여기서 loss_i와 loss_t의 경우 어차피 동일한 형태로 볼 수 있을텐데 왜 나눠서 구하는지 이해가 안될 수도 있다.

이는 각각 데이터 입장에서 loss를 구한 것이라고 볼 수 있는데,
loss_i의 경우 주어진 image에서 어떤 text가 가장 적합한지를 측정하기 위한 loss이고,
loss_t의 경우 주어진 text에 대해서 어떤 image가 가장 적합한지를 측정하기 위한 loss이다.
(관련해서는 여기를 참고하길 바란다)

이를 코드로 나타내면 아래와 같다.
logits_per_image와 logits_per_text는 서로 transpose 관계이다.
import torch.nn.functional as F def loss_fn(logits_per_image, logits_per_text): assert logits_per_image.shape[0] == logits_per_image.shape[0] # logits' shape should be (nxn) assert logits_per_image.shape == logits_per_text.shape labels = torch.arange(logits_per_image.shape[0], device=DEVICE) loss_i = F.cross_entropy(logits_per_image, labels) loss_t = F.cross_entropy(logits_per_text, labels) loss = (loss_i + loss_t) / 2 return loss
이제 데이터와 loss가 모두 주어졌기 때문에 학습을 시작할 수 있다.
최종적으로 fine-tuning 코드는 아래와 같다.
from torch.optim import AdamW from tqdm import tqdm optimizer = AdamW(model.parameters(), lr=LR) model.train() for epoch in tqdm(range(1, EPOCHS+1), position=0, desc="epoch"): for batch in tqdm(train_dataloader, position=0, desc="batch", leave=False): optimizer.zero_grad() outputs = model( pixel_values=batch["pixel_values"].to(DEVICE), input_ids=batch["input_ids"].to(DEVICE), attention_mask=batch["attention_mask"].to(DEVICE), ) logits_per_image = outputs.logits_per_image logits_per_text = outputs.logits_per_text # logits_per_text == logits_per_image.T loss = loss_fn(logits_per_image, logits_per_text) loss.backward() optimizer.step() print(f"Train loss: {loss}")
Evaluate
fine-tuning을 마친 후에 test 데이터셋으로 성능을 평가해보자.
CIFAR10의 10개의 class에 대해 image classification을 진행하여 이를 통해 모델을 평가하였다.
import torch.nn.functional as F all_class_texts = processor.tokenizer(test_dataset.class_texts) all_class_texts = {k: torch.tensor(v, device=DEVICE) for k, v in all_class_texts.items()} model.eval() correct_count = 0 ce_loss_sum = 0 with torch.no_grad(): for batch in tqdm(test_dataloader): outputs = model( pixel_values=batch["pixel_values"].to(DEVICE), **all_class_texts, ) probs = outputs.logits_per_image.cpu().softmax(dim=1) pred = probs.argmax(dim=1) label = batch["label"] correct_count += (pred == label).sum().item() ce_loss_sum += F.cross_entropy(probs, label).item() accuracy = correct_count / len(test_dataloader.dataset) ce_loss = ce_loss_sum / len(test_dataloader) print(f"Test CE loss: {ce_loss:.4}, Test accuracy: {accuracy:.4}")
'Machine Learning > Deep Learning' 카테고리의 다른 글
🤗 Transformers - CLIPModel을 사용한 Image Classification (0) | 2024.08.14 |
---|---|
Various Normalizations on CNN (0) | 2024.08.07 |
CNN Filter Visualization (0) | 2024.06.16 |
CNN에서 layer가 깊어질 수록 channel size를 키우는 이유 (0) | 2024.05.31 |
Dropout vs Inverted Dropout (0) | 2024.05.21 |
댓글