vLLM-Ascend中LoRA核心算子逻辑与整体流程解析

LoRA(Low-Rank Adaptation)是大模型高效微调的核心技术,其核心原理是在原有模型权重旁插入低秩矩阵(A和B),通过优化低秩矩阵实现模型适配,推理时将低秩矩阵的输出与原模型输出叠加。vLLM中针对LoRA的推理优化集中在lora/ops/torch_ops/lora_ops.py文件,以下结合LoRA数学原理和vLLM整体流程,详细解析每个函数的作用、变量含义与执行逻辑。

先明确LoRA核心计算公式

LoRA的核心是对模型的线性层(如Attention的Q/K/V投影层)做低秩分解:
假设原线性层为 y=Wxy = Wxy=WxW∈Rd×kW \in \mathbb{R}^{d \times k}WRd×k),LoRA引入:

  • A∈Rr×kA \in \mathbb{R}^{r \times k}ARr×k(降秩矩阵,r为低秩,远小于d/k)
  • B∈Rd×rB \in \mathbb{R}^{d \times r}BRd×r(升秩矩阵)
    最终输出:y=Wx+BAx×α/ry = Wx + BAx \times \alpha/ry=Wx+BAx×α/rα\alphaα是缩放系数,r是秩,用于平衡LoRA贡献)

vLLM中sgmv(Scaled General Matrix-Vector)系列函数是对LoRA中BAxBAxBAx计算的工程实现,expand/shrink对应LoRA在不同推理阶段(预填充/解码)的计算逻辑,slice则是针对大张量的分片优化。

函数逐模块解析

1. sgmv_expand:预填充阶段LoRA升维计算(核心)

def sgmv_expand(
    inputs: torch.Tensor,
    lora_b_weights: torch.Tensor,
    output_tensor: torch.Tensor,
    b_seq_start_loc: torch.Tensor,
    seq_len_tensor: torch.Tensor,
    lora_indices_tensor: torch.Tensor,
    batches: int,
    max_seq_length: int,
    token_nums: int,
    add_inputs: bool = False,
):
    exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor)

    bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, add_inputs)
变量含义
变量名 类型 核心含义 关联LoRA逻辑
inputs Tensor 模型输入张量(对应公式中的x) LoRA计算的输入,形状通常为[token_nums, hidden_dim](token总数×隐藏层维度)
lora_b_weights Tensor LoRA的B矩阵权重(升秩矩阵) 对应公式中的B,形状为[num_loras, out_dim, rank](LoRA数量×输出维度×秩)
output_tensor Tensor 输出张量(最终要叠加到原模型输出) 存储BAxBAxBAx的计算结果,后续与原模型输出叠加
b_seq_start_loc Tensor 每个序列在batch中的起始位置 处理多序列并行推理时的序列边界,比如batch中有3个序列,起始位置可能是[0, 5, 8]
seq_len_tensor Tensor 每个序列的长度 比如3个序列的长度是[5, 3, 4]
lora_indices_tensor Tensor 每个序列对应的LoRA ID 多LoRA推理时,标记每个序列用哪个LoRA(比如[0, 1, 0]表示序列1用LoRA0,序列2用LoRA1)
batches int 批处理中的序列数量 对应lora_indices_tensor的长度
max_seq_length int batch中最长序列的长度 用于对齐张量形状,避免维度不匹配
token_nums int batch中所有序列的总token数 等于seq_len_tensor求和,是inputs的第一维大小
add_inputs bool 是否叠加到原输出(True=叠加,False=直接赋值) 对应LoRA公式中Wx+BAxWx + BAxWx+BAx的“+”操作
执行逻辑
  • 第一步:exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor)
    「序列级LoRA ID」扩展为「token级LoRA ID」。比如lora_indices_tensor=[0,1]seq_len_tensor=[3,2],扩展后是[0,0,0,1,1],目的是让每个token都知道自己该用哪个LoRA(因为一个序列内的所有token共享同一个LoRA)。
  • 第二步:调用bgmv_expand,把扩展后的LoRA ID传入,执行核心的BAxBAxBAx计算(这里A矩阵的计算在sgmv_shrink中,先完成A的降秩,再到B的升秩)。

2. bgmv_expand:LoRA B矩阵的核心计算(升秩)

def bgmv_expand(
    inputs: torch.Tensor,
    lora_b_weights: torch.Tensor,
    output_tensor: torch.Tensor,
    lora_indices_tensor: torch.Tensor,
    add_inputs: bool = True,
):
    selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
    if len(selected_loras.shape) == 4:
        selected_loras = selected_loras.squeeze(dim=1)
    inputs = inputs.to(dtype=output_tensor.dtype)
    outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)

    limit = output_tensor.shape[0]
    if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
        limit = 1

    common_len = min(outputs.shape[1], output_tensor.shape[1])

    if add_inputs:
        output_tensor[:, :common_len] += outputs[:limit, :common_len]
    else:
        output_tensor[:, :common_len] = outputs[:limit, :common_len]
变量补充(新增/关键)
变量名 核心含义
selected_loras 按token级LoRA ID筛选出的B矩阵,形状从[num_loras, out_dim, rank]变为[token_nums, out_dim, rank]
outputs BAxBAxBAx的计算结果(这里inputs是AxAxAx,乘B后得到升维结果)
limit 输出张量的行数限制(处理单序列vs多序列的边界情况)
common_len 输出与目标张量的最小列数(避免维度越界)
逐行执行逻辑
  1. selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
    根据每个token的LoRA ID,从所有LoRA的B矩阵中筛选出当前token需要的B矩阵,并转换为与输出张量相同的数据类型(保证计算精度一致)。比如lora_b_weights是[2, 1024, 64](2个LoRA,输出维度1024,秩64),lora_indices_tensor是[0,0,1],则selected_loras是[3, 1024, 64]。

  2. if len(selected_loras.shape) == 4: selected_loras = selected_loras.squeeze(dim=1)
    处理特殊情况:如果筛选后的B矩阵多了一个维度(比如形状是[token_nums, 1, out_dim, rank]),则挤压掉第1维,保证形状是[token_nums, out_dim, rank](兼容不同模型的LoRA权重格式)。

  3. inputs = inputs.to(dtype=output_tensor.dtype)
    将输入张量(AxAxAx)转换为与输出张量相同的类型,避免类型不匹配导致的计算错误。

  4. outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
    LoRA核心计算步骤:用爱因斯坦求和实现Ax×BAx \times BAx×B(升秩)。

    • 输入形状:inputs是[token_nums, rank](AxAxAx的结果,rank是低秩维度),selected_loras是[token_nums, out_dim, rank](B矩阵);
    • 求和规则:bi(token×秩) × boi(token×输出维度×秩) → bo(token×输出维度);
    • 对应公式:BAxBAxBAx中的B(Ax)B(Ax)B(Ax),完成从低秩(rank)到原模型输出维度(out_dim)的升维。
  5. limit = output_tensor.shape[0]; if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: limit = 1
    处理边界:如果输出只有1行(单token)但目标张量有多行,限制只取第1行,避免维度越界。

  6. common_len = min(outputs.shape[1], output_tensor.shape[1])
    取输出和目标张量的最小列数:比如outputs列数是1024,output_tensor列数是512,则只取前512列(兼容不同模型的输出维度裁剪)。

  7. 最后分支:

    • add_inputs=Trueoutput_tensor[:, :common_len] += outputs[:limit, :common_len] → 对应LoRA公式中的Wx+BAxWx + BAxWx+BAx,将LoRA计算结果叠加到原模型输出;
    • add_inputs=False:直接赋值,用于初始化LoRA输出张量(比如首次计算时)。

3. sgmv_shrink:预填充阶段LoRA降维计算(A矩阵)

def sgmv_shrink(
    inputs: torch.Tensor,
    lora_a_weights: torch.Tensor,
    output_tensor: torch.Tensor,
    b_seq_start_loc: torch.Tensor,
    seq_len_tensor: torch.Tensor,
    lora_indices_tensor: torch.Tensor,
    batches: int,
    max_seq_length: int,
    token_nums: int,
    scaling: float,
):
    exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor)

    bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, scaling)
变量新增/核心
变量名 类型 核心含义 关联LoRA逻辑
lora_a_weights Tensor LoRA的A矩阵权重(降秩矩阵) 对应公式中的A,形状为[num_loras, rank, in_dim]
scaling float 缩放系数(对应LoRA公式中的α/r\alpha/rα/r 平衡LoRA贡献的超参数
执行逻辑
  • 第一步:和sgmv_expand一样,将序列级LoRA ID扩展为token级(exploded_indices),保证每个token匹配正确的A矩阵;
  • 第二步:调用bgmv_shrink执行A矩阵的核心降维计算(AxAxAx)。

4. bgmv_shrink:LoRA A矩阵的核心计算(降秩)

def bgmv_shrink(
    inputs: torch.Tensor,
    lora_b_weights: torch.Tensor,  # 注:变量名笔误,实际是lora_a_weights
    output_tensor: torch.Tensor,
    lora_indices_tensor: torch.Tensor,
    scaling: float = 1.0,
):
    selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
    if len(selected_loras.shape) == 4:
        selected_loras = selected_loras.squeeze(dim=1)
    inputs = inputs.to(dtype=output_tensor.dtype)
    outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)

    output_tensor[:, : outputs.shape[1]] = scaling * outputs[:]
关键说明
  • 变量名lora_b_weights是笔误,实际传入的是lora_a_weights(A矩阵),这是代码实现中的命名不规范,需注意;
  • 核心作用是计算Ax×α/rAx \times \alpha/rAx×α/r(降秩),对应LoRA公式的第一步:将原输入从高维(in_dim)降到低秩(rank)。
逐行执行逻辑
  1. selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
    筛选当前token需要的A矩阵(降秩矩阵),形状从[num_loras, rank, in_dim]变为[token_nums, rank, in_dim],并转换数据类型。

  2. if len(selected_loras.shape) == 4: selected_loras = selected_loras.squeeze(dim=1)
    挤压多余维度,保证A矩阵形状为[token_nums, rank, in_dim]

  3. inputs = inputs.to(dtype=output_tensor.dtype)
    输入张量(原模型输入x,形状[token_nums, in_dim])转换为目标类型。

  4. outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
    LoRA降维核心计算:实现AxAxAx

    • 输入形状:inputs是[token_nums, in_dim](原输入x),selected_loras是[token_nums, rank, in_dim](A矩阵);
    • 求和规则:bi(token×输入维度) × boi(token×秩×输入维度) → bo(token×秩);
    • 结果:将高维输入降到低秩维度(rank),得到AxAxAx
  5. output_tensor[:, : outputs.shape[1]] = scaling * outputs[:]
    乘以缩放系数α/r\alpha/rα/rscaling),并将结果写入输出张量。比如scaling=0.1(对应α=64,r=64\alpha=64, r=64α=64,r=64),则Ax×0.1Ax \times 0.1Ax×0.1,平衡LoRA的贡献强度。

5. sgmv_expand_slice:分片版LoRA升维计算

def sgmv_expand_slice(
    inputs: torch.Tensor,
    lora_b_weights: torch.Tensor,
    output_tensor: torch.Tensor,
    b_seq_start_loc: torch.Tensor,
    seq_len_tensor: torch.Tensor,
    lora_indices_tensor: torch.Tensor,
    batches: int,
    max_seq_length: int,
    token_nums: int,
    slice_offset: int,
    slice_size: int,
    add_inputs: bool = False,
):
    exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor)

    bgmv_expand_slice(
        inputs,
        lora_b_weights,
        output_tensor,
        exploded_indices,
        slice_offset,
        slice_size,
        add_inputs,
    )
新增变量
变量名 核心含义
slice_offset int
slice_size int
执行逻辑
  • sgmv_expand逻辑一致,唯一区别是:为了处理超大张量(比如输出维度10240),将计算结果写入输出张量的指定分片区间(而非整个张量),避免一次性占用过多显存。

6. bgmv_expand_slice:分片版B矩阵计算

def bgmv_expand_slice(
    inputs: torch.Tensor,
    lora_b_weights: torch.Tensor,
    output_tensor: torch.Tensor,
    lora_indices_tensor: torch.Tensor,
    slice_offset: int,
    slice_size: int,
    add_inputs: bool = True,
):
    selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
    inputs = inputs.to(dtype=output_tensor.dtype)
    if len(selected_loras.shape) == 4:
        selected_loras = selected_loras.squeeze(dim=1)
    outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)

    if add_inputs:
        output_tensor[:, slice_offset : slice_offset + slice_size] += outputs[:]
    else:
        output_tensor[:, slice_offset : slice_offset + slice_size] = outputs[:]
核心差异
  • 计算完BAxBAxBAx后,不是写入整个output_tensor,而是写入slice_offsetslice_offset+slice_size的区间;
  • 比如输出维度是2048,slice_offset=1024,slice_size=512,则只写入1024~1536列;
  • 适配vLLM的“分块预填充(Chunked prefill)”特性,将大序列拆分成小块计算,降低显存占用。

vLLM-Ascend中LoRA整体执行流程

结合vLLM的推理流程(预填充+解码),LoRA的完整执行路径如下:

阶段1:预填充(Prefill)——处理首个token序列

  1. 输入阶段:用户输入的prompt被tokenize后,形成inputs张量(形状[token_nums, in_dim]);
  2. 降维计算(A矩阵):
    • 调用sgmv_shrink → 扩展LoRA ID到token级 → 调用bgmv_shrink
    • 计算Ax×α/rAx \times \alpha/rAx×α/r,得到低秩张量([token_nums, rank]);
  3. 升维计算(B矩阵):
    • 调用sgmv_expand(或分片版sgmv_expand_slice) → 扩展LoRA ID到token级 → 调用bgmv_expand
    • 计算B(Ax×α/r)B(Ax \times \alpha/r)B(Ax×α/r),得到升维张量([token_nums, out_dim]);
  4. 结果叠加:将LoRA计算结果(BAx×α/rBAx \times \alpha/rBAx×α/r)叠加到原模型线性层输出(WxWxWx),得到最终输出。

阶段2:解码(Decode)——处理后续生成的token

  • 逻辑与预填充一致,但由于解码阶段每次只处理1个token(token_nums=1),seq_len_tensor恒为[1],exploded_indices也仅含1个LoRA ID;
  • 分片版函数(sgmv_expand_slice)基本不会被调用(无需分片),主要走sgmv_expand/sgmv_shrink

多LoRA并行推理的特殊处理

vLLM支持同时加载多个LoRA适配器(如max_loras=4),核心逻辑在lora_indices_tensor

  • 每个序列(prompt)绑定一个LoRA ID;
  • 通过torch.repeat_interleave将序列级ID扩展为token级,保证每个token使用正确的LoRA矩阵;
  • 筛选lora_a_weights/lora_b_weights时,按ID索引到对应LoRA的权重,实现多LoRA并行推理。

关键总结

  1. LoRA的核心计算分为“降秩(A矩阵)+升秩(B矩阵)+缩放+叠加”四步,对应sgmv_shrinksgmv_expand的调用链;
  2. sgmv_*函数负责处理多序列的LoRA ID扩展,bgmv_*函数负责核心的矩阵乘法;
  3. 分片版函数(*_slice)是显存优化手段,适配超大模型/超长序列的推理;
  4. 变量名lora_b_weightsbgmv_shrink中是笔误,实际对应A矩阵,需注意代码实现的细节;
  5. add_inputs参数控制“叠加”还是“赋值”,对应LoRA公式中Wx+BAxWx + BAxWx+BAx的核心逻辑。

通过这套算子设计,vLLM实现了LoRA推理的高效并行:既保证多LoRA适配器的灵活切换,又通过张量操作和分片优化降低显存占用,适配大模型高吞吐量推理的需求。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐