(三)长短期记忆-LSTM为什么比RNN更好

LSTM 是在 1997 年被提出来的。

LSTM 结构

LSTM 的结构比 RNN 要复杂,其中包含 4 个参数矩阵(相比于 RNN 只有一个参数矩阵),可从训练数据中反向传播而得到更新学习。

传送带

LSTM 中有一个传送带,可以将过去的信息 Ct1C_{t-1} 直接传递给未来的 CtC_t

四个参数矩阵

遗忘门

遗忘门(forget gate)WfW_f 决定什么数据能通过,什么不能通过,或者通过百分之多少,σ\sigma 是 Sigmoid function

遗忘门计算出来的 ftf_tc\bold c 向量逐元素相乘,得到结果,在传送带上面继续往下走。

输入门

输入门(input gate)WiW_i 决定传送带上要更新的值

ht1\bold h_{t-1}xt\bold x_t 拼接在一起,点乘参数矩阵 Wi\bold W_i ,其结果通过 sigmoid\text sigmoid 函数得到输入门结果 it\bold i_t

新值

新值(new value)c~t\tilde c _ t 代表要加入传送带上的新值

ht1\bold h_{t-1}xt\bold x_t 拼接在一起,点乘参数矩阵 Wc\bold W_c ,其结果通过 tanh\tanh 函数得到输入门结果 c~t\tilde c_t

然后 **遗忘门的结果 **和

输入门新值逐元素相乘的结果)相加得到新的结果作为 ct\bold c_t,送到传送带上。

输出门

输出门(output gate)WoW_o 决定如何从 ct1\bold c_{t-1} 得到 ht\bold h_t

ht1\bold h_{t-1}xt\bold x_t 拼接在一起,点乘参数矩阵 Wo\bold W_o ,其结果通过 sigmoid\text sigmoid 函数得到输入门结果 ot\bold o_t

ot\bold o_t 和刚刚计算出来放在传送带上的 ct\bold c_t (经过双曲正切函数激活)逐元素相乘得到最终的 ht\bold h_t

ht\bold h_t 分两份,一份作为本单元(unit)的输出,一份传递到下一个 LSTM 单元(unit)。

LSTM 参数

4 个参数矩阵,每个的大小都和一个 SimpleRNN 一样。所以参数数量是 RNN 的 4 倍。

  • #rows: shape(h)

  • #cols: shape(h) + shape(x)

  • total #parameters: 4 × shape(h) × [shape(h)+shape(x)]

LSTM 实践

在使用的时候,通过指定 return_sequences=False 可以只用最后一个状态 hth_t,设置 stateful=True可以使下一个 batch 训练时使用上一个 batch 训练的 ht\bold h_t 结果作为此次 batch 的输出 h0\bold h_0

model = tf.keras.Sequential([
tf.keras.layers.Embedding(vocab_size, embedding_size, input_length=seq_length),
tf.keras.layers.LSTM(state_dim, return_sequences=False, stateful=True),
tf.keras.layers.Dense(1, activation="sigmoid")
], name="LSTM")
model.summary()
Model: "LSTM"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding (Embedding) (None, 25, 512) 5120000
_________________________________________________________________
lstm (LSTM) (None, 1024) 6295552
_________________________________________________________________
dense (Dense) (None, 1) 1025
=================================================================
Total params: 11,416,577
Trainable params: 11,416,577
Non-trainable params: 0
_________________________________________________________________

当然,LSTM 也可以做成双向的(Bi-directional LSTM),因为 LSTM 也会有遗忘问题,也可以把 LSTM 叠起来(Stacked LSTM)使用,增加参数矩阵的数量。


   转载规则


《(三)长短期记忆-LSTM为什么比RNN更好》 Harbor Zeng 采用 知识共享署名 4.0 国际许可协议 进行许可。
 上一篇
(四)使用LSTM做文本生成 (四)使用LSTM做文本生成
这里使用一部 90 多万字小说《琉璃美人煞》为例,使用 LSTM 方法做一次文本生成。 从一句预测下一句 Input data: '玑庸懒外表下的,是一颗琉璃般清澈冰冷的心,前世种种因果,让她今世不懂情感。对修仙'Target data: '庸懒外表下的,是一颗琉璃般清澈冰冷的心,前世种种因果,让她今世不懂情感。对修仙的'Input data: &#x
2021-01-15
下一篇 
(二)SimpleRNN 更适合时序数据的模型 (二)SimpleRNN 更适合时序数据的模型
为什么要 RNN 全连接的逻辑回归有什么局限性? 将整段文字一起处理(one to one) 输入输出是固定的形状 RNN(Recurrent Neural Networks 循环神经网络)更适合序列数据(many to one)。 RNN 内部详解 循环神经网络,顾名思义,单词一个一个的进行训练,x0\bold x_0x0​ 和初始 0 向量拼接在一起,与 A\bold AA 矩阵,
2021-01-14
  目录