LSTM公式详解&推导
《Surpervised Sequence Labelling with Recurrent Neural Network》(《用循环神经网络进行序列标记》),RNN(Recurrent Neural Network,循环神经网络)经典教材,由多伦多大学Alexander Graves所著,详细叙述了各种RNN模型及其推导。本文介绍该书的LSTM部分。对于该书,想深入了解的朋友点这里获取资源。
LSTM理解
LSTM(Long Short-Term Memory Networks,长短时记忆网络),由Hochreiter和Schmidhuber于1997年提出,目的是解决一般循环神经网络中存在的梯度爆炸(输入信息激活后权重过小)及梯度消失(例如sigmoid、tanh的激活值在输入很大时其梯度趋于零)问题,主要通过引入门和Cell状态的概念来实现梯度的调整,已被大量应用于时间序列预测等深度学习领域。
下面的描述主要侧重公式推导,对LSTM来由更详细的讨论请见《Step-by-step to LSTM: 解析LSTM神经网络设计原理》。
LSTM采用了门控输出的方式,即三门(输入门、遗忘门、输出门)两态(Cell State长时、Hidden State短时)。其核心即Cell State,指用于信息传播的Cell的状态,在结构示意图(图1,图源Understanding LSTMs,略改动)中是最上面的直链(从到)。
图1
和本时刻的输入值,由这两个参数 先进入遗忘门,得到决定要舍弃的信息 (即权重较小的信息)后,再进入输入门,得到决定要更新的信息 (即与上一Cell相比权重较大的信息)以及当前时刻的Cell状态 (候选向量,可理解为中间变量,存储当前 Cell State 信息),最后由这两个门(遗忘门,输入门)的输出值(即 )进行组合(上一Cell状态要遗忘信息的激活值 与 当前时刻Cell状态需要记忆信息的激活值进行叠加,从图中可以更直观得到),得到分别的长时()和短时()信息,最后进行存储操作及对下一个神经元的输入。下图2介绍了LSTM在网络中是如何工作的。
图2
根据图1,可依次得到三个门的形式方程如下(符号与图中保持一致):
遗忘门:输入门:
以及时刻的Cell 状态(长时)方程:
输出门:算法及公式根据上面的描述及图1,首先定义如下符号(符号为方便理解,与书中保持一致):
一些函数:门的激活函数:Cell输入的激活函数:Cell输出的激活函数: 训练模型时的损失函数:Sigmoid激活函数
:tanh激活函数
一些符号:输入层信息的数量:输出层信息的数量:隐层Cell状态的数量(注意这里的Cell与下面的Cell不同,代表短时记忆Cell),指图1中最下面的一条直链,即从到,处理短时记忆:Cell状态信息(长时记忆状态)的数量:总时间数(网络层总数),即
:下标,指一个LSTM单元的遗忘门:下标,指一个LSTM单元的输入门:下标,指一个LSTM单元的输出门:下标,指神经元中某一个记忆元胞(Cell)
:从单元到单元的权重:时刻第个单元的激活值,在时初始化为:时刻第个单元的带权输入,可作抽象定义如下 :时刻记忆元胞 的状态(State),在时初始化为 :时刻第个单元的误差,在时初始化为。一般化的定义为前向传播
由上述的形式方程,很容易得到下面的前向传播公式:
遗忘门。由图1可知,遗忘门的输出依赖三个变量(图1中表示为左下角的两个输入和左上角的一个输入),分别是:上一时刻神经元的短时记忆输出,本时刻神经元的输入以及上一时刻神经元的长时记忆输出Cell状态,乘以权重因子后对层数求和即可得到遗忘门的输入值及激活值如下:输入门。其输出所依赖的变量与遗忘门相同,故同理可得Cell状态。由输入门的时刻的Cell 状态(长时)方程立即可得。一一对应形式方程即可得到表达式如下
重头戏来了!建议不熟悉反向传播的朋友看一下我的另一篇文章nndl学习笔记(二)反向传播公式推导,帮助你快速理解&回顾反向传播。
同样地,为了与前向传播对应,这里也采用五个部分进行证明。反向传播,其目的就是通过计算损失函数关于权重和偏置的偏导数(本例中不对偏置进行分析),从而得到每一个神经元上出现的误差(误差定义为损失函数对神经元输入的偏导数),最后均摊给每个神经元,以此逐步减小误差。因为需要反向传播,所以顺序与前向传播正好相反(从后往前计算)。
关于误差的定义Cell输出的误差(短时记忆)Cell状态的误差(长时记忆):时刻第个单元的误差,在时初始化为。定义为公式推导这些公式的核心,都是根据链式法则求偏导数,需要注意损失函数与哪些变量有关,找准变量,再应用求导法则,即可轻松计算出表达式。
Cell输出(短时记忆)。首先找Cell输出与哪些量有关,从图1可以得知其只与隐层(Cell短时记忆状态)和输出层两个部分的信息有关,再根据误差定义,可以得到:
注意到这里层时间状态取而层取,是为了与前向传播式子的意义保持一致,即:隐层Cell状态前向传播需要前一时刻的隐层Cell状态,而输出只需与本时刻输入的时刻一致即可,而反向传播正好相反(具体可见图1)。
再根据带权输入的一般定义(同上,需要根据情况构造定义式,即:层时刻变化而层时刻保持不变)
代入得到(注意这里有一步化简,去掉求和号,具体原因可见nndl学习笔记(二)反向传播公式推导公式一的推导部分):
输出门。
这里只需用到误差定义式及前向传播的式,最后一步求和是指针对所有神经元输出门激活值误差的叠加。
Cell状态(长时记忆)。最长的一个式子,但是把握好变量之间的关系就可以轻松得出(直接寻找前向传播众多公式中哪个含有变量 ,这样再进行链式法则处理,会更加直观,由于五个式子都含有,故下面第四个等号后的式子有五项)。
推导过程与Cell输出(短时记忆)部分类似,要用到误差的一般定义,并注意到本时刻Cell状态(长时记忆)是由上一时刻遗忘门和输入门的输出共同决定的(反映在图上就是图1中上面直链的加号);在反向传播中,除了需要将Cell状态(长时记忆)的时间取反,还要考虑三个门误差的积累(第二个等号后式子第一项),注意这里计算输出门误差时没有取后一时刻,是因为遗忘门和输入门的误差在前向传播时会传递给下一时刻的带权输入,故反向传播需要后一时刻来计算误差;而输出门误差在本时刻即可计算。反映到方程上为第二个等号后的方程。
Cell输出(短时记忆)。
只需应用前向传播的式,即可得到:
遗忘门。方法同输出门推导,只需应用前向传播的式,可立即得到:
输入门。方法同输出门,只需应用前向传播的式,即可得到:
总结
本文介绍了这本书LSTM部分(第四章)的流程详解及公式推导,其中难免会有些许错误,望大家指出。得到公式后,下一步就是编程实现了,这里可以参考另一篇文章零基础入门深度学习(6) - 长短时记忆网络(LSTM),有非常细致的讲解。第一篇万字长文(其实主要是公式多),如果有用就点个赞吧!
P.S. PPT绘图大法是真的香,对我这种小白十分友好,有兴趣的朋友可以玩玩
版权声明
本文仅代表作者观点,不代表博信信息网立场。