⚠️ 关键修复点总结

  | 你的代码              | 问题           | 正确写法                              |
  |-------------------|--------------|-----------------------------------|
  | pd.txt.read()     | pandas没有这个方法 | open(file).read()                 |
  | os.jion.path()    | 拼写错误         | os.path.join()                    |
  | train_text.lines  | 文件对象没有此属性    | f.read().splitlines()             |
  | 取i中的./images/图像id | 伪代码          | parts = line.split() + parts[0]   |
  | "D:\github\..."   | 转义问题         | r"D:\github\..." 或 "D:\\github\\" |

错误示例:

import os

import pandas as pd

# 初始化列表,存放图片和掩码路径

image_list = []

mask_list = []

# 定义base路径

base_dir = "D:\github\pytorch-deeplab-xception-master\data\hanfeng"

# 读取图像和掩码的id

# image_train_test = pd.txt.read(os.jion.path(base_dir, "hanfen","trainval.txt"))

# mask_train_text = pd.txt.read(os.jion.path(base_dir, "hanfen","trainval.txt"))

train_text = pd.txt.read(os.jion.path(base_dir, "Hui_txt","trainval.txt"))

# 将所有的图像和掩码路径存入列表

for i in train_text.lines:

        image = 取i中的./images/图像id

        mask = 取i中的./mask/图像id_mask

        image_list.append(image)

        mask_list.append(mask)

 正确示例:

import os

# 初始化列表,存放图片和掩码路径

image_list = []

mask_list = []

1.# 定义base路径(注意:使用原始字符串或正斜杠)

base_dir = r"D:\github\pytorch-deeplab-xception-master\data\hanfeng"

2.# 构建txt文件的完整路径

txt_path = os.path.join(base_dir, "Hui_txt", "trainval.txt")

3.# 读取txt文件

with open(txt_path, "r", encoding="utf-8") as f:

    train_text = f.readlines()

4.# 遍历每一行,提取图像和掩码路径

for line in train_text:

    # 去除首尾空白字符和换行符

    line = line.strip()

    # 分割字符串(按空格分割)

    paths = line.split()

    if len(paths) == 2:  # 确保每行有两个路径

        image_path = paths[0]  # ./images/0.png

        mask_path = paths[1]   # ./annotations_Hui/0.png

        # 将相对路径转为完整路径

        full_image_path = os.path.join(base_dir, image_path)

        full_mask_path = os.path.join(base_dir, mask_path)

        # 添加到列表

        image_list.append(full_image_path)

        mask_list.append(full_mask_path)

结合dataset类进行撰写:

from __future__ import print_function, division

import os

from PIL import Image

import numpy as np

from torch.utils.data import Dataset

from mypath import Path

from torchvision import transforms

import custom_transforms as tr

class HanfengSegmentation(Dataset):

    """

    Hanfeng 焊缝分割数据集

    """

    NUM_CLASSES = 2  # 根据你的实际类别数调整(背景+焊缝)

    def __init__(self,

                 args,

                 base_dir=r"D:\github\pytorch-deeplab-xception-master\data\hanfeng",

                 split='trainval',

                 ):

        """

        初始化Hanfeng数据集

        :param args: 参数对象(包含base_size, crop_size等)

        :param base_dir: 数据集根目录

        :param split: 数据集划分 ('train', 'val', 'trainval', 'test')

        """

        super().__init__()

        self._base_dir = base_dir

        self.args = args

        if isinstance(split, str):

            self.split = [split]

        else:

            split.sort()

            self.split = split

        # 存储图像和掩码的路径

        self.images = []

        self.masks = []

        # 读取数据集划分文件,给数据集路径txt的路径

        for splt in self.split:

            split_file = os.path.join(self._base_dir, 'Hui_txt', splt + '.txt')

            # 检查文件是否存在

            if not os.path.isfile(split_file):

                raise FileNotFoundError(f"Split file not found: {split_file}")

            # 读取文件内容,splitlines()自动将每行分开并存入列表中

        # f.read():一次性把整个文件内容读成一个大字符串

        # .splitlines():在这个字符串上按换行符切分,返回一个“每行一条”的列表(不包含行尾的 \n)

            with open(split_file, 'r') as f:

                lines = f.read().splitlines()

            # 解析每一行,提取图像和掩码路径

            for line in lines:

                # 每行格式: ./images/0.png ./annotations_Hui/0.png

                # strip() 去掉换行符,split() 按空格切开,["./images/0.png", "./annotations_Hui/0.png"]

                parts = line.strip().split()

                if len(parts) >= 2:

                    # 提取相对路径

                    image_rel_path = parts[0]  # 例如: ./images/0.png

                    mask_rel_path = parts[1]   # 例如: ./annotations_Hui/0.png

                    # 构建完整路径

                    image_path = os.path.join(self._base_dir, image_rel_path.lstrip('./'))  #去掉左侧所有 '.' 和 '/'

                    mask_filename = os.path.basename(mask_rel_path)  # 只取文件名部分 "./annotations_Hui/10.png" → "10.png"

                    mask_path = os.path.join(self._base_dir, 'masks', mask_filename) # 因为实际目录结构中掩码在masks文件夹下

                    # 验证文件是否存在

                    if os.path.isfile(image_path) and os.path.isfile(mask_path):

                        self.images.append(image_path)

                        self.masks.append(mask_path)

                    else:

                        print(f"Warning: File not found - Image: {image_path}, Mask: {mask_path}")

        # 验证数据完整性

        assert len(self.images) == len(self.masks), \

            f"图像数量({len(self.images)})和掩码数量({len(self.masks)})不匹配!"

        # 显示统计信息

        print(f'Hanfeng数据集 - {split} 划分: 共 {len(self.images)} 张图像')

    def __len__(self):

        """返回数据集大小"""

        return len(self.images)

    def __getitem__(self, index):

        """

        获取单个样本

        :param index: 样本索引

        :return: 包含image和label的字典

        """

        _img, _target = self._make_img_gt_pair(index)

        sample = {'image': _img, 'label': _target}

        # 根据数据集划分应用不同的数据变换

        for split in self.split:

            if split in ["train", "trainval"]:

                return self.transform_tr(sample)

            elif split in ['val', 'test']:

                return self.transform_val(sample)

    def _make_img_gt_pair(self, index):

        """

        加载图像和掩码

        :param index: 样本索引

        :return: PIL图像对象

        """

        _img = Image.open(self.images[index]).convert('RGB')

        _target = Image.open(self.masks[index])

        return _img, _target

    def transform_tr(self, sample):

        """训练集数据增强"""

        composed_transforms = transforms.Compose([

            tr.RandomHorizontalFlip(),

            tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size),

            tr.RandomGaussianBlur(),

            tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),

            tr.ToTensor()])

        return composed_transforms(sample)

    def transform_val(self, sample):

        """验证/测试集数据变换"""

        composed_transforms = transforms.Compose([

            tr.FixScaleCrop(crop_size=self.args.crop_size),

            tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),

            tr.ToTensor()])

        return composed_transforms(sample)

    def __str__(self):

        return 'Hanfeng(split=' + str(self.split) + ')'

# ==================== 简单测试脚本 ====================

# if __name__ == '__main__':

#     from torch.utils.data import DataLoader

#     import matplotlib.pyplot as plt

#     import argparse

#     # 创建参数对象

#     parser = argparse.ArgumentParser()

#     args = parser.parse_args()

#     args.base_size = 513

#     args.crop_size = 513

#     # 加载训练集

#     hanfeng_train = HanfengSegmentation(args, split='trainval')

#     print(f"数据集大小: {len(hanfeng_train)}")

#     print(f"第一个样本的路径:")

#     print(f"  图像: {hanfeng_train.images[0]}")

#     print(f"  掩码: {hanfeng_train.masks[0]}")

#     # 创建数据加载器

#     dataloader = DataLoader(hanfeng_train, batch_size=2, shuffle=True, num_workers=0)

#     # 可视化第一个batch,从 dataloader 里取 batch,一边给每个 batch 自动编号(0,1,2,3…)

#     for ii, sample in enumerate(dataloader):

#         print(f"Batch {ii}: Image shape = {sample['image'].shape}, Label shape = {sample['label'].shape}")

#         # 可视化

#         for jj in range(sample["image"].size()[0]):

#             img = sample['image'].numpy()

#             gt = sample['label'].numpy()

#             # 反归一化图像

#             img_tmp = np.transpose(img[jj], axes=[1, 2, 0])

#             img_tmp *= (0.229, 0.224, 0.225)

#             img_tmp += (0.485, 0.456, 0.406)

#             img_tmp *= 255.0

#             img_tmp = img_tmp.astype(np.uint8)

#             # 掩码

#             mask_tmp = np.array(gt[jj]).astype(np.uint8)

#             # 显示

#             plt.figure(figsize=(10, 5))

#             plt.subplot(121)

#             plt.imshow(img_tmp)

#             plt.title('Image')

#             plt.axis('off')

#             plt.subplot(122)

#             plt.imshow(mask_tmp, cmap='gray')

#             plt.title('Mask')

#             plt.axis('off')

#             plt.tight_layout()

#         if ii == 0:  # 只显示第一个batch

#             break

#     plt.show()

总体流程:

🎯 PyTorch Dataset 完整解释

我用三个层次帮你理解这个 hanfengDatasetSegmentation 类:


📚 层次1:Dataset 在训练中的角色

PyTorch 训练流程三大组件:

训练流程:

1. Dataset(数据集)
   ↓ "我有N个样本,第i个样本的图片和标签在这里"
   
2. DataLoader(数据加载器)
   ↓ "我每次取batch_size个样本,打包成一批"
   
3. Model + Loss + Optimizer(模型训练)
   ↓ "用这批数据:前向→计算损失→反向传播→更新参数"

你写的 hanfengDatasetSegmentation 就是第1块

  • 告诉PyTorch数据在哪里
  • 如何读取图像和掩码
  • 如何做数据增强

🏗️ 层次2:如何构建这个Dataset(逐步拆解)

第0步:确认数据结构

你的数据目录:
/workspace/pytorch-deeplab-xception-master/data/hanfeng/
├── images/              ← 原始焊缝图像
│   ├── 0.png
│   ├── 1.png
│   └── ...
├── masks/               ← 掩码(分割标注)
│   ├── 0.png
│   ├── 1.png
│   └── ...
└── Hui_txt/
    └── trainval.txt     ← 记录image-mask对应关系

trainval.txt 内容:

./images/0.png ./annotations_Hui/0.png
./images/1.png ./annotations_Hui/1.png
./images/10.png ./annotations_Hui/10.png
...

注意陷阱

  • txt里写的是 ./annotations_Hui/0.png
  • 但实际文件在 masks/0.png
  • 需要只取文件名,重新拼接到masks目录

第1步:导入必要的库

from __future__ import print_function, division
import os                          # 操作路径
from PIL import Image              # 读取图片
import numpy as np                 # 数组操作
from torch.utils.data import Dataset  # PyTorch数据集基类
from mypath import Path
from torchvision import transforms
from dataloaders import custom_transforms as tr  # 自定义数据增强

第2步:定义类并继承Dataset

class hanfengDatasetSegmentation(Dataset):
    NUM_CLASSES = 2  # 二分类:背景(0) + 焊缝(1)

继承Dataset后必须实现3个方法

  • __init__() - 初始化,读取所有数据路径
  • __len__() - 返回数据集大小
  • __getitem__() - 返回第i个样本

第3步:__init__ - 初始化(最核心)

3.1 保存参数
def __init__(self, args, base_dir, split='trainval'):
    self._base_dir = base_dir  # 数据根目录
    self.args = args           # 训练参数(如crop_size)
    self.split = [split]       # 使用哪个数据集划分
3.2 准备两个列表装路径
    self.images = []  # 存放所有图像路径
    self.masks = []   # 存放所有掩码路径
3.3 读取txt文件
    for splt in self.split:
        # 构建txt路径
        txt_path = os.path.join(base_dir, 'Hui_txt', splt + '.txt')
        # 例:/workspace/.../hanfeng/Hui_txt/trainval.txt
        
        with open(txt_path, 'r') as f:
            lines = f.read().splitlines()
        # lines = ["./images/0.png ./annotations_Hui/0.png", ...]
3.4 解析每一行,构建完整路径
    for line in lines:
        parts = line.strip().split()  # 按空格分割
        # parts = ["./images/0.png", "./annotations_Hui/0.png"]
        
        if len(parts) >= 2:
            image_rel = parts[0]  # ./images/0.png
            mask_rel = parts[1]   # ./annotations_Hui/0.png
            
            # 构建图像完整路径
            image_path = os.path.join(
                base_dir,
                image_rel.lstrip('./')  # 去掉开头的 ./
            )
            # 结果:/workspace/.../hanfeng/images/0.png
            
            # 构建掩码完整路径(关键:只取文件名)
            mask_filename = os.path.basename(mask_rel)  # 0.png
            mask_path = os.path.join(base_dir, 'masks', mask_filename)
            # 结果:/workspace/.../hanfeng/masks/0.png
            
            # 验证文件存在后添加到列表
            if os.path.isfile(image_path) and os.path.isfile(mask_path):
                self.images.append(image_path)
                self.masks.append(mask_path)
3.5 最终结果
    assert len(self.images) == len(self.masks)
    print(f'数据集共 {len(self.images)} 张图像')
    
# 此时:
# self.images = [路径1, 路径2, ..., 路径N]
# self.masks  = [路径1, 路径2, ..., 路径N]
# 两者一一对应

第4步:__len__ - 返回数据集大小

def __len__(self):
    return len(self.images)

作用:告诉DataLoader有多少个样本

dataset = hanfengDatasetSegmentation(...)
print(len(dataset))  # 假设输出 500

第5步:_make_img_gt_pair - 读取图像和掩码

def _make_img_gt_pair(self, index):
    _img = Image.open(self.images[index]).convert('RGB')
    _target = Image.open(self.masks[index])
    return _img, _target

解释

# 从路径列表取出第index个路径
image_path = self.images[index]  # /workspace/.../images/0.png
mask_path = self.masks[index]    # /workspace/.../masks/0.png

# 用PIL打开
_img = PIL.Image (RGB格式)
_target = PIL.Image (单通道灰度图)

第6步:__getitem__ - 返回一个样本

def __getitem__(self, index):
    # 1. 读取图像和掩码
    _img, _target = self._make_img_gt_pair(index)
    
    # 2. 打包成字典
    sample = {'image': _img, 'label': _target}
    
    # 3. 根据split选择数据增强
    for split in self.split:
        if split in ["train", "trainval"]:
            return self.transform_tr(sample)  # 训练用增强
        elif split in ['val', 'test']:
            return self.transform_val(sample) # 验证用增强

DataLoader调用过程

# 当你写:
for i, sample in enumerate(dataloader):
    images = sample['image']
    labels = sample['label']

# 内部发生:
sample = dataset.__getitem__(i)  # 调用这个方法
# 返回:{'image': Tensor, 'label': Tensor}

第7步:transform_tr / transform_val - 数据增强

def transform_tr(self, sample):
    """训练时的数据增强"""
    composed_transforms = transforms.Compose([
        tr.RandomHorizontalFlip(),      # 随机水平翻转
        tr.RandomScaleCrop(...),        # 随机缩放裁剪
        tr.RandomGaussianBlur(),        # 随机高斯模糊
        tr.Normalize(mean=..., std=...), # 归一化
        tr.ToTensor()                   # 转为Tensor
    ])
    return composed_transforms(sample)

def transform_val(self, sample):
    """验证时的数据增强(通常较少)"""
    composed_transforms = transforms.Compose([
        tr.FixScaleCrop(crop_size=...),  # 固定尺寸裁剪
        tr.Normalize(mean=..., std=...),
        tr.ToTensor()
    ])
    return composed_transforms(sample)

流程

输入:{'image': PIL.Image, 'label': PIL.Image}
  ↓
各种增强操作(翻转、裁剪、模糊...)
  ↓
归一化(图像减均值除方差)
  ↓
转Tensor
  ↓
输出:{'image': Tensor[3,H,W], 'label': Tensor[H,W]}

⚙️ 层次3:代码运行时的执行顺序

完整流程:

# 1. train.py - 创建Dataset
dataset = hanfengDatasetSegmentation(
    args, 
    base_dir="/workspace/.../hanfeng",
    split='trainval'
)
# → 触发 __init__(),读取所有路径

# 2. train.py - 创建DataLoader
dataloader = DataLoader(
    dataset,
    batch_size=4,
    shuffle=True
)

# 3. train.py - 训练循环
for epoch in range(epochs):
    for i, sample in enumerate(dataloader):
        # → DataLoader 内部调用 dataset.__getitem__(i)
        
        images = sample['image']  # [4, 3, 512, 512]
        labels = sample['label']  # [4, 512, 512]
        
        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # 反向传播
        loss.backward()
        optimizer.step()

详细执行序列:

用户代码: for i, sample in enumerate(dataloader)
    ↓
DataLoader: 需要第i个样本
    ↓
调用: dataset.__getitem__(i)
    ↓
Dataset: _make_img_gt_pair(i)
    ↓ 读取 self.images[i] 和 self.masks[i]
    ↓ PIL.open() → PIL.Image对象
    ↓
Dataset: transform_tr(sample)
    ↓ 随机翻转、裁剪、模糊...
    ↓ 归一化、ToTensor
    ↓
返回: {'image': Tensor, 'label': Tensor}
    ↓
DataLoader: 收集batch_size个样本,打包成批
    ↓
返回给用户: {'image': [B,3,H,W], 'label': [B,H,W]}

🎯 核心总结

Dataset的三大职责:

方法 职责 何时调用
__init__() 读取所有数据路径到列表 创建Dataset对象时
__len__() 返回数据集大小 DataLoader初始化时
__getitem__(i) 返回第i个样本(增强后) 训练循环的每次迭代

关键理解:

Dataset不是直接返回数据,而是:
"告诉DataLoader如何获取第i个样本"

就像图书馆的索引卡:
- __init__: 建立索引(哪本书在哪个书架)
- __len__: 告诉你总共多少本书
- __getitem__(i): 根据索引去书架拿第i本书

现在理解为什么要写这个Dataset类了吗? 🎉

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐