2023-4-27 更新
另外一种使用 torchdata 库的解决方法 https://blog.csdn.net/ONE_SIX_MIX/article/details/130405330


pytorch 的 dataloader 默认使用 python 自带的多进程库 multiprocessing ,它又使用 pickle 作为序列化库。
pickle 库只能储存一些简单类型。如果 dataset 中使用 lambda 函数对象,将会导致出现这样的错误 AttributeError: Can’t pickle local object

multiprocess 的 pip 安装方法

pip install -U multiprocess

第三方库的多进程库 multiprocess ,使用 dill 库,它可以序列化复杂的类型,例如 lambda 函数。

我们将用它替代掉 dataloader 中默认的多进程库。

!注意!使用 multiprocess 库替代后,可能会占用更多的内存。速度没差别。

以下为替代示例
以下变量 use_multiprocess 是开关
将其 use_multiprocess 设为 False ,即为使用默认的多进程库,执行以下代码会报错。
将 use_multiprocess 设为 True,即为使用第三方多进程库,以下代码可以正常执行

import torch
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset

import multiprocessing
multiprocessing.set_start_method('spawn', force=True)

# Flag
use_multiprocess = True

if use_multiprocess:
    import multiprocess
    multiprocess.set_start_method('spawn', force=True)

    torch.utils.data.dataloader.python_multiprocessing = multiprocess
    new_multiprocess_ctx = multiprocess.get_context()

else:
    new_multiprocess_ctx = None


class SimpleDataset(Dataset):
    def __init__(self):
        # lambda function here
        self.func = lambda x: x+1

    def __len__(self):
        return 1000

    def __getitem__(self, i):
        return self.func(i)


if __name__ == '__main__':
    ds = SimpleDataset()

    dl = DataLoader(ds, 1, multiprocessing_context=new_multiprocess_ctx, num_workers=2)

    for x in dl:
        print(x)

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐