pytorch-模型的加载,迁移学习,保存
代码的主要目的是加载预训练的 VGG16 模型,并修改其分类器部分以适配 CIFAR-10 数据集。修改后的模型可以直接用于 CIFAR-10 的分类任务,而无需重新训练整个 VGG16 模型。这段代码展示了两种不同的模型加载方式,并打印了模型的结构。
这段代码的主要目的是加载和修改 VGG16 模型,并打印模型的结构。以下是代码的详细解释:
1. 导入必要的库
Python复制
import torchvision
import torch
-
torchvision
是 PyTorch 的计算机视觉扩展库,提供了常用的预训练模型和数据集。 -
torch
是 PyTorch 核心库,用于构建和训练深度学习模型。
2. 加载预训练和未预训练的 VGG16 模型
Python复制
vgg16_true = torchvision.models.vgg16(pretrained=True)
vgg16_false = torchvision.models.vgg16(pretrained=False)
-
vgg16_true
:加载带有预训练权重的 VGG16 模型。预训练权重通常是在大规模数据集(如 ImageNet)上训练得到的。 -
vgg16_false
:加载不带预训练权重的 VGG16 模型,初始权重是随机初始化的。
3. 打印预训练的 VGG16 模型结构
Python复制
print(vgg16_true)
-
这行代码会输出 VGG16 模型的结构,包括卷积层(
features
)和全连接层(classifier
)的详细信息。
4. 加载 CIFAR-10 数据集
Python复制
train_data = torchvision.datasets.CIFAR10("../data", train=True, transform=torchvision.transforms.ToTensor(), download=True)
-
torchvision.datasets.CIFAR10
:加载 CIFAR-10 数据集。-
"../data"
:数据集存储的路径。 -
train=True
:加载训练集。 -
transform=torchvision.transforms.ToTensor()
:将图像数据转换为 PyTorch 张量。 -
download=True
:如果数据集不存在,会自动下载。
-
5. 修改 VGG16 模型的分类器部分
Python复制
vgg16_true.classifier.add_module('add_linear', torch.nn.Linear(1000, 10))
-
vgg16_true.classifier
:VGG16 模型的全连接层部分。 -
add_module('add_linear', torch.nn.Linear(1000, 10))
:向分类器部分添加一个新的线性层。-
'add_linear'
:新模块的名称。 -
torch.nn.Linear(1000, 10)
:定义一个线性层,输入维度为 1000(对应原分类器的输出维度),输出维度为 10(对应 CIFAR-10 的 10 个类别)。
-
6. 再次打印修改后的 VGG16 模型结构
Python复制
print(vgg16_true)
-
这行代码会输出修改后的 VGG16 模型的结构,可以看到新增的
add_linear
模块。
总结
-
代码的主要目的是加载预训练的 VGG16 模型,并修改其分类器部分以适配 CIFAR-10 数据集。
-
修改后的模型可以直接用于 CIFAR-10 的分类任务,而无需重新训练整个 VGG16 模型。
这段代码展示了两种不同的模型加载方式,并打印了模型的结构。以下是代码的详细解释:
1. 加载模型的方式1
Python复制
model = torch.load("vgg16_method1.pth")
-
使用
torch.load
加载保存的模型文件"vgg16_method1.pth"
。 -
这种方式加载的模型是完整的模型对象,包括模型结构和权重。
-
通常,这种保存方式是通过
torch.save(model, path)
保存的,保存了整个模型。
2. 加载模型的方式2
Python复制
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
-
先实例化一个未预训练的 VGG16 模型。
-
使用
load_state_dict
方法加载保存的模型权重文件"vgg16_method2.pth"
。 -
这种方式加载的模型是通过
torch.save(model.state_dict(), path)
保存的模型权重。
3. 打印模型结构
Python复制
print(model)
print(vgg16)
-
分别打印两种加载方式加载的模型结构。
-
预期输出是两个模型的详细结构,包括各层的参数和连接方式。
区别与注意事项
-
保存方式不同:
-
torch.save(model, path)
保存的是完整的模型对象(包括结构和权重)。 -
torch.save(model.state_dict(), path)
保存的是模型的权重字典。
-
-
加载方式不同:
-
方式1:直接加载完整的模型对象。
-
方式2:需要先实例化模型,再加载权重。
-
-
适用场景:
-
如果需要迁移学习或微调模型,通常使用方式2(加载权重),因为这样可以灵活修改模型结构。
-
如果模型结构固定,可以直接使用方式1加载整个模型。
-
运行结果示例
-
print(model)
:输出通过方式1加载的完整模型结构。 -
print(vgg16)
:输出通过方式2加载的模型结构,预期与方式1的结构一致(假设权重一致)。
总结
-
方式1 适用于快速加载和使用完整的模型对象。
-
方式2 适用于需要灵活修改模型结构或进行迁移学习的场景。
这段代码展示了如何使用 PyTorch 的 torch.save — PyTorch 2.6 documentation 函数以两种不同的方式保存模型。以下是代码的详细解释:
主线:代码逻辑
-
导入必要的库
Python复制
import torch import torchvision
-
torch
: PyTorch 核心库,用于构建和训练深度学习模型。 -
torchvision
: PyTorch 的计算机视觉扩展库,提供了常用的预训练模型。
-
-
加载未预训练的 VGG16 模型
Python复制
vgg16 = torchvision.models.vgg16(pretrained=False)
-
torchvision.models.vgg16
: 加载 VGG16 模型。 -
pretrained=False
: 表示不加载预训练权重,模型的权重是随机初始化的。
-
-
保存方式 1:模型结构 + 参数
Python复制
torch.save(vgg16, "vgg16_method1.pth")
-
torch.save(vgg16, "vgg16_method1.pth")
: 直接保存整个模型对象。-
包括模型的结构(例如,层的设计)和权重(参数)。
-
这种方式适合保存完整的模型,便于直接加载和使用。
-
-
-
保存方式 2:只保存模型参数(权重)
Python复制
torch.save(vgg16.state_dict(), "vgg16_method2.pth")
-
vgg16.state_dict()
: 获得模型的权重字典(仅保存参数)。 -
torch.save(...)
: 保存权重字典到文件。 -
这种方式更适合迁移学习或模型微调时使用,因为可以灵活地将权重加载到不同的模型结构中。
-
补充知识:两种保存方式的区别
特性 | 方式 1 | 方式 2 |
---|---|---|
保存内容 | 模型结构 + 参数 | 仅参数(权重) |
文件大小 | 较大(保存结构信息) | 较小(仅保存参数) |
适用场景 | 快速导出和加载完整模型 | 迁移学习、模型微调 |
局限性 | 模型结构依赖原环境 | 加载时需预先定义模型结构 |
载入模型
-
方式 1:
Python复制
model = torch.load("vgg16_method1.pth")
-
直接加载整个模型对象。
-
-
方式 2:
Python复制
vgg16 = torchvision.models.vgg16(pretrained=False) vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
-
需要先定义模型结构,再加载权重。
-
实际应用场景
-
方式 1:适用于快速保存和分享完整的模型,不涉及模型结构调整。
-
方式 2:适用于迁移学习,可以在保留部分预训练权重的同时调整模型结构。
总结
-
方式 1 是直接保存模型的完整状态。
-
方式 2 仅保存模型的权重,便于灵活应用。
-
根据实际需求(是否需要调整模型结构)选择合适的保存方式。
torch.load以及torch.save的讲解
torch.load
和 torch.save
是 PyTorch 中用于保存和加载模型、张量和其他对象的重要函数。它们在深度学习项目中经常用于保存训练好的模型权重、模型结构、训练状态等。以下是它们的详细原理和使用方法:
一、torch.save
函数原理
torch.save
用于将 Python 对象(如模型、张量、字典等)保存到文件中。它利用 Python 的序列化机制将对象的状态信息转化为可存储的字节流格式,并写入到文件中。
1. 功能与原理
-
功能:保存 PyTorch 对象(如模型、张量、字典等)到文件中。
-
原理:
-
使用 PyTorch 的自定义序列化机制,将对象的状态信息(如张量的数据、模型的权重、优化器的状态等)序列化为字节流。
-
支持多种存储后端,包括 PyTorch 自己的二进制格式、
pickle
序列化、zip 文件等。 -
保存的对象可以是张量、模型的整个对象、或者模型的权重(
state_dict
)。
-
2. 常用的保存方式
-
保存整个模型:
Python复制
torch.save(model, path)
-
保存模型的结构和权重。
-
优点:保存的内容完整,可以直接加载整个模型。
-
缺点:模型结构依赖于保存时的模型定义,如果模型定义发生变化(如代码改动),加载时可能出错。
-
-
保存模型权重(
state_dict
):Python复制
torch.save(model.state_dict(), path)
-
仅保存模型的权重(参数)。
-
这样可以在加载时更灵活地定义模型结构(例如,用于迁移学习或模型迁移)。
-
优点:模型结构和权重分离,更适合迁移学习和模型微调。
-
缺点:需要手动加载到模型结构中。
-
3. 示例
以下是一个保存模型的示例:
Python复制
import torch
import torchvision
# 创建一个未预训练的 VGG16 模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存整个模型
torch.save(vgg16, 'vgg16_model.pth')
# 保存模型权重
torch.save(vgg16.state_dict(), 'vgg16_weights.pth')
二、torch.load
函数原理
torch.load
用于从文件中加载之前保存的 PyTorch 对象。它通过解析文件中的字节流,恢复对象的状态信息。
1. 功能与原理
-
功能:从文件中加载 PyTorch 对象。
-
原理:
-
使用与
torch.save
相同的序列化机制,解析文件中的字节流,将其反序列化为 Python 对象。 -
自动检测文件的存储后端(如
pickle
、zip 文件等)并选择合适的解析方式。
-
2. 常用的加载方式
-
加载整个模型:
Python复制
model = torch.load(path)
-
直接加载整个模型对象,包括模型结构和权重。
-
注意:需要确保在当前环境中定义了相同的模型类(如
torchvision.models.vgg16
),否则会报错。
-
-
加载模型权重(
state_dict
):Python复制
model = torchvision.models.vgg16(pretrained=False) model.load_state_dict(torch.load(path))
-
需要先定义模型结构,然后加载权重到模型中。
-
这样可以更灵活地加载权重到不同的模型结构中(例如,迁移学习)。
-
3. 示例
以下是一个加载模型的示例:
Python复制
import torch
import torchvision
# 加载整个模型
model = torch.load('vgg16_model.pth')
# 加载模型权重
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load('vgg16_weights.pth'))
三、注意事项
-
文件路径和权限:
-
确保保存和加载文件时路径正确,且有权限读写文件。
-
-
兼容性:
-
如果使用
torch.save
保存整个模型,加载时需要保证 PyTorch 版本和模型定义兼容。 -
使用
state_dict
保存和加载可以更好地兼容不同版本的 PyTorch 和模型定义。
-
-
使用 GPU 和 CPU 的转换:
-
如果模型在 GPU 上保存,但在 CPU 上加载(反之亦然),可以在加载时指定设备:
Python复制
# 将模型加载到 CPU model = torch.load(path, map_location=torch.device('cpu')) # 将模型加载到 GPU model = torch.load(path, map_location=torch.device('cuda:0'))
-
-
分布式训练的保存和加载:
-
在分布式训练中,
state_dict
会包含module.
前缀。加载时需要移除或保留前缀,以匹配模型结构。
-
四、总结
-
torch.save
和torch.load
是 PyTorch 中用于保存和加载模型的核心函数。 -
保存和加载的方式取决于应用场景(如迁移学习、微调、快速加载等)。
-
推荐使用
state_dict
方式保存和加载模型权重,以获得更好的灵活性和兼容性。
更多推荐
所有评论(0)