构建自定义数据集类(MyDataset)的过程解析

大家好, 我是诗人啊_最近在整理seq2seq案例, 后续会把完整的案例实现过程都一一发布,各位观众老爷可以点点关注不咯~ (简单实用, 注释清晰, 看了包会的)在这里插入图片描述

目录

  • 源代码实现
  • 构建过程解析
    1. 继承 Dataset 基类
    2. 初始化方法(init
    3. 定义数据集长度(len
    4. 提取单个样本(getitem
  • 核心作用总结

以下是基于源代码,对 MyDataset 类构建过程的详细整理,包含代码实现和核心逻辑说明

源代码实现

# 3. 构建数据集类
class MyDataset(Dataset):
    def __init__(self, my_pairs):
        super(MyDataset, self).__init__()
        # 样本
        self.my_pairs = my_pairs

        # 样本数量(长度)
        self.amples = len(my_pairs)

    def __len__(self):
        return self.amples

    # 获取第几条 样本数据
    # 数据源对象---> __getitem__ ---> DataLoader ---> 每次迭代获取一个样本数据
    def __getitem__(self, index):
        # todo: 对index进行修正, 确保index >= 0, index <= 最大下标(样本长度-1)
        # index要么小于0---> max取0; index要么大于样本长度---> min取--->index
        index = min(max(index, 0), self.sample_len - 1)

        # 按索引获取,样本数据 x, y  ---> x代表英文, y代表法文
        x = self.my_pairs[index][0]
        y = self.my_pairs[index][1]

        # 样本x 文本数值化.
        '''
        todo: 这里提一下, 单词转为向量的过程:
        
            1. 构建词表 {单词:索引} (word2index) 或者 {索引:单词} (index2word). '2'(to) 					可以理解为字典中的冒号,所以再品一下(word:index)的意思, 是不是一样就懂了~
            
            2. 根据词表, 将单词转为索引, 索引就是向量. 
                因为词表是字典,只能通过key来获取value,
                我们的目的是把单词转为索引, 也就是找到那个单词, 取出对应的索引. 
                所以要把单词转为索引,遍历单词所在的容器. 联系上文讲的字典是通过key来获取value, 
                并且 现在的词表是{单词:索引}形式, 就得到了我代码:
                x = [english_word2index[word] for word in x.split(' ')]
        '''
        x = [english_word2index[word] for word in x.split(' ')]  # 列表推导式写法
        x.append(EOS_token)  # EOS_token代表结束标志, 其实可以不写这句话, 因为在初始化时, 我们已经把EOS_token加到字典中了,并且索引为1
        # 将x列表转为tensor, 并且指定数据类型为long, 指定设备为cpu
        # device的意思是, 如果你选择cpu运行, 那么tensor就在cpu上; 如果你选择gpu运行, 那么tensor就在gpu上.
        tensor_x = torch.tensor(data=x, dtype=torch.long, device=device)  # device最开始初始化了, 值是cpu

        # 样本y 文本数值化---> y代表法文, 处理方法和x一样(看上面)
        y = [french_word2index[word] for word in y.split(' ')]  # 列表推导式写法
        y.append(EOS_token)
        tensor_y = torch.tensor(data=y, dtype=torch.long, device=device)
        # 返回结果
        return tensor_x, tensor_y

构建过程解析

1. 继承 Dataset 基类

自定义数据集类必须继承 PyTorch 的 Dataset 基类,以遵循 PyTorch 数据加载的规范,确保后续可通过 DataLoader 调用。

class MyDataset(Dataset):  # 继承 Dataset 基类,获得数据加载的标准接口

2. 初始化方法(__init__

作用:接收原始数据并初始化类的核心属性,存储样本信息。

  • 输入参数my_pairs(原始样本对列表,格式如 [(英文句子1, 法文句子1), ...])。
  • 核心操作
    • 调用父类初始化方法 super(MyDataset, self).__init__()
    • 存储原始样本对到 self.my_pairs(后续提取样本的数据源)。
    • 记录样本总数到 self.amples(即数据集长度,用于 __len__ 方法返回)。

3. 定义数据集长度(__len__

作用:返回数据集的样本总数,供 DataLoader 计算迭代次数(如批量大小为 32 时,总迭代次数 = 样本总数 / 32)。

def __len__(self):
    return self.amples  # 返回样本数量

4. 提取单个样本(__getitem__

作用:根据索引 index 提取单个样本,并将文本转换为模型可处理的张量(核心逻辑)。
步骤拆解:

(1)索引修正

确保输入的 index 在有效范围内(0 到 样本总数-1),避免因索引越界导致错误。

index = min(max(index, 0), self.sample_len - 1)  # 限制索引在 [0, 样本总数-1] 之间
(2)提取原始文本

self.my_pairs 中按索引提取对应的英文句子 x 和法文句子 y

x = self.my_pairs[index][0]  # 英文句子(如 "i am student")
y = self.my_pairs[index][1]  # 法文句子(如 "je suis étudiant")
(3)文本数值化(单词→索引)

通过提前构建的词表(english_word2indexfrench_word2index),将文本转换为整数索引列表:

  • 按空格分割句子为单词(如 "i am"["i", "am"])。
  • 用词表映射每个单词为索引(如 english_word2index["i"] = 3)。
  • 拼接句子结束标志 EOS_token(表示句子终止,方便模型识别句子边界)。
# 英文句子转换
x = [english_word2index[word] for word in x.split(' ')]  # 单词→索引列表
x.append(EOS_token)  # 添加结束标志的索引

# 法文句子转换(逻辑同上)
y = [french_word2index[word] for word in y.split(' ')]
y.append(EOS_token)
(4)转换为张量(Tensor)

torch.tensor() 将索引列表转换为 PyTorch 张量,指定:

  • 数据类型 dtype=torch.long(整数类型,符合词嵌入层的输入要求)。
  • 设备 device(如 CPU 或 GPU,确保张量与模型在同一设备上)。
tensor_x = torch.tensor(data=x, dtype=torch.long, device=device)  # 英文索引→张量
tensor_y = torch.tensor(data=y, dtype=torch.long, device=device)  # 法文索引→张量
(5)返回样本

返回处理后的英文张量和法文张量,作为模型的输入(如 seq2seq 模型的编码器输入和解码器目标)。

return tensor_x, tensor_y

核心作用总结

MyDataset 类的核心功能是打通“原始文本”到“模型可处理数据”的转换链路
原始文本对(英文-法文)→ 单词分割 → 词表映射为索引 → 转换为张量 → 输出给模型。

通过继承 Dataset 基类,使其可与 DataLoader 配合,实现批量加载、打乱顺序、多线程预处理等功能,为模型训练提供高效的数据输入管道。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐