4 분 소요

torch.nn.RNN torch.nn.LSTM torch.nn.GRUforward의 결과물로 두 벡터를 돌려준다. pytorch의 공식 doc에 의하면 outputh_n인데, 만드는 모델, 과제에 따라 사용해야 하는 벡터가 다르다. 지금까지 주로 huggingface 라이브러리를 썼기 때문에 torch.nn의 RNN/LSTM/GRU를 쓸 일은 별로 없었지만 아직도 양뱡향 LSTM 모델등은 일부 사용하는 경우가 있기에 이번 기회에 항상 헷갈렸던 output'h_n의 차이점에 대해 알아봤다.

Note: nn.LSTMoutput과 hidden_state, cell state가 담긴 튜플 (h_n, c_n)를 리턴한다.

아래 설명은 input이 torch.nn.utils.rnn.PackedSequence 객체로 되어 있다고 가정한다. 편이상 RNN/LSTM/GRU는 RNN으로 통칭한다.

output과 h_n의 차이

  • output

shape: (seq_len, Batch_size, D * hidden_size) (Batch_first=False일 시)

output모든 time stepfinal layer의 hidden state들이다. 즉 주어진 문장 내 모든 토큰들의 마지막 hidden state이다. 만약 양방향 RNN이면 각 토큰별 순방향과 역방향시 h_t시 concat되어 리턴되므로 마지막 차원의 크기는 2 * hidden_size이다.

  • h_n

shape: (D * n_layers, Batch_size, hidden_size)

h_n마지막 time step (token)모든 layer의 hidden state이다. D=2 즉 양뱡향 RNN 모델일 때 h_n의 짝수번째 벡터는 순방향, 홀수번째 벡터는 역방향 layer의 최종 결과물이다.

  output h_n
layer final only all
timestep 1, 2, …, t t

output과 h_n의 관계

레이어 수가 하나인 단방향 RNN의 경우에만 h_noutput의 부분집합이다. 정확히 말해 마지막 토큰의 벡터인 output[-1]h_n이다. 그러나 레이어 수가 2 이상이고, 양뱡향인 RNN의 경우 outputh_n의 정확한 관계를 파악하기가 어렵다.

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

# pad_token의 id: 0
X = [[12, 23, 1, 16, 59, 6, 37],
     [6,  3,  2, 76, 31, 0, 0],
     [12, 6,  1,  0,  0, 0, 0],
     [11, 21, 18, 1,  0, 0, 0]]

X = torch.tensor(X, dtype=torch.long)
X.shape # (batch_size, seq_len)
torch.Size([4, 7])
# pad_token을 제외한 실제 token들의 len
input_lens = torch.tensor([torch.max(row.nonzero())+1 for row in X])
input_lens, indices = input_lens.sort(descending=True)
X_sorted = X[indices]

print(f'정렬 후: {X_sorted}')
print(f'패딩 토큰을 제외한 실토큰 개수: {input_lens}')

정렬 후: tensor([[12, 23,  1, 16, 59,  6, 37],
        [ 6,  3,  2, 76, 31,  0,  0],
        [11, 21, 18,  1,  0,  0,  0],
        [12,  6,  1,  0,  0,  0,  0]])
패딩 토큰을 제외한 실토큰 개수: tensor([7, 5, 4, 3])
hidden_size=256

class RNNModel(nn.Module):
  def __init__(self, rnn_type):
    super(RNNModel, self).__init__()

    self.embedding = nn.Embedding(100, 128)
    self.num_layers = 2
    self.hidden_size = hidden_size

    self.rnn = getattr(nn, rnn_type)(
        input_size=128, 
        hidden_size=self.hidden_size,
        num_layers=2,
        bidirectional=True,
    )

  def forward(self, Batch, Batch_lens):  # Batch: (batch_size, seq_len), Batch_lens: (batch_size)
    # d_w: word emBedding size
    batch_emb = self.embedding(Batch)  # (batch_size, seq_len, d_w)
    batch_emb = batch_emb.transpose(0, 1)  # (seq_len, batch_size, d_w)

    packed_input = pack_padded_sequence(batch_emb, Batch_lens)
    h_0 = torch.zeros((self.num_layers * 2, Batch.shape[0], self.hidden_size))  # (num_layers*num_dirs, batch_size, d_h) = (4, batch_size, d_h)
    packed_output, h_n = self.rnn(packed_input, h_0)  # h_n: (4, batch_size, d_h)
    output = pad_packed_sequence(packed_output)[0]  # outputs: (seq_len, batch_size, 2*d_h)

    return output, h_n, packed_output
# 양방향 RNN이고, batch_size = False일 때

rnn = RNNModel(rnn_type='GRU')
output, h_n, packed_outputs = rnn(X_sorted, input_lens)
print(f'output shape (seq_len, batch_size, 2*hidden_size) == {output.shape}')
print(f'h_n shape (2*num_layers, batch_size, hidden_size) == {h_n.shape}')

output shape (seq_len, batch_size, 2*hidden_size) == torch.Size([7, 4, 512])
h_n shape (2*num_layers, batch_size, hidden_size) == torch.Size([4, 4, 256])
  • 순방향

순방향시 마지막 time-step은 문장 맨 마지막 토큰이며, 그 값은 output 마지막 차원의 첫 hidden_size개의 원소와 같다 (순방향+역방향을 붙인 벡터이기 때문에).

  • 역방향

역방향시에는 마지막 time-step이 문장 맨 처음 토큰이며, 그 값은 output 마지막 차원내 뒤에서부터 hidden_size개의 원소와 같다.

# forward: last time-step (i.e. last token), last-layer
print(torch.eq(output[-1, 0, :hidden_size], h_n[2, 0, :]).all())

# backward: last time-step (i.e. first token), last-layer
print(torch.eq(output[0, 0, hidden_size:], h_n[3, 0, :]).all())
tensor(True)
tensor(True)

outputh_n의 관계를 일반화 해보자.

# backward 
batch_size = output.shape[1]
for batch_idx in range(batch_size):
    if torch.eq(output[0, batch_idx, hidden_size:], h_n[3, batch_idx, :]).all().item():
        print(f'At {batch_idx}th batch output & h_n --> equal')
At 0th batch output & h_n --> equal
At 1th batch output & h_n --> equal
At 2th batch output & h_n --> equal
At 3th batch output & h_n --> equal

역방향 rnn layer의 경우 output의 첫번째 토큰(0)은 모든 배치에서 [hidden_size : ] 원소들과 h_n[3]의 원소들은 일치한다 (앞서 말했듯 h_n은 모든 레이어의 마지막 time-step의 hidden_state를 리턴하는데, 첫 레이어의 순방향, 첫 레이어의 역방향, 두번째 레이어의 순방향, 두번째 레이어의 역방향시 마지막 hidden state들을 차례차례 리턴한다).

순방향의 rnn layer는 어떨까?

# forward
for batch_idx in range(batch_size):
    if torch.eq(output[-1, batch_idx, :hidden_size], h_n[2, batch_idx, :]).all().item():
        print(f'At {batch_idx}th batch output & h_n --> equal')
    else:
        print(f'At {batch_idx}th batch output & h_n --> NOT equal')

At 0th batch output & h_n --> equal
At 1th batch output & h_n --> NOT equal
At 2th batch output & h_n --> NOT equal
At 3th batch output & h_n --> NOT equal

마지막 토큰의 hidden state 즉 output[-1, batch_idx, :hidden_size]를 했음에도 몇몇 배치에서는 예상과 다르게 나왔다. 문제를 알아보기 위해 다음 코드를 돌려봤다.

# forward
for batch_idx in range(batch_size):
    for token_idx in range(output.shape[0]):
        if torch.eq(output[token_idx, batch_idx, :hidden_size], h_n[2, batch_idx, :]).all().item():
            print(f'For {token_idx}th token at {batch_idx}th batch output & h_n --> equal')

For 6th token at 0th batch output & h_n --> equal
For 4th token at 1th batch output & h_n --> equal
For 3th token at 2th batch output & h_n --> equal
For 2th token at 3th batch output & h_n --> equal

각 배치마다 outputh_n이 겹치는 토큰 인덱스가 다른데, 정렬한 인풋의 (패딩 토큰을 제외한) 실토큰의 개수와 일치함을 알 수 있다. 이는 마지막 time step의 hidden state라고 하는 h_n의 경우 (패딩 토큰이 없는) 실제 문장의 마지막 토큰의 hidden state을 담기 때문이다.

# 위에서 만든 정렬한 인풋
print(f'정렬 후: {X_sorted}')
print(f'패딩 토큰을 제외한 실토큰 개수: {input_lens}')
정렬 후: tensor([[12, 23,  1, 16, 59,  6, 37],
        [ 6,  3,  2, 76, 31,  0,  0],
        [11, 21, 18,  1,  0,  0,  0],
        [12,  6,  1,  0,  0,  0,  0]])
패딩 토큰을 제외한 실토큰 개수: tensor([7, 5, 4, 3])

업데이트: