load_pretrained_model

load_pretrained_model(
    model_path, 
    model_base=None, 
    model_name, 
    load_8bit=False, 
    load_4bit=False, 
    device_map="auto", 
    torch_dtype="float16",
    attn_implementation="flash_attention_2", 
    customized_config=None, 
    overwrite_config=None, 
    **kwargs
)

加载预训练的 LLaVA 模型或语言模型。

该函数支持加载多种模型架构,包括 LLaVA 模型(Llama、Mistral、Mixtral、Qwen、Gemma)和标准语言模型。它能够处理不同的加载场景,如完整模型加载、LoRA 权重加载和量化(8-bit/4-bit)。

参数

参数名 类型 默认值 说明
model_path str - 模型检查点路径或 HuggingFace 模型标识符。对于 LoRA 模型,应指向 LoRA 权重目录。
model_base str, optional None 基础模型检查点路径。加载 LoRA 权重或仅加载投影器检查点时必须提供。
model_name str - 模型名称或标识符,用于确定模型架构。应包含关键词如 “llava”、“lora”、“mixtral”、“mistral”、“qwen”、“gemma” 等。
load_8bit bool False 是否以 8-bit 量化加载模型。与 load_4bit 互斥。
load_4bit bool False 是否使用 BitsAndBytes 以 4-bit 量化加载模型。与 load_8bit 互斥。
device_map str "auto" 模型加载的设备映射策略。选项包括 “auto”、“cpu”、“cuda” 或特定的设备映射字典。
torch_dtype str "float16" 模型权重的数据类型。选项为 “float16” 或 “bfloat16”。
attn_implementation str "flash_attention_2" 使用的注意力实现。选项包括 “flash_attention_2”、“sdpa” 等。
customized_config dict, optional None 自定义模型配置字典,用于覆盖从 model_path 加载的默认配置。
overwrite_config dict, optional None 配置属性字典,用于在加载基础配置后覆盖。键应为配置属性名,值为要设置的新值。
**kwargs - - 传递给底层 from_pretrained() 方法的其他关键字参数(如 trust_remote_codelow_cpu_mem_usage)。

返回值

返回一个包含以下四个元素的元组:

  • tokenizer (AutoTokenizer): 加载模型的 tokenizer 实例。
  • model (PreTrainedModel): 加载的模型实例。可能是以下类型之一:
    • LlavaLlamaForCausalLM
    • LlavaMistralForCausalLM
    • LlavaMixtralForCausalLM
    • LlavaQwenForCausalLM
    • LlavaGemmaForCausalLM
    • AutoModelForCausalLM(用于纯语言模型)
  • image_processor (ImageProcessor or None): 视觉语言模型的图像处理器。对于纯语言模型返回 None
  • context_len (int): 模型的最大上下文长度,从配置属性(如 max_sequence_lengthmax_position_embeddings)确定,或默认为 2048。

功能特性

支持的模型架构

该函数自动识别并加载以下模型架构:

  • LLaVA 系列: Llama、Mistral、Mixtral、Qwen、Gemma 等基础架构的 LLaVA 变体
  • LoRA 模型: 支持加载和合并 LoRA 权重
  • 纯语言模型: 支持标准 HuggingFace 语言模型

量化支持

  • 8-bit 量化: 通过 load_8bit=True 启用
  • 4-bit 量化: 通过 load_4bit=True 启用,使用 BitsAndBytes 库,配置为 NF4 量化类型

特殊处理

  • LLaVA v1.5 模型: 自动设置 delay_load=True 作为正确加载的解决方案
  • 多模态模型: 自动为 tokenizer 添加特殊 token(图像 patch token、开始/结束 token)
  • LoRA 权重合并: 自动加载并合并 LoRA 权重到基础模型

使用示例

示例 1: 加载完整的 LLaVA 模型

from llava.model.builder import load_pretrained_model

# 从 HuggingFace 加载 LLaVA v1.5 模型
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path="liuhaotian/llava-v1.5-7b",
    model_base=None,
    model_name="llava",
    torch_dtype="float16"
)

print(f"Model loaded: {model.__class__.__name__}")
print(f"Context length: {context_len}")

示例 2: 加载 LoRA 模型

# 加载 LoRA 权重(需要提供基础模型)
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path="./checkpoints/llava-lora",
    model_base="liuhaotian/llava-v1.5-7b",
    model_name="llava_lora",
    torch_dtype="float16"
)

示例 3: 使用 4-bit 量化加载模型

# 使用 4-bit 量化以节省显存
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path="liuhaotian/llava-v1.5-7b",
    model_base=None,
    model_name="llava",
    load_4bit=True,
    device_map="auto"
)

示例 4: 加载自定义配置的模型

# 使用自定义配置覆盖默认设置
custom_config = {
    "mm_vision_select_layer": -2,
    "mm_projector_type": "mlp2x_gelu"
}

tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path="liuhaotian/llava-v1.5-7b",
    model_base=None,
    model_name="llava",
    customized_config=custom_config,
    torch_dtype="float16"
)

示例 5: 加载纯语言模型

# 加载标准语言模型(非多模态)
tokenizer, model, _, context_len = load_pretrained_model(
    model_path="meta-llama/Llama-2-7b-hf",
    model_base=None,
    model_name="llama",
    torch_dtype="float16"
)
# 注意:image_processor 为 None

注意事项

  1. LoRA 模型加载: 加载 LoRA 模型时,必须提供 model_base 参数。函数会先加载基础模型,然后应用 LoRA 权重并自动合并。

  2. 量化互斥性: load_8bitload_4bit 参数互斥,不能同时设置为 True

  3. 设备映射: 使用 device_map="auto" 时,函数会自动将模型分配到可用设备。对于多 GPU 环境,模型会被分片到不同 GPU。

  4. 配置覆盖顺序:

    • 首先应用 customized_config(如果提供)
    • 然后应用 overwrite_config(如果提供)
    • 最后应用从 model_path 加载的配置
  5. 特殊 Token: 对于多模态模型,函数会根据模型配置自动添加图像相关的特殊 token 到 tokenizer。

警告

  • 如果 model_name 包含 “lora” 但 model_baseNone,函数会发出警告,因为 LoRA 模型需要基础模型。

  • 如果指定的 model_name 不被支持,函数会抛出 ValueError 异常。

相关链接

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐