示例:

import torch
a = torch.randn(2,3,4)
print(a.numel())    # 24

统计model中所有可训练参数量:

num_params = sum(p.numel() for p in model.parameters())

注:numel() 是pytorch的函数,只适用于 tensor,不能用于统计 list、tuple、dict 等的元素数量。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐