当前位置: 代码迷 >> 综合 >> RNN 训练算法 —— 反向传播(Backpropagation Through Time)
  详细解决方案

RNN 训练算法 —— 反向传播(Backpropagation Through Time)

热度:29   发布时间:2024-03-07 15:30:35.0

在这里插入图片描述

参见基本框架:https://goodgoodstudy.blog.csdn.net/article/details/109245095

问题描述

考虑模型循环网络模型:
x(k)=f[Wx(k?1)](1)x(k) = f[Wx(k-1)] \tag1{} x(k)=f[Wx(k?1)](1)
其中 x(k)∈RNx(k) \in R^Nx(k)RN表示网络节点状态,W∈RN×NW\in R^{N\times N}WRN×N表示网络结点之间相互连接的权重,网络的输出节点为 {xi(k)∣i∈O}\{x_i(k)| i\in O\}{ xi?(k)iO}OOO为所有输出(或称“观测”)单元的下标集合

在这里插入图片描述
训练的目标是为了减少观测状态和预期值之间误差,即最小化损失函数:
E=12∑k=1K∑i∈O[xi(k)?di(k)]2(2)E = \frac{1}{2}\sum_{k=1}^K \sum_{i\in O} [x_i(k) - d_i(k)]^2 \tag{2} E=21?k=1K?iO?[xi?(k)?di?(k)]2(2)
其中 di(k)d_i(k)di?(k) 表示 kkk 时刻第 iii 个节点的预期值

采用梯度下降法更新 WWW:
W+=W?ηdEdWW_+ = W - \eta \frac{dE}{dW} W+?=W?ηdWdE?

符号约定

W≡[—–w1T—–?—–wNT—–]N×NW \equiv \begin{bmatrix} \text{-----} w_1^T \text{-----} \\ \vdots \\ \text{-----} w_N^T \text{-----} \end{bmatrix}_{N\times N} W????—–w1T?—–?—–wNT?—–?????N×N?
将矩阵 WWW 拉成列向量,记为 www
w=[w1T,?,wNT]T∈RN2w = [w_1^T, \cdots, w_N^T]^T \in R^{N^2} w=[w1T?,?,wNT?]TRN2
把所有时间的状态拼成列向量,记为 xxx
x=[xT(1),?,xT(K)]T∈RNKx = [x^T(1), \cdots, x^T(K)]^T \in R^{NK} x=[xT(1),?,xT(K)]TRNK
将RNN 的训练视为约束优化问题,(1)式转化成约束条件:
g(k)≡f[Wx(k?1)]?x(k)=0,k=1,…,K(3)g(k) \equiv f[Wx(k-1)] - x(k) =0, \quad k=1,\ldots ,K \tag{3} g(k)f[Wx(k?1)]?x(k)=0,k=1,,K(3)

g=[gT(1),…,gT(K)]T∈RNKg = [g^T(1), \ldots, g^T(K)]^T \in R^{NK} g=[gT(1),,gT(K)]TRNK


0=dg(x(w),w)dw=?g(x(w),w)?x?x(w)?w+?g(x(w),w)?w(4)0 = \frac{dg(x(w),w)}{dw} = \frac{\partial g(x(w),w)}{\partial x}\frac{\partial x(w)}{\partial w} + \frac{\partial g(x(w),w)}{\partial w} \tag{4} 0=dwdg(x(w),w)?=?x?g(x(w),w)??w?x(w)?+?w?g(x(w),w)?(4)
dEdw=?E?x(?g?x)?1?g?w(5)\frac{dE}{dw} = \frac{\partial E}{\partial x} \left(\frac{\partial g}{\partial x}\right)^{-1} \frac{\partial g}{\partial w} \tag{5} dwdE?=?x?E?(?x?g?)?1?w?g?(5)
(5)中三项如下:
1.
?E?x=[e(1),…,e(K)]ei(k)={xi(k)?di(k),if i∈O,0,otherwise. k∈1,…,K.\begin{aligned} \frac{\partial E}{\partial x} &= [e(1), \ldots, e(K)] \\\\ e_i(k)&= \begin{cases} x_i(k) - d_i(k), &\text{if } i\in O, \\ 0, &\text{otherwise. } \end{cases} k \in 1,\ldots,K. \end{aligned} ?x?E?ei?(k)?=[e(1),,e(K)]={ xi?(k)?di?(k),0,?if iO,otherwise. ?k1,,K.?
2.
?g?x=[?I00…0D(1)W?I0…00D(2)W?I…0?????000D(K?1)W?I]NK×NK\frac{\partial g}{\partial x} = \begin{bmatrix} -I & 0& 0 &\ldots & 0\\ D(1)W & -I & 0 &\ldots & 0 \\ 0 & D(2)W & -I & \ldots & 0 \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & 0 & D(K-1)W& -I \end{bmatrix}_{NK\times NK} ?x?g?=?????????ID(1)W0?0?0?ID(2)W?0?00?I?0??D(K?1)W?000??I?????????NK×NK?
其中
D(j)=[f′(w1Tx(j))0?0f′(wNTx(j))]D(j)= \begin{bmatrix} f'(w_1^Tx(j)) & &0\\ & \ddots & \\ 0& & f'(w_N^Tx(j)) \end{bmatrix} D(j)=???f(w1T?x(j))0???0f(wNT?x(j))????

?g?w=[D(0)X(0)D(1)X(1)?D(K?1)X(K?1)]\frac{\partial g}{\partial w} = \begin{bmatrix} D(0)X(0)\\ D(1)X(1) \\ \vdots \\ D(K-1)X(K-1) \end{bmatrix} ?w?g?=??????D(0)X(0)D(1)X(1)?D(K?1)X(K?1)???????
其中
X(k)?[xT(k)xT(k)?xT(k)]N×N2X(k) \triangleq \begin{bmatrix} x^T(k) &&& \\ & x^T(k)&& \\ && \ddots & \\ &&& x^T(k) \end{bmatrix}_{N\times N^2} X(k)??????xT(k)?xT(k)???xT(k)??????N×N2?
在这里插入图片描述

反向传播


δ=?E?x(?g?x)?1∈R1×NK(6)\delta = \frac{\partial E}{\partial x} \left(\frac{\partial g}{\partial x}\right)^{-1} \in R^{1\times NK}\tag{6} δ=?x?E?(?x?g?)?1R1×NK(6)
然后计算
dEdw=?δ?g?w\frac{dE}{dw} =- \delta \frac{\partial g}{\partial w} dwdE?=?δ?w?g?


(6)式变形为:
δ?g?x=?E?x\delta \frac{\partial g}{\partial x} = \frac{\partial E}{\partial x} δ?x?g?=?x?E?

δ=[δ(1)…δ(K)]1×NK,δ(k)∈R1×N\delta = \begin{bmatrix} \delta(1) & \ldots &\delta(K) \end{bmatrix}_{1\times NK}, \quad \delta(k) \in R^{1\times N} δ=[δ(1)??δ(K)?]1×NK?,δ(k)R1×N
则有
[δ(1)…δ(K)][?I00…0D(1)W?I0…00D(2)W?I…0?????000D(K?1)W?I]=[e(1),…,e(K)]\begin{bmatrix} \delta(1) & \ldots &\delta(K) \end{bmatrix} \begin{bmatrix} -I & 0& 0 &\ldots & 0\\ D(1)W & -I & 0 &\ldots & 0 \\ 0 & D(2)W & -I & \ldots & 0 \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & 0 & D(K-1)W& -I \end{bmatrix}= [e(1), \ldots, e(K)] [δ(1)??δ(K)?]?????????ID(1)W0?0?0?ID(2)W?0?00?I?0??D(K?1)W?000??I?????????=[e(1),,e(K)]
解得:
δ(K)=?e(K)δ(k)=δ(k+1)D(k)W?e(k),k=1,…,K?1\begin{aligned} \delta(K) &= - e(K) \\ \delta(k) &= \delta(k+1)D(k)W - e(k), \\ k&=1,\ldots,K-1 \end{aligned} δ(K)δ(k)k?=?e(K)=δ(k+1)D(k)W?e(k),=1,,K?1?
所以
dEdw=?δ?g?w=?[δ(1)…δ(K)][D(0)X(0)D(1)X(1)?D(K?1)X(K?1)]=?∑k=1Kδ(k)D(k?1)X(k?1)\begin{aligned} \frac{dE}{dw} &= - \delta \frac{\partial g}{\partial w} \\ &= - \begin{bmatrix} \delta(1) & \ldots &\delta(K) \end{bmatrix}\begin{bmatrix} D(0)X(0)\\ D(1)X(1) \\ \vdots \\ D(K-1)X(K-1) \end{bmatrix}\\ &= -\sum_{k=1}^K \delta(k)D(k-1)X(k-1) \end{aligned} dwdE??=?δ?w?g?=?[δ(1)??δ(K)?]??????D(0)X(0)D(1)X(1)?D(K?1)X(K?1)???????=?k=1K?δ(k)D(k?1)X(k?1)?
其中
δ(k)D(k?1)X(k?1)=[δ1(k)…δN(k)]1×N[f′(w1Tx(k?1))0?0f′(wNTx(k?1))]N×N[xT(k?1)?xT(k?1)]N×N2=[δ1(k)f′(w1Tx(k?1))xT(k?1)…δN(k)f′(wNTx(k?1))xT(k?1)]1×N2\begin{aligned} & \delta(k)D(k-1)X(k-1) \\ &= \begin{bmatrix} \delta_1(k) & \ldots &\delta_N(k) \end{bmatrix}_{1\times N} \begin{bmatrix} f'(w_1^Tx(k-1)) & &0\\ & \ddots & \\ 0& & f'(w_N^Tx(k-1)) \end{bmatrix}_{N\times N} \begin{bmatrix} x^T(k-1) && \\ & \ddots & \\ && x^T(k-1) \end{bmatrix}_{N\times N^2} \\ &= \begin{bmatrix} \delta_1(k) f'(w_1^Tx(k-1))x^T(k-1) & \ldots &\delta_N(k) f'(w_N^Tx(k-1))x^T(k-1) \end{bmatrix}_{1\times N^2} \end{aligned} ?δ(k)D(k?1)X(k?1)=[δ1?(k)??δN?(k)?]1×N????f(w1T?x(k?1))0???0f(wNT?x(k?1))????N×N????xT(k?1)???xT(k?1)????N×N2?=[δ1?(k)f(w1T?x(k?1))xT(k?1)??δN?(k)f(wNT?x(k?1))xT(k?1)?]1×N2??
所以矩阵形式的梯度 dEdW∈RN×N\frac{dE}{dW} \in R^{N\times N}dWdE?RN×N
dEdW=?∑k=1K[δ1(k)f′(w1Tx(k?1))xT(k?1)?δN(k)f′(wNTx(k?1))xT(k?1)]N×N=?∑k=1K[f′(w1Tx(k?1))0?0f′(wNTx(k?1))]N×N[δ1(k)?δN(k)]N×1xT(k?1)=?∑k=1KD(k?1)δT(k)xT(k?1)\begin{aligned} \frac{dE}{dW} &= -\sum_{k=1}^K \begin{bmatrix} \delta_1(k) f'(w_1^Tx(k-1))x^T(k-1) \\ \vdots \\ \delta_N(k) f'(w_N^Tx(k-1))x^T(k-1) \end{bmatrix}_{N\times N} \\ &= -\sum_{k=1}^K \begin{bmatrix} f'(w_1^Tx(k-1)) & &0\\ & \ddots & \\ 0& & f'(w_N^Tx(k-1)) \end{bmatrix}_{N\times N} \begin{bmatrix} \delta_1(k) \\ \vdots \\ \delta_N(k) \end{bmatrix}_{N\times 1} x^T(k-1) \\ &= -\sum_{k=1}^K D(k-1)\delta^T(k)x^T(k-1) \end{aligned} dWdE??=?k=1K?????δ1?(k)f(w1T?x(k?1))xT(k?1)?δN?(k)f(wNT?x(k?1))xT(k?1)?????N×N?=?k=1K????f(w1T?x(k?1))0???0f(wNT?x(k?1))????N×N?????δ1?(k)?δN?(k)?????N×1?xT(k?1)=?k=1K?D(k?1)δT(k)xT(k?1)?
在这里插入图片描述

  相关解决方案