Loading [MathJax]/jax/output/CommonHTML/jax.js
본문 바로가기
Machine Learning/Deep Learning

🤗 Transformers - CLIPModel을 사용한 Image Classification

by devson 2024. 8. 14.

 

CLIP(Contrastive Language-Image Pretraining) 모델은 multi-modal(text, image) embedding 모델로

Contrastive Learning를 사용하여 text와 image를 같은 공간에 임베딩할 수 있게하는 모델이다. 

 

CLIP 모델의 흥미로운 점은 Contrastive Learning를 통해 학습한 모델이,

image classification task에 있어서 zero-shot 성능이 탁월하다는 점이다.

https://arxiv.org/pdf/2103.00020

 

이번 포스팅에서는 image classification 예제를 통해 🤗 Transformers CLIPModel의 사용법을 익혀보도록 하겠다.

 

코드는 여기에서 확인할 수 있다.


 

개요

먼저 어떻게 image classification을 진행할지를 살펴보자.

4개의 class {cat, dog, horse, bear} 에 대해 분류를 하는 task로 각 class에 대한 text와 이미지의 embedding 값을 사용하여 이미지 분류를 진행한다.

각 component 위에 있는 초록색 글씨는 이를 처리하는 class를 의미한다.

위 이미지 예시에서 '고양이'와 '말' 이미지를 Image Encoder의 입력으로 넣기 때문에,

I1에 대해서는 T1I1의 값이 가장 높게 나와야하고, I2에 대해서는 T3I2의 값이 가장 높게 나와야한다.

 

이제 코드로 Image classification을 진행해보자.

 

Load pretrained model

pretrained CLIPModelCLIPProcessor를 불러온다.

from transformers import CLIPModel, CLIPProcessor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

 

  • CLIPModel
    • 앞서 이미지에서처럼 Text EncoderImage Encoder를 갖는 embedding 모델이다.
    • 이에 대응하는 text_model, vision_model 필드를 갖고 있다.
  • CLIPProcessor
    • CLIPImageProcessorCLIPTokenizer를 wrapping한 class로
    • CLIPImageProcessor는 이미지 처리(e.g. resize, normalization)를 담당한다.
    • CLIPTokenizer은 text tokenizing을 담당한다.
    • 이에 대응하는 image_processor, tokenizer 필드를 갖고있다.

 

Load Image

다음으로 분류할 이미지를 가져온다.

from PIL import Image
import requests
def get_image(url):
return Image.open(requests.get(url, stream=True).raw)
cat_image = get_image("http://images.cocodataset.org/val2017/000000039769.jpg")
horse_image = get_image("https://farm6.staticflickr.com/5465/8929343165_e34cf36bce_z.jpg")

 

예제 코드의 이미지는 COCO 데이터셋에서 가져온 고양이, 말 이미지이다.

 

Get model inputs

다음으로 CLIP 모델의 입력을 만들기 위해 CLIPProcessor를 사용한다.

4개의 class에 대한 text와 앞서 가져온 이미지를 CLIPProcessor의 입력으로써 사용한다.

text = ["a photo of a cat", "a photo of a dog", "a photo of a horse", "a photo of a bear"]
inputs = processor(text=text,
images=[cat_image, horse_image],
return_tensors="pt",
padding=True)
print("inputs.keys()")
print(inputs.keys())
print('\ninputs["input_ids"]')
print(inputs["input_ids"])
print('\ninputs["attention_mask"]')
print(inputs["attention_mask"])
print('\ninputs["pixel_values"].shape')
print(inputs["pixel_values"].shape) # 원본 이미지 크기는 (640, 480)

 

CLIPProcessor의 출력값은 CLIPModel의 입력값이 된다.

이 값은 dict로 다음과 같은 값이 들어있다.

  • input_ids: tokenized text
  • attention_mask
  • pixel_values: resized & normalized image

 

Inference

모델의 입력값을 사용하여 classification을 진행한다.

출력된 값의 logit을 통해 해당 이미지의 class를 예측할 수 있다.

import torch
with torch.no_grad():
outputs = model(**inputs)
print("\nimage-text similarity score:")
logits_per_image = outputs.logits_per_image # image-text similarity score
print(logits_per_image.numpy())
probs = logits_per_image.softmax(dim=1) # take the softmax to get the label probabilities
print("\nlabel probability(softmax):")
print(probs.numpy())
print("\npred labels:")
print(probs.argmax(dim=1).numpy())

 

2개의 이미지에 대해 4개의 class를 분류하는 task이기 때문에 (2x4) 차원을 갖는 것을 확인할 수 있다.

또한 높은 정확도로 class를 예측하는 것을 확인할 수 있다.

 


 

간단한 image classification 예제를 통해 🤗 Transformers의 CLIPModel을 사용하는 방법에 대해 알아보았다.