🤗 Transformers - CLIPModel fine-tuning
앞서 image classification을 하는 예시를 통해 🤗 Transformers의 CLIPModel의 사용법을 알아보았다. - 참고
이번에는 CLIPModel을 custom dataset에 맞춰 fine-tuning하는 예제에 대해 알아보도록 하겠다.
(전체 코드는 여기에서 살펴볼 수 있다)
개요
데이터
모델을 훈련시키기 위해서 image-text pair 데이터가 필요하다.
여기서는 CIFAR10 데이터셋을 사용할 것이고, image의 label을 통해 text를 생성하여 pair 데이터를 만들어 사용할 것이다.
이 포스팅에서는 CIFAR10 데이터셋을 사용하였지만 이미지 사이즈도 작기도 하고 class도 10개로 매우 적기 때문에 실제로는 fine-tuning을 적용하기엔 부적합하다.
예시를 보여주기 위함이니 어떤 식으로 fine-tuning이 진행되는지만 참고하도록 하자.
학습 과정
전체적인 학습 과정에 대한 overview는 아래와 같다.
- CLIPProcessor, CLIPModel을 통해 image-text pair의 embedding을 구한다.
- 동일한 image-text pair의 embedding을 끼리는 서로 가까워야하기 때문에 dot product 값이 커야한다.
- 동일한 image-text pair의 내적값이 최대화되도록 학습을 진행시킨다.
코드 구현
이제 본격적으로 fine-tuning 코드를 살펴보자.
Config
개발 환경과 hyperparameter에 대한 config 설정이다.
각자 상황에 맞게 바꾸면 된다.
# env
DEVICE = "cuda"
DATA_PATH = "data"
# train
BATCH_SIZE = 512
EPOCHS = 5
LR = 5e-5
Load pretrained model
pretrained CLIP 모델을 가져온다.
from transformers import CLIPProcessor, CLIPModel
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
model.to(DEVICE)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", clean_up_tokenization_spaces=False)
Setup Dataset
데이터를 CLIPModel에서 사용할 수 있도록 가공하기 위해 custom Dataset을 구현하였다.
CIFAR10 데이터를 조회하여 image-text pair 데이터를 만들고 CLIPProcessor를 통해 전처리하는 Dataset이다.
import torch, torchvision
class CIFAR10Dataset(torch.utils.data.Dataset):
def __init__(self, clip_processor, is_train):
self.clip_processor = clip_processor
self.is_train = is_train
self.dataset = torchvision.datasets.CIFAR10(DATA_PATH, train=is_train)
self.class_texts = [
f"A photo of a {class_}."
for class_ in self.dataset.class_to_idx.keys()
]
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
image, label = self.dataset[idx]
text = self.class_texts[label]
return {
"image": image,
"label": label,
"text": text,
}
# batch로 데이터 전처리
def preprocess(self, batch):
images = [data["image"] for data in batch]
labels = [data["label"] for data in batch]
texts = [data["text"] for data in batch]
inputs = self.clip_processor(
text=texts,
images=images,
return_tensors="pt",
padding=True
)
return {
"text": texts,
"label": labels,
**inputs,
}
train_dataset = CIFAR10Dataset(processor, is_train=True)
test_dataset = CIFAR10Dataset(processor, is_train=False)
train_dataloader = torch.utils.data.DataLoader(train_dataset, collate_fn=train_dataset.preprocess, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, collate_fn=train_dataset.preprocess, batch_size=BATCH_SIZE)
CLIPProcessor의 출력값이 batch 데이터를 기준으로 나오기 때문에 Dataset.__getitem__ 에서 전처리를 하면 차원을 바꿔줘야 해서 불필요하게 코드가 장황해진다.
그렇기 때문에 Dataloader의 collate_fn에서 CLIPProcessor를 통해 batch 데이터를 전처리하도록 하여 코드를 간결하게 만들었다.
Fine-tune model
모델과 데이터가 준비되었으니 학습을 위해 loss를 정하면 학습을 진행할 수 있다.
loss는 논문에서와 동일하게 CrossEntropy를 사용하였다.
- logits는 image와 text를 내적한 결과이다.
- labels는 우리가 pair 데이터를 갖고 contrastive learning을 하기 때문에 arange로 만들어주면 된다.
여기서 loss_i와 loss_t의 경우 어차피 동일한 형태로 볼 수 있을텐데 왜 나눠서 구하는지 이해가 안될 수도 있다.
이는 각각 데이터 입장에서 loss를 구한 것이라고 볼 수 있는데,
loss_i의 경우 주어진 image에서 어떤 text가 가장 적합한지를 측정하기 위한 loss이고,
loss_t의 경우 주어진 text에 대해서 어떤 image가 가장 적합한지를 측정하기 위한 loss이다.
(관련해서는 여기를 참고하길 바란다)
이를 코드로 나타내면 아래와 같다.
logits_per_image와 logits_per_text는 서로 transpose 관계이다.
import torch.nn.functional as F
def loss_fn(logits_per_image, logits_per_text):
assert logits_per_image.shape[0] == logits_per_image.shape[0] # logits' shape should be (nxn)
assert logits_per_image.shape == logits_per_text.shape
labels = torch.arange(logits_per_image.shape[0], device=DEVICE)
loss_i = F.cross_entropy(logits_per_image, labels)
loss_t = F.cross_entropy(logits_per_image, labels)
loss = (loss_i + loss_t) / 2
return loss
이제 데이터와 loss가 모두 주어졌기 때문에 학습을 시작할 수 있다.
최종적으로 fine-tuning 코드는 아래와 같다.
from torch.optim import AdamW
from tqdm import tqdm
optimizer = AdamW(model.parameters(), lr=LR)
model.train()
for epoch in tqdm(range(1, EPOCHS+1), position=0, desc="epoch"):
for batch in tqdm(train_dataloader, position=0, desc="batch", leave=False):
optimizer.zero_grad()
outputs = model(
pixel_values=batch["pixel_values"].to(DEVICE),
input_ids=batch["input_ids"].to(DEVICE),
attention_mask=batch["attention_mask"].to(DEVICE),
)
logits_per_image = outputs.logits_per_image
logits_per_text = outputs.logits_per_text # logits_per_text == logits_per_image.T
loss = loss_fn(logits_per_image, logits_per_text)
loss.backward()
optimizer.step()
print(f"Train loss: {loss}")
Evaluate
fine-tuning을 마친 후에 test 데이터셋으로 성능을 평가해보자.
CIFAR10의 10개의 class에 대해 image classification을 진행하여 이를 통해 모델을 평가하였다.
import torch.nn.functional as F
all_class_texts = processor.tokenizer(test_dataset.class_texts)
all_class_texts = {k: torch.tensor(v, device=DEVICE) for k, v in all_class_texts.items()}
model.eval()
correct_count = 0
ce_loss_sum = 0
with torch.no_grad():
for batch in tqdm(test_dataloader):
outputs = model(
pixel_values=batch["pixel_values"].to(DEVICE),
**all_class_texts,
)
probs = outputs.logits_per_image.cpu().softmax(dim=1)
pred = probs.argmax(dim=1)
label = batch["label"]
correct_count += (pred == label).sum().item()
ce_loss_sum += F.cross_entropy(probs, label).item()
accuracy = correct_count / len(test_dataloader.dataset)
ce_loss = ce_loss_sum / len(test_dataloader)
print(f"Test CE loss: {ce_loss:.4}, Test accuracy: {accuracy:.4}")