一、回顾上节课:你已经认识了“食材”

在上一篇文章中,我们学会了:

✅ 张量就是“装数字的箱子”

  • 0维(一个数)、1维(一串数)、2维(表格)、3维+(多维数据)

  • 就像是米、面、油、菜等各种基础食材

✅ 创建和操作张量

  • torch.tensor().shape.view()、广播机制等

  • 就像认识各种食材的特性,知道怎么存放它们

一句话总结上节:你认识了AI厨房里的所有基础食材!


二、今日目标:学习“切菜”——数据加载与预处理

🎯 学习路线图更新

还记得我们的“做菜式”学习路径吗?今天来到第二步:

今日核心:学习如何把“买回来的菜”(原始数据)变成“可以直接下锅的食材”(模型可用的张量)。


三、为什么需要数据处理?

现实世界的“菜”长什么样?

想象一下:如果你要教AI认识猫和狗

原始数据(刚买回来的“菜”)

  • 图片大小不一:有的1920×1080,有的640×480

  • 格式混乱:.jpg、.png、.bmp都有

  • 方向不对:有的横着拍,有的竖着拍

  • 亮度差异:有的在阳光下,有的在阴影里

模型需要的“标准食材”

  • 统一尺寸:比如都变成224×224

  • 统一格式:都转为RGB三通道张量

  • 数值标准化:像素值从0-255变成0-1

  • 批量整齐:一次处理32张,形状都是[32, 3, 224, 224]

这就是数据处理要做的事情!


四、数据处理三大步:洗、切、配

第一步:洗菜——数据加载(Dataset)

把数据从硬盘“拿”到内存,并进行初步清理。

python

import torch
from torch.utils.data import Dataset
from PIL import Image
import os

# 自定义数据集类(就像定制菜篮)
class 猫狗数据集(Dataset):
    def __init__(self, 图片文件夹路径, 变换=None):
        """
        参数:
            图片文件夹路径:存放猫狗图片的文件夹
            变换:要对图片做的处理(比如调整大小)
        """
        self.所有图片路径 = []
        self.标签 = []  # 0=猫,1=狗
        self.变换 = 变换
        
        # 遍历文件夹,收集所有图片路径和标签
        for 文件名 in os.listdir(图片文件夹路径):
            文件路径 = os.path.join(图片文件夹路径, 文件名)
            
            if 文件名.startswith('cat'):
                self.所有图片路径.append(文件路径)
                self.标签.append(0)  # 猫标签为0
            elif 文件名.startswith('dog'):
                self.所有图片路径.append(文件路径)
                self.标签.append(1)  # 狗标签为1
    
    def __len__(self):
        """返回数据集大小(有多少张图片)"""
        return len(self.所有图片路径)
    
    def __getitem__(self, 索引):
        """
        获取第n个样本
        返回:(图片张量, 标签)
        """
        # 1. 加载图片(从硬盘读到内存)
        图片路径 = self.所有图片路径[索引]
        图片 = Image.open(图片路径).convert('RGB')  # 确保是RGB格式
        
        # 2. 应用变换(如果有的话)
        if self.变换:
            图片 = self.变换(图片)
        
        # 3. 获取标签
        标签 = self.标签[索引]
        
        return 图片, 标签

# 使用示例
print("📂 创建数据集对象...")
我的数据集 = 猫狗数据集(图片文件夹路径='./data/train')

print(f"数据集大小: {len(我的数据集)} 张图片")
print(f"第一张图片: {我的数据集[0]}")

第二步:切菜——数据变换(Transform)

把图片变成模型能吃懂的“标准形状”。

python

import torchvision.transforms as transforms

# 定义一系列“切菜刀法”
数据预处理 = transforms.Compose([
    # 1. 调整大小(把大菜叶切成合适大小)
    transforms.Resize((256, 256)),
    
    # 2. 随机裁剪(从中间切出224×224的正方形)
    transforms.RandomCrop((224, 224)),
    
    # 3. 随机水平翻转(增加数据多样性)
    transforms.RandomHorizontalFlip(p=0.5),  # 50%概率翻转
    
    # 4. 转为张量(从PIL图片变成PyTorch张量)
    transforms.ToTensor(),
    
    # 5. 标准化(让数值分布更稳定)
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],  # ImageNet数据集的均值
        std=[0.229, 0.224, 0.225]    # ImageNet数据集的标准差
    )
])

# 更简单的变换(初学者先用这个)
简单变换 = transforms.Compose([
    transforms.Resize((224, 224)),  # 统一大小
    transforms.ToTensor(),          # 转为张量
    # 标准化:把0-255的像素值变成0-1,再归一化
])

# 应用到数据集
数据集带变换 = 猫狗数据集(
    图片文件夹路径='./data/train',
    变换=简单变换
)

# 查看处理后的图片
图片张量, 标签 = 数据集带变换[0]
print(f"\n处理后图片形状: {图片张量.shape}")
print(f"像素值范围: [{图片张量.min():.3f}, {图片张量.max():.3f}]")
print(f"标签: {'猫' if 标签==0 else '狗'}")

# 解释shape的含义:
# [3, 224, 224] = [颜色通道数, 高度, 宽度]
# 3个通道:红(Red)、绿(Green)、蓝(Blue)

第三步:配菜——批量加载(DataLoader)

把处理好的数据分成小份,方便“下锅炒菜”。

python

from torch.utils.data import DataLoader

# 创建数据加载器(就像准备多个小菜篮)
训练数据加载器 = DataLoader(
    dataset=数据集带变换,  # 使用哪个数据集
    batch_size=32,         # 每个小菜篮放32张图片
    shuffle=True,          # 每次打乱顺序(像洗牌)
    num_workers=2          # 用2个"帮手"并行加载数据
)

# 使用示例:遍历一个批次的数据
print("\n🍽️ 准备上菜(批量加载数据)...")
for 批次索引, (图片批次, 标签批次) in enumerate(训练数据加载器):
    print(f"\n第{批次索引+1}批菜:")
    print(f"  图片形状: {图片批次.shape}")  # [32, 3, 224, 224]
    print(f"  标签形状: {标签批次.shape}")  # [32]
    print(f"  标签内容: {标签批次[:5]}")    # 前5个标签
    
    # 通常只查看前几个批次
    if 批次索引 == 2:
        break

# 高级技巧:不同阶段用不同数据
训练集 = 猫狗数据集('./data/train', 变换=训练变换)
测试集 = 猫狗数据集('./data/test', 变换=测试变换)

训练加载器 = DataLoader(训练集, batch_size=32, shuffle=True)
测试加载器 = DataLoader(测试集, batch_size=32, shuffle=False)  # 测试时不打乱

五、完整实战:猫狗识别数据准备

让我们把所有步骤组合起来:

python

import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os

# ========== 1. 定义数据变换 ==========
训练变换 = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], 
                        [0.229, 0.224, 0.225])
])

测试变换 = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),  # 测试时从中心裁剪
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                        [0.229, 0.224, 0.225])
])

# ========== 2. 创建数据集 ==========
class 猫狗识别数据集(Dataset):
    def __init__(self, 根目录, 模式='train', 变换=None):
        self.文件列表 = []
        self.标签列表 = []
        self.变换 = 变换
        
        文件夹路径 = os.path.join(根目录, 模式)
        
        # 假设目录结构:
        # data/
        #   train/
        #       cat.001.jpg
        #       dog.001.jpg
        #   test/
        #       cat.002.jpg
        #       dog.002.jpg
        
        for 文件名 in os.listdir(文件夹路径):
            if 文件名.endswith(('.jpg', '.png', '.jpeg')):
                完整路径 = os.path.join(文件夹路径, 文件名)
                self.文件列表.append(完整路径)
                
                # 从文件名判断是猫还是狗
                if 'cat' in 文件名.lower():
                    self.标签列表.append(0)
                else:
                    self.标签列表.append(1)
    
    def __len__(self):
        return len(self.文件列表)
    
    def __getitem__(self, 索引):
        图片 = Image.open(self.文件列表[索引]).convert('RGB')
        标签 = self.标签列表[索引]
        
        if self.变换:
            图片 = self.变换(图片)
            
        return 图片, 标签

# ========== 3. 创建数据加载器 ==========
# 假设你的数据放在当前目录的data文件夹中
数据集 = 猫狗识别数据集(根目录='./data', 模式='train', 变换=训练变换)
数据加载器 = DataLoader(数据集, batch_size=16, shuffle=True, num_workers=4)

# ========== 4. 验证数据准备是否成功 ==========
print("=" * 50)
print("猫狗识别数据准备系统")
print("=" * 50)

print(f"\n📊 数据集统计:")
print(f"  总图片数: {len(数据集)}")
print(f"  批次大小: 16")
print(f"  总批次数: {len(数据加载器)}")

# 获取一个批次看看
图片批次, 标签批次 = next(iter(数据加载器))
print(f"\n🔍 第一个批次详情:")
print(f"  图片形状: {图片批次.shape}")
print(f"  标签形状: {标签批次.shape}")

# 统计猫狗数量
猫数量 = (标签批次 == 0).sum().item()
狗数量 = (标签批次 == 1).sum().item()
print(f"  本批次猫: {猫数量}张, 狗: {狗数量}张")

print("\n✅ 数据准备完成!可以开始训练模型了!")

六、常见问题与技巧

Q1:我的数据不在本地,怎么加载?

python

# 从网上下载标准数据集(最简单的方式)
from torchvision import datasets

# 下载MNIST手写数字数据集
mnist数据集 = datasets.MNIST(
    root='./data',          # 下载到哪里
    train=True,             # 训练集
    download=True,          # 如果本地没有就下载
    transform=transforms.ToTensor()
)

# 下载CIFAR-10数据集(10类物体)
cifar数据集 = datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transforms.ToTensor()
)

print(f"MNIST大小: {len(mnist数据集)}")
print(f"CIFAR-10大小: {len(cifar数据集)}")

Q2:标签不是0/1,而是文字怎么办?

python

# 方法1:使用标签映射字典
类别到编号 = {'cat': 0, 'dog': 1, 'bird': 2}
编号到类别 = {0: '猫', 1: '狗', 2: '鸟'}

# 方法2:自动生成映射
类别列表 = ['cat', 'dog', 'bird', 'fish']
类别到编号 = {类别: 编号 for 编号, 类别 in enumerate(类别列表)}

Q3:数据太大,内存放不下怎么办?

python

# 使用迭代器,一次只加载一部分
class 大数据集(Dataset):
    def __init__(self, 文件列表路径):
        # 只保存文件路径,不保存图片数据
        with open(文件列表路径, 'r') as f:
            self.文件路径列表 = f.read().splitlines()
    
    def __getitem__(self, 索引):
        # 每次需要时才从硬盘读取
        图片路径 = self.文件路径列表[索引]
        return Image.open(图片路径)  # 现用现读

Q4:如何查看处理后的图片?

python

import matplotlib.pyplot as plt
import numpy as np

def 显示图片(图片张量, 标题=''):
    """
    将张量格式的图片显示出来
    图片张量形状:[C, H, W] 或 [B, C, H, W]
    """
    # 如果是批量数据,取第一张
    if len(图片张量.shape) == 4:
        图片张量 = 图片张量[0]
    
    # 将张量转为numpy,并调整通道顺序
    图片数组 = 图片张量.numpy().transpose(1, 2, 0)
    
    # 反标准化(如果需要)
    图片数组 = 图片数组 * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
    图片数组 = np.clip(图片数组, 0, 1)  # 限制在0-1范围内
    
    plt.imshow(图片数组)
    plt.title(标题)
    plt.axis('off')
    plt.show()

# 使用示例
图片, 标签 = 数据集带变换[0]
显示图片(图片, f'标签: {"猫" if 标签==0 else "狗"}')

七、今日总结与任务

✅ 今天你学会了:

1. Dataset类:自定义数据容器

  • 像定制菜篮,告诉PyTorch你的数据在哪里、怎么读

  • 必须实现__len____getitem__两个方法

2. Transform变换:数据预处理流水线

  • Resize:调整大小

  • ToTensor:转为PyTorch张量

  • Normalize:标准化(重要!)

  • 还可以随机裁剪、翻转等增加数据多样性

3. DataLoader:批量数据加载器

  • batch_size:每批多少数据

  • shuffle:是否打乱顺序

  • num_workers:用几个线程加载

🎯 一句话记住数据处理:

Dataset定义数据在哪,Transform定义怎么处理,DataLoader定义怎么批量取用。

📝 你的实践任务:

python

"""
任务:创建一个表情识别数据管道
假设你有3类表情图片:happy(高兴)、sad(悲伤)、angry(生气)

目录结构:
emotion_data/
    train/
        happy_001.jpg
        sad_001.jpg
        angry_001.jpg
        ...
    test/
        ...

要求:
1. 创建自定义Dataset类
2. 训练集使用随机裁剪和翻转
3. 测试集只做中心裁剪
4. 批量大小设为16
5. 显示第一批数据的前3张图片

提示:
- 可以用文件名判断类别(如'happy'开头)
- 图片统一缩放到256×256,再裁剪224×224
"""

# 先自己尝试,参考答案见文末

🚀 下一步学习什么?

学会了“切菜”(数据处理),下次我们将进入:
“学习炒菜”——构建第一个神经网络模型

你将学会:

  • 如何定义网络层(像准备锅碗瓢盆)

  • 前向传播(像下锅翻炒)

  • 损失函数和优化器(像控制火候和调味)


八、数据处理参考答案

python

# 表情识别数据管道参考答案
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import matplotlib.pyplot as plt

class 表情数据集(Dataset):
    def __init__(self, 根目录, 模式='train', 变换=None):
        self.文件路径 = []
        self.标签 = []
        self.类别映射 = {'happy': 0, 'sad': 1, 'angry': 2}
        self.变换 = 变换
        
        文件夹 = os.path.join(根目录, 模式)
        
        for 文件 in os.listdir(文件夹):
            if 文件.endswith(('.jpg', '.png')):
                路径 = os.path.join(文件夹, 文件)
                self.文件路径.append(路径)
                
                # 从文件名提取标签
                for 类别 in self.类别映射:
                    if 类别 in 文件.lower():
                        self.标签.append(self.类别映射[类别])
                        break
    
    def __len__(self):
        return len(self.文件路径)
    
    def __getitem__(self, 索引):
        图片 = Image.open(self.文件路径[索引]).convert('RGB')
        标签 = self.标签[索引]
        
        if self.变换:
            图片 = self.变换(图片)
            
        return 图片, 标签

# 定义变换
训练变换 = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

测试变换 = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

# 创建数据集和加载器
数据集 = 表情数据集('./emotion_data', 模式='train', 变换=训练变换)
加载器 = DataLoader(数据集, batch_size=16, shuffle=True)

# 获取并显示第一批数据
图片批次, 标签批次 = next(iter(加载器))
标签名称 = ['高兴', '悲伤', '生气']

fig, axes = plt.subplots(1, 3, figsize=(12, 4))
for i in range(3):
    图片 = 图片批次[i].numpy().transpose(1, 2, 0)
    图片 = 图片 * 0.5 + 0.5  # 反标准化
    图片 = np.clip(图片, 0, 1)
    
    axes[i].imshow(图片)
    axes[i].set_title(f'表情: {标签名称[标签批次[i].item()]}')
    axes[i].axis('off')

plt.tight_layout()
plt.show()

print(f"✅ 数据管道创建成功!")
print(f"   批次形状: {图片批次.shape}")
print(f"   标签: {标签批次[:3].tolist()}")

记住:数据处理是AI项目中最耗时但最重要的环节。一个好的数据管道,相当于备好了新鲜、干净的食材,后续烹饪(模型训练)才会顺利!

现在,你的“AI厨房”里已经备好了处理好的食材,下次我们就可以开始真正的“烹饪”了——构建和训练神经网络!

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐