前言


本文针对中文 BERT 文本分类项目的核心入口(Main 函数) 展开逐行拆解,适合刚接触 PyTorch 和 BERT 的零基础开发者。全程避开复杂公式,聚焦「代码作用 + 参数含义 + 实操调参技巧 + 常见坑点」,以酒店评论情感分类(二分类)为例,代码可直接复制运行,稍作修改即可适配你自己的文本分类场景(如新闻分类、评论质检等)。


一、完整 Main 函数代码(可直接复制)

import random
import torch
import torch.nn as nn
import numpy as np
import os

# 导入自定义工具模块(项目核心依赖,需根据自身项目路径调整)
from model_utils.data import get_data_loader  # 数据加载工具
from model_utils.model import myBertModel    # 自定义BERT分类模型
from model_utils.train import train_val      # 训练&验证核心函数

# 固定随机种子:保证实验结果可复现(新手必学)
def seed_everything(seed):
    torch.manual_seed(seed)                # 固定PyTorch CPU随机种子
    torch.cuda.manual_seed(seed)           # 固定PyTorch GPU随机种子(单卡)
    torch.cuda.manual_seed_all(seed)       # 固定PyTorch GPU随机种子(多卡)
    torch.backends.cudnn.benchmark = False # 关闭CUDA自动优化(避免结果波动)
    torch.backends.cudnn.deterministic = True # 强制CUDA按固定逻辑计算
    random.seed(seed)                      # 固定Python原生随机种子
    np.random.seed(seed)                   # 固定numpy随机种子
    os.environ['PYTHONHASHSEED'] = str(seed) # 固定Python哈希值(避免随机)

# 调用种子函数,种子值可自定义(0/123/456均可)
seed_everything(0)

# ===================== 1. 配置核心训练超参数(新手重点调参区) =====================
lr = 0.0001                # 学习率(BERT微调关键参数)
batchsize = 16             # 批次大小(根据GPU内存灵活调整)
loss_fn = nn.CrossEntropyLoss() # 损失函数(二分类/多分类通用)
bert_path = "bert-base-chinese" # 预训练BERT路径(中文通用版,自动下载)
num_class = 2              # 分类类别数(二分类设2,多分类按实际改)
data_path = "jiudian.txt"  # 数据集路径(酒店评论数据,需自己准备)
max_acc = 0.6              # 模型保存阈值(超过该准确率才保存)
# 自动选择训练设备(有GPU用GPU,无GPU用CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"

# ===================== 2. 初始化核心组件 =====================
# 初始化BERT模型,并移到指定设备(GPU/CPU)
model = myBertModel(bert_path, num_class, device).to(device)
# 初始化优化器(BERT微调标配AdamW,防过拟合)
optimizer = torch.optim.AdamW(
    model.parameters(),  # 要优化的模型参数(固定写法)
    lr=lr,               # 学习率
    weight_decay=1e-5    # 权重衰减(轻微惩罚大参数,防过拟合)
)
# 初始化数据加载器(自动处理数据、分批喂给模型)
train_loader, val_loader = get_data_loader(data_path, batchsize)

# ===================== 3. 配置训练辅助参数 =====================
epochs = 5                 # 总训练轮数(把所有数据学5遍)
save_path = "model_save/best_model.pth" # 最优模型保存路径
# 学习率调度器(动态调整学习率,避免模型学不动)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=20, eta_min=1e-9
)
val_epoch = 1              # 验证间隔:每训练1轮,验证1次模型效果

# ===================== 4. 打包参数(简化训练函数传参) =====================
train_para = {
    "model": model,          # 待训练模型
    "train_loader": train_loader, # 训练集数据加载器
    "val_loader": val_loader,     # 验证集数据加载器
    "scheduler": scheduler,       # 学习率调度器
    "optimizer": optimizer,       # 优化器
    "loss_fn": loss_fn,           # 损失函数
    "epochs": epochs,             # 总训练轮数
    "device": device,             # 训练设备
    "save_path": save_path,       # 模型保存路径
    "max_acc": max_acc,           # 保存模型的准确率阈值
    "val_epoch": val_epoch        # 验证间隔
}

# ===================== 5. 启动训练&验证(一键启动) =====================
train_val(train_para)

模块 2:固定随机种子(避坑关键)

核心作用

AI 训练中大量用到随机操作(参数初始化、数据打乱等),不固定种子会导致「每次跑代码结果都不一样」(比如这次准确率 80%,下次 75%),固定种子后实验结果可复现。

新手提示

❗️ 注意:种子值(如 0)可随便改,但一旦确定就不要动,否则结果无法复现;多 GPU 训练时需额外配置torch.distributed,新手先单卡训练即可。

模块 3:配置核心超参数(重点调参区)

参数 取值 通俗解释 调参技巧
lr=0.0001 1e-4 模型更新参数的 “步长” BERT 微调优先用 2e-5/1e-4/5e-5,步长太小学得慢,太大易震荡
batchsize=16 16 每次喂给模型 16 个样本 GPU 内存小设 8/16,内存大设 32/64,超了会报「CUDA out of memory」
loss_fn=CrossEntropyLoss() 交叉熵损失 模型的 “纠错工具” 二分类 / 多分类通用,不用改;回归任务换MSELoss
bert_path="bert-base-chinese" 中文 BERT 预训练模型路径 需联网,首次运行自动下载;也可下载到本地,填本地路径
num_class=2 2 分类类别数 情感分类(正 / 负)设 2,新闻分类(体育 / 财经 / 娱乐)设 3,依实际改
data_path="jiudian.txt" 酒店评论数据 数据集路径 数据格式要求:每行「标签,文本」(如好评,这家酒店超棒
device="cuda/cpu" 自动选择 训练设备 有 GPU 一定要用,训练速度比 CPU 快 10 倍以上

模块 4:初始化核心组件(训练的 “核心动力”)

4.1 初始化 BERT 模型
model = myBertModel(bert_path, num_class, device).to(device)
  • 作用:加载预训练 BERT,拼接分类头(适配 2 分类任务),并把模型移到 GPU/CPU;
  • 新手提示:myBertModel是自定义的,核心逻辑是「BERT + 全连接层」,不用自己写,直接用 HuggingFace 的BertForSequenceClassification也可。
4.2 初始化优化器(AdamW)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
  • 核心作用:根据模型的 “错误程度” 更新参数,让模型越来越准;
  • 为什么选 AdamW?:BERT 微调标配,比普通 Adam 多了「权重衰减」,能防过拟合;
  • 新手提示:weight_decay(权重衰减)取值 1e-5~1e-4,小数据集用 1e-5 更稳。
4.3 初始化数据加载器
train_loader, val_loader = get_data_loader(data_path, batchsize)
  • 核心作用:把jiudian.txt的原始文本转换成模型能看懂的数字(token_id),并按批次打包;
  • 内部逻辑(不用写代码,懂流程即可):
    1. 读取文本 + 标签 → 2. 划分训练集 / 验证集(8:2) → 3. 文本编码 → 4. 分批打包;
  • 新手提示:数据格式必须是「标签,文本」,不要有空行 / 乱码,否则加载失败。

模块 5:配置辅助参数 + 打包传参

5.1 辅助参数说明
参数 作用
epochs=5 总训练轮数,新手先设 5~10 轮,看效果再调整
save_path="model_save/best_model.pth" 保存最优模型,后续可加载模型做预测
scheduler 动态调整学习率,避免模型后期 “学不动”
val_epoch=1 每训练 1 轮验证 1 次,及时监控效果
5.2 打包参数字典

把所有参数打包成train_para字典,避免给train_val传 10 + 个参数,代码更简洁,新手后期加参数(如dropout)直接往字典里加即可。

模块 6:启动训练(一键运行)

train_val(train_para)
  • 核心作用:调用自定义的train_val函数,自动执行「训练→验证→保存模型」全流程;
  • 新手提示:train_val内部会遍历train_loader喂数据、计算损失、更新参数,验证时计算准确率,超过max_acc就保存模型。

三、实操提示(新手避坑)

1. 环境配置

# 安装核心依赖,一行命令搞定
pip install torch transformers pandas numpy
  • transformers:HuggingFace 库,加载 BERT 的核心;
  • pandas/numpy:处理数据用。

2. 数据格式要求(必须遵守)

jiudian.txt示例:

好评,这家酒店服务超棒,环境也很好
差评,房间小还吵,性价比极低
好评,早餐种类多,离地铁口近

❗️ 注意:标签要和num_class对应(二分类设 2 个标签,如好评 / 差评),文本不要有特殊符号(如¥、&)。

3. 显存溢出解决

  • 降低batchsize(从 16→8→4);
  • 模型移到 CPU(device="cpu"),但训练速度慢;
  • 用更小的预训练模型(如bert-base-chinese-mini)。

四、常见问题 & 解决方法

问题 原因 解决
训练结果不可复现 没固定随机种子 / 开了 CUDA 自动优化 严格执行seed_everything函数
准确率一直低于max_acc 学习率太大 / 数据量太少 / 模型没适配 lr降到 2e-5,增加数据量,检查模型分类头是否正确
报错「找不到 model_utils」 自定义模块路径不对 确认model_utils和 Main 函数在同一目录,或改导入路径
模型保存失败 没有model_save文件夹 先执行os.makedirs("model_save", exist_ok=True)创建文件夹

五、总结

本文拆解的 Main 函数是 BERT 文本分类的「通用模板」,核心逻辑可总结为:

  1. 准备工具(导入依赖)→ 2. 定规矩(固定种子、设参数)→ 3. 备物料(初始化模型 / 优化器 / 数据)→ 4. 打包启动训练;新手只需改「数据集路径、类别数、学习率 / 批次大小」,就能适配自己的文本分类任务。

如果运行中遇到问题,优先检查「数据格式、GPU 内存、参数取值」,这是新手最容易踩坑的地方。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐