torch.nn.RNN(input_size, hidden_size, num_layers)

pytorch官方文档链接:https://pytorch.org/docs/stable/generated/torch.nn.RNN.html#torch.nn.RNN

input_size:每个token作为输入时的向量长度
hidden_size:中间的隐层向量长度
num_layers:RNN模型的层数

以下对于batch_size=1举例
rnn = nn.RNN(10, 20, 2)
input = torch.randn(3, 10)
h0 = torch.randn(2, 20)
output, hn = rnn(input, h0)
# output.shape应该是(3,20);hn.shape应该是(2,20)

计算过程可根据下图理解。
在这里插入图片描述
官方文档中计算h_t的公式可根据手绘图中的“框1”理解。
在这里插入图片描述

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐