torchvision中的数据集使用

Ctrl+P可以查看参数

下载数据集可以在下载链接出来以后放在迅雷当中

#下载数据集
import torchvision

train_set = torchvision.datasets.CIFAR10(root='./dataset',train=True,download=True)
test_set = torchvision.datasets.CIFAR10(root='./dataset',train=False,download=True)

print(test_set[0])#查看第一个数据集

结合使用

import torchvision
from torch.utils.tensorboard import SummaryWriter

dataset_transform=torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
])

train_set = torchvision.datasets.CIFAR10(root='./dataset',train=True,transform=dataset_transform,download=True)
test_set = torchvision.datasets.CIFAR10(root='./dataset',train=False,transform=dataset_transform,download=True)

writer=SummaryWriter("p10")
for i in range(10):
    img,target=train_set[i]
    writer.add_image('test_set',img,i)

writer.close()
#然后在终端输入:tensorboard --logdir=p10

Dataloder

一、核心概念铺垫

在使用 DataLoader 前,首先要理解两个核心组件:

  1. Dataset:负责定义数据的读取逻辑(告诉程序数据在哪、怎么读)
  2. DataLoader:负责对 Dataset 封装,实现批量加载、打乱、多线程读取等功能

二、完整使用流程(从基础到进阶)

步骤1:基础使用(自定义数据集 + DataLoader)

2.1 定义自定义 Dataset 类

Dataset 是抽象类,必须实现 __len__(返回数据总数)和 __getitem__(返回指定索引的单条数据)两个方法。

import torch
from torch.utils.data import 0.
............................................................................................................................................................................................................................................................................................................................................
Dataset, DataLoader
import numpy as np

# 自定义数据集类(以简单的数值数据集为例)
class MyDataset(Dataset):
    def __init__(self, data, labels):
        """
        初始化函数:加载数据和标签
        :param data: 特征数据,如 numpy 数组
        :param labels: 标签数据,如 numpy 数组
        """
        self.data = torch.from_numpy(data).float()  # 转为 torch 张量(float 类型)
        self.labels = torch.from_numpy(labels).long()  # 标签一般用 long 类型

    def __len__(self):
        """返回数据集的总长度"""
        return len(self.data)

    def __getitem__(self, idx):
        """根据索引返回单条数据(特征+标签)"""
        return self.data[idx], self.labels[idx]

# 生成测试数据(100条样本,每条5个特征,标签为0/1)
np_data = np.random.randn(100, 5)  # 特征:100x5
np_labels = np.random.randint(0, 2, size=100)  # 标签:100个0/1

# 实例化数据集
my_dataset = MyDataset(np_data, np_labels)

2.2 初始化 DataLoader 并使用

# 初始化 DataLoader
dataloader = DataLoader(
    dataset=my_dataset,       # 传入自定义的 Dataset 实例
    batch_size=16,            # 每个批次的样本数(核心参数)
    shuffle=True,             # 每个 epoch 是否打乱数据(训练时建议True)
    num_workers=2,            # 多线程读取数据(windows下建议设为0,避免报错)
    drop_last=True            # 是否丢弃最后一个不完整的批次(如100条数据,batch_size=16,最后剩4条则丢弃)
)

# 遍历 DataLoader(核心使用方式)
for epoch in range(2):  # 模拟2个训练轮次
    print(f"===== Epoch {epoch+1} =====")
    for batch_idx, (batch_data, batch_labels) in enumerate(dataloader):
        # batch_data: 形状 [16,5](16个样本,每个5个特征)
        # batch_labels: 形状 [16](16个标签)
        print(f"Batch {batch_idx+1}:")
        print(f"  数据形状: {batch_data.shape}")
        print(f"  标签形状: {batch_labels.shape}")
        print(f"  前3个标签: {batch_labels[:3]}")
        # 这里可添加模型训练逻辑(如前向传播、计算损失等)
        break  # 仅演示第一个批次,实际训练时去掉break

步骤3:进阶用法

3.1 处理图像数据集(结合 torchvision)

from torchvision import datasets, transforms

# 定义图像预处理(如归一化、缩放、转张量)
transform = transforms.Compose([
    transforms.Resize((28, 28)),  # 缩放为28x28
    transforms.ToTensor(),        # 转为张量(0-1范围)
    transforms.Normalize((0.5,), (0.5,))  # 归一化(均值0.5,标准差0.5)
])

# 加载MNIST手写数字数据集(内置Dataset)
mnist_dataset = datasets.MNIST(
    root='./data',  # 数据保存路径
    train=True,     # 训练集
    download=True,  # 自动下载
    transform=transform  # 应用预处理
)

# 初始化DataLoader
mnist_dataloader = DataLoader(
    dataset=mnist_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=0  # windows下设为0
)

# 遍历图像数据集
for images, labels in mnist_dataloader:
    print(f"图像批次形状: {images.shape}")  # [32,1,28,28](32张图,1通道,28x28)
    print(f"标签批次形状: {labels.shape}")  # [32]
    break

3.2 自定义批次拼接(collate_fn)

默认情况下 DataLoader 会自动拼接张量,但如果数据是不规则长度(如文本序列),需要自定义 collate_fn

# 自定义collate_fn:处理变长数据
def custom_collate_fn(batch):
    """
    batch: 是一个列表,每个元素是 __getitem__ 返回的结果
    此处示例:处理文本序列(假设每条数据是 (序列, 标签),序列长度不同)
    """
    # 分离数据和标签
    sequences = [item[0] for item in batch]
    labels = [item[1] for item in batch]

    # 对序列做padding(补0到最长序列长度)
    max_len = max(len(seq) for seq in sequences)
    padded_sequences = []
    for seq in sequences:
        pad_len = max_len - len(seq)
        padded_seq = seq + [0]*pad_len  # 补0
        padded_sequences.append(padded_seq)

    # 转为张量
    return torch.tensor(padded_sequences), torch.tensor(labels)

# 模拟变长文本数据集
class TextDataset(Dataset):
    def __init__(self):
        # 模拟5条文本序列(长度分别为3,5,2,4,6)
        self.data = [
            ([1,2,3], 0),
            ([4,5,6,7,8], 1),
            ([9,10], 0),
            ([11,12,13,14], 1),
            ([15,16,17,18,19,20], 0)
        ]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# 初始化DataLoader并传入自定义collate_fn
text_dataloader = DataLoader(
    dataset=TextDataset(),
    batch_size=3,
    shuffle=True,
    collate_fn=custom_collate_fn  # 自定义批次拼接函数
)

# 遍历验证
for seqs, labels in text_dataloader:
    print(f"padding后的序列形状: {seqs.shape}")  # [3,6](3个样本,最长序列长度6)
    print(f"序列内容:\\\\n{seqs}")
    print(f"标签: {labels}")
    break

3.3 多数据集加载(ConcatDataset)

from torch.utils.data import ConcatDataset

# 拼接两个数据集
dataset1 = MyDataset(np_data[:50], np_labels[:50])
dataset2 = MyDataset(np_data[50:], np_labels[50:])
concat_dataset = ConcatDataset([dataset1, dataset2])

# 加载拼接后的数据集
concat_dataloader = DataLoader(concat_dataset, batch_size=20, shuffle=True)
print(f"拼接后数据集总长度: {len(concat_dataset)}")  # 输出100

三、常见问题与注意事项

  1. num_workers 报错:Windows系统下建议设为0,Linux/Mac可根据CPU核心数设置(如4、8),避免多线程冲突。
  2. drop_last 选择:训练时建议设为True(保证每个批次大小一致),测试时可设为False(不浪费数据)。
  3. shuffle 选择:训练集设为True(避免模型过拟合),测试/验证集设为False(保证结果可复现)。
  4. 内存占用batch_size 不宜过大(否则显存不足),可根据显卡显存调整(如12GB显存可设64/128)。

总结

  1. DataLoader 是基于 Dataset 的封装,核心作用是批量加载、打乱、多线程读取数据,必须配合 Dataset 使用。
  2. 核心参数:batch_size(批次大小)、
  3. (是否打乱)、num_workers(多线程)、drop_last(是否丢弃不完整批次)。
  4. 进阶场景:处理变长数据需自定义 collate_fn,处理图像需结合 torchvision.transforms,多数据集可使用 ConcatDataset

Ctrl+\多行注释

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# 新增:导入make_grid,用于拼接批量图片成网格(解决add_image批量显示问题)
from torchvision.utils import make_grid

# 1. 加载CIFAR10测试集(新增download=True,本地无数据时自动下载)
test_data = torchvision.datasets.CIFAR10(
    root='./dataset',
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=True  # 避免本地无数据集导致的路径错误
)

# 2. 封装DataLoader(参数保持不变)
test_loader = DataLoader(
    dataset=test_data,
    batch_size=4,
    shuffle=True,
    num_workers=0,
    drop_last=False
)

# 3. 查看第一张图片信息(修正print赋值错误)
img, target = test_data[0]
print("第一张图片的形状:", img.shape)  # 正确打印形状,输出torch.Size([3, 32, 32])
print("第一张图片的标签:", target)    # 正确打印标签,输出0-9的整数

# 4. TensorBoard可视化批量图片(修正核心错误)
writer = SummaryWriter("dataloader")  # 日志保存到当前目录的dataloader文件夹
step = 0
for data in test_loader:
    imgs, targets = data
    # 关键修正:用make_grid拼接4张图片为网格(nrow=2表示每行2张)
    img_grid = make_grid(imgs, nrow=2)
    # 写入拼接后的网格图片,step控制可视化的步数
    writer.add_image("test_data_batch", img_grid, step)
    step += 1  # 简化写法,等价于step=step+1
    # 移除循环内的writer.close() → 循环结束后再关闭

# 循环结束后关闭writer(确保所有图片都写入)
writer.close()

# 运行后在终端执行 tensorboard --logdir=dataloader,打开浏览器访问http://localhost:6006查看可视化结果

代码整体结构梳理

整个代码分为4个核心模块:

  1. 导入依赖库(基础工具+可视化工具)
  2. 加载CIFAR10测试数据集(定义数据来源和格式)
  3. 用DataLoader封装数据集(实现批量加载)
  4. 查看单张图片信息(验证数据加载正确性)
  5. TensorBoard可视化批量图片(直观展示DataLoader的批量数据)

二、逐行详细解释

模块1:导入依赖库

import torchvision

  • 作用:导入PyTorch官方的计算机视觉工具库torchvision,这个库包含了:
    • 常用的公开数据集(如CIFAR10、MNIST);
    • 图像预处理工具(如转张量、归一化);
    • 视觉相关的工具函数(如make_grid);
    • 预训练模型(如ResNet、VGG)。
  • 这里主要用它来加载CIFAR10数据集和使用make_grid函数。
from torch.utils.data import DataLoader

  • 作用:从PyTorch的核心数据工具模块中导入DataLoader类——这是PyTorch中批量加载数据的核心工具,能把数据集封装成“可迭代的批量数据生成器”,支持批量读取、打乱数据、多线程加载等功能。
from torch.utils.tensorboard import SummaryWriter

  • 作用:导入TensorBoard的PyTorch封装类SummaryWriter,用于将数据(图片、损失值、参数等)写入日志文件,后续可通过TensorBoard可视化这些数据。
from torchvision.utils import make_grid

  • 作用:从torchvision的工具模块中导入make_grid函数——核心作用是将多张图片张量拼接成一张网格状的图片(比如把4张图片拼成2行2列的网格),解决SummaryWriter.add_image只能显示单张图片的问题。

模块2:加载CIFAR10测试数据集

test_data = torchvision.datasets.CIFAR10(
    root='./dataset',
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=True
)

  • 这行代码是实例化CIFAR10数据集类,本质是告诉程序“去哪里找数据、加载哪部分数据、数据要怎么处理”,逐个参数解释:
    1. root='./dataset':数据集的保存路径(相对路径,保存在当前代码所在文件夹的dataset子文件夹中);
    2. train=False:指定加载测试集(如果设为True则加载训练集);
    3. transform=torchvision.transforms.ToTensor():对图片做预处理——将原本的PIL格式图片(像素值0-255)转为PyTorch张量(Tensor),同时把像素值归一化到0-1范围,张量形状变为[通道数, 高度, 宽度](CIFAR10图片是3通道,所以形状是[3,32,32]);
    4. download=True:如果./dataset文件夹中没有CIFAR10数据集,自动从官网下载(避免本地无数据导致报错)。
  • 最终test_data是一个Dataset类型的对象,支持通过索引访问单条数据(如test_data[0]取第一张图片)。

模块3:用DataLoader封装数据集

test_loader = DataLoader(
    dataset=test_data,
    batch_size=4,
    shuffle=True,
    num_workers=0,
    drop_last=False
)

  • 这行代码是将Dataset封装为DataLoader,本质是把“单条数据的集合”变成“批量数据的生成器”,逐个参数解释:
    1. dataset=test_data:传入要封装的数据集(即上面加载的CIFAR10测试集);
    2. batch_size=4:每个批次包含4张图片(每次读取4条数据);
    3. shuffle=True:遍历数据时,先打乱数据顺序再分批次(测试集设为True仅为演示,实际测试时建议设为False,保证结果可复现);
    4. num_workers=0:数据加载的线程数(Windows系统建议设为0,避免多线程冲突报错;Linux/Mac可设为4/8,加快加载速度);
    5. drop_last=False:如果最后一个批次的样本数不足batch_size(比如总数据数不是4的倍数),不丢弃这个批次(设为True则丢弃)。
  • 最终test_loader是一个可迭代的对象,遍历它会返回(批量图片张量, 批量标签张量),比如for data in test_loader每次取4张图片+4个标签。

模块4:查看单张图片信息(验证数据)

img, target = test_data[0]

  • 作用:通过索引0读取test_data中的第一张图片Dataset对象的索引访问会返回(图片张量, 标签)的元组,这里用img接收图片张量,target接收标签。
  • CIFAR10的标签是0-9的整数,对应10类物体:0=飞机、1=汽车、2=鸟、3=猫、4=鹿、5=狗、6=青蛙、7=马、8=船、9=卡车。
print("第一张图片的形状:", img.shape)

  • 作用:打印第一张图片张量的形状,输出是torch.Size([3, 32, 32]),含义:
    • 3:图片的通道数(RGB彩色图片,3个通道);
    • 32:图片的高度(像素);
    • 32:图片的宽度(像素)——这是CIFAR10图片的固定尺寸。
print("第一张图片的标签:", target)

  • 作用:打印第一张图片对应的标签(0-9的整数),比如输出3就代表这张图片是“猫”。

模块5:TensorBoard可视化批量图片(核心可视化逻辑)

writer = SummaryWriter("dataloader")

  • 作用:创建SummaryWriter对象,指定日志文件保存到当前文件夹的dataloader子文件夹中——后续所有可视化数据都会写入这个文件夹,供TensorBoard读取。
step = 0

  • 作用:定义一个步数计数器step,用于TensorBoard中区分不同批次的图片(每显示一个批次,步数+1,避免图片覆盖)。
for data in test_loader:
    imgs, targets = data

  • 作用:遍历test_loader,每次循环取一个批次的数:
    • data是一个元组,包含(批量图片张量, 批量标签张量)
    • imgs接收批量图片张量,形状是[4, 3, 32, 32](4张图片,每张3通道、32x32);
    • targets接收批量标签张量,形状是[4](4个标签,每个是0-9的整数)。
img_grid = make_grid(imgs, nrow=2)

  • 核心作用:用make_grid将4张图片拼接成一张网格图:
    • imgs:传入批量图片张量([4,3,32,32]);
    • nrow=2:指定网格的行数/列数——每行显示2张图片,所以最终是2行2列的网格;
    • 输出img_grid是一个单张图片张量,形状为[3, 32*2 + 间隔, 32*2 + 间隔](拼接后的网格图,可直接用add_image显示)。
  • 为什么需要这一步?因为writer.add_image默认只支持显示单张图片(形状[C,H,W]),直接传入批量图片([4,3,32,32])会报维度错误,make_grid是解决这个问题的关键。
writer.add_image("test_data_batch", img_grid, step)

  • 作用:将拼接后的网格图写入TensorBoard日志:
    • "test_data_batch":可视化图片的名称(在TensorBoard中会显示这个名称);
    • img_grid:要显示的图片张量(拼接后的单张网格图);
    • step:步数,用于区分不同批次的图片(比如step=0是第1批,step=1是第2批,TensorBoard中可通过滑动条切换)。
step += 1

  • 作用:每遍历一个批次,步数+1,保证下一个批次的图片有唯一的步数标识。
writer.close()

  • 作用:遍历结束后关闭SummaryWriter,确保所有日志数据都写入文件(必须放在循环外,如果放在循环内,第一次循环就会关闭写入器,后续批次的图片无法写入)。
# 额外提示:运行后在终端执行 tensorboard --logdir=dataloader,打开浏览器访问http://localhost:6006查看可视化结果

  • 作用:告诉用户如何查看TensorBoard可视化结果:
    1. 运行代码后,在终端(命令行)执行tensorboard --logdir=dataloader(指定日志文件夹);
    2. 终端会输出一个网址(通常是http://localhost:6006);
    3. 打开浏览器访问这个网址,在「Images」标签下就能看到拼接后的批量图片网格。

三、代码运行流程总结

  1. 程序先导入所需的库,准备好工具;
  2. 加载CIFAR10测试集,将图片转为张量格式(如果本地没有数据则自动下载);
  3. 用DataLoader把数据集封装成“每次返回4张图片”的批量数据生成器;
  4. 读取第一张图片,打印它的形状和标签,验证数据加载正确;
  5. 创建TensorBoard写入器,遍历DataLoader的每个批次:
    • 把4张图片拼接成2x2的网格图;
    • 将网格图写入TensorBoard日志,用step区分不同批次;
  6. 遍历结束后关闭写入器,用户可通过TensorBoard查看所有批次的图片。

四、关键细节补充

  1. 张量形状的变化
    • 单张图片:[3, 32, 32](C, H, W);
    • 批量图片:[4, 3, 32, 32](batch_size, C, H, W);
    • 拼接后网格图:[3, 68, 68](3通道,高度=322+4(间隔),宽度=322+4(间隔))。
  2. shuffle=True的作用:每次运行代码,遍历DataLoader得到的图片顺序都不同(打乱数据),但单张图片的内容不变。
  3. download=True的注意事项:第一次运行代码会下载CIFAR10数据集(约17MB),后续运行会直接读取本地数据,无需重复下载。
  4. TensorBoard的使用前提:需提前安装TensorBoard(pip install tensorboard),否则无法运行。

总结

  1. 这份代码的核心是DataLoader的使用TensorBoard可视化批量图片,前者实现数据的批量加载,后者通过make_grid拼接图片解决批量可视化的问题;
  2. 关键修正点:避免覆盖print函数、循环外关闭SummaryWriter、用make_grid拼接批量图片;
  3. 代码的最终效果:既能验证数据加载的正确性,又能直观看到DataLoader输出的批量图片,是深度学习中“数据加载+可视化”的基础模板。
  4. 训练查看不同数据时记得把

writer.add_image("test_data_batch", img_grid, step)这一句里的标签给改了,然后把tensorboard更新,以便查看不同训练的显示

1.什么是神经网络 Containers?

1. 本质定义

Containers 是 PyTorch torch.nn 模块下的容器类,核心作用是:

  • 把多个神经网络层(如 nn.Conv2dnn.Linearnn.ReLU)“打包”成一个整体;
  • 统一管理层的参数(如权重、偏置)、设备(CPU/GPU)、训练/评估模式;
  • 支持自动前向传播、参数优化、模型保存/加载等核心功能。

2. 为什么需要 Containers?

如果没有容器,你需要手动写每一层的前向传播(比如 x = nn.Conv2d()(x); x = nn.ReLU()(x)),且无法统一管理参数——Containers 让网络构建从“零散代码”变成“模块化结构”,大幅降低复杂度。

3. 核心基类:nn.Module

所有 Containers 都继承自 nn.Module(这是 PyTorch 中所有神经网络模块的基类),因此具备以下核心能力:

  • parameters():返回所有可训练参数(供优化器更新);
  • to(device):把整个模型移到 CPU/GPU;
  • train()/eval():切换训练/评估模式(影响 Dropout、BN 等层);
  • state_dict()/load_state_dict():保存/加载模型参数。

搭建我的第一个简单的神经网络

import torch
from torch import nn

class Tudui(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input):
            output = input+1
            return output
        #以上神经网络模板定义完了

#创建神经网络
tudui = Tudui()
x=torch.tensor(1.0)
#把X输入放在神经网络tudui当中
output=tudui(x)
print(output)
#输出tensor(2.)

2、PyTorch 中最常用的 4 种 Containers

1. 基础容器:nn.Sequential(顺序容器)

核心特点

  • 顺序堆叠神经网络层,前一层的输出直接作为后一层的输入;
  • 最简单、最常用的容器,适合构建“线性流程”的网络(如简单 CNN、MLP)。

实战用法(以 CIFAR10 分类的简单 CNN 为例)

import torch
import torch.nn as nn

# 用 Sequential 构建简单 CNN
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()#调用初始化函数
        # 卷积层:顺序堆叠(Conv2d → ReLU → MaxPool2d → Flatten → Linear)
        self.features = nn.Sequential(
            # 第一层卷积:3通道→16通道,3x3卷积,步长1,填充1
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),  # 激活函数
            nn.MaxPool2d(kernel_size=2, stride=2),  # 池化:32x32→16x16

            # 第二层卷积:16通道→32通道
            nn.Conv2d(16, 32, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 16x16→8x8

            # 展平:32*8*8 → 2048
            nn.Flatten()
        )

        # 全连接层:顺序堆叠(Linear → ReLU → Linear)
        self.classifier = nn.Sequential(
            nn.Linear(32 * 8 * 8, 512),  # 2048 → 512
            nn.ReLU(),
            nn.Linear(512, 10)  # 512 → 10(CIFAR10 10类)
        )

    def forward(self, x):#x是你自己定义的
        # 前向传播:按 Sequential 顺序执行
        x = self.features(x)
        x = self.classifier(x)
        return x

# 测试 Sequential 容器
model = SimpleCNN()
# 模拟输入:batch_size=4,3通道,32x32(CIFAR10 尺寸)
input_tensor = torch.randn(4, 3, 32, 32)
output = model(input_tensor)
print("输出形状:", output.shape)  # 输出 (4, 10) → 4个样本,10类预测值
print("模型可训练参数总数:", sum(p.numel() for p in model.parameters() if p.requires_grad))

关键

  • nn.Sequential 接收层的列表/字典,按传入顺序执行;
  • 无需手动写前向传播的每一步(比如不用写 x = self.conv1(x)),只需调用容器即可;
  • 适合结构简单、无分支的网络(如 LeNet、简单 MLP)。
  • 输入x 卷积conv1(x) 非线性F.relu 卷积 非线性 输出 比如
def forward(self.x)
x=F.relu(self.conv1(x))
return F.relu(self.conv2(x)

#forward(*input)计算图-在每个子类中都应该重写

2. 灵活容器:nn.Modulelist(模块列表)

核心特点

  • 把多个层/模块存成列表,支持索引访问、动态添加/删除层;
  • 无自动前向传播,需手动写前向逻辑,适合动态调整层数的场景(如可变层数的 CNN)。

实战用法(动态构建多层卷积)

class DynamicCNN(nn.Module):
    def __init__(self, in_channels=3, out_channels_list=[16, 32, 64]):
        super(DynamicCNN, self).__init__()
        # 用 ModuleList 存储多个卷积层(动态层数)
        self.conv_layers = nn.ModuleList()
        for out_ch in out_channels_list:
            # 逐个添加卷积层(Conv2d + ReLU)
            self.conv_layers.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_ch, 3, 1, 1),
                    nn.ReLU()
                )
            )
            in_channels = out_ch  # 更新输入通道数

        # 池化层
        self.pool = nn.MaxPool2d(2, 2)
        # 全连接层
        self.fc = nn.Linear(64 * 4 * 4, 10)  # 64通道,4x4特征图(32→16→8→4)

    def forward(self, x):
        # 手动遍历 ModuleList 执行前向传播
        for conv in self.conv_layers:
            x = conv(x)
            x = self.pool(x)

        x = torch.flatten(x, 1)  # 展平
        x = self.fc(x)
        return x

# 测试 ModuleList
dynamic_model = DynamicCNN(out_channels_list=[16, 32, 64])
input_tensor = torch.randn(4, 3, 32, 32)
output = dynamic_model(input_tensor)
print("DynamicCNN 输出形状:", output.shape)  # (4, 10)
# 访问 ModuleList 中的第二层
print("第二层卷积:", dynamic_model.conv_layers[1])

关键

  • nn.ModuleList 像普通 Python 列表,但会被 PyTorch 识别为“模型的一部分”(参数会被管理);
  • 必须手动遍历 ModuleList 执行前向传播(无自动顺序);
  • 适合需要动态调整层数的场景(如根据配置文件设置卷积层数)。

3. 键值对容器:nn.ModuleDict(模块字典)

核心特点

  • 把多个层/模块存成键值对(字典),按名称访问层;
  • 无自动前向传播,适合可选分支的网络(如多路径特征提取)。

实战用法(多分支特征提取)

class BranchCNN(nn.Module):
    def __init__(self):
        super(BranchCNN, self).__init__()
        # 用 ModuleDict 存储不同分支的卷积层(按名称访问)
        self.branch_layers = nn.ModuleDict({
            "small_kernel": nn.Conv2d(3, 16, kernel_size=3, padding=1),
            "large_kernel": nn.Conv2d(3, 16, kernel_size=5, padding=2),
            "relu": nn.ReLU()
        })
        self.pool = nn.MaxPool2d(2, 2)
        self.fc = nn.Linear(16 * 16 * 16, 10)

    def forward(self, x, branch_name="small_kernel"):
        # 按名称选择分支(灵活切换)
        x = self.branch_layers[branch_name](x)
        x = self.branch_layers["relu"](x)  # 共用ReLU层
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

# 测试 ModuleDict
branch_model = BranchCNN()
input_tensor = torch.randn(4, 3, 32, 32)
# 使用小核分支
output_small = branch_model(input_tensor, branch_name="small_kernel")
# 使用大核分支
output_large = branch_model(input_tensor, branch_name="large_kernel")
print("小核分支输出形状:", output_small.shape)  # (4, 10)
print("大核分支输出形状:", output_large.shape)  # (4, 10)

关键

  • nn.ModuleDict 用字符串键映射层,适合需要“动态选择分支”的场景;
  • 可动态添加层:branch_model.branch_layers["new_branch"] = nn.Conv2d(...)
  • 常用于构建多尺度特征提取的网络(如 ResNet 的分支、YOLO 的多尺度检测)。

4. 自定义容器:继承 nn.Module(最灵活)

核心特点

  • 所有容器的“底层”,Sequential/ModuleList 都继承自 nn.Module
  • 完全自定义前向传播逻辑,适合复杂结构的网络(如 ResNet 的残差块、Transformer 的注意力层)。

实战用法(残差块,ResNet 核心)

class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResBlock, self).__init__()
        # 主分支:Conv → BN → ReLU → Conv → BN
        self.main_branch = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, stride, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels)
        )
        # 捷径分支(维度匹配)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, 1, stride)

    def forward(self, x):
        # 自定义前向:主分支 + 捷径分支(残差连接)
        out = self.main_branch(x)
        out += self.shortcut(x)  # 残差相加(核心逻辑)
        out = nn.ReLU()(out)
        return out

# 构建带残差块的网络
class SimpleResNet(nn.Module):
    def __init__(self):
        super(SimpleResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, 1, 1)
        self.res_block1 = ResBlock(16, 16)  # 自定义容器作为子模块
        self.res_block2 = ResBlock(16, 32, stride=2)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(32, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.res_block1(x)  # 调用自定义残差块
        x = self.res_block2(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

# 测试自定义容器
res_model = SimpleResNet()
input_tensor = torch.randn(4, 3, 32, 32)
output = res_model(input_tensor)
print("ResNet 输出形状:", output.shape)  # (4, 10)

关键

  • 自定义容器是构建复杂网络的核心(如 ResNet、Transformer、GPT 等);
  • 只需继承 nn.Module,在 __init__ 中定义子模块,在 forward 中写自定义逻辑;
  • 子模块(如 res_block1)会被自动管理(参数、设备、训练模式)。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐