lstm은 rnn의 장기 의존성 문제를 해결하기 위해 등장했고,
import torch
import torch.nn as nn
import torch.optim as optim
# LSTM 모델 정의
class LSTMModel(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size):
super(LSTMModel, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
out, _ = self.lstm(x, (h0, c0))
out = self.fc(out[:, -1, :])
return out
# 하이퍼파라미터 설정
input_size = 10 # 입력 차원
hidden_size = 20 # LSTM의 hidden state 차원
num_layers = 2 # LSTM 레이어 수
output_size = 1 # 출력 차원
num_epochs = 100 # 학습 에폭 수
learning_rate = 0.001 # 학습률
# 데이터셋 생성 (여기서는 임의의 데이터)
x_train = torch.randn(100, 10, input_size)
y_train = torch.randn(100, output_size)
# 모델, 손실 함수, 최적화 알고리즘 정의
model = LSTMModel(input_size, hidden_size, num_layers, output_size)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 모델 학습
for epoch in range(num_epochs):
model.train()
outputs = model(x_train)
optimizer.zero_grad()
loss = criterion(outputs, y_train)
loss.backward()
optimizer.step()
if (epoch+1) % 10 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
print("학습 완료")
파이토치를 이용하면 이렇게 구현할 수 있는데
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
forward에서 out, _ = self.lstm(x, (h0, c0))를 하면 return으로 out과 (h,c)가 나오는데 코드에서는 (h,c)는 _로 안받지만, 받는다면 out에는 x시퀀스에 따른 결과 h 시퀀스, h와 c에는 마지막 벡터에 대한 hidden state와 cell state이다.
이 out을 linear layer에 넣어줘서 최종 output을 구하고 이걸 target과 loss를 구해서 최적화하면 된다
파이토치가 아닌 numpy만으로 구현한다면, lstm그림에 나온 것처럼 각 게이트(forget, input, candidate, output, final) 계산들을 forward에 구현해주고 backpropagation을 chain rule로 계산해 주어 backward를 구현해주면 된다. 이 역시 rnn때와 마찬가지로 과거의 어느정도까지 학습시킬지 truncation을 정해서 구현해야한다.여기서 lstm은 rnn보다 더 긴 truncation을 가지고 있어도 학습이 잘 되는 경향이 있는데 이는 lstm이 long term dependency를 더 잘 학습할 수 있는 구조적 특징을 가지고 있기 때문이다.