Dataset撰写,数据加载代码
mask_path = os.path.join(self._base_dir, 'masks', mask_filename) # 因为实际目录结构中掩码在masks文件夹下。(mask_rel_path)# 只取文件名部分 "./annotations_Hui/10.png" → "10.png"# ==================== 简单测试脚本 ===================
⚠️ 关键修复点总结
| 你的代码 | 问题 | 正确写法 |
|-------------------|--------------|-----------------------------------|
| pd.txt.read() | pandas没有这个方法 | open(file).read() |
| os.jion.path() | 拼写错误 | os.path.join() |
| train_text.lines | 文件对象没有此属性 | f.read().splitlines() |
| 取i中的./images/图像id | 伪代码 | parts = line.split() + parts[0] |
| "D:\github\..." | 转义问题 | r"D:\github\..." 或 "D:\\github\\" |
错误示例:
import os
import pandas as pd
# 初始化列表,存放图片和掩码路径
image_list = []
mask_list = []
# 定义base路径
base_dir = "D:\github\pytorch-deeplab-xception-master\data\hanfeng"
# 读取图像和掩码的id
# image_train_test = pd.txt.read(os.jion.path(base_dir, "hanfen","trainval.txt"))
# mask_train_text = pd.txt.read(os.jion.path(base_dir, "hanfen","trainval.txt"))
train_text = pd.txt.read(os.jion.path(base_dir, "Hui_txt","trainval.txt"))
# 将所有的图像和掩码路径存入列表
for i in train_text.lines:
image = 取i中的./images/图像id
mask = 取i中的./mask/图像id_mask
image_list.append(image)
mask_list.append(mask)
正确示例:
import os
# 初始化列表,存放图片和掩码路径
image_list = []
mask_list = []
1.# 定义base路径(注意:使用原始字符串或正斜杠)
base_dir = r"D:\github\pytorch-deeplab-xception-master\data\hanfeng"
2.# 构建txt文件的完整路径
txt_path = os.path.join(base_dir, "Hui_txt", "trainval.txt")
3.# 读取txt文件
with open(txt_path, "r", encoding="utf-8") as f:
train_text = f.readlines()
4.# 遍历每一行,提取图像和掩码路径
for line in train_text:
# 去除首尾空白字符和换行符
line = line.strip()
# 分割字符串(按空格分割)
paths = line.split()
if len(paths) == 2: # 确保每行有两个路径
image_path = paths[0] # ./images/0.png
mask_path = paths[1] # ./annotations_Hui/0.png
# 将相对路径转为完整路径
full_image_path = os.path.join(base_dir, image_path)
full_mask_path = os.path.join(base_dir, mask_path)
# 添加到列表
image_list.append(full_image_path)
mask_list.append(full_mask_path)
结合dataset类进行撰写:
from __future__ import print_function, division
import os
from PIL import Image
import numpy as np
from torch.utils.data import Dataset
from mypath import Path
from torchvision import transforms
import custom_transforms as tr
class HanfengSegmentation(Dataset):
"""
Hanfeng 焊缝分割数据集
"""
NUM_CLASSES = 2 # 根据你的实际类别数调整(背景+焊缝)
def __init__(self,
args,
base_dir=r"D:\github\pytorch-deeplab-xception-master\data\hanfeng",
split='trainval',
):
"""
初始化Hanfeng数据集
:param args: 参数对象(包含base_size, crop_size等)
:param base_dir: 数据集根目录
:param split: 数据集划分 ('train', 'val', 'trainval', 'test')
"""
super().__init__()
self._base_dir = base_dir
self.args = args
if isinstance(split, str):
self.split = [split]
else:
split.sort()
self.split = split
# 存储图像和掩码的路径
self.images = []
self.masks = []
# 读取数据集划分文件,给数据集路径txt的路径
for splt in self.split:
split_file = os.path.join(self._base_dir, 'Hui_txt', splt + '.txt')
# 检查文件是否存在
if not os.path.isfile(split_file):
raise FileNotFoundError(f"Split file not found: {split_file}")
# 读取文件内容,splitlines()自动将每行分开并存入列表中
# f.read():一次性把整个文件内容读成一个大字符串
# .splitlines():在这个字符串上按换行符切分,返回一个“每行一条”的列表(不包含行尾的 \n)
with open(split_file, 'r') as f:
lines = f.read().splitlines()
# 解析每一行,提取图像和掩码路径
for line in lines:
# 每行格式: ./images/0.png ./annotations_Hui/0.png
# strip() 去掉换行符,split() 按空格切开,["./images/0.png", "./annotations_Hui/0.png"]
parts = line.strip().split()
if len(parts) >= 2:
# 提取相对路径
image_rel_path = parts[0] # 例如: ./images/0.png
mask_rel_path = parts[1] # 例如: ./annotations_Hui/0.png
# 构建完整路径
image_path = os.path.join(self._base_dir, image_rel_path.lstrip('./')) #去掉左侧所有 '.' 和 '/'
mask_filename = os.path.basename(mask_rel_path) # 只取文件名部分 "./annotations_Hui/10.png" → "10.png"
mask_path = os.path.join(self._base_dir, 'masks', mask_filename) # 因为实际目录结构中掩码在masks文件夹下
# 验证文件是否存在
if os.path.isfile(image_path) and os.path.isfile(mask_path):
self.images.append(image_path)
self.masks.append(mask_path)
else:
print(f"Warning: File not found - Image: {image_path}, Mask: {mask_path}")
# 验证数据完整性
assert len(self.images) == len(self.masks), \
f"图像数量({len(self.images)})和掩码数量({len(self.masks)})不匹配!"
# 显示统计信息
print(f'Hanfeng数据集 - {split} 划分: 共 {len(self.images)} 张图像')
def __len__(self):
"""返回数据集大小"""
return len(self.images)
def __getitem__(self, index):
"""
获取单个样本
:param index: 样本索引
:return: 包含image和label的字典
"""
_img, _target = self._make_img_gt_pair(index)
sample = {'image': _img, 'label': _target}
# 根据数据集划分应用不同的数据变换
for split in self.split:
if split in ["train", "trainval"]:
return self.transform_tr(sample)
elif split in ['val', 'test']:
return self.transform_val(sample)
def _make_img_gt_pair(self, index):
"""
加载图像和掩码
:param index: 样本索引
:return: PIL图像对象
"""
_img = Image.open(self.images[index]).convert('RGB')
_target = Image.open(self.masks[index])
return _img, _target
def transform_tr(self, sample):
"""训练集数据增强"""
composed_transforms = transforms.Compose([
tr.RandomHorizontalFlip(),
tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size),
tr.RandomGaussianBlur(),
tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
tr.ToTensor()])
return composed_transforms(sample)
def transform_val(self, sample):
"""验证/测试集数据变换"""
composed_transforms = transforms.Compose([
tr.FixScaleCrop(crop_size=self.args.crop_size),
tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
tr.ToTensor()])
return composed_transforms(sample)
def __str__(self):
return 'Hanfeng(split=' + str(self.split) + ')'
# ==================== 简单测试脚本 ====================
# if __name__ == '__main__':
# from torch.utils.data import DataLoader
# import matplotlib.pyplot as plt
# import argparse
# # 创建参数对象
# parser = argparse.ArgumentParser()
# args = parser.parse_args()
# args.base_size = 513
# args.crop_size = 513
# # 加载训练集
# hanfeng_train = HanfengSegmentation(args, split='trainval')
# print(f"数据集大小: {len(hanfeng_train)}")
# print(f"第一个样本的路径:")
# print(f" 图像: {hanfeng_train.images[0]}")
# print(f" 掩码: {hanfeng_train.masks[0]}")
# # 创建数据加载器
# dataloader = DataLoader(hanfeng_train, batch_size=2, shuffle=True, num_workers=0)
# # 可视化第一个batch,从 dataloader 里取 batch,一边给每个 batch 自动编号(0,1,2,3…)
# for ii, sample in enumerate(dataloader):
# print(f"Batch {ii}: Image shape = {sample['image'].shape}, Label shape = {sample['label'].shape}")
# # 可视化
# for jj in range(sample["image"].size()[0]):
# img = sample['image'].numpy()
# gt = sample['label'].numpy()
# # 反归一化图像
# img_tmp = np.transpose(img[jj], axes=[1, 2, 0])
# img_tmp *= (0.229, 0.224, 0.225)
# img_tmp += (0.485, 0.456, 0.406)
# img_tmp *= 255.0
# img_tmp = img_tmp.astype(np.uint8)
# # 掩码
# mask_tmp = np.array(gt[jj]).astype(np.uint8)
# # 显示
# plt.figure(figsize=(10, 5))
# plt.subplot(121)
# plt.imshow(img_tmp)
# plt.title('Image')
# plt.axis('off')
# plt.subplot(122)
# plt.imshow(mask_tmp, cmap='gray')
# plt.title('Mask')
# plt.axis('off')
# plt.tight_layout()
# if ii == 0: # 只显示第一个batch
# break
# plt.show()
总体流程:
🎯 PyTorch Dataset 完整解释
我用三个层次帮你理解这个 hanfengDatasetSegmentation 类:
📚 层次1:Dataset 在训练中的角色
PyTorch 训练流程三大组件:
训练流程:
1. Dataset(数据集)
↓ "我有N个样本,第i个样本的图片和标签在这里"
2. DataLoader(数据加载器)
↓ "我每次取batch_size个样本,打包成一批"
3. Model + Loss + Optimizer(模型训练)
↓ "用这批数据:前向→计算损失→反向传播→更新参数"
你写的 hanfengDatasetSegmentation 就是第1块:
- 告诉PyTorch数据在哪里
- 如何读取图像和掩码
- 如何做数据增强
🏗️ 层次2:如何构建这个Dataset(逐步拆解)
第0步:确认数据结构
你的数据目录:
/workspace/pytorch-deeplab-xception-master/data/hanfeng/
├── images/ ← 原始焊缝图像
│ ├── 0.png
│ ├── 1.png
│ └── ...
├── masks/ ← 掩码(分割标注)
│ ├── 0.png
│ ├── 1.png
│ └── ...
└── Hui_txt/
└── trainval.txt ← 记录image-mask对应关系
trainval.txt 内容:
./images/0.png ./annotations_Hui/0.png
./images/1.png ./annotations_Hui/1.png
./images/10.png ./annotations_Hui/10.png
...
注意陷阱:
- txt里写的是
./annotations_Hui/0.png - 但实际文件在
masks/0.png - 需要只取文件名,重新拼接到masks目录
第1步:导入必要的库
from __future__ import print_function, division
import os # 操作路径
from PIL import Image # 读取图片
import numpy as np # 数组操作
from torch.utils.data import Dataset # PyTorch数据集基类
from mypath import Path
from torchvision import transforms
from dataloaders import custom_transforms as tr # 自定义数据增强
第2步:定义类并继承Dataset
class hanfengDatasetSegmentation(Dataset):
NUM_CLASSES = 2 # 二分类:背景(0) + 焊缝(1)
继承Dataset后必须实现3个方法:
__init__()- 初始化,读取所有数据路径__len__()- 返回数据集大小__getitem__()- 返回第i个样本
第3步:__init__ - 初始化(最核心)
3.1 保存参数
def __init__(self, args, base_dir, split='trainval'):
self._base_dir = base_dir # 数据根目录
self.args = args # 训练参数(如crop_size)
self.split = [split] # 使用哪个数据集划分
3.2 准备两个列表装路径
self.images = [] # 存放所有图像路径
self.masks = [] # 存放所有掩码路径
3.3 读取txt文件
for splt in self.split:
# 构建txt路径
txt_path = os.path.join(base_dir, 'Hui_txt', splt + '.txt')
# 例:/workspace/.../hanfeng/Hui_txt/trainval.txt
with open(txt_path, 'r') as f:
lines = f.read().splitlines()
# lines = ["./images/0.png ./annotations_Hui/0.png", ...]
3.4 解析每一行,构建完整路径
for line in lines:
parts = line.strip().split() # 按空格分割
# parts = ["./images/0.png", "./annotations_Hui/0.png"]
if len(parts) >= 2:
image_rel = parts[0] # ./images/0.png
mask_rel = parts[1] # ./annotations_Hui/0.png
# 构建图像完整路径
image_path = os.path.join(
base_dir,
image_rel.lstrip('./') # 去掉开头的 ./
)
# 结果:/workspace/.../hanfeng/images/0.png
# 构建掩码完整路径(关键:只取文件名)
mask_filename = os.path.basename(mask_rel) # 0.png
mask_path = os.path.join(base_dir, 'masks', mask_filename)
# 结果:/workspace/.../hanfeng/masks/0.png
# 验证文件存在后添加到列表
if os.path.isfile(image_path) and os.path.isfile(mask_path):
self.images.append(image_path)
self.masks.append(mask_path)
3.5 最终结果
assert len(self.images) == len(self.masks)
print(f'数据集共 {len(self.images)} 张图像')
# 此时:
# self.images = [路径1, 路径2, ..., 路径N]
# self.masks = [路径1, 路径2, ..., 路径N]
# 两者一一对应
第4步:__len__ - 返回数据集大小
def __len__(self):
return len(self.images)
作用:告诉DataLoader有多少个样本
dataset = hanfengDatasetSegmentation(...)
print(len(dataset)) # 假设输出 500
第5步:_make_img_gt_pair - 读取图像和掩码
def _make_img_gt_pair(self, index):
_img = Image.open(self.images[index]).convert('RGB')
_target = Image.open(self.masks[index])
return _img, _target
解释:
# 从路径列表取出第index个路径
image_path = self.images[index] # /workspace/.../images/0.png
mask_path = self.masks[index] # /workspace/.../masks/0.png
# 用PIL打开
_img = PIL.Image (RGB格式)
_target = PIL.Image (单通道灰度图)
第6步:__getitem__ - 返回一个样本
def __getitem__(self, index):
# 1. 读取图像和掩码
_img, _target = self._make_img_gt_pair(index)
# 2. 打包成字典
sample = {'image': _img, 'label': _target}
# 3. 根据split选择数据增强
for split in self.split:
if split in ["train", "trainval"]:
return self.transform_tr(sample) # 训练用增强
elif split in ['val', 'test']:
return self.transform_val(sample) # 验证用增强
DataLoader调用过程:
# 当你写:
for i, sample in enumerate(dataloader):
images = sample['image']
labels = sample['label']
# 内部发生:
sample = dataset.__getitem__(i) # 调用这个方法
# 返回:{'image': Tensor, 'label': Tensor}
第7步:transform_tr / transform_val - 数据增强
def transform_tr(self, sample):
"""训练时的数据增强"""
composed_transforms = transforms.Compose([
tr.RandomHorizontalFlip(), # 随机水平翻转
tr.RandomScaleCrop(...), # 随机缩放裁剪
tr.RandomGaussianBlur(), # 随机高斯模糊
tr.Normalize(mean=..., std=...), # 归一化
tr.ToTensor() # 转为Tensor
])
return composed_transforms(sample)
def transform_val(self, sample):
"""验证时的数据增强(通常较少)"""
composed_transforms = transforms.Compose([
tr.FixScaleCrop(crop_size=...), # 固定尺寸裁剪
tr.Normalize(mean=..., std=...),
tr.ToTensor()
])
return composed_transforms(sample)
流程:
输入:{'image': PIL.Image, 'label': PIL.Image}
↓
各种增强操作(翻转、裁剪、模糊...)
↓
归一化(图像减均值除方差)
↓
转Tensor
↓
输出:{'image': Tensor[3,H,W], 'label': Tensor[H,W]}
⚙️ 层次3:代码运行时的执行顺序
完整流程:
# 1. train.py - 创建Dataset
dataset = hanfengDatasetSegmentation(
args,
base_dir="/workspace/.../hanfeng",
split='trainval'
)
# → 触发 __init__(),读取所有路径
# 2. train.py - 创建DataLoader
dataloader = DataLoader(
dataset,
batch_size=4,
shuffle=True
)
# 3. train.py - 训练循环
for epoch in range(epochs):
for i, sample in enumerate(dataloader):
# → DataLoader 内部调用 dataset.__getitem__(i)
images = sample['image'] # [4, 3, 512, 512]
labels = sample['label'] # [4, 512, 512]
# 前向传播
outputs = model(images)
loss = criterion(outputs, labels)
# 反向传播
loss.backward()
optimizer.step()
详细执行序列:
用户代码: for i, sample in enumerate(dataloader)
↓
DataLoader: 需要第i个样本
↓
调用: dataset.__getitem__(i)
↓
Dataset: _make_img_gt_pair(i)
↓ 读取 self.images[i] 和 self.masks[i]
↓ PIL.open() → PIL.Image对象
↓
Dataset: transform_tr(sample)
↓ 随机翻转、裁剪、模糊...
↓ 归一化、ToTensor
↓
返回: {'image': Tensor, 'label': Tensor}
↓
DataLoader: 收集batch_size个样本,打包成批
↓
返回给用户: {'image': [B,3,H,W], 'label': [B,H,W]}
🎯 核心总结
Dataset的三大职责:
| 方法 | 职责 | 何时调用 |
|---|---|---|
__init__() |
读取所有数据路径到列表 | 创建Dataset对象时 |
__len__() |
返回数据集大小 | DataLoader初始化时 |
__getitem__(i) |
返回第i个样本(增强后) | 训练循环的每次迭代 |
关键理解:
Dataset不是直接返回数据,而是:
"告诉DataLoader如何获取第i个样本"
就像图书馆的索引卡:
- __init__: 建立索引(哪本书在哪个书架)
- __len__: 告诉你总共多少本书
- __getitem__(i): 根据索引去书架拿第i本书
现在理解为什么要写这个Dataset类了吗? 🎉
更多推荐


所有评论(0)