第八章:PyTorch生态简介 学习笔记
datawhale ai 共学。
第八章:PyTorch生态简介 — 深入浅出PyTorch datawhale ai 共学
datasets / transforms / model-zoo / audio & text
8.1 生态概览
领域 | 官方包 | 用途 |
---|---|---|
CV | torchvision |
图像/视频数据集、数据增强、预训练模型 |
NLP | torchtext |
文本数据处理、词表、常见数据集 |
Audio | torchaudio |
音频 I/O、特征、datasets、pipelines |
Video | pytorchvideo |
SOTA 视频模型 & pipeline (Meta) |
Graphs | torch_geometric |
GNN 旗舰库 (生态外) |
8.2 torchvision
8.2.1 datasets
from torchvision import datasets, transforms
train_ds = datasets.CIFAR10(root='./data',
train=True,
download=True,
transform=transforms.ToTensor())
-
-
root
:本地缓存目录。 -
download=True
:如不存在本地文件则从 mirror 自动拉取 -
transform
:对 单张 样本调用,不对 batch 起效
-
症状 | 原因 | 修复 |
---|---|---|
下载卡死 0B/s | 默认站点被墙 | TORCHVISION_DOWNLOAD_MIRRORS=http://download.pytorch.org 或换清华镜像 |
爆内存 | FakeData 默认 1000×3×224×224 | 指定 size= & image_size= |
8.2.2 transforms
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
img = Image.open("./figures/lenna.jpg") # ① 读入 PIL.Image
trans = transforms.Compose([ # ② pipeline
transforms.Resize(256), # 等比 ⇒ 短边 256
transforms.CenterCrop(224), # 中心裁 224²
transforms.ColorJitter(0.4,0.4,0.4,0.1), # 随机亮度/对比度/饱和/色调
transforms.ToTensor(), # [0,255] PIL ⇒ [0,1] Tensor
transforms.Normalize( # 标准化(Imagenet 统计值)
mean=[0.485,0.456,0.406],
std =[0.229,0.224,0.225])
])
t_img = trans(img) # ③ 调用
报错 | 含义 | 解决 |
---|---|---|
TypeError: pic should be PIL Image |
在 ToTensor 前用了 plt.imread |
保持 PIL 或改用 transforms.ToPILImage() |
图像全黑 | ColorJitter 参数太大 |
控制在 0–0.5 范围 |
拓展
-
RandAug / TrivialAug 已内置于
torchvision.transforms.v2
(≥0.15) -
批量增强 用
torchvision.transforms.functional
+vmap
/for
,Compose 仅对 sample
8.2.3 models
import torchvision.models as models
net = models.resnet18(weights='IMAGENET1K_V1') # ① 加载权重
net.fc = torch.nn.Linear(net.fc.in_features, 4) # ② 微调到 4 类
-
weights
新接口 (0.13+);旧pretrained=True
已 deprecated -
微调时记得
for p in net.parameters(): p.requires_grad = False
再解冻fc
8.3 PyTorchVideo
1. Hub 调用
import torch
model = torch.hub.load('facebookresearch/pytorchvideo',
model='slowfast_r50',
pretrained=True)
model.eval()
2. 输入尺寸
(B, C=3, T, H, W)
,默认 32 x 3 clip
clip = torch.randn(1, 3, 32, 224, 224)
with torch.no_grad(): logits = model(clip)
常见错误RuntimeError: Given groups=1, weight ...
→ 维度顺序,把 (B,C,H,W,T)
写反
8.4 torchtext
8.4.1 快速流水线
from torchtext.data import Field, BucketIterator
from torchtext.datasets import IMDB
token = lambda x: x.split()
TEXT = Field(tokenize=token, lower=True, batch_first=True)
LABEL = Field(sequential=False, unk_token=None)
train_ds, test_ds = IMDB.splits(TEXT, LABEL)
TEXT.build_vocab(train_ds, max_size=20000, vectors='glove.6B.100d')
LABEL.build_vocab(train_ds)
train_iter, test_iter = BucketIterator.splits(
(train_ds, test_ds), batch_size=32, sort_key=lambda x: len(x.text))
-
踩坑:新版 torchtext (0.12+) 移除了上面老接口,需要用
torchtext.legacy
或新 API (torchtext.data.functional
,torchtext.datasets.IMDB(split='train')
)
8.5 torchaudio
8.5.1 基础 I/O
import torchaudio
wave, sr = torchaudio.load('speech.wav') # (channel, time)
resample = torchaudio.transforms.Resample(sr, 16_000)
mel_spect = torchaudio.transforms.MelSpectrogram(16_000)
feat = mel_spect(resample(wave)) # (C, n_mels, frames)
8.5.2 预训练 ASR
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
asr = bundle.get_model().eval()
tokens = asr(wave.unsqueeze(0)) # logits
需求 | 方案 |
---|---|
移动端 on-device 推理 | torchaudio.models.emformer_rnnt + TorchScript |
多 说话人 | Asteroid , SpeechBrain 社区项目 |
常见跨包报错速查表
Error | 触发场景 | 修正 |
---|---|---|
libsox not found |
torchaudio I/O |
conda install -c conda-forge sox |
No module named 'pathlib' |
老 Python | >=3.7 |
TypeError: expected scalar type Double but found Float |
torchaudio + GPU |
.to(dtype=torch.float32) |
Checklist
-
版本对齐:
torch == torchvision == torchaudio
major/minor 相同 -
数据 Aug:影像
transforms.v2
,音频torchaudio.sox_effects.apply_effects_tensor
,文本nlpaug
。 -
模型 Zoo:Torch-Hub(CV/Video)+
torchaudio.pipelines
(Audio)+transformers
(NLP) -
可视化:
torchinfo
(结构)+ TensorBoard / wandb(指标) -
部署:移动端 → PyTorchVideo
accelerator
; 服务器 →torch.compile
≥2.0。
更多推荐
所有评论(0)