<>一、RNN的前向传播结构
t时刻输入: XtX_{t}Xt 、St−1S_{t-1}St−1
t时刻输出: hth_{t}ht
t时刻中间状态: StS_{t}St
上图是一个RNN神经网络的时序展开模型,中间t时刻的网络模型揭示了RNN的结构。可以看到,原始的RNN网络的内部结构非常简单。神经元A在t时刻的状态仅仅是(t-1)时刻神经元状态
St−1S_{t-1}St−1,与(t)时刻网络输入XtX_tXt
的双曲正切函数的值;这个值不仅仅作为该时刻网络的输出,也作为该时刻网络的状态被传入到下一个时刻的网络状态中,这个过程叫做RNN的正向传播(forward
propagation)
传播中的数学公式(含参数)
上图表示为RNN网络的完整的拓扑结构,以及RNN网络中相应的参数情况。我们通过对t
时刻网络的行为进行数学的推导。在如下的内容中,会出现线性状态和激活状态两种表达,线性状态将用∗*∗号进行标注。
t时刻神经元状态 :
St=ϕ(St∗)S_t= {\phi}{(S{_t^*})}St=ϕ(St∗)
St∗=(UXt+WSt−1)S{_t^*}=(UX_t+WS_{t-1})St∗=(UXt+WSt−1)
t时刻的输出状态:
Ot=ψ(Ot∗)O_t=\psi{(O{_t^*})}Ot=ψ(Ot∗)
Ot∗=VStO{_t^*} = VS_tOt∗=VSt
我们该如何得到RNN模型中的U、V、W三个全局共享参数的具体值呢?在之后的RNN逆向传播中可以得出具体的情况。
<>二、BPTT(随时间变化的反向传播算法)
1、 损失函数的选取,在RNN中一般选取交叉熵(Cross Entropy),表达式如下:
Loss=−∑i=0nyilnyi∗Loss = -{\sum_{i=0}^{n}y_ilny_i^*}Loss=−i=0∑nyilnyi∗
上式为交叉熵的标量的形式,yiy_iyi是真实的标签纸,yi∗y_i^*yi∗
是模型给出的预测值,在多维输出值的时,则可以通过累加得出n维损失值。交叉熵在应用于RNN需进行微调:首先,RNN的输出是向量的形式
,没有必要将所有的维度进行累加一起,直接把损失值用向量进行表达即可;其次,由于RNN模型是序列问题,因此其模型损失不能只是一个时刻的损失,应该包含全部N个时刻的损失。
因此RNN模型在t时刻的损失函数如下:
Losst=−[ytln(Ot)+(yt−1)ln(1−Ot)]{Loss}_t = -[y_tln(O_t) + (y_t-1)ln(1-O_t)]Loss
t=−[ytln(Ot)+(yt−1)ln(1−Ot)]
全部N个时刻的损失函数(全局损失)表达为如下形式:
Loss=−∑t=1NLosst=−∑t=1N[ytln(Ot)+(yt−1)ln(1−Ot)]Loss = -{\sum_{t=1}^NLoss_t}=
-{\sum_{t=1}^N[y_tln(O_t) + (y_t-1)ln(1-O_t)]}Loss=−t=1∑NLosst=−t=1∑N[ytln(O
t)+(yt−1)ln(1−Ot)]
2、 softmax函数的求导公式为(下文用ψ表示\psi 表示ψ表示)
ψ′(x)=ψ(x)(1−ψ(x))\psi'(x)=\psi(x)(1-\psi(x))ψ′(x)=ψ(x)(1−ψ(x))
3、 激活函数的求导公式为(选取tanh(x)作为激活函数)
ϕ(x)=tanh(x)\phi(x) = tanh(x)ϕ(x)=tanh(x)
ϕ′(x)=(1−ϕ2(x))\phi'(x)=(1-{\phi^2(x)})ϕ′(x)=(1−ϕ2(x))
4、 BPTT算法
注: 由于RNN模型与时间序列有关,所以使用Back Propagation Through
Time(随时间变化反向传播的算法),但依旧遵循链式求导法则。在损失函数中,虽然RNN的额全局损失是与N个时刻有关的,但下面的推导仅涉及某个t时刻。
(1)求出t时刻下的损失函数关于Ot∗O_t^*Ot∗的微分:
∂Lt∂Ot∗=∂Lt∂Ot∗∂Ot∂Ot∗=∂Lt∂Ot∗∂ψ(Ot∗)∂Ot∗=∂Lt∂Ot∗ψ′(Ot∗)
\frac{\partial{L_t}}{\partial{O_t^*}} =\frac{\partial{L_t}}{\partial{O_t}} *
\frac{\partial{O_t}} {\partial{O_t^*}}=\frac{\partial{L_t}}{\partial{O_t}} *
\frac{\partial{\psi{(O_t^*)}}}
{\partial{O_t^*}}=\frac{\partial{L_t}}{\partial{O_t}} * \psi'(O_t^*)∂Ot∗∂Lt=∂
Ot∂Lt∗∂Ot∗∂Ot=∂Ot∂Lt∗∂Ot∗∂ψ(Ot∗)=∂Ot∂Lt∗ψ′(Ot∗)
(2)求出损失函数关于参数V的微分(需要(1)中的结论):
∂Lt∂V=∂Lt∂(VSt)∗∂(VSt)∂V=∂Lt∂Ot∗∗St=∂Lt∂Ot∗ψ′(Ot∗)∗St
\frac{\partial{L_t}}{\partial{V}} = \frac{\partial{L_t}}{\partial{(VS_t)}} *
\frac{\partial{(VS_t)}} {\partial{V}}=\frac{\partial{L_t}}{\partial{O_t^*}} *
S_t=\frac{\partial{L_t}}{\partial{O_t}} * \psi'(O_t^*)* S_t∂V∂Lt=∂(VSt)∂Lt∗
∂V∂(VSt)=∂Ot∗∂Lt∗St=∂Ot∂Lt∗ψ′(Ot∗)∗St
因此,全局关于参数V的微分为:
∂L∂V=∑t=1N∂Lt∂V=∑t=1N∂Lt∂Ot∗ψ′(Ot∗)∗St
\frac{\partial{L}}{\partial{V}}={\sum_{t=1}^{N}}\frac{\partial{L_t}}{\partial{V}}={\sum_{t=1}^{N}}\frac{\partial{L_t}}{\partial{O_t}}
* \psi'(O_t^*)* S_t∂V∂L=t=1∑N∂V∂Lt=t=1∑N∂Ot∂Lt∗ψ′(Ot∗)∗St
(3)求出t时刻的损失函数关于St∗S_t^*St∗的微分:
∂Lt∂St∗=∂Lt∂(VSt)∗∂(VSt)∂St∗∂St∂St∗=∂Lt∂Ot∗∗V∗ϕ′(St∗)=∂Lt∂Ot∗ψ′(Ot∗)∗V∗ϕ′(St∗)
\frac{\partial{L_t}}{\partial{S_t^*}} = \frac{\partial{L_t}}{\partial{(VS_t)}}
* \frac{\partial{(VS_t)}} {\partial{S_t}} * \frac{\partial{S_t}}
{\partial{S_t^*}}=\frac{\partial{L_t}}{\partial{O_t^*}}*V*\phi'(S_t^*)=\frac{\partial{L_t}}{\partial{O_t}}*\psi'(O_t^*)*V*\phi'(S_t^*)
∂St∗∂Lt=∂(VSt)∂Lt∗∂St∂(VSt)∗∂St∗∂St=∂Ot∗∂Lt∗V∗ϕ′(St∗)=∂Ot∂Lt∗
ψ′(Ot∗)∗V∗ϕ′(St∗)
(4)求出t时刻的损失函数关于St−1S_{t-1}St−1的微分
∂Lt∂St−1∗=∂Lt∂St∗∗∂St∗∂St−1∗=∂Lt∂St∗∗∂[Wϕ(St−1∗)+UXt]∂St−1∗=∂Lt∂St∗∗Wϕ′(St−1∗)
\frac{\partial{L_t}}{\partial{S_{t-1}^*}}=\frac{\partial{L_t}}{\partial{S_t^*}}
*\frac{\partial{S_t^*}}{\partial{S_{t-1}^*}}=
\frac{\partial{L_t}}{\partial{S_t^*}}
*\frac{\partial{[W\phi(S_{t-1}^*)}+UX_t]}{\partial{S_{t-1}^*}} =
\frac{\partial{L_t}}{\partial{S_t^*}} *W\phi'(S_{t-1}^*)∂St−1∗∂Lt=∂St∗∂Lt∗
∂St−1∗∂St∗=∂St∗∂Lt∗∂St−1∗∂[Wϕ(St−1∗)+UXt]=∂St∗∂Lt∗Wϕ′(St−1∗)
(5)求出t时刻关于参数U的偏微分
注:因为是时间序列模型,因此t时刻关于U
的微分与前(t-1)个时刻都相关,在具体计算时可以限定最远回溯到前n个时刻,但在推导时需将(t-1)个时刻全部代入计算
∂Lt∂U=∑k=1t∂Lt∂Sk∗∂Sk∗∂U=∑k=1t∂Lt∂Sk∗∂(WSk−1+UXk)∂U=∑k=1t∂Lt∂Sk∗∗Xk
\frac{\partial L_t}{\partial U}=\sum_{k=1}^{t}\frac{\partial L_t}{\partial
S_k^*}\frac{\partial S_k^*}{\partial U}=\sum_{k=1}^{t}\frac{\partial
L_t}{\partial S_k^*}\frac{\partial ({WS_{k-1}}+UX_k)}{\partial
U}=\sum_{k=1}^{t}\frac{\partial L_t}{\partial S_k^*}*X_k∂U∂Lt=k=1∑t∂Sk∗∂Lt
∂U∂Sk∗=k=1∑t∂Sk∗∂Lt∂U∂(WSk−1+UXk)=k=1∑t∂Sk∗∂Lt∗Xk
因此,全局关于U的损失偏微分为:
∂L∂U=∑t=1N∂Lt∂U=∑t=1N∑k=1t∂Lt∂Sk∗∂Sk∗∂U=∑t=1N∑k=1t∂Lt∂Sk∗∗Xk\frac{\partial
L}{\partial U}=\sum_{t=1}^{N}\frac{\partial L_t}{\partial
U}=\sum_{t=1}^{N}\sum_{k=1}^{t}\frac{\partial L_t}{\partial
S_k^*}\frac{\partial S_k^*}{\partial
U}=\sum_{t=1}^{N}\sum_{k=1}^{t}\frac{\partial L_t}{\partial S_k^*}*X_k∂U∂L=t=1∑
N∂U∂Lt=t=1∑Nk=1∑t∂Sk∗∂Lt∂U∂Sk∗=t=1∑Nk=1∑t∂Sk∗∂Lt∗Xk
(6)求出t时刻关于参数W的偏微分(同上)
∂Lt∂W=∑k=1t∂Lt∂Sk∗∂Sk∗∂W=∑k=1t∂Lt∂Sk∗∂(WSk−1+UXk)∂W=∑k=1t∂Lt∂Sk∗∗Sk−1
\frac{\partial L_t}{\partial W}=\sum_{k=1}^{t}\frac{\partial L_t}{\partial
S_k^*}\frac{\partial S_k^*}{\partial W}=\sum_{k=1}^{t}\frac{\partial
L_t}{\partial S_k^*}\frac{\partial ({WS_{k-1}}+UX_k)}{\partial
W}=\sum_{k=1}^{t}\frac{\partial L_t}{\partial S_k^*}*S_{k-1}∂W∂Lt=k=1∑t∂Sk∗∂
Lt∂W∂Sk∗=k=1∑t∂Sk∗∂Lt∂W∂(WSk−1+UXk)=k=1∑t∂Sk∗∂Lt∗Sk−1
因此,全局关于U的损失偏微分为:
∂L∂W=∑t=1N∂Lt∂W=∑t=1N∑k=1t∂Lt∂Sk∗∂Sk∗∂W=∑t=1N∑k=1t∂Lt∂Sk∗∗Sk−1\frac{\partial
L}{\partial W}=\sum_{t=1}^{N}\frac{\partial L_t}{\partial
W}=\sum_{t=1}^{N}\sum_{k=1}^{t}\frac{\partial L_t}{\partial
S_k^*}\frac{\partial S_k^*}{\partial
W}=\sum_{t=1}^{N}\sum_{k=1}^{t}\frac{\partial L_t}{\partial S_k^*}*S_{k-1}∂W∂L=
t=1∑N∂W∂Lt=t=1∑Nk=1∑t∂Sk∗∂Lt∂W∂Sk∗=t=1∑Nk=1∑t∂Sk∗∂Lt∗Sk−1
(7)由于大多数的输出为softmax函数,我们在对Ot∗O_t^*Ot∗进行softmax运算后求导可得
ψ′(Ot∗)=Ot(1−Ot)\psi'(O_t^*)=O_t(1-O_t)ψ′(Ot∗)=Ot(1−Ot)
所以在OtO_tOt进行微分求偏导可得(采用交叉熵作为损失函数)
∂Lt∂Ot=−∂[∑t=1N[ytln(Ot)+(yt−1)ln(1−Ot)]∂Ot=−(ytOt+yt−Ot1−Ot)=−yt−OtOt(1−Ot)
\frac{\partial L_t }{\partial O_t}=\frac{-\partial [\sum_{t=1}^N[y_tln(O_t) +
(y_t-1)ln(1-O_t)]}{\partial O_t}=-(\frac
{y_t}{O_t}+\frac{y_t-O_t}{1-O_t})=-{\frac{y_t-O_t}{O_t(1-O_t)}}∂Ot∂Lt=∂Ot−∂[
∑t=1N[ytln(Ot)+(yt−1)ln(1−Ot)]=−(Otyt+1−Otyt−Ot)=−Ot(1−Ot)yt−Ot
∂Lt∂Ot∗ψ′(Ot∗)=−yt−OtOt(1−Ot)∗Ot(1−Ot)=Ot−yt\frac{\partial L_t }{\partial
O_t}*\psi'(O_t^*)=-{\frac{y_t-O_t}{O_t(1-O_t)}}*O_t(1-O_t)=O_t-y_t∂Ot∂Lt∗ψ′(O
t∗)=−Ot(1−Ot)yt−Ot∗Ot(1−Ot)=Ot−yt
∂Lt∂St∗=∂Lt∂Ot∗ψ′(Ot∗)∗V∗ϕ′(St∗)=[V∗(Ot−yt)]∗[1−ϕ2(st∗)]=[V∗(Ot−yt)]∗[1−St2]
\frac{\partial{L_t}}{\partial{S_t^*}}
=\frac{\partial{L_t}}{\partial{O_t}}*\psi'(O_t^*)*V*\phi'(S_t^*)=
[V*(O_t-y_t)]*[1-{\phi^2(s_t^*)}]= [V*(O_t-y_t)]*[1-S_t^2]∂St∗∂Lt=∂Ot∂Lt∗ψ
′(Ot∗)∗V∗ϕ′(St∗)=[V∗(Ot−yt)]∗[1−ϕ2(st∗)]=[V∗(Ot−yt)]∗[1−St2]
∂Lt∂St−1∗=∂Lt∂St∗∗Wϕ′(St−1∗)=∂Lt∂St∗∗W∗[1−St−12]
\frac{\partial{L_t}}{\partial{S_{t-1}^*}}=
\frac{\partial{L_t}}{\partial{S_t^*}} *W\phi'(S_{t-1}^*)=
\frac{\partial{L_t}}{\partial{S_t^*}}*W*[1-S_{t-1}^2]∂St−1∗∂Lt=∂St∗∂Lt∗Wϕ′
(St−1∗)=∂St∗∂Lt∗W∗[1−St−12]
综上:
∂L∂V=∑t=1N∂Lt∂V=∑t=1N(Ot−yt)∗St
\frac{\partial{L}}{\partial{V}}={\sum_{t=1}^{N}}\frac{\partial{L_t}}{\partial{V}}={\sum_{t=1}^{N}}(O_t-y_t)
*S_t∂V∂L=t=1∑N∂V∂Lt=t=1∑N(Ot−yt)∗St
其余得类似
(8)我们逐步更新V,U,W三者得参数,直至它们收敛为之
V:=V−η∗∂L∂VV:=V-\eta*\frac{\partial L}{\partial V}V:=V−η∗∂V∂L
U:=U−η∗∂L∂UU:=U-\eta*\frac{\partial L}{\partial U}U:=U−η∗∂U∂L
W:=W−η∗∂L∂WW:=W-\eta*\frac{\partial L}{\partial W}W:=W−η∗∂W∂L