判断是否有 nan

torch.any(torch.isnan(a))

利用 torch.where() 函数替换所有 nan

where() 函数有三个输入值,第一个是判断条件,第二个是符合条件的设置值,第三个是不符合条件的设置值。

a = torch.Tensor([[1, 2, np.nan], [2, np.nan, 4], [3, 4, 5]])

a = torch.where(torch.isnan(a), torch.full_like(a, 0), a)
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐