大模型初识(常用的基础模型+微调框架+优化工具)
大模型初识
常用的基础模型
GPT (以 ChatGPT 为代表) | BERT | Llama (以 Llama 3 为代表) |
---|---|---|
仅解码器 (Decoder-Only) | 仅编码器 (Encoder-Only) | 仅解码器 (Decoder-Only) |
训练方式 自回归 (Causal LM) | 自编码 (Masked LM) | 自回归 (Causal LM) |
上下文理解单向(只能看前面的词) | 双向(能同时看前后所有词) | 单向(只能看前面的词) |
文本生成 | 文本理解 | 文本生成 |
GPT-1/2 开源,GPT-3/4 闭源 | 完全开源 | 完全开源 (Meta 发布) |
典型交互方式 “请写一首关于春天的诗” | “这句话的情感是积极还是消极?” | 同GPT,可作为开源替代 |
生成连贯、长篇的文本 | 输出一个分类标签或对词语的理解 | 生成连贯、长篇的文本 |
除了以上三种,还有:
1. 闭源模型(通过API提供服务)
- GPT-4 (OpenAI): 当前能力的标杆,多模态(支持图像输入),推理能力极强。(通过AutoModelForCausalLM.from_pretrained(model_name, device_map=“auto”)加载)
- Claude (Anthropic): 强调 Constitutional AI(宪法AI),致力于构建安全、可靠、可控的AI,长上下文能力是其亮点(最高支持200K tokens)。
- Gemini (Google DeepMind): Google 的多模态模型家族,从轻量级到超大规模(Nano, Pro, Flash,Ultra),深度集成到Google生态中。
2. 开源模型(可自行部署微调)
-
Mistral / Mixtral (Mistral AI): 一家欧洲公司推出的强大开源模型。Mixtral 是 混合专家(MoE)模型,用更少的激活参数实现了更强的性能,效率极高。
-
Command R/R+ (Cohere): 专为企业级应用优化,特别擅长 RAG(检索增强生成) 和工具调用,适合构建复杂的生产系统。
-
Qwen (通义千问, 阿里云): 强大的中英文双语开源模型系列,覆盖不同尺寸,性能优异。
-
OLMo (Allen Institute for AI):真正完全开源(包括代码、数据、训练全过程),旨在推动AI研究的透明度和可重复性。
-
Gemma (Google): 基于 Gemini 技术推出的轻量级开源模型系列,适合轻量级部署和入门。
-
T5 (Text-to-Text Transfer Transformer, Google):提出了“所有任务都是文本到文本的生成”的框架,统一了理解和生成任务的范式。例如,翻译任务输入 “translate English to German: That is good.”,输出 “Das ist gut.”;分类任务输入 “mnli premise: I hate this. hypothesis: My feelings are positive.”,输出 “contradiction”。(通过AutoModelForSeq2SeqLM进行加载)
-
BART (Facebook):一个结合了BERT和GPT思想的编码器-解码器(Encoder-Decoder)架构模型。非常适合文本生成和重构类任务,如文本摘要、去噪、句子修正等。
需求 | 首选模型类型 |
---|---|
需要最强的能力,不想自己维护 | 闭源API (GPT-4, Claude 3) |
需要生成文本(对话、创作) | GPT-like / Llama (Llama 3, Mistral, 闭源API) |
需要理解文本(分类、情感分析) | BERT 及其变体 (在NLU任务上仍是常青树) |
注重成本、数据隐私、需要定制化 | 开源模型 (Llama, Mistral, Qwen) + 自行微调部署 |
需要处理长文档 | Claude (闭源) 或 Mistral (开源) |
需要多模态能力(识图) | GPT-4V, Gemini 1.5 |
关于DeepSeek模型
模型名称 | 特点 | 应用场景 |
---|---|---|
DeepSeek V3 | 6710亿参数混合专家模型(MoE),支持长上下文,每个token仅激活370亿参数,高效节能。 | 通用对话、内容创作、多语言任务、代码辅助 |
DeepSeek R1 | 专注逻辑推理和复杂问题解决,有满血版(671B)和多种蒸馏版(如1.5B-32B),强化学习优化 | 数学推理、编程挑战、决策支持、需要深度思考的任务 |
DeepSeek Coder | 专为代码生成与理解设计,架构类似Llama | 程序员编程辅助、代码补全、调试 |
DeepSeek Math | 专注于数学推理,通过强化学习优化 | 数学解题辅助、奥数训练 |
DeepSeek LLM | 开源通用大语言模型,通过监督微调提升多任务处理能力,有7B/67B等版本 | 学术研究、企业应用开发、微调基础模型 |
DeepSeek MoE | 采用混合专家(MoE) 架构,提升模型效率,例如DeepSeek MoE-16B推理成本接近7B模型 | 高性价比的复杂任务处理,如长文本生成和多轮对话 |
全能王 (闭源):GPT-4。如果你想要目前最全面、最省心的模型,选它。
全能王 (开源):Llama 3。如果你想要一个免费、可修改、可商用的顶级模型作为基础,选它。
理解王:BERT。如果你的任务是传统的文本分类、情感分析,它轻量且高效。
长文本专家:DeepSeek-V3。如果你的核心需求是处理长文档、代码库,或者非常看重中文能力和推理性价比,选它。
微调框架
框架/方法 | 类型 | 核心思想 | 优点 | 缺点 | 使用场景 |
---|---|---|---|---|---|
全参数微调 | 全量 | 更新模型所有参数 | 效果潜力最大 | 计算成本极高,显存需求大,易发生灾难性遗忘 | 数据量极大且与预训练数据分布差异巨大,不计成本追求极致性能 |
LoRA (Low-Rank Adaptation) | PEFT | 注入低秩矩阵来近似权重更新 | 极大节省显存和存储,多个任务适配器可快速切换,效果接近全量微调 | 需要选择目标模块和设置秩®,超参增多 | 最通用,适用于绝大多数企业场景(客服、内容生成等) |
QLoRA | PEFT | LoRA + 4bit量化 | 在LoRA基础上进一步大幅降低显存需求,可在单卡消费级GPU上微调大模型 | 量化可能带来极轻微的精度损失 | 资源受限的中小企业或个人开发者,低成本实验和部署 |
Adapter | PEFT | 在Transformer层中插入小型全连接层 | 模块化设计,易于添加和移除 | 会增加模型推理的延迟 | 需要为不同任务动态加载不同模块的多任务学习系统 |
P-Tuning v2 | PEFT | 在输入序列中加入可训练的连续提示向量 | 不修改模型原始权重,无推理延迟 | 效果通常略逊于LoRA,提示向量长度是超参 | 模型权重无法修改的黑盒API微调,或对推理延迟极度敏感的场景 |
首选 LoRA/QLoRA:对于绝大多数企业微调任务,LoRA 是起点。它是性能、效率和易用性的最佳平衡点。如果硬件资源紧张,直接使用 QLoRA。
特殊场景考虑其他方案:需要动态多任务时考虑 Adapters;模型权重完全不可修改时考虑 P-Tuning。
放弃全量微调:除非你有海量数据和巨量算力,且性能提升至关重要,否则在2024年及以后,PEFT方法几乎总是更优的选择。
智能客服问答系统(使用 LoRA)
背景:一家电商公司希望将其客服历史记录微调一个开源模型(如Llama 3),以更准确地回答关于产品、物流和退款的问题。
优势选择:LoRA 是此场景的最佳选择。它效果出色,且训练出的适配器(Adapter)文件很小(通常几MB到几百MB),便于管理和部署。可以为一个产品线训练一个LoRA适配器,灵活切换。
代码实现 (使用 Hugging Face TRL 和 PEFT):
# 安装: pip install transformers peft accelerate datasets trl bitsandbytes
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer
from datasets import load_dataset
import torch
# 1. 加载模型并配置4bit量化 (QLoRA)
model_name = "meta-llama/Meta-Llama-3-8B"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token # 设置pad_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
model = prepare_model_for_kbit_training(model) # 为K比特训练准备模型
# 2. 配置 LoRA
lora_config = LoraConfig(
r=8, # Rank
lora_alpha=32, # 缩放参数
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], # 针对Llama架构
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters() # 查看可训练参数量,通常不足1%
# 3. 加载并处理企业数据
dataset = load_dataset("json", data_files={"train": "customer_service_data.jsonl"})
def format_function(examples):
text = [f"<|user|>\n{q}\n<|assistant|>\n{a}{tokenizer.eos_token}"
for q, a in zip(examples['question'], examples['answer'])]
return {"text": text}
formatted_dataset = dataset.map(format_function, batched=True)
# 4. 使用SFTTrainer进行训练
training_args = TrainingArguments(
output_dir="./llama3-customer-service-lora",
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
num_train_epochs=3,
logging_dir="./logs",
logging_steps=10,
save_strategy="epoch",
fp16=True,
max_grad_norm=0.3,
)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=formatted_dataset["train"],
dataset_text_field="text",
max_seq_length=1024,
tokenizer=tokenizer,
)
trainer.train()
trainer.save_model() # 保存的只有LoRA权重,非常小
# 5. 推理时加载并使用
from peft import PeftModel
base_model = AutoModelForCausalLM.from_pretrained(...)
model = PeftModel.from_pretrained(base_model, "./llama3-customer-service-lora")
低成本概念验证(使用 QLoRA)
背景:一个初创团队想用一个较小的预算,在单张RTX 4090 (24GB) 上微调一个7B模型,验证一个法律文档分析的想法。
优势选择:QLoRA 是资源受限情况下的不二之选。它让微调大模型的门槛从数万人民币的A100/H100降低到了消费级显卡。
上述LoRA代码已经使用了4bit量化(BitsAndBytesConfig),这本身就是QLoRA。关键区别在于量化配置和prepare_model_for_kbit_training步骤。代码完全通用。
多任务企业助手(使用 Adapters)
背景:一家大型企业需要同一个模型(如Llama 3)扮演不同角色:有时是IT支持助手(解决电脑问题),有时是HR助手(回答请假政策),有时是财务助手(解释报销流程)。
优势选择:Adapter 的模块化特性非常适合此场景。可以为每个任务训练一个独立的Adapter。在推理时,根据用户请求的类型,动态加载对应的Adapter到模型中,实现“一个模型,多种专业角色”的效果。
# 安装: pip install adapter-transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from adapter_transformers import AdapterConfig
# 1. 加载基础模型
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
# 2. 添加并训练IT支持任务的Adapter
it_config = AdapterConfig.load("pfeiffer") # 一种Adapter配置
model.add_adapter("it_support", config=it_config)
model.train_adapter("it_support") # 只训练Adapter参数
# ... (在此处准备IT支持数据并进行训练) ...
model.save_adapter("./adapters/it_support", "it_support")
# 3. 添加并训练HR任务的Adapter
hr_config = AdapterConfig.load("pfeiffer")
model.add_adapter("hr_policy", config=hr_config)
model.train_adapter("hr_policy")
# ... (在此处准备HR数据并进行训练) ...
model.save_adapter("./adapters/hr_policy", "hr_policy")
# 4. 推理时动态切换Adapter
# 用户询问IT问题
model.set_active_adapters("it_support")
it_response = generate_response(user_it_question, model, tokenizer)
# 用户询问HR问题
model.set_active_adapters("hr_policy")
hr_response = generate_response(user_hr_question, model, tokenizer)
# 5. 也可以将所有Adapter合并到主权重中(可选,会失去灵活性)
model.merge_adapter("it_support")
大模型推理优化的工具、方法
量化 (Quantization)
将模型权重和激活值从高精度(如FP32, FP16)转换为低精度(如INT8, INT4)的过程。
这是最重要的优化手段,能直接减少模型体积和内存占用,提升计算速度。
GPTQ:一种训练后量化方法,需要少量校准数据,精度损失极小。非常适合GPU部署。
优点:高精度,支持到2-4bit,性能提升显著。
缺点:量化过程需要GPU和校准时间。
AWQ:一种新兴的量化方法,认为“权重并非同等重要”,通过激活感知来保护 salient weight(重要权重)
优点:相比GPTQ,可能在泛化性和精度上更有优势。
缺点:生态相对GPTQ稍新。
GGUF (原GGML):专为CPU优化的量化格式,通常与llama.cpp绑定。支持多种量化级别(q4_0, q5_0, q8_0等)
优点:可在CPU上高效运行,无需GPU,部署极其简单。
缺点:在GPU上性能不如GPTQ专为GPU优化的版本。
模型编译与加速运行时 (Compilation & Acceleration Runtimes)
将模型计算图编译优化,并利用高性能内核来最大化硬件利用率。
vLLM:当前GPU上首选的推理服务器。其核心是PagedAttention算法,有效解决了KV缓存的内存管理问题,极大地提升了吞吐量。
优点:高吞吐、动态批处理、开源、支持Continuous Batching(连续批处理)。
缺点:主要专注于GPU,对量化的支持在逐步完善。
TensorRT-LLM:NVIDIA官方推出的推理SDK。对NVIDIA GPU进行了极致优化,包括算子融合、内核优化、量化支持等。
优点:延迟最低,性能极致。
缺点:生态相对封闭(需NVIDIA环境),编译过程稍复杂。
ONNX Runtime:微软推出的高性能推理引擎。支持多硬件后端(CPU, GPU, NPU),通过将模型转换为ONNX格式并进行图优化来加速。
优点:硬件支持广泛,工业化程度高。
缺点:对于动态形状的生成式任务,优化效果有时不如vLLM/TRT-LLM。
投机采样 (Speculative Sampling)
一种先进的解码策略。使用一个小而快的“草稿模型”来预先生成多个token,然后让原始大模型”一次性地“验证这些token。如果验证通过,则一步生成多个token,极大提升生成速度。
优点:理论上可提升2-3倍解码速度,与大模型兼容性好。
优点:理论上可提升2-3倍解码速度,与大模型兼容性好。
高并发智能客服API服务 (优化工具: vLLM)
背景:一家企业将微调好的Llama 3客服模型部署为REST API,需要应对来自官网、APP的数千并发用户请求,要求高吞吐、低延迟。
解决方案:使用vLLM部署经过GPTQ量化的模型,利用其PagedAttention和Continuous Batching能力。
①、首先对模型进行GPTQ量化 (使用auto_gptq库)
# 安装: pip install auto-gptq
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from transformers import AutoTokenizer
model_name = "meta-llama/Llama-3-8B"
quantized_model_dir = "./llama3-8b-gptq-4bit"
tokenizer = AutoTokenizer.from_pretrained(model_name)
quantize_config = BaseQuantizeConfig(bits=4, group_size=128, desc_act=False)
model = AutoGPTQForCausalLM.from_pretrained(
model_name,
quantize_config,
trust_remote_code=True
)
# 准备校准数据(示例)
examples = [
tokenizer("A large language model (LLM) is", return_tensors="pt").input_ids[0]
for _ in range(100)
]
# 量化并保存
model.quantize(examples)
model.save_quantized(quantized_model_dir)
tokenizer.save_pretrained(quantized_model_dir)
②、使用vLLM部署量化后的模型
vLLM自动处理动态批处理,当一个请求生成完毕时,可以立即处理下一个请求,GPU利用率极高,完美满足高并发场景。
# 安装: pip install vllm
from vllm import SamplingParams, LLM
from fastapi import FastAPI
import uvicorn
app = FastAPI()
# vLLM引擎是核心,它自动处理批处理、内存管理等
llm = LLM(
model=quantized_model_dir,
quantization="gptq", # 指定量化格式
max_model_len=1024,
gpu_memory_utilization=0.9,
enforce_eager=True, # 对于量化模型,有时需要设置
)
@app.post("/generate")
async def generate_endpoint(request: dict):
prompt = request["prompt"]
sampling_params = SamplingParams(
temperature=0.7,
max_tokens=200,
stop=["\n用户:"]
)
outputs = llm.generate(prompt, sampling_params)
return {"response": outputs[0].outputs[0].text}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
游戏内实时NPC对话 (优化工具: TensorRT-LLM)
背景:一款3A游戏需要在玩家与NPC对话时提供极低延迟(<100ms)的文本生成,确保游戏体验不中断。
解决方案:使用TensorRT-LLM将模型编译为高度优化的引擎,追求极致单次推理延迟。
TensorRT-LLM提供了最低的延迟,确保了游戏的实时性。编译后的引擎是高度优化的,消除了框架开销。
# 1. 构建TRT-LLM引擎(通常在开发环境完成)
# 假设我们已经有一个Hugging Face格式的模型
# 使用TRT-LLM的命令行工具构建引擎
python build.py --model_dir ./my_npc_model \
--dtype float16 \ # 也可用--quant_config指定量化
--use_gpt_attention_plugin \
--use_gemm_plugin \
--output_dir ./trt_engines/npc_engine \
--max_batch_size 8 \
--max_input_len 512 \
--max_output_len 128
# 2. 在游戏推理服务器中部署TRT引擎(C++/Python API)
from tensorrt_llm.runtime import ModelRunner
import numpy as np
# 加载编译好的引擎
runner = ModelRunner.from_dir(
engine_dir='./trt_engines/npc_engine',
lora_dir=None # 可支持LoRA
)
def generate_response(prompt):
sampling_params = {
"temperature": 0.85,
"max_new_tokens": 100,
}
# TRT-LLM的输入输出通常是numpy数组
input_ids = tokenizer.encode(prompt, return_tensors="np").astype(np.int32)
output_ids = runner.generate(input_ids, sampling_params)
response = tokenizer.decode(output_ids[0])
return response
# 游戏服务器循环
while True:
player_input = get_player_input() # 从游戏网络获取输入
npc_response = generate_response(player_input)
send_to_game_client(npc_response) # 将响应发送回游戏客户端
边缘设备上的离线文档摘要 (优化工具: GGUF + llama.cpp)
背景:一家法律科技公司需要在其律师的笔记本电脑上部署一个离线文档摘要模型,处理敏感的客户案宗,无法使用GPU,且要求部署简单。
解决方案:将模型转换为GGUF格式,使用llama.cpp在CPU上运行。
①、将模型转换为GGUF格式 (使用llama.cpp项目中的convert.py)
# 克隆 llama.cpp 项目
git clone https://github.com/ggerganov/llama.cpp
cd llama.cpp
# 安装依赖并编译项目
make
# 将Hugging Face模型转换为GGUF格式(通常为FP16)
python convert.py ~/models/my_summary_model --outtype f16 --outfile ~/models/my_summary_model.gguf
# 进一步量化(可选,以减少体积和内存)
./quantize ~/models/my_summary_model.gguf ~/models/my_summary_model_q4_0.gguf q4_0
②、在客户笔记本电脑上使用llama.cpp的Python绑定进行推理
# 安装: pip install llama-cpp-python
from llama_cpp import Llama
import requests
# 初始化模型 - 指定GGUF文件路径
llm = Llara(
model_path="./models/my_summary_model_q4_0.gguf", # 量化后模型
n_ctx=2048, # 上下文长度
n_gpu_layers=0, # 0表示纯CPU运行。如果有Mac M芯片或GPU,可设置offload的层数
verbose=False
)
def summarize_document(document_text):
prompt = f"请为以下法律文档生成摘要:\n{document_text}"
# 同步生成
output = llm.create_completion(
prompt,
max_tokens=300,
temperature=0.1, # 摘要任务温度低一些
stop=["。", "\n"]
)
return output['choices'][0]['text']
# 使用案例
with open("legal_doc.txt", "r") as f:
doc_text = f.read()
summary = summarize_document(doc_text)
print(summary)
DeepSpeed大模型训练优化工具
主要用于分布式训练领域的优化,尤其擅长显存管理和计算加速,与之前提到的专注于推理优化的工具(vLLM, TensorRT-LLM)处于不同的赛道。
DeepSpeed 是一个专注于训练阶段、尤其擅长显存管理和分布式计算优化的强大工具。它与 vLLM、TensorRT-LLM 等推理优化工具不是替代关系,而是互补关系。如下情况时,就可以使用DeepSpeed
- 需要训练非常大的模型,而单个GPU或服务器的显存无法满足时,DeepSpeed 几乎是必需品
- 拥有多GPU服务器或集群,并希望最大化利用硬件资源,加速训练过程时
- 使用 Hugging Face Transformers 进行训练,并且希望以相对简单的方式引入强大的分布式优化时
常见的工作流:
- 使用 DeepSpeed(配合其丰富的ZeRO阶段和Offload策略)在有限的硬件上高效地训练或微调大型模型
- 训练完成后,使用推理优化工具(如vLLM, TensorRT-LLM)对训练好的模型进行部署和服务,以实现高吞吐、低延迟的推理。
微调大型语言模型
企业场景:一家企业希望在一台具有多张GPU(每张显存有限,如40GB)的服务器上,微调一个超过700亿参数的大模型(如Falcon 180B或LLaMA 2 70B),用于内部的金融文档分析。
挑战:模型本身远超单卡甚至多卡显存容量,无法使用常规方法加载和训练。
解决方案:使用 DeepSpeed 的 ZeRO-Stage 3 结合 CPU Offload 来分布式地存储和优化模型参数、梯度、优化器状态,从而将显存占用控制在可行范围内。
代码实现(结合 Hugging Face Transformers Trainer):
①、安装 DeepSpeed:
pip install deepspeed
②、创建 DeepSpeed 配置文件 (ds_config.json):
{
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": "auto",
"zero_optimization": {
"stage": 3, //使用ZeRO第三阶段,将参数、梯度、优化器状态全部分区。
"offload_optimizer": { //将它们卸载到CPU内存,这是节省GPU显存的关键
"device": "cpu",
"pin_memory": true
},
"offload_param": { //将它们卸载到CPU内存,这是节省GPU显存的关键
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_bucket_size": 5e8,
"stage3_prefetch_bucket_size": 5e8,
"stage3_param_persistence_threshold": 1e6,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9
},
"bf16": { //使用bfloat16混合精度训练,进一步节省显存并加速计算
"enabled": true
},
"gradient_clipping": 1.0,
"wall_clock_breakdown": false
}
③、修改训练代码(基于 Transformers Trainer):
from transformers import TrainingArguments, Trainer, AutoModelForCausalLM, AutoTokenizer
# 1. 加载模型和分词器
model_name = "tiiuae/falcon-180B" # 以 Falcon 180B 为例
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 注意:使用 device_map="auto" 可能无法直接处理如此大的模型,这正是DeepSpeed的用武之地
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
# 2. 设置 TrainingArguments,关键是指定 deepspeed 配置文件路径
training_args = TrainingArguments(
output_dir="./results",
per_device_train_batch_size=1, # 根据显存调整,Micro batch size
gradient_accumulation_steps=8, # 梯度累积步数,模拟更大的全局batch size
learning_rate=2e-5,
num_train_epochs=3,
logging_dir="./logs",
logging_steps=10,
# 最重要的一行:启用 DeepSpeed
deepspeed="./ds_config.json", # 指向你的配置文件
fp16=False, # 如果在配置文件中启用了 bf16,这里就关闭 fp16
dataloader_pin_memory=True,
)
# 3. 创建 Trainer 并开始训练
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset, # 你的训练数据集
data_collator=data_collator,
)
trainer.train()
④、启动训练
使用 deepspeed 启动器来运行你的训练脚本,它会自动处理分布式环境
deepspeed --num_gpus=4 train.py
可以使用 accelerate 配置后启动:
accelerate config # 进行配置
accelerate launch --deepspeed_config ds_config.json train.py
:cite[2]
更多推荐
所有评论(0)