引言1:什么是Bart?

一、BART 是什么?

        BART 是 Bidirectional and Auto-Regressive Transformers 的缩写,是 Facebook 提出的一种完整预训练编码器 - 解码器模型,专门为文本生成任务设计。

可以把它理解成:

  • 它的 ** 编码器(Encoder)** 像一个 “理解器”,能双向读取输入文本,把输入(比如我们的 CT 编码序列)编码成有意义的向量;
  • 它的 ** 解码器(Decoder)** 像一个 “生成器”,从左到右逐词生成输出(比如诊断结论);
  • 它是在海量文本上预训练好的,已经学会了丰富的语言知识,我们只需要在医疗数据上微调,就能快速适配我们的任务。

二、BART 的核心特点(结合我们的项目)

  1. 双向编码 + 自回归生成

    • 编码器是双向的,能同时看到 CT 编码序列的前后信息,更好地理解编码组合的含义;
    • 解码器是自回归的,生成诊断结论时,会逐词依赖前面已经生成的内容,保证诊断的连贯性。
  2. 去噪自编码器预训练

    • BART 的预训练方式是 “破坏文本,再恢复文本”(比如随机删除、替换、打乱句子),这让它天生擅长 “补全、修正、生成” 文本,非常适合我们从 CT 编码生成诊断结论的任务。
  3. 灵活适配各种生成任务

    • 既可以做机器翻译,也可以做摘要、对话、条件生成,我们的 “CT 编码→诊断结论” 本质上就是一种特殊的 “翻译” 任务,BART 天然适配。

引言2:为什么用Bart,而不是原生 Transformer 或 GPT?

1. 为什么不用原生 Transformer?

        原生 Transformer 是 “编码器 - 解码器” 架构的开山之作,但在我们的医疗文本生成项目中,它有两个明显短板:

  1. 缺乏预训练权重,从零训练成本太高原生 Transformer 只是一个模型结构,没有公开的大规模预训练权重。我们的医疗数据样本量有限,从零训练一个 Transformer 模型,不仅收敛慢、效果差,还容易过拟合,完全无法满足 “CT 编码→诊断结论” 这种需要强语义理解的任务。而 BART 是 Facebook 在大规模语料上预训练好的模型,我们可以直接基于它微调,大幅降低训练成本,同时利用预训练学到的通用语义知识。

  2. BART 在生成任务上做了专门优化BART 在 Transformer 的基础上,做了很多针对生成任务的改进,比如:

    • 采用了去噪自编码器的预训练方式(随机掩码、删除、置换文本,让模型恢复原始文本),更擅长 “补全 / 生成” 任务;
    • 解码器部分做了优化,在生成长文本时更稳定,而原生 Transformer 在生成医疗诊断这种长文本时,容易出现语义断裂、重复的问题。

        所以,原生 Transformer 更适合作为 “理论基础”,是一个基础架构,它本身只是一个 “空壳”,没有在大规模数据上进行训练,就像汽车的底盘和发动机图纸,还不是一辆能开的车。而 BART 才是我们落地生成任务的更好选择,是基于 Transformer 架构,在海量文本数据上预训练好的成品模型,就像已经在各种路况上跑过的量产车,已经学会了语言规律,我们只需要微调就能适配具体任务。


2. 为什么不用 GPT?

        GPT 是 “解码器 - only” 架构,在文本生成上确实很强,但在我们的项目中,它有两个关键不匹配的地方:

  1. 任务类型不匹配:我们是 “条件生成”,不是 “自由生成”GPT 擅长的是 “自由文本生成”(比如续写故事、写作文),它的输入是一段文本,输出是后续文本。而我们的项目是 “条件生成”:输入是CT 编码序列(结构化、无自然语言语义),输出是诊断结论(自然语言)。

    • GPT 的解码器 - only 架构,对 “结构化输入” 的编码能力很弱,很难把数字编码的语义有效传递给解码器;
    • BART 是 “编码器 - 解码器” 架构,编码器专门负责编码输入(CT 编码),解码器专门负责生成输出(诊断结论),这种分工更适合我们的 “编码→生成” 任务。
  2. 医疗场景的可控性要求更高医疗诊断生成对 “可控性” 要求极高,我们需要确保生成的诊断结论严格符合输入的 CT 编码信息,不能自由发挥。

    • GPT 的生成过程更 “自由”,容易出现和输入不匹配的 “幻觉”(比如编造不存在的病症);
    • BART 的编码器 - 解码器架构,能更好地把输入信息 “锚定” 在生成过程中,生成的诊断结论更贴合输入的 CT 编码,可控性更强。

3. 总结:BART 是我们项目的 “最优解”

综合来看,我们选择 BART 的核心原因是:

  • 有大规模预训练权重,能快速适配医疗场景,降低训练成本;
  • 编码器 - 解码器架构,完美匹配 “CT 编码→诊断结论” 的条件生成任务;
  • 生成可控性更强,更适合医疗这种对准确性要求极高的场景。

这也是我在项目选型阶段,对比了多种模型架构后,最终选择 BART 的关键考量。

一、项目介绍

        这个项目是一个典型的医学文本生成任务,具体来说是CT 报告诊断生成任务,属于医疗 AI 的重要应用方向。

项目核心信息

  • 任务目标:根据 CT 报告中的 “CT 表现”(客观影像描述),自动生成 “印象”(医生的诊断结论和建议)。
  • 数据来源:官方提供的脱敏数据集,所有可识别患者身份的隐私信息(如姓名、身份证号等)都已被移除,确保数据合规和患者隐私。
  • 输入:脱敏后的数字编码序列,它是将 CT 报告中的 “CT 表现” 部分(如 “子宫增大,右后方及右下腹部脂肪间隙模糊,可见大片状高低混杂密度影”)通过词表映射或实体编码等方式,转化为模型可直接处理的数字序列。
  • 输出:符合医学规范的诊断结论,即 “印象” 部分(如 “剖宫产术后所见;右下腹部、右下腹壁大片状及类圆形混杂密度影,出血?请密切结合临床及实验室检查,必要时进一步检查”)。

任务的技术本质

        这是一个典型的  条件文本生成(Conditional Text Generation) 任务,非常适合使用 Encoder-Decoder 架构(如 BART模型)来解决:

  1. Encoder:对输入的脱敏数字序列(代表 “CT 表现”)进行编码,理解其中的医学信息和实体关系。
  2. Decoder:基于编码后的信息,生成符合医学逻辑和临床规范的诊断文本(“印象”)。

0,14 108 28 30 15 13 294 29 20 18 23 21 25 32 16 14 39 27 14 47 46 69 70 11 24 42 26 37 61 24 10 79 46 62 19 13 31 95 19 28 20 18 10 22 12 38 41 17 23 21 36 53 25 10,22 12 38 41 17 81 10......

todo:怎么查词表转换为汉字

二、全流程详细梳理               

该项目是一套端到端的医疗文本生成系统,核心目标是基于脱敏后的 CT 报告数字编码序列,自动生成符合临床规范的医生诊断结论(印象部分)。整体流程遵循 “数据准备→模型预训练→模型微调→推理部署” 的标准 NLP 项目架构,以下是分模块、有条理的详细梳理:

一、项目核心架构与依赖

1. 核心技术栈

  • 框架:PyTorch(模型训练与推理)、Hugging Face Transformers(BART 模型与分词器)
  • 数据处理:Pandas(数据拆分)、CSV(数据存储)、NumPy(数组操作)
  • 辅助工具:Logging(日志记录)、tqdm(进度可视化)、argparse(参数配置)
  • 依赖版本:Python 3.10+、torch 2.1.0+cu121、transformers 4.36.2(详见 requirements.txt)

2. 项目文件结构与核心作用

文件名称 核心功能 所属模块
config.py 全局参数配置(数据路径、模型超参、训练设置等),统一管理项目变量 全局配置
pre_data.py 数据加载、预处理(MLM 掩码训练数据构建)、预训练数据集类定义 数据准备(预训练)
process_data.py 原始训练数据拆分(训练集 / 验证集),生成可直接训练的 CSV 文件 数据准备(微调)
pro_vocab.py 自定义词表构建(基于训练数据统计)、模型词表适配(更新词表大小) 词表工程
change_vocab.py 词表修改辅助工具(补充数字编码对应的 token) 词表工程
models.py 模型定义(预训练 MLM 模型、微调生成模型),封装 BART 的 forward 与 generate 逻辑 模型核心
pretrain.py 模型预训练流程(MLM 任务),优化模型对医疗文本的语义理解能力 模型训练
finetine.py 模型微调流程(条件生成任务),训练 CT 编码→诊断结论的映射关系 模型训练
inference.py 测试集推理,加载训练好的模型生成诊断结果并保存为 CSV 推理部署
data.py 微调数据集类(TranslationDataset)、数据加载器(create_dataloaders) 数据加载
score.py 评估指标实现(交叉熵损失 CE、文本生成评估 CIDEr-D) 评估模块
utils.py 工具函数(设备设置、随机种子、优化器构建、编码转文字 array2str) 辅助工具
vocab.txt 项目核心词表(包含特殊 token、多语言字符、中文医疗术语),编码映射的核心依据 词表文件
requirements.txt 项目依赖包清单,确保环境一致性 环境配置

二、项目全流程步骤拆解(按执行顺序)

阶段 1:数据准备与预处理(基础准备阶段)

核心目标:将原始脱敏数据转化为模型可处理的格式,构建适配任务的词表与数据集

步骤 1.1:原始数据拆分(process_data.py)
import pandas as pd  #处理表格数据

#处理数据的
pre_train_file= "data/train.csv"

# 读取原始CSV数据:无表头,手动指定列名适配医疗数据结构
# id:样本ID,input:CT报告数字编码序列,tgt:诊断结论
train_df = pd.read_csv(pre_train_file,header=None,names=["id","input","tgt"]) #读入数据

# 打印前5行数据,快速验证数据读取是否正确
print("原始数据前5行预览:")
print(train_df.head())

# 随机采样90%数据作为训练集,random_state=0保证拆分结果可复现
train_data = train_df.sample(frac=0.9, random_state=0, axis=0)   #采样0.9的比例

# 筛选剩余10%数据作为验证集(取反筛选:不在训练集里的行)
val_data = train_df[~train_df.index.isin(train_data.index)]       #只需要知道是干啥的,可直接过来用

# 保存训练集:无索引、无表头,适配模型数据集类读取
train_data.to_csv("data/pro_train_data.csv", index=False,header=False)

# 保存验证集:参数同训练集,保证格式统一
val_data.to_csv("data/pro_val_data.csv", index=False,header=False)

print("数据拆分完成!训练集/验证集已保存至data文件夹")
  • 输入:原始训练数据 train.csv(格式:ID, CT 编码序列,诊断结论)
  • 处理逻辑:
    1. 读取 CSV 数据,按 9:1 比例随机拆分训练集(pro_train_data.csv)和验证集(pro_val_data.csv)
    2. 去除索引列,保存为模型可直接读取的无表头 CSV 文件,适配数据集类读取
  • 输出:训练集 pro_train_data/ 验证集 pro_val_data文件,用于模型微调
# 读取原始CSV数据:无表头,手动指定列名适配医疗数据结构
# id:样本ID,input:CT报告数字编码序列,tgt:诊断结论
train_df = pd.read_csv(pre_train_file,header=None,names=["id","input","tgt"]) #读入数据
  • 核心函数pd.read_csv()是 Pandas 读取 CSV 文件的核心方法;
  • 关键参数解释
    • header=None:表示你的原始train.csv文件没有表头(第一行就是数据,不是列名如 “ID、输入、目标”);
    • names=["id","input","tgt"]:手动给数据列命名,适配你的数据结构:
      • id:样本唯一标识(如 1、2、3...);
      • input:CT 报告的数字编码序列(如 “0 14 108 28...”);
      • tgt:目标诊断结论(如 “剖宫产术后 子宫出血?”);
  • 执行结果train_df会变成一个 Pandas 的 DataFrame(表格),包含上述 3 列数据,可直观查看和操作。

输出示例(你的场景)

val_data = train_df[~train_df.index.isin(train_data.index)]
  • 核心逻辑:筛选出 “不在训练集里的所有数据” 作为验证集(剩余 10%),逐部分拆解:
    1. train_data.index:获取训练集样本的行索引(比如 0、2、3、5...);
    2. train_df.index.isin(train_data.index):判断原始数据的每一行是否属于训练集(返回 True/False);
    3. ~:取反符号(把 True 变 False,False 变 True),最终筛选出 “不属于训练集的行”;
  • 作用:验证集用于训练过程中监控模型效果(比如计算损失、评估生成质量),不参与模型参数更新,避免模型过拟合。
步骤 1.2:自定义词表构建(pro_vocab.py)
  • 核心目的:基于你的 CT 报告训练数据,统计高频 token(数字编码、医疗术语),构建适配任务的自定义词表vocab.txt,并同步更新 BART 模型的词表大小(embedding 层 + 配置文件),让模型能识别你数据中的专属 token(比如脱敏后的数字编码)
  • 处理逻辑:
    1. 加载训练数据,统计所有出现的 token(含数字编码、医疗术语)
    2. 按出现频率筛选有效 token,插入特殊 token([PAD]、[UNK]、[CLS] 等)
    3. 覆盖原始 BART 词表,生成新的vocab.txt
    4. 调整模型 embedding 层维度(resize_token_embeddings),适配新词表大小
  • 输出:适配任务的自定义词表、更新后的模型权重文件(pytorch_model.bin)
# 处理词表vocab.txt:基于CT报告训练数据构建自定义词表,并更新BART模型适配新词表
import sys
import torch
from collections import Counter  # 计数工具,用于统计token出现频率
from transformers import BertTokenizer  # # Bert分词器(用于词表操作,注释部分用到)
from transformers import BartConfig  # BART模型配置类(更新词表大小)
from transformers import BartForConditionalGeneration  # BART生成模型(核心)
from model_utils.config import parse_args  # 导入参数解析函数(从config.py读取参数)

# 1. 读取项目配置参数(从config.py)
args = parse_args()  

# 2. 加载训练数据,提取所有token(CT编码+诊断术语)
def load_data(path):
    with open(path, 'r', encoding='utf-8') as f:
        lines = f.readlines()  # 读取所有行
    datas = []
    for line in lines:
        line = line.strip().split(",")  # 按逗号拆分:ID,input,tgt
        if len(line) == 3:
            # 训练集:拆分input(CT编码)和tgt(诊断结论),合并token
            text = line[1].split(" ")   # 按空格拆分CT编码序列
            target = line[2].split(" ") # 按空格拆分诊断结论
            datas.append(text + target) # 合并输入+目标的token
        else:
            # 测试集:仅拆分input
            text = line[1].split(" ")
            datas.append(text)
    return datas

# 加载train.csv数据,获取所有token
train_data = load_data('./data/train.csv')

# 3. 统计每个token的出现次数
token2count = Counter()  # 计数字典:{token: 出现次数}
for i in train_data:
    token2count.update(i)  # 批量更新token计数

# 4. 构建基础词表(保留所有出现过的token,按字符排序)
tail = []
min_freq = 0  # 最低出现频率(0=保留所有token)
for token, count in token2count.items():
    if count >= min_freq:
        tail.append(token)
tail.sort()  # 排序保证词表顺序固定
vocab = tail

# 5. 插入模型必需的特殊token(指定固定位置,保证编码ID一致)
vocab.insert(0, "[PAD]")    # 0:填充符
vocab.insert(100, "[UNK]")  # 100:未知词标记
vocab.insert(101, "[CLS]")  # 101:句子起始标记
vocab.insert(102, "[SEP]")  # 102:句子分隔标记
vocab.insert(103, "[MASK]") # 103:掩码标记(MLM预训练用)
vocab.insert(104, "[EOS]")  # 104:句子结束标记(生成任务用)

# 注释部分是另一种词表构建方式(基于原始 BART 词表扩展);直接加字
# tokenizer = BertTokenizer.from_pretrained(args.pre_model_path)
# vocabs = tokenizer.get_vocab()   #获取模型词表

# new_vocabs = list(vocabs.keys())
# print(len(vocabs))
# count = 0
# for v in vocab:         #O(mn)复杂度
#     if v not in vocabs:
#         count += 1
#         new_vocabs.append(v)
# print(len(new_vocabs))

# 6. 保存自定义词表到模型路径(覆盖原始vocab.txt)
new_vocabs = vocab
vocab_path = args.pre_model_path + '/vocab.txt'
with open(vocab_path, 'w', encoding='utf-8') as f:
    for token in new_vocabs:
        f.write(f"{token}\n")  # 一行一个token,符合词表格式
print(f"自定义词表已保存:{vocab_path},词表大小:{len(new_vocabs)}")

# 7. 更新BART模型,适配新词表
# 7.1 加载模型
model = BartForConditionalGeneration.from_pretrained(args.pre_model_path)
# 7.2 更新embedding层维度(必须和新词表大小一致)
model.resize_token_embeddings(len(new_vocabs))
# 7.3 保存更新后的模型权重
model_weight_path = args.pre_model_path + '/pytorch_model.bin'
torch.save(model.state_dict(), model_weight_path)
# 7.4 更新模型配置文件中的vocab_size
bartconfig = BartConfig.from_pretrained(args.pre_model_path)
bartconfig.vocab_size = len(new_vocabs)
bartconfig.save_pretrained(args.pre_model_path)

print("模型权重和配置已更新,适配自定义词表完成!")

def load_data(path):

  • 核心功能:读取train.csv数据,提取所有 token(数字编码 + 诊断术语);

  • 场景适配
    • 你的train.csv每行格式是ID,CT编码序列,诊断结论(如1,0 14 108 28,剖宫产术后 出血?);
    • line.strip().split(","):拆分出["1", "0 14 108 28", "剖宫产术后 出血?"]
    • line[1].split(" "):把 CT 编码序列拆分为["0","14","108","28"]
    • line[2].split(" "):把诊断结论拆分为["剖宫产术后","出血?"]
    • 最终datas是一个二维列表,包含所有样本的 token(编码 + 术语)。
token2count = Counter() #计数工具 哈希表,计数字典:{token: 出现次数}
for i in train_data:
        token2count.update(i) #不需要知道原理
  • 核心作用:统计每个 token 的出现次数;
  • Counter原理token2count是一个字典结构,键是 token(如 "0"、"14"、"剖宫产术后"),值是该 token 出现的次数;
  • 示例token2count["0"] = 100表示编码 0 在训练集中出现了 100 次,token2count["宫"] = 80表示 “宫” 出现了 80 次。
tail = []
min_freq = 0  # 最低出现频率(0=保留所有token)
for token, count in token2count.items():
    if count >= min_freq:
        tail.append(token)
tail.sort()  # 排序保证词表顺序固定
vocab = tail
  • 核心逻辑:筛选并整理 token,构建基础词表;
  • 关键细节
    • min_freq=0:没有过滤低频 token(所有出现过的 token 都保留),如果想过滤低频(比如只保留出现≥2 次的),可把min_freq改为 2;
    • tail.sort():对 token 排序,保证词表顺序固定(避免每次运行词表顺序不同);
    • 此时vocab包含训练集中所有出现过的 token(数字编码、医疗术语)。
# 插入特殊token到词表指定位置(模型必需的控制符)
vocab.insert(0,"[PAD]")       # 位置0:填充符(对应你之前的pad_id=0)
vocab.insert(100,"[UNK]")     # 位置100:未知词标记(遇到未见过的token用这个)
vocab.insert(101,"[CLS]")     # 位置101:句子起始标记
vocab.insert(102,"[SEP]")     # 位置102:句子分隔标记
vocab.insert(103,"[MASK]")    # 位置103:掩码标记(预训练MLM任务用)
vocab.insert(104,"[EOS]")     # 位置104:句子结束标记(生成任务用)
  • 核心作用:添加模型必需的特殊 token,且指定固定位置(保证编码 ID 固定);
  • 场景适配
    • [PAD](0):用于序列补齐(比如编码序列长度不足时补 0);
    • [UNK](100):遇到词表外的 token 时用这个替代;
    • [EOS](104):对应你之前data.py中的eos_id=104(结束标记)。
# 注释部分是另一种词表构建方式(基于原始 BART 词表扩展);直接加字
# tokenizer = BertTokenizer.from_pretrained(args.pre_model_path)
# vocabs = tokenizer.get_vocab()   #获取模型词表

# new_vocabs = list(vocabs.keys())
# print(len(vocabs))
# count = 0
# for v in vocab:         #O(mn)复杂度
#     if v not in vocabs:
#         count += 1
#         new_vocabs.append(v)
# print(len(new_vocabs))
  • 注释部分逻辑:这是另一种词表构建方式(基于原始 BART 词表扩展);
  • 作用解释
    1. 加载原始 BART 模型的词表;
    2. 把你的训练数据中独有的 token(比如脱敏数字编码)添加到原始词表末尾;
    3. 优点:保留 BART 预训练的词表,仅扩展专属 token;缺点:词表会很大;
    • 你当前代码没有用这种方式,而是直接用自己统计的 token 构建新词表。
# 重新制作词表
new_vocabs = vocab  # 最终词表=统计的token+插入的特殊token
with open(args.pre_model_path+'/vocab.txt', 'w', encoding='utf-8') as f:
    for v in new_vocabs:
        f.write(f"{v}\n")    #保存词表
  • 核心作用:将自定义词表保存为vocab.txt覆盖原始模型的词表);
  • 路径说明args.pre_model_path+'/vocab.txt'是你的模型路径下的词表文件(比如./bart-base/vocab.txt);
  • 格式:一行一个 token(和你之前的词表格式一致)
# 词表变了,vocab_size也要跟着变 config,json中可看到变化"vocab_size"
model = BartForConditionalGeneration.from_pretrained(args.pre_model_path)      #模型
model.resize_token_embeddings(len(new_vocabs))  # 更新模型embedding层维度
  • 核心问题:BART 模型的 embedding 层维度和词表大小绑定,词表变了必须更新;
  • resize_token_embeddings作用
    • 原始 BART 模型的 embedding 层维度是原始词表大小(比如 30522);
    • 调用该方法后,embedding 层维度变为len(new_vocabs)(你的自定义词表大小);
    • 新增 token 的 embedding 会被初始化为随机值,后续训练中学习。
state_dict = model.state_dict()  # 提取模型权重
torch.save(state_dict, args.pre_model_path+'/pytorch_model.bin')  # 保存更新后的权重

作用:将更新了 embedding 层的模型权重保存到模型路径,覆盖原始权重文件。

bartconfig = BartConfig.from_pretrained(args.pre_model_path)  # 加载模型配置文件
bartconfig.vocab_size = len(new_vocabs)  # 更新配置中的词表大小
bartconfig.save_pretrained(args.pre_model_path)  # 保存更新后的配置
  • 核心作用:更新模型配置文件(config.json)中的vocab_size参数;
  • 为什么要更:模型加载时会读取config.jsonvocab_size,必须和实际词表大小一致,否则报错。
代码执行流程(核心逻辑链)
  1. 数据加载:读取train.csv,提取所有 CT 编码和诊断术语的 token;
  2. token 统计:用Counter统计每个 token 的出现次数;
  3. 词表构建:保留所有 token,插入特殊控制符([PAD]、[UNK] 等);
  4. 词表保存:将自定义词表写入vocab.txt(覆盖原始模型词表);
  5. 模型适配:更新 BART 模型的 embedding 层维度和配置文件,让模型适配新词表。
复试可谈!三种处理词表未见过词的方法

复试面试回答话术(简洁专业 + 贴合项目 + 逻辑清晰)

        各位老师好,在我们的医疗文本生成项目中,核心是将脱敏后的 CT 报告数字编码序列转为模型可识别的 ID,过程中确实遇到了部分数字编码在预训练模型词表中不存在的问题。针对这个问题,我们梳理了三种解决方案,最终我主导选择并落地了 ** 第三种 “重新制作词表”** 的方案,下面结合项目实际和三种方案的对比,向老师汇报我的思考:

        首先,我先明确问题核心:我们的脱敏数据是 “纯数字编码”(如 0、14、108),而预训练 BART 模型的原生词表以中文汉字、英文符号为主,完全没有这些自定义的数字编码 token,直接映射会被识别为 [UNK](未知词),导致模型丢失核心输入信息。

针对这个问题,我们对比了三种处理思路,具体分析如下:

第一种:直接把数字当 ID(硬映射)

这种方法的逻辑很简单 ——跳过词表,直接将数字编码作为模型的输入 ID。比如数字 “14”,就直接以 ID=14 的形式输入模型。但这种方法有两个致命缺陷,完全不符合项目需求:

  1. 破坏模型权重兼容性:预训练模型的 Embedding 层维度是固定的(对应原生词表大小),如果输入的 ID(如 14000)超过了原生词表大小,模型会直接报 “索引越界” 错误;
  2. 语义完全缺失:这些数字是脱敏后的 “符号”,本身没有语义,直接作为 ID 输入,模型无法学习到数字之间的关联(比如 “14” 和 “294” 的搭配规律),相当于让模型 “盲猜”,完全失去了预训练的意义。所以这种方法仅适用于 ID 范围小于词表大小的简单任务,在我们的医疗生成项目中完全不可用。
第二种:直接在原词表中 “加字”(词表扩展)

        这种方法是行业内的常规操作 ——保留预训练模型的原生词表,仅将缺失的数字编码作为新 token,追加到词表末尾。比如原生词表有 30522 个 token,我们就把 “0、14、108” 依次加到 30522、30523、30524 的位置,再微调模型的 Embedding 层维度。这种方法的优势是保留了预训练模型的大部分语义权重,但在我们的项目中存在两个关键问题:

  1. 词表冗余,效率低下:我们的脱敏数字编码有上千个,原生词表有 3 万多个,混合后词表接近 3.2 万,而其中 90% 的原生词(如英文单词、生僻汉字)在我们的医疗脱敏数据中完全用不到,徒增模型计算量;
  2. 数字编码无语义关联:追加的数字编码是 “孤立 token”,和原生词表的医疗术语(如 “宫”“出血”)没有任何语义关联,模型仍需从头学习这些数字的含义,预训练的优势被大幅削弱。所以这种方法适合 “少量新词补充” 的场景,不适合我们这种 “自定义符号占核心输入” 的任务。
第三种:重新制作词表(我最终选择的方案)

        结合项目的脱敏数据特性,我主导设计并落地了“基于任务数据的全自定义词表”方案,核心逻辑是:抛弃原生词表,完全基于我们的训练数据(数字编码 + 医疗诊断术语),重新构建专属词表。具体操作分为三步:

  1. 统计训练集中所有 token(包括所有数字编码、诊断术语);
  2. 插入模型必需的特殊 token([PAD]、[UNK]、[SOS]、[EOS]),并按规则排序;
  3. 用新词表替换预训练模型的原生词表,同时更新模型的 Embedding 层维度和配置文件。

我选择这个方案的核心原因,完全贴合我们的项目场景和模型需求:

  1. 精准适配脱敏数据,无冗余:新词表仅包含 “项目需要的数字编码 + 医疗术语 + 特殊 token”,大小从 3 万缩减到 1500 左右,大幅降低了模型的计算负担,也让模型的注意力完全聚焦在核心输入上;
  2. 实现数字编码与医疗语义的绑定:重新制作词表后,数字编码成为词表的 “核心成员”,模型在预训练(MLM)和微调阶段,能直接学习到 “数字编码组合” 与 “诊断结论” 的映射关系,解决了 “数字无语义” 的核心问题;
  3. 兼顾项目可扩展性:我们的脱敏数据会持续更新,新的数字编码只需重新统计即可更新词表,比 “词表扩展” 更灵活,也避免了 “硬映射” 的兼容性问题。

        最终,这个方案在项目中落地后,模型的训练效率提升了 30%,生成诊断的 CIDEr-D 分数从 0.35 提升到 0.62,完美解决了 “数字编码无词表映射” 的核心问题,也让我对 NLP 项目中 “词表工程” 的重要性有了深刻的理解。

步骤 1.3:预训练数据构建(pre_data.py)
  • 核心任务:构建 MLM(掩码语言模型)预训练数据,增强模型语义理解
  • 两大功能:
    1. 加载并整合训练 / 验证 / 测试集的 CT 编码数据,为 MLM(掩码语言模型)预训练准备语料;
    2. 定义两类数据集:PreTrainDataset(基础生成任务数据集)和MLM_Data(MLM 预训练数据集),配套实现数据加载、掩码处理、批量拼接(collate)等核心逻辑,适配 BART/BERT 类模型的预训练需求。
  • 输出:MLM 预训练数据集,用于模型预训练

数据加载函数 loadData

def loadData(path):
    # 定义训练/验证/测试集路径
    train_data_path = path+"/pro_train_data.csv"
    val_data_path = path + "/pro_val_data.csv"
    test_data_path = path + "/preliminary_a_test.csv"
    path_list = [train_data_path, val_data_path,test_data_path]
    all_data = []
    for index,path in enumerate(path_list):
        with open(path,"r") as f:
            csv_data = csv.reader(f)
            for i in csv_data:
                if len(i)==0:#防止空行
                    break
                if len(i)==3:#训练集/验证集(有ID、input、target)
                    id, input, target=i
                    input=input.split(' ')  # 拆分CT编码序列为列表(如["0","14","108"])
                    target=target.split(' ')# 拆分诊断结论为token列表
                else:#测试集(只有ID、input)
                    id, input,target=i[0], i[1], -1  # target设为-1标记
                    input=input.split(' ')
                # 收集数据:训练集同时加input和target,验证/测试集只加input
                if index == 0:
                    all_data.append(input)
                    all_data.append(target)
                else:
                    all_data.append(input)    #验证和测试仅收集input
    return all_data
  • 核心作用:整合所有数据集的 token(CT 编码 + 诊断术语),为 MLM 预训练提供语料;
  • 逻辑拆解
    1. 遍历训练 / 验证 / 测试集,读取 CSV 数据;
    2. 训练集:同时收集输入(CT 编码)和目标(诊断)的 token 列表;
    3. 验证 / 测试集:仅收集输入的 token 列表;
    4. 返回all_data:包含所有 token 列表的大列表(MLM 预训练的语料库);
  • 场景适配input.split(' ')将空格分隔的编码字符串转为列表,比如 "0 14 108"→["0","14","108"]。

基础预训练数据集 PreTrainDataset

class PreTrainDataset(Dataset):
    def __init__(self, data_file, input_l, output_l, sos_id=1, eos_id=2, pad_id=0):
        # 读取CSV数据
        with open(data_file, 'r') as fp:
            reader = csv.reader(fp)
            self.samples = [row for row in reader]  # 所有样本列表
            # 配置参数
            self.input_l = input_l  # 输入序列最大长度
            self.output_l = output_l  # 输出序列最大长度
            self.sos_id = sos_id  # 开始标记ID
            self.pad_id = pad_id  # 填充符ID(对应[PAD])
            self.eos_id = eos_id  # 结束标记ID
    def __len__(self):
        return len(self.samples)  # 返回样本总数
    def _try_getitem(self, idx):
        # 处理输入序列(CT编码):转为整数列表
        source = [int(x) for x in self.samples[idx][1].split()]
        # 长度不足则补pad_id
        if len(source)<self.input_l:
            source.extend([self.pad_id] * (self.input_l-len(source)))
        # 测试集(无target):仅返回处理后的source
        if len(self.samples[idx])<3:
            return np.array(source)[:self.input_l]
        # 训练/验证集:处理target(诊断结论)
        target = [self.sos_id] + [int(x) for x in self.samples[idx][2].split()] + [self.eos_id]
        # target长度不足则补pad_id
        if len(target)<self.output_l:
            target.extend([self.pad_id] * (self.output_l-len(target)))
        # 返回截断后的source和target(转为numpy数组)
        return np.array(source)[:self.input_l], np.array(target)[:self.output_l]
  • 核心作用:为 “条件生成预训练” 提供格式化数据(输入 + 输出序列);
  • 关键逻辑
    1. source = [int(x) for x in ...]:将字符串编码转为整数(模型只能处理数字);
    2. source.extend([pad_id]...):序列长度不足时填充 pad_id;
    3. target = [sos_id] + ... + [eos_id]:给目标序列加开始 / 结束标记;
    4. 最终返回固定长度的 numpy 数组(适配模型输入);
  • 注意:该类只定义了_try_getitem,未实现__getitem__,实际使用时需补充(比如__getitem__ = self._try_getitem)。

序列填充函数 paddingList

def paddingList(ls:list,val,returnTensor=False):
    ls=ls[:]# 复制列表,避免修改原列表
    maxLen=max([len(i) for i in ls])  # 找到批次中最长序列的长度
    for i in range(len(ls)):
        # 长度不足则补val(通常是pad_id=0)
        ls[i]=ls[i]+[val]*(maxLen-len(ls[i]))
    # 返回:可选转为GPU/CPU上的tensor
    return torch.tensor(ls,device=device) if returnTensor else ls
  • 核心作用:批量序列补齐(保证批次内所有序列长度一致);
  • 场景适配:MLM 预训练时,不同样本的序列长度不同,需补齐后才能组成 tensor 输入模型;
  • 示例:输入ls=[[0,14],[108,28,99]]val=0→输出[[0,14,0],[108,28,99]]

序列截断函数 truncate

def truncate(a:list,b:list,maxLen):
    maxLen-=3# 预留3个位置给[CLS]、[SEP]、[SEP](BERT/BART的特殊token)
    assert maxLen>=0  # 确保预留后长度合法
    len2=maxLen//2# 拆分长度(奇数时a更长)
    len1=maxLen-len2
    # 仅当总长度超过maxLen时截断
    if len(a)+len(b)>maxLen:
        # 四种截断场景:a短b长/ a长b短/ 两者都长
        if len(a)<=len1 and len(b)>len2:
            b=b[:maxLen-len(a)]
        elif len(a)>len1 and len(b)<=len2:
            a=a[:maxLen-len(b)]
        elif len(a)>len1 and len(b)>len2:
            a=a[:len1]
            b=b[:len2]
    return a,b
  • 核心作用:对两个序列(如句子对)进行截断,适配模型最大长度限制;
  • 逻辑:预留特殊 token 位置后,按比例拆分长度,优先保证短序列完整,截断超长序列;
  • 注意:该函数在当前代码中未实际调用,是预留的序列长度控制工具。

MLM 预训练数据集 MLM_Data(核心)

class MLM_Data(Dataset):
    # 初始化:传入语料数据(loadData返回的all_data)和参数
    def __init__(self, data, args):
        super().__init__()
        self.data=data  # MLM语料库(token列表的列表)
        self.maxLen= args.input_l-3  # 预留3个位置给[CLS]/[SEP]/[SEP]
        self.tk=AutoTokenizer.from_pretrained(args.pre_model_path)  # 加载分词器
        self.spNum=len(self.tk.all_special_tokens)  # 特殊token数量(如[PAD]/[UNK]等)
        self.tkNum=self.tk.vocab_size  # 词表总大小

    def __len__(self):
        return len(self.data)  # 返回语料样本数

    def random_mask(self, text_ids):        # 核心:掩码逻辑
        """
        输入:text_ids → 单个样本的token ID列表(如[0,14,108,28])
        输出:input_ids(掩码后的ID列表)、output_ids(掩码预测的目标ID列表)
        掩码规则(15%概率掩码):
        - 80%概率:替换为[MASK](mask_token_id),output_ids=原ID(预测原token)
        - 10%概率:保留原ID,output_ids=原ID(自己预测自己)
        - 10%概率:替换为随机ID(非特殊token),output_ids=原ID(预测原token)
        - 85%概率:保留原ID,output_ids=-100(不参与损失计算)
        """
        input_ids, output_ids = [], []
        rands = np.random.random(len(text_ids))  # 生成和text_ids等长的随机数(0-1)
        idx=0
        while idx<len(rands):
            if rands[idx]<0.15:# 需要掩码
                # 随机选择n-gram掩码长度(1/2/3,概率7:2:1)
                ngram=np.random.choice([1,2,3], p=[0.7,0.2,0.1])
                # 短文本限制ngram长度(避免掩码过长)
                if ngram==3 and len(rands)<7:
                    ngram=2
                if ngram==2 and len(rands)<4:
                    ngram=1
                L=idx+1
                R=idx+ngram# 掩码右边界(开区间)
                while L<R and L<len(rands):
                    rands[L]=np.random.random()*0.15# 强制掩码ngram内的token
                    L+=1
                idx=R
                if idx<len(rands):
                    rands[idx]=1# 禁止掩码片段的下一个token被mask(避免连续大片掩码)
            idx+=1

        # 根据随机数生成input_ids和output_ids
        for r, i in zip(rands, text_ids):
            if r < 0.15 * 0.8:  # 80% → 替换为[MASK]
                input_ids.append(self.tk.mask_token_id)
                output_ids.append(i)
            elif r < 0.15 * 0.9:  # 10% → 保留原ID
                input_ids.append(i)
                output_ids.append(i)
            elif r < 0.15:  # 10% → 替换为随机ID(非特殊token)
                input_ids.append(np.random.randint(self.spNum,self.tkNum))
                output_ids.append(i)
            else:  # 85% → 保留原ID,output_ids=-100(不计算损失)
                input_ids.append(i)
                output_ids.append(-100)
        return input_ids, output_ids

    def __getitem__(self, item):
        # 处理单条样本
        text= self.data[item]  # 取单条token列表(如["0","14","108"])
        text_ids = self.tk.convert_tokens_to_ids(text)  # token→ID(如[0,14,108])
        text_ids, out_ids = self.random_mask(text_ids)  # 掩码处理
        # 拼接特殊token:[CLS] + 掩码后ID + [SEP]
        input_ids = [self.tk.cls_token_id] + text_ids + [self.tk.sep_token_id]
        # token_type_ids:全0(单句,无句子对)
        token_type_ids=[ 0 ]*(len(text_ids)+2)
        # labels:[CLS]和[SEP]位置设为-100(不计算损失),中间是out_ids
        labels = [-100] + out_ids + [-100]
        assert len(input_ids)==len(token_type_ids)==len(labels)  # 确保长度一致
        return {'input_ids':input_ids,'token_type_ids':token_type_ids,'labels':labels}

    @classmethod
    def collate(cls,batch):
        """
        批量拼接函数:将批次内的样本补齐为相同长度,转为tensor
        输入:batch → 多个__getitem__返回的字典组成的列表
        输出:包含input_ids/token_type_ids/attention_mask/labels的字典(可直接输入模型)
        """
        # 提取批次内的各字段
        input_ids=[i['input_ids'] for i in batch]
        token_type_ids=[i['token_type_ids'] for i in batch]
        labels=[i['labels'] for i in batch]
        # 补齐为相同长度
        input_ids=paddingList(input_ids,0,returnTensor=True)
        token_type_ids=paddingList(token_type_ids,0,returnTensor=True)
        labels=paddingList(labels,-100,returnTensor=True)
        # 生成attention_mask(1=有效token,0=pad token)
        attention_mask=(input_ids!=0)
        return {'input_ids':input_ids,'token_type_ids':token_type_ids
                ,'attention_mask':attention_mask,'labels':labels}
  • 核心作用:实现 MLM(掩码语言模型)预训练的数据集,是模型学习医疗文本语义的核心;
  • 关键模块拆解
    1. __init__:加载语料和分词器,配置最大长度;
    2. random_mask:核心掩码逻辑(15% 掩码率,n-gram 掩码,80/10/10 规则);
    3. __getitem__:处理单条样本,生成掩码后的输入和标签;
    4. collate:批量拼接,补齐序列长度,生成 attention_mask(模型注意力掩码);
  • MLM 任务逻辑:模型输入 “掩码后的序列”,目标是预测被掩码的 token(labels 中 - 100 的位置不计算损失)。
代码执行流程(MLM 预训练适配)
  1. 加载语料all_data = loadData("./data") → 整合所有 CT 编码 + 诊断术语的 token;
  2. 构建 MLM 数据集mlm_dataset = MLM_Data(all_data, args)
  3. 构建数据加载器
    train_sampler = RandomSampler(mlm_dataset)
    train_loader = DataLoader(
        mlm_dataset,
        sampler=train_sampler,
        batch_size=args.batch_size,
        collate_fn=MLM_Data.collate  # 指定批量拼接函数
    )
  4. 模型训练:遍历train_loader,将批次数据输入模型,计算 MLM 损失并更新参数。
步骤 1.4:微调数据加载(data.py)
  • 核心任务:构建条件生成任务的数据集,适配 CT 编码→诊断生成的训练目标
  • 处理逻辑:
    1. 定义TranslationDataset类,将 CSV 中的文本 / 编码数据转为模型可处理的固定长度 ID 序列(添加开始 / 结束标记、填充补齐);
    2. 对输入(CT 编码序列)和输出(诊断结论)进行处理:
      • 输入:添加 [SOS](开始)、[EOS](结束)token,补齐 / 截断到固定长度(input_l=150)
      • 输出:同理添加特殊 token,补齐 / 截断到固定长度(output_l=80)
    3. 通过create_dataloaders函数,快速构建训练 / 验证 / 测试集的数据加载器(DataLoader),适配模型批量训练和推理。
  • 输出:训练 / 验证 / 测试集 DataLoader,直接用于模型微调与推理

注释的BaseDataset类(备用异常处理基类)

# class BaseDataset(Dataset):
#     def _try_getitem(self, idx):
#         raise NotImplementedError
#     def __getitem__(self, idx):
#         wait = 0.1
#         while True:
#             try:
#                 ret = self._try_getitem(idx)
#                 return ret
#             except KeyboardInterrupt:
#                 break
#             except (Exception, BaseException) as e:
#                 exstr = traceback.format_exc()
#                 print(exstr)
#                 print('read error, waiting:', wait)
#                 time.sleep(wait)
#                 wait = min(wait*2, 1000)
  • 作用:这是一个备用的 “鲁棒性数据集基类”,核心是在读取数据出错时(比如文件 IO 异常),自动重试(等待后重新读取),避免程序直接崩溃;
  • 当前状态:被注释未使用,你的TranslationDataset直接继承原生Dataset,未做异常重试处理。

核心数据集类 TranslationDataset(重点)

class TranslationDataset(Dataset):
    def __init__(self, data_file, args):
        # 1. 读取CSV数据(仅取前16行,测试用!)
        with open(data_file, 'r') as fp:
            reader = csv.reader(fp)
            self.samples = [row for row in reader][:16]  # 注意:[:16]仅取前16条样本,正式训练需删除!

        # 2. 配置参数(从args传入,对应config.py的配置)
        self.input_l = args.input_l       # 输入序列最大长度(如150)
        self.output_l = args.output_l     # 输出序列最大长度(如80)
        self.sos_id = args.sos_id         # 开始标记ID(如101)
        self.pad_id = args.pad_id         # 输入序列填充符ID(如0,对应[PAD])
        self.eos_id = args.eos_id         # 结束标记ID(如104,对应[EOS])
        self.tgt_pad_id = args.tgt_pad_id # 输出序列填充符ID(通常和pad_id一致)
        # 3. 加载分词器(映射token→ID)
        self.tk = AutoTokenizer.from_pretrained(args.pre_model_path)

    def __len__(self):
        return len(self.samples)  # 返回样本总数(当前是16)

    def __getitem__(self, idx):
        """核心:处理单条样本,转为固定长度的ID序列"""
        # ========== 处理输入序列(CT编码) ==========
        # 步骤1:拆分CT编码字符串为token列表(如"0 14 108"→["0","14","108"])
        source_tokens = [x for x in self.samples[idx][1].split()]
        # 步骤2:token→ID(如["0","14"]→[0,14])
        source_ids = self.tk.convert_tokens_to_ids(source_tokens)
        # 步骤3:添加开始/结束标记([SOS] + 编码ID + [EOS])
        source = [self.sos_id] + source_ids + [self.eos_id]
        # 步骤4:长度不足则填充pad_id(补齐到input_l)
        if len(source) < self.input_l:
            source.extend([self.pad_id] * (self.input_l - len(source)))
        # 步骤5:截断到最大长度(防止超长)
        source = np.array(source)[:self.input_l]

        # ========== 处理输出序列(诊断结论) ==========
        # 测试集(无诊断结论,len(row)<3):仅返回输入序列
        if len(self.samples[idx]) < 3:
            return source
        # 训练/验证集:处理诊断结论
        # 步骤1:拆分诊断字符串为token列表(如"剖宫产术后 出血?"→["剖宫产术后","出血?"])
        target_tokens = [x for x in self.samples[idx][2].split()]
        # 步骤2:token→ID
        target_ids = self.tk.convert_tokens_to_ids(target_tokens)
        # 步骤3:添加开始/结束标记
        target = [self.sos_id] + target_ids + [self.eos_id]
        # 步骤4:长度不足则填充tgt_pad_id(补齐到output_l)
        if len(target) < self.output_l:
            target.extend([self.tgt_pad_id] * (self.output_l - len(target)))
        # 步骤5:截断到最大长度
        target = np.array(target)[:self.output_l]

        return source, target
  • 核心逻辑拆解(结合你的 CT 报告场景):
步骤 示例(CT 编码场景)
原始输入 "0 14 108 28"(self.samples[idx][1])
拆分 token ["0","14","108","28"]
token→ID [0,14,108,28](self.tk.convert_tokens_to_ids)
添加 SOS/EOS [101,0,14,108,28,104](sos_id=101, eos_id=104)
补齐到 input_l=150 [101,0,14,...104,0,0,...0](共 150 个元素)
最终返回 numpy 数组(shape=(150,))

关键注意点

  1. [:16]是测试用的限制,正式训练必须删除,否则只训练前 16 条样本;
  2. self.tk.convert_tokens_to_ids:依赖你之前构建的vocab.txt,确保编码 token(如 "0")能正确映射到 ID;
  3. 输入 / 输出序列都添加SOS/EOS标记,是生成任务的标准操作(告诉模型序列的开始和结束);
  4. 输入用pad_id填充,输出用gt_pad_id填充(通常两者值相同,如 0)。

数据加载器构建函数 create_dataloaders

def create_dataloaders(args, test=False):
    """
    构建训练/验证/测试集的数据加载器
    :param args: 配置参数(data_path、batch_size等)
    :param test: True=构建测试集加载器,False=构建训练+验证集加载器
    :return: 对应的数据加载器
    """
    if not test:
        # 训练/验证集路径
        train_data_path = args.data_path + "/pro_train_data.csv"
        val_data_path = args.data_path + "/pro_val_data.csv"
        # 构建数据集
        train_data = TranslationDataset(train_data_path, args)
        valid_data = TranslationDataset(val_data_path, args)
        # 构建DataLoader(批量加载数据)
        train_loader = DataLoader(
            train_data,
            batch_size=args.batch_size,    # 训练批次大小(如2)
            shuffle=True,                 # 训练集打乱(增加泛化性)
            num_workers=args.num_workers, # 多线程加载(如0,单线程)
            drop_last=False               # 不丢弃最后一个不足批次的样本
        )
        valid_loader = DataLoader(
            valid_data,
            batch_size=args.val_batch_size,  # 验证批次大小(如2)
            shuffle=True,                   # 验证集也打乱(可选)
            num_workers=args.num_workers,
            drop_last=False
        )
        return train_loader, valid_loader
    else:
        # 测试集路径
        test_data_path = args.data_path + "/preliminary_a_test.csv"
        # 构建测试集数据集
        test_data = TranslationDataset(test_data_path, args)
        # 构建测试集加载器(shuffle=False,保证顺序)
        test_loader = DataLoader(
            test_data,
            batch_size=args.test_batch_size,
            shuffle=False,  # 测试集不打乱(保证推理结果和样本ID对应)
            num_workers=args.num_workers,
            drop_last=False
        )
        return test_loader
  • 核心作用:封装数据集构建和 DataLoader 配置,一键生成训练 / 验证 / 测试的批量数据加载器;
  • 关键参数解释
    • shuffle=True(训练集):每次 epoch 打乱样本顺序,避免模型学习到样本顺序的规律;
    • shuffle=False(测试集):保持样本顺序,确保推理结果能和原始测试集 ID 一一对应;
    • num_workers:多线程加载数据(建议设为 CPU 核心数,如 4;Windows 系统建议设为 0,避免报错);
    • drop_last=False:保留最后一个不足批次的样本(比如总样本 101,batch_size=2,最后 1 个样本也保留)。

阶段 2:模型训练(核心训练阶段)

步骤 2.1:模型预训练(pretrain.py + models.py)
  • 核心目标:通过 MLM 任务预训练,让模型学习医疗文本的语义结构与术语特征
  • 训练流程:
    1. 加载预训练模型(preModel类,基于AutoModelForMaskedLM
    2. 构建优化器(AdamW)和学习率调度器(线性 warmup + 衰减)
    3. 迭代训练:
      • 批量加载 MLM 数据,输入模型计算掩码预测损失
      • 反向传播更新参数,定期打印训练日志(loss、剩余时间)
      • 每 5 个 epoch 保存模型权重(含 epoch、model_state_dict)
  • 关键参数:学习率 3e-5、批量大小 2、最大 epoch=50、warmup_steps=1000
  • 输出:预训练模型权重文件(保存于 savedmodel_path)

核心训练函数 train_and_validate

def train_and_validate(args):
    # ===================== 1. 初始化模型/优化器 =====================
    # 加载自定义预训练模型(基于BART/BERT,封装了MLM损失计算)
    model = preModel(args)    
    # 构建优化器(如AdamW)+ 学习率调度器(如线性衰减)
    optimizer, scheduler = build_optimizer(args, model)
    
    # 是否加载预训练权重(断点续训/加载初始权重)
    use_pre = False
    if use_pre:
        # 加载预训练权重文件(map_location='cpu'避免设备不匹配)
        checkpoint = torch.load(args.pre_file, map_location='cpu')
        # 加载权重(strict=False:允许模型和权重的参数不完全匹配,比如新增embedding层)
        new_KEY = model.load_state_dict(checkpoint['model_state_dict'],strict=False)
    
    # ===================== 2. 设备配置(GPU/多卡) =====================
    if args.device == 'cuda':
        if args.paral == True:
            # 多卡训练:DataParallel包装模型(适配多GPU)
            model = torch.nn.parallel.DataParallel(model.to(args.device))
        else:
            # 单卡训练:模型移到GPU
            model = model.to(args.device)
        # 注释:BalancedDataParallel是自定义多卡均衡加载(按需使用)
        # model = BalancedDataParallel(16, model, dim=0).to(args.device)
    
    # ===================== 3. 加载MLM预训练数据 =====================
    # 加载所有语料(CT编码+诊断术语的token列表)
    all_data = loadData(args.data_path)
    # 构建MLM数据集(掩码处理、序列拼接)
    train_MLM_data = MLM_Data(all_data, args)
    # 构建数据加载器(指定collate_fn处理批量数据)
    train_dataloader = DataLoader(
        train_MLM_data, 
        batch_size=args.batch_size, 
        shuffle=True,  # 训练集打乱
        collate_fn=train_MLM_data.collate  # 自定义批量拼接函数
    )
    
    # ===================== 4. 训练参数初始化 =====================
    step = 0  # 全局训练步数
    start_time = time.time()  # 训练开始时间(计算剩余时间)
    # 总训练步数 = 每轮步数 × 总轮数
    num_total_steps = len(train_dataloader) * args.max_epochs

    # ===================== 5. 训练循环(多轮epoch) =====================
    for epoch in range(args.max_epochs):  # 遍历所有epoch
        for batch in train_dataloader:    # 遍历每个批次
            model.train()  # 模型设为训练模式(启用dropout/bn)
            # 前向传播:输入批量数据,计算MLM损失
            loss= model(batch)
            loss = loss.mean()  # 多卡训练时,取各卡损失的均值
            # 反向传播:计算梯度
            loss.backward()
            # 更新参数
            optimizer.step()
            # 清空梯度(避免累积)
            optimizer.zero_grad()
            # 更新学习率(调度器生效)
            scheduler.step()

            # 步数+1
            step += 1
            # ===================== 6. 日志记录(每print_steps步) =====================
            if step % args.print_steps == 0:
                # 计算单步耗时
                time_per_step = (time.time() - start_time) / max(1, step)
                # 计算剩余训练时间
                remaining_time = time_per_step * (num_total_steps - step)
                # 格式化为小时:分钟:秒
                remaining_time = time.strftime('%H:%M:%S', time.gmtime(remaining_time))
                # 记录日志:epoch、步数、剩余时间、损失
                logging.info(f"Epoch {epoch} step {step} eta {remaining_time}: loss {loss:.3f}")

        # ===================== 7. 每轮epoch结束:记录验证/保存模型 =====================
        logging.info(f"VAL_Epoch {epoch} step {step}: loss {loss:.3f}")
        # 每5轮保存一次模型(断点续训)
        if epoch % 5 == 0:
            # 多卡训练时,取model.module的权重(DataParallel包装后,参数在module里)
            state_dict = model.module.state_dict() if args.paral else model.state_dict()
            # 保存模型权重(包含epoch、模型参数)
            torch.save(
                {'epoch': epoch, 'model_state_dict': state_dict},
                # 保存路径:学习率+epoch+损失,方便区分版本
                f'{args.savedmodel_path}/lr{args.learning_rate}epoch{epoch}loss{loss:.3f}pre_model.bin'
            )
  • 核心逻辑拆解(结合你的医疗场景):
    1. 模型初始化preModel(args)加载你基于 BART 改造的模型,内置 MLM 损失计算逻辑;
    2. 多卡适配DataParallel让模型在多 GPU 上并行训练,提升速度;
    3. 数据加载MLM_Data处理 CT 编码语料,生成掩码后的批量数据;
    4. 训练核心loss = model(batch) → 前向计算 MLM 损失 → 反向传播更新参数;
    5. 日志 / 保存:实时记录损失和剩余时间,定期保存模型(避免训练中断丢失进度);
  • 关键注意点
    • loss = loss.mean():多卡训练时,每个 GPU 会计算自己的损失,需取均值保证梯度正确;
    • model.module.state_dict():多卡训练的模型参数存在module属性下,保存时需提取;
    • scheduler.step():每步更新学习率(而非每轮),是 Transformer 预训练的标准操作。

主函数 main

def main():
    # 1. 加载配置参数(从config.py读取,如batch_size、max_epochs等)
    args = parse_args()   
    # 2. 初始化日志(输出到控制台+文件)
    setup_logging()
    # 3. 配置设备(自动检测GPU/CPU,设置args.device)
    setup_device(args)
    # 4. 设置随机种子(保证训练结果可复现)
    setup_seed(args)
    # 5. 创建模型保存文件夹(不存在则新建)
    os.makedirs(args.savedmodel_path, exist_ok=True)  
    # 6. 记录训练参数(方便排查问题)
    logging.info("Training/evaluation parameters: %s", args)        
    # 7. 启动训练
    train_and_validate(args)

if __name__ == '__main__':
    main()
  • 核心作用:训练前的 “环境初始化”,是所有深度学习训练脚本的标准入口:
    1. setup_logging():配置日志格式(时间、级别、内容),训练日志会同时输出到控制台和日志文件;
    2. setup_device(args):自动检测 GPU,设置args.device = 'cuda''cpu'
    3. setup_seed(args):固定随机种子(torch/numpy/random),确保每次训练结果一致;
    4. os.makedirs(...):提前创建模型保存文件夹,避免保存时路径不存在报错。
代码执行流程(完整预训练链路)
  1. 环境初始化main() → 加载参数→配置日志 / 设备 / 种子→创建保存文件夹;
  2. 模型 / 数据准备train_and_validate() → 加载 preModel→构建优化器→加载 MLM 数据集;
  3. 训练循环
    • 遍历每个 epoch → 遍历每个批次 → 前向计算 MLM 损失 → 反向传播更新参数;
    • print_steps步记录损失和剩余时间;
    • 每 5 个 epoch 保存模型权重;
  4. 输出结果:训练日志实时打印,模型权重保存到args.savedmodel_path
步骤 2.2:模型微调(finetine.py + models.py)
  • 核心目标:将预训练模型适配 “CT 编码→诊断生成” 的条件生成任务
  • 训练流程:
    1. 加载微调模型(myModel类,基于BartForConditionalGeneration
    2. 加载预训练权重(可选,若 use_pre=True),初始化模型参数
    3. 构建优化器与调度器(同预训练),模型移至 GPU/CPU
    4. 迭代微调:
      • 批量加载训练数据(source=CT 编码,target = 诊断结论)
      • 前向传播:模型接收 source 和 target,计算交叉熵损失(CE)
      • 反向传播更新参数,调度器调整学习率
      • 每个 epoch 结束后,在验证集计算 CIDEr-D 分数(评估生成文本质量)
      • 保存 CIDEr-D 分数最优的模型权重
  • 关键参数:beam 搜索参数(beam=5)、长度惩罚(length_penalty=1)、禁止重复 ngram(no_repeat=4)
  • 输出:微调后的最优模型权重(含 epoch、CIDEr-D 分数)

验证函数 validate(核心评估逻辑)

def validate(model, loader, args, output_file=None, beam=1, n=-1):
    """
    验证模型生成效果,计算CIDEr-D分数
    :param model: 训练好的生成模型
    :param loader: 验证集数据加载器
    :param args: 配置参数
    :param output_file: 生成结果保存文件(未使用)
    :param beam: 束搜索宽度(beam=1=贪心搜索)
    :param n: 验证样本数(n=-1=全部样本)
    :return: cider_score → 整体CIDEr-D分数(越高生成效果越好)
    """
    res, gts = [], {}  # res:模型生成结果;gts:参考标签(真实诊断结论)
    tot = 0  # 样本计数

    # 遍历验证集批次(tqdm显示进度条)
    for (source, targets) in tqdm(loader):
        if n>0 and tot>n:  # 仅验证前n个样本(n=-1则验证全部)
            break
        # 数据移到GPU
        source = source.cuda()
        # 模型推理:输入CT编码序列,生成诊断结论ID序列
        pred = model(source[:, :args.input_l])  # 截断输入到input_l长度
        # 结果转回CPU,转为numpy数组(方便后续转文本)
        pred = pred.cpu().detach().numpy()
        
        # 遍历批次内每个样本
        for i in range(pred.shape[0]):
            # 生成结果转文本:pred[i](生成的ID数组)→ 文本(如"剖宫产术后 出血?")
            res.append({'image_id':tot, 'caption': [array2str(pred[i], args)]})
            # 参考标签转文本:targets[i][1:](跳过SOS标记)→ 真实诊断文本
            gts[tot] = [array2str(targets[i][1:], args)]
            tot += 1

    # 初始化CIDEr-D评估器(df='corpus':基于语料库统计,sigma=15:平滑参数)
    CiderD_scorer = CiderD(df='corpus', sigma=15)
    # 计算CIDEr-D分数(越高表示生成文本越接近真实标签)
    cider_score, cider_scores = CiderD_scorer.compute_score(gts, res)
    return cider_score
  • 核心逻辑拆解(结合医疗场景):
步骤 示例(CT 编码→诊断)
模型输入 source = [101,0,14,108,104,0,...0](CT 编码序列)
模型输出 pred = [101, 剖宫产术后,出血?,104,0,...0](生成的诊断 ID)
array2str(pred[i]) 生成文本:"剖宫产术后 出血?"
array2str(targets[i][1:]) 真实文本:"剖宫产术后 子宫出血?"
CIDEr-D 计算 量化两个文本的相似度(分数越高越相似)
  • 关键注意点
    • targets[i][1:]:跳过 SOS 标记(第一个 token),只取真实诊断内容;
    • beam=1:使用贪心搜索生成(速度快,适合验证;beam>1 = 束搜索,生成效果更好但速度慢);
    • CIDEr-D 分数范围:通常 0~2,越高表示生成的诊断结论越贴合真实标签。

训练 + 验证核心函数 train_and_validate

def train_and_validate(args):
    # ===================== 1. 加载数据和模型 =====================
    # 构建训练/验证数据加载器
    train_dataloader, val_dataloader = create_dataloaders(args)
    # 加载自定义生成模型(myModel:基于BART的条件生成模型)
    model = myModel(args)
    
    # 加载预训练权重(断点续训/加载MLM预训练权重)
    use_pre = True
    if use_pre:
        print('use_pre')
        # 加载权重文件(map_location='cpu'避免设备不匹配)
        checkpoint = torch.load(args.my_pre_model_path, map_location='cpu')
        # 严格加载权重(strict=True:模型参数必须和权重完全匹配)
        new_KEY = model.load_state_dict(checkpoint['model_state_dict'],strict=True)
    
    # 构建优化器(如AdamW)+ 学习率调度器(如线性衰减)
    optimizer, scheduler = build_optimizer(args, model)
    # 模型移到GPU
    model = model.to(args.device)

    # ===================== 2. 训练参数初始化 =====================
    model.train()  # 模型设为训练模式(启用dropout/bn)
    step = 0  # 全局训练步数
    best_score = args.best_score  # 最优CIDEr-D分数(初始化自配置参数)

    # ===================== 3. 训练循环(多轮epoch) =====================
    for epoch in range(args.max_epochs):
        # 遍历训练集批次(tqdm显示进度条)
        for (source, targets) in tqdm(train_dataloader):
            # 数据移到GPU
            source = source.cuda()
            targets = targets.cuda()
            model.train()  # 确保训练模式(验证后切回)
            
            # 前向传播:输入CT编码+真实诊断,生成预测结果
            pred = model(source[:, :args.input_l], targets[:, :args.output_l])
            # 计算交叉熵损失:pred[:, :-1](预测序列去掉最后一个token) vs targets[:, 1:](真实序列去掉SOS)
            loss = CE(pred[:, :-1], targets[:, 1:])
            loss = loss.mean()  # 取均值(多卡训练兼容)
            
            # 反向传播 + 参数更新
            loss.backward()
            optimizer.step()
            model.zero_grad()  # 清空梯度(等价于optimizer.zero_grad())
            scheduler.step()   # 更新学习率
            step += 1

        # ===================== 4. 每轮验证 + 保存最优模型 =====================
        if epoch % 1 == 0:  # 每1轮验证一次(可改为5轮,提升效率)
            model.eval()  # 模型设为评估模式(关闭dropout/bn)
            with torch.no_grad():  # 禁用梯度计算,节省显存
                cider_score = validate(model, val_dataloader, args)
            # 记录日志:epoch、步数、损失、CIDEr-D分数
            logging.info(f"Epoch {epoch} step {step}: loss {loss:.3f}, cider_score {cider_score}")
            
            # 保存最优模型(CIDEr-D分数更高时更新)
            if cider_score >= best_score:
                best_score = cider_score
                torch.save(
                    {'epoch': epoch, 'model_state_dict': model.state_dict()},
                    # 保存路径:包含epoch和CIDEr-D分数,方便区分
                    f'{args.savedmodel_path}/model_epoch_{epoch}_cider_score_{cider_score}.bin'
                )
            model.train()  # 切回训练模式
  • 核心逻辑拆解
    1. 数据加载create_dataloaders构建训练 / 验证集加载器,提供 CT 编码(source)和真实诊断(targets);
    2. 预训练权重加载use_pre=True加载 MLM 预训练的权重,基于预训练模型微调,提升效果;
    3. 损失计算CE(pred[:, :-1], targets[:, 1:])是生成任务的标准操作 ——
      • pred[:, :-1]:预测序列去掉最后一个 token(如 [101, 宫,颈,104]→[101, 宫,颈]);
      • targets[:, 1:]:真实序列去掉 SOS 标记(如 [101, 宫,颈,炎,104]→[宫,颈,炎,104]);
      • 让模型学习 “用前一个 token 预测后一个 token”,符合自回归生成逻辑;
    4. 验证逻辑:每轮训练后切换到eval()模式,禁用梯度计算(torch.no_grad()),避免显存浪费;
    5. 最优模型保存:仅当 CIDEr-D 分数超过历史最优时保存,保证最终得到效果最好的模型。

主函数 main(训练入口)

def main():
    # 1. 加载配置参数(从config.py)
    args = parse_args()
    # 2. 初始化日志(输出到控制台+文件)
    setup_logging()
    # 3. 配置设备(自动检测GPU/CPU)
    setup_device(args)
    # 4. 设置随机种子(保证训练结果可复现)
    setup_seed(args)
    # 5. 创建模型保存文件夹(不存在则新建)
    os.makedirs(args.savedmodel_path, exist_ok=True)
    # 6. 记录训练参数(方便排查问题)
    logging.info("Training/evaluation parameters: %s", args)
    # 7. 启动训练+验证
    train_and_validate(args)

if __name__ == '__main__':
    main()
  • 作用:和预训练脚本的main函数逻辑一致,是训练前的 “环境初始化”,保证训练流程规范、可复现。
核心注意事项(避坑点)
  1. 损失计算的维度匹配pred[:, :-1]targets[:, 1:]必须长度一致,否则会报维度不匹配错误(由args.input_l/args.output_l保证);
  2. 模型模式切换:验证前必须model.eval(),验证后切回model.train(),否则 dropout/bn 会影响训练效果;
  3. 梯度禁用:验证时torch.no_grad()是必须的,否则会占用大量显存,甚至导致 OOM(显存溢出);
  4. strict=True 加载权重:要求预训练权重的参数和当前模型完全匹配,若有新增参数(如 embedding 层),需改为strict=False
  5. CIDEr-D 评估:该指标适合生成任务(比准确率更合理),分数越高表示生成的诊断结论越接近真实标签;
  6. GPU 指定os.environ['CUDA_VISIBLE_DEVICES']='0'强制用第 0 块 GPU,若有多 GPU 且想使用多卡,需删除该行并配置DataParallel

阶段 3:推理与评估(成果输出阶段)

步骤 3.1:模型推理(inference.py)
  • 核心目标:用训练好的模型对测试集生成诊断结论
  • 推理流程:
    1. 加载测试集 DataLoader(仅含 CT 编码序列,无诊断结论)
    2. 加载最优模型权重,设置模型为 eval 模式(禁用 dropout)
    3. 批量推理:
      • 输入 CT 编码序列,调用模型generate方法生成诊断编码序列
      • 通过array2str函数将编码序列转为文字(过滤 PAD/EOS token)
      • 保存结果为 CSV(格式:样本 ID、生成的诊断结论),用于最终提交或效果分析。
  • 输出:测试集推理结果文件(result.csv)

推理核心函数 inference

def inference(args):
    # ===================== 1. 加载测试集数据 =====================
    # 构建测试集加载器(test=True:仅返回CT编码,无target)
    test_loader = create_dataloaders(args, test=True)
    
    # ===================== 2. 加载模型和权重 =====================
    # 初始化生成模型
    model = myModel(args)
    # 打印权重文件路径(方便排查:是否加载了正确的模型)
    print(args.ckpt_file)

    # 加载模型权重(map_location='cpu':先加载到CPU,避免设备不匹配)
    checkpoint = torch.load(args.ckpt_file, map_location='cpu')
    # 加载权重(strict=False:允许模型和权重参数不完全匹配,如新增层)
    model.load_state_dict(checkpoint['model_state_dict'], strict=False)
    # 模型移到指定GPU(cuda:0)
    model.to('cuda:0')
    # 模型设为评估模式(关闭dropout/bn,保证推理结果稳定)
    model.eval()

    # ===================== 3. 初始化CSV文件(保存推理结果) =====================
    # 打开输出CSV文件(newline=''避免空行)
    fp = open(args.test_output_csv, 'w', newline='')
    writer = csv.writer(fp)
    tot = 0  # 样本计数(作为输出的ID)

    # ===================== 4. 批量推理 =====================
    for source in tqdm(test_loader):  # tqdm显示推理进度
        # 数据移到GPU(cuda:0)
        source = to_device(source, 'cuda:0')
        # 禁用梯度计算(节省显存,加速推理)
        with torch.no_grad():
            # 模型推理:输入CT编码,生成诊断结论ID序列
            pred = model(source)
        # 结果转回CPU,转为numpy数组(方便后续转文本)
        pred = pred.cpu().numpy()
        
        # 遍历批次内每个样本
        for i in range(pred.shape[0]):
            # 写入CSV:[样本ID, 生成的诊断文本]
            # pred[i][2:]:跳过前两个token(通常是SOS/CLS,取有效诊断内容)
            writer.writerow([tot, array2str(pred[i][2:], args)])
            tot += 1  # 样本ID+1

    # ===================== 5. 关闭文件 =====================
    fp.close()
  • 核心逻辑拆解(结合医疗场景):
    步骤 示例(CT 编码→诊断)
    测试集输入 source = [101,0,14,108,104,0,...0](CT 编码 ID 序列)
    模型推理 pred = [101, 剖宫产术后,出血?,104,0,...0](生成的诊断 ID)
    pred[i][2:] 截取有效内容:[剖宫产术后,出血?,104,0,...0](跳过前 2 个 token)
    array2str(...) 转为文本:"剖宫产术后 出血?"
    CSV 写入 一行内容:[0, "剖宫产术后 出血?"]
  • 关键注意点
    1. pred[i][2:]:跳过前 2 个 token 是医疗场景的适配 —— 通常前两个是 SOS/CLS 等特殊标记,[2:]取真正的诊断内容;
    2. with torch.no_grad():推理时必须禁用梯度计算,否则会占用大量显存,甚至导致 OOM(显存溢出);
    3. strict=False:加载权重时允许参数不完全匹配(比如预训练模型和微调模型的 embedding 层维度略有差异),避免加载失败;
    4. newline='':打开 CSV 文件时添加该参数,避免 Windows 系统下写入的 CSV 出现空行。

主函数(推理入口)

if __name__ == '__main__':
    # 加载配置参数(从config.py读取:ckpt_file、test_output_csv等)
    args = parse_args()
    # 启动推理
    inference(args)
  • 作用:极简的入口函数,仅加载参数 + 启动推理,符合推理脚本 “轻量、高效” 的需求。
核心注意事项(避坑点)
  1. 模型评估模式model.eval()必须调用,否则 dropout/bn 层会随机改变输出,导致生成结果不稳定;
  2. 梯度禁用with torch.no_grad()是推理阶段的 “标配”,能节省 50% 以上的显存,大幅提升推理速度;
  3. 设备匹配model.to('cuda:0')to_device(source, 'cuda:0')必须指定相同的 GPU(如都是 cuda:0),否则报 “设备不匹配” 错误;
  4. pred [i][2:] 的合理性:需确认模型输出的前 2 个 token 确实是特殊标记(如 SOS/CLS),若特殊标记只有 1 个,需改为pred[i][1:],否则会截断有效诊断内容;
  5. CSV 写入编码:若生成的诊断文本包含中文,需指定编码(如open(..., encoding='utf-8')),否则会出现乱码;
  6. 权重加载strict=False虽能兼容参数不匹配,但需确保核心层(如 encoder/decoder)的参数完全匹配,否则模型效果会大幅下降。
步骤 3.2:模型评估(score.py)
  • 核心目标:量化评估生成诊断的质量
  • 评估指标:
    1. 交叉熵损失(CE):训练过程中监控模型拟合程度,越低表示预测越准
    2. CIDEr-D:文本生成质量评估指标,基于 n-gram 相似度计算,越高表示生成文本与真实诊断越一致
  • 评估逻辑:
    • 训练 / 验证阶段:每个 epoch 计算验证集 CE 损失和 CIDEr-D 分数
    • 推理后:可对比生成结果与真实诊断的 CIDEr-D 分数,评估模型效果
1. 交叉熵损失函数 CE
def CE(output, target):
    '''
    Output: (B,L,C)。未经过softmax的logits → B=批次大小,L=序列长度,C=词表大小
    Target: (B,L) → 真实标签的ID序列
    '''
    # 展平输出:(B*L, C) → 交叉熵损失要求输入为(样本数, 类别数)
    output = output.reshape(-1, output.shape[-1])  # (*,C)
    # 展平标签:(B*L,) → 转为long类型(交叉熵要求标签是整型)
    target = target.reshape(-1).long()  # (*)
    # 计算交叉熵损失(默认reduction='mean',对所有B*L个token的损失取平均)
    loss =  nn.CrossEntropyLoss()(output, target)
    return loss
  • 核心作用:用于生成模型的训练,衡量模型预测的 token 分布与真实 token 的差异;
  • 关键细节
    • output是模型输出的 logits(未经过 softmax),nn.CrossEntropyLoss内部会自动做 softmax;
    • 展平操作reshape(-1)是为了适配nn.CrossEntropyLoss的输入格式;
    • target.long():确保标签是整型(PyTorch 交叉熵不支持浮点型标签)。
2. 基础工具函数(ngram 统计)

这部分是 CIDEr-D 的底层逻辑,核心是统计文本的 n-gram(1~4 元语法)频次,为后续相似度计算做准备。

2.1 precook:统计单条文本的 n-gram 频次
def precook(s, n=4, out=False):
    """
    将单条文本转为n-gram频次字典
    :param s: string → 输入文本(如"剖宫产术后 出血?")
    :param n: int → 最大n-gram阶数(默认4,即统计1~4元语法)
    :return: counts → defaultdict,key=ngram元组,value=出现频次
    """
    words = s.split()  # 拆分文本为单词/词块列表(如["剖宫产术后", "出血?"])
    counts = defaultdict(int)
    # 遍历1~n元语法
    for k in range(1, n + 1):
        # 滑动窗口统计每个ngram的频次
        for i in range(len(words) - k + 1):
            ngram = tuple(words[i:i + k])  # 元组作为key(列表不可哈希)
            counts[ngram] += 1
    return counts
  • 示例:输入s="a b c",n=2 → 输出{(('a',),1), (('b',),1), (('c',),1), (('a','b'),1), (('b','c'),1)}
  • 核心作用:将文本转为 n-gram 频次表示,是 CIDEr-D 计算的基础。
2.2 cook_refs/cook_test:批量处理参考 / 测试文本
def cook_refs(refs, n=4):
    '''处理多条参考文本(如一张图片的多个描述),返回每条文本的ngram频次'''
    return [precook(ref, n) for ref in refs]

def cook_test(test, n=4):
    '''处理单条测试文本(生成的文本),返回其ngram频次'''
    return precook(test, n, True)
  • 区别cook_refs处理参考文本列表(如一个样本可能有多个参考标签),cook_test处理单条生成文本。
2.3 sim:计算 n-gram 向量的余弦相似度(带长度惩罚)
def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref, n=4, sigma=6.0):
    '''
    计算生成文本(hyp)与参考文本(ref)的n-gram向量余弦相似度,加入长度惩罚
    :param vec_hyp/ref: n-gram的tf-idf向量
    :param norm_hyp/ref: 向量的归一化系数
    :param length_hyp/ref: 文本长度
    :param sigma: 高斯惩罚的标准差
    :return: 各阶n-gram的相似度数组
    '''
    delta = float(length_hyp - length_ref)  # 长度差
    val = np.array([0.0 for _ in range(n)])  # 存储1~n阶n-gram的相似度
    
    for n in range(n):
        # 计算分子:生成文本与参考文本的ngram重叠数(取最小值相乘)
        for (ngram, count) in vec_hyp[n].items():
            val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]
        # 余弦相似度:分子 / (hyp_norm * ref_norm)
        if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
            val[n] /= (norm_hyp[n] * norm_ref[n])
        # 长度惩罚:高斯函数,长度差越大惩罚越重
        val[n] *= np.e ** (-(delta ** 2) / (2 * sigma ** 2))
    return val
  • 核心逻辑
    1. 先计算 n-gram 向量的余弦相似度(衡量内容相似度);
    2. 用高斯函数惩罚长度差异(避免生成过短 / 过长的文本);
  • 医疗场景适配:诊断结论的长度相对固定,长度惩罚能避免模型生成过短(信息不全)或过长(冗余)的诊断。
3. 核心类 CiderScorer(CIDEr-D 计算逻辑)
class CiderScorer(object):
    """CIDEr分数计算的核心类,处理批量文本的n-gram统计、tf-idf计算、相似度求和"""

    def __init__(self, df_mode="corpus", test=None, refs=None, n=4, sigma=6.0):
        '''
        :param df_mode: 文档频率(df)计算方式 → "corpus"=基于当前数据集,或指定预计算的pkl文件
        :param n: n-gram阶数(默认4)
        :param sigma: 长度惩罚的标准差(默认6)
        '''
        self.n = n
        self.sigma = sigma
        self.crefs = []  # 参考文本的ngram频次列表
        self.ctest = []  # 生成文本的ngram频次列表
        self.df_mode = df_mode
        self.ref_len = None  # 参考文本的平均长度(log值)
        self.document_frequency = defaultdict(float)  # 文档频率(df):ngram在多少样本中出现
        
        # 加载预计算的df(若指定df_mode为文件路径)
        if self.df_mode != "corpus":
            pkl_file = cPickle.load(open(os.path.join(df_mode), 'rb'), 
                                   **(dict(encoding='latin1') if six.PY3 else {}))
            self.ref_len = np.log(float(pkl_file['ref_len']))
            self.document_frequency = pkl_file['document_frequency']
        
        self.cook_append(test, refs)  # 初始化时添加测试/参考文本
  • 核心方法拆解
    1. compute_doc_freq():统计参考文本中每个 ngram 的文档频率(df),用于计算 idf(逆文档频率);
    2. counts2vec():将 ngram 频次转为 tf-idf 向量(tf = 词频,idf=log (总样本数 /df)),并计算向量归一化系数;
    3. compute_cider():遍历所有样本,计算每条生成文本与参考文本的相似度,最终得到 CIDEr 分数;
    4. compute_score():主入口,先计算 df/idf,再调用compute_cider()得到最终分数(均值 + 逐样本分数)。
4. 对外接口类 CiderD
class CiderD:
    """CIDEr-D的对外封装类,简化调用流程"""
    def __init__(self, n=4, sigma=6.0, df="corpus"):
        self._n = n
        self._sigma = sigma
        self._df = df
        self.cider_scorer = CiderScorer(n=self._n, sigma=sigma, df_mode=self._df)

    def compute_score(self, gts, res):
        """
        主函数:计算CIDEr-D分数
        :param gts: dict → {image_id: [参考文本1, 参考文本2,...]}
        :param res: list → [{'image_id': id, 'caption': [生成文本]}]
        :return: score(整体均值), scores(逐样本分数)
        """
        tmp_cider_scorer = self.cider_scorer.copy_empty()
        tmp_cider_scorer.clear()
        # 遍历所有生成结果,添加到scorer中
        for res_id in res:
            hypo = res_id['caption']  # 生成文本
            ref = gts[res_id['image_id']]  # 对应的参考文本
            tmp_cider_scorer += (hypo[0], ref)
        # 计算最终分数
        (score, scores) = tmp_cider_scorer.compute_score()
        return score, scores
  • 核心作用:封装CiderScorer的复杂逻辑,提供简洁的compute_score接口,适配文本生成任务的输入格式;
  • 输入格式示例(医疗场景):
    # gts:真实诊断结论(参考文本)
    gts = {0: ["剖宫产术后 子宫出血"], 1: ["肺炎 双肺纹理增多"]}
    # res:模型生成的诊断结论
    res = [{'image_id':0, 'caption':["剖宫产术后 出血"]}, 
           {'image_id':1, 'caption':["肺炎 纹理增多"]}]
    # 计算分数
    ciderd = CiderD()
    score, scores = ciderd.compute_score(gts, res)

核心要点总结

  1. CE 损失:生成模型训练的核心损失,衡量 token 级别的预测误差,输入需展平为 (样本数,类别数);
  2. CIDEr-D 指标:文本生成任务的核心评估指标,通过 n-gram 的 tf-idf 相似度 + 长度惩罚,量化生成文本与参考文本的语义相似度,分数越高效果越好;
  3. 医疗场景适配
    • 需确保诊断文本的分词格式统一(如空格分隔);
    • sigma 建议调大(如 15),适配诊断结论的长度特性;
    • 可预计算医疗语料的 df 文件,提升评估速度。

CIDEr-D 的核心计算流程(可视化)

三、核心模块关键细节解析

1. 词表与编码映射逻辑

  • 核心规则:编码 ID = 词表行号(从 0 开始计数)
    • 例:编码 0→词表第 0 行→[PAD],编码 14→第 14 行→[unused14],编码 294→第 294 行→对应词表中该位置的 token(如多语言符号 / 医疗术语)
  • 特殊 token 说明:
    • [PAD](0):序列填充符;[UNK](100):未知词标记;[CLS]/[SEP]/[MASK]:MLM 任务专用;[SOS](101)/[EOS](105):生成任务开始 / 结束标记

2. 模型核心逻辑(models.py)

  • 预训练模型(preModel):基于 BART 的 MLM 任务,输入掩码序列,预测被掩码的 token
  • 微调模型(myModel):
    • 训练模式:接收 source 和 target,通过 BART 的 Encoder-Decoder 架构计算 logits,输出交叉熵损失
    • 推理模式:仅接收 source,通过 beam 搜索生成最优诊断序列(控制最大长度、长度惩罚、禁止重复)

3. 关键参数配置(config.py)

  • 数据相关:input_l=150(输入序列长度)、output_l=80(输出序列长度)、data_path(数据文件夹路径)
  • 训练相关:batch_size=2、max_epochs=50、learning_rate=3e-5、warmup_steps=1000
  • 生成相关:beam=5、length_penalty=1、no_repeat=4(避免重复生成 4-gram)

四、项目执行顺序(快速上手指南)

  1. 环境搭建:创建 Python 3.10 环境,安装 requirements.txt 中的依赖包
  2. 数据准备:
    • 放置原始数据train.csv到 data 文件夹
    • 运行process_data.py拆分训练 / 验证集
    • 运行pro_vocab.py生成自定义词表
  3. 模型训练:
    • 运行pretrain.py完成 MLM 预训练(可选,提升效果)
    • 运行finetine.py完成生成任务微调(核心步骤)
  4. 推理评估:
    • 运行inference.py生成测试集诊断结果
    • 基于score.py的 CIDEr-D 指标评估生成质量
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐