AI-大语言模型LLM-模型微调6-LoRA
LoRA(Low-Rank Adaptation)是一种通过低秩矩阵分解来高效微调大模型的参数高效微调方法。其核心思想是:工程经验证明:大模型在适应新任务时,其参数变化具有低秩特性。因此可以用极小的参数代价实现高质量的模型适配。秩是矩阵中线性无关的行或列的最大数量,反映了矩阵包含的独立信息维度。例如,一个秩为2的矩阵,其所有行或列都可以由2个独立的基底组合而成,就像一个三维空间的影子实际上可能只由
目的
为避免一学就会、一用就废,这里做下笔记
说明
- 本文内容紧承前文-模型微调1-基础理论,欲渐进,请循序
- 前面学完了4种不同的微调方法,这里选择第5种微调方法LoRA进行学习和实战

LoRA介绍

核心理念:低秩分解的智能适配
LoRA(Low-Rank Adaptation)是一种通过低秩矩阵分解来高效微调大模型的参数高效微调方法。其核心思想是:工程经验证明:大模型在适应新任务时,其参数变化具有低秩特性。因此可以用极小的参数代价实现高质量的模型适配。
秩的概念:
秩是矩阵中线性无关的行或列的最大数量,反映了矩阵包含的独立信息维度。例如,一个秩为2的矩阵,其所有行或列都可以由2个独立的基底组合而成,就像一个三维空间的影子实际上可能只由两个方向决定。秩越小,矩阵包含的冗余信息越多。
核心原理:低秩更新矩阵
基本公式
对于预训练权重矩阵 W ∈ R d × k W ∈ ℝ^{d×k} W∈Rd×k,LoRA保持其冻结,只训练两个小的低秩矩阵:
前向计算: h = W + Δ W = W + B ⋅ A h = W + ΔW = W + B·A h=W+ΔW=W+B⋅A
其中:
- B ∈ R d × r , A ∈ R r × k B ∈ ℝ^{d×r},A ∈ ℝ^{r×k} B∈Rd×r,A∈Rr×k
- r 是远小于 d 和 k 的秩(通常为4-64)
- 训练时更新 A 和 B,推理时合并 ΔW = B·A 到 W 中
参数效率对比
| 方法 | 训练参数量 | 存储开销 | 推理延迟 |
|---|---|---|---|
| 全量微调 | 100% | 每个任务一个完整模型 | 无增加 |
| LoRA | 0.1%-1% | 仅存低秩矩阵 | 几乎无增加 |
| Prefix Tuning | 0.5%-5% | 存储前缀向量 | 略有增加 |
关键设计优势
1. 无推理延迟
- 训练完成后可将 BA 合并到 W 中
- 推理时与原始模型结构完全一致
- 不增加任何计算复杂度
2. 模块化与可组合性
多个LoRA适配器可以线性组合:
W f i n a l = W + α 1 B 1 A 1 + α 2 B 2 A 2 + . . . W_{final} = W + α₁B₁A₁ + α₂B₂A₂ + ... Wfinal=W+α1B1A1+α2B2A2+...
这使得:
- 多技能融合:组合不同任务的适配器
- 权重调节:通过α系数控制适配强度
- 快速切换:无需重新加载模型
3. 广泛的适用性
| 模型组件 | 适用性 | 典型秩( r ) |
|---|---|---|
| 自注意力投影矩阵 | 高度有效 | 4-16 |
| FFN层矩阵 | 有效 | 8-32 |
| 所有线性层 | 通常有效 | 4-64 |
应用配置策略
参数配置指南
| 模型规模 | 推荐秩( r ) | 适配模块选择 | 参数量占比 |
|---|---|---|---|
| <1B参数 | 8-32 | Q,V投影矩阵 | 0.2%-0.5% |
| 1B-10B | 4-16 | Q,K,V,输出投影 | 0.1%-0.3% |
| >10B | 4-8 | 仅Q,V投影 | 0.05%-0.1% |
初始化技巧
- 矩阵A:使用随机高斯初始化
- 矩阵B:初始化为零矩阵
- 保证训练开始时 ΔW = BA = 0,不影响原始模型性能
性能表现
任务效果对比
| 任务类型 | LoRA效果/全微调 | 训练速度提升 | 显存节省 |
|---|---|---|---|
| 文本分类 | 98%-100% | 1.5-2倍 | 60%-75% |
| 指令跟随 | 95%-99% | 2-3倍 | 70%-80% |
| 代码生成 | 90%-95% | 1.8-2.5倍 | 65%-75% |
独特优势场景
- 多任务适配:同时维护数十个任务的适配器
- 资源受限部署:无法存储多个完整大模型的场景
- 快速实验迭代:大幅缩短训练周期
实践建议
调优要点
- 秩的选择:从较小值开始(如4),根据效果逐步增加
- α参数:控制适配强度,通常设为2r效果较好
- 模块选择:优先适配注意力层的Q、V投影矩阵
注意事项
- 过大秩可能导致过拟合
- 不同任务可能需要不同的最佳秩配置
- 极低秩(如1-2)在某些简单任务上也可能有效
生态影响
社区应用
LoRA已成为大模型微调的事实标准,因为其:
- 易于实现:改动最小,兼容性好
- 资源友好:普通消费级GPU即可训练
- 生态丰富:主流训练框架均内置支持
创新衍生
基于LoRA的改进不断涌现:
- QLoRA:进一步量化压缩
- AdaLoRA:动态调整秩分配
- LoRA+:差异化学习率策略
总结
LoRA通过低秩矩阵分解这一简洁而深刻的洞察,解决了大模型微调的核心矛盾:既要保持原模型能力,又要高效适配新任务。其“训练小矩阵,合并大模型”的设计哲学,在效果、效率、灵活性之间达到了出色平衡。
实战代码(Jupyter)
Step1 导包
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer
import datasets
datasets.__version__
'4.5.0'
import transformers
transformers.__version__
'4.56.2'
import warnings
warnings.filterwarnings('ignore')
Step2 加载数据集
ds = load_dataset("json", data_dir="./alpaca_data_zh/")
ds = ds['train']
ds
Dataset({
features: ['instruction', 'input', 'output'],
num_rows: 48818
})
ds[:3]
{'instruction': ['保持健康的三个提示。', '三原色是什么?', '描述原子的结构。'],
'input': ['', '', ''],
'output': ['以下是保持健康的三个提示:\n\n1. 保持身体活动。每天做适当的身体运动,如散步、跑步或游泳,能促进心血管健康,增强肌肉力量,并有助于减少体重。\n\n2. 均衡饮食。每天食用新鲜的蔬菜、水果、全谷物和脂肪含量低的蛋白质食物,避免高糖、高脂肪和加工食品,以保持健康的饮食习惯。\n\n3. 睡眠充足。睡眠对人体健康至关重要,成年人每天应保证 7-8 小时的睡眠。良好的睡眠有助于减轻压力,促进身体恢复,并提高注意力和记忆力。',
'三原色通常指的是红色、绿色和蓝色(RGB)。它们是通过加色混合原理创建色彩的三种基础颜色。在以发光为基础的显示设备中(如电视、计算机显示器、智能手机和平板电脑显示屏), 三原色可混合产生大量色彩。其中红色和绿色可以混合生成黄色,红色和蓝色可以混合生成品红色,蓝色和绿色可以混合生成青色。当红色、绿色和蓝色按相等比例混合时,可以产生白色或灰色。\n\n此外,在印刷和绘画中,三原色指的是以颜料为基础的红、黄和蓝颜色(RYB)。这三种颜色用以通过减色混合原理来创建色彩。不过,三原色的具体定义并不唯一,不同的颜色系统可能会采用不同的三原色。',
'原子是物质的基本单位,它由三种基本粒子组成:质子、中子和电子。质子和中子形成原子核,位于原子中心,核外的电子围绕着原子核运动。\n\n原子结构具有层次性。原子核中,质子带正电,中子不带电(中性)。原子核非常小且致密,占据了原子总质量的绝大部分。电子带负电,通常围绕核运动,形成若干层次,称为壳层或电子层。电子数量与质子数量相等,使原子呈电中性。\n\n电子在每个壳层中都呈规律分布,并且不同壳层所能容纳的电子数也不同。在最里面的壳层一般只能容纳2个电子,其次一层最多可容纳8个电子,再往外的壳层可容纳的电子数逐层递增。\n\n原子核主要受到两种相互作用力的影响:强力和电磁力。强力的作用范围非常小,主要限制在原子核内,具有极强的吸引作用,使核子(质子和中子)紧密结合在一起。电磁力的作用范围较大,主要通过核外的电子与原子核相互作用,发挥作用。\n\n这就是原子的基本结构。原子内部结构复杂多样,不同元素的原子核中质子、中子数量不同,核外电子排布分布也不同,形成了丰富多彩的化学世界。']}
Step3 数据集预处理
tokenizer = AutoTokenizer.from_pretrained("Langboat/bloom-1b4-zh")
tokenizer
BloomTokenizerFast(name_or_path='Langboat/bloom-1b4-zh', vocab_size=46145, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
3: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)
def process_func(example):
MAX_LENGTH = 256
input_ids, attention_mask, labels = [], [], []
instruction = tokenizer("\n".join(["Human: "+ example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")
response = tokenizer(example["output"] + tokenizer.eos_token)
input_ids = instruction["input_ids"] + response["input_ids"]
attention_mask = instruction["attention_mask"] + response["attention_mask"]
labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
if len(input_ids) > MAX_LENGTH:
input_ids = input_ids[:MAX_LENGTH]
attention_mask = attention_mask[:MAX_LENGTH]
labels = labels[:MAX_LENGTH]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
}
tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)
tokenized_ds
Dataset({
features: ['input_ids', 'attention_mask', 'labels'],
num_rows: 48818
})
tokenizer.decode(tokenized_ds[2]["input_ids"])
'Human: 描述原子的结构。\n\nAssistant: 原子是物质的基本单位,它由三种基本粒子组成:质子、中子和电子。质子和中子形成原子核,位于原子中心,核外的电子围绕着原子核运动。\n\n原子结构具有层次性。原子核中,质子带正电,中子不带电(中性)。原子核非常小且致密,占据了原子总质量的绝大部分。电子带负电,通常围绕核运动,形成若干层次,称为壳层或电子层。电子数量与质子数量相等,使原子呈电中性。\n\n电子在每个壳层中都呈规律分布,并且不同壳层所能容纳的电子数也不同。在最里面的壳层一般只能容纳2个电子,其次一层最多可容纳8个电子,再往外的壳层可容纳的电子数逐层递增。\n\n原子核主要受到两种相互作用力的影响:强力和电磁力。强力的作用范围非常小,主要限制在原子核内,具有极强的吸引作用,使核子(质子和中子)紧密结合在一起。电磁力的作用范围较大,主要通过核外的电子与原子核相互作用,发挥作用。\n\n这就是原子的'
tokenizer.decode(list(filter(lambda x: x!=-100, tokenized_ds[2]["labels"])))
'原子是物质的基本单位,它由三种基本粒子组成:质子、中子和电子。质子和中子形成原子核,位于原子中心,核外的电子围绕着原子核运动。\n\n原子结构具有层次性。原子核中,质子带正电,中子不带电(中性)。原子核非常小且致密,占据了原子总质量的绝大部分。电子带负电,通常围绕核运动,形成若干层次,称为壳层或电子层。电子数量与质子数量相等,使原子呈电中性。\n\n电子在每个壳层中都呈规律分布,并且不同壳层所能容纳的电子数也不同。在最里面的壳层一般只能容纳2个电子,其次一层最多可容纳8个电子,再往外的壳层可容纳的电子数逐层递增。\n\n原子核主要受到两种相互作用力的影响:强力和电磁力。强力的作用范围非常小,主要限制在原子核内,具有极强的吸引作用,使核子(质子和中子)紧密结合在一起。电磁力的作用范围较大,主要通过核外的电子与原子核相互作用,发挥作用。\n\n这就是原子的'
len(tokenized_ds[2]["input_ids"])
256
len(tokenized_ds[2]["labels"])
256
Step4 模型创建
model = AutoModelForCausalLM.from_pretrained("Langboat/bloom-1b4-zh", low_cpu_mem_usage=True)
model.device
device(type='cpu')
model
BloomForCausalLM(
(transformer): BloomModel(
(word_embeddings): Embedding(46145, 2048)
(word_embeddings_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
(h): ModuleList(
(0-23): 24 x BloomBlock(
(input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
(self_attention): BloomAttention(
(query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
(dense): Linear(in_features=2048, out_features=2048, bias=True)
(attention_dropout): Dropout(p=0.0, inplace=False)
)
(post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
(mlp): BloomMLP(
(dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
(gelu_impl): BloomGelu()
(dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
)
)
)
(ln_f): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
)
(lm_head): Linear(in_features=2048, out_features=46145, bias=False)
)
sum(param.numel() for param in model.parameters())
1303111680
LoRA
PEFT Step1 配置文件
import peft
peft.__version__
'0.18.1'
# conda install peft --channel conda-forge
from peft import LoraConfig, get_peft_model, TaskType
config = LoraConfig(task_type=TaskType.CAUSAL_LM,
# target_modules=['query_key_value', 'dense_4h_to_h']
# modules_to_save=['word_embeddings']
)
config
LoraConfig(task_type=<TaskType.CAUSAL_LM: 'CAUSAL_LM'>, peft_type=<PeftType.LORA: 'LORA'>, auto_mapping=None, peft_version='0.18.1', base_model_name_or_path=None, revision=None, inference_mode=False, r=8, target_modules=None, exclude_modules=None, lora_alpha=8, lora_dropout=0.0, fan_in_fan_out=False, bias='none', use_rslora=False, modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None, rank_pattern={}, alpha_pattern={}, megatron_config=None, megatron_core='megatron.core', trainable_token_indices=None, loftq_config={}, eva_config=None, corda_config=None, use_dora=False, alora_invocation_tokens=None, use_qalora=False, qalora_group_size=16, layer_replication=None, runtime_config=LoraRuntimeConfig(ephemeral_gpu_offload=False), lora_bias=False, target_parameters=None, arrow_config=None, ensure_weight_tying=False)
for name, parameter in model.named_parameters():
print(name)
transformer.word_embeddings.weight
transformer.word_embeddings_layernorm.weight
transformer.word_embeddings_layernorm.bias
transformer.h.0.input_layernorm.weight
transformer.h.0.input_layernorm.bias
transformer.h.0.self_attention.query_key_value.weight
transformer.h.0.self_attention.query_key_value.bias
transformer.h.0.self_attention.dense.weight
transformer.h.0.self_attention.dense.bias
transformer.h.0.post_attention_layernorm.weight
transformer.h.0.post_attention_layernorm.bias
transformer.h.0.mlp.dense_h_to_4h.weight
transformer.h.0.mlp.dense_h_to_4h.bias
transformer.h.0.mlp.dense_4h_to_h.weight
transformer.h.0.mlp.dense_4h_to_h.bias
transformer.h.1.input_layernorm.weight
transformer.h.1.input_layernorm.bias
transformer.h.1.self_attention.query_key_value.weight
transformer.h.1.self_attention.query_key_value.bias
transformer.h.1.self_attention.dense.weight
transformer.h.1.self_attention.dense.bias
transformer.h.1.post_attention_layernorm.weight
transformer.h.1.post_attention_layernorm.bias
transformer.h.1.mlp.dense_h_to_4h.weight
transformer.h.1.mlp.dense_h_to_4h.bias
transformer.h.1.mlp.dense_4h_to_h.weight
transformer.h.1.mlp.dense_4h_to_h.bias
transformer.h.2.input_layernorm.weight
transformer.h.2.input_layernorm.bias
transformer.h.2.self_attention.query_key_value.weight
transformer.h.2.self_attention.query_key_value.bias
transformer.h.2.self_attention.dense.weight
transformer.h.2.self_attention.dense.bias
transformer.h.2.post_attention_layernorm.weight
transformer.h.2.post_attention_layernorm.bias
transformer.h.2.mlp.dense_h_to_4h.weight
transformer.h.2.mlp.dense_h_to_4h.bias
transformer.h.2.mlp.dense_4h_to_h.weight
transformer.h.2.mlp.dense_4h_to_h.bias
transformer.h.3.input_layernorm.weight
transformer.h.3.input_layernorm.bias
transformer.h.3.self_attention.query_key_value.weight
transformer.h.3.self_attention.query_key_value.bias
transformer.h.3.self_attention.dense.weight
transformer.h.3.self_attention.dense.bias
transformer.h.3.post_attention_layernorm.weight
transformer.h.3.post_attention_layernorm.bias
transformer.h.3.mlp.dense_h_to_4h.weight
transformer.h.3.mlp.dense_h_to_4h.bias
transformer.h.3.mlp.dense_4h_to_h.weight
transformer.h.3.mlp.dense_4h_to_h.bias
transformer.h.4.input_layernorm.weight
transformer.h.4.input_layernorm.bias
transformer.h.4.self_attention.query_key_value.weight
transformer.h.4.self_attention.query_key_value.bias
transformer.h.4.self_attention.dense.weight
transformer.h.4.self_attention.dense.bias
transformer.h.4.post_attention_layernorm.weight
transformer.h.4.post_attention_layernorm.bias
transformer.h.4.mlp.dense_h_to_4h.weight
transformer.h.4.mlp.dense_h_to_4h.bias
transformer.h.4.mlp.dense_4h_to_h.weight
transformer.h.4.mlp.dense_4h_to_h.bias
transformer.h.5.input_layernorm.weight
transformer.h.5.input_layernorm.bias
transformer.h.5.self_attention.query_key_value.weight
transformer.h.5.self_attention.query_key_value.bias
transformer.h.5.self_attention.dense.weight
transformer.h.5.self_attention.dense.bias
transformer.h.5.post_attention_layernorm.weight
transformer.h.5.post_attention_layernorm.bias
transformer.h.5.mlp.dense_h_to_4h.weight
transformer.h.5.mlp.dense_h_to_4h.bias
transformer.h.5.mlp.dense_4h_to_h.weight
transformer.h.5.mlp.dense_4h_to_h.bias
transformer.h.6.input_layernorm.weight
transformer.h.6.input_layernorm.bias
transformer.h.6.self_attention.query_key_value.weight
transformer.h.6.self_attention.query_key_value.bias
transformer.h.6.self_attention.dense.weight
transformer.h.6.self_attention.dense.bias
transformer.h.6.post_attention_layernorm.weight
transformer.h.6.post_attention_layernorm.bias
transformer.h.6.mlp.dense_h_to_4h.weight
transformer.h.6.mlp.dense_h_to_4h.bias
transformer.h.6.mlp.dense_4h_to_h.weight
transformer.h.6.mlp.dense_4h_to_h.bias
transformer.h.7.input_layernorm.weight
transformer.h.7.input_layernorm.bias
transformer.h.7.self_attention.query_key_value.weight
transformer.h.7.self_attention.query_key_value.bias
transformer.h.7.self_attention.dense.weight
transformer.h.7.self_attention.dense.bias
transformer.h.7.post_attention_layernorm.weight
transformer.h.7.post_attention_layernorm.bias
transformer.h.7.mlp.dense_h_to_4h.weight
transformer.h.7.mlp.dense_h_to_4h.bias
transformer.h.7.mlp.dense_4h_to_h.weight
transformer.h.7.mlp.dense_4h_to_h.bias
transformer.h.8.input_layernorm.weight
transformer.h.8.input_layernorm.bias
transformer.h.8.self_attention.query_key_value.weight
transformer.h.8.self_attention.query_key_value.bias
transformer.h.8.self_attention.dense.weight
transformer.h.8.self_attention.dense.bias
transformer.h.8.post_attention_layernorm.weight
transformer.h.8.post_attention_layernorm.bias
transformer.h.8.mlp.dense_h_to_4h.weight
transformer.h.8.mlp.dense_h_to_4h.bias
transformer.h.8.mlp.dense_4h_to_h.weight
transformer.h.8.mlp.dense_4h_to_h.bias
transformer.h.9.input_layernorm.weight
transformer.h.9.input_layernorm.bias
transformer.h.9.self_attention.query_key_value.weight
transformer.h.9.self_attention.query_key_value.bias
transformer.h.9.self_attention.dense.weight
transformer.h.9.self_attention.dense.bias
transformer.h.9.post_attention_layernorm.weight
transformer.h.9.post_attention_layernorm.bias
transformer.h.9.mlp.dense_h_to_4h.weight
transformer.h.9.mlp.dense_h_to_4h.bias
transformer.h.9.mlp.dense_4h_to_h.weight
transformer.h.9.mlp.dense_4h_to_h.bias
transformer.h.10.input_layernorm.weight
transformer.h.10.input_layernorm.bias
transformer.h.10.self_attention.query_key_value.weight
transformer.h.10.self_attention.query_key_value.bias
transformer.h.10.self_attention.dense.weight
transformer.h.10.self_attention.dense.bias
transformer.h.10.post_attention_layernorm.weight
transformer.h.10.post_attention_layernorm.bias
transformer.h.10.mlp.dense_h_to_4h.weight
transformer.h.10.mlp.dense_h_to_4h.bias
transformer.h.10.mlp.dense_4h_to_h.weight
transformer.h.10.mlp.dense_4h_to_h.bias
transformer.h.11.input_layernorm.weight
transformer.h.11.input_layernorm.bias
transformer.h.11.self_attention.query_key_value.weight
transformer.h.11.self_attention.query_key_value.bias
transformer.h.11.self_attention.dense.weight
transformer.h.11.self_attention.dense.bias
transformer.h.11.post_attention_layernorm.weight
transformer.h.11.post_attention_layernorm.bias
transformer.h.11.mlp.dense_h_to_4h.weight
transformer.h.11.mlp.dense_h_to_4h.bias
transformer.h.11.mlp.dense_4h_to_h.weight
transformer.h.11.mlp.dense_4h_to_h.bias
transformer.h.12.input_layernorm.weight
transformer.h.12.input_layernorm.bias
transformer.h.12.self_attention.query_key_value.weight
transformer.h.12.self_attention.query_key_value.bias
transformer.h.12.self_attention.dense.weight
transformer.h.12.self_attention.dense.bias
transformer.h.12.post_attention_layernorm.weight
transformer.h.12.post_attention_layernorm.bias
transformer.h.12.mlp.dense_h_to_4h.weight
transformer.h.12.mlp.dense_h_to_4h.bias
transformer.h.12.mlp.dense_4h_to_h.weight
transformer.h.12.mlp.dense_4h_to_h.bias
transformer.h.13.input_layernorm.weight
transformer.h.13.input_layernorm.bias
transformer.h.13.self_attention.query_key_value.weight
transformer.h.13.self_attention.query_key_value.bias
transformer.h.13.self_attention.dense.weight
transformer.h.13.self_attention.dense.bias
transformer.h.13.post_attention_layernorm.weight
transformer.h.13.post_attention_layernorm.bias
transformer.h.13.mlp.dense_h_to_4h.weight
transformer.h.13.mlp.dense_h_to_4h.bias
transformer.h.13.mlp.dense_4h_to_h.weight
transformer.h.13.mlp.dense_4h_to_h.bias
transformer.h.14.input_layernorm.weight
transformer.h.14.input_layernorm.bias
transformer.h.14.self_attention.query_key_value.weight
transformer.h.14.self_attention.query_key_value.bias
transformer.h.14.self_attention.dense.weight
transformer.h.14.self_attention.dense.bias
transformer.h.14.post_attention_layernorm.weight
transformer.h.14.post_attention_layernorm.bias
transformer.h.14.mlp.dense_h_to_4h.weight
transformer.h.14.mlp.dense_h_to_4h.bias
transformer.h.14.mlp.dense_4h_to_h.weight
transformer.h.14.mlp.dense_4h_to_h.bias
transformer.h.15.input_layernorm.weight
transformer.h.15.input_layernorm.bias
transformer.h.15.self_attention.query_key_value.weight
transformer.h.15.self_attention.query_key_value.bias
transformer.h.15.self_attention.dense.weight
transformer.h.15.self_attention.dense.bias
transformer.h.15.post_attention_layernorm.weight
transformer.h.15.post_attention_layernorm.bias
transformer.h.15.mlp.dense_h_to_4h.weight
transformer.h.15.mlp.dense_h_to_4h.bias
transformer.h.15.mlp.dense_4h_to_h.weight
transformer.h.15.mlp.dense_4h_to_h.bias
transformer.h.16.input_layernorm.weight
transformer.h.16.input_layernorm.bias
transformer.h.16.self_attention.query_key_value.weight
transformer.h.16.self_attention.query_key_value.bias
transformer.h.16.self_attention.dense.weight
transformer.h.16.self_attention.dense.bias
transformer.h.16.post_attention_layernorm.weight
transformer.h.16.post_attention_layernorm.bias
transformer.h.16.mlp.dense_h_to_4h.weight
transformer.h.16.mlp.dense_h_to_4h.bias
transformer.h.16.mlp.dense_4h_to_h.weight
transformer.h.16.mlp.dense_4h_to_h.bias
transformer.h.17.input_layernorm.weight
transformer.h.17.input_layernorm.bias
transformer.h.17.self_attention.query_key_value.weight
transformer.h.17.self_attention.query_key_value.bias
transformer.h.17.self_attention.dense.weight
transformer.h.17.self_attention.dense.bias
transformer.h.17.post_attention_layernorm.weight
transformer.h.17.post_attention_layernorm.bias
transformer.h.17.mlp.dense_h_to_4h.weight
transformer.h.17.mlp.dense_h_to_4h.bias
transformer.h.17.mlp.dense_4h_to_h.weight
transformer.h.17.mlp.dense_4h_to_h.bias
transformer.h.18.input_layernorm.weight
transformer.h.18.input_layernorm.bias
transformer.h.18.self_attention.query_key_value.weight
transformer.h.18.self_attention.query_key_value.bias
transformer.h.18.self_attention.dense.weight
transformer.h.18.self_attention.dense.bias
transformer.h.18.post_attention_layernorm.weight
transformer.h.18.post_attention_layernorm.bias
transformer.h.18.mlp.dense_h_to_4h.weight
transformer.h.18.mlp.dense_h_to_4h.bias
transformer.h.18.mlp.dense_4h_to_h.weight
transformer.h.18.mlp.dense_4h_to_h.bias
transformer.h.19.input_layernorm.weight
transformer.h.19.input_layernorm.bias
transformer.h.19.self_attention.query_key_value.weight
transformer.h.19.self_attention.query_key_value.bias
transformer.h.19.self_attention.dense.weight
transformer.h.19.self_attention.dense.bias
transformer.h.19.post_attention_layernorm.weight
transformer.h.19.post_attention_layernorm.bias
transformer.h.19.mlp.dense_h_to_4h.weight
transformer.h.19.mlp.dense_h_to_4h.bias
transformer.h.19.mlp.dense_4h_to_h.weight
transformer.h.19.mlp.dense_4h_to_h.bias
transformer.h.20.input_layernorm.weight
transformer.h.20.input_layernorm.bias
transformer.h.20.self_attention.query_key_value.weight
transformer.h.20.self_attention.query_key_value.bias
transformer.h.20.self_attention.dense.weight
transformer.h.20.self_attention.dense.bias
transformer.h.20.post_attention_layernorm.weight
transformer.h.20.post_attention_layernorm.bias
transformer.h.20.mlp.dense_h_to_4h.weight
transformer.h.20.mlp.dense_h_to_4h.bias
transformer.h.20.mlp.dense_4h_to_h.weight
transformer.h.20.mlp.dense_4h_to_h.bias
transformer.h.21.input_layernorm.weight
transformer.h.21.input_layernorm.bias
transformer.h.21.self_attention.query_key_value.weight
transformer.h.21.self_attention.query_key_value.bias
transformer.h.21.self_attention.dense.weight
transformer.h.21.self_attention.dense.bias
transformer.h.21.post_attention_layernorm.weight
transformer.h.21.post_attention_layernorm.bias
transformer.h.21.mlp.dense_h_to_4h.weight
transformer.h.21.mlp.dense_h_to_4h.bias
transformer.h.21.mlp.dense_4h_to_h.weight
transformer.h.21.mlp.dense_4h_to_h.bias
transformer.h.22.input_layernorm.weight
transformer.h.22.input_layernorm.bias
transformer.h.22.self_attention.query_key_value.weight
transformer.h.22.self_attention.query_key_value.bias
transformer.h.22.self_attention.dense.weight
transformer.h.22.self_attention.dense.bias
transformer.h.22.post_attention_layernorm.weight
transformer.h.22.post_attention_layernorm.bias
transformer.h.22.mlp.dense_h_to_4h.weight
transformer.h.22.mlp.dense_h_to_4h.bias
transformer.h.22.mlp.dense_4h_to_h.weight
transformer.h.22.mlp.dense_4h_to_h.bias
transformer.h.23.input_layernorm.weight
transformer.h.23.input_layernorm.bias
transformer.h.23.self_attention.query_key_value.weight
transformer.h.23.self_attention.query_key_value.bias
transformer.h.23.self_attention.dense.weight
transformer.h.23.self_attention.dense.bias
transformer.h.23.post_attention_layernorm.weight
transformer.h.23.post_attention_layernorm.bias
transformer.h.23.mlp.dense_h_to_4h.weight
transformer.h.23.mlp.dense_h_to_4h.bias
transformer.h.23.mlp.dense_4h_to_h.weight
transformer.h.23.mlp.dense_4h_to_h.bias
transformer.ln_f.weight
transformer.ln_f.bias
PEFT Step2 创建模型
peft_model = get_peft_model(model, config)
peft_model
PeftModelForCausalLM(
(base_model): LoraModel(
(model): BloomForCausalLM(
(transformer): BloomModel(
(word_embeddings): Embedding(46145, 2048)
(word_embeddings_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
(h): ModuleList(
(0-23): 24 x BloomBlock(
(input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
(self_attention): BloomAttention(
(query_key_value): lora.Linear(
(base_layer): Linear(in_features=2048, out_features=6144, bias=True)
(lora_dropout): ModuleDict(
(default): Identity()
)
(lora_A): ModuleDict(
(default): Linear(in_features=2048, out_features=8, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=8, out_features=6144, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
(lora_magnitude_vector): ModuleDict()
)
(dense): Linear(in_features=2048, out_features=2048, bias=True)
(attention_dropout): Dropout(p=0.0, inplace=False)
)
(post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
(mlp): BloomMLP(
(dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
(gelu_impl): BloomGelu()
(dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
)
)
)
(ln_f): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
)
(lm_head): Linear(in_features=2048, out_features=46145, bias=False)
)
)
)
config
LoraConfig(task_type=<TaskType.CAUSAL_LM: 'CAUSAL_LM'>, peft_type=<PeftType.LORA: 'LORA'>, auto_mapping=None, peft_version='0.18.1', base_model_name_or_path='Langboat/bloom-1b4-zh', revision=None, inference_mode=False, r=8, target_modules={'query_key_value'}, exclude_modules=None, lora_alpha=8, lora_dropout=0.0, fan_in_fan_out=False, bias='none', use_rslora=False, modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None, rank_pattern={}, alpha_pattern={}, megatron_config=None, megatron_core='megatron.core', trainable_token_indices=None, loftq_config={}, eva_config=None, corda_config=None, use_dora=False, alora_invocation_tokens=None, use_qalora=False, qalora_group_size=16, layer_replication=None, runtime_config=LoraRuntimeConfig(ephemeral_gpu_offload=False), lora_bias=False, target_parameters=None, arrow_config=None, ensure_weight_tying=False)
peft_model.print_trainable_parameters()
trainable params: 1,572,864 || all params: 1,304,684,544 || trainable%: 0.1206
Step5 配置训练参数
args = TrainingArguments(
output_dir="./chatbot", # 输出文件夹存储模型的预测结果和模型文件checkpoints
per_device_train_batch_size=1, # 默认8, 对于训练的时候每个 GPU核或者CPU 上面对应的一个批次的样本数
gradient_accumulation_steps=8, # 默认1, 在执行反向传播/更新参数之前, 对应梯度计算累积了多少次
logging_steps=10, # 每隔10迭代落地一次日志
num_train_epochs=1 # 整体上数据集让模型学习多少遍
)
Step6 创建训练器
trainer = Trainer(
model=model,
args=args,
train_dataset=tokenized_ds,
# 构建一个个批次数据所需要的
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True)
)
Step7 模型训练
trainer.train()
Step8 模型推理
省略…
更多推荐


所有评论(0)