如何计算 LSTM 的参数量
理论上的参数量之前翻译了 Christopher Olah 的那篇著名的 Understanding LSTM Networks,这篇文章对于整体理解 LSTM 很有帮助,但是在理解 LSTM 的参数数量这种细节方面,略有不足。本文就来补充一下,讲讲如何计算 LSTM 的参数数量。建议阅读本文前先阅读 Understanding LSTM Networks 的原文或我的译文。首先来回顾下 LSTM
理论上的参数量
之前翻译了 Christopher Olah 的那篇著名的 Understanding LSTM Networks,这篇文章对于整体理解 LSTM 很有帮助,但是在理解 LSTM 的参数数量这种细节方面,略有不足。本文就来补充一下,讲讲如何计算 LSTM 的参数数量。
首先来回顾下 LSTM。一层 LSTM 如下:
这里的 x t x_t xt 实际上是一个句子的 embedding(不考虑 batch 维度),shape 一般为 [seq_length, embedding_size]。图中的 A A A 就是 cell, x t x_t xt 中的词依次进入这个 cell 中进行处理。可以看到其实只有这么一个 cell,所以每次词进去处理的时候,权重是共享的,将这个过程平铺展开,就是下面这张图了:
实际上我觉得这里 x t x_t xt 并不准确,第一个 x t x_t xt 应该指的是整句话,而第二个 x t x_t xt 应该指的是这句话中最后一个词,所以为了避免歧义,我认为可以将第一个 x t x_t xt 重命名为 x x x,第二个仍然保留,即现在 x x x 表示一句话,该句话有 t + 1 t+1 t+1 个词, x t x_t xt 表示该句话的第 t + 1 t+1 t+1 个词, t ∈ [ 0 , t ] t \in [0, t] t∈[0,t]。
始终要记住这么多 A A A 都是一样的,权重是一样的, x 0 x_0 x0 到 x t x_t xt 是一个个词,每一次的处理都依赖前一个词的处理结果,这也是 RNN 系的网络难以像 CNN 一样并行加速的原因。同时, 这就像一个递归过程,如果把求 h t h_t ht 的公式展开写, A A A 里的权重记为 W W W,那么就会发现需要 t t t 个 W W W 相乘,即 W t W^t Wt,这是非常恐怖的:
0. 9 100 = 2.6561398887587544 × 1 0 − 5 0.9^{100} = 2.6561398887587544 \times 10^{-5} 0.9100=2.6561398887587544×10−5
1. 1 100 = 13780.61233982238 1.1^{100} = 13780.61233982238 1.1100=13780.61233982238
一个不那么小的数被多次相乘之后会变得很小,一个不那么大的数被多次相乘之后会变得很大。所以,这也是普通 RNN 容易出现梯度消失/爆炸的问题的原因。
扯远了点。
那么 LSTM 的参数很明显了,就是这个 A A A 中的参数。这个 A A A 内部具体是这样的:
从这张图来理解参数的数量你可能有点懵逼,一步一步来看,实际上这里面有 4 个非线性变换(3 个 门 + 1 个 tanh),每一个非线性变换说白了就是一个两层的全连接网络。重点来了,第一层是 x i x_i xi 和 h i h_i hi 的结合,维度就是 embedding_size + hidden_size,第二层就是输出层,维度为 hidden_size,所以该网络的参数量就是:
(embedding_size + hidden_size) * hidden_size + hidden_size
一个 cell 有 4 个这样结构相同的网络,那么一个 cell 的总参数量就是直接 × 4:
((embedding_size + hidden_size) * hidden_size + hidden_size) * 4
注意这 4 个权重可不是共享的,都是独立的网络。
所以,一般来说,一层 LSTM 的参数量计算公式是:
4 [ d h ( d h + d x ) + d h ] 4[d_h(d_h + d_x) + d_h] 4[dh(dh+dx)+dh]
其中 4 表示有 4 个非线性映射层, d _ h + d _ x d\_h + d\_x d_h+d_x 即 Understanding LSTM Networks 中的 [ h _ t − 1 , x _ t ] [h\_{t-1}, x\_t] [h_t−1,x_t] 的维度,后面的 d h d_h dh 表示 bias 的数量。所以,LSTM 层的参数数量只与输入维度 d _ x d\_x d_x 和输出维度 d _ h d\_h d_h 相关,和普通全连接层相同。
那么显而易见,一层双向 LSTM 的参数量就是上述公式 × 2。
TensorFlow 中的实现
在 TensorFlow 中,这些 d x d_x dx、 d h d_h dh 如何与代码对应上呢?
我们可以如下实现一个简单的以 LSTM 为核心的网络:
import tensorflow as tf
model = tf.keras.model.Sequential(
tf.keras.layers.Embedding(1000, 128),
tf.keras.layers.LSTM(units=64),
tf.keras.layers.Dense(10)
)
model.summary()
输入如下:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding_1 (Embedding) (None, None, 128) 128000
_________________________________________________________________
lstm_1 (LSTM) (None, 64) 49408
_________________________________________________________________
dense_1 (Dense) (None, 10) 650
=================================================================
Total params: 178,058
Trainable params: 178,058
Non-trainable params: 0
_________________________________________________________________
可以看到 TF 给出的 LSTM 层参数量是 49408。我们来根据上面的公式验证下。
- d x d_x dx:输入维度,在这里就对应于 128,就是词向量维度。
- d h d_h dh:输出维度,在这里就是 LSTM 的参数 64,在 TF 这里叫
units。
所以,参数量就是 4 × [ 64 × ( 64 + 128 ) + 64 ] = 49408 4 \times \left[64 \times \left(64+128\right) + 64 \right] = 49408 4×[64×(64+128)+64]=49408,和 TF 给出的一样。
另外,tf.keras.layers.LSTM() 的默认输出大小为 [batch_size, units],就是只使用最后一个 time step 的输出。假如我们想要得到每个 time step 的输出( h 0 , ⋯ , h t h_0,\cdots,h_t h0,⋯,ht)和最终的 cell state( C t C_t Ct),那么我们可以指定另外两个参数 return_sequences=True 和 return_state=True:
inputs = tf.random.normal([64, 100, 128]) # [batch_size, seq_length, embedding_size]
whole_seq_output, final_memory_state, final_carry_state = tf.keras.layers.LSTM(64, return_sequences=True, return_state=True)(inputs)
print(f"{whole_seq_output.shape=}")
print(f"{final_memory_state.shape=}")
print(f"{final_carry_state.shape=}")
输出:
whole_seq_output.shape=TensorShape([32, 100, 64]) # 100 表示有 100 个词,即 100 个 time step
final_memory_state.shape=TensorShape([32, 64])
final_carry_state.shape=TensorShape([32, 64])
OK,LSTM 的参数量应该挺清晰了,欢迎在评论区留下你的想法。😋
Reference
- Counting No. of Parameters in Deep Learning Models by Hand
- deep learning - Number of parameters in an LSTM model - Data Science Stack Exchange
- machine learning - How to calculate the number of parameters of an LSTM network? - Stack Overflow
- tensorflow - In Keras, what exactly am I configuring when I create a stateful
LSTMlayer with Nunits? - Stack Overflow - 理解 LSTM 网络 · Alan Lee
- Recurrent Neural Networks (RNN) with Keras | TensorFlow Core
- LSTM is dead. Long Live Transformers! - YouTube
END
更多推荐


所有评论(0)