DataLoader

torch.utils.data.DataLoader

参数worker_init_fn

创建DataLoader需要传入Dataset对象,如果在Dataset中实现了worker_init_fn成员函数,则把这个函数也一并传给DataLoader。

不管传给DataLoader的num_workers等于几,Dataset的构造函数都只会被创建一次,即不同的worker是使用同一个Dataset;但是worker_init_fn会被调用num_workers次,用于初始化每个worker自己独有的数据,避免了和其他worker使用公用的数据,进而加快数据加载速度。

测试代码如下,如果把worker_init_fn的self.data = [i for i in range(200)]注释掉,则不同worker打印出来的self.data的id是相同的,如果不注释,则每个worker拥有自己独有的self.data。对于简单的数据结构,worker_init_fn的作用不大,但是对于复杂的数据加载方式则能显著加快数据读取速度。

import torch
import numpy as np

class Dataset(torch.utils.data.Dataset):
    def __init__(self):
        self.data = [i for i in range(200)]
        print(111)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        print(self.worker_id, id(self.data))
        tensor = torch.FloatTensor([self.data[idx]])
        return tensor

    def worker_init_fn(self, worker_id):
        print(worker_id)
        self.worker_id = worker_id
        self.data = [i for i in range(200)]

if __name__ == '__main__':
    dataset = Dataset()
    dataloader = torch.utils.data.DataLoader(
            dataset=dataset,
            shuffle=True,
            batch_size=20,
            num_workers=4,
            worker_init_fn=dataset.worker_init_fn)

    for idx, d in enumerate(dataloader):
        print(idx, d.shape)

 

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐