Machine Learning/Paper

[논문 리뷰] Deep Residual Learning for Image Recognition

devson 2024. 7. 1. 22:43

ResNet으로 유명한 Deep Residual Learning for Image Recognition의 논문을 리뷰한다.

ResNet은 CNN을 아키텍처이지만, 이 논문에서 시사하는 바는 딥러닝에 일반적으로 통용되기 때문에 꼭 읽어봐야할 논문이다.

(기계 번역을 위한 모델인 Transformer에서도 skip(shortcut) connection을 사용하는 것을 확인할 수 있다)

 

이번 글을 통해 논문에서 제기하는 문제와 이를 해결하기 위한 residual learning에 대해 이해할 수 있도록 하자.

(최대한 논문의 내용을 베이스로 하되 좀 더 쉽게 이해하도록 풀어쓰기 위해 노력하였다)

 


 

* 이 논문에서 얘기하는 "plain" network는 이 논문에서 제안하는 shortcut connection을 적용하지 않은 network를 얘기한다. (Figure 3 참고)

Figure 3.

 

Degradation problem in "plain" Deep convolution neural network

(in 1. Introduction)

"Deep" convolutional neural network는 이미지 분류에서 혁신을 가져왔다.

그렇다면 레이어를 더 쌓는다면 network가 학습을 더 잘할 수 있을까?

(Is learning better networks as easy as stacking more layers?)

  • 단순히 레이어를 깊게 쌓기만 한다면 vanishing/exploding gradients가 발생하여 수렴을 방해하여 학습이 원활하게 일어나지 않게된다.
    • 하지만 vanishing/exploding gradientsnormalized initialization, intermediate normalization layers를 통해 대부분 해결되었다.
      (즉, 이 논문에서 풀고자 하는 문제는 vanishing/exploding gradients와 관련된 것이 아니다)
  • 깊은 네트워크가 학습 시 수렴이 가능하다면 degradation 문제가 나타난다.
    (degradation 문제가 이 논문에서 다루는 가장 핵심 문제이다)
    • 모델이 깊어지면 accuracy가 포화되고(saturated) 성능이 낮아진다(degrade).

 

"plain" network에서 레이어의 수를 20개, 56개로 하여 각각 CIFAR-10 데이터셋으학습한 결과가 아래 Figure 1과 같다.

(설명은 없지만 아마도 vanishing/exploding gradients를 막기 위한 조치는 했을 것이다)

Figure 1. Shallow "plain" network VS Deep "plain" network

 

여기서 training error에 대한 왼쪽 그래프를 보면, 더 깊은 layer에서의 error가 더 높은 것을 확인할 수 있다.

  • 이 degradation 문제는 overfitting 문제는 아니다.
    • 만약 overfitting이었다면 training error는 더 깊은 네트워크에서 더 낮게 나왔을 것이다.
      (또한 vanishing/exploding gradients가 발생했다면 학습조차 제대로 이뤄지지 않았을 것이다)

그리고 test error도 동일하게 더 깊은 모델이 더 높게 나왔다.

 

 

(in 4.1 ImageNet Classification)

관련하여 다른 실험 결과를 살펴보자.

깊이가 다른 "plain" network들을 ImageNet 데이터셋으로 학습시킨 결과가 Figure 4의 왼쪽 그래프이다.

(아래 이미지는 Figure 4의 왼쪽 그래프만 잘라놓은 이미지이다)

이때 plain network에는 Batch Normalization을 적용하였다.

Figure 4. (left): plain networks' training error(thin) and validation error(bold) on ImageNet

 

이때도 동일하게 깊은 plain network에 대해 degradation 문제를 볼 수 있었다.

  • network에 Batch Normalization을 적용하였고 역방향으로 전파되는 gradient가 healthy norms을 보이기 때문에
    깊은 plain network의 optimization이 잘 안되는 이유는 vanishing gradients 때문이 아니라고 주장한다.
    • 또한 깊은 plain network도 충분히 경쟁력있는 성능을 갖기 때문에, vanishing gradients가 발생했다면 이러한 결과를 얻을 수 없었을 것이다.
  • 심지어 training iteration을 3배로 하여도 여전히 degradation 문제를 볼 수 있었다고 한다.

 

Deep Residual Learning

(in 1. Introduction)

앞서 plain network를 깊게 쌓았을 때 degradation 문제가 발생하는 것을 확인할 수 있었다.

 

하지만 생각했을 때는 깊은 모델은 얕은 모델의 성능을 그대로 유지할 수 있을 것 같다.

  • 예를 들어, Model A와 Model B가 있다고 하자. Model B는 Model A에 레이어를 더 쌓아 올린 모델이다.
    • 만약 Model B에서 추가로 쌓아 올린 레이어가 입력과 출력이 동일한 identity mapping이라면, Model B의 성능은 Model A와 동일할 것이다.

 

이러한 사고로는 깊은 모델은 얕은 모델보다 성능이 좋지는 않아도 최소한 유사하기라도 할텐데,
실제로는 그렇지 않다는 것을 앞서 실험 결과를 통해 확인할 수 있었다.

 

이어 저자들은 deep residual learning이라는 개념을 소개하면서 degradation 문제를 어떻게 해결할 수 있는지에 대해 가정을 한다.

  • 먼저 residual learning을 위한 구조는 아래 Figure 2와 같다.
    • block의 입력과 출력값을 서로 더해주는 shortcut connection을 추가한다.

Figure 2

  • $\mathcal{x}$: block의 입력
  • $\mathcal{F(x)}$: block의 출력
  • $\mathcal{H(x)}$: block의 최적의 출력 (desired underlying mapping)

 

shortcut connection을 사용하여 입력과 출력을 이어줌(residual mapping)으로써 더 쉽게 최적화 될 수 있다고 한다.

  • shortcut connection이 없는 block(original mapping)의 출력을 $\mathcal{F(x)}$라고 하면
  • shortcut connection이 있는 block(residual mapping)의 출력은 $\mathcal{F(x) + x}$가 된다.

저자는 shortcut connection이 있는 경우가 더 쉽게 최적화 된다고 가정한다.

극단적으로 만약 정말 identity mapping이 최적($\mathcal{H(x)} = x$)이라면

  • shortcut connection 없이 $\mathcal{F(x)}$이 x가 되는 것보다는
  • shortcut connection 이 있으면서 $\mathcal{F(x)}$이 0이 되는 것이 더 쉬울 것이라고 예측한다.
    (weight는 0 근처의 작은 값으로 initialization 되기 때문에 weight를 0으로 만들어 출력을 0으로 만드는건 어렵지 않을 것이다)

즉, shortcut connection을 추가함으로써 block은 $\mathcal{H(x)}$를 출력하기 위해 입력과의 차이인 잔차(residual)만 학습하면 되기 때문에 학습을 하면서 최적화하는 과정이 더 쉬워진다는 것이다.

 

(in 3.1 Residual Learning - 의 내용을 다르게 풀어 정리)

하지만 이 가설은 가정이 있는데, $\mathcal{H(x)}$가 0(zero mapping)보다는 identity mapping에 가까워야 한다는 것이다.

  • $\mathcal{H(x)} = 0$ 이라면 shortcut connection이 없으면 F(x)는 0이 되어야한다.
  • $\mathcal{H(x)} = 0$ 이라면 shortcut connection이 있으면 F(x)는 -x가 되어야한다.

즉, 앞서 논리의 반대가 되버리게되어 residual learning이 무용지물이 되버린다.

 

물론 현실에서는 identity mapping이 최적은 아닐지라도, 최적의 함수(optimal fnuction)가 (출력이 0인)zero mapping보다 identity mapping에 가까우면 shortcut connection이 있는 편이 잔차를 학습하는게 더 쉬울 것이다.

 

(in 4.2. CIFAR-10 and Analysis - Analysis of Layer Responses.)

이와 관련해서 plain network와 residual network에서 각 layer의 출력(Conv -> BN 출력)에 대한 표준편차(std)를 Figure 7.에 정리하였다.

Figure 7.

Figure 7.에서 아래 그래프는 layer를 무시하고 표준편차 크기 내림차순으로 정렬하였는데, residual network에서 각 layer의 출력값에 대한 표준편차가 대부분 낮은 것을 확인할 수 있다.

이는 (앞서 3.1에서의) 잔차로 학습해야하는 것이 0에 더 가깝다(identity mapping)는 가정과 부합하는 결과이다.

 

Experiments

앞서 깊은 plain network에서 degradation 문제와 이를 해결하기 위한 residual learning, 그리고 residual learning을 통해 네트워크 학습이 왜 더 쉬워질 수 있는지에 대해 설명하였다.

관련된 실험은 그 결과에 대한 Figure와 Table을 나열하고 간단하게 설명을 하고자한다.

 

  • Figure 4. (좌) plain network에 대해 layer를 깊게하였을 때 error와 (우) plain network에 대해 layer를 깊게하였을 때 error를 비교한다.
  • (좌) plain network의 경우, 깊은 network의 error가 더 높은 것을 확인할 수 있지만,
    (우) residual network의 경우, 깊은 network의 error가 더 낮은 것을 확인할 수 있다.
    • degradation 문제를 해결하였다!
  • 또한 layer depth가 동일하더라도 같은 iteration에서의 residual network error가 더 낮은 것을 볼 수 있다.
    • residual learning이 네트워크 학습이 더 쉽다는 것을 확인할 수 있다!

Figure 4. Training on ImageNet

 

  • Table 3., 4.는 이전 ImageNet 대회에서 좋은 성적을 거둔 모델들과 ResNet을 비교한 결과이다.
    • 다른 모델과 비교했을 때 ResNet의 error가 낮을 뿐만 아니라, 더 깊은 네트워크의 ResNet(ResNet-152)의 error가 가장 낮은 것을 확인할 수 있다.

Table 3., Table 4.

 

  • Table 6.와 Figure 6.는 CIFAR-10 데이터셋에서 ResNet을 트레이닝한 결과이다.
  • 여기서 다른 실험 결과와 다르게 추가적으로 확인할 수 있는 정보는 훨씬 깊은 1202-layer ResNet을 실험한 결과이다.
    • Figure 6.의 가장 오른쪽 그래프를 보면 1202-layer의 test error가 그보다 얕은 110-layer 보다 높은 것을 볼 수 있다.
      하지만 (바닥에 깔려있는) training error를 보면 매우 낮게 나왔다.
      저자는 이를 매우 큰(19.4M) 네트워크에 비해 데이터가 적어서 overfitting 되었다고 판단한다.
      (이는 plain network에서의 degradation 문제와는 다르다!)
      • residual learning을 실험하기 위해 적극적으로 regularization을 하지 않았는데,
        regularization을 제대로 하면 아마도 결과가 개선될 것이라고 한다.

Table 6., Figure 6.