SolartheNomad 2023. 4. 1. 02:36

RNN

- 시계열 데이터 처리 신경망 

- ‘기억(과거부터 현재까지의 입력 데이터를 요약한 정보)’을 가지는 신경망 

- 새로운 입력이 네트워크로 들어올 때마다 기억은 조금씩 수정되며, 결국 최종적으로 남겨진 기억은 모든 입력 전체를 요약한 정보가 된다. 

-  자연어 처리, 음성인식, 단어의 의미 판단 및 대화, 손글씨, 센서 데이터등의 시계열 데이터 ㅓ리에 활용

 

RNN의 구조 

 

-  첫 번째 입력(x1)이 들어오면 첫 번째 기억(h1)이 만들어지고, 두 번째 입력(x2)이 들어오면 기존 기억(h1)과 새로운 입력을 참고하여 새 기억(h2)을 만들어진다. 

 

RNN 유형

 

일대일

- 순환이 없어서 RNN이라고 보기 힘듦

- 순방향 네트워크

 

일대다 

- 입력이 하나이고, 출력이 다수인 구조

- 이미지 캡션 : 이미지를 입력해서 

 

다대일

 입력이 다수이고 출력이 하나인 구조로, 문장을 입력해서 긍정/부정을 출력하는 감성 분석기에서 사용

 

다대일 모델 

 

다대다 

- 입력과 출력이 다수이다

- 자동 번역기에서 주로 사용됨 

- seq-2-seq2(시퀀스-투-시퀀스)를 이용하는 방식으로 사용된다. 

Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(7855, 256)
    (rnn): LSTM(256, 512, num_layers=2, dropout=0.5)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (decoder): Decoder(
    (embedding): Embedding(5893, 256)
    (rnn): LSTM(256, 512, num_layers=2, dropout=0.5)
    (fc_out): Linear(in_features=512, out_features=5893, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
)

 

다대다 모델 

 

 

 


RNN 모델 한눈에 비교하기

 

 

 

 

RNN 구성 요소

 

RNN 계층

- RNN은 내장된 계층뿐만아니라 셀 레벨의 API 역시 제공함

- . RNN 계층이 입력된 배치 순서대로 모두 처리하는 것과 다르게 RNN 셀은 오직 하나의 단계(time step)만 처리한다.

- RNN 계층의 for loop 구문을 갖는 구조 

 

- 단일 입력과 과거 상태(state)를 가져와서 출력과 새로운 상태를 생성

 

셀 유형

 

 nn.RNNCell: SimpleRNN 계층에 대응되는 RNN 셀

 nn.GRUCell: GRU 계층에 대응되는 GRU 셀

 nn.LSTMCell: LSTM 계층에 대응되는 LSTM 셀

 

 

RNN 구조

- 은닉층 노드들이 연결되어 있어 이전 단계 정보들을 은닉층 노드에 저장할 수 있도록 구성함

-  xt-1에서 ht-1을 얻고 다음 단계에서 ht-1 xt를 사용하여 과거 정보와 현재 정보를 모두 반영함

 ht xt+1의 정보를 이용하여 과거와 현재 정보를 반복해서 반영함

입력층, 은닉층, 출력층, 가중치 세개( Wxh, Whh, Why)

 

Wxh :  입력층에서 은닉층으로 전달되는 가중치

Whh : t 시점의 은닉층에서 t+1 시점의 은닉층으로 전달되는 가중치

Why : 은닉층에서 출력층으로 전달되는 가중치

 

- Wxh, Whh, Why는 모든 시점에 동일하다

 

은닉층

- 계산을 위해 xt ht-1이 필요하다. 

- 이전 은닉층×은닉층 → 은닉층 가중치 + 입력층 → 은닉층 가중치×(현재) 입력 값

- 일반적으로 하이퍼볼릭 탄젠트 활성화 함수를 통하여 구현한다. 

 

출력층

- 심층 신경망과 계산방법이 동일함

- 은닉층 → 출력층 가중치×현재 은닉층

- 소프트맥스 함수를 적용함

 

 

오차 측정

- 심층 신경망에서 순방향 학습과 달리 각 단계(t) 마다 오차를 측정하게 됨 

- 각 단계마다 실제 값(yt)과 예측 값(yt^)으로 평균 제곱 오차를 이용하여 측정하게 됨 

- BPTT :  각 단계(t)마다 오차를 측정하고 이전 단계로 전달되는 것을 의미함