My Profile Photo

DongChanS's blog


수학과 학생의 개발일지


Attention is all you need (transformer) 파악하기

예전에 attention에 대해서 공부했었는데 요즘 들어서 까먹은 부분이 존재하여 다시 복습하는 시간을 가져보았습니다.

참고링크 : http://mlexplained.com/2017/12/29/attention-is-all-you-need-explained/

1) “Attention is All You Need” 의 개요 이해하기

Attention is All You Need라는 논문은 2017년도에 나온 논문입니다.

이를 이해하기 위해서는 이전에 개발되었던 여러 네트워크에 대해서 이해하여야 하는데,

대표적으로 seq2seqalignment model 입니다.

1-1. Seq2Seq 모델

14

  • encoder-decoder 구조

    => 이를 통해 가변적인 길이의 sentence를 학습하고 도출할 수 있도록 함.

  • encoder : sequence of input words -> intermediate representation

  • decoder : intermediate representation + generated output words before -> next word 예측

    • MLE 기반의 loss function을 사용함

      15

1-2. RNN 과 encoder&decoder의 결합

  • recurrent 라는 컨셉이 언어의 컨셉과 잘 매칭됨 (사람이 단어를 왼쪽에서 오른쪽으로 읽듯이 RNN도 그러하니까)
  • encoder & decoder의 단점
    • 한 문장에 대한 전체적인 정보를 벡터 하나로 요약하기에는 너무 적다 (bottleneck)
    • 무조건 한 문장 전체로 뭉쳐진 정보를 써야한다.
      • 지금 도출해야 할 단어에 맞게 선택적으로 정보를 사용하고 싶다!
  • RNN의 단점
    • sequential nature of RNN => 병렬화 힘듬
    • long-range dependencies => LSTM도 한계가 있음.
    • 몇몇 단어는 multiple meaning을 가지는데 RNN으로 이를 커버하지 못한다고 함.

1-3. 전통적인 attention (alignment model)

16

alignment model은 기계번역에서 시작된 말인데요, source sentence의 토큰들과 output sentence의 토큰들의 의미론적인 맵핑을 뜻합니다.

이런 관계를 파악할 수 있다면, a라는 단어를 도출하기 위해서 has에 더 큰 가중치를 줄 수 있을 것입니다.

  • look back at the entire sentence & selectively extract

  • decoder 가 encoder의 hidden states를 확인함.

    • 각 hidden states에 가중치(attention weight)를 곱하여 attention vector 도출함.
  • attention weight는 feed-forward network를 통해서 구함.

1-4. 우리가 모델링해야 할 3가지의 dependencies

  • input sentence & output sentence
    • 1-3에서는 attention을 통해서 이 부분을 해결하려고 하였습니다.
  • input sentence
  • output sentence

기존에는 RNN(혹은 CNN)을 이용해서 input & output sentence의 dependency를 묘사했는데, transformer 논문에서는 attention을 이용해서 이 두개의 dependency를 모델링했습니다.

왜냐하면 RNN에는 1-2에서 소개한 것과 같은 단점이 있기 때문입니다.

1-5. Multi-head attention

transformer 모델에서는 하나의 attention 벡터를 만드는게 아니라 여러 개의 attention matrix (Multi-head attention) 을 만드려고 합니다.

  • multi-head ?

    12

    attention마다 문장을 보는 여러가지 관점들을 학습시키기 위해서 여러개의 attention을 활용합니다.

    ex) 문법적으로 유사, 의미적으로 유사….

    => 여러개의 block을 병렬적으로 학습할 수 있음.

  • Multi-head attention의 핵심 구조 : values, keys, queries

    • values : encoder hidden states (사진은 단어 그자체로 보이지만 hidden state 임)
    • query : processing the decoder state
    • keys : encoder hidden states (사진은 단어 그자체로 보이지만 hidden state 임)

    11

  • mechanism

    1. 각 query와 key들간의 유사도(attention score)를 측정한다.

    2. 그 attention score와 value의 weighted sum을 도출한다.

      => 도출하기 위해서 Scaled Dot Product 을 사용합니다.

  • Scaled Dot Product Attention

    별건 아니고 그냥 행렬곱입니다.

    13

    각 query마다, 각 value에 대한 attention score은 query와 key간의 dot product로 계산할 수 있습니다.

그 외에 positional embedding이나 residual network, layer normalization과 같은 몇몇 다양한 특징들이 있습니다!

comments powered by Disqus