安装

pip下安装:

pip install torch-summary

conda下安装:

conda install torch-summary

 注:torchsummary与torch-summary是两个不同库!后者是前者的升级版,添加更多功能且解决了部分bug,因此推荐使用torch-summary!

使用

通过nn.Module构建一个模型(以一个简单的LSTM为例):

import torch.nn as nn
class LSTM(nn.Module):
    def __init__(self):
        super(LSTM, self).__init__()
        self.lstm = nn.LSTM(input_size=1, hidden_size=16, batch_first=True)
        self.mlp = nn.Sequential(
            nn.Linear(16, 8),
            nn.ReLU(),
            nn.Linear(8, 1)
        )

    def forward(self, x):
        x = self.lstm(x)[0][:, -1, :]
        res = self.mlp(x)
        return res

if __name__ == '__main__':
    model = LSTM()
    input = torch.randn(8, 32, 1)
    output = model(input)

使用torch-summary可视化网络: 

import torchsummary

model = LSTM()
torchsummary.summary(model, input_size=(32, 1), batch_size=8)

结果输出:

        图中可看出模型的层次结构以及各层的参数统计:包括LSTM和Sequential层,分别包含1216、145个参数,其中Sequential层的两个Linear层分别包含136、9个参数。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐