长短期记忆 (LSTM) 网络是一种广泛用于序列预测任务的递归神经网络 (RNN)。在 PyTorch 中,nn.LSTM 模块是实现这些网络的强大工具。然而,了解 LSTM 的 “隐藏” 和 “输出” 状态之间的区别可能会让许多人感到困惑。本文旨在阐明这些概念,提供详细的解释和示例,以帮助您了解 LSTM 在 PyTorch 中的工作原理。


目录

PyTorch LSTM:隐藏状态与输出
隐藏状态和输出之间的差异
示例代码:访问 Hidden 状态和输出:
在 Hidden State 和 Output 之间进行选择
性能影响

PyTorch LSTM:隐藏状态与输出


1. 隐藏状态 (h_n)


LSTM 中的隐藏状态表示网络的短期记忆。它包含有关到目前为止已处理的序列的信息,并在每个时间步长更新。隐藏状态对于跨时间步长和层维护信息至关重要。

形状:h_n的隐藏状态具有形状 (num_layers * num_directions、batch hidden_size)。此形状表示 LSTM 中的每个层和方向都保持隐藏状态。

2. 输出 (output)

LSTM 的输出是每个时间步的最后一层的隐藏状态序列。与 hidden 状态不同,hidden 状态只是每个序列的最后一个隐藏状态,而 output 包括 sequence 中每个时间步的 hidden 状态。

形状:输出的形状为 (seq_len, batch, num_directions * hidden_size),其中 seq_len 是输入序列的长度。

隐藏状态和输出之间的差异

  • 范围:隐藏状态 (h_n) 是批处理中每个元素的最终隐藏状态,而输出包含序列中所有时间步的隐藏状态。
  • 用法:隐藏状态通常用于需要整个序列摘要的任务,例如分类,而输出用于需要在每个时间步进行预测的任务,例如序列生成。

示例代码:访问 Hidden 状态和输出:

以下是如何在 PyTorch 中实现 LSTM 并访问隐藏状态和输出的示例:

Output shape: torch.Size([5, 3, 20])

Hidden state shape: torch.Size([2, 3, 20])

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐