【LLM实战】手把手教你用小模型实现CoT(思维链)微调
思维链(Chain of Thought, CoT)是提升大模型推理能力的关键技术。很多人认为这需要巨大的模型和昂贵的硬件。本文将打破这一迷思,为你详细讲解CoT原理,并提供一套完整的、可在8GB显存(如RTX 4060)上流畅运行的代码,教你如何通过微调(Fine-tuning)让小模型也具备逻辑推理能力。
【LLM实战】手把手教你用小模型实现CoT(思维链)微调
摘要:思维链(Chain of Thought, CoT)是提升大模型推理能力的关键技术。很多人认为这需要巨大的模型和昂贵的硬件。本文将打破这一迷思,为你详细讲解CoT原理,并提供一套完整的、可在8GB显存(如RTX 4060)上流畅运行的代码,教你如何通过微调(Fine-tuning)让小模型也具备逻辑推理能力。
一、 引言:为什么小模型也需要CoT?
在大型语言模型(LLM)时代,思维链(Chain of Thought, CoT) 几乎成了复杂推理任务的标配。它通过引导模型通过中间推理步骤来获得最终答案,显著提升了模型在数学、逻辑和常识推理等任务上的表现。
通常,我们看到的是GPT-4或Claude这样的巨无霸模型在使用CoT。但对于资源有限的个人开发者或研究者,我们能否在自己的消费级显卡上,让小模型也学会这种“深思熟虑”的能力呢?
答案是肯定的。通过有监督微调(Supervised Fine-tuning, SFT),我们可以将CoT的能力“注入”到小模型中。本文将使用Google的flan-t5-base模型,在仅需较少显存的情况下,演示这一全过程。
二、 理论篇:通俗理解CoT思维链
2.1 什么是CoT?
简单来说,CoT就是让模型把“思考过程”写出来,而不是直接给答案。
举个经典的例子:
-
传统提问(Standard Prompting):
问:小明有5个苹果,给了小红2个,又买了3个。现在他有几个?
答:6个。
(模型需要在内部完成所有计算,容易出错) -
思维链提问(CoT Prompting):
问:小明有5个苹果,给了小红2个,又买了3个。现在他有几个?
答:小明一开始有5个苹果。给了小红2个,剩下 5-2=3个。又买了3个,所以现在有 3+3=6个。答案是6个。
(模型把大问题分解成小步骤,准确率大幅提升)
- CoT的两种主要形式
- Zero-Shot CoT (零样本CoT): 无需给模型任何示例,只需在问题末尾加上一句简单的指令,如“让我们一步一步地思考” (Let’s think step by step),就能激发模型自身的推理能力。这适用于能力非常强的大模型。
- Fine-tune CoT (微调CoT): 对于小型模型,它们自身可能不具备强大的零样本推理能力。因此,我们需要“教会”它们如何进行链式思考。方法是:准备一个“问题 -> 思考过程 -> 答案”格式的数据集,然后用这个数据集去 微调(Fine-tune) 小模型。这正是我们接下来要用代码实现的
- 基础CoT(像我们代码实现的)是让模型生成一个非结构化的、自由发挥的思考过程。这能提高推理的准确性。
- 高级CoT 是构建一个结构化的思考模板,强制模型按照我们设计的步骤去思考,并在每一步生成特定的、我们想要的东西。
2.2 为什么小模型需要微调来实现CoT?
- 大模型(>100B参数):通常具备“涌现”能力,只需要在提问时加上一句“让我们一步步思考”(Zero-shot CoT),就能自动生成推理步骤。
- 小模型(<1B参数):原生推理能力较弱,直接问它可能无法生成连贯的思维链。因此,我们需要构建一个包含 “问题 -> 详细推理过程 -> 答案” 的数据集,通过微调来“教会”它这种思考模式。
三、 实战篇:8G显存微调CoT代码详解
3.1 环境准备
本项目对硬件要求极低,8GB显存的显卡(如NVIDIA RTX 3060/4060)绰绰有余,甚至集显或纯CPU也能跑(只是慢一点)。
首先,创建一个新的Python虚拟环境,并安装必要的依赖库:
pip install torch transformers datasets sentencepiece accelerate
- torch:核心的深度学习框架。
- transformers:Hugging Face库,用于加载模型和进行训练。
- datasets:Hugging Face库,用于处理数据。
- sentencepiece:T5模型需要的分词器工具。
3.2 代码全解析
下面是完整的Python代码。你可以将其保存为一个.py文件(例如 train_cot.py)然后直接运行。
# -------------------------------------------------------------------------
# CoT微调实战代码:基于GSM8K数据集微调Flan-T5模型,实现数学推理能力
# -------------------------------------------------------------------------
# 导入必要的库
import torch # 深度学习框架
import os # 操作系统相关操作
import re # 正则表达式,用于提取答案中的数字
import numpy as np # 数值计算库
from datasets import load_dataset # 加载Hugging Face数据集
from transformers import (
T5ForConditionalGeneration, # T5生成式模型
T5Tokenizer, # T5对应的分词器
Trainer, # Hugging Face训练工具
TrainingArguments, # 训练参数配置
EvalPrediction # 评估预测结果的数据结构
)
# --- 环境与设备设置 ---
# 设置Hugging Face镜像地址,加速国内下载
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
# 选择计算设备(优先GPU,无GPU则用CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("=" * 40 + f"\n将使用设备: {device.upper()}\n" + "=" * 40)
# 1. 加载训练集与测试集
# ---------------------------------------------------------
print("正在从 Hugging Face Hub 加载 GSM8K 数据集...")
# GSM8K是一个数学推理数据集,包含8.5k个小学数学问题及分步解答
dataset = load_dataset("gsm8k", "main")
print("数据集加载完成!")
# 为了加速演示,选择部分样本进行训练和评估(实际应用可使用完整数据集)
train_dataset = dataset['train'].select(range(2000)) # 取前2000条训练数据
eval_dataset = dataset['test'].select(range(200)) # 取前200条测试数据
# 2. 加载模型与分词器
# ------------------------------------
model_name = "google/flan-t5-base" # 选择基础版Flan-T5模型(性能与效率平衡)
# 加载分词器:将文本转换为模型可理解的token
tokenizer = T5Tokenizer.from_pretrained(model_name)
# 加载预训练模型:T5的条件生成版本(适合文本生成任务)
model = T5ForConditionalGeneration.from_pretrained(model_name)
# 将模型移动到指定计算设备(GPU/CPU)
model.to(device)
print(f"模型 '{model_name}' 已成功加载到 {device.upper()} 设备上。")
# 3. 数据预处理
# ------------------------------------
def preprocess_function(examples):
"""
数据预处理函数:将原始文本转换为模型输入格式
Args:
examples: 数据集样本(包含question和answer字段)
Returns:
处理后的模型输入(包含input_ids, attention_mask, labels)
"""
# 构造输入:在问题前添加"solve: "前缀,提示模型进行求解
inputs = ["solve: " + q for q in examples["question"]]
# 目标输出:原始数据中的答案(包含思维链和最终结果)
targets = [ans for ans in examples["answer"]]
# 对输入文本进行分词:
# max_length=512:最大序列长度(超过截断)
# truncation=True:超长时截断
# padding="max_length":不足时用pad token填充到最大长度
model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length")
# 对目标文本进行同样处理(作为模型的标签)
labels = tokenizer(targets, max_length=512, truncation=True, padding="max_length")
# 将标签存入模型输入字典(T5训练时需要labels字段)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
# 对训练集和测试集批量应用预处理函数
tokenized_train_dataset = train_dataset.map(
preprocess_function,
batched=True, # 批量处理(加速)
desc="正在处理训练集..." # 显示处理进度描述
)
tokenized_eval_dataset = eval_dataset.map(
preprocess_function,
batched=True,
desc="正在处理测试集..."
)
# 4. 定义评估指标计算函数
# ------------------------------------
def extract_last_number(text):
"""从文本中提取最后一个数字(GSM8K答案通常以数字结尾)"""
# 用正则表达式匹配所有数字(支持整数、小数、带逗号的数字)
numbers = re.findall(r'[\d\.\,]+', text)
if not numbers: # 没有找到数字
return None
try:
# 移除数字中的逗号(如"1,000" -> "1000")并转换为浮点数
return float(numbers[-1].replace(',', ''))
except ValueError: # 数字格式错误
return None
def compute_metrics(p: EvalPrediction):
"""
计算评估指标:精确匹配准确率(预测结果与标签的最后数字是否一致)
Args:
p: 评估预测对象(包含predictions和label_ids)
Returns:
包含准确率的字典
"""
# 提取预测结果(处理可能的元组格式)
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
labels = p.label_ids # 标签ID
# 将预测结果和解码为文本(跳过特殊token,如pad、eos)
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
# 将标签中的-100(T5中用于忽略的token)替换为pad token,再解码
labels[labels == -100] = tokenizer.pad_token_id
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# 计算精确匹配数量
correct = 0
for pred, label in zip(decoded_preds, decoded_labels):
# 提取预测和标签中的最后一个数字
pred_ans = extract_last_number(pred)
label_ans = extract_last_number(label)
# 两者都有效且相等则视为正确
if pred_ans is not None and label_ans is not None and pred_ans == label_ans:
correct += 1
# 计算准确率
accuracy = correct / len(decoded_preds)
return {"exact_match_accuracy": accuracy}
# 5. 设置训练参数
# ------------------------------------
training_args = TrainingArguments(
output_dir="./cot_finetuned_model_gsm8k_base_eval", # 模型保存路径
num_train_epochs=3, # 训练轮数(小数据集3轮足够演示)
per_device_train_batch_size=2, # 每个设备的训练批次大小(根据GPU显存调整)
warmup_steps=100, # 学习率热身步数(避免初始学习率过大)
weight_decay=0.01, # 权重衰减(防止过拟合)
logging_dir="./logs_gsm8k_base_eval", # 日志保存路径
logging_steps=50, # 每50步记录一次日志
do_eval=True, # 训练过程中进行评估
eval_steps=500, # 每500步评估一次
save_steps=500, # 每500步保存一次模型
)
# 初始化Trainer:封装训练逻辑
trainer = Trainer(
model=model, # 待训练的模型
args=training_args, # 训练参数
train_dataset=tokenized_train_dataset, # 训练数据集
eval_dataset=tokenized_eval_dataset, # 评估数据集
compute_metrics=compute_metrics, # 评估指标计算函数
tokenizer=tokenizer, # 分词器(用于保存模型时同步保存)
)
# 6. 开始训练和评估
# ------------------------------------
print("开始使用 GSM8K 数据集进行CoT微调 (base 模型)...")
trainer.train() # 启动训练
print("模型微调完成!")
# 7. 使用最终模型进行推理测试
# ------------------------------------
def ask_question(question, trained_model, tokenizer):
"""
使用训练好的模型进行推理
Args:
question: 输入的问题
trained_model: 训练好的模型
tokenizer: 分词器
Returns:
模型生成的回答(包含思维链)
"""
# 构造提示词(与训练时的输入格式保持一致)
prompt = "solve: " + question
# 将文本转换为模型输入格式,并移动到设备
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# 生成回答:
# max_length=512:最大生成长度
# num_beams=5:beam search宽度(提升生成质量)
# early_stopping=True:遇到结束符时停止生成
outputs = trained_model.generate(**inputs, max_length=512, num_beams=5, early_stopping=True)
# 解码生成结果(跳过特殊token)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# 测试问题:简单的数学应用题
new_question = "A person has 6 bananas. They eat 2 of the bananas. Then, a friend gives them 5 more bananas. How many bananas does the person have now?"
print("\n--- 使用训练完成后的 base 模型进行推理测试 ---")
print(f"问题: {new_question}")
# 生成回答并打印
generated_answer = ask_question(new_question, trainer.model, tokenizer)
print(f"模型的回答 (包含思考过程):\n{generated_answer}")
3.3 运行结果与分析
1.将上面的代码保存为 train_cot.py。
2.在终端中,cd 到该文件所在的目录。
3.运行命令:python train_cot.py
4.程序会自动下载模型(首次运行时需要),然后开始训练。你会看到训练的进度条和loss变化。
5.训练结束后,程序会自动用一个新问题来测试模型,并打印出推理结果。
预期结果展示
训练完成后,你将看到类似下面的输出:
... (此处为漫长的训练日志) ...
{'train_runtime': 45820.12, 'train_loss': 0.2531, 'epoch': 3.0}
模型微调完成!
--- 使用训练完成后的 base 模型进行最终测试 ---
问题: A person has 6 bananas. They eat 2 of the bananas. Then, a friend gives them 5 more bananas. How many bananas does the person have now?
模型的回答 (包含思考过程):
The person started with 6 bananas and ate 2, so they had 6 - 2 = 4 bananas. Then they were given 5 more bananas, so they have 4 + 5 = 9 bananas. #### 9
结果分析
由我们实际的运行结果看出,base 模型,已经学会了如何“扮演”一个会做CoT推理的AI。它掌握了所谓的的“套路”和“格式”。但这也暴露了我们的局限性:
- 思维固化 (Over-Specialization):
- 由于训练数据过于单一(全是同一种格式的数学题),模型的大脑被“格式化”了。它只会模仿解题的形式(写出计算步骤),却丧失了理解问题和进行真实计算的能力。
- 灾难性遗忘 (Catastrophic Forgetting):
- 这场高度重复的数学特训,让模型忘记了它在预训练阶段学到的、渊博的通用语言知识和常识推理能力。它从一个“通才”退化成了一个有缺陷的“专才”。
- 模型容量的误用 (Capacity Misuse):
- 无论是small还是base模型,它们都把有限的“脑容量”用错了地方——全部用来死记硬背解题的 “套路”和“格式”,而没有足够的能力去学习更深层次的语义理解 和数学逻辑
四、总结
本文带你从理论到实践,完整地走了一遍CoT微调的全过程。希望你通过这篇文章,能够理解:
- CoT是什么: 一种通过引导模型输出思考步骤来提升推理能力的技术。
- CoT为何有效: 它将复杂问题分解,为模型提供了推理的“草稿纸”。
- 如何动手实现: 如何准备CoT格式数据,并使用Hugging Face transformers 库微调一个小型模型。
现在,你已经掌握了开启语言模型推理能力大门的一把钥匙。快去尝试解决更有趣的问题吧!
更多推荐


所有评论(0)