5개월 전에 모델들을 이해해보려고 하나씩 하던 중 RNN을
https://www.youtube.com/watch?v=OkTyY28XMuQ&t=602s
이 영상을 보면서 공부를 해보았다.
pytorch에 rnn모델이 있지만 모델을 직접 numpy를 이용해서 구현해보았다.
일단 영상에 코드가 전부 나와있으니 보면서 이해하는게 어렵지 않았다.
rnn의 forward는 이해가 쉽다 단일 cell이라면 h0을 시작으로 순차적인 인풋을 임베딩해서 함께 모델에 넣어서 계산하면 다음 hidden state가 나오고 다음 input을 임베딩으로 바꿔서 넣어주고 계속 반복하는 것.
forward는 이해 된다. 문제는 여기부터 backpropagation 역전파에서 막혀버렸다.
사실 지금도 정확히 이해는 되지 않지만, 과거보단 이해에 가까워 져서 현재의 상태를 남기기 위해 작성한다.
softmax를 통해 예측값을 만들고 실제값과 cross entropy를 통해 loss를 만든다. 이 loss는 스칼라값.
이제 이 스칼라값을 각 가중치에 대해 미분하여 gradient를 만들어야하는데..
일단 차근차근 해보자.
https://velog.io/@hjk1996/Cross-Entropy%EC%99%80-Softmax%EC%9D%98-%EB%AF%B8%EB%B6%84
이 글을 보면 softmax와 cross entropy의 미분이 y^ - y인 것을 알 수 있다.
def backward(ps, hs, xs):
# Backward propagation through time (BPTT)
# 처음에 모든 가중치들은 0으로 설정
dV = np.zeros(V.shape)
dW = np.zeros(W.shape)
dU = np.zeros(U.shape)
for i in range(seq_len)[::-1]:
output = np.zeros((vocab_size, 1))
output[targets[i]] = 1
ps[i] = ps[i] - output.reshape(-1, 1)
# 매번 i스텝에서 dL/dVi를 구하기
dV_step_i = ps[i] @ (hs[i]).T # (y_hat - y) @ hs.T - for each step
dV = dV + dV_step_i # dL/dVi를 다 더하기
# 각i별로 V와 W를 구하기 위해서는
# 먼저 공통적으로 계산되는 부분을 delta로 해서 계산해두고
# 그리고 시간을 거슬러 dL/dWij와 dL/dUij를 구한 뒤
# 각각을 합하여 dL/dW와 dL/dU를 구하고
# 다시 공통적으로 계산되는 delta를 업데이트
# i번째 스텝에서 공통적으로 사용될 delta
delta_recent = (V.T @ ps[i]) * (1 - hs[i] ** 2)
# 시간을 거슬러 올라가서 dL/dW와 dL/dU를 구하
for j in range(i + 1)[::-1]:
dW_ij = delta_recent @ hs[j - 1].T
dW = dW + dW_ij
dU_ij = delta_recent @ xs[j].reshape(1, -1)
dU = dU + dU_ij
# 그리고 다음번 j번째 타임에서 공통적으로 계산할 delta를 업데이트
delta_recent = (W.T @ delta_recent) * (1 - hs[j - 1] ** 2)
for d in [dU, dW, dV]:
np.clip(d, -1, 1, out=d)
return dU, dW, dV, hs[len(inputs) - 1]
신박ai님의 코드를 가져온 것인데
내가 막힌 부분은
dV_step_i = ps[i] @ (hs[i]).T # (y_hat - y) @ hs.T - for each step
이 부분 dL/dV를 구하는데
dL/dV = dL/dy * dy/dV
=(y^ -y)@h(1)가 아니라 왜 (y^-y)@h(1)T인지..
그리고 다 양보해서 그렇다 칠 수 있는데
dL/dW를 구할때는 확실히 어지러웠다.
dL/dW = dL/dy * dy/dh(1)*dh(1)/dW
=(y^-y) @ V * ( 1 - h(1)^2 ) @ h(0) 여야만 하는데
왜!!
=VT @ (y^-y) *(1-h(1)^2) @ h(0)T 인지 왜 transpose하는지, 한건 그렇다 치더라고 왜 마음대로 dL/dy * dy/dh(1)이거를 자리를 바꾸는건지 도대체 이해가 되지 않았다..
내가 말한 대로 하면 행렬의 사이즈가 안맞아서 계산 자체가 안된다 근데 왜 저렇게 하는걸까 도대체 이해가 안된다. 그냥 계산 되게 transpose하고 이리저리 왔다갔다 교환법칙을 하는건가? 해서 google에 chain rule 교환법칙 검색해보고 이리저리 해도 마땅한 글을 못 찾았다 검색력 부족.
그래서 막혀서 진행 못한 지 어언 5개월 생각 날때마다 이해하려고 노력했는데.. 식을 써보면서 한번 이해해보자 싶어서 도전했다.
결과적으로는 체인룰 규칙을 내가 잘 못 배운건지 나만의 규칙을 찾았고 대충은 이해했다.
결론 부터 쓰자면 Y=AX 가 있을때,
1. 일단 미분을 하면 무조건 계수의 T가 붙어서 나온다.
dY/dA = xT, dY/dx = AT이다. (Y=스칼라)
그리고 체인룰에서는 Z=BY 라는 식이 추가 될때 (Z=스칼라 ,Y는 스칼라아님)
dZ/dA= dZ/dY * dY/dA 인데 처음의 미분에 대해서는 1의 법칙이 성립하여,
dZ/dY = BT가 되고
두번째 미분에 대해서는 또 다른 규칙이 추가되는데,
2.두번째 미분부터 미분해서 뒤에것이 나오면 그냥 계수에 T를 붙여서 곱해준다.
즉 dY/dA = XT, dZ/dA = BT @ XT가 된다.
(뒤의것이 나온다는 것은 Y=AX 에서 A)
마지막 규칙으로는
3. 두번째 미분부터 미분해서 앞에것이 나오면 계수에 T를 붙인 걸 이때까지 순서대로 미분한 것의 맨 앞에 곱해준다.
즉, dZ/dX = dZ/dY * dY/dX
(앞의것이 나온다는 것은 Y=AX 에서 A)
dZ/dY = BT
dZ/dX = BT * dY/dX = AT * BT 인 것이다.
예시를 만들어서 테스트 (A=스칼라) 예제 1.
A=BC ,B=DE D=FG
여기서 dA/dG를 구해보자
dA/dG = dA/dB * dB/dD * dD/dG
니까 앞에서부터 차근차근
dA/dB = CT
dA/dB * dB/dD = CT @ ET ( 뒤에것이 나와서)
dA/dB * dB/dD * dD/dG = FT@CT@ET ( 앞에것이 나와서)
다른 예시(E와 D 위치를 바꿈) 예제 2
A=BC ,B=ED D=FG
여기서 dA/dG를 구해보자
dA/dG = dA/dB * dB/dD * dD/dG
니까 앞에서부터 차근차근
dA/dB = CT
dA/dB * dB/dD =ET@CT ( 앞에것이 나와서)
dA/dB * dB/dD * dD/dG = FT@ET@CT ( 앞에것이 나와서)
중요한거 하나 dA/dG가 있을때 dG의 1단위 변화하면 dA가 얼마나 변하는지의 행렬이라서 dA/dG의 사이즈는 G행렬 크기와 같아야한다.
G의 요소마다 A스칼라값을 변화시키는 양이 기록되는거니까..
이런 법칙이 나오게 하는 방법은 직접 행렬을 만들어서 계산해 보면 저것과 계산 값이 똑같다.
임의의 행렬로 미분을 계산하고 저 방식으로 계산한 행렬이 똑같게 나온다.
유도하지는 못했지만 같은걸 확인해서 저게 맞는거 같다.
근데 이런 체인룰을 어디서 배우는걸까? 어떤 과목인지 궁금하다 아시는분 있으면 알려주세요 공부하고싶어요.
미적분학이나 선형대수일까,,
법칙을 찾는데 소요된 시간 5개월 나처럼 길 잃은 자들에게 도움이 됐으면 좋겠다..
2시간 정도 더 고민해보니까 알겠다
전개해서 L에 대해 미분하고 싶은 행렬의 원소 계수를 써주는 걸로 구한거랑 체인룰 규칙을 적용한거랑 식이 같다는 걸 손으로 쓰면서 풀다가 알았다.
미분하고 싶은 행렬만 빼고 양 옆의 행렬을 한 뭉텅이씩으로 보고 T를 취해주면 그게 gradient가 나온다
예제 1번에서 A=BC, B=DE, D=FG -> A=DEC, D=FG -> A = FGEC 로 쓸 수 있는데 dA/dG = FT @ (E@C)T = FT@CT@ET 니까 나온다..
예제 2번도 마찬가지로 A = EFGC로 쓸 수 있는데 dA/dG = (E@F)T @ CT = FT @ ET @ CT
다른 예로 또 들면
A= BCDEF 일경우
dA/dD = (B@C)T@(E@F)T = CT@BT@FT@ET
결국 스칼라로만드는 연산도 선형변환으로 취급해서 한줄로 표현하게 해서 저런식으로 하면 되는구나..
그래도 한줄로 표현하지 않아도 체인룰떄 쓸 수 있는 규칙을 알아내고 어느정도 이해했으니까 감사하다
'일상' 카테고리의 다른 글
다이어그램 그리는 사이트 (0) | 2024.06.05 |
---|---|
무료로 고퀄리티 chatbot 쓰는 방법 (0) | 2024.05.23 |
파이토치로 행렬 공부하기 (0) | 2024.05.22 |
파이토치 함수 알아보기 (0) | 2024.05.21 |
인공지능 공부하기 (0) | 2024.05.21 |