gradient 为nan可能原因:

  1. 梯度爆炸
  2. 学习率太大
  3. 数据本身有问题
  4. backward时,某些方法造成0在分母上, 如:使用方法sqrt()

定位造成nan的代码:

import torch
# 异常检测开启
torch.autograd.set_detect_anomaly(True)
# 反向传播时检测是否有异常值,定位code
with torch.autograd.detect_anomaly():
	loss.backward()
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐