背景:

模型转trace时, 报错RuntimeError: unsupported output type: int, from operator: xformers::efficient_attention_forward, 排查原因是在模型内部使用了xformers.ops.memory_efficient_attention(), 更换了几个版本的xformer都没搞定, 最后通过替换这个方法绕过了这个问题, 虽然显存会增加一些, 但是起码能把模型转trace成功

解决办法:

xformers.ops.memory_efficient_attention()替换为如下方法

import torch.nn.functional as F

def memory_efficient_attention_pytorch(query, key, value, attn_bias=None, p=0., scale=None):
    # query     [batch, seq_len, n_head, head_dim]
    # key       [batch, seq_len, n_head, head_dim]
    # value     [batch, seq_len, n_head, head_dim]
    # attn_bias [batch, n_head, seq_len, seq_len]

    if scale is None:
        scale = 1 / query.shape[-1] ** 0.5
    
    # BLHC -> BHLC
    query = query.transpose(1, 2)
    key = key.transpose(1, 2)
    value = value.transpose(1, 2)

    query = query * scale
    # BHLC @ BHCL -> BHLL
    attn = query @ key.transpose(-2, -1)
    if attn_bias is not None:
        attn = attn + attn_bias
    attn = attn.softmax(-1)
    attn = F.dropout(attn, p)
    # BHLL @ BHLC -> BHLC
    out = attn @ value
    # BHLC -> BLHC
    out = out.transpose(1, 2)
    return out


参考链接:

memory_efficient_attention_pytorch
官方bug链接

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐