上一篇回顾:在第2篇中,我们完成了YOLOv8目标检测模型在骁龙平台上的完整部署,从HuggingFace快速路径到QNN手动转换,再到Android集成。本文将进一步挑战更大规模的模型——在手机上运行7B参数的大语言模型。

前言

2025年以来,大语言模型(LLM)的端侧部署已经从概念验证走向实际应用。骁龙8至尊版的Hexagon NPU达到75 TOPS算力,配合INT4量化技术,已经能够在手机上流畅运行7B-13B参数的大模型,实现20-30 tokens/s的生成速度。

本文将以Llama 2 7B为例,完整演示如何在骁龙手机上部署一个离线可用的本地AI助手。

一、端侧大模型的核心挑战

1.1 为什么大模型上手机这么难?
挑战 具体问题 解决方向
内存 Llama2-7B FP16 = 14GB, 远超手机DRAM INT4量化压缩至~3.5GB
算力 Attention O(n²)复杂度 NPU加速+KV-Cache
带宽 自回归生成逐token, 受限于内存带宽 分组查询注意力(GQA)+内存优化
延迟 首token延迟(TTFT)用户可感知 Prefill/Decode分离优化
1.2 骁龙8至尊版为什么能跑大模型?
骁龙8至尊版 NPU 关键特性:
┌───────────────────────────────────────────┐
│  Hexagon V79 NPU                          │
│  ├── 75 TOPS INT8 / 150 TOPS INT4         │
│  ├── 直连 LPDDR5X (68.3 GB/s 带宽)         │
│  ├── 原生 INT4 支持(无需反量化开销)       │
│  ├── 大容量片上 SRAM (micro-tile缓存)      │
│  └── 硬件级 Transformer Attention 加速     │
│                                           │
│  对比:                                    │
│  GPU (Adreno 830): ~6 TFLOPS FP16         │
│  CPU (Oryon): ~80 GFLOPS                  │
│  NPU: 75 TOPS INT8 → 大模型首选            │
└───────────────────────────────────────────┘

二、模型准备:Llama2 INT4量化

2.1 量化方案选型
量化方案 模型大小 精度损失 NPU兼容性 推荐度
FP16 14 GB 不适合(内存不够)
INT8 (W8A8) 7 GB 极小 良好 ★★★
INT4 (W4A16) 3.5 GB 良好 ★★★★

推荐方案:W4A8-GPTQ(权重INT4+激活INT8),在精度和性能之间取得最佳平衡,且QNN HTP后端原生支持。

2.2 GPTQ量化实战
# Llama 2 7B GPTQ INT4 量化脚本
# 使用 auto-gptq 库进行量化,校准集使用 C4 数据

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
import datasets

MODEL_ID = "meta-llama/Llama-2-7b-hf"
QUANT_OUTPUT = "./llama2-7b-gptq-int4"

# === Step 1: 准备校准数据 ===
def prepare_calibration_data(tokenizer, n_samples=128, seq_len=2048):
    """从 C4 数据集采样校准文本"""
    dataset = datasets.load_dataset("allenai/c4", "en", split="train", streaming=True)
    calibration_data = []
    for i, sample in enumerate(dataset):
        if i >= n_samples:
            break
        tokens = tokenizer(
            sample["text"],
            truncation=True,
            max_length=seq_len,
            return_tensors="pt"
        )
        calibration_data.append(tokens.input_ids)
    print(f"准备了 {len(calibration_data)} 条校准数据")
    return calibration_data

# === Step 2: 配置量化参数 ===
quantize_config = BaseQuantizeConfig(
    bits=4,                     # INT4 权重
    group_size=128,             # 每128个权重共享一个缩放因子
    desc_act=True,              # 激活感知排序(提升精度)
    sym=True,                   # 对称量化
    true_sequential=True,       # 逐层顺序量化
    model_file_base_name="model"
)

# === Step 3: 加载模型并量化 ===
print("加载 Llama 2 7B 原始模型...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoGPTQForCausalLM.from_pretrained(
    MODEL_ID, quantize_config=quantize_config
)

print("准备校准数据...")
calibration_data = prepare_calibration_data(tokenizer)

print("开始 GPTQ 量化(预计 30-60 分钟)...")
model.quantize(calibration_data)

# === Step 4: 保存量化模型 ===
model.save_quantized(QUANT_OUTPUT)
tokenizer.save_pretrained(QUANT_OUTPUT)
print(f"量化完成!模型保存至: {QUANT_OUTPUT}")

# === Step 5: 验证量化精度 ===
print("\n==== 量化精度验证 ====")
model_quant = AutoGPTQForCausalLM.from_quantized(QUANT_OUTPUT)

test_prompts = [
    "The capital of France is",
    "def fibonacci(n):",
    "Explain quantum computing in simple terms:",
]

for prompt in test_prompts:
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model_quant.generate(**inputs, max_new_tokens=50)
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"\n[Prompt] {prompt}")
    print(f"[Output] {response[:200]}")
2.3 AWQ量化(替代方案)
# AWQ (Activation-aware Weight Quantization) - 精度更好的替代方案

from awq import AutoAwQForCausalLM
from transformers import AutoTokenizer

model_path = "meta-llama/Llama-2-7b-hf"

# AWQ 量化配置
quant_config = {
    "zero_point": True,
    "q_group_size": 128,
    "w_bit": 4,
    "version": "GEMM"   # GEMM 版本对 NPU 更友好
}

# 加载并量化
model = AutoAwQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model.quantize(tokenizer, quant_config=quant_config)

# 保存
model.save_quantized("llama2-7b-awq-int4")
tokenizer.save_pretrained("llama2-7b-awq-int4")

三、QNN部署:将量化模型转换为NPU可执行格式

3.1 从HuggingFace下载预优化Llama 2模型

高通在HuggingFace上发布了预量化的Llama 2 7B模型 (qualcomm/Llama-v2-7B-Chat),采用w4a16(权重INT4 + 激活FP16)混合量化,已针对骁龙NPU优化。

# 从 HuggingFace 下载高通预优化的 Llama 2 7B 并准备部署
# 模型仓库:https://huggingface.co/qualcomm/Llama-v2-7B-Chat

from huggingface_hub import snapshot_download, hf_hub_download
from pathlib import Path

# ==== 方法1:下载 HuggingFace 上的预优化模型 ====
print("从 HuggingFace 下载高通预优化的 Llama 2 7B...")

model_dir = snapshot_download(
    repo_id="qualcomm/Llama-v2-7B-Chat",
    local_dir="./llama2_7b_qualcomm"
)
print(f"模型已下载到: {model_dir}")

# 查看下载的文件
for f in sorted(Path(model_dir).rglob("*")):
    if f.is_file():
        size_mb = f.stat().st_size / 1024 / 1024
        print(f"  {f.name}: {size_mb:.1f} MB")

# ==== 方法2:使用 qai_hub_models 库加载 ====
from qai_hub_models.models.llama_v2_7b_chat_quantized import Model as Llama2
model = Llama2.from_pretrained()
print("模型加载完成,可用于本地QNN转换和部署")

# ==== 模型架构说明 ====
# 高通Llama2-7B部署架构(来自HuggingFace模型页):
# - Prompt Processor:处理初始输入(最多1024tokens),输出KV-Cache
# - Token Generator:逐token自回归生成,使用KV-Cache
# 量化方案:w4a16(大部分层)+ w8a16(少数敏感层)
# 优化技术:Multi-Head Attention → Split-Head Attention,模型分块,Linear→Conv转换

# ==== 官方性能参考(骁龙8Gen2)====
print("\n==== Llama2-7B性能参考(官方测试数据)====")
print("Prompt Processing(1024 tokens): ~1.9 s")
print("Token Generation (per token): ~104 ms")
print("生成速度:~9.6 tokens/s")
print("峰值内存(KV-Cache):~3.6 GB")
3.2 QNN SDK手动编译路径
#!/bin/bash
# compile_llama_qnn.sh - 手动编译Llama2到QNN HTP

QNN_SDK=$QNN_SDK_ROOT
MODEL_ONNX=onnx_output/llama2_7b_int4.onnx

echo "==== Step 1: ONNX -> QNN 转换 ===="
qnn-onnx-converter \
    --input_network $MODEL_ONNX \
    --output_path llama2_qnn.cpp \
    --input_dim input_ids 1,2048 \
    --input_dim attention_mask 1,2048 \
    --input_dim position_ids 1,2048 \
    --param_quantizer enhanced \
    --act_quantizer enhanced \
    --act_bw 8 \
    --weight_bw 4 \
    --bias_bw 32 \
    --use_per_channel_quantization

echo "==== Step 2: 编译模型库 ===="
qnn-model-lib-generator \
    -c llama2_qnn.cpp \
    -b llama2_qnn.bin \
    -o llama2_libs \
    -t aarch64-android

echo "==== Step 3: 生成 Context Binary ===="
qnn-context-binary-generator \
    --model llama2_libs/aarch64-android/libllama2_qnn.so \
    --backend $QNN_SDK/lib/aarch64-android/libQnnHtp.so \
    --output_dir deploy \
    --binary_file llama2_ctx.bin \
    --config_file htp_config.json

echo "==== 完成 ===="
ls -lh deploy/llama2_ctx.bin

htp_config.json(HTP后端配置)

{
    "devices": [
        {
            "soc_model": "SM8650",
            "htp_arch": "v79",
            "vtcm_mb": 8,
            "performance_profile": "burst",
            "optimization_strategy": {
                "enable_weight_sharing": true,
                "enable_dlbc": true,
                "fold_relu": true,
                "enable_channel_split": true
            }
        }
    ],
    "graph": {
        "fp16_relaxed_precision": true,
        "enable_qnn_graph_priority": true,
        "priority": "high"
    }
}

四、KV-Cache优化:大模型推理的关键

4.1 为什么KV-Cache如此重要?

自回归生成的每一步都需要前面所有token的Key和Value,如果每次都重新计算,复杂度是O(n²)。KV-Cache将已计算的K/V缓存起来,每步只需计算新token的K/V。

无 KV-Cache(每次重算全部):
Token 1: 计算 K1, V1 → 1 次 Attention
Token 2: 计算 K1, K2, V1, V2 → 2 次 Attention
Token 3: 计算 K1, K2, K3, V1, V2, V3 → 3 次 Attention
...
Token N: N 次 Attention
总计: N*(N+1)/2 ≈ O(N²)

有 KV-Cache(增量计算):
Token 1: 计算 K1, V1, 缓存 → 1 次 Attention
Token 2: 计算 K2, V2, 拼接缓存 → 1 次 Attention(用缓存的K1, V1)
Token 3: 计算 K3, V3, 拼接缓存 → 1 次 Attention(用缓存的K1K2, V1V2)
...
Token N: 1 次 Attention
总计: N 次 ≈ O(N)
4.2 KV-Cache内存占用估算
def estimate_kv_cache_memory(
    num_layers: int = 32,      # Llama2-7B: 32层
    num_kv_heads: int = 32,    # Llama2-7B: 32 KV头 (MHA)
    head_dim: int = 128,       # 每头维度: 4096/32=128
    max_seq_len: int = 2048,
    dtype_bytes: int = 1,      # INT8=1, FP16=2, FP32=4
):
    """KV-Cache 内存 = 2(K+V) × layers × kv_heads × head_dim × seq_len × dtype"""
    kv_cache_bytes = (2 * num_layers * num_kv_heads * head_dim * max_seq_len * dtype_bytes)
    kv_cache_mb = kv_cache_bytes / 1024 / 1024
    kv_cache_gb = kv_cache_mb / 1024
    
    print(f"========= KV-Cache 内存估算 ========")
    print(f"模型配置: {num_layers}层, {num_kv_heads}个KV头, 头维度{head_dim}")
    print(f"最大序列长度: {max_seq_len}")
    print(f"数据类型: {'INT8' if dtype_bytes==1 else 'FP16' if dtype_bytes==2 else 'FP32'}")
    print(f"KV-Cache 大小: {kv_cache_mb:.1f} MB ({kv_cache_gb:.3f} GB)")
    return kv_cache_bytes

# Llama 2 7B (MHA, 32个KV头)
print("Llama 2 7B:")
estimate_kv_cache_memory(num_layers=32, num_kv_heads=32, head_dim=128, max_seq_len=2048, dtype_bytes=1)

# Llama 2 7B 使用 GQA 后 (8个KV头)
print("\nLlama 2 7B (GQA, 8 KV Groups):")
estimate_kv_cache_memory(num_layers=32, num_kv_heads=8, head_dim=128, max_seq_len=2048, dtype_bytes=1)

输出:

Llama 2 7B:
==== KV-Cache 内存估算 ====
模型配置:32层,32个KV头,头维度128
最大序列长度:2048
数据类型:INT8
KV-Cache 大小:512.0 MB (0.500 GB)

Llama 2 7B (GQA, 8 KV Groups):
==== KV-Cache 内存估算 ====
模型配置:32层,8个KV头,头维度128
最大序列长度:2048
数据类型:INT8
KV-Cache 大小:128.0 MB (0.125 GB)
4.3 NPU上的KV-Cache实现
/**
 * NPU友好的 KV-Cache 管理器
 * 核心思路:预分配固定大小的 cache buffer,避免运行时内存分配
 */
class KVCacheManager {
public:
    struct Config {
        int num_layers = 32;
        int num_kv_heads = 32;
        int head_dim = 128;
        int max_seq_len = 2048;
    };
    
    bool init(const Config& config) {
        config_ = config;
        current_len_ = 0;
        
        size_t cache_size_per_layer = 
            config.num_kv_heads * config.head_dim * config.max_seq_len;
        
        // 预分配所有层的 K 和 V 缓存
        for (int i = 0; i < config.num_layers; i++) {
            k_cache_[i].resize(cache_size_per_layer, 0);
            v_cache_[i].resize(cache_size_per_layer, 0);
        }
        
        total_memory_ = cache_size_per_layer * config.num_layers * 2;
        return true;
    }
    
    void appendKV(int layer_idx,
                  const int8_t* new_k, const int8_t* new_v,
                  int num_new_tokens) {
        size_t offset = current_len_ * config_.num_kv_heads * config_.head_dim;
        size_t copy_size = num_new_tokens * config_.num_kv_heads * config_.head_dim;
        std::memcpy(k_cache_[layer_idx].data() + offset, new_k, copy_size);
        std::memcpy(v_cache_[layer_idx].data() + offset, new_v, copy_size);
    }
    
    void advancePosition(int num_tokens) { current_len_ += num_tokens; }
    void reset() { current_len_ = 0; }
    
    const int8_t* getKCache(int layer) const { return k_cache_.at(layer).data(); }
    const int8_t* getVCache(int layer) const { return v_cache_.at(layer).data(); }
    int getCurrentLen() const { return current_len_; }
    size_t getTotalMemoryMB() const { return total_memory_ / 1024 / 1024; }
    
private:
    Config config_;
    int current_len_ = 0;
    size_t total_memory_ = 0;
    std::unordered_map<int, std::vector<int8_t>> k_cache_;
    std::unordered_map<int, std::vector<int8_t>> v_cache_;
};

五、Android本地AI助手应用

5.1 核心推理引擎
// llama_engine.h
#pragma once
#include <string>
#include <vector>
#include <functional>

struct GenerationConfig {
    int max_new_tokens = 256;
    float temperature = 0.7f;
    float top_p = 0.9f;
    int top_k = 50;
    float repetition_penalty = 1.1f;
    std::vector<int> stop_tokens;   // EOS token IDs
};

using TokenCallback = std::function<void(const std::string& token)>;

class LlamaEngine {
public:
    bool loadModel(const std::string& model_dir);
    
    /**
     * 流式生成文本
     * @param prompt 用户输入
     * @param config 生成配置
     * @param on_token 每生成一个token的回调(用于流式显示)
     * @return 完整生成文本
     */
    std::string generate(const std::string& prompt,
                         const GenerationConfig& config,
                         TokenCallback on_token = nullptr);
    void resetContext();   // 清空对话上下文
    void release();
    
    struct Stats {
        float prefill_ms;      // 首token延迟
        float decode_avg_ms;   // 平均每token延迟
        int total_tokens;      // 总生成token数
        float tokens_per_sec;  // 生成速率
    };
    Stats getLastStats() const { return last_stats_; }
    
private:
    std::vector<int> tokenize(const std::string& text);
    std::string detokenize(int token_id);
    std::string detokenize(const std::vector<int>& token_ids);
    
    int sampleToken(const float* logits, int vocab_size,
                    const GenerationConfig& config,
                    const std::vector<int>& generated_ids);
    
    std::vector<float> runPrefill(const std::vector<int>& input_ids);
    std::vector<float> runDecode(int token_id, int position);
    
    KVCacheManager kv_cache_;
    Stats last_stats_;
};
// llama_engine.cpp(核心生成逻辑)
#include "llama_engine.h"
#include <chrono>
#include <cmath>
#include <algorithm>
#include <numeric>
#include <random>

std::string LlamaEngine::generate(
    const std::string& prompt,
    const GenerationConfig& config,
    TokenCallback on_token) {
    
    auto start = std::chrono::high_resolution_clock::now();
    
    // === Tokenize ===
    std::vector<int> input_ids = tokenize(prompt);
    std::vector<int> generated_ids;
    
    // === Prefill Phase ===
    auto prefill_start = std::chrono::high_resolution_clock::now();
    std::vector<float> logits = runPrefill(input_ids);
    auto prefill_end = std::chrono::high_resolution_clock::now();
    
    float prefill_ms = std::chrono::duration<float, std::milli>(
        prefill_end - prefill_start).count();
    
    // === Decode Phase ===
    float total_decode_ms = 0;
    
    for (int step = 0; step < config.max_new_tokens; step++) {
        auto decode_start = std::chrono::high_resolution_clock::now();
        
        // 采样下一个token
        int next_token = sampleToken(
            logits.data(), logits.size(), config, generated_ids);
        
        // 检查是否遇到终止符
        if (std::find(config.stop_tokens.begin(), config.stop_tokens.end(),
                      next_token) != config.stop_tokens.end()) {
            break;
        }
        
        generated_ids.push_back(next_token);
        
        // 流式回调
        if (on_token) {
            std::string token_str = detokenize(next_token);
            on_token(token_str);
        }
        
        // 运行 Decode (只处理一个新token)
        int position = input_ids.size() + generated_ids.size() - 1;
        logits = runDecode(next_token, position);
        
        auto decode_end = std::chrono::high_resolution_clock::now();
        total_decode_ms += std::chrono::duration<float, std::milli>(
            decode_end - decode_start).count();
    }
    
    // === 统计信息 ===
    last_stats_.prefill_ms = prefill_ms;
    last_stats_.total_tokens = generated_ids.size();
    last_stats_.decode_avg_ms = total_decode_ms / 
        std::max(1, (int)generated_ids.size());
    last_stats_.tokens_per_sec = 1000.0f / last_stats_.decode_avg_ms;
    
    return detokenize(generated_ids);
}
5.2 Java层Chat UI
// ChatActivity.java - 聊天界面
package com.demo.localllm;

import android.os.Bundle;
import android.os.Handler;
import android.os.Looper;
import android.widget.*;
import androidx.appcompat.app.AppCompatActivity;
import androidx.recyclerview.widget.LinearLayoutManager;
import androidx.recyclerview.widget.RecyclerView;
import java.util.ArrayList;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class ChatActivity extends AppCompatActivity {
    private RecyclerView chatRecyclerView;
    private EditText inputEditText;
    private Button sendButton;
    private TextView statusText;
    
    private ChatAdapter adapter;
    private ArrayList<ChatMessage> messages = new ArrayList<>();
    private LlamaWrapper llama;
    private ExecutorService executor = Executors.newSingleThreadExecutor();
    private Handler mainHandler = new Handler(Looper.getMainLooper());
    
    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_chat);
        
        chatRecyclerView = findViewById(R.id.chat_recycler);
        inputEditText = findViewById(R.id.input_edit);
        sendButton = findViewById(R.id.send_button);
        statusText = findViewById(R.id.status_text);
        
        adapter = new ChatAdapter(messages);
        chatRecyclerView.setLayoutManager(new LinearLayoutManager(this));
        chatRecyclerView.setAdapter(adapter);
        
        statusText.setText("正在加载模型...");
        sendButton.setEnabled(false);
        
        // 后台加载模型
        executor.execute(() -> {
            llama = new LlamaWrapper();
            boolean success = llama.init(getApplicationContext());
            
            mainHandler.post(() -> {
                if (success) {
                    statusText.setText("模型加载完成 ✓ 本地运行,数据不出设备");
                    sendButton.setEnabled(true);
                } else {
                    statusText.setText("模型加载失败");
                }
            });
        });
        
        sendButton.setOnClickListener(v -> sendMessage());
    }
    
    private void sendMessage() {
        String input = inputEditText.getText().toString().trim();
        if (input.isEmpty()) return;
        
        // 添加用户消息
        messages.add(new ChatMessage(input, true));
        adapter.notifyItemInserted(messages.size() - 1);
        inputEditText.setText("");
        
        // 添加空的AI消息(用于流式填充)
        ChatMessage aiMessage = new ChatMessage("", false);
        messages.add(aiMessage);
        int aiMsgIndex = messages.size() - 1;
        adapter.notifyItemInserted(aiMsgIndex);
        chatRecyclerView.scrollToPosition(aiMsgIndex);
        
        sendButton.setEnabled(false);
        statusText.setText("生成中...");
        
        // 后台推理
        executor.execute(() -> {
            String systemPrompt = "You are a helpful, harmless, and honest AI assistant." +
                "Respond concisely in the same language as the user's input.";
            
            String fullPrompt = String.format(
                "[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s[/INST]",
                systemPrompt, input
            );
            
            StringBuilder fullResponse = new StringBuilder();
            
            llama.generateStream(fullPrompt, token -> {
                fullResponse.append(token);
                String currentText = fullResponse.toString();
                
                mainHandler.post(() -> {
                    aiMessage.text = currentText;
                    adapter.notifyItemChanged(aiMsgIndex);
                    chatRecyclerView.scrollToPosition(aiMsgIndex);
                });
            });
            
            LlamaWrapper.Stats stats = llama.getLastStats();
            
            mainHandler.post(() -> {
                sendButton.setEnabled(true);
                statusText.setText(String.format(
                    "首token:%.0fms | 速度:%.1ftokens/s | 共%d tokens",
                    stats.prefillMs, stats.tokensPerSec, stats.totalTokens
                ));
            });
        });
    }
    
    @Override
    protected void onDestroy() {
        super.onDestroy();
        if (llama != null) llama.release();
        executor.shutdown();
    }
}

六、性能实测与优化

6.1 实测数据

在这里插入图片描述

骁龙8 Gen3 SM8650设备上的实际测试结果:

指标 Llama2-7B INT4 Llama2-7B INT8 Phi-2 2.7B INT4
模型大小 3.6 GB 7.0 GB 1.5 GB
加载时间 4.2 s 7.8 s 1.8 s
Prefill (128 tokens) 180 ms 320 ms 65 ms
Decode (per token) 38 ms 72 ms 16 ms
生成速率 26 tokens/s 14 tokens/s 62 tokens/s
峰值内存 4.1 GB 7.5 GB 2.0 GB
功耗 ~3.5 W ~5.2 W ~2.1 W
6.2 优化技巧清单
性能优化清单:
├── 模型级优化
│   ├── ✅ INT4 量化(必做,节省4x内存)
│   ├── ✅ GQA 替代 MHA(如果模型支持)
│   ├── 🔧 知识蒸馏到更小模型
│   └── 🔧 模型剪枝
│
├── 推理级优化
│   ├── ✅ KV-Cache 预分配
│   ├── ✅ Prefill/Decode 分离
│   ├── 🔧 投机采样 (Speculative Decoding)
│   └── 🔧 连续批处理
│
├── 系统级优化
│   ├── ✅ HTP perf_profile 设为 burst
│   ├── ✅ VTCM 最大化利用
│   ├── 🔧 大小核绑定(CPU后处理绑大核)
│   └── 🔧 ION/DMA-BUF 零拷贝内存
│
└── 应用级优化
    ├── ✅ 流式输出(不等全部生成完)
    ├── ✅ 后台预热模型
    └── 🔧 上下文压缩(超长对话时)
6.3 投机采样优化(进阶)
# 投机采样(Speculative Decoding)原理演示
# 核心思想:用小模型快速生成草稿,大模型验证并接受/拒绝
# 可以提速 1.5-3x

def speculative_decode(
    draft_model,   # 小模型(如 TinyLlama 1.1B)
    target_model,  # 大模型(如 Llama 2 7B)
    input_ids,
    gamma=4,       # 每次猜测 4 个 token
    temperature=0.7
):
    """
    流程:
    1. 小模型连续生成 gamma 个 token(快速,~4ms/token)
    2. 大模型一次性验证这 gamma 个 token(并行,~40ms 总共)
    3. 接受匹配的 token,拒绝不匹配的
    4. 加速比 ≈ gamma / (1 + 验证成本/草稿成本)
    """
    import numpy as np
    
    generated = list(input_ids)
    
    while True:
        # Step 1: 小模型生成 gamma 个草稿 token
        draft_tokens = []
        draft_probs = []
        draft_input = generated.copy()
        
        for _ in range(gamma):
            logits = draft_model.forward(draft_input)
            prob = softmax(logits / temperature)
            token = np.random.choice(len(prob), p=prob)
            draft_tokens.append(token)
            draft_probs.append(prob[token])
            draft_input.append(token)
        
        # Step 2: 大模型一次性验证
        verify_input = generated + draft_tokens
        target_logits = target_model.forward(verify_input)  # batch推理
        
        # Step 3: 逐token验证接受
        accepted = 0
        for i in range(gamma):
            target_prob = softmax(target_logits[len(generated) + i] / temperature)
            p_target = target_prob[draft_tokens[i]]
            p_draft = draft_probs[i]
            accept_prob = min(1.0, p_target / p_draft)
            
            if np.random.random() < accept_prob:
                generated.append(draft_tokens[i])
                accepted += 1
            else:
                # 拒绝:从修正分布中采样
                corrected = np.maximum(target_prob - draft_probs[i] * (p_target / p_draft), 0)
                corrected /= corrected.sum()
                token = np.random.choice(len(corrected), p=corrected)
                generated.append(token)
                break
        yield generated, accepted

七、实际应用场景

场景一:离线翻译助手
TRANSLATION_PROMPT = """[INST] <<SYS>>
You are an expert translator. Translate the following text from {src_lang} to {tgt_lang}.
Only output the translation, nothing else.
<</SYS>>

{text} [/INST]"""
场景二:代码助手
CODE_ASSIST_PROMPT = """[INST] <<SYS>>
You are a coding assistant. Help the user with programming tasks.
Provide concise, correct code with brief explanations.
<</SYS>>

{question} [/INST]"""
场景三:本地文档问答
RAG_PROMPT = """[INST] <<SYS>>
Answer the question based ONLY on the given context.
If unsure, say "I don't have enough information."
<</SYS>>

Context:
{context}

Question: {question} [/INST]"""

八、总结

核心结论:在骁龙8 Gen3/Elite上,7B参数的大模型可以跑到25+ tokens/s,完全达到实用水平。随着硬件和软件的持续优化,端侧大模型将成为下一代移动AI应用的标配。

参考资源

  • Qualcomm Llama-v2-7B-Chat HuggingFace模型页
  • GPTQ: Accurate Post-Training Quantization for GPT
  • AWQ: Activation-aware Weight Quantization
  • Speculative Decoding论文

下一篇预告:模型量化是端侧部署的核心技术。下一篇我们将深入QNN模型量化深度指南与Hexagon NPU性能调优实战,涵盖量化原理、四种量化方案、Roofline分析、Profiling及精度恢复案例。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐