在PyTorch中,可以使用torchstat这个库来查看网络模型的一些信息,包括总的参数量params、MAdd、显卡内存占用量和FLOPs等。

使用前需要先安装torchstat包,如下:

pip install torchstat

示例代码如下:

from torchstat import stat
from torchvision.models import resnet50, resnet101, resnet152, resnext101_32x8d

model = resnet50()
stat(model, (3, 224, 224))

如果只是想看模型的总参数量,可以通过如下方式:

total = sum([param.nelement() for param in model.parameters()])
print("Number of parameters: %.2fM" % (total/1e6))

stat打印完整信息如下:

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐