RuntimeError: unsupported output type: int, from operator: xformers::efficient_attention_forward
模型转trace时, 报错, 排查原因是在模型内部使用了, 更换了几个版本的xformer都没搞定, 最后通过替换这个方法绕过了这个问题, 虽然显存会增加一些, 但是起码能把模型转trace成功。
·
背景:
模型转trace时, 报错RuntimeError: unsupported output type: int, from operator: xformers::efficient_attention_forward
, 排查原因是在模型内部使用了xformers.ops.memory_efficient_attention()
, 更换了几个版本的xformer
都没搞定, 最后通过替换这个方法绕过了这个问题, 虽然显存会增加一些, 但是起码能把模型转trace成功
解决办法:
将xformers.ops.memory_efficient_attention()
替换为如下方法
import torch.nn.functional as F
def memory_efficient_attention_pytorch(query, key, value, attn_bias=None, p=0., scale=None):
# query [batch, seq_len, n_head, head_dim]
# key [batch, seq_len, n_head, head_dim]
# value [batch, seq_len, n_head, head_dim]
# attn_bias [batch, n_head, seq_len, seq_len]
if scale is None:
scale = 1 / query.shape[-1] ** 0.5
# BLHC -> BHLC
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
query = query * scale
# BHLC @ BHCL -> BHLL
attn = query @ key.transpose(-2, -1)
if attn_bias is not None:
attn = attn + attn_bias
attn = attn.softmax(-1)
attn = F.dropout(attn, p)
# BHLL @ BHLC -> BHLC
out = attn @ value
# BHLC -> BLHC
out = out.transpose(1, 2)
return out
参考链接:
更多推荐
所有评论(0)