PyTorch中的模型保存:一键保存、两种选择/保存整个模型和保存模型参数
当我们使用一键保存功能时,PyTorch会把整个模型连同它的结构和参数一起保存下来。与保存整个模型相比,有时我们只需要保存模型的参数而不是结构。这种方式会生成更小的文件,更适用于共享参数或迁移学习等场景。通过这种转换的方式,我们可以随心所欲地在保存模型整体结构和仅保存参数之间切换,让模型保存变得更加灵活便捷。有时候,我们需要在保存整个模型和保存模型参数之间自由转换。通过这种方式,我们一举保存了模型
·
探索PyTorch中的模型保存:一键保存、两种选择
目录
一键保存整个模型:保留全貌
当我们使用一键保存功能时,PyTorch会把整个模型连同它的结构和参数一起保存下来。这意味着我们可以完整地保存模型的状态,随时随地加载它并开始预测或继续训练。
python
import torch
import torchvision.models as models
# 创建模型并保存整个模型
model = models.resnet18(pretrained=True)
torch.save(model, 'whole_model.pth')
# 加载整个模型
loaded_model = torch.load('whole_model.pth')
通过这种方式,我们一举保存了模型的全貌,文件通常以.pth或.pt为后缀。
只保存模型参数:轻装上阵
与保存整个模型相比,有时我们只需要保存模型的参数而不是结构。这种方式会生成更小的文件,更适用于共享参数或迁移学习等场景。
import torch
import torchvision.models as models
# 创建模型并加载预训练参数
model = models.resnet18()
model.load_state_dict(torch.load('model_params.pth'))
# 保存模型参数
torch.save(model.state_dict(), 'model_params.pth')
通过这种方式,我们轻装上阵,只携带了模型的参数而不是整个结构。
转换的奇妙之处
有时候,我们需要在保存整个模型和保存模型参数之间自由转换。这时候,我们只需一行代码就可以实现。
保存整个模型转换为保存模型参数:
import torch
# 加载整个模型
loaded_model = torch.load('whole_model.pth')
# 保存模型参数
torch.save(loaded_model.state_dict(), 'model_params.pth')
保存整个模型转换为保存模型参数:
import torch
import torchvision.models as models
# 创建模型并加载模型参数
model = models.resnet18()
model.load_state_dict(torch.load('model_params.pth'))
# 保存整个模型
torch.save(model, 'whole_model_from_params.pth')
通过这种转换的方式,我们可以随心所欲地在保存模型整体结构和仅保存参数之间切换,让模型保存变得更加灵活便捷。
更多推荐


所有评论(0)