使用flash-attention推理
虽然transformers库中可以实现flash attention,但是默认情况下是不使用的,需要在加载模型时使用一个参数:attn_implementation="flash_attention_2"。不仅如此,还需要在本地install flash-attn;如果安装失败,可以下载。这个文件,下载到本地之后pip install 它就可以。
·
虽然transformers库中可以实现flash attention,但是默认情况下是不使用的,需要在加载模型时使用一个参数:attn_implementation="flash_attention_2"。
不仅如此,还需要在本地install flash-attn;如果安装失败,可以去这个地方下载对应版本的whl文件: https://github.com/Dao-AILab/flash-attention/releases/,下载到本地之后pip install 它就可以。
此外,flash-attn还要求model和input都必须在cuda上,所以在还需要将它们加载后转移到cuda上,代码如下:
# 定义输入,并转移到cuda上
input_text = "Hello, how are you?"
inputs = tokenizer(input_text, return_tensors="pt")
inputs = {key: value.to('cuda') for key, value in inputs.items()}
model_name = "meta-llama/llama-2-7b-chat-hf/"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name,
trust_remote_code=True,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16).to('cuda') # model也转移到cuda上
更多推荐


所有评论(0)