虽然transformers库中可以实现flash attention,但是默认情况下是不使用的,需要在加载模型时使用一个参数:attn_implementation="flash_attention_2"。

不仅如此,还需要在本地install flash-attn;如果安装失败,可以去这个地方下载对应版本的whl文件: https://github.com/Dao-AILab/flash-attention/releases/,下载到本地之后pip install 它就可以。

此外,flash-attn还要求model和input都必须在cuda上,所以在还需要将它们加载后转移到cuda上,代码如下:

# 定义输入,并转移到cuda上
input_text = "Hello, how are you?"
inputs = tokenizer(input_text, return_tensors="pt")
inputs = {key: value.to('cuda') for key, value in inputs.items()}
​
model_name = "meta-llama/llama-2-7b-chat-hf/"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name,
                                             trust_remote_code=True,
                                             attn_implementation="flash_attention_2",
                                             torch_dtype=torch.bfloat16).to('cuda')  # model也转移到cuda上

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐