第八章:PyTorch生态简介 — 深入浅出PyTorch   datawhale ai 共学

datasets / transforms / model-zoo / audio & text


8.1 生态概览

领域 官方包 用途
CV torchvision 图像/视频数据集、数据增强、预训练模型
NLP torchtext 文本数据处理、词表、常见数据集
Audio torchaudio 音频 I/O、特征、datasets、pipelines
Video pytorchvideo SOTA 视频模型 & pipeline (Meta)
Graphs torch_geometric GNN 旗舰库 (生态外)

8.2 torchvision

8.2.1 datasets 

from torchvision import datasets, transforms
train_ds = datasets.CIFAR10(root='./data',
                            train=True,
                            download=True,
                            transform=transforms.ToTensor())
    1. root:本地缓存目录。

    2. download=True:如不存在本地文件则从 mirror 自动拉取

    3. transform:对 单张 样本调用,不对 batch 起效

症状 原因 修复
下载卡死 0B/s 默认站点被墙 TORCHVISION_DOWNLOAD_MIRRORS=http://download.pytorch.org 或换清华镜像
爆内存 FakeData 默认 1000×3×224×224 指定 size= & image_size=

8.2.2 transforms 

from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

img = Image.open("./figures/lenna.jpg")      # ① 读入 PIL.Image
trans = transforms.Compose([                 # ② pipeline
    transforms.Resize(256),                  # 等比 ⇒ 短边 256
    transforms.CenterCrop(224),              # 中心裁 224²
    transforms.ColorJitter(0.4,0.4,0.4,0.1), # 随机亮度/对比度/饱和/色调
    transforms.ToTensor(),                   # [0,255] PIL ⇒ [0,1] Tensor
    transforms.Normalize(                    # 标准化(Imagenet 统计值)
        mean=[0.485,0.456,0.406],
        std =[0.229,0.224,0.225])
])

t_img = trans(img)                           # ③ 调用
报错 含义 解决
TypeError: pic should be PIL Image ToTensor 前用了 plt.imread 保持 PIL 或改用 transforms.ToPILImage()
图像全黑 ColorJitter 参数太大 控制在 0–0.5 范围

拓展

  • RandAug / TrivialAug 已内置于 torchvision.transforms.v2 (≥0.15)

  • 批量增强torchvision.transforms.functional + vmap / for,Compose 仅对 sample


8.2.3 models 

import torchvision.models as models

net = models.resnet18(weights='IMAGENET1K_V1')   # ① 加载权重
net.fc = torch.nn.Linear(net.fc.in_features, 4)  # ② 微调到 4 类
  • weights 新接口 (0.13+);旧 pretrained=Truedeprecated

  • 微调时记得 for p in net.parameters(): p.requires_grad = False 再解冻 fc


8.3 PyTorchVideo

1. Hub 调用

import torch
model = torch.hub.load('facebookresearch/pytorchvideo',
                       model='slowfast_r50',
                       pretrained=True)
model.eval()

2. 输入尺寸

(B, C=3, T, H, W),默认 32 x 3 clip

clip = torch.randn(1, 3, 32, 224, 224)
with torch.no_grad(): logits = model(clip)

常见错误
RuntimeError: Given groups=1, weight ... → 维度顺序,把 (B,C,H,W,T) 写反


8.4 torchtext

8.4.1 快速流水线

from torchtext.data import Field, BucketIterator
from torchtext.datasets import IMDB

token = lambda x: x.split()
TEXT  = Field(tokenize=token, lower=True, batch_first=True)
LABEL = Field(sequential=False, unk_token=None)

train_ds, test_ds = IMDB.splits(TEXT, LABEL)
TEXT.build_vocab(train_ds, max_size=20000, vectors='glove.6B.100d')
LABEL.build_vocab(train_ds)

train_iter, test_iter = BucketIterator.splits(
        (train_ds, test_ds), batch_size=32, sort_key=lambda x: len(x.text))
  • 踩坑:新版 torchtext (0.12+) 移除了上面老接口,需要用 torchtext.legacy 或新 API (torchtext.data.functional, torchtext.datasets.IMDB(split='train'))


8.5 torchaudio

8.5.1 基础 I/O

import torchaudio
wave, sr = torchaudio.load('speech.wav')        # (channel, time)

resample  = torchaudio.transforms.Resample(sr, 16_000)
mel_spect = torchaudio.transforms.MelSpectrogram(16_000)

feat = mel_spect(resample(wave))                # (C, n_mels, frames)

8.5.2 预训练 ASR

bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
asr    = bundle.get_model().eval()
tokens = asr(wave.unsqueeze(0))             # logits
需求 方案
移动端 on-device 推理 torchaudio.models.emformer_rnnt + TorchScript
多 说话人 Asteroid, SpeechBrain 社区项目

常见跨包报错速查表

Error 触发场景 修正
libsox not found torchaudio I/O conda install -c conda-forge sox
No module named 'pathlib' 老 Python >=3.7
TypeError: expected scalar type Double but found Float torchaudio + GPU .to(dtype=torch.float32)

Checklist 

  1. 版本对齐torch == torchvision == torchaudio major/minor 相同

  2. 数据 Aug:影像 transforms.v2,音频 torchaudio.sox_effects.apply_effects_tensor,文本 nlpaug

  3. 模型 Zoo:Torch-Hub(CV/Video)+ torchaudio.pipelines(Audio)+ transformers(NLP)

  4. 可视化torchinfo(结构)+ TensorBoard / wandb(指标)

  5. 部署:移动端 → PyTorchVideo accelerator; 服务器 → torch.compile ≥2.0。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐