ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [순환신경망] RNN의 문제점 (기울기 소실, 기울기 폭주 = Gradient Vanishing & Exploding)
    IT/AI 2022. 10. 19. 01:25

    본 포스팅은 "밑바닥부터 시작하는 딥러닝2" 도서로 공부한 내용을 요약하기 위한 포스팅입니다.


    RNN이란 ?

      RNN(Recurrent Neural Network)은 시계열 데이터를 처리하기 위해 고안된 모델로 아래 그림에서와 같이 이전 시각(계층)의 출력 값(은닉 값)이 다음 시각(계층)으로 전파되어 즉, 과거 정보를 계승하여 시계열 데이터에 대응하는 신경망입니다.

     

    RNN의 순환 구조


    RNN의 문제점

      RNN은 이전 맥락의 정보를 기억하는 일명 순환 메커니즘을 통해 데이터가 한 방향으로만 흐르는 "Feed forward" 방식의 신경망보다 시계열 데이터 학습에 대한 성능이 향상 되었습니다. 하지만, RNN은 시계열 데이터의 장기 의존 관계를 학습하기에는 어려운 문제가 있습니다. 즉, 장기 기억에는 취약하다는 단점이 존재합니다.

     

    장기 기억 필요한 문제의 예시

     

      RNN 모델이 위의 예시의 정답을 맞추기 위해서는 "to"가 입력으로 주어지는 시각(계층)까지 그 이전의 맥락들을 기억하고 있어야 합니다. 즉, 이전의 모든 맥락 정보를 Output 값 (h)에 보관해둬야 합니다.

      하지만 이러한 상황에서 왜 RNN이 장기 기억에 취약한지 알아보기 위해 정답 레이블로 "TOM"이라는 단어가 주어졌을 때, 학습의 관점에서 생각해봅시다.

      아래 그림은 정답 레이블이 주어졌을 때 기울기가 역 전파(Back propagation)되는 과정을 나타낸 것이며, 빨간 선이 기울기의 전파 과정입니다.

     

    학습 과정에서 정답 레이블이 주어졌을 때 기울기의 흐름

     

      위의 그림처럼 기울기는 RNN 계층의 과거 방향으로 전달되고, 이 때 "의미 있는 기울기"가 전달되어야 시간 방향의 의존 관계를 제대로 학습할 수 있습니다. 즉, 과거로 전파되는 은닉 값(Hidden state)에 대한 기울기에 학습해야 할 의미 있는 정보가 들어 있고, 그것이 과거로 잘 전달되어야 장기 의존 관계를 학습할 수 있다는 말과 동일합니다.

     

      만약 이 기울기가 중간에 소실되거나 과도하게 커진다면 모델은 장기의존 관계를 학습할 수 없습니다.

    그리고 단순한 RNN 계층에서는 시간을 거슬러 올라갈수록 기울기가 작아지거나(기울기 소실) 혹은 커지는 문제(기울기 폭주)가 발생합니다.


    RNN의 기울기 소실 문제 (Gradient Vanishing)

      RNN에서 기울기 소실이 왜 일어나는 지 알아보기에 앞서, 은닉 값(Hidden state)의 계산 그래프를 그려보면 아래와 같습니다. 빨간 선은 Back propagation (기울기 전파) 과정을 나타냈습니다.

    RNN의 Computation graph

     

      위의 그림에서 알 수 있듯이 기울기 전파 시 'tanh' 연산과 'MatMul(행렬곱)' 연산을 통과하는 것을 확인할 수 있습니다.

    여기서 'tanh' 함수에 주목하여 살펴보겠습니다.

     

      'tanh' 함수의 미분은 다음과 같고 각 결과를 그래프로 그려보면 아래의 그래프가 나옵니다.

     

    tanh의 미분

     

    tanh(x)와 tanh(x) 미분의 그래프

     

      위의 그래프에서 살펴보면 dy/dx의 그래프는 값이 0~1이고, x가 0으로부터 멀어질수록 작아집니다.

    즉, 역전파에서는 기울기가 tanh 노드를 지날 때 마다 값은 계속해서 작아진다는 의미입니다. 따라서 계층이 많을 수록 tanh 함수를 많이 통과하기 때문에 기울기도 계속해서 작아지게 되고 소실되는 문제가 발생합니다.

     

    # 미분 그래프 코드
    
    x = np.arange(-4, 4, 0.1)
    y = np.tanh(x) # y = tanh(x)
    z = 1 - y**2 # dy/dx = 1 - y^2
    
    plt.figure(figsize=(12,8))
    
    plt.plot(x, y, color='red', alpha=0.5, linewidth=3, label='tanh(x)')
    plt.plot(x, z, color='black', linestyle='--', alpha=0.5, linewidth=3, label='dy/dx')
    
    plt.axhline(y=0, color='b', alpha=0.5, linestyle=':', linewidth=1)
    plt.axvline(x=0, color='b', alpha=0.5, linestyle=':', linewidth=1)
    
    plt.title("Tanh Backpropagation", fontdict = {'fontsize' : 20})
    plt.xlabel("x")
    plt.ylabel("y")
    plt.legend(loc=4, prop={'size' : 20})
    plt.show()

    RNN의 기울기 폭주 문제 (Gradient Exploding)

      기울기 폭주 문제는 'MatMul(행렬곱)' 연산에 의해 발생합니다.

    기울기 폭주 문제를 살펴보기 위해 행렬 곱 연산에만 주목하여 역전파 시 기울기 변화를 살펴봅니다.

    (실제 RNN의 case가 아닌 MatMul에 의한 기울기 폭주를 나타내기 위한 예시입니다.)

     

    MatMul의 역전파

     

      상류로부터 dh라는 기울기가 흘러온다고 가정했을 때, Matmul 노드에서의 기울기, 역전파는 dhWht 라는 행렬 곱으로 계산됩니다. 그리고 해당 계산을 시계열 데이터의 시간 크기(계층 크기)만큼 반복합니다.

     

      여기서 주목할 점은 이 행렬 곱셈에서는 매번 똑같은 가중치가 사용된다는 것이고, 해당 가중치의 값에 따라 기울기가 지수적으로 증가할 수도 감소할 수도 있습니다.

     

    기울기 폭주 문제의 대책

      기울기 폭주의 전통적인 기법으로는 기울기 클리핑(Gradient clipping)이라는 기법이 있습니다.

    말 그대로 기울기 값을 자르는 것을 의미하며, 기울기 폭주를 막기 위해 단순히 특정 임계값을 넘는 기울기 값을 잘라내거나 특정 비율만큼 곱하여 기울기 값을 감소시키는 방법입니다. 


    RNN의 기울기 소실 문제의 해결

      기존의 RNN에서 발생하는 기울기 소실 문제를 해결하기 위해 등장한 것이 '게이트가 추가된 RNN'이며 대표적으로 LSTM(Long Short-Term Memory model)과 GRU(Gated Recurrent Unit)이 있으며, 해당 모델에 관한 포스팅은 다음에 마저 진행하도록 하겠습니다.

Designed by Tistory.