LoRA微调OpenAI Whisper:PEFT实现中文语音识别的实战指南

1. 背景与原理
  • Whisper模型:OpenAI开源的语音识别模型,支持多语言识别,基础架构为编码器-解码器Transformer。
  • LoRA(低秩适应):通过注入低秩分解矩阵$A \in \mathbb{R}^{r \times k}$和$B \in \mathbb{R}^{d \times r}$($r \ll d$),仅微调少量参数: $$ W' = W + BA $$ 其中$W$为原始权重,$r$为秩(通常$r=8$)。
  • PEFT库:Hugging Face提供的参数高效微调工具,支持LoRA等轻量级适配技术。
2. 环境准备
# 安装依赖
!pip install transformers datasets peft soundfile librosa evaluate jiwer

3. 数据准备
  • 数据集:使用中文语音数据集(如AISHELL-1Common Voice中文
  • 预处理
    from datasets import load_dataset
    
    # 加载数据集(示例:Common Voice)
    dataset = load_dataset("mozilla-foundation/common_voice_11_0", "zh-CN", split="train[:10%]")
    
    # 重采样至16kHz(Whisper输入要求)
    from librosa import resample
    import soundfile as sf
    
    def resample_audio(batch):
        audio, sr = sf.read(batch["path"])
        audio_16k = resample(audio, orig_sr=sr, target_sr=16000)
        sf.write(batch["path"], audio_16k, 16000)
        return {"audio": batch["path"]}
    
    dataset = dataset.map(resample_audio)
    

4. 模型与处理器配置
from transformers import WhisperProcessor, WhisperForConditionalGeneration

# 加载预训练模型(以whisper-small为例)
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="zh", task="transcribe")

# 配置LoRA参数
from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=8,                  # 秩
    lora_alpha=32,        # 缩放因子
    target_modules=["q_proj", "v_proj"],  # 目标模块(Whisper的注意力层)
    lora_dropout=0.1,
    bias="none"
)
model = get_peft_model(model, config)
model.print_trainable_parameters()  # 输出:可训练参数≈0.5%

5. 训练流程
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

# 数据整理函数
def prepare_dataset(batch):
    audio = batch["audio"]
    input_features = processor(audio, sampling_rate=16000).input_features[0]
    labels = processor.tokenizer(batch["sentence"]).input_ids
    return {"input_features": input_features, "labels": labels}

dataset = dataset.map(prepare_dataset, remove_columns=dataset.column_names)

# 训练参数
training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-lora-zh",
    per_device_train_batch_size=8,
    learning_rate=1e-4,
    num_train_epochs=3,
    fp16=True,
    evaluation_strategy="epoch",
    save_strategy="epoch"
)

# 启动训练
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    tokenizer=processor.tokenizer,
)
trainer.train()

6. 推理与评估
import evaluate

wer_metric = evaluate.load("wer")

# 推理函数
def transcribe(audio_path):
    audio, _ = sf.read(audio_path)
    inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
    predicted_ids = model.generate(inputs=input_features)
    return processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]

# 计算词错率(WER)
references = ["这是测试句子"]
predictions = [transcribe("test_audio.wav")]
wer = wer_metric.compute(references=references, predictions=predictions)
print(f"WER: {wer:.2f}")  # 示例输出:WER: 0.15

7. 优化建议
  • 数据增强:添加噪声、时移变换提升鲁棒性
  • 秩选择:通过实验调整$r$值($r \in [4, 16]$),平衡效果与效率
  • 混合精度训练:启用fp16=True加速训练
  • 模型融合:结合N-gram语言模型纠错

关键点:LoRA微调后仅需保存适配器权重(约10MB),原始Whisper权重保持不变,大幅降低部署成本。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐