vLLM-Ascend中LoRA核心算子逻辑与整体流程解析
LoRA的核心计算分为“降秩(A矩阵)+升秩(B矩阵)+缩放+叠加”四步,对应→的调用链;sgmv_*函数负责处理多序列的LoRA ID扩展,bgmv_*函数负责核心的矩阵乘法;分片版函数(*_slice)是显存优化手段,适配超大模型/超长序列的推理;变量名在中是笔误,实际对应A矩阵,需注意代码实现的细节;add_inputs参数控制“叠加”还是“赋值”,对应LoRA公式中WxBAxWx + BA
vLLM-Ascend中LoRA核心算子逻辑与整体流程解析
LoRA(Low-Rank Adaptation)是大模型高效微调的核心技术,其核心原理是在原有模型权重旁插入低秩矩阵(A和B),通过优化低秩矩阵实现模型适配,推理时将低秩矩阵的输出与原模型输出叠加。vLLM中针对LoRA的推理优化集中在lora/ops/torch_ops/lora_ops.py文件,以下结合LoRA数学原理和vLLM整体流程,详细解析每个函数的作用、变量含义与执行逻辑。
先明确LoRA核心计算公式
LoRA的核心是对模型的线性层(如Attention的Q/K/V投影层)做低秩分解:
假设原线性层为 y=Wxy = Wxy=Wx(W∈Rd×kW \in \mathbb{R}^{d \times k}W∈Rd×k),LoRA引入:
- A∈Rr×kA \in \mathbb{R}^{r \times k}A∈Rr×k(降秩矩阵,r为低秩,远小于d/k)
- B∈Rd×rB \in \mathbb{R}^{d \times r}B∈Rd×r(升秩矩阵)
最终输出:y=Wx+BAx×α/ry = Wx + BAx \times \alpha/ry=Wx+BAx×α/r(α\alphaα是缩放系数,r是秩,用于平衡LoRA贡献)
vLLM中sgmv(Scaled General Matrix-Vector)系列函数是对LoRA中BAxBAxBAx计算的工程实现,expand/shrink对应LoRA在不同推理阶段(预填充/解码)的计算逻辑,slice则是针对大张量的分片优化。
函数逐模块解析
1. sgmv_expand:预填充阶段LoRA升维计算(核心)
def sgmv_expand(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
add_inputs: bool = False,
):
exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor)
bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, add_inputs)
变量含义
| 变量名 | 类型 | 核心含义 | 关联LoRA逻辑 |
|---|---|---|---|
inputs |
Tensor | 模型输入张量(对应公式中的x) | LoRA计算的输入,形状通常为[token_nums, hidden_dim](token总数×隐藏层维度) |
lora_b_weights |
Tensor | LoRA的B矩阵权重(升秩矩阵) | 对应公式中的B,形状为[num_loras, out_dim, rank](LoRA数量×输出维度×秩) |
output_tensor |
Tensor | 输出张量(最终要叠加到原模型输出) | 存储BAxBAxBAx的计算结果,后续与原模型输出叠加 |
b_seq_start_loc |
Tensor | 每个序列在batch中的起始位置 | 处理多序列并行推理时的序列边界,比如batch中有3个序列,起始位置可能是[0, 5, 8] |
seq_len_tensor |
Tensor | 每个序列的长度 | 比如3个序列的长度是[5, 3, 4] |
lora_indices_tensor |
Tensor | 每个序列对应的LoRA ID | 多LoRA推理时,标记每个序列用哪个LoRA(比如[0, 1, 0]表示序列1用LoRA0,序列2用LoRA1) |
batches |
int | 批处理中的序列数量 | 对应lora_indices_tensor的长度 |
max_seq_length |
int | batch中最长序列的长度 | 用于对齐张量形状,避免维度不匹配 |
token_nums |
int | batch中所有序列的总token数 | 等于seq_len_tensor求和,是inputs的第一维大小 |
add_inputs |
bool | 是否叠加到原输出(True=叠加,False=直接赋值) | 对应LoRA公式中Wx+BAxWx + BAxWx+BAx的“+”操作 |
执行逻辑
- 第一步:
exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor)
「序列级LoRA ID」扩展为「token级LoRA ID」。比如lora_indices_tensor=[0,1],seq_len_tensor=[3,2],扩展后是[0,0,0,1,1],目的是让每个token都知道自己该用哪个LoRA(因为一个序列内的所有token共享同一个LoRA)。 - 第二步:调用
bgmv_expand,把扩展后的LoRA ID传入,执行核心的BAxBAxBAx计算(这里A矩阵的计算在sgmv_shrink中,先完成A的降秩,再到B的升秩)。
2. bgmv_expand:LoRA B矩阵的核心计算(升秩)
def bgmv_expand(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True,
):
selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
if len(selected_loras.shape) == 4:
selected_loras = selected_loras.squeeze(dim=1)
inputs = inputs.to(dtype=output_tensor.dtype)
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
limit = output_tensor.shape[0]
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
limit = 1
common_len = min(outputs.shape[1], output_tensor.shape[1])
if add_inputs:
output_tensor[:, :common_len] += outputs[:limit, :common_len]
else:
output_tensor[:, :common_len] = outputs[:limit, :common_len]
变量补充(新增/关键)
| 变量名 | 核心含义 |
|---|---|
selected_loras |
按token级LoRA ID筛选出的B矩阵,形状从[num_loras, out_dim, rank]变为[token_nums, out_dim, rank] |
outputs |
BAxBAxBAx的计算结果(这里inputs是AxAxAx,乘B后得到升维结果) |
limit |
输出张量的行数限制(处理单序列vs多序列的边界情况) |
common_len |
输出与目标张量的最小列数(避免维度越界) |
逐行执行逻辑
-
selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
根据每个token的LoRA ID,从所有LoRA的B矩阵中筛选出当前token需要的B矩阵,并转换为与输出张量相同的数据类型(保证计算精度一致)。比如lora_b_weights是[2, 1024, 64](2个LoRA,输出维度1024,秩64),lora_indices_tensor是[0,0,1],则selected_loras是[3, 1024, 64]。 -
if len(selected_loras.shape) == 4: selected_loras = selected_loras.squeeze(dim=1)
处理特殊情况:如果筛选后的B矩阵多了一个维度(比如形状是[token_nums, 1, out_dim, rank]),则挤压掉第1维,保证形状是[token_nums, out_dim, rank](兼容不同模型的LoRA权重格式)。 -
inputs = inputs.to(dtype=output_tensor.dtype)
将输入张量(AxAxAx)转换为与输出张量相同的类型,避免类型不匹配导致的计算错误。 -
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
LoRA核心计算步骤:用爱因斯坦求和实现Ax×BAx \times BAx×B(升秩)。- 输入形状:
inputs是[token_nums, rank](AxAxAx的结果,rank是低秩维度),selected_loras是[token_nums, out_dim, rank](B矩阵); - 求和规则:
bi(token×秩) ×boi(token×输出维度×秩) →bo(token×输出维度); - 对应公式:BAxBAxBAx中的B(Ax)B(Ax)B(Ax),完成从低秩(rank)到原模型输出维度(out_dim)的升维。
- 输入形状:
-
limit = output_tensor.shape[0]; if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: limit = 1
处理边界:如果输出只有1行(单token)但目标张量有多行,限制只取第1行,避免维度越界。 -
common_len = min(outputs.shape[1], output_tensor.shape[1])
取输出和目标张量的最小列数:比如outputs列数是1024,output_tensor列数是512,则只取前512列(兼容不同模型的输出维度裁剪)。 -
最后分支:
add_inputs=True:output_tensor[:, :common_len] += outputs[:limit, :common_len]→ 对应LoRA公式中的Wx+BAxWx + BAxWx+BAx,将LoRA计算结果叠加到原模型输出;add_inputs=False:直接赋值,用于初始化LoRA输出张量(比如首次计算时)。
3. sgmv_shrink:预填充阶段LoRA降维计算(A矩阵)
def sgmv_shrink(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
scaling: float,
):
exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor)
bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, scaling)
变量新增/核心
| 变量名 | 类型 | 核心含义 | 关联LoRA逻辑 |
|---|---|---|---|
lora_a_weights |
Tensor | LoRA的A矩阵权重(降秩矩阵) | 对应公式中的A,形状为[num_loras, rank, in_dim] |
scaling |
float | 缩放系数(对应LoRA公式中的α/r\alpha/rα/r) | 平衡LoRA贡献的超参数 |
执行逻辑
- 第一步:和
sgmv_expand一样,将序列级LoRA ID扩展为token级(exploded_indices),保证每个token匹配正确的A矩阵; - 第二步:调用
bgmv_shrink执行A矩阵的核心降维计算(AxAxAx)。
4. bgmv_shrink:LoRA A矩阵的核心计算(降秩)
def bgmv_shrink(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor, # 注:变量名笔误,实际是lora_a_weights
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
):
selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
if len(selected_loras.shape) == 4:
selected_loras = selected_loras.squeeze(dim=1)
inputs = inputs.to(dtype=output_tensor.dtype)
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
output_tensor[:, : outputs.shape[1]] = scaling * outputs[:]
关键说明
- 变量名
lora_b_weights是笔误,实际传入的是lora_a_weights(A矩阵),这是代码实现中的命名不规范,需注意; - 核心作用是计算Ax×α/rAx \times \alpha/rAx×α/r(降秩),对应LoRA公式的第一步:将原输入从高维(in_dim)降到低秩(rank)。
逐行执行逻辑
-
selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
筛选当前token需要的A矩阵(降秩矩阵),形状从[num_loras, rank, in_dim]变为[token_nums, rank, in_dim],并转换数据类型。 -
if len(selected_loras.shape) == 4: selected_loras = selected_loras.squeeze(dim=1)
挤压多余维度,保证A矩阵形状为[token_nums, rank, in_dim]。 -
inputs = inputs.to(dtype=output_tensor.dtype)
输入张量(原模型输入x,形状[token_nums, in_dim])转换为目标类型。 -
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
LoRA降维核心计算:实现AxAxAx。- 输入形状:
inputs是[token_nums, in_dim](原输入x),selected_loras是[token_nums, rank, in_dim](A矩阵); - 求和规则:
bi(token×输入维度) ×boi(token×秩×输入维度) →bo(token×秩); - 结果:将高维输入降到低秩维度(rank),得到AxAxAx。
- 输入形状:
-
output_tensor[:, : outputs.shape[1]] = scaling * outputs[:]
乘以缩放系数α/r\alpha/rα/r(scaling),并将结果写入输出张量。比如scaling=0.1(对应α=64,r=64\alpha=64, r=64α=64,r=64),则Ax×0.1Ax \times 0.1Ax×0.1,平衡LoRA的贡献强度。
5. sgmv_expand_slice:分片版LoRA升维计算
def sgmv_expand_slice(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
slice_offset: int,
slice_size: int,
add_inputs: bool = False,
):
exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor)
bgmv_expand_slice(
inputs,
lora_b_weights,
output_tensor,
exploded_indices,
slice_offset,
slice_size,
add_inputs,
)
新增变量
| 变量名 | 核心含义 |
|---|---|
slice_offset |
int |
slice_size |
int |
执行逻辑
- 与
sgmv_expand逻辑一致,唯一区别是:为了处理超大张量(比如输出维度10240),将计算结果写入输出张量的指定分片区间(而非整个张量),避免一次性占用过多显存。
6. bgmv_expand_slice:分片版B矩阵计算
def bgmv_expand_slice(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True,
):
selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
inputs = inputs.to(dtype=output_tensor.dtype)
if len(selected_loras.shape) == 4:
selected_loras = selected_loras.squeeze(dim=1)
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
if add_inputs:
output_tensor[:, slice_offset : slice_offset + slice_size] += outputs[:]
else:
output_tensor[:, slice_offset : slice_offset + slice_size] = outputs[:]
核心差异
- 计算完BAxBAxBAx后,不是写入整个
output_tensor,而是写入slice_offset到slice_offset+slice_size的区间; - 比如输出维度是2048,slice_offset=1024,slice_size=512,则只写入1024~1536列;
- 适配vLLM的“分块预填充(Chunked prefill)”特性,将大序列拆分成小块计算,降低显存占用。
vLLM-Ascend中LoRA整体执行流程
结合vLLM的推理流程(预填充+解码),LoRA的完整执行路径如下:
阶段1:预填充(Prefill)——处理首个token序列
- 输入阶段:用户输入的prompt被tokenize后,形成
inputs张量(形状[token_nums, in_dim]); - 降维计算(A矩阵):
- 调用
sgmv_shrink→ 扩展LoRA ID到token级 → 调用bgmv_shrink; - 计算Ax×α/rAx \times \alpha/rAx×α/r,得到低秩张量([token_nums, rank]);
- 调用
- 升维计算(B矩阵):
- 调用
sgmv_expand(或分片版sgmv_expand_slice) → 扩展LoRA ID到token级 → 调用bgmv_expand; - 计算B(Ax×α/r)B(Ax \times \alpha/r)B(Ax×α/r),得到升维张量([token_nums, out_dim]);
- 调用
- 结果叠加:将LoRA计算结果(BAx×α/rBAx \times \alpha/rBAx×α/r)叠加到原模型线性层输出(WxWxWx),得到最终输出。
阶段2:解码(Decode)——处理后续生成的token
- 逻辑与预填充一致,但由于解码阶段每次只处理1个token(
token_nums=1),seq_len_tensor恒为[1],exploded_indices也仅含1个LoRA ID; - 分片版函数(
sgmv_expand_slice)基本不会被调用(无需分片),主要走sgmv_expand/sgmv_shrink。
多LoRA并行推理的特殊处理
vLLM支持同时加载多个LoRA适配器(如max_loras=4),核心逻辑在lora_indices_tensor:
- 每个序列(prompt)绑定一个LoRA ID;
- 通过
torch.repeat_interleave将序列级ID扩展为token级,保证每个token使用正确的LoRA矩阵; - 筛选
lora_a_weights/lora_b_weights时,按ID索引到对应LoRA的权重,实现多LoRA并行推理。
关键总结
- LoRA的核心计算分为“降秩(A矩阵)+升秩(B矩阵)+缩放+叠加”四步,对应
sgmv_shrink→sgmv_expand的调用链; sgmv_*函数负责处理多序列的LoRA ID扩展,bgmv_*函数负责核心的矩阵乘法;- 分片版函数(
*_slice)是显存优化手段,适配超大模型/超长序列的推理; - 变量名
lora_b_weights在bgmv_shrink中是笔误,实际对应A矩阵,需注意代码实现的细节; add_inputs参数控制“叠加”还是“赋值”,对应LoRA公式中Wx+BAxWx + BAxWx+BAx的核心逻辑。
通过这套算子设计,vLLM实现了LoRA推理的高效并行:既保证多LoRA适配器的灵活切换,又通过张量操作和分片优化降低显存占用,适配大模型高吞吐量推理的需求。
更多推荐



所有评论(0)