AI、人工智能基础:构造自定义数据集教程(简单明了,看了包会!)
`MyDataset`这货,堪称PyTorch的“文本翻译机”——专把人类话儿转成模型能啃的张量。出身自带光环:继承`Dataset`名门,会三招绝活:- `__init__`:先把文本对(比如英→法)打包存好,数清楚有多少份;- `__len__`:报个数,让`DataLoader`知道得跑几趟;- `__getitem__`:核心操作!给单词发“数字身份证”(用词表转索引),贴个“结束贴”(E
构建自定义数据集类(MyDataset)的过程解析
大家好, 我是诗人啊_最近在整理seq2seq案例, 后续会把完整的案例实现过程都一一发布,各位观众老爷可以点点关注不咯~ (简单实用, 注释清晰, 看了包会的)
目录
- 源代码实现
- 构建过程解析
- 继承 Dataset 基类
- 初始化方法(init)
- 定义数据集长度(len)
- 提取单个样本(getitem)
- 核心作用总结
以下是基于源代码,对 MyDataset 类构建过程的详细整理,包含代码实现和核心逻辑说明
源代码实现
# 3. 构建数据集类
class MyDataset(Dataset):
def __init__(self, my_pairs):
super(MyDataset, self).__init__()
# 样本
self.my_pairs = my_pairs
# 样本数量(长度)
self.amples = len(my_pairs)
def __len__(self):
return self.amples
# 获取第几条 样本数据
# 数据源对象---> __getitem__ ---> DataLoader ---> 每次迭代获取一个样本数据
def __getitem__(self, index):
# todo: 对index进行修正, 确保index >= 0, index <= 最大下标(样本长度-1)
# index要么小于0---> max取0; index要么大于样本长度---> min取--->index
index = min(max(index, 0), self.sample_len - 1)
# 按索引获取,样本数据 x, y ---> x代表英文, y代表法文
x = self.my_pairs[index][0]
y = self.my_pairs[index][1]
# 样本x 文本数值化.
'''
todo: 这里提一下, 单词转为向量的过程:
1. 构建词表 {单词:索引} (word2index) 或者 {索引:单词} (index2word). '2'(to) 可以理解为字典中的冒号,所以再品一下(word:index)的意思, 是不是一样就懂了~
2. 根据词表, 将单词转为索引, 索引就是向量.
因为词表是字典,只能通过key来获取value,
我们的目的是把单词转为索引, 也就是找到那个单词, 取出对应的索引.
所以要把单词转为索引,遍历单词所在的容器. 联系上文讲的字典是通过key来获取value,
并且 现在的词表是{单词:索引}形式, 就得到了我代码:
x = [english_word2index[word] for word in x.split(' ')]
'''
x = [english_word2index[word] for word in x.split(' ')] # 列表推导式写法
x.append(EOS_token) # EOS_token代表结束标志, 其实可以不写这句话, 因为在初始化时, 我们已经把EOS_token加到字典中了,并且索引为1
# 将x列表转为tensor, 并且指定数据类型为long, 指定设备为cpu
# device的意思是, 如果你选择cpu运行, 那么tensor就在cpu上; 如果你选择gpu运行, 那么tensor就在gpu上.
tensor_x = torch.tensor(data=x, dtype=torch.long, device=device) # device最开始初始化了, 值是cpu
# 样本y 文本数值化---> y代表法文, 处理方法和x一样(看上面)
y = [french_word2index[word] for word in y.split(' ')] # 列表推导式写法
y.append(EOS_token)
tensor_y = torch.tensor(data=y, dtype=torch.long, device=device)
# 返回结果
return tensor_x, tensor_y
构建过程解析
1. 继承 Dataset 基类
自定义数据集类必须继承 PyTorch 的 Dataset 基类,以遵循 PyTorch 数据加载的规范,确保后续可通过 DataLoader 调用。
class MyDataset(Dataset): # 继承 Dataset 基类,获得数据加载的标准接口
2. 初始化方法(__init__)
作用:接收原始数据并初始化类的核心属性,存储样本信息。
- 输入参数:
my_pairs(原始样本对列表,格式如[(英文句子1, 法文句子1), ...])。 - 核心操作:
- 调用父类初始化方法
super(MyDataset, self).__init__()。 - 存储原始样本对到
self.my_pairs(后续提取样本的数据源)。 - 记录样本总数到
self.amples(即数据集长度,用于__len__方法返回)。
- 调用父类初始化方法
3. 定义数据集长度(__len__)
作用:返回数据集的样本总数,供 DataLoader 计算迭代次数(如批量大小为 32 时,总迭代次数 = 样本总数 / 32)。
def __len__(self):
return self.amples # 返回样本数量
4. 提取单个样本(__getitem__)
作用:根据索引 index 提取单个样本,并将文本转换为模型可处理的张量(核心逻辑)。
步骤拆解:
(1)索引修正
确保输入的 index 在有效范围内(0 到 样本总数-1),避免因索引越界导致错误。
index = min(max(index, 0), self.sample_len - 1) # 限制索引在 [0, 样本总数-1] 之间
(2)提取原始文本
从 self.my_pairs 中按索引提取对应的英文句子 x 和法文句子 y。
x = self.my_pairs[index][0] # 英文句子(如 "i am student")
y = self.my_pairs[index][1] # 法文句子(如 "je suis étudiant")
(3)文本数值化(单词→索引)
通过提前构建的词表(english_word2index 和 french_word2index),将文本转换为整数索引列表:
- 按空格分割句子为单词(如
"i am"→["i", "am"])。 - 用词表映射每个单词为索引(如
english_word2index["i"] = 3)。 - 拼接句子结束标志
EOS_token(表示句子终止,方便模型识别句子边界)。
# 英文句子转换
x = [english_word2index[word] for word in x.split(' ')] # 单词→索引列表
x.append(EOS_token) # 添加结束标志的索引
# 法文句子转换(逻辑同上)
y = [french_word2index[word] for word in y.split(' ')]
y.append(EOS_token)
(4)转换为张量(Tensor)
用 torch.tensor() 将索引列表转换为 PyTorch 张量,指定:
- 数据类型
dtype=torch.long(整数类型,符合词嵌入层的输入要求)。 - 设备
device(如 CPU 或 GPU,确保张量与模型在同一设备上)。
tensor_x = torch.tensor(data=x, dtype=torch.long, device=device) # 英文索引→张量
tensor_y = torch.tensor(data=y, dtype=torch.long, device=device) # 法文索引→张量
(5)返回样本
返回处理后的英文张量和法文张量,作为模型的输入(如 seq2seq 模型的编码器输入和解码器目标)。
return tensor_x, tensor_y
核心作用总结
MyDataset 类的核心功能是打通“原始文本”到“模型可处理数据”的转换链路:
原始文本对(英文-法文)→ 单词分割 → 词表映射为索引 → 转换为张量 → 输出给模型。
通过继承 Dataset 基类,使其可与 DataLoader 配合,实现批量加载、打乱顺序、多线程预处理等功能,为模型训练提供高效的数据输入管道。
更多推荐


所有评论(0)