解决flash attention提示model not initialized on GPU的方法
最初以为是flash attention 2的安装问题,使用了。至此,flash attention不再提示加载问题。中提供的测试代码,进行测试,发现安装没有问题。在加载函数中强制将模型加载到GPU。
·
最近笔者在折腾smolVLM的时候,在使用官方提供的如下加载代码时:
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct")
model = AutoModelForVision2Seq.from_pretrained(
"HuggingFaceTB/SmolVLM-256M-Instruct",
torch_dtype=torch.bfloat16,
_attn_implementation="flash_attention_2" if DEVICE == "cuda" else "eager",
).to(DEVICE)
flash attention 给出了以下的提示:
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
最初以为是flash attention 2的安装问题,使用了Flash-attention 安装指南-知乎中提供的测试代码,进行测试,发现安装没有问题
#测试代码
import torch
from flash_attn import flash_attn_func
import time
def test_flash_attention():
# 设置随机种子以确保结果可重现
torch.manual_seed(0)
# 生成随机测试数据
batch_size = 2
seq_len = 1024
num_heads = 8
head_dim = 64
# 创建随机查询、键和值张量
q = torch.randn(batch_size, seq_len, num_heads, head_dim, device='cuda', dtype=torch.float16)
k = torch.randn(batch_size, seq_len, num_heads, head_dim, device='cuda', dtype=torch.float16)
v = torch.randn(batch_size, seq_len, num_heads, head_dim, device='cuda', dtype=torch.float16)
try:
# 测试 Flash Attention
start_time = time.time()
output = flash_attn_func(q, k, v, causal=True)
flash_time = time.time() - start_time
print("Flash Attention 测试成功!")
print(f"输出张量形状: {output.shape}")
print(f"运行时间: {flash_time:.4f} 秒")
print("\n张量设备位置:", output.device)
print("张量数据类型:", output.dtype)
return True
except Exception as e:
print("Flash Attention 测试失败")
print("错误信息:", str(e))
return False
if __name__ == "__main__":
if torch.cuda.is_available():
print("CUDA 可用")
print("GPU:", torch.cuda.get_device_name(0))
test_flash_attention()
else:
print("错误: 需要 CUDA 支持才能运行 Flash Attention")
对官方提供的加载代码进行如下修改,依旧报错:
#将末尾的
.to(DEVICE)
#显式修改为
to("cuda")
最终解决方法:
在加载函数中强制将模型加载到GPU
device_map="cuda"
修改后的完整函数:
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct")
model = AutoModelForVision2Seq.from_pretrained(
"HuggingFaceTB/SmolVLM-256M-Instruct",
torch_dtype=torch.bfloat16,
_attn_implementation="flash_attention_2" if DEVICE == "cuda" else "eager",
device_map="cuda" ,
).to(DEVICE)
至此,flash attention不再提示加载问题。
更多推荐



所有评论(0)