pytorch notes
DataLoadertorch.utils.data.DataLoader参数worker_init_fn创建DataLoader需要传入Dataset对象,如果在Dataset中实现了worker_init_fn成员函数,则把这个函数也一并传给DataLoader。不管传给DataLoader的num_workers等于几,Dataset的构造函数都只会被创建一次,即不同的worker是使用同一
·
DataLoader
torch.utils.data.DataLoader
参数worker_init_fn
创建DataLoader需要传入Dataset对象,如果在Dataset中实现了worker_init_fn成员函数,则把这个函数也一并传给DataLoader。
不管传给DataLoader的num_workers等于几,Dataset的构造函数都只会被创建一次,即不同的worker是使用同一个Dataset;但是worker_init_fn会被调用num_workers次,用于初始化每个worker自己独有的数据,避免了和其他worker使用公用的数据,进而加快数据加载速度。
测试代码如下,如果把worker_init_fn的self.data = [i for i in range(200)]注释掉,则不同worker打印出来的self.data的id是相同的,如果不注释,则每个worker拥有自己独有的self.data。对于简单的数据结构,worker_init_fn的作用不大,但是对于复杂的数据加载方式则能显著加快数据读取速度。
import torch
import numpy as np
class Dataset(torch.utils.data.Dataset):
def __init__(self):
self.data = [i for i in range(200)]
print(111)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
print(self.worker_id, id(self.data))
tensor = torch.FloatTensor([self.data[idx]])
return tensor
def worker_init_fn(self, worker_id):
print(worker_id)
self.worker_id = worker_id
self.data = [i for i in range(200)]
if __name__ == '__main__':
dataset = Dataset()
dataloader = torch.utils.data.DataLoader(
dataset=dataset,
shuffle=True,
batch_size=20,
num_workers=4,
worker_init_fn=dataset.worker_init_fn)
for idx, d in enumerate(dataloader):
print(idx, d.shape)
更多推荐
所有评论(0)