大模型微调完全指南:从原理到工业级实践

本文深度剖析大模型微调的核心原理、数学推导、主流方法对比、完整代码实现及工业级最佳实践。无论你是初学者还是从业者,都能从中获得系统性的知识体系。


目录

  1. 引言:为什么微调如此重要
  2. 大模型训练全景图
  3. 全参微调:原理与挑战
  4. 参数高效微调(PEFT)方法论
  5. LoRA:深入原理与数学推导
  6. QLoRA:量化与低秩的完美结合
  7. 其他PEFT方法详解
  8. 工业级代码实战
  9. 数据工程:微调成功的关键
  10. 评估与调优策略
  11. 部署与推理优化
  12. 常见问题与解决方案
  13. 总结与展望

一、引言:为什么微调如此重要

1.1 大模型的能力与局限

大型语言模型(LLM)如GPT-4、Claude、LLaMA等,通过在海量文本数据上进行预训练,习得了丰富的语言知识和推理能力。然而,预训练模型存在以下局限:

┌─────────────────────────────────────────────────────────────────────────┐
│                     预训练模型的能力边界                                  │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│   ✅ 预训练模型擅长:                    ❌ 预训练模型的局限:             │
│   ┌─────────────────────────┐          ┌─────────────────────────┐     │
│   │ • 通用语言理解          │          │ • 缺乏领域专业知识      │     │
│   │ • 基础推理能力          │          │ • 不了解私有数据        │     │
│   │ • 常识知识              │          │ • 输出格式不可控        │     │
│   │ • 多语言能力            │          │ • 可能产生幻觉          │     │
│   │ • 上下文理解            │          │ • 风格固定难以定制      │     │
│   └─────────────────────────┘          └─────────────────────────┘     │
│                                                                         │
│   微调的核心价值: 让通用模型变成"专才",在特定领域/任务上表现卓越        │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

1.2 微调的典型应用场景

应用场景 具体需求 微调策略 数据规模
垂直领域问答 医疗、法律、金融等专业知识 SFT + 领域语料 1万-10万条
企业智能客服 理解产品知识,风格一致 SFT + 对话数据 5千-5万条
代码助手 特定语言/框架的代码生成 SFT + 代码数据 1万-50万条
内容创作 特定风格的文案、文章 SFT + 风格样本 1千-1万条
数据抽取 结构化信息提取 SFT + 标注数据 1千-5千条
安全对齐 拒绝有害请求 RLHF/DPO 1万-10万条

1.3 微调的投入产出分析

投资回报率分析 (以7B模型为例)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

投入成本:
┌────────────────┬─────────────────┬─────────────────┬─────────────────┐
│     方法       │    硬件需求     │   训练时间      │    云成本/次    │
├────────────────┼─────────────────┼─────────────────┼─────────────────┤
│   全参微调     │  4×A100 80GB    │    4-8小时      │    $50-100      │
│   LoRA        │  1×A100 40GB    │    2-4小时      │    $10-20       │
│   QLoRA       │  1×RTX 4090     │    3-6小时      │    $5-15        │
└────────────────┴─────────────────┴─────────────────┴─────────────────┘

产出效果:
┌────────────────┬──────────────────────────────────────────────────────┐
│    指标        │                       提升幅度                        │
├────────────────┼──────────────────────────────────────────────────────┤
│  任务准确率    │  基础模型 60% → 微调后 85-95% (提升 25-35%)           │
│  响应相关性    │  通用回答 → 领域精准回答                              │
│  格式遵循率    │  50% → 95%+ (JSON/表格等结构化输出)                   │
│  幻觉率        │  降低 30-50% (在特定领域)                             │
└────────────────┴──────────────────────────────────────────────────────┘

二、大模型训练全景图

2.1 完整训练流程

一个完整的大模型从"出生"到"上岗",通常经历以下阶段:

┌─────────────────────────────────────────────────────────────────────────┐
│                        大模型完整训练流程                                │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  阶段1: 预训练 (Pre-training)                                           │
│  ════════════════════════════                                           │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │  目标: 学习语言的统计规律和世界知识                              │   │
│  │  数据: 数万亿token的互联网文本 (网页、书籍、代码、论文...)       │   │
│  │  任务: 下一个token预测 (Causal LM) 或 掩码预测 (MLM)             │   │
│  │  成本: 数百万至数千万美元                                        │   │
│  │  产出: 基座模型 (Base Model)                                     │   │
│  │        - 具备语言理解和生成能力                                  │   │
│  │        - 但可能"话痨"、不遵循指令、输出不可控                    │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│                              │                                          │
│                              ▼                                          │
│  阶段2: 监督微调 (Supervised Fine-Tuning, SFT)                         │
│  ══════════════════════════════════════════════                        │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │  目标: 学习理解指令并按要求输出                                  │   │
│  │  数据: 高质量的 (指令, 输出) 配对数据,通常1万-100万条           │   │
│  │  任务: 条件语言建模,最大化 P(输出|指令)                         │   │
│  │  成本: 数百至数千美元                                            │   │
│  │  产出: SFT模型 (Instruct Model)                                  │   │
│  │        - 能理解并执行指令                                        │   │
│  │        - 输出格式可控                                            │   │
│  │        - 但可能仍会输出有害内容                                  │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│                              │                                          │
│                              ▼                                          │
│  阶段3: 人类偏好对齐 (Alignment)                                       │
│  ═══════════════════════════════                                        │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │  目标: 让模型输出符合人类价值观和偏好                            │   │
│  │  方法:                                                           │   │
│  │    • RLHF: 训练奖励模型 + PPO强化学习                            │   │
│  │    • DPO: 直接偏好优化,无需单独的奖励模型                       │   │
│  │    • ORPO/KTO: 更简化的对齐方法                                  │   │
│  │  数据: 偏好对比数据 (好的回答 vs 差的回答)                       │   │
│  │  产出: 对齐模型 (Chat/Assistant Model)                           │   │
│  │        - 安全、有帮助、诚实                                      │   │
│  │        - 拒绝有害请求                                            │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

2.2 微调在训练流程中的位置

                    我们通常所说的"微调"发生在这里
                              │
                              ▼
┌──────────┐    ┌──────────┐    ┌──────────┐    ┌──────────┐
│  预训练  │ -> │   SFT    │ -> │  RLHF    │ -> │   部署   │
│ (厂商做) │    │ (可定制) │    │ (可选)   │    │          │
└──────────┘    └──────────┘    └──────────┘    └──────────┘
     │               │               │
     │               │               │
     ▼               ▼               ▼
 LLaMA-Base     LLaMA-Instruct  LLaMA-Chat
 Qwen-Base      Qwen-Instruct   Qwen-Chat
 Mistral        Mistral-Instruct

━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

常见微调起点选择:

┌─────────────────┬────────────────────────────────────────────────────┐
│   起点模型      │                    适用场景                         │
├─────────────────┼────────────────────────────────────────────────────┤
│   Base模型      │ • 需要大量领域预训练 (继续预训练)                   │
│                 │ • 对话能力不重要的任务 (分类、抽取)                 │
│                 │ • 从头构建对话能力                                  │
├─────────────────┼────────────────────────────────────────────────────┤
│  Instruct模型   │ • 保留通用对话能力                                  │
│  (推荐)         │ • 快速适配特定领域                                  │
│                 │ • 大多数业务场景的最佳起点                          │
├─────────────────┼────────────────────────────────────────────────────┤
│   Chat模型      │ • 需要保留安全对齐能力                              │
│                 │ • 客服、助手类应用                                  │
│                 │ • 注意可能降低任务执行能力                          │
└─────────────────┴────────────────────────────────────────────────────┘

2.3 不同训练阶段的对比

维度 预训练 SFT RLHF/DPO
目标函数 下一token预测 条件语言建模 奖励最大化/偏好对齐
数据规模 万亿token 万-百万样本 万-十万偏好对
数据来源 互联网爬取 人工标注/合成 人工标注偏好
计算成本 极高 中等 中等
学习内容 语言知识+世界知识 任务执行能力 价值观对齐
参数更新 全部参数 全部/部分参数 全部/部分参数

三、全参微调:原理与挑战

3.1 全参微调的数学原理

全参微调(Full Fine-Tuning)是最直接的微调方式,即在下游任务数据上更新模型的所有参数。

优化目标:

给定预训练模型参数 θ 0 \theta_0 θ0,和微调数据集 D = { ( x i , y i ) } i = 1 N D = \{(x_i, y_i)\}_{i=1}^{N} D={(xi,yi)}i=1N,全参微调的目标是:

θ ∗ = arg ⁡ min ⁡ θ ∑ i = 1 N L ( f θ ( x i ) , y i ) \theta^* = \arg\min_{\theta} \sum_{i=1}^{N} \mathcal{L}(f_\theta(x_i), y_i) θ=argθmini=1NL(fθ(xi),yi)

其中 L \mathcal{L} L 通常是交叉熵损失。对于因果语言模型:

L = − ∑ t = 1 T log ⁡ P θ ( y t ∣ x , y < t ) \mathcal{L} = -\sum_{t=1}^{T} \log P_\theta(y_t | x, y_{<t}) L=t=1TlogPθ(ytx,y<t)

梯度更新:

┌─────────────────────────────────────────────────────────────────────────┐
│                        全参微调的参数更新过程                            │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│   对于模型中的每一个参数 θ:                                             │
│                                                                         │
│   1. 前向传播: 计算预测值 ŷ = f_θ(x)                                    │
│                                                                         │
│   2. 计算损失: L = CrossEntropy(ŷ, y)                                   │
│                                                                         │
│   3. 反向传播: 计算梯度 g = ∂L/∂θ                                       │
│                                                                         │
│   4. 参数更新 (以AdamW为例):                                            │
│      ┌──────────────────────────────────────────────────────────────┐  │
│      │  m_t = β₁ · m_{t-1} + (1 - β₁) · g_t        # 一阶矩估计     │  │
│      │  v_t = β₂ · v_{t-1} + (1 - β₂) · g_t²       # 二阶矩估计     │  │
│      │  m̂_t = m_t / (1 - β₁ᵗ)                      # 偏差修正       │  │
│      │  v̂_t = v_t / (1 - β₂ᵗ)                      # 偏差修正       │  │
│      │  θ_t = θ_{t-1} - η · (m̂_t / (√v̂_t + ε) + λ · θ_{t-1})       │  │
│      │                                             # 权重衰减       │  │
│      └──────────────────────────────────────────────────────────────┘  │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

3.2 显存占用深度分析

全参微调的显存占用是限制其应用的主要瓶颈。让我们详细分析7B模型的显存需求:

┌─────────────────────────────────────────────────────────────────────────┐
│               7B模型全参微调显存占用详细分析                              │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  假设: 模型参数量 P = 7B = 7 × 10⁹                                      │
│                                                                         │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │ 1. 模型权重 (Model Weights)                                     │   │
│  │    ────────────────────────                                     │   │
│  │    • FP32: 7B × 4 bytes = 28 GB                                 │   │
│  │    • FP16/BF16: 7B × 2 bytes = 14 GB  ← 常用                    │   │
│  │    • INT8: 7B × 1 byte = 7 GB                                   │   │
│  │    • INT4: 7B × 0.5 bytes = 3.5 GB                              │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│                                                                         │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │ 2. 梯度 (Gradients)                                             │   │
│  │    ────────────────                                             │   │
│  │    与模型权重相同大小                                            │   │
│  │    • FP16训练: 7B × 2 bytes = 14 GB                             │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│                                                                         │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │ 3. 优化器状态 (Optimizer States) - 以AdamW为例                  │   │
│  │    ─────────────────────────────────────────────                │   │
│  │    需要存储每个参数的:                                          │   │
│  │    • 一阶矩 m (momentum): 7B × 4 bytes = 28 GB                  │   │
│  │    • 二阶矩 v (variance): 7B × 4 bytes = 28 GB                  │   │
│  │    • 主权重副本 (FP32): 7B × 4 bytes = 28 GB (混合精度训练需要) │   │
│  │    小计: 84 GB                                                   │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│                                                                         │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │ 4. 激活值 (Activations) - 取决于序列长度和批次大小               │   │
│  │    ──────────────────────────────────────────────               │   │
│  │    假设: batch_size=4, seq_len=2048, hidden_size=4096           │   │
│  │    每层激活: 4 × 2048 × 4096 × 2 bytes ≈ 64 MB                  │   │
│  │    32层总计: 32 × 64 MB × 2 (前向+后向) ≈ 4 GB                  │   │
│  │    (使用gradient checkpointing可降至 ~1 GB)                     │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│                                                                         │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │ 总计 (FP16混合精度训练 + AdamW)                                  │   │
│  │ ═════════════════════════════════                               │   │
│  │                                                                  │   │
│  │   模型权重:      14 GB                                           │   │
│  │   梯度:          14 GB                                           │   │
│  │   优化器状态:    84 GB                                           │   │
│  │   激活值:         4 GB                                           │   │
│  │   ─────────────────────                                         │   │
│  │   总计:        ~116 GB  →  需要 2×A100 80GB 或 4×A100 40GB      │   │
│  │                                                                  │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

3.3 全参微调的优缺点

┌─────────────────────────────────────────────────────────────────────────┐
│                          全参微调 优缺点分析                             │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│   ✅ 优点:                                                              │
│   ──────                                                                │
│   • 理论上能达到最佳性能,模型可以完全适应新任务                        │
│   • 不引入额外的模型结构,推理时无额外开销                              │
│   • 对任务的适应能力最强,可以学习复杂的领域特定模式                    │
│   • 实现简单,是最直接的迁移学习方式                                    │
│                                                                         │
│   ❌ 缺点:                                                              │
│   ──────                                                                │
│   • 显存需求巨大: 7B模型需要 ~120GB显存                                 │
│   • 存储开销大: 每个任务需要保存完整模型副本                            │
│   • 容易过拟合: 尤其在小数据集上,需要仔细调整正则化                    │
│   • 灾难性遗忘: 可能损失预训练获得的通用能力                            │
│   • 训练成本高: 需要多卡并行,成本较高                                  │
│                                                                         │
│   适用场景:                                                             │
│   ──────────                                                            │
│   • 数据充足 (>10万样本) 且与预训练数据差异大                           │
│   • 追求极致性能,且有足够计算资源                                      │
│   • 单一任务部署,不需要多任务切换                                      │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

四、参数高效微调(PEFT)方法论

4.1 PEFT的核心思想

参数高效微调(Parameter-Efficient Fine-Tuning, PEFT)的核心假设是:

低秩假设: 模型适应新任务所需的参数更新存在于一个低维子空间中,不需要更新全部参数。

┌─────────────────────────────────────────────────────────────────────────┐
│                         PEFT 的核心假设与动机                            │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│   预训练模型参数空间示意:                                               │
│                                                                         │
│        全参数空间 (d×d维)                   任务适应所需的真实子空间     │
│   ┌──────────────────────────┐          ┌────────────────────────┐     │
│   │ ░░░░░░░░░░░░░░░░░░░░░░░░ │          │                        │     │
│   │ ░░░░░░░░░░░░░░░░░░░░░░░░ │          │    ████████            │     │
│   │ ░░░░░░░░░░░░░░░░░░░░░░░░ │    →     │    ████████  低秩子空间 │     │
│   │ ░░░░░░░░░░░░░░░░░░░░░░░░ │          │    ████████            │     │
│   │ ░░░░░░░░░░░░░░░░░░░░░░░░ │          │                        │     │
│   └──────────────────────────┘          └────────────────────────┘     │
│         全参微调:                              PEFT:                    │
│     更新所有 d² 个参数                    只更新 ~2dr 个参数            │
│     (大量冗余更新)                         (聚焦有效更新)               │
│                                                                         │
│   数学表达:                                                             │
│   ──────────                                                            │
│   • 全参微调: W' = W₀ + ΔW,  其中 ΔW ∈ ℝ^{d×d}                         │
│   • PEFT假设: ΔW = BA,      其中 B ∈ ℝ^{d×r}, A ∈ ℝ^{r×d}, r << d     │
│   • 参数量:   d² → 2dr,     压缩比例 d/(2r), 如 d=4096, r=8 → 256倍   │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

4.2 PEFT方法分类体系

┌─────────────────────────────────────────────────────────────────────────┐
│                        PEFT 方法分类体系                                 │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│                           ┌──────────┐                                  │
│                           │  PEFT    │                                  │
│                           └────┬─────┘                                  │
│          ┌─────────┬──────────┼──────────┬─────────┐                   │
│          │         │          │          │         │                   │
│          ▼         ▼          ▼          ▼         ▼                   │
│    ┌─────────┐ ┌────────┐ ┌────────┐ ┌────────┐ ┌────────┐            │
│    │ 加性方法│ │选择性  │ │重参数化│ │ 混合   │ │ 量化   │            │
│    │Additive │ │Selective│ │Reparam │ │ Hybrid │ │ Quant  │            │
│    └────┬────┘ └───┬────┘ └───┬────┘ └───┬────┘ └───┬────┘            │
│         │          │          │          │          │                   │
│    ┌────┴────┐    │     ┌────┴────┐    │     ┌────┴────┐              │
│    │         │    │     │         │    │     │         │              │
│    ▼         ▼    ▼     ▼         ▼    ▼     ▼         ▼              │
│ ┌──────┐ ┌──────┐ │  ┌──────┐ ┌──────┐ │  ┌──────┐ ┌──────┐          │
│ │Adapter│ │Prefix│ │  │LoRA  │ │DoRA  │ │  │QLoRA │ │GPTQ- │          │
│ │Tuning│ │Tuning│ │  │      │ │      │ │  │      │ │LoRA  │          │
│ └──────┘ └──────┘ │  └──────┘ └──────┘ │  └──────┘ └──────┘          │
│           │       │      │             │                              │
│    ┌──────┴──────┐│ ┌────┴────┐  ┌─────┴─────┐                       │
│    │             ││ │         │  │           │                       │
│    ▼             ▼▼ ▼         ▼  ▼           ▼                       │
│ ┌──────┐ ┌──────┐ ┌──────┐ ┌──────┐ ┌──────┐ ┌──────┐               │
│ │Prompt│ │P-Tun-│ │BitFit│ │AdaLo-│ │Uni-  │ │MAM   │               │
│ │Tuning│ │ing v2│ │      │ │RA    │ │PELT  │ │Adapter│              │
│ └──────┘ └──────┘ └──────┘ └──────┘ └──────┘ └──────┘               │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

方法详细说明:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

┌───────────────┬──────────────────────────────────────────────────────┐
│    方法类别   │                      核心思想                         │
├───────────────┼──────────────────────────────────────────────────────┤
│   加性方法    │ 在模型中添加额外的可训练模块                          │
│  (Additive)   │ • Adapter: 在层间插入小型MLP                         │
│               │ • Prefix: 在输入前添加可学习的"虚拟token"            │
│               │ • Prompt Tuning: 学习软提示嵌入                      │
├───────────────┼──────────────────────────────────────────────────────┤
│   选择性方法  │ 只更新模型中的部分参数                               │
│ (Selective)   │ • BitFit: 只更新偏置项                               │
│               │ • LN-Tuning: 只更新LayerNorm参数                     │
│               │ • 选择特定层进行微调                                 │
├───────────────┼──────────────────────────────────────────────────────┤
│  重参数化方法 │ 用低秩矩阵表示参数更新                               │
│ (Reparameter) │ • LoRA: 低秩适配,最流行的方法                       │
│               │ • DoRA: 分解为幅度和方向                             │
│               │ • AdaLoRA: 自适应秩分配                              │
├───────────────┼──────────────────────────────────────────────────────┤
│   混合方法    │ 组合多种PEFT技术                                     │
│   (Hybrid)    │ • UniPELT: 统一多种PEFT方法                          │
│               │ • MAM Adapter: 混合Adapter和Prefix                   │
├───────────────┼──────────────────────────────────────────────────────┤
│   量化方法    │ 结合量化技术进一步降低资源需求                       │
│  (Quantized)  │ • QLoRA: 4bit量化 + LoRA                             │
│               │ • GPTQ-LoRA: GPTQ量化 + LoRA                         │
└───────────────┴──────────────────────────────────────────────────────┘

4.3 PEFT方法综合对比

方法 可训练参数 显存占用 推理开销 实现复杂度 性能 适用场景
LoRA 0.1-1% 无(可合并) ★★★★★ 通用首选
QLoRA 0.1-1% ★★★★☆ 资源受限
AdaLoRA 自适应 ★★★★★ 追求性能
DoRA 0.1-1% ★★★★★ 追求性能
Adapter 1-5% ★★★★☆ 多任务
Prefix-Tuning <1% ★★★☆☆ 生成任务
Prompt-Tuning <0.1% 极低 ★★★☆☆ 简单任务
BitFit <0.1% 极低 ★★☆☆☆ 快速实验

五、LoRA:深入原理与数学推导

5.1 LoRA的核心思想

LoRA(Low-Rank Adaptation)由微软研究院在2021年提出,是目前最主流的PEFT方法。其核心思想是:

权重更新矩阵具有低秩特性,可以用两个小矩阵的乘积来高效近似。

┌─────────────────────────────────────────────────────────────────────────┐
│                         LoRA 核心原理图解                                │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│   传统全参微调:                          LoRA微调:                      │
│   ──────────────                         ──────────                      │
│                                                                         │
│        输入 x                                 输入 x                     │
│          │                                      │                       │
│          ▼                                      │                       │
│   ┌──────────────┐                       ┌──────┴──────┐               │
│   │              │                       │             │               │
│   │  W₀ + ΔW    │                       │             │               │
│   │  (d × d)    │                       ▼             ▼               │
│   │  全部更新   │                  ┌─────────┐   ┌─────────┐          │
│   │              │                  │   W₀    │   │  B × A  │          │
│   └──────┬───────┘                  │  (d×d)  │   │ (d×r×d) │          │
│          │                          │  冻结   │   │  训练   │          │
│          ▼                          └────┬────┘   └────┬────┘          │
│        输出 h                            │             │               │
│                                          └──────┬──────┘               │
│   参数量: d × d                                 │                       │
│   = d²                                          ▼                       │
│                                            h = W₀x + BAx               │
│                                              = (W₀ + BA)x              │
│                                                                         │
│                                          参数量: d×r + r×d = 2dr       │
│                                          压缩比: d²/(2dr) = d/(2r)     │
│                                                                         │
│   例: d=4096, r=8                                                       │
│       全参: 4096² = 16,777,216 参数                                     │
│       LoRA: 2×4096×8 = 65,536 参数                                      │
│       压缩比: 256倍                                                     │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

5.2 数学推导详解

基本公式:

对于预训练权重矩阵 W 0 ∈ R d × k W_0 \in \mathbb{R}^{d \times k} W0Rd×k,LoRA将权重更新分解为:

W = W 0 + Δ W = W 0 + B A W = W_0 + \Delta W = W_0 + BA W=W0+ΔW=W0+BA

其中:

  • B ∈ R d × r B \in \mathbb{R}^{d \times r} BRd×r (升维矩阵)
  • A ∈ R r × k A \in \mathbb{R}^{r \times k} ARr×k (降维矩阵)
  • r ≪ min ⁡ ( d , k ) r \ll \min(d, k) rmin(d,k) (秩远小于原始维度)

前向传播:

h = W x = W 0 x + B A x = W 0 x + B ( A x ) h = Wx = W_0 x + BAx = W_0 x + B(Ax) h=Wx=W0x+BAx=W0x+B(Ax)

┌─────────────────────────────────────────────────────────────────────────┐
│                      LoRA 前向传播计算流程                               │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│   输入: x ∈ ℝ^{k×1} (或批量 x ∈ ℝ^{batch×seq×k})                       │
│                                                                         │
│   步骤1: 原始路径                                                       │
│   ─────────────────                                                     │
│   h₁ = W₀ · x                                                           │
│   维度: (d×k) · (k×1) = (d×1)                                           │
│                                                                         │
│   步骤2: LoRA路径                                                       │
│   ────────────────                                                      │
│   z = A · x          # 降维: (r×k) · (k×1) = (r×1)                     │
│   h₂ = B · z         # 升维: (d×r) · (r×1) = (d×1)                     │
│   h₂ = B · A · x     # 等价表示                                         │
│                                                                         │
│   步骤3: 缩放与合并                                                     │
│   ─────────────────                                                     │
│   h = h₁ + (α/r) · h₂                                                   │
│     = W₀·x + (α/r) · B·A·x                                              │
│                                                                         │
│   其中 α (lora_alpha) 是缩放因子,用于控制LoRA的影响程度               │
│   缩放系数 α/r 确保不同r值时,学习率的有效值保持稳定                    │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

初始化策略:

# LoRA的初始化非常关键,确保训练开始时不改变原模型行为

# A矩阵: 使用Kaiming初始化(高斯分布)
nn.init.kaiming_uniform_(A, a=math.sqrt(5))

# B矩阵: 零初始化
nn.init.zeros_(B)

# 这样初始时 BA = 0,模型行为与原模型完全一致

为什么这样初始化?

┌─────────────────────────────────────────────────────────────────────────┐
│                      LoRA 初始化策略解析                                 │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│   目标: 训练开始时,LoRA不应改变预训练模型的行为                        │
│                                                                         │
│   ┌──────────────────────────────────────────────────────────────────┐ │
│   │  初始状态:                                                       │ │
│   │  • B = 0 (零矩阵)                                                │ │
│   │  • A ~ N(0, σ²) (高斯随机)                                       │ │
│   │  • BA = 0·A = 0                                                  │ │
│   │  • 输出 h = W₀x + 0·x = W₀x  ← 与原模型一致!                    │ │
│   └──────────────────────────────────────────────────────────────────┘ │
│                                                                         │
│   训练过程中:                                                           │
│   • B逐渐从0学习到有意义的值                                           │
│   • A提供了多样化的初始方向供B选择                                     │
│   • 这种"慢启动"让模型平滑地从预训练状态过渡到微调状态                 │
│                                                                         │
│   对比其他初始化方式:                                                   │
│   ┌────────────────┬───────────────────────────────────────────────┐  │
│   │   初始化方式   │                    效果                        │  │
│   ├────────────────┼───────────────────────────────────────────────┤  │
│   │ A=0, B=random  │ 同样有效,但A无法更新方向                      │  │
│   │ A=random, B=0  │ 推荐方式,B可以学习任意方向组合                │  │
│   │ A=random, B=r  │ 初始即改变模型行为,可能导致训练不稳定        │  │
│   └────────────────┴───────────────────────────────────────────────┘  │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

5.3 LoRA应用于Transformer

在Transformer架构中,LoRA通常应用于注意力层的投影矩阵:

┌─────────────────────────────────────────────────────────────────────────┐
│                   LoRA 在 Transformer 中的应用位置                       │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│                    Transformer Layer                                    │
│   ┌────────────────────────────────────────────────────────────────┐   │
│   │                                                                │   │
│   │   ┌─────────────────────────────────────────────────────────┐ │   │
│   │   │              Multi-Head Attention                       │ │   │
│   │   │  ┌─────────────────────────────────────────────────────┐│ │   │
│   │   │  │                                                     ││ │   │
│   │   │  │  Q = W_q · x + B_q · A_q · x   ← LoRA适用 ✓         ││ │   │
│   │   │  │                                                     ││ │   │
│   │   │  │  K = W_k · x + B_k · A_k · x   ← LoRA适用 ✓         ││ │   │
│   │   │  │                                                     ││ │   │
│   │   │  │  V = W_v · x + B_v · A_v · x   ← LoRA适用 ✓         ││ │   │
│   │   │  │                                                     ││ │   │
│   │   │  │  Attention(Q, K, V) = softmax(QK^T/√d) · V          ││ │   │
│   │   │  │                                                     ││ │   │
│   │   │  │  O = W_o · Attention + B_o · A_o · Attention        ││ │   │
│   │   │  │       ↑ LoRA适用 ✓                                  ││ │   │
│   │   │  │                                                     ││ │   │
│   │   │  └─────────────────────────────────────────────────────┘│ │   │
│   │   └─────────────────────────────────────────────────────────┘ │   │
│   │                            │                                   │   │
│   │                            ▼                                   │   │
│   │   ┌─────────────────────────────────────────────────────────┐ │   │
│   │   │                Feed-Forward Network                     │ │   │
│   │   │  ┌─────────────────────────────────────────────────────┐│ │   │
│   │   │  │                                                     ││ │   │
│   │   │  │  gate = W_gate · x + B_gate · A_gate · x  ← 可选    ││ │   │
│   │   │  │  up   = W_up · x + B_up · A_up · x        ← 可选    ││ │   │
│   │   │  │  down = W_down · (gate ⊙ up)              ← 可选    ││ │   │
│   │   │  │                                                     ││ │   │
│   │   │  └─────────────────────────────────────────────────────┘│ │   │
│   │   └─────────────────────────────────────────────────────────┘ │   │
│   │                                                                │   │
│   └────────────────────────────────────────────────────────────────┘   │
│                                                                         │
│   常用target_modules配置:                                               │
│   ━━━━━━━━━━━━━━━━━━━━━━━                                               │
│   • 最小配置: ["q_proj", "v_proj"]           参数量最少,效果尚可      │
│   • 推荐配置: ["q_proj", "k_proj",           平衡效果和效率            │
│               "v_proj", "o_proj"]                                      │
│   • 完整配置: ["q_proj", "k_proj", "v_proj", 效果最好,参数量较多      │
│               "o_proj", "gate_proj",                                   │
│               "up_proj", "down_proj"]                                  │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

5.4 LoRA关键参数深度解析

from peft import LoraConfig

config = LoraConfig(
    r=8,                    # 秩
    lora_alpha=32,          # 缩放系数
    lora_dropout=0.1,       # Dropout
    target_modules=[...],   # 目标模块
    bias="none",            # 偏置处理
    task_type="CAUSAL_LM",  # 任务类型
)

参数详解:

┌─────────────────────────────────────────────────────────────────────────┐
│                       LoRA 关键参数详解                                  │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  1. r (秩/rank)                                                         │
│  ═══════════════                                                        │
│  定义: 低秩矩阵的维度,决定了LoRA的表达能力                             │
│                                                                         │
│  ┌─────────┬───────────────┬─────────────────────────────────────────┐ │
│  │    r    │   参数量占比  │              适用场景                    │ │
│  ├─────────┼───────────────┼─────────────────────────────────────────┤ │
│  │    4    │    ~0.05%     │  简单任务,快速实验                      │ │
│  │    8    │    ~0.1%      │  常用默认值,大多数任务                  │ │
│  │   16    │    ~0.2%      │  中等复杂度任务                          │ │
│  │   32    │    ~0.4%      │  复杂任务,追求更好效果                  │ │
│  │   64    │    ~0.8%      │  复杂任务,数据量大                      │ │
│  │  128+   │    ~1.6%+     │  接近全参微调,很少使用                  │ │
│  └─────────┴───────────────┴─────────────────────────────────────────┘ │
│                                                                         │
│  选择建议:                                                              │
│  • 从r=8开始实验,根据效果调整                                         │
│  • 训练数据少时,用较小的r防止过拟合                                   │
│  • 任务复杂或数据量大时,可以增大r                                     │
│  • r的增大是边际递减的,r=64后提升很小                                 │
│                                                                         │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  2. lora_alpha (缩放系数)                                               │
│  ════════════════════════                                               │
│  定义: 控制LoRA更新的缩放比例,实际缩放因子为 alpha/r                   │
│                                                                         │
│  数学解释:                                                              │
│  h = W₀x + (α/r) · BAx                                                  │
│                                                                         │
│  ┌───────────────────────────────────────────────────────────────────┐ │
│  │  为什么需要 alpha/r 而不是直接用 alpha?                           │ │
│  │  ──────────────────────────────────────────────────               │ │
│  │  • 不同r值时,BA的范数不同                                        │ │
│  │  • alpha/r 使得改变r时,有效学习率基本不变                        │ │
│  │  • 这样可以在不同r值之间迁移超参数                                │ │
│  │                                                                   │ │
│  │  例如: alpha=32, r=8 时,缩放因子=4                               │ │
│  │        alpha=32, r=16时,缩放因子=2                               │ │
│  │        r增大2倍,缩放因子减小2倍,保持总影响相对稳定              │ │
│  └───────────────────────────────────────────────────────────────────┘ │
│                                                                         │
│  设置建议:                                                              │
│  • 常用值: alpha = 2r 或 alpha = 4r                                    │
│  • 数据量少时用较大的alpha (如 4r),增强LoRA影响                       │
│  • 数据量大时用较小的alpha (如 r 或 2r),避免过度偏离                  │
│                                                                         │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  3. lora_dropout                                                        │
│  ════════════════                                                       │
│  定义: 对LoRA层输出应用Dropout,防止过拟合                              │
│                                                                         │
│  ┌─────────────┬───────────────────────────────────────────────────┐   │
│  │   数据规模  │                  建议值                            │   │
│  ├─────────────┼───────────────────────────────────────────────────┤   │
│  │   <1000条   │   0.1 - 0.2                                        │   │
│  │  1k-10k条   │   0.05 - 0.1                                       │   │
│  │   >10k条    │   0 - 0.05                                         │   │
│  └─────────────┴───────────────────────────────────────────────────┘   │
│                                                                         │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  4. target_modules (目标模块)                                           │
│  ═══════════════════════════                                            │
│  定义: 指定对哪些层应用LoRA                                             │
│                                                                         │
│  常见模型的模块名称:                                                    │
│  ┌─────────────────┬─────────────────────────────────────────────────┐ │
│  │      模型       │              target_modules                      │ │
│  ├─────────────────┼─────────────────────────────────────────────────┤ │
│  │  LLaMA/Qwen     │  q_proj, k_proj, v_proj, o_proj                 │ │
│  │                 │  gate_proj, up_proj, down_proj                  │ │
│  ├─────────────────┼─────────────────────────────────────────────────┤ │
│  │  GPT-2/GPT-J    │  c_attn, c_proj, c_fc                           │ │
│  ├─────────────────┼─────────────────────────────────────────────────┤ │
│  │  BLOOM          │  query_key_value, dense, dense_h_to_4h          │ │
│  ├─────────────────┼─────────────────────────────────────────────────┤ │
│  │  Mistral        │  q_proj, k_proj, v_proj, o_proj                 │ │
│  │                 │  gate_proj, up_proj, down_proj                  │ │
│  └─────────────────┴─────────────────────────────────────────────────┘ │
│                                                                         │
│  如何查找模型的模块名称:                                                │
│  ```python                                                              │
│  from transformers import AutoModel                                     │
│  model = AutoModel.from_pretrained("model_name")                        │
│  for name, module in model.named_modules():                             │
│      print(name, type(module).__name__)                                 │
│  ```                                                               │
│                                                                         │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  5. bias (偏置处理)                                                     │
│  ═══════════════════                                                    │
│  • "none": 不训练任何偏置 (默认,推荐)                                  │
│  • "all": 训练所有偏置项                                                │
│  • "lora_only": 只训练LoRA层对应的偏置                                  │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

5.5 LoRA的推理优化:权重合并

LoRA的一个重要优势是训练后可以将适配器权重合并到原模型中,消除推理时的额外计算开销

┌─────────────────────────────────────────────────────────────────────────┐
│                       LoRA 权重合并原理                                  │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│   训练时的前向传播:                                                     │
│   ─────────────────                                                     │
│   h = W₀·x + (α/r)·B·A·x                                                │
│     = W₀·x + ΔW·x          (其中 ΔW = (α/r)·B·A)                        │
│     = (W₀ + ΔW)·x                                                       │
│                                                                         │
│   权重合并:                                                             │
│   ──────────                                                            │
│   W_merged = W₀ + (α/r)·B·A                                             │
│                                                                         │
│   合并后的推理:                                                         │
│   ──────────────                                                        │
│   h = W_merged·x           ← 与原模型计算量完全相同!                   │
│                                                                         │
│   ┌──────────────────────────────────────────────────────────────────┐ │
│   │                    合并前 vs 合并后                               │ │
│   ├──────────────────────────────────────────────────────────────────┤ │
│   │                                                                  │ │
│   │   合并前 (需要PEFT库):                合并后 (独立运行):         │ │
│   │   ┌─────────┐  ┌─────────┐           ┌─────────────────┐        │ │
│   │   │ 基础模型 │ + │ LoRA    │    →     │   合并后模型     │        │ │
│   │   │  14GB   │  │ Adapter │           │     14GB        │        │ │
│   │   │         │  │  ~20MB  │           │ (无额外开销)    │        │ │
│   │   └─────────┘  └─────────┘           └─────────────────┘        │ │
│   │                                                                  │ │
│   │   • 需要加载两部分                    • 单一模型文件            │ │
│   │   • 每次推理多一次矩阵乘法            • 计算量与原模型相同      │ │
│   │   • 可以热切换不同adapter             • 部署更简单              │ │
│   └──────────────────────────────────────────────────────────────────┘ │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

```python
# 权重合并代码示例
from peft import PeftModel

# 加载基础模型和LoRA适配器
base_model = AutoModelForCausalLM.from_pretrained("base_model_path")
model = PeftModel.from_pretrained(base_model, "lora_adapter_path")

# 合并权重
merged_model = model.merge_and_unload()

# 保存合并后的模型
merged_model.save_pretrained("merged_model_path")

六、QLoRA:量化与低秩的完美结合

6.1 QLoRA的核心创新

QLoRA(Quantized LoRA)由华盛顿大学在2023年提出,通过结合4bit量化LoRA,在几乎不损失性能的情况下,将微调所需显存降低到原来的1/4。

┌─────────────────────────────────────────────────────────────────────────┐
│                        QLoRA 核心架构                                    │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│   QLoRA = 4bit量化 + 双重量化 + 分页优化器 + LoRA                       │
│                                                                         │
│   ┌─────────────────────────────────────────────────────────────────┐  │
│   │                          输入 x                                  │  │
│   │                            │                                     │  │
│   │              ┌─────────────┴─────────────┐                      │  │
│   │              │                           │                      │  │
│   │              ▼                           ▼                      │  │
│   │    ┌─────────────────┐         ┌─────────────────┐             │  │
│   │    │    W₀ (4bit)    │         │   B·A (16bit)   │             │  │
│   │    │   NF4量化存储   │         │  LoRA适配器     │             │  │
│   │    │                 │         │                 │             │  │
│   │    │ 计算时反量化为  │         │   正常训练      │             │  │
│   │    │   BF16/FP16    │         │   梯度更新      │             │  │
│   │    └────────┬────────┘         └────────┬────────┘             │  │
│   │             │                           │                      │  │
│   │             │    Dequantize(W₀)·x + B·A·x                      │  │
│   │             │                           │                      │  │
│   │             └─────────────┬─────────────┘                      │  │
│   │                           │                                     │  │
│   │                           ▼                                     │  │
│   │                         输出 h                                  │  │
│   │                                                                 │  │
│   └─────────────────────────────────────────────────────────────────┘  │
│                                                                         │
│   关键技术:                                                             │
│   ─────────                                                             │
│   1. NF4 (4-bit NormalFloat): 专为正态分布权重设计的量化格式           │
│   2. 双重量化: 对量化常数也进行量化,进一步节省显存                     │
│   3. 分页优化器: 显存不足时自动将优化器状态转移到CPU内存               │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

6.2 NF4量化详解

┌─────────────────────────────────────────────────────────────────────────┐
│                      NF4 (4-bit NormalFloat) 量化原理                    │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│   背景: 神经网络权重通常近似服从正态分布                                │
│                                                                         │
│   权重分布示意:                                                         │
│                    ┌─────────────────┐                                 │
│                    │      ████       │                                 │
│                    │     ██████      │                                 │
│                    │    ████████     │                                 │
│                    │   ██████████    │                                 │
│                    │  ████████████   │                                 │
│                    │ ██████████████  │                                 │
│                    └─────────────────┘                                 │
│                   -3σ  -2σ  -σ  0  σ  2σ  3σ                           │
│                                                                         │
│   NF4的设计思想:                                                        │
│   ────────────────                                                      │
│   • 传统INT4: 均匀分布16个量化点                                        │
│   • NF4: 按正态分布的分位数分配16个量化点                               │
│   • 中心区域(权重密集)分配更多量化点,边缘区域(权重稀疏)分配较少       │
│                                                                         │
│   量化点分布对比:                                                       │
│   ────────────────                                                      │
│   INT4 (均匀):  |  |  |  |  |  |  |  |  |  |  |  |  |  |  |  |        │
│   NF4 (正态):   |    |   |  | || || |  |   |    |                       │
│                     ↑ 中心区域更密集                                    │
│                                                                         │
│   NF4的16个量化值 (归一化到[-1,1]):                                     │
│   -1.0, -0.6962, -0.5251, -0.3949, -0.2844, -0.1848, -0.0911, 0.0,     │
│   0.0796, 0.1609, 0.2461, 0.3379, 0.4407, 0.5626, 0.7230, 1.0          │
│                                                                         │
│   量化公式:                                                             │
│   ──────────                                                            │
│   q = round(W / absmax(W) * 127)  # 先归一化                           │
│   W_nf4 = nearest_nf4(q)          # 找最近的NF4量化点                  │
│                                                                         │
│   反量化公式:                                                           │
│   ────────────                                                          │
│   W_dequant = W_nf4 * absmax(W)   # 恢复原始尺度                       │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

6.3 双重量化 (Double Quantization)

┌─────────────────────────────────────────────────────────────────────────┐
│                          双重量化技术                                    │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│   问题: 量化需要为每个block存储一个缩放因子(absmax)                     │
│         block_size=64时,每64个参数需要1个FP32缩放因子                  │
│         这会带来额外的显存开销                                          │
│                                                                         │
│   单重量化的显存占用 (以7B模型为例):                                    │
│   ───────────────────────────────────                                   │
│   • 4bit权重: 7B × 0.5 bytes = 3.5 GB                                  │
│   • FP32缩放因子: 7B/64 × 4 bytes = 0.44 GB                            │
│   • 总计: 3.94 GB                                                       │
│                                                                         │
│   双重量化:                                                             │
│   ──────────                                                            │
│   • 第一层量化: 权重 → 4bit + FP32缩放因子                             │
│   • 第二层量化: FP32缩放因子 → 8bit + FP32二级缩放因子                 │
│                                                                         │
│   ┌─────────────────────────────────────────────────────────────────┐  │
│   │   原始权重 W [64个]                                              │  │
│   │        │                                                         │  │
│   │        ▼                                                         │  │
│   │   ┌─────────────────────────────────────────────────────────┐   │  │
│   │   │  第一层量化                                              │   │  │
│   │   │  W_q = NF4量化(W)  [64×4bit = 32 bytes]                 │   │  │
│   │   │  scale_1 = absmax(W)  [1×FP32 = 4 bytes]                │   │  │
│   │   └─────────────────────────────────────────────────────────┘   │  │
│   │        │                                                         │  │
│   │        ▼  (对256个scale_1进行二次量化)                          │  │
│   │   ┌─────────────────────────────────────────────────────────┐   │  │
│   │   │  第二层量化                                              │   │  │
│   │   │  scale_1_q = INT8量化(scale_1)  [256×1byte = 256 bytes] │   │  │
│   │   │  scale_2 = absmax(scale_1)  [1×FP32 = 4 bytes]          │   │  │
│   │   └─────────────────────────────────────────────────────────┘   │  │
│   └─────────────────────────────────────────────────────────────────┘  │
│                                                                         │
│   双重量化后的显存占用:                                                 │
│   ───────────────────────                                               │
│   • 4bit权重: 7B × 0.5 bytes = 3.5 GB                                  │
│   • 8bit缩放因子: 7B/64 × 1 byte = 0.11 GB                             │
│   • FP32二级缩放: 7B/64/256 × 4 bytes ≈ 0.002 GB                       │
│   • 总计: ~3.61 GB (节省约0.33 GB)                                     │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

6.4 QLoRA显存对比

┌─────────────────────────────────────────────────────────────────────────┐
│                   7B模型不同微调方法显存对比                             │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│   ┌────────────────────────────────────────────────────────────────┐   │
│   │                                                                │   │
│   │   全参FP16微调                                                 │   │
│   │   ████████████████████████████████████████████████████████████ │   │
│   │   ~120GB (需要多卡)                                            │   │
│   │                                                                │   │
│   │   LoRA FP16                                                    │   │
│   │   ████████████████████████████                                 │   │
│   │   ~28GB (1×A100 40GB)                                          │   │
│   │                                                                │   │
│   │   LoRA INT8                                                    │   │
│   │   ████████████████████                                         │   │
│   │   ~20GB (1×RTX 3090)                                           │   │
│   │                                                                │   │
│   │   QLoRA 4bit                                                   │   │
│   │   █████████████████                                            │   │
│   │   ~16GB (1×RTX 4090)                                           │   │
│   │                                                                │   │
│   │   QLoRA 4bit + Gradient Checkpointing                          │   │
│   │   ████████████                                                 │   │
│   │   ~10GB (1×RTX 3080)                                           │   │
│   │                                                                │   │
│   └────────────────────────────────────────────────────────────────┘   │
│                                                                         │
│       0GB     20GB     40GB     60GB     80GB    100GB    120GB        │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

详细显存构成对比表:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

┌────────────────┬──────────────┬────────────┬────────────┬────────────┐
│     组成部分   │  全参FP16    │  LoRA FP16 │  QLoRA 4bit│   说明     │
├────────────────┼──────────────┼────────────┼────────────┼────────────┤
│ 模型权重       │    14 GB     │   14 GB    │   3.5 GB   │ 量化效果   │
├────────────────┼──────────────┼────────────┼────────────┼────────────┤
│ 梯度           │    14 GB     │   ~0 GB    │   ~0 GB    │ 只算LoRA   │
├────────────────┼──────────────┼────────────┼────────────┼────────────┤
│ 优化器状态     │    84 GB     │   ~0.2 GB  │   ~0.2 GB  │ 只优化LoRA │
├────────────────┼──────────────┼────────────┼────────────┼────────────┤
│ 激活值         │    4 GB      │   10 GB    │   8 GB     │ 反量化开销 │
├────────────────┼──────────────┼────────────┼────────────┼────────────┤
│ LoRA参数       │     -        │   ~0.02 GB │   ~0.02 GB │ 很小       │
├────────────────┼──────────────┼────────────┼────────────┼────────────┤
│ 总计           │   ~116 GB    │   ~24 GB   │   ~12 GB   │            │
└────────────────┴──────────────┴────────────┴────────────┴────────────┘

七、其他PEFT方法详解

7.1 Adapter Tuning

┌─────────────────────────────────────────────────────────────────────────┐
│                        Adapter Tuning 详解                               │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│   核心思想: 在Transformer层中插入小型"适配器"模块                       │
│                                                                         │
│   Adapter结构:                                                          │
│   ┌─────────────────────────────────────────────────────────────────┐  │
│   │                                                                 │  │
│   │   输入 x (hidden_size = d)                                     │  │
│   │       │                                                         │  │
│   │       ▼                                                         │  │
│   │   ┌───────────────┐                                            │  │
│   │   │ Down-project  │  W_down: d → r  (降维)                     │  │
│   │   │   (d → r)     │  参数: d × r                               │  │
│   │   └───────┬───────┘                                            │  │
│   │           │                                                     │  │
│   │           ▼                                                     │  │
│   │   ┌───────────────┐                                            │  │
│   │   │   Non-linear  │  通常是 ReLU 或 GELU                       │  │
│   │   │  Activation   │                                            │  │
│   │   └───────┬───────┘                                            │  │
│   │           │                                                     │  │
│   │           ▼                                                     │  │
│   │   ┌───────────────┐                                            │  │
│   │   │  Up-project   │  W_up: r → d  (升维)                       │  │
│   │   │   (r → d)     │  参数: r × d                               │  │
│   │   └───────┬───────┘                                            │  │
│   │           │                                                     │  │
│   │           ▼                                                     │  │
│   │       + ←───────── x (残差连接)                                │  │
│   │       │                                                         │  │
│   │       ▼                                                         │  │
│   │   输出 (d维)                                                    │  │
│   │                                                                 │  │
│   │   总参数: 2 × d × r + r (bias)                                 │  │
│   │                                                                 │  │
│   └─────────────────────────────────────────────────────────────────┘  │
│                                                                         │
│   在Transformer中的位置:                                                │
│   ┌─────────────────────────────────────────────────────────────────┐  │
│   │                                                                 │  │
│   │   Input                                                         │  │
│   │     │                                                           │  │
│   │     ▼                                                           │  │
│   │   Multi-Head Attention                                          │  │
│   │     │                                                           │  │
│   │     ▼                                                           │  │
│   │   ┌──────────────────┐                                         │  │
│   │   │    Adapter 1     │  ← 注意力后                              │  │
│   │   └──────────────────┘                                         │  │
│   │     │                                                           │  │
│   │     ▼                                                           │  │
│   │   Layer Norm                                                    │  │
│   │     │                                                           │  │
│   │     ▼                                                           │  │
│   │   Feed Forward                                                  │  │
│   │     │                                                           │  │
│   │     ▼                                                           │  │
│   │   ┌──────────────────┐                                         │  │
│   │   │    Adapter 2     │  ← FFN后                                 │  │
│   │   └──────────────────┘                                         │  │
│   │     │                                                           │  │
│   │     ▼                                                           │  │
│   │   Layer Norm                                                    │  │
│   │     │                                                           │  │
│   │     ▼                                                           │  │
│   │   Output                                                        │  │
│   │                                                                 │  │
│   └─────────────────────────────────────────────────────────────────┘  │
│                                                                         │
│   Adapter vs LoRA:                                                      │
│   ┌─────────────────┬─────────────────┬─────────────────────────────┐  │
│   │     特性        │     Adapter     │          LoRA               │  │
│   ├─────────────────┼─────────────────┼─────────────────────────────┤  │
│   │  推理开销       │  有 (串行计算)  │  无 (可合并)                │  │
│   │  多任务切换     │  简单           │  需要重新合并               │  │
│   │  参数效率       │  中等           │  更高                       │  │
│   │  实现复杂度     │  低             │  低                         │  │
│   └─────────────────┴─────────────────┴─────────────────────────────┘  │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

7.2 Prefix-Tuning

┌─────────────────────────────────────────────────────────────────────────┐
│                        Prefix-Tuning 详解                                │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│   核心思想: 在每一层的Key和Value前面添加可学习的"前缀"向量              │
│                                                                         │
│   原理图解:                                                             │
│   ┌─────────────────────────────────────────────────────────────────┐  │
│   │                                                                 │  │
│   │   原始注意力:                                                   │  │
│   │   Q = [q₁, q₂, ..., qₙ]     (n个query)                         │  │
│   │   K = [k₁, k₂, ..., kₙ]     (n个key)                           │  │
│   │   V = [v₁, v₂, ..., vₙ]     (n个value)                         │  │
│   │                                                                 │  │
│   │   Prefix-Tuning:                                                │  │
│   │   Q = [q₁, q₂, ..., qₙ]           (query不变)                  │  │
│   │   K = [P_k₁, ..., P_kₘ, k₁, ..., kₙ]  (前面加m个prefix key)    │  │
│   │   V = [P_v₁, ..., P_vₘ, v₁, ..., vₙ]  (前面加m个prefix value)  │  │
│   │                                                                 │  │
│   │   其中 P_k, P_v 是可学习的参数                                  │  │
│   │                                                                 │  │
│   └─────────────────────────────────────────────────────────────────┘  │
│                                                                         │
│   注意力计算变化:                                                       │
│   ┌─────────────────────────────────────────────────────────────────┐  │
│   │                                                                 │  │
│   │   原始: Attention(Q, K, V) = softmax(QK^T/√d) V                │  │
│   │                                                                 │  │
│   │   Prefix-Tuning:                                                │  │
│   │                    ┌───────────────┬───────────────┐            │  │
│   │   Attention = Q × │   P_k (m×d)   │   K (n×d)     │^T          │  │
│   │                    └───────────────┴───────────────┘            │  │
│   │                                                                 │  │
│   │   每个query现在会同时关注:                                      │  │
│   │   • 可学习的prefix (提供任务特定信息)                           │  │
│   │   • 原始的keys (提供输入相关信息)                               │  │
│   │                                                                 │  │
│   └─────────────────────────────────────────────────────────────────┘  │
│                                                                         │
│   参数量计算:                                                           │
│   ────────────                                                          │
│   每层需要: 2 × prefix_length × hidden_size × num_heads                │
│   总参数: num_layers × 2 × m × d                                       │
│                                                                         │
│   例: LLaMA-7B, prefix_length=10                                       │
│   = 32层 × 2 × 10 × 4096 = 2,621,440 ≈ 2.6M参数 (0.04%)               │
│                                                                         │
│   优缺点:                                                               │
│   ┌─────────────────────────────────────────────────────────────────┐  │
│   │  ✅ 优点:                         ❌ 缺点:                      │  │
│   │  • 参数量极少                     • 推理时增加序列长度          │  │
│   │  • 不改变模型结构                 • 效果不如LoRA                │  │
│   │  • 多任务时方便切换               • 复杂任务表现一般            │  │
│   └─────────────────────────────────────────────────────────────────┘  │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

7.3 Prompt-Tuning

┌─────────────────────────────────────────────────────────────────────────┐
│                        Prompt-Tuning 详解                                │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│   核心思想: 在输入embedding层添加可学习的"软提示"向量                   │
│                                                                         │
│   Hard Prompt (传统方式):                                               │
│   ┌─────────────────────────────────────────────────────────────────┐  │
│   │  输入: "Classify the sentiment: I love this movie. Answer:"    │  │
│   │        ↑───────── 人工设计的提示词 ─────────↑                   │  │
│   │  • 需要人工设计                                                 │  │
│   │  • 离散token,不可微分                                          │  │
│   │  • 性能依赖提示词质量                                           │  │
│   └─────────────────────────────────────────────────────────────────┘  │
│                                                                         │
│   Soft Prompt (Prompt-Tuning):                                          │
│   ┌─────────────────────────────────────────────────────────────────┐  │
│   │                                                                 │  │
│   │   输入序列:  [soft_1] [soft_2] ... [soft_n] [input tokens]     │  │
│   │              └─────── 可学习 ───────┘ └─── 真实输入 ───┘       │  │
│   │                                                                 │  │
│   │   Embedding层:                                                  │  │
│   │   ┌─────────────────────────────────────────────────────────┐  │  │
│   │   │  [P₁]   [P₂]  ... [Pₙ]   [E₁]   [E₂]  ... [Eₘ]        │  │  │
│   │   │   ↑      ↑         ↑      ↑      ↑         ↑           │  │  │
│   │   │  可学习向量        真实token的embedding                   │  │  │
│   │   │  (从随机初始化     (来自预训练的embedding层)              │  │  │
│   │   │   或token初始化)                                         │  │  │
│   │   └─────────────────────────────────────────────────────────┘  │  │
│   │                                                                 │  │
│   │   Pᵢ ∈ ℝ^{hidden_size}, 共n×hidden_size个可训练参数           │  │
│   │                                                                 │  │
│   └─────────────────────────────────────────────────────────────────┘  │
│                                                                         │
│   vs Prefix-Tuning:                                                     │
│   ┌─────────────────────────────────────────────────────────────────┐  │
│   │                                                                 │  │
│   │   Prompt-Tuning:    只在输入embedding层添加软提示               │  │
│   │                     参数量: n × d                               │  │
│   │                                                                 │  │
│   │   Prefix-Tuning:    在每一层的K,V添加前缀                       │  │
│   │                     参数量: L × 2 × n × d                       │  │
│   │                                                                 │  │
│   │   Prefix-Tuning 参数量约为 Prompt-Tuning 的 2L 倍              │  │
│   │                                                                 │  │
│   └─────────────────────────────────────────────────────────────────┘  │
│                                                                         │
│   适用场景:                                                             │
│   • 超大模型 (100B+),此时效果接近全参微调                             │
│   • 多任务场景,每个任务只需一组软提示                                 │
│   • 对参数量极度敏感的场景                                             │
│                                                                         │
│   局限性:                                                               │
│   • 小模型上效果明显差于LoRA                                           │
│   • 训练不够稳定                                                       │
│   • 推理时有额外的序列长度开销                                         │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

7.4 DoRA:LoRA的改进版

┌─────────────────────────────────────────────────────────────────────────┐
│                      DoRA (Weight-Decomposed LoRA) 详解                  │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│   核心思想: 将权重分解为"幅度"和"方向"两个部分,分别处理                │
│                                                                         │
│   LoRA的局限:                                                           │
│   ────────────                                                          │
│   W = W₀ + BA                                                           │
│   • 同时改变了权重的幅度(magnitude)和方向(direction)                   │
│   • 两者耦合在一起,难以独立控制                                       │
│                                                                         │
│   DoRA的分解:                                                           │
│   ─────────────                                                         │
│   将权重W表示为: W = m · (W / ||W||) = m · V                            │
│   • m: 幅度 (magnitude), 标量或向量                                    │
│   • V = W / ||W||: 方向矩阵 (direction), 归一化后的权重                │
│                                                                         │
│   ┌─────────────────────────────────────────────────────────────────┐  │
│   │                                                                 │  │
│   │   原始权重:  W₀ = m₀ · V₀                                       │  │
│   │                                                                 │  │
│   │   DoRA微调:                                                     │  │
│   │   • 幅度: m = m₀ + Δm    (Δm是可学习的幅度调整)                │  │
│   │   • 方向: V = (V₀ + BA) / ||V₀ + BA||  (用LoRA调整方向)        │  │
│   │                                                                 │  │
│   │   最终: W = m · V = (m₀ + Δm) · normalize(V₀ + BA)             │  │
│   │                                                                 │  │
│   │   直观理解:                                                     │  │
│   │   ┌─────────────────────────────────────────────────────────┐  │  │
│   │   │                                                         │  │  │
│   │   │         原始权重                 DoRA调整后              │  │  │
│   │   │            ↗                        ↗                    │  │  │
│   │   │           /  ||W₀||               / ||W||               │  │  │
│   │   │          /                       /                       │  │  │
│   │   │         O ────→ 方向V₀          O ────→ 方向V           │  │  │
│   │   │                                                         │  │  │
│   │   │   Δm调整箭头长度(幅度), LoRA调整箭头方向                │  │  │
│   │   │                                                         │  │  │
│   │   └─────────────────────────────────────────────────────────┘  │  │
│   │                                                                 │  │
│   └─────────────────────────────────────────────────────────────────┘  │
│                                                                         │
│   参数量对比:                                                           │
│   ─────────────                                                         │
│   • LoRA:  2 × d × r                                                   │
│   • DoRA:  2 × d × r + d (额外的幅度参数)                              │
│   • 增加很少,但效果有明显提升                                         │
│                                                                         │
│   实验效果 (论文数据):                                                  │
│   ┌─────────────────┬──────────────┬──────────────┬────────────────┐   │
│   │    方法         │  LLaMA-7B    │  LLaMA-13B   │   平均提升     │   │
│   ├─────────────────┼──────────────┼──────────────┼────────────────┤   │
│   │    LoRA         │    63.4      │    65.2      │      -         │   │
│   │    DoRA         │    65.1      │    67.0      │    +1.7%       │   │
│   │    全参微调     │    65.8      │    67.5      │      -         │   │
│   └─────────────────┴──────────────┴──────────────┴────────────────┘   │
│   DoRA更接近全参微调的效果                                              │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

八、工业级代码实战

8.1 完整的LoRA微调Pipeline

"""
完整的LoRA微调代码示例
包含: 数据处理、模型配置、训练、评估、保存、推理
"""

import os
import torch
from dataclasses import dataclass, field
from typing import Optional, List, Dict
from datasets import load_dataset, Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForSeq2Seq,
    BitsAndBytesConfig,
)
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
    TaskType,
)
import wandb

# ═══════════════════════════════════════════════════════════════════════════
# 第一部分: 配置类定义
# ═══════════════════════════════════════════════════════════════════════════

@dataclass
class ModelConfig:
    """模型相关配置"""
    model_name: str = "Qwen/Qwen2-1.5B-Instruct"  # 基础模型路径
    torch_dtype: str = "bfloat16"                   # 模型精度
    use_flash_attention: bool = True                 # 是否使用Flash Attention
    trust_remote_code: bool = True                   # 是否信任远程代码
    
@dataclass
class LoRAConfig:
    """LoRA相关配置"""
    r: int = 16                                      # LoRA秩
    lora_alpha: int = 32                             # 缩放系数
    lora_dropout: float = 0.05                       # Dropout
    target_modules: List[str] = field(default_factory=lambda: [
        "q_proj", "k_proj", "v_proj", "o_proj",     # 注意力层
        "gate_proj", "up_proj", "down_proj"          # FFN层
    ])
    bias: str = "none"                               # 偏置处理
    task_type: str = "CAUSAL_LM"                     # 任务类型

@dataclass  
class TrainConfig:
    """训练相关配置"""
    output_dir: str = "./output"                     # 输出目录
    num_train_epochs: int = 3                        # 训练轮数
    per_device_train_batch_size: int = 4             # 训练批次大小
    per_device_eval_batch_size: int = 4              # 评估批次大小
    gradient_accumulation_steps: int = 4             # 梯度累积步数
    learning_rate: float = 2e-4                      # 学习率
    warmup_ratio: float = 0.1                        # 预热比例
    weight_decay: float = 0.01                       # 权重衰减
    max_seq_length: int = 2048                       # 最大序列长度
    logging_steps: int = 10                          # 日志记录步数
    save_steps: int = 100                            # 保存步数
    eval_steps: int = 100                            # 评估步数
    save_total_limit: int = 3                        # 最大保存数量
    
@dataclass
class DataConfig:
    """数据相关配置"""
    dataset_name: str = "tatsu-lab/alpaca"           # 数据集名称
    train_split: str = "train[:90%]"                 # 训练集划分
    eval_split: str = "train[90%:]"                  # 验证集划分
    max_samples: Optional[int] = None                # 最大样本数(调试用)

# ═══════════════════════════════════════════════════════════════════════════
# 第二部分: 数据处理
# ═══════════════════════════════════════════════════════════════════════════

class DataProcessor:
    """数据处理类"""
    
    def __init__(self, tokenizer, max_length: int = 2048):
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        # 确保tokenizer有pad_token
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            
    def format_alpaca(self, example: Dict) -> Dict:
        """
        将Alpaca格式数据转换为对话格式
        
        Alpaca格式:
        {
            "instruction": "指令",
            "input": "输入(可选)",
            "output": "输出"
        }
        """
        instruction = example.get("instruction", "")
        input_text = example.get("input", "")
        output = example.get("output", "")
        
        # 构建prompt
        if input_text:
            prompt = f"""### 指令:
{instruction}

### 输入:
{input_text}

### 回答:
"""
        else:
            prompt = f"""### 指令:
{instruction}

### 回答:
"""
        
        # 完整文本 = prompt + output
        full_text = prompt + output
        
        return {
            "prompt": prompt,
            "output": output,
            "full_text": full_text
        }
    
    def tokenize_function(self, examples: Dict) -> Dict:
        """
        分词函数 - 处理prompt和output的拼接与mask
        
        关键点:
        1. prompt部分的label设为-100 (不计算loss)
        2. output部分正常计算loss
        3. padding部分label设为-100
        """
        prompts = examples["prompt"]
        outputs = examples["output"]
        
        model_inputs = {
            "input_ids": [],
            "attention_mask": [],
            "labels": []
        }
        
        for prompt, output in zip(prompts, outputs):
            # 分别tokenize prompt和output
            prompt_tokens = self.tokenizer(
                prompt, 
                add_special_tokens=True,
                truncation=True,
                max_length=self.max_length
            )
            output_tokens = self.tokenizer(
                output + self.tokenizer.eos_token,  # 添加结束符
                add_special_tokens=False,
                truncation=True,
                max_length=self.max_length - len(prompt_tokens["input_ids"])
            )
            
            # 拼接input_ids
            input_ids = prompt_tokens["input_ids"] + output_tokens["input_ids"]
            attention_mask = [1] * len(input_ids)
            
            # 构建labels: prompt部分为-100, output部分为实际token id
            labels = [-100] * len(prompt_tokens["input_ids"]) + output_tokens["input_ids"]
            
            # 截断到max_length
            if len(input_ids) > self.max_length:
                input_ids = input_ids[:self.max_length]
                attention_mask = attention_mask[:self.max_length]
                labels = labels[:self.max_length]
            
            model_inputs["input_ids"].append(input_ids)
            model_inputs["attention_mask"].append(attention_mask)
            model_inputs["labels"].append(labels)
            
        return model_inputs

    def prepare_dataset(
        self, 
        dataset_name: str,
        train_split: str = "train",
        eval_split: Optional[str] = None,
        max_samples: Optional[int] = None
    ) -> tuple:
        """准备训练和验证数据集"""
        
        # 加载数据集
        print(f"Loading dataset: {dataset_name}")
        raw_dataset = load_dataset(dataset_name, split=train_split)
        
        if max_samples:
            raw_dataset = raw_dataset.select(range(min(max_samples, len(raw_dataset))))
            
        # 格式转换
        print("Formatting dataset...")
        formatted_dataset = raw_dataset.map(
            self.format_alpaca,
            remove_columns=raw_dataset.column_names,
            desc="Formatting"
        )
        
        # 分词
        print("Tokenizing dataset...")
        tokenized_dataset = formatted_dataset.map(
            self.tokenize_function,
            batched=True,
            remove_columns=formatted_dataset.column_names,
            desc="Tokenizing"
        )
        
        # 划分训练集和验证集
        if eval_split:
            # 如果指定了eval_split,单独加载
            eval_raw = load_dataset(dataset_name, split=eval_split)
            if max_samples:
                eval_raw = eval_raw.select(range(min(max_samples // 10, len(eval_raw))))
            eval_formatted = eval_raw.map(self.format_alpaca, remove_columns=eval_raw.column_names)
            eval_tokenized = eval_formatted.map(self.tokenize_function, batched=True, 
                                                 remove_columns=eval_formatted.column_names)
            return tokenized_dataset, eval_tokenized
        else:
            # 自动划分90%/10%
            split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
            return split_dataset["train"], split_dataset["test"]

# ═══════════════════════════════════════════════════════════════════════════
# 第三部分: 模型加载与LoRA配置
# ═══════════════════════════════════════════════════════════════════════════

class ModelLoader:
    """模型加载类"""
    
    @staticmethod
    def load_model_and_tokenizer(
        model_config: ModelConfig,
        use_quantization: bool = False,
        quantization_bits: int = 4
    ):
        """
        加载模型和分词器
        
        Args:
            model_config: 模型配置
            use_quantization: 是否使用量化 (QLoRA)
            quantization_bits: 量化位数 (4 or 8)
        """
        # 确定torch dtype
        dtype_map = {
            "float32": torch.float32,
            "float16": torch.float16,
            "bfloat16": torch.bfloat16,
        }
        torch_dtype = dtype_map.get(model_config.torch_dtype, torch.bfloat16)
        
        # 量化配置 (QLoRA)
        quantization_config = None
        if use_quantization:
            print(f"Enabling {quantization_bits}-bit quantization (QLoRA)")
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=(quantization_bits == 4),
                load_in_8bit=(quantization_bits == 8),
                bnb_4bit_quant_type="nf4",              # NF4量化
                bnb_4bit_compute_dtype=torch_dtype,     # 计算精度
                bnb_4bit_use_double_quant=True,         # 双重量化
            )
        
        # 加载tokenizer
        print(f"Loading tokenizer: {model_config.model_name}")
        tokenizer = AutoTokenizer.from_pretrained(
            model_config.model_name,
            trust_remote_code=model_config.trust_remote_code,
            padding_side="right",  # 右填充用于生成
        )
        
        # 加载模型
        print(f"Loading model: {model_config.model_name}")
        model_kwargs = {
            "pretrained_model_name_or_path": model_config.model_name,
            "torch_dtype": torch_dtype,
            "trust_remote_code": model_config.trust_remote_code,
            "device_map": "auto",  # 自动分配设备
        }
        
        if quantization_config:
            model_kwargs["quantization_config"] = quantization_config
            
        if model_config.use_flash_attention:
            model_kwargs["attn_implementation"] = "flash_attention_2"
            
        model = AutoModelForCausalLM.from_pretrained(**model_kwargs)
        
        # QLoRA需要准备模型
        if use_quantization:
            model = prepare_model_for_kbit_training(
                model,
                use_gradient_checkpointing=True
            )
        
        return model, tokenizer
    
    @staticmethod
    def apply_lora(model, lora_config: LoRAConfig):
        """应用LoRA配置到模型"""
        
        print("Applying LoRA configuration...")
        
        # 创建PEFT配置
        peft_config = LoraConfig(
            r=lora_config.r,
            lora_alpha=lora_config.lora_alpha,
            lora_dropout=lora_config.lora_dropout,
            target_modules=lora_config.target_modules,
            bias=lora_config.bias,
            task_type=TaskType.CAUSAL_LM,
        )
        
        # 应用LoRA
        model = get_peft_model(model, peft_config)
        
        # 打印可训练参数信息
        model.print_trainable_parameters()
        
        return model

# ═══════════════════════════════════════════════════════════════════════════
# 第四部分: 训练器封装
# ═══════════════════════════════════════════════════════════════════════════

class LoRATrainer:
    """LoRA训练器封装类"""
    
    def __init__(
        self,
        model,
        tokenizer,
        train_dataset,
        eval_dataset,
        train_config: TrainConfig,
        use_wandb: bool = False
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.train_config = train_config
        
        # 初始化wandb
        if use_wandb:
            wandb.init(project="lora-finetuning", name=train_config.output_dir.split("/")[-1])
            
        # 配置训练参数
        self.training_args = self._create_training_args()
        
        # 创建Trainer
        self.trainer = self._create_trainer()
        
    def _create_training_args(self) -> TrainingArguments:
        """创建训练参数"""
        return TrainingArguments(
            output_dir=self.train_config.output_dir,
            num_train_epochs=self.train_config.num_train_epochs,
            per_device_train_batch_size=self.train_config.per_device_train_batch_size,
            per_device_eval_batch_size=self.train_config.per_device_eval_batch_size,
            gradient_accumulation_steps=self.train_config.gradient_accumulation_steps,
            learning_rate=self.train_config.learning_rate,
            warmup_ratio=self.train_config.warmup_ratio,
            weight_decay=self.train_config.weight_decay,
            logging_steps=self.train_config.logging_steps,
            save_steps=self.train_config.save_steps,
            eval_steps=self.train_config.eval_steps,
            evaluation_strategy="steps",
            save_strategy="steps",
            save_total_limit=self.train_config.save_total_limit,
            load_best_model_at_end=True,
            metric_for_best_model="eval_loss",
            greater_is_better=False,
            fp16=False,  # 如果GPU支持,可以启用
            bf16=True,   # 推荐使用bf16
            gradient_checkpointing=True,  # 节省显存
            optim="adamw_torch",
            lr_scheduler_type="cosine",
            report_to="wandb" if wandb.run else "none",
            dataloader_num_workers=4,
            remove_unused_columns=False,
        )
    
    def _create_trainer(self) -> Trainer:
        """创建Trainer实例"""
        
        # 数据整理器
        data_collator = DataCollatorForSeq2Seq(
            tokenizer=self.tokenizer,
            padding=True,
            max_length=self.train_config.max_seq_length,
            pad_to_multiple_of=8,  # 对齐到8的倍数,提高效率
            return_tensors="pt"
        )
        
        return Trainer(
            model=self.model,
            args=self.training_args,
            train_dataset=self.train_dataset,
            eval_dataset=self.eval_dataset,
            tokenizer=self.tokenizer,
            data_collator=data_collator,
        )
    
    def train(self):
        """执行训练"""
        print("=" * 60)
        print("Starting training...")
        print(f"  Total samples: {len(self.train_dataset)}")
        print(f"  Epochs: {self.train_config.num_train_epochs}")
        print(f"  Batch size: {self.train_config.per_device_train_batch_size}")
        print(f"  Gradient accumulation: {self.train_config.gradient_accumulation_steps}")
        print(f"  Effective batch size: {self.train_config.per_device_train_batch_size * self.train_config.gradient_accumulation_steps}")
        print("=" * 60)
        
        # 开始训练
        train_result = self.trainer.train()
        
        # 保存最终模型
        self.trainer.save_model()
        
        # 保存训练指标
        metrics = train_result.metrics
        self.trainer.log_metrics("train", metrics)
        self.trainer.save_metrics("train", metrics)
        
        return train_result
    
    def evaluate(self):
        """执行评估"""
        print("Running evaluation...")
        metrics = self.trainer.evaluate()
        self.trainer.log_metrics("eval", metrics)
        self.trainer.save_metrics("eval", metrics)
        return metrics

# ═══════════════════════════════════════════════════════════════════════════
# 第五部分: 推理与合并
# ═══════════════════════════════════════════════════════════════════════════

class LoRAInference:
    """LoRA推理类"""
    
    def __init__(self, base_model_path: str, lora_adapter_path: str):
        """
        初始化推理器
        
        Args:
            base_model_path: 基础模型路径
            lora_adapter_path: LoRA适配器路径
        """
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # 加载tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(base_model_path)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            
        # 加载模型
        self.model = self._load_model(base_model_path, lora_adapter_path)
        self.model.eval()
        
    def _load_model(self, base_model_path: str, lora_adapter_path: str):
        """加载模型(两种方式)"""
        from peft import PeftModel
        
        # 加载基础模型
        base_model = AutoModelForCausalLM.from_pretrained(
            base_model_path,
            torch_dtype=torch.bfloat16,
            device_map="auto",
        )
        
        # 加载LoRA权重
        model = PeftModel.from_pretrained(base_model, lora_adapter_path)
        
        return model
    
    def merge_and_save(self, output_path: str):
        """合并LoRA权重并保存"""
        print("Merging LoRA weights...")
        merged_model = self.model.merge_and_unload()
        
        print(f"Saving merged model to {output_path}")
        merged_model.save_pretrained(output_path)
        self.tokenizer.save_pretrained(output_path)
        
        return merged_model
    
    @torch.inference_mode()
    def generate(
        self,
        prompt: str,
        max_new_tokens: int = 512,
        temperature: float = 0.7,
        top_p: float = 0.9,
        top_k: int = 50,
        do_sample: bool = True,
    ) -> str:
        """
        生成文本
        
        Args:
            prompt: 输入提示
            max_new_tokens: 最大生成token数
            temperature: 温度参数
            top_p: nucleus sampling参数
            top_k: top-k sampling参数
            do_sample: 是否采样
        """
        # 格式化输入
        formatted_prompt = f"""### 指令:
{prompt}

### 回答:
"""
        
        # 编码输入
        inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.device)
        
        # 生成
        outputs = self.model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            do_sample=do_sample,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
        )
        
        # 解码输出
        generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # 提取回答部分
        if "### 回答:" in generated_text:
            response = generated_text.split("### 回答:")[-1].strip()
        else:
            response = generated_text
            
        return response

# ═══════════════════════════════════════════════════════════════════════════
# 第六部分: 主函数
# ═══════════════════════════════════════════════════════════════════════════

def main():
    """主函数 - 完整的训练流程"""
    
    # 1. 初始化配置
    model_config = ModelConfig(
        model_name="Qwen/Qwen2-1.5B-Instruct",
        torch_dtype="bfloat16",
    )
    
    lora_config = LoRAConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
    )
    
    train_config = TrainConfig(
        output_dir="./output/qwen2-lora-alpaca",
        num_train_epochs=3,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        learning_rate=2e-4,
        max_seq_length=2048,
    )
    
    data_config = DataConfig(
        dataset_name="tatsu-lab/alpaca",
        max_samples=5000,  # 调试时限制样本数
    )
    
    # 2. 加载模型和分词器
    model, tokenizer = ModelLoader.load_model_and_tokenizer(
        model_config,
        use_quantization=False,  # 设为True启用QLoRA
    )
    
    # 3. 应用LoRA
    model = ModelLoader.apply_lora(model, lora_config)
    
    # 4. 准备数据
    data_processor = DataProcessor(tokenizer, train_config.max_seq_length)
    train_dataset, eval_dataset = data_processor.prepare_dataset(
        data_config.dataset_name,
        max_samples=data_config.max_samples,
    )
    
    print(f"Train samples: {len(train_dataset)}")
    print(f"Eval samples: {len(eval_dataset)}")
    
    # 5. 创建训练器并训练
    trainer = LoRATrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        train_config=train_config,
        use_wandb=False,
    )
    
    # 6. 开始训练
    trainer.train()
    
    # 7. 评估
    trainer.evaluate()
    
    print("Training completed!")
    print(f"Model saved to: {train_config.output_dir}")
    
    # 8. 推理示例
    print("\n" + "="*60)
    print("Running inference example...")
    print("="*60)
    
    inference = LoRAInference(
        base_model_path=model_config.model_name,
        lora_adapter_path=train_config.output_dir
    )
    
    test_prompts = [
        "请介绍一下什么是机器学习?",
        "写一首关于春天的诗",
        "如何提高编程能力?",
    ]
    
    for prompt in test_prompts:
        print(f"\n输入: {prompt}")
        response = inference.generate(prompt, max_new_tokens=256)
        print(f"输出: {response}")
        print("-" * 40)

if __name__ == "__main__":
    main()

8.2 QLoRA微调代码

"""
QLoRA微调示例 - 在消费级显卡上微调大模型
"""

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from datasets import load_dataset

def setup_qlora_training():
    """设置QLoRA训练环境"""
    
    model_name = "meta-llama/Llama-2-7b-hf"
    
    # ═══════════════════════════════════════════════════════════════════
    # 步骤1: 配置4bit量化
    # ═══════════════════════════════════════════════════════════════════
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,                  # 启用4bit量化
        bnb_4bit_quant_type="nf4",          # 使用NF4量化类型
        bnb_4bit_compute_dtype=torch.bfloat16,  # 计算时使用bf16
        bnb_4bit_use_double_quant=True,     # 启用双重量化
    )
    
    # ═══════════════════════════════════════════════════════════════════
    # 步骤2: 加载量化模型
    # ═══════════════════════════════════════════════════════════════════
    print("Loading quantized model...")
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
    )
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"
    
    # 准备模型用于k-bit训练
    model = prepare_model_for_kbit_training(
        model,
        use_gradient_checkpointing=True,  # 启用梯度检查点
    )
    
    # ═══════════════════════════════════════════════════════════════════
    # 步骤3: 配置LoRA
    # ═══════════════════════════════════════════════════════════════════
    lora_config = LoraConfig(
        r=64,                           # QLoRA通常使用较大的r
        lora_alpha=16,                  # alpha/r = 0.25
        lora_dropout=0.1,
        target_modules=[
            "q_proj", "k_proj", "v_proj", "o_proj",
            "gate_proj", "up_proj", "down_proj",
        ],
        bias="none",
        task_type="CAUSAL_LM",
    )
    
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()
    
    # ═══════════════════════════════════════════════════════════════════
    # 步骤4: 配置训练参数 (针对QLoRA优化)
    # ═══════════════════════════════════════════════════════════════════
    training_args = TrainingArguments(
        output_dir="./qlora_output",
        num_train_epochs=3,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        learning_rate=2e-4,
        warmup_ratio=0.03,
        lr_scheduler_type="cosine",
        logging_steps=10,
        save_strategy="steps",
        save_steps=100,
        bf16=True,                          # 使用bf16
        optim="paged_adamw_8bit",           # 分页优化器,进一步节省显存
        gradient_checkpointing=True,         # 梯度检查点
        max_grad_norm=0.3,                  # 梯度裁剪
        group_by_length=True,               # 按长度分组,提高效率
        report_to="none",
    )
    
    return model, tokenizer, training_args

# 显存占用参考 (7B模型):
# ┌─────────────────────┬─────────────────┐
# │      配置           │    显存占用     │
# ├─────────────────────┼─────────────────┤
# │ QLoRA 4bit          │    ~10GB        │
# │ + gradient_ckpt     │    ~8GB         │
# │ + paged_adamw_8bit  │    ~7GB         │
# │ batch_size=4        │    ~12GB        │
# └─────────────────────┴─────────────────┘

8.3 多GPU分布式训练

"""
多GPU分布式LoRA训练
支持: DeepSpeed ZeRO, FSDP
"""

from transformers import TrainingArguments
import os

# ═══════════════════════════════════════════════════════════════════════════
# 方式1: DeepSpeed ZeRO-3 配置
# ═══════════════════════════════════════════════════════════════════════════

deepspeed_config = {
    "bf16": {
        "enabled": True
    },
    "zero_optimization": {
        "stage": 3,                             # ZeRO Stage 3
        "offload_optimizer": {
            "device": "cpu",                    # 优化器状态卸载到CPU
            "pin_memory": True
        },
        "offload_param": {
            "device": "cpu",                    # 参数卸载到CPU
            "pin_memory": True
        },
        "overlap_comm": True,
        "contiguous_gradients": True,
        "sub_group_size": 1e9,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_16bit_weights_on_model_save": True
    },
    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "steps_per_print": 10,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": False
}

# 保存配置文件
import json
with open("ds_config.json", "w") as f:
    json.dump(deepspeed_config, f, indent=2)

# 训练参数
training_args = TrainingArguments(
    output_dir="./output",
    deepspeed="ds_config.json",           # 使用DeepSpeed配置
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    num_train_epochs=3,
    learning_rate=2e-4,
    bf16=True,
    logging_steps=10,
    save_steps=100,
)

# 启动命令:
# torchrun --nproc_per_node=4 train.py
# 或
# deepspeed --num_gpus=4 train.py

# ═══════════════════════════════════════════════════════════════════════════
# 方式2: Accelerate + FSDP
# ═══════════════════════════════════════════════════════════════════════════

# accelerate_config.yaml:
"""
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_forward_prefetch: false
  fsdp_offload_params: false
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 4
"""

# 启动命令:
# accelerate launch --config_file accelerate_config.yaml train.py

九、数据工程:微调成功的关键

9.1 数据质量的重要性

┌─────────────────────────────────────────────────────────────────────────┐
│                    数据工程的黄金法则                                    │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│   "数据质量 > 数据数量 > 模型大小 > 训练技巧"                          │
│                                                                         │
│   ┌─────────────────────────────────────────────────────────────────┐  │
│   │                                                                 │  │
│   │         高质量数据 + 小模型  >  低质量数据 + 大模型             │  │
│   │                                                                 │  │
│   │   例: 1000条精心标注的数据 + 7B模型                            │  │
│   │       通常优于                                                  │  │
│   │       10000条自动生成的数据 + 13B模型                          │  │
│   │                                                                 │  │
│   └─────────────────────────────────────────────────────────────────┘  │
│                                                                         │
│   数据质量评估维度:                                                     │
│   ─────────────────                                                     │
│   ┌─────────────────┬───────────────────────────────────────────────┐  │
│   │     维度        │                    说明                        │  │
│   ├─────────────────┼───────────────────────────────────────────────┤  │
│   │   准确性        │  答案是否正确、完整                            │  │
│   │   相关性        │  答案是否与问题相关                            │  │
│   │   一致性        │  格式、风格是否统一                            │  │
│   │   多样性        │  是否覆盖各种情况                              │  │
│   │   复杂度分布    │  简单/中等/复杂问题的比例                      │  │
│   │   无害性        │  是否包含有害内容                              │  │
│   └─────────────────┴───────────────────────────────────────────────┘  │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

9.2 数据格式规范

"""
常见数据格式及其适用场景
"""

# ═══════════════════════════════════════════════════════════════════════════
# 格式1: Alpaca格式 (最常用)
# 适用: 通用指令微调
# ═══════════════════════════════════════════════════════════════════════════
alpaca_format = {
    "instruction": "将以下句子翻译成英文",
    "input": "今天天气很好",           # 可选字段
    "output": "The weather is nice today."
}

# ═══════════════════════════════════════════════════════════════════════════
# 格式2: ShareGPT/对话格式
# 适用: 多轮对话、角色扮演
# ═══════════════════════════════════════════════════════════════════════════
sharegpt_format = {
    "conversations": [
        {"from": "system", "value": "你是一个专业的医学助手"},
        {"from": "human", "value": "我最近经常头痛,可能是什么原因?"},
        {"from": "gpt", "value": "头痛可能由多种原因引起..."},
        {"from": "human", "value": "需要去医院检查吗?"},
        {"from": "gpt", "value": "如果头痛持续或加重,建议就医..."}
    ]
}

# ═══════════════════════════════════════════════════════════════════════════
# 格式3: OpenAI格式
# 适用: 与ChatGPT API兼容
# ═══════════════════════════════════════════════════════════════════════════
openai_format = {
    "messages": [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "Hello, who are you?"},
        {"role": "assistant", "content": "I am an AI assistant..."}
    ]
}

# ═══════════════════════════════════════════════════════════════════════════
# 格式4: 纯文本续写格式
# 适用: 继续预训练、领域适应
# ═══════════════════════════════════════════════════════════════════════════
text_format = {
    "text": "这里是一段长文本,用于继续预训练模型,让其学习领域知识..."
}

# ═══════════════════════════════════════════════════════════════════════════
# 格式5: 偏好数据格式 (DPO/RLHF)
# 适用: 人类偏好对齐
# ═══════════════════════════════════════════════════════════════════════════
preference_format = {
    "prompt": "写一首关于春天的诗",
    "chosen": "春风轻拂柳丝长,...",     # 更好的回答
    "rejected": "春天来了,很好看..."     # 较差的回答
}

9.3 数据清洗与增强

"""
数据清洗与增强工具函数
"""

import re
from typing import List, Dict
import random

class DataCleaner:
    """数据清洗类"""
    
    @staticmethod
    def clean_text(text: str) -> str:
        """基础文本清洗"""
        # 移除多余空白
        text = re.sub(r'\s+', ' ', text).strip()
        # 移除特殊字符
        text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', text)
        # 统一标点符号
        text = text.replace(',', ',').replace('。', '.')
        return text
    
    @staticmethod
    def filter_by_length(
        data: List[Dict],
        min_length: int = 10,
        max_length: int = 4096
    ) -> List[Dict]:
        """按长度过滤数据"""
        filtered = []
        for item in data:
            text = item.get("output", "") or item.get("text", "")
            if min_length <= len(text) <= max_length:
                filtered.append(item)
        return filtered
    
    @staticmethod
    def remove_duplicates(data: List[Dict], key: str = "instruction") -> List[Dict]:
        """去重"""
        seen = set()
        unique = []
        for item in data:
            identifier = item.get(key, "")
            if identifier not in seen:
                seen.add(identifier)
                unique.append(item)
        return unique
    
    @staticmethod
    def filter_quality(data: List[Dict]) -> List[Dict]:
        """质量过滤"""
        filtered = []
        for item in data:
            output = item.get("output", "")
            
            # 过滤太短的回答
            if len(output) < 20:
                continue
                
            # 过滤重复内容过多的回答
            words = output.split()
            if len(words) > 10 and len(set(words)) / len(words) < 0.3:
                continue
                
            # 过滤特定无意义模式
            skip_patterns = [
                r'^(I cannot|I don\'t know|I am an AI)',
                r'作为一个AI',
                r'我无法回答',
            ]
            skip = False
            for pattern in skip_patterns:
                if re.search(pattern, output, re.IGNORECASE):
                    skip = True
                    break
            
            if not skip:
                filtered.append(item)
                
        return filtered

class DataAugmenter:
    """数据增强类"""
    
    @staticmethod
    def paraphrase_instruction(instruction: str, variations: int = 3) -> List[str]:
        """
        指令改写 (示例,实际可使用LLM生成)
        """
        templates = [
            f"请{instruction}",
            f"帮我{instruction}",
            f"我需要你{instruction}",
            f"能否{instruction}",
            f"麻烦{instruction}",
        ]
        return random.sample(templates, min(variations, len(templates)))
    
    @staticmethod
    def add_context_variations(data: List[Dict]) -> List[Dict]:
        """添加上下文变体"""
        augmented = []
        for item in data:
            augmented.append(item)  # 保留原始
            
            # 添加系统提示变体
            system_prompts = [
                "你是一个专业的助手。",
                "请认真回答以下问题。",
                "作为专家,请给出你的建议。",
            ]
            
            for sys_prompt in random.sample(system_prompts, 1):
                new_item = item.copy()
                new_item["instruction"] = f"{sys_prompt}\n\n{item['instruction']}"
                augmented.append(new_item)
                
        return augmented

# ═══════════════════════════════════════════════════════════════════════════
# 数据处理Pipeline示例
# ═══════════════════════════════════════════════════════════════════════════

def process_data_pipeline(raw_data: List[Dict]) -> List[Dict]:
    """完整的数据处理流程"""
    
    cleaner = DataCleaner()
    augmenter = DataAugmenter()
    
    print(f"原始数据量: {len(raw_data)}")
    
    # 1. 基础清洗
    data = [{k: cleaner.clean_text(str(v)) for k, v in item.items()} 
            for item in raw_data]
    
    # 2. 长度过滤
    data = cleaner.filter_by_length(data, min_length=20, max_length=2000)
    print(f"长度过滤后: {len(data)}")
    
    # 3. 去重
    data = cleaner.remove_duplicates(data)
    print(f"去重后: {len(data)}")
    
    # 4. 质量过滤
    data = cleaner.filter_quality(data)
    print(f"质量过滤后: {len(data)}")
    
    # 5. 数据增强 (可选)
    # data = augmenter.add_context_variations(data)
    # print(f"增强后: {len(data)}")
    
    # 6. 打乱顺序
    random.shuffle(data)
    
    return data

9.4 数据配比策略

┌─────────────────────────────────────────────────────────────────────────┐
│                        数据配比最佳实践                                  │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│   场景1: 通用指令微调                                                   │
│   ────────────────────                                                  │
│   ┌───────────────────────────────────────────────────────────────┐    │
│   │  问答类    ████████████████████████  40%                      │    │
│   │  写作类    ████████████████          25%                      │    │
│   │  代码类    ██████████                15%                      │    │
│   │  推理类    ████████                  12%                      │    │
│   │  翻译类    █████                      8%                      │    │
│   └───────────────────────────────────────────────────────────────┘    │
│                                                                         │
│   场景2: 领域专业化 (如医疗)                                           │
│   ─────────────────────────────                                         │
│   ┌───────────────────────────────────────────────────────────────┐    │
│   │  领域数据   ████████████████████████████████  60%             │    │
│   │  通用数据   ████████████████████              30%             │    │
│   │  对话能力   ██████                            10%             │    │
│   └───────────────────────────────────────────────────────────────┘    │
│   注意: 保留部分通用数据防止能力退化                                   │
│                                                                         │
│   场景3: 多任务微调                                                     │
│   ────────────────                                                      │
│   ┌───────────────────────────────────────────────────────────────┐    │
│   │  主任务     ████████████████████████████  50%                 │    │
│   │  辅助任务1  ████████████                  20%                 │    │
│   │  辅助任务2  ████████                      15%                 │    │
│   │  通用能力   ████████                      15%                 │    │
│   └───────────────────────────────────────────────────────────────┘    │
│                                                                         │
│   数据规模建议:                                                         │
│   ────────────                                                          │
│   ┌─────────────────┬───────────────┬───────────────────────────────┐  │
│   │    任务类型     │   最少数据量  │          说明                  │  │
│   ├─────────────────┼───────────────┼───────────────────────────────┤  │
│   │  简单分类       │    500+       │  如情感分析                    │  │
│   │  信息抽取       │    1000+      │  如NER、关系抽取               │  │
│   │  指令跟随       │    2000+      │  通用指令理解                  │  │
│   │  领域问答       │    5000+      │  需要领域知识                  │  │
│   │  复杂推理       │    10000+     │  需要多步推理                  │  │
│   │  代码生成       │    20000+     │  需要语法正确性                │  │
│   └─────────────────┴───────────────┴───────────────────────────────┘  │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

十、评估与调优策略

10.1 评估指标体系

┌─────────────────────────────────────────────────────────────────────────┐
│                        微调评估指标体系                                  │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│   1. 自动评估指标                                                       │
│   ═══════════════                                                       │
│   ┌─────────────────────────────────────────────────────────────────┐  │
│   │                                                                 │  │
│   │   Loss相关:                                                     │  │
│   │   • Training Loss: 训练过程监控,应持续下降                     │  │
│   │   • Validation Loss: 验证集损失,用于早停判断                   │  │
│   │   • Perplexity: exp(loss),越低越好                            │  │
│   │                                                                 │  │
│   │   任务相关:                                                     │  │
│   │   • Accuracy: 分类任务准确率                                    │  │
│   │   • F1-Score: 不平衡数据集的综合指标                           │  │
│   │   • BLEU/ROUGE: 生成任务与参考文本的相似度                     │  │
│   │   • Exact Match: 精确匹配率                                     │  │
│   │                                                                 │  │
│   └─────────────────────────────────────────────────────────────────┘  │
│                                                                         │
│   2. LLM评估指标 (使用GPT-4等进行评估)                                 │
│   ═══════════════════════════════════════                               │
│   ┌─────────────────────────────────────────────────────────────────┐  │
│   │                                                                 │  │
│   │   MT-Bench: 多轮对话能力评估                                    │  │
│   │   AlpacaEval: 指令跟随能力评估                                  │  │
│   │   Arena-Hard: 复杂问题处理能力                                  │  │
│   │   自定义评估: 针对特定领域设计评估prompt                        │  │
│   │                                                                 │  │
│   └─────────────────────────────────────────────────────────────────┘  │
│                                                                         │
│   3. 人工评估 (最可靠但成本高)                                         │
│   ════════════════════════════                                          │
│   ┌─────────────────────────────────────────────────────────────────┐  │
│   │                                                                 │  │
│   │   评估维度:                                                     │  │
│   │   • 准确性 (1-5分): 回答是否正确                               │  │
│   │   • 完整性 (1-5分): 回答是否全面                               │  │
│   │   • 相关性 (1-5分): 回答是否切题                               │  │
│   │   • 流畅性 (1-5分): 表达是否自然                               │  │
│   │   • 有用性 (1-5分): 回答是否有帮助                             │  │
│   │                                                                 │  │
│   │   评估方法:                                                     │  │
│   │   • A/B测试: 对比微调前后                                      │  │
│   │   • 盲评: 隐藏模型来源进行评估                                 │  │
│   │   • 多人评估: 取平均或多数投票                                 │  │
│   │                                                                 │  │
│   └─────────────────────────────────────────────────────────────────┘  │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

10.2 超参数调优指南

┌─────────────────────────────────────────────────────────────────────────┐
│                        超参数调优速查表                                  │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│   问题诊断与解决:                                                       │
│   ═══════════════                                                       │
│                                                                         │
│   ┌─────────────────┬─────────────────┬─────────────────────────────┐  │
│   │     现象        │    可能原因     │         解决方案            │  │
│   ├─────────────────┼─────────────────┼─────────────────────────────┤  │
│   │                 │                 │                             │  │
│   │  Loss不下降     │  学习率太低     │  增大学习率 (2x-10x)        │  │
│   │                 │  数据问题       │  检查数据格式和标签         │  │
│   │                 │  梯度消失       │  检查梯度,使用bf16         │  │
│   │                 │                 │                             │  │
│   │  Loss震荡       │  学习率太高     │  降低学习率 (0.5x-0.1x)     │  │
│   │                 │  批次太小       │  增加batch_size或累积       │  │
│   │                 │                 │                             │  │
│   │  训练集Loss低   │  过拟合         │  增大dropout               │  │
│   │  验证集Loss高   │                 │  减小r值                   │  │
│   │                 │                 │  增加数据量                 │  │
│   │                 │                 │  早停                       │  │
│   │                 │                 │                             │  │
│   │  Loss快速降到   │  学习过快       │  减小学习率                 │  │
│   │  很低后不动     │  模型容量不足   │  增大r值                   │  │
│   │                 │                 │                             │  │
│   │  生成重复内容   │  温度太低       │  推理时增大temperature      │  │
│   │                 │  训练数据单一   │  增加数据多样性             │  │
│   │                 │                 │                             │  │
│   │  生成内容偏离   │  过拟合严重     │  减小训练epoch             │  │
│   │  原模型风格     │  数据分布差异大 │  混入通用数据               │  │
│   │                 │                 │                             │  │
│   │  显存不足       │  批次太大       │  减小batch_size            │  │
│   │  (OOM)          │  序列太长       │  减小max_length            │  │
│   │                 │                 │  启用gradient_checkpointing │  │
│   │                 │                 │  使用QLoRA                  │  │
│   │                 │                 │                             │  │
│   └─────────────────┴─────────────────┴─────────────────────────────┘  │
│                                                                         │
│   推荐调优顺序:                                                         │
│   ─────────────────                                                     │
│   1. 学习率 (最关键): 1e-5 → 1e-4 → 5e-4 → 1e-3                        │
│   2. 训练轮数: 1 → 3 → 5 (观察验证集loss)                              │
│   3. LoRA秩r: 8 → 16 → 32 (在loss饱和时尝试增大)                       │
│   4. 批次大小: 在显存允许范围内尽量大                                   │
│   5. Dropout: 如果过拟合再调整                                         │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

10.3 训练监控代码

"""
训练监控与可视化
"""

import matplotlib.pyplot as plt
from transformers import TrainerCallback
import numpy as np

class TrainingMonitorCallback(TrainerCallback):
    """训练监控回调"""
    
    def __init__(self):
        self.train_losses = []
        self.eval_losses = []
        self.learning_rates = []
        self.steps = []
        
    def on_log(self, args, state, control, logs=None, **kwargs):
        """记录训练指标"""
        if logs:
            if "loss" in logs:
                self.train_losses.append(logs["loss"])
                self.steps.append(state.global_step)
            if "eval_loss" in logs:
                self.eval_losses.append(logs["eval_loss"])
            if "learning_rate" in logs:
                self.learning_rates.append(logs["learning_rate"])
                
    def plot_metrics(self, save_path: str = None):
        """绘制训练曲线"""
        fig, axes = plt.subplots(1, 3, figsize=(15, 4))
        
        # Loss曲线
        axes[0].plot(self.steps, self.train_losses, label='Train Loss', alpha=0.7)
        if self.eval_losses:
            eval_steps = [s for i, s in enumerate(self.steps) 
                         if i < len(self.eval_losses)]
            axes[0].plot(eval_steps[:len(self.eval_losses)], 
                        self.eval_losses, label='Eval Loss', alpha=0.7)
        axes[0].set_xlabel('Steps')
        axes[0].set_ylabel('Loss')
        axes[0].set_title('Training & Validation Loss')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)
        
        # 学习率曲线
        if self.learning_rates:
            axes[1].plot(self.steps[:len(self.learning_rates)], 
                        self.learning_rates)
            axes[1].set_xlabel('Steps')
            axes[1].set_ylabel('Learning Rate')
            axes[1].set_title('Learning Rate Schedule')
            axes[1].grid(True, alpha=0.3)
        
        # Loss分布
        axes[2].hist(self.train_losses[-100:], bins=20, alpha=0.7)
        axes[2].set_xlabel('Loss')
        axes[2].set_ylabel('Frequency')
        axes[2].set_title('Recent Loss Distribution')
        axes[2].grid(True, alpha=0.3)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=150)
        plt.show()
        
    def print_summary(self):
        """打印训练摘要"""
        print("\n" + "="*60)
        print("Training Summary")
        print("="*60)
        print(f"Total steps: {len(self.steps)}")
        print(f"Initial loss: {self.train_losses[0]:.4f}")
        print(f"Final loss: {self.train_losses[-1]:.4f}")
        print(f"Best loss: {min(self.train_losses):.4f}")
        print(f"Loss reduction: {(1 - self.train_losses[-1]/self.train_losses[0])*100:.1f}%")
        if self.eval_losses:
            print(f"Best eval loss: {min(self.eval_losses):.4f}")
        print("="*60)

# 使用示例:
# monitor = TrainingMonitorCallback()
# trainer = Trainer(..., callbacks=[monitor])
# trainer.train()
# monitor.plot_metrics("training_curves.png")
# monitor.print_summary()

十一、部署与推理优化

11.1 模型部署选项

┌─────────────────────────────────────────────────────────────────────────┐
│                        模型部署方案对比                                  │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│   ┌─────────────────┬───────────┬───────────┬───────────┬───────────┐  │
│   │     方案        │  吞吐量   │  延迟     │  显存     │  难度     │  │
│   ├─────────────────┼───────────┼───────────┼───────────┼───────────┤  │
│   │ HF Transformers │   低      │   高      │   高      │   简单    │  │
│   │ vLLM            │   高      │   低      │   中      │   中等    │  │
│   │ TensorRT-LLM    │   最高    │   最低    │   低      │   复杂    │  │
│   │ llama.cpp       │   中      │   中      │   最低    │   简单    │  │
│   │ Ollama          │   中      │   中      │   低      │   最简    │  │
│   └─────────────────┴───────────┴───────────┴───────────┴───────────┘  │
│                                                                         │
│   推荐选择:                                                             │
│   ─────────                                                             │
│   • 快速原型: HF Transformers / Ollama                                 │
│   • 生产部署: vLLM (推荐) / TensorRT-LLM                               │
│   • 边缘设备: llama.cpp / Ollama                                       │
│   • 多LoRA切换: vLLM (原生支持动态LoRA)                                │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

11.2 vLLM部署LoRA模型

"""
使用vLLM部署LoRA微调模型
vLLM原生支持LoRA热加载,适合多租户场景
"""

# 安装: pip install vllm

from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest

# ═══════════════════════════════════════════════════════════════════════════
# 方式1: 合并后部署 (推荐用于单一任务)
# ═══════════════════════════════════════════════════════════════════════════

def deploy_merged_model():
    """部署合并后的模型"""
    
    llm = LLM(
        model="./merged_model",        # 合并后的模型路径
        tensor_parallel_size=1,         # GPU数量
        gpu_memory_utilization=0.9,     # GPU显存利用率
        max_model_len=4096,             # 最大序列长度
    )
    
    sampling_params = SamplingParams(
        temperature=0.7,
        top_p=0.9,
        max_tokens=512,
    )
    
    prompts = ["请介绍一下机器学习", "写一首诗"]
    outputs = llm.generate(prompts, sampling_params)
    
    for output in outputs:
        print(f"Prompt: {output.prompt}")
        print(f"Output: {output.outputs[0].text}")
        print("-" * 40)

# ═══════════════════════════════════════════════════════════════════════════
# 方式2: 动态LoRA加载 (适合多任务/多租户)
# ═══════════════════════════════════════════════════════════════════════════

def deploy_with_dynamic_lora():
    """动态加载不同LoRA适配器"""
    
    # 启用LoRA支持
    llm = LLM(
        model="Qwen/Qwen2-1.5B-Instruct",  # 基础模型
        enable_lora=True,                    # 启用LoRA
        max_lora_rank=64,                    # 最大LoRA秩
        max_loras=4,                         # 同时加载的最大LoRA数
    )
    
    sampling_params = SamplingParams(temperature=0.7, max_tokens=256)
    
    # 定义不同的LoRA适配器
    lora_medical = LoRARequest("medical", 1, "./lora_medical")
    lora_legal = LoRARequest("legal", 2, "./lora_legal")
    lora_code = LoRARequest("code", 3, "./lora_code")
    
    # 使用医疗LoRA
    medical_output = llm.generate(
        ["患者出现头痛症状,可能的原因是什么?"],
        sampling_params,
        lora_request=lora_medical
    )
    
    # 使用法律LoRA
    legal_output = llm.generate(
        ["劳动合同解除的法定条件有哪些?"],
        sampling_params,
        lora_request=lora_legal
    )
    
    # 不使用LoRA (原始模型)
    base_output = llm.generate(
        ["介绍一下Python编程语言"],
        sampling_params
    )
    
    return medical_output, legal_output, base_output

# ═══════════════════════════════════════════════════════════════════════════
# 方式3: API服务部署
# ═══════════════════════════════════════════════════════════════════════════

"""
启动API服务:

python -m vllm.entrypoints.openai.api_server \
    --model ./merged_model \
    --port 8000 \
    --tensor-parallel-size 1

或启用LoRA:

python -m vllm.entrypoints.openai.api_server \
    --model Qwen/Qwen2-1.5B-Instruct \
    --enable-lora \
    --lora-modules medical=./lora_medical legal=./lora_legal \
    --port 8000

客户端调用:

curl http://localhost:8000/v1/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": "medical",
        "prompt": "头痛的可能原因",
        "max_tokens": 256
    }'
"""

11.3 量化部署

"""
模型量化部署 - 进一步减小模型体积和推理资源需求
"""

# ═══════════════════════════════════════════════════════════════════════════
# GPTQ量化 (推荐用于GPU推理)
# ═══════════════════════════════════════════════════════════════════════════

from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig

def quantize_with_gptq():
    """使用GPTQ进行4bit量化"""
    
    model_id = "./merged_model"
    
    # GPTQ配置
    gptq_config = GPTQConfig(
        bits=4,                     # 量化位数
        dataset="c4",               # 校准数据集
        tokenizer=AutoTokenizer.from_pretrained(model_id),
    )
    
    # 加载并量化模型
    quantized_model = AutoModelForCausalLM.from_pretrained(
        model_id,
        quantization_config=gptq_config,
        device_map="auto",
    )
    
    # 保存量化后的模型
    quantized_model.save_pretrained("./model_gptq_4bit")
    
    return quantized_model

# ═══════════════════════════════════════════════════════════════════════════
# AWQ量化 (更快的推理速度)
# ═══════════════════════════════════════════════════════════════════════════

"""
使用AutoAWQ进行量化:

pip install autoawq

from awq import AutoAWQForCausalLM

model = AutoAWQForCausalLM.from_pretrained("./merged_model")
tokenizer = AutoTokenizer.from_pretrained("./merged_model")

# 量化配置
quant_config = {
    "zero_point": True,
    "q_group_size": 128,
    "w_bit": 4,
    "version": "GEMM"
}

# 执行量化
model.quantize(tokenizer, quant_config=quant_config)

# 保存
model.save_quantized("./model_awq_4bit")
"""

# ═══════════════════════════════════════════════════════════════════════════
# llama.cpp GGUF格式 (CPU/边缘设备)
# ═══════════════════════════════════════════════════════════════════════════

"""
转换为GGUF格式:

1. 首先合并LoRA权重
2. 使用llama.cpp转换脚本

git clone https://github.com/ggerganov/llama.cpp
cd llama.cpp

# 转换为GGUF
python convert.py ./merged_model --outfile model.gguf

# 量化
./quantize model.gguf model_q4_k_m.gguf Q4_K_M

量化级别说明:
- Q4_0: 4bit, 速度最快, 质量一般
- Q4_K_M: 4bit, 平衡速度和质量 (推荐)
- Q5_K_M: 5bit, 质量更好, 稍慢
- Q8_0: 8bit, 质量接近原始, 体积大
"""

十二、常见问题与解决方案

12.1 训练问题排查

┌─────────────────────────────────────────────────────────────────────────┐
│                        常见训练问题排查                                  │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│   问题1: CUDA Out of Memory (OOM)                                       │
│   ════════════════════════════════                                      │
│   原因: 显存不足                                                        │
│   解决方案 (按优先级):                                                  │
│   1. 减小 batch_size                                                    │
│   2. 减小 max_seq_length                                                │
│   3. 启用 gradient_checkpointing=True                                   │
│   4. 使用 QLoRA (4bit量化)                                              │
│   5. 减小 LoRA r值                                                      │
│   6. 使用 DeepSpeed ZeRO                                                │
│                                                                         │
│   问题2: Loss为NaN或Inf                                                 │
│   ═════════════════════════                                             │
│   原因: 数值溢出                                                        │
│   解决方案:                                                             │
│   1. 降低学习率                                                         │
│   2. 使用 bf16 而不是 fp16                                              │
│   3. 增加 warmup_ratio                                                  │
│   4. 启用梯度裁剪 max_grad_norm=1.0                                     │
│   5. 检查数据中是否有异常值                                             │
│                                                                         │
│   问题3: 训练速度慢                                                     │
│   ═════════════════════                                                 │
│   解决方案:                                                             │
│   1. 启用 Flash Attention 2                                             │
│   2. 使用 bf16 混合精度                                                 │
│   3. 增加 dataloader_num_workers                                        │
│   4. 使用 packing (多样本打包)                                          │
│   5. 检查IO瓶颈 (数据加载)                                              │
│                                                                         │
│   问题4: 验证Loss持续上升 (过拟合)                                      │
│   ════════════════════════════════════                                  │
│   解决方案:                                                             │
│   1. 减少训练epochs                                                     │
│   2. 增加 lora_dropout                                                  │
│   3. 减小 r 值                                                          │
│   4. 增加训练数据量                                                     │
│   5. 使用早停 (load_best_model_at_end=True)                            │
│                                                                         │
│   问题5: 微调后模型能力退化                                             │
│   ════════════════════════════════                                      │
│   原因: 灾难性遗忘                                                      │
│   解决方案:                                                             │
│   1. 混入通用数据 (10-30%)                                              │
│   2. 使用较小的学习率                                                   │
│   3. 减少训练步数                                                       │
│   4. 使用正则化 (weight_decay)                                          │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

12.2 代码级别Debug技巧

"""
调试技巧与工具
"""

import torch
from transformers import Trainer

# ═══════════════════════════════════════════════════════════════════════════
# 1. 检查梯度
# ═══════════════════════════════════════════════════════════════════════════

def check_gradients(model):
    """检查模型梯度"""
    for name, param in model.named_parameters():
        if param.requires_grad and param.grad is not None:
            grad = param.grad
            print(f"{name}:")
            print(f"  grad mean: {grad.mean().item():.6f}")
            print(f"  grad std:  {grad.std().item():.6f}")
            print(f"  grad max:  {grad.max().item():.6f}")
            print(f"  grad min:  {grad.min().item():.6f}")
            
            # 检查异常值
            if torch.isnan(grad).any():
                print(f"  ⚠️ WARNING: NaN gradients detected!")
            if torch.isinf(grad).any():
                print(f"  ⚠️ WARNING: Inf gradients detected!")

# ═══════════════════════════════════════════════════════════════════════════
# 2. 检查LoRA参数
# ═══════════════════════════════════════════════════════════════════════════

def inspect_lora_params(model):
    """检查LoRA参数状态"""
    lora_params = []
    frozen_params = []
    
    for name, param in model.named_parameters():
        if param.requires_grad:
            lora_params.append((name, param.numel()))
        else:
            frozen_params.append((name, param.numel()))
    
    print("=" * 60)
    print("LoRA Parameters (trainable):")
    print("=" * 60)
    total_lora = 0
    for name, count in lora_params:
        print(f"  {name}: {count:,}")
        total_lora += count
    print(f"\nTotal LoRA params: {total_lora:,}")
    
    total_frozen = sum(count for _, count in frozen_params)
    print(f"Total frozen params: {total_frozen:,}")
    print(f"Trainable ratio: {total_lora/(total_lora+total_frozen)*100:.4f}%")

# ═══════════════════════════════════════════════════════════════════════════
# 3. 数据检查
# ═══════════════════════════════════════════════════════════════════════════

def verify_dataset(dataset, tokenizer, num_samples=5):
    """验证数据集格式"""
    print("=" * 60)
    print("Dataset Verification")
    print("=" * 60)
    
    for i in range(min(num_samples, len(dataset))):
        sample = dataset[i]
        print(f"\nSample {i}:")
        print(f"  input_ids length: {len(sample['input_ids'])}")
        print(f"  labels length: {len(sample['labels'])}")
        
        # 检查labels
        non_masked = sum(1 for l in sample['labels'] if l != -100)
        print(f"  non-masked labels: {non_masked}")
        
        # 解码查看
        decoded = tokenizer.decode(sample['input_ids'][:50])
        print(f"  decoded (first 50 tokens): {decoded[:100]}...")
        
        # 检查对齐
        if len(sample['input_ids']) != len(sample['labels']):
            print("  ⚠️ WARNING: input_ids and labels length mismatch!")

# ═══════════════════════════════════════════════════════════════════════════
# 4. 显存监控
# ═══════════════════════════════════════════════════════════════════════════

def print_gpu_memory():
    """打印GPU显存使用情况"""
    if torch.cuda.is_available():
        for i in range(torch.cuda.device_count()):
            allocated = torch.cuda.memory_allocated(i) / 1024**3
            reserved = torch.cuda.memory_reserved(i) / 1024**3
            total = torch.cuda.get_device_properties(i).total_memory / 1024**3
            print(f"GPU {i}:")
            print(f"  Allocated: {allocated:.2f} GB")
            print(f"  Reserved:  {reserved:.2f} GB")
            print(f"  Total:     {total:.2f} GB")
            print(f"  Free:      {total - reserved:.2f} GB")

# ═══════════════════════════════════════════════════════════════════════════
# 5. 自定义Trainer用于调试
# ═══════════════════════════════════════════════════════════════════════════

class DebugTrainer(Trainer):
    """调试用Trainer"""
    
    def training_step(self, model, inputs):
        """重写训练步骤以添加调试信息"""
        loss = super().training_step(model, inputs)
        
        # 每100步打印一次详细信息
        if self.state.global_step % 100 == 0:
            print(f"\n[Step {self.state.global_step}]")
            print(f"  Loss: {loss.item():.4f}")
            print_gpu_memory()
            
        return loss

十三、总结与展望

13.1 核心要点回顾

┌─────────────────────────────────────────────────────────────────────────┐
│                        大模型微调核心要点                                │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│   1. 方法选择                                                           │
│   ════════════                                                          │
│   • LoRA是当前最主流的PEFT方法,适用于绝大多数场景                     │
│   • QLoRA在资源受限时是最佳选择,几乎不损失性能                        │
│   • 全参微调在数据充足+追求极致性能时考虑                              │
│                                                                         │
│   2. 数据工程                                                           │
│   ════════════                                                          │
│   • 数据质量 > 数据数量,1000条高质量数据优于10000条低质量数据         │
│   • 注意数据格式一致性,prompt模板很重要                               │
│   • 混入通用数据防止能力退化                                           │
│                                                                         │
│   3. 训练技巧                                                           │
│   ════════════                                                          │
│   • 从小r值开始实验 (r=8),根据效果逐步增大                            │
│   • 学习率通常在1e-4到5e-4之间                                         │
│   • 使用验证集监控过拟合,必要时早停                                   │
│   • bf16混合精度+gradient_checkpointing节省显存                        │
│                                                                         │
│   4. 部署优化                                                           │
│   ════════════                                                          │
│   • 训练后可合并LoRA权重,消除推理开销                                 │
│   • vLLM是生产部署的推荐选择                                           │
│   • 量化可进一步减小模型体积                                           │
│                                                                         │
│   5. 持续迭代                                                           │
│   ════════════                                                          │
│   • 建立完善的评估体系                                                 │
│   • 收集线上反馈数据持续优化                                           │
│   • 定期rebase到更新的基座模型                                         │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐