Pytorch如何保存和加载模型参数
pytorch 保存和加载模型的方法有两种:1.保存网络的参数import torch#导入模块net=Net()#创建网络,当然还需要损失函数梯度等省略PATH='state_dict_model.pth'#先建立路径torch.save(net.state_dict(),PATH)#保存:可以是pth文件或者pt文件model=Net()model.load_state_dict(torch.
·
pytorch 保存和加载模型的方法有两种:
1.保存网络的参数
import torch
#导入模块
net=Net()
#创建网络,当然还需要损失函数梯度等省略
PATH='state_dict_model.pth'
#先建立路径
torch.save(net.state_dict(),PATH)
#保存:可以是pth文件或者pt文件
model=Net()
model.load_state_dict(torch.load(PATH))
#载入保存的模型参数
model.eval()
#不启用 BatchNormalization 和 Dropout
2.保存整个网络
import torch
PATH = "entire_model.pt"
# Save
torch.save(net, PATH)
# Load
model = torch.load(PATH)
model.eval()
Remember too, that you must call model.eval() to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results.
更多推荐


所有评论(0)