본문 바로가기
Machine Learning/Deep Learning

CNN Filter Visualization

by devson 2024. 6. 16.

DNN(Deep Neural Network)은 성능면에서는 이미 검증이 된 머신러닝 기법이다.

하지만 DNN이 어떠한 결론을 내렸을 때 어떠한 과정을 통해 그러한 결론을 내렸는지를 이해하는 것도 중요하다.

 

이번 포스팅에서는 CNN이 이미지에서 어떻게 물체를 인식하는지를 확인하기 위해 CNN 필터가 반응하는 패턴을 시각화하는 방법에 대해 알아보도록 하겠다.

 

(예제 코드는 여기에서 확인할 수 있다)

 


Concept

먼저 필터가 어떤 패턴에 반응하는지를 확인하려면 어떻게 해야할까?

필터의 출력인 활성화 값을 최대화시키는 이미지 데이터를 찾으면 된다.

그 이미지는 필터의 패턴에 가장 많이 반응하는 것이기 때문에 해당 필터의 패턴을 나타내는 이미지라고 볼 수 있을 것이다.

 

그러면 활성화 값을 최대화시키는 이미지 데이터를 어떻게 찾을 수 있을까?

Loss를 최소화하는 weight를 gradient descent를 통해 찾는 것과 동일하게,

활성화 값을 최대화하는 이미지를 gradient ascent를 통해 찾을 수 있을 것이다.



컨셉 자체는 Neural Network가 학습하는 것과 별반 다르지 않다.

이제 코드를 통해서 CNN 필터를 시각화하는 방법에 대해 알아보자.

 

코드

library setup

먼저 필요한 library import를 하고 행렬 shape 변환을 위해 einops를 설치한다.

import tensorflow as tf
from tensorflow import keras
import numpy as np

!pip install einops -q

 

Feature extractor

CNN 모델의 필터는 입력에 가까울수록 '가로선', '세로선' 같은 저차원 패턴을 학습하고,

출력에 가까울수록 좀 더 특정 물체의 특징에 대한 추상적인 패턴을 학습한다.

 

이러한 CNN 모델의 특성을 확인하기 위해 입력에 가까운 layer와 출력에 가까운 layer를 사용해서 feature extractor들을 만들도록 하겠다.

(여기서는 pretrained Xception Net을 사용하도록 하겠다)

from tensorflow.keras.applications import xception

model = xception.Xception(weights="imagenet",
                          include_top=False) # CNN의 필터만 필요
feature_extractors = []

for i in (2, 4, 12, 14):
    layer = model.get_layer(f"block{i}_sepconv1")
    feature_extractor = keras.Model(inputs=model.input, outputs=layer.output)
    feature_extractors.append(feature_extractor)

 

Image 생성 함수

이미지 데이터는 필터의 패턴에 반응하는 패턴을 찾기 위함이기 때문에 단순하게 random 노이즈 이미지를 생성하도록 한다.

from einops import rearrange

def generate_single_image_batch(height, width):
    image = tf.random.uniform((height, width, 3))
    return rearrange(image, "h w c -> 1 h w c") # add batch dimension

 

Loss 및 Gradient Ascent 정의

다음으로 필터의 출력인 활성화 값을 loss로 계산하는 compute_loss 함수를 정의한다.

def compute_loss(image_batch, feature_extractor, filter_index):
    activation = feature_extractor(image_batch)
    filter_activation = activation[:, 2:-2, 2:-2, filter_index] # 2:-2 -> remove border effect
    return tf.reduce_mean(filter_activation) # filter의 activation 평균값

 

그리고 gradient ascent를 통해 필터의 활성화 값인 loss를 증가시키는 이미지를 찾는 함수를 정의한다.

7번째 라인의 코드에서 볼 수 있듯이 이미지 데이터에 대한 loss gradient를 구하도록 한다.

@tf.function
def gradient_ascent_step(image_batch, feature_extractor, filter_index, learning_rate):
    with tf.GradientTape() as tape:
        tape.watch(image_batch)
        loss = compute_loss(image_batch, feature_extractor, filter_index)

    grads = tape.gradient(loss, image_batch)
    grads = tf.math.l2_normalize(grads)
    image_batch += learning_rate * grads # gradient ascent
    return image_batch

def generate_filter_pattern_image(image_height, image_width,
                      feature_extractor, filter_index, learning_rate):
    image_batch = generate_single_image_batch(height=image_height, width=image_width)
    for _ in range(10):
        image_batch = gradient_ascent_step(image_batch, feature_extractor, filter_index, learning_rate)
    return image_batch[0].numpy()

 

필터 패턴 이미지 생성 및 시각화

필터의 패턴 이미지를 생성하기 위한 준비는 다 되었다.

이제 필터 패턴 이미지를 생성한 뒤 시각화하는 코드를 짜면 된다.

 

먼저 위 generate_filter_pattern_image를 통해 구한 이미지 데이터를 이미지로 사용할 수 있도록 후처리하는 함수를 정의한다.

def deprocess_image(image):
    image -= image.mean()
    image /= image.std()
    image *= 64
    image += 128
    image = np.clip(image, 0, 255).astype("uint8") # [0, 255] 범위로 값 정규화
    image = image[25:-25, 25:-25, :] # 25:-25 -> remove border effect
    return image

 

그리고 각 feature extractor의 필터에 대해서 패턴 이미지를 생성하도록 한다.

여기서는 각 feature extractor 당 20개의 필터에 대한 패턴 이미지를 생성하도록 하겠다.

filter_pattern_images = []

for feature_extractor in feature_extractors[:5]:
    pattern_images = []
    for i in range(20):
        pattern_image = generate_filter_pattern_image(200, 200, feature_extractor, i, 10.)
        pattern_image = deprocess_image(pattern_image)
        pattern_images.append(pattern_image)

    filter_pattern_images.append(pattern_images)

 

그리고 결과를 출력해보자.

import matplotlib.pyplot as plt

def plot_filter_pattern_images(pattern_images, layer_name):
    plt.figure(figsize=(15, 10))
    plt.suptitle(layer_name)
    plot_count = 1

    for pattern_image in pattern_images:
        plt.subplot(4, 5, plot_count)
        plot_count += 1

        plt.imshow(pattern_image)

    plt.show()

plot_filter_pattern_images(filter_pattern_images[0], "block2_sepconv1")
plot_filter_pattern_images(filter_pattern_images[1], "block4_sepconv1")
plot_filter_pattern_images(filter_pattern_images[2], "block12_sepconv1")
plot_filter_pattern_images(filter_pattern_images[3], "block14_sepconv1")

 

이미지 자체가 많기 때문에

  • 입력에 가까운 block2_sepconv1 layer에 대한 필터 패턴 이미지
  • 출력에 가까운 block14_sepconv1 layer에 대한 필터 패턴 이미지

만을 첨부하였다.

 

 

 

이미지를 확인하였을 때 block2_sepconv1 layer의 필터는 가로, 세로와 같은 저차원의 패턴에 반응하고,

block14_sepconv1 layer의 필터는 명확하진 않지만 어떤 실루엣이 보이는 것을 봐서 물체의 형태에 대해서 반응하는 것을 확인할 수 있다.

 

이렇게 필터가 이미지에서 어떤 패턴에 반응하는지를 확인할 수 있다.

 

Gradient ascent가 아니라 Gradient descent를 사용한다면

지금까지는 필터가 반응하는 패턴을 확인하기 위하여 이미지 데이터에 대해 필터의 활성화 값을 증가시키도록 gradient ascent를 사용하였다.

 

반대로 필터의 활성화 값을 감소시키는 식으로 변경하면 어떤 이미지를 얻을 수 있을까?

위 코드에서 image_batch를 update하는 코드의 부호만 반대로 바꿔준 뒤 이미지를 비교해보자.

@tf.function
def gradient_ascent_step(image_batch, feature_extractor, filter_index, learning_rate):
    with tf.GradientTape() as tape:
        tape.watch(image_batch)
        loss = compute_loss(image_batch, feature_extractor, filter_index)

    grads = tape.gradient(loss, image_batch)
    grads = tf.math.l2_normalize(grads)

    # 변경한 코드
    image_batch -= learning_rate * grads # gradient descent
    return image_batch

(단순하게 확인하기 위해 함수명은 gradient_ascent_step 그대로 두고 내부 코드만 수정하였다)

 

그리고 동일하게 필터의 패턴 이미지를 생성해보자.

filter_pattern_images = []

for feature_extractor in feature_extractors[:5]:
    pattern_images = []
    for i in range(20):
        pattern_image = generate_filter_pattern_image(200, 200, feature_extractor, i, 10.)
        pattern_image = deprocess_image(pattern_image)
        pattern_images.append(pattern_image)

    filter_pattern_images.append(pattern_images)

plot_filter_pattern_images(filter_pattern_images[0], "block2_sepconv1")
plot_filter_pattern_images(filter_pattern_images[1], "block4_sepconv1")
plot_filter_pattern_images(filter_pattern_images[2], "block12_sepconv1")
plot_filter_pattern_images(filter_pattern_images[3], "block14_sepconv1")

 

명확하게 비교를 하기위해 저수준의 패턴을 학습하는 입력 쪽에 가까운 필터를 비교해보도록 하겠다.

  • 가로 패턴에 반응하는 필터의 경우(좌) 세로 형태의 패턴이 두드러지는 이미지가 나왔다. (우)
  • 반대로 세로 패턴에 반응하는 필터의 경우(좌) 가로 형태의 패턴이 두드러지는 이미지가 나왔다. (우)

즉, 필터가 반응하는 패턴이 제거된 이미지를 생성하는 확인할 수 있게된다.

댓글