torch.cat 数据类型的坑
all_img = torch.tensor([])# 采用下面的语句读取图片img = torch.from_numpy(cv2.imread('{}.JPEG'.format(5))).unsqueeze(0)# 此时 img 的维度 (1,224,224,3)# 将img合并入all_img 中all_img = all_img.cat((all_img,img)) # 报错 Runtime
·
all_img = torch.tensor([])
# 采用下面的语句读取图片
img = torch.from_numpy(cv2.imread('{}.JPEG'.format(5))).unsqueeze(0)
# 此时 img 的维度 (1,224,224,3)
# 将img合并入all_img 中
all_img = all_img.cat((all_img,img)) # 报错 RuntimeError: Expected object of scalar type Byte but got scalar type Float for sequence element 1 in sequence argument at position #1 'tensors'
发现img 是 int 8 类型,所以转化一下img 的数据类型
all_img = torch.tensor([])
# 采用下面的语句读取图片
img = torch.from_numpy(cv2.imread('{}.JPEG'.format(5))).unsqueeze(0).type(torch.float32) #### 注意此处
# 此时 img 的维度 (1,224,224,3)
# 将img合并入all_img 中
all_img = all_img.cat((all_img,img)) # 报错 RuntimeError: Expected object of scalar type Byte but got scalar type Float for sequence element 1 in sequence argument at position #1 'tensors'
这样才可以
更多推荐
所有评论(0)