pytorch 保存和加载模型的方法有两种:

1.保存网络的参数

import torch
#导入模块

net=Net()
#创建网络,当然还需要损失函数梯度等省略


PATH='state_dict_model.pth'
#先建立路径
torch.save(net.state_dict(),PATH)
#保存:可以是pth文件或者pt文件

model=Net()
model.load_state_dict(torch.load(PATH))
#载入保存的模型参数
model.eval()
#不启用 BatchNormalization 和 Dropout

2.保存整个网络

import torch

PATH = "entire_model.pt"
# Save
torch.save(net, PATH)

# Load
model = torch.load(PATH)
model.eval()

Remember too, that you must call model.eval() to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results.

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐