RuntimeError: The size of tensor a (512) must match the size of tensor b (128) at non-singleton dime
GAN网络改了个生成器遇到问题。
·
项目场景:
GAN网络改了个生成器遇到问题
问题描述
RuntimeError: The size of tensor a (512) must match the size of tensor b (128) at non-singleton dime
意思就是张量维度不对
原因分析:
因为生成器输出的是512*512的但是损失函数和判别器都是128*128的
解决方案:
方法一:重新设计生成器,使之有正确的输出;
还是建议更改网络结构,不然就算能跑了效果也不好
方法二:使用resize操作,使tensor维度对应;
在pytorch中,输入网络的图像的shape=[B,C,H,W].
有时我们需要在网络中对图像张量进行resize操作,这时就要用到transforms.Resize([H,W]) 操作。示例如下:
import cv2
import numpy as np
import torch
from torchvision.transforms import Resize
im1 = cv2.imread("./datasets/frame_0001.png").transpose([2,0,1]) # shape=[C,H,W]
im1_torch = torch.from_numpy(im1.astype(np.float32)).unsqueeze(0) # shape=[B,C,H,W]
# im1_torch可以看作是输入torch神经网络的tensor.
torch_resize = Resize([256,256]) # 定义Resize类对象
im1_resize = torch_resize(im1_torch)
# torchvision.transforms.Resize([H,W])的作用是把最后两个维度resize成[H,W].
# 所以,这对图像的通道顺序有要求。
im1_resize_np = im1_resize.data.cpu().numpy()[0].transpose(1, 2, 0) # shape=[H,W,C]
print(im1.shape)
print(im1_resize.shape)
print(im1_resize_np.shape)
参考: Pytorch transforms.Resize()的简单用法_xiongxyowo的博客-CSDN博客_transforms.resize
更多推荐


所有评论(0)