自己的数据集微调SAM模型时报错:

RuntimeError: The size of tensor a (64) must match the size of tensor b (8) at non-singleton dimension 0

解决办法:

找到SAM源码中的 mask_decoder.py 文件,找到以下代码块:

src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)  # 直接扩展,无条件判断

将其修改为:

if image_embeddings.shape[0] != tokens.shape[0]:  # 动态判断批次大小是否相等
 src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
else:
 src = image_embeddings  # 如果大小相等,直接使用image_embeddings

原因:在原始代码中,无论 image_embeddingstokens 的大小是否匹配,都会对 image_embeddings 进行扩展操作,这不仅可能导致张量维度不匹配的错误,还可能在 batch_size 较大时显著增加显存占用,从而引发显存不足的问题。修改后的代码通过动态判断批次大小,仅在必要时进行扩展,从而避免了不必要的显存浪费,支持更大的 batch_size,并提升了显存使用效率。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐