Pytorch|零基础入门(五)Dataloder+常用Containers
本文介绍了PyTorch中torchvision数据集的使用方法以及DataLoader的核心功能。主要内容包括:1. 使用torchvision加载CIFAR10数据集,通过transform参数进行数据预处理;2. DataLoader的配置参数详解,包括batch_size、shuffle、num_workers等;3. 数据可视化方法,使用TensorBoard和make_grid展示批量
torchvision中的数据集使用
Ctrl+P可以查看参数
下载数据集可以在下载链接出来以后放在迅雷当中
#下载数据集
import torchvision
train_set = torchvision.datasets.CIFAR10(root='./dataset',train=True,download=True)
test_set = torchvision.datasets.CIFAR10(root='./dataset',train=False,download=True)
print(test_set[0])#查看第一个数据集
结合使用
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
])
train_set = torchvision.datasets.CIFAR10(root='./dataset',train=True,transform=dataset_transform,download=True)
test_set = torchvision.datasets.CIFAR10(root='./dataset',train=False,transform=dataset_transform,download=True)
writer=SummaryWriter("p10")
for i in range(10):
img,target=train_set[i]
writer.add_image('test_set',img,i)
writer.close()
#然后在终端输入:tensorboard --logdir=p10
Dataloder
一、核心概念铺垫
在使用 DataLoader 前,首先要理解两个核心组件:
Dataset:负责定义数据的读取逻辑(告诉程序数据在哪、怎么读)DataLoader:负责对Dataset封装,实现批量加载、打乱、多线程读取等功能
二、完整使用流程(从基础到进阶)
步骤1:基础使用(自定义数据集 + DataLoader)
2.1 定义自定义 Dataset 类
Dataset 是抽象类,必须实现 __len__(返回数据总数)和 __getitem__(返回指定索引的单条数据)两个方法。
import torch
from torch.utils.data import 0.
............................................................................................................................................................................................................................................................................................................................................
Dataset, DataLoader
import numpy as np
# 自定义数据集类(以简单的数值数据集为例)
class MyDataset(Dataset):
def __init__(self, data, labels):
"""
初始化函数:加载数据和标签
:param data: 特征数据,如 numpy 数组
:param labels: 标签数据,如 numpy 数组
"""
self.data = torch.from_numpy(data).float() # 转为 torch 张量(float 类型)
self.labels = torch.from_numpy(labels).long() # 标签一般用 long 类型
def __len__(self):
"""返回数据集的总长度"""
return len(self.data)
def __getitem__(self, idx):
"""根据索引返回单条数据(特征+标签)"""
return self.data[idx], self.labels[idx]
# 生成测试数据(100条样本,每条5个特征,标签为0/1)
np_data = np.random.randn(100, 5) # 特征:100x5
np_labels = np.random.randint(0, 2, size=100) # 标签:100个0/1
# 实例化数据集
my_dataset = MyDataset(np_data, np_labels)
2.2 初始化 DataLoader 并使用
# 初始化 DataLoader
dataloader = DataLoader(
dataset=my_dataset, # 传入自定义的 Dataset 实例
batch_size=16, # 每个批次的样本数(核心参数)
shuffle=True, # 每个 epoch 是否打乱数据(训练时建议True)
num_workers=2, # 多线程读取数据(windows下建议设为0,避免报错)
drop_last=True # 是否丢弃最后一个不完整的批次(如100条数据,batch_size=16,最后剩4条则丢弃)
)
# 遍历 DataLoader(核心使用方式)
for epoch in range(2): # 模拟2个训练轮次
print(f"===== Epoch {epoch+1} =====")
for batch_idx, (batch_data, batch_labels) in enumerate(dataloader):
# batch_data: 形状 [16,5](16个样本,每个5个特征)
# batch_labels: 形状 [16](16个标签)
print(f"Batch {batch_idx+1}:")
print(f" 数据形状: {batch_data.shape}")
print(f" 标签形状: {batch_labels.shape}")
print(f" 前3个标签: {batch_labels[:3]}")
# 这里可添加模型训练逻辑(如前向传播、计算损失等)
break # 仅演示第一个批次,实际训练时去掉break
步骤3:进阶用法
3.1 处理图像数据集(结合 torchvision)
from torchvision import datasets, transforms
# 定义图像预处理(如归一化、缩放、转张量)
transform = transforms.Compose([
transforms.Resize((28, 28)), # 缩放为28x28
transforms.ToTensor(), # 转为张量(0-1范围)
transforms.Normalize((0.5,), (0.5,)) # 归一化(均值0.5,标准差0.5)
])
# 加载MNIST手写数字数据集(内置Dataset)
mnist_dataset = datasets.MNIST(
root='./data', # 数据保存路径
train=True, # 训练集
download=True, # 自动下载
transform=transform # 应用预处理
)
# 初始化DataLoader
mnist_dataloader = DataLoader(
dataset=mnist_dataset,
batch_size=32,
shuffle=True,
num_workers=0 # windows下设为0
)
# 遍历图像数据集
for images, labels in mnist_dataloader:
print(f"图像批次形状: {images.shape}") # [32,1,28,28](32张图,1通道,28x28)
print(f"标签批次形状: {labels.shape}") # [32]
break
3.2 自定义批次拼接(collate_fn)
默认情况下 DataLoader 会自动拼接张量,但如果数据是不规则长度(如文本序列),需要自定义 collate_fn:
# 自定义collate_fn:处理变长数据
def custom_collate_fn(batch):
"""
batch: 是一个列表,每个元素是 __getitem__ 返回的结果
此处示例:处理文本序列(假设每条数据是 (序列, 标签),序列长度不同)
"""
# 分离数据和标签
sequences = [item[0] for item in batch]
labels = [item[1] for item in batch]
# 对序列做padding(补0到最长序列长度)
max_len = max(len(seq) for seq in sequences)
padded_sequences = []
for seq in sequences:
pad_len = max_len - len(seq)
padded_seq = seq + [0]*pad_len # 补0
padded_sequences.append(padded_seq)
# 转为张量
return torch.tensor(padded_sequences), torch.tensor(labels)
# 模拟变长文本数据集
class TextDataset(Dataset):
def __init__(self):
# 模拟5条文本序列(长度分别为3,5,2,4,6)
self.data = [
([1,2,3], 0),
([4,5,6,7,8], 1),
([9,10], 0),
([11,12,13,14], 1),
([15,16,17,18,19,20], 0)
]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# 初始化DataLoader并传入自定义collate_fn
text_dataloader = DataLoader(
dataset=TextDataset(),
batch_size=3,
shuffle=True,
collate_fn=custom_collate_fn # 自定义批次拼接函数
)
# 遍历验证
for seqs, labels in text_dataloader:
print(f"padding后的序列形状: {seqs.shape}") # [3,6](3个样本,最长序列长度6)
print(f"序列内容:\\\\n{seqs}")
print(f"标签: {labels}")
break
3.3 多数据集加载(ConcatDataset)
from torch.utils.data import ConcatDataset
# 拼接两个数据集
dataset1 = MyDataset(np_data[:50], np_labels[:50])
dataset2 = MyDataset(np_data[50:], np_labels[50:])
concat_dataset = ConcatDataset([dataset1, dataset2])
# 加载拼接后的数据集
concat_dataloader = DataLoader(concat_dataset, batch_size=20, shuffle=True)
print(f"拼接后数据集总长度: {len(concat_dataset)}") # 输出100
三、常见问题与注意事项
num_workers报错:Windows系统下建议设为0,Linux/Mac可根据CPU核心数设置(如4、8),避免多线程冲突。drop_last选择:训练时建议设为True(保证每个批次大小一致),测试时可设为False(不浪费数据)。shuffle选择:训练集设为True(避免模型过拟合),测试/验证集设为False(保证结果可复现)。- 内存占用:
batch_size不宜过大(否则显存不足),可根据显卡显存调整(如12GB显存可设64/128)。
总结
DataLoader是基于Dataset的封装,核心作用是批量加载、打乱、多线程读取数据,必须配合Dataset使用。- 核心参数:
batch_size(批次大小)、 - (是否打乱)、
num_workers(多线程)、drop_last(是否丢弃不完整批次)。 - 进阶场景:处理变长数据需自定义
collate_fn,处理图像需结合torchvision.transforms,多数据集可使用ConcatDataset。
Ctrl+\多行注释
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# 新增:导入make_grid,用于拼接批量图片成网格(解决add_image批量显示问题)
from torchvision.utils import make_grid
# 1. 加载CIFAR10测试集(新增download=True,本地无数据时自动下载)
test_data = torchvision.datasets.CIFAR10(
root='./dataset',
train=False,
transform=torchvision.transforms.ToTensor(),
download=True # 避免本地无数据集导致的路径错误
)
# 2. 封装DataLoader(参数保持不变)
test_loader = DataLoader(
dataset=test_data,
batch_size=4,
shuffle=True,
num_workers=0,
drop_last=False
)
# 3. 查看第一张图片信息(修正print赋值错误)
img, target = test_data[0]
print("第一张图片的形状:", img.shape) # 正确打印形状,输出torch.Size([3, 32, 32])
print("第一张图片的标签:", target) # 正确打印标签,输出0-9的整数
# 4. TensorBoard可视化批量图片(修正核心错误)
writer = SummaryWriter("dataloader") # 日志保存到当前目录的dataloader文件夹
step = 0
for data in test_loader:
imgs, targets = data
# 关键修正:用make_grid拼接4张图片为网格(nrow=2表示每行2张)
img_grid = make_grid(imgs, nrow=2)
# 写入拼接后的网格图片,step控制可视化的步数
writer.add_image("test_data_batch", img_grid, step)
step += 1 # 简化写法,等价于step=step+1
# 移除循环内的writer.close() → 循环结束后再关闭
# 循环结束后关闭writer(确保所有图片都写入)
writer.close()
# 运行后在终端执行 tensorboard --logdir=dataloader,打开浏览器访问http://localhost:6006查看可视化结果
代码整体结构梳理
整个代码分为4个核心模块:
- 导入依赖库(基础工具+可视化工具)
- 加载CIFAR10测试数据集(定义数据来源和格式)
- 用DataLoader封装数据集(实现批量加载)
- 查看单张图片信息(验证数据加载正确性)
- TensorBoard可视化批量图片(直观展示DataLoader的批量数据)
二、逐行详细解释
模块1:导入依赖库
import torchvision
- 作用:导入PyTorch官方的计算机视觉工具库
torchvision,这个库包含了:- 常用的公开数据集(如CIFAR10、MNIST);
- 图像预处理工具(如转张量、归一化);
- 视觉相关的工具函数(如
make_grid); - 预训练模型(如ResNet、VGG)。
- 这里主要用它来加载CIFAR10数据集和使用
make_grid函数。
from torch.utils.data import DataLoader
- 作用:从PyTorch的核心数据工具模块中导入
DataLoader类——这是PyTorch中批量加载数据的核心工具,能把数据集封装成“可迭代的批量数据生成器”,支持批量读取、打乱数据、多线程加载等功能。
from torch.utils.tensorboard import SummaryWriter
- 作用:导入TensorBoard的PyTorch封装类
SummaryWriter,用于将数据(图片、损失值、参数等)写入日志文件,后续可通过TensorBoard可视化这些数据。
from torchvision.utils import make_grid
- 作用:从
torchvision的工具模块中导入make_grid函数——核心作用是将多张图片张量拼接成一张网格状的图片(比如把4张图片拼成2行2列的网格),解决SummaryWriter.add_image只能显示单张图片的问题。
模块2:加载CIFAR10测试数据集
test_data = torchvision.datasets.CIFAR10(
root='./dataset',
train=False,
transform=torchvision.transforms.ToTensor(),
download=True
)
- 这行代码是实例化CIFAR10数据集类,本质是告诉程序“去哪里找数据、加载哪部分数据、数据要怎么处理”,逐个参数解释:
root='./dataset':数据集的保存路径(相对路径,保存在当前代码所在文件夹的dataset子文件夹中);train=False:指定加载测试集(如果设为True则加载训练集);transform=torchvision.transforms.ToTensor():对图片做预处理——将原本的PIL格式图片(像素值0-255)转为PyTorch张量(Tensor),同时把像素值归一化到0-1范围,张量形状变为[通道数, 高度, 宽度](CIFAR10图片是3通道,所以形状是[3,32,32]);download=True:如果./dataset文件夹中没有CIFAR10数据集,自动从官网下载(避免本地无数据导致报错)。
- 最终
test_data是一个Dataset类型的对象,支持通过索引访问单条数据(如test_data[0]取第一张图片)。
模块3:用DataLoader封装数据集
test_loader = DataLoader(
dataset=test_data,
batch_size=4,
shuffle=True,
num_workers=0,
drop_last=False
)
- 这行代码是将Dataset封装为DataLoader,本质是把“单条数据的集合”变成“批量数据的生成器”,逐个参数解释:
dataset=test_data:传入要封装的数据集(即上面加载的CIFAR10测试集);batch_size=4:每个批次包含4张图片(每次读取4条数据);shuffle=True:遍历数据时,先打乱数据顺序再分批次(测试集设为True仅为演示,实际测试时建议设为False,保证结果可复现);num_workers=0:数据加载的线程数(Windows系统建议设为0,避免多线程冲突报错;Linux/Mac可设为4/8,加快加载速度);drop_last=False:如果最后一个批次的样本数不足batch_size(比如总数据数不是4的倍数),不丢弃这个批次(设为True则丢弃)。
- 最终
test_loader是一个可迭代的对象,遍历它会返回(批量图片张量, 批量标签张量),比如for data in test_loader每次取4张图片+4个标签。
模块4:查看单张图片信息(验证数据)
img, target = test_data[0]
- 作用:通过索引
0读取test_data中的第一张图片,Dataset对象的索引访问会返回(图片张量, 标签)的元组,这里用img接收图片张量,target接收标签。 - CIFAR10的标签是0-9的整数,对应10类物体:0=飞机、1=汽车、2=鸟、3=猫、4=鹿、5=狗、6=青蛙、7=马、8=船、9=卡车。
print("第一张图片的形状:", img.shape)
- 作用:打印第一张图片张量的形状,输出是
torch.Size([3, 32, 32]),含义:3:图片的通道数(RGB彩色图片,3个通道);32:图片的高度(像素);32:图片的宽度(像素)——这是CIFAR10图片的固定尺寸。
print("第一张图片的标签:", target)
- 作用:打印第一张图片对应的标签(0-9的整数),比如输出
3就代表这张图片是“猫”。
模块5:TensorBoard可视化批量图片(核心可视化逻辑)
writer = SummaryWriter("dataloader")
- 作用:创建
SummaryWriter对象,指定日志文件保存到当前文件夹的dataloader子文件夹中——后续所有可视化数据都会写入这个文件夹,供TensorBoard读取。
step = 0
- 作用:定义一个步数计数器
step,用于TensorBoard中区分不同批次的图片(每显示一个批次,步数+1,避免图片覆盖)。
for data in test_loader:
imgs, targets = data
- 作用:遍历
test_loader,每次循环取一个批次的数:data是一个元组,包含(批量图片张量, 批量标签张量);imgs接收批量图片张量,形状是[4, 3, 32, 32](4张图片,每张3通道、32x32);targets接收批量标签张量,形状是[4](4个标签,每个是0-9的整数)。
img_grid = make_grid(imgs, nrow=2)
- 核心作用:用
make_grid将4张图片拼接成一张网格图:imgs:传入批量图片张量([4,3,32,32]);nrow=2:指定网格的行数/列数——每行显示2张图片,所以最终是2行2列的网格;- 输出
img_grid是一个单张图片张量,形状为[3, 32*2 + 间隔, 32*2 + 间隔](拼接后的网格图,可直接用add_image显示)。
- 为什么需要这一步?因为
writer.add_image默认只支持显示单张图片(形状[C,H,W]),直接传入批量图片([4,3,32,32])会报维度错误,make_grid是解决这个问题的关键。
writer.add_image("test_data_batch", img_grid, step)
- 作用:将拼接后的网格图写入TensorBoard日志:
"test_data_batch":可视化图片的名称(在TensorBoard中会显示这个名称);img_grid:要显示的图片张量(拼接后的单张网格图);step:步数,用于区分不同批次的图片(比如step=0是第1批,step=1是第2批,TensorBoard中可通过滑动条切换)。
step += 1
- 作用:每遍历一个批次,步数+1,保证下一个批次的图片有唯一的步数标识。
writer.close()
- 作用:遍历结束后关闭
SummaryWriter,确保所有日志数据都写入文件(必须放在循环外,如果放在循环内,第一次循环就会关闭写入器,后续批次的图片无法写入)。
# 额外提示:运行后在终端执行 tensorboard --logdir=dataloader,打开浏览器访问http://localhost:6006查看可视化结果
- 作用:告诉用户如何查看TensorBoard可视化结果:
- 运行代码后,在终端(命令行)执行
tensorboard --logdir=dataloader(指定日志文件夹); - 终端会输出一个网址(通常是
http://localhost:6006); - 打开浏览器访问这个网址,在「Images」标签下就能看到拼接后的批量图片网格。
- 运行代码后,在终端(命令行)执行
三、代码运行流程总结
- 程序先导入所需的库,准备好工具;
- 加载CIFAR10测试集,将图片转为张量格式(如果本地没有数据则自动下载);
- 用DataLoader把数据集封装成“每次返回4张图片”的批量数据生成器;
- 读取第一张图片,打印它的形状和标签,验证数据加载正确;
- 创建TensorBoard写入器,遍历DataLoader的每个批次:
- 把4张图片拼接成2x2的网格图;
- 将网格图写入TensorBoard日志,用step区分不同批次;
- 遍历结束后关闭写入器,用户可通过TensorBoard查看所有批次的图片。
四、关键细节补充
- 张量形状的变化:
- 单张图片:
[3, 32, 32](C, H, W); - 批量图片:
[4, 3, 32, 32](batch_size, C, H, W); - 拼接后网格图:
[3, 68, 68](3通道,高度=322+4(间隔),宽度=322+4(间隔))。
- 单张图片:
- shuffle=True的作用:每次运行代码,遍历DataLoader得到的图片顺序都不同(打乱数据),但单张图片的内容不变。
- download=True的注意事项:第一次运行代码会下载CIFAR10数据集(约17MB),后续运行会直接读取本地数据,无需重复下载。
- TensorBoard的使用前提:需提前安装TensorBoard(
pip install tensorboard),否则无法运行。
总结
- 这份代码的核心是DataLoader的使用和TensorBoard可视化批量图片,前者实现数据的批量加载,后者通过
make_grid拼接图片解决批量可视化的问题; - 关键修正点:避免覆盖
print函数、循环外关闭SummaryWriter、用make_grid拼接批量图片; - 代码的最终效果:既能验证数据加载的正确性,又能直观看到DataLoader输出的批量图片,是深度学习中“数据加载+可视化”的基础模板。
- 训练查看不同数据时记得把
writer.add_image("test_data_batch", img_grid, step)这一句里的标签给改了,然后把tensorboard更新,以便查看不同训练的显示
1.什么是神经网络 Containers?
1. 本质定义
Containers 是 PyTorch torch.nn 模块下的容器类,核心作用是:
- 把多个神经网络层(如
nn.Conv2d、nn.Linear、nn.ReLU)“打包”成一个整体; - 统一管理层的参数(如权重、偏置)、设备(CPU/GPU)、训练/评估模式;
- 支持自动前向传播、参数优化、模型保存/加载等核心功能。
2. 为什么需要 Containers?
如果没有容器,你需要手动写每一层的前向传播(比如 x = nn.Conv2d()(x); x = nn.ReLU()(x)),且无法统一管理参数——Containers 让网络构建从“零散代码”变成“模块化结构”,大幅降低复杂度。
3. 核心基类:nn.Module
所有 Containers 都继承自 nn.Module(这是 PyTorch 中所有神经网络模块的基类),因此具备以下核心能力:
parameters():返回所有可训练参数(供优化器更新);to(device):把整个模型移到 CPU/GPU;train()/eval():切换训练/评估模式(影响 Dropout、BN 等层);state_dict()/load_state_dict():保存/加载模型参数。
搭建我的第一个简单的神经网络
import torch
from torch import nn
class Tudui(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input):
output = input+1
return output
#以上神经网络模板定义完了
#创建神经网络
tudui = Tudui()
x=torch.tensor(1.0)
#把X输入放在神经网络tudui当中
output=tudui(x)
print(output)
#输出tensor(2.)
2、PyTorch 中最常用的 4 种 Containers
1. 基础容器:nn.Sequential(顺序容器)
核心特点
- 按顺序堆叠神经网络层,前一层的输出直接作为后一层的输入;
- 最简单、最常用的容器,适合构建“线性流程”的网络(如简单 CNN、MLP)。
实战用法(以 CIFAR10 分类的简单 CNN 为例)
import torch
import torch.nn as nn
# 用 Sequential 构建简单 CNN
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()#调用初始化函数
# 卷积层:顺序堆叠(Conv2d → ReLU → MaxPool2d → Flatten → Linear)
self.features = nn.Sequential(
# 第一层卷积:3通道→16通道,3x3卷积,步长1,填充1
nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1),
nn.ReLU(), # 激活函数
nn.MaxPool2d(kernel_size=2, stride=2), # 池化:32x32→16x16
# 第二层卷积:16通道→32通道
nn.Conv2d(16, 32, 3, 1, 1),
nn.ReLU(),
nn.MaxPool2d(2, 2), # 16x16→8x8
# 展平:32*8*8 → 2048
nn.Flatten()
)
# 全连接层:顺序堆叠(Linear → ReLU → Linear)
self.classifier = nn.Sequential(
nn.Linear(32 * 8 * 8, 512), # 2048 → 512
nn.ReLU(),
nn.Linear(512, 10) # 512 → 10(CIFAR10 10类)
)
def forward(self, x):#x是你自己定义的
# 前向传播:按 Sequential 顺序执行
x = self.features(x)
x = self.classifier(x)
return x
# 测试 Sequential 容器
model = SimpleCNN()
# 模拟输入:batch_size=4,3通道,32x32(CIFAR10 尺寸)
input_tensor = torch.randn(4, 3, 32, 32)
output = model(input_tensor)
print("输出形状:", output.shape) # 输出 (4, 10) → 4个样本,10类预测值
print("模型可训练参数总数:", sum(p.numel() for p in model.parameters() if p.requires_grad))
关键
nn.Sequential接收层的列表/字典,按传入顺序执行;- 无需手动写前向传播的每一步(比如不用写
x = self.conv1(x)),只需调用容器即可; - 适合结构简单、无分支的网络(如 LeNet、简单 MLP)。
- 输入x 卷积conv1(x) 非线性F.relu 卷积 非线性 输出 比如
def forward(self.x)
x=F.relu(self.conv1(x))
return F.relu(self.conv2(x)
#forward(*input)计算图-在每个子类中都应该重写
2. 灵活容器:nn.Modulelist(模块列表)
核心特点
- 把多个层/模块存成列表,支持索引访问、动态添加/删除层;
- 无自动前向传播,需手动写前向逻辑,适合动态调整层数的场景(如可变层数的 CNN)。
实战用法(动态构建多层卷积)
class DynamicCNN(nn.Module):
def __init__(self, in_channels=3, out_channels_list=[16, 32, 64]):
super(DynamicCNN, self).__init__()
# 用 ModuleList 存储多个卷积层(动态层数)
self.conv_layers = nn.ModuleList()
for out_ch in out_channels_list:
# 逐个添加卷积层(Conv2d + ReLU)
self.conv_layers.append(
nn.Sequential(
nn.Conv2d(in_channels, out_ch, 3, 1, 1),
nn.ReLU()
)
)
in_channels = out_ch # 更新输入通道数
# 池化层
self.pool = nn.MaxPool2d(2, 2)
# 全连接层
self.fc = nn.Linear(64 * 4 * 4, 10) # 64通道,4x4特征图(32→16→8→4)
def forward(self, x):
# 手动遍历 ModuleList 执行前向传播
for conv in self.conv_layers:
x = conv(x)
x = self.pool(x)
x = torch.flatten(x, 1) # 展平
x = self.fc(x)
return x
# 测试 ModuleList
dynamic_model = DynamicCNN(out_channels_list=[16, 32, 64])
input_tensor = torch.randn(4, 3, 32, 32)
output = dynamic_model(input_tensor)
print("DynamicCNN 输出形状:", output.shape) # (4, 10)
# 访问 ModuleList 中的第二层
print("第二层卷积:", dynamic_model.conv_layers[1])
关键
nn.ModuleList像普通 Python 列表,但会被 PyTorch 识别为“模型的一部分”(参数会被管理);- 必须手动遍历
ModuleList执行前向传播(无自动顺序); - 适合需要动态调整层数的场景(如根据配置文件设置卷积层数)。
3. 键值对容器:nn.ModuleDict(模块字典)
核心特点
- 把多个层/模块存成键值对(字典),按名称访问层;
- 无自动前向传播,适合可选分支的网络(如多路径特征提取)。
实战用法(多分支特征提取)
class BranchCNN(nn.Module):
def __init__(self):
super(BranchCNN, self).__init__()
# 用 ModuleDict 存储不同分支的卷积层(按名称访问)
self.branch_layers = nn.ModuleDict({
"small_kernel": nn.Conv2d(3, 16, kernel_size=3, padding=1),
"large_kernel": nn.Conv2d(3, 16, kernel_size=5, padding=2),
"relu": nn.ReLU()
})
self.pool = nn.MaxPool2d(2, 2)
self.fc = nn.Linear(16 * 16 * 16, 10)
def forward(self, x, branch_name="small_kernel"):
# 按名称选择分支(灵活切换)
x = self.branch_layers[branch_name](x)
x = self.branch_layers["relu"](x) # 共用ReLU层
x = self.pool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
# 测试 ModuleDict
branch_model = BranchCNN()
input_tensor = torch.randn(4, 3, 32, 32)
# 使用小核分支
output_small = branch_model(input_tensor, branch_name="small_kernel")
# 使用大核分支
output_large = branch_model(input_tensor, branch_name="large_kernel")
print("小核分支输出形状:", output_small.shape) # (4, 10)
print("大核分支输出形状:", output_large.shape) # (4, 10)
关键
nn.ModuleDict用字符串键映射层,适合需要“动态选择分支”的场景;- 可动态添加层:
branch_model.branch_layers["new_branch"] = nn.Conv2d(...); - 常用于构建多尺度特征提取的网络(如 ResNet 的分支、YOLO 的多尺度检测)。
4. 自定义容器:继承 nn.Module(最灵活)
核心特点
- 所有容器的“底层”,
Sequential/ModuleList都继承自nn.Module; - 完全自定义前向传播逻辑,适合复杂结构的网络(如 ResNet 的残差块、Transformer 的注意力层)。
实战用法(残差块,ResNet 核心)
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(ResBlock, self).__init__()
# 主分支:Conv → BN → ReLU → Conv → BN
self.main_branch = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, stride, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels)
)
# 捷径分支(维度匹配)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Conv2d(in_channels, out_channels, 1, stride)
def forward(self, x):
# 自定义前向:主分支 + 捷径分支(残差连接)
out = self.main_branch(x)
out += self.shortcut(x) # 残差相加(核心逻辑)
out = nn.ReLU()(out)
return out
# 构建带残差块的网络
class SimpleResNet(nn.Module):
def __init__(self):
super(SimpleResNet, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3, 1, 1)
self.res_block1 = ResBlock(16, 16) # 自定义容器作为子模块
self.res_block2 = ResBlock(16, 32, stride=2)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(32, 10)
def forward(self, x):
x = self.conv1(x)
x = self.res_block1(x) # 调用自定义残差块
x = self.res_block2(x)
x = self.pool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
# 测试自定义容器
res_model = SimpleResNet()
input_tensor = torch.randn(4, 3, 32, 32)
output = res_model(input_tensor)
print("ResNet 输出形状:", output.shape) # (4, 10)
关键
- 自定义容器是构建复杂网络的核心(如 ResNet、Transformer、GPT 等);
- 只需继承
nn.Module,在__init__中定义子模块,在forward中写自定义逻辑; - 子模块(如
res_block1)会被自动管理(参数、设备、训练模式)。
更多推荐

所有评论(0)