Datawhale 大模型算法全栈基础篇 202602第6次笔记
NER(named_entity_recognition)命名实体识别
完成了一个完整的 NER 项目:
-
数据处理:构建了标签映射
categories.json和词汇表vocabulary.json。 -
模型构建:实现了基于双向 GRU 的序列标注模型。
-
训练封装:使用 Trainer 类高效训练,早停机制防止过拟合。
-
预测部署:预测脚本能对新文本进行实时实体抽取,成功。
笔记:
第一节 命名实体识别
一、什么是命名实体识别(NER)?
想象一下,你读了一句话:“姚明昨天在北京打篮球。”
普通人一眼就能看出:“姚明”是个人,“北京”是个地方。
命名实体识别就是让电脑也学会这种能力——从文字中找出谁(人名)、哪里(地名)、什么机构(公司名)、什么产品等关键词语,并给它们贴上正确的标签。
常见的实体类型:
-
人名(PER):姚明、雷军
-
地名(LOC):北京、故宫
-
机构名(ORG):阿里巴巴、英伟达
-
产品名(PROD):黑神话:悟空、iPhone
-
时间(TIME):昨天、2025年
一个词是不是实体,取决于场景。比如“苹果”:
-
在水果摊前 → 不是实体(普通名词)
-
在手机讨论中 → 是组织机构名(Apple公司)
二、NER有什么用?
1. 智能搜索
你在百度搜“姚明的身高”,电脑会:
-
识别出意图:“属性查询”
-
抽取出实体:姚明(主体)和身高(属性)
-
然后从知识库里找到答案“2.26米”,直接显示在搜索结果最上方。
2. 聊天机器人
你在淘宝问:“我的快递到哪了?”
机器人通过NER认出你关心的是“物流信息”,于是自动帮你查快递状态,而不是傻傻地问你“什么快递?”
3. 医疗辅助诊断
医生输入:“病人多饮、多食,血糖检测偏高。”
NER系统可以抽取出:
-
症状:多饮、多食
-
检查:血糖检测
结合知识库,可能提示医生:“怀疑II型糖尿病,建议进一步检查。”
三、怎么教电脑做NER?
方法1:查字典 + 规则(最原始)
-
准备一个巨大的词典,比如《中国地名词典》,里面收录所有地名。
-
在文本中一个个词去比对,如果匹配上了,就标为地名。
-
优点:简单、快。
-
缺点:词典里没有的新词(比如新上市的公司)就认不出来,而且维护词典很累。
方法2:序列标注(目前最常用)
把NER变成一个“给每个字贴标签”的任务。比如句子:
text
西 安 的 大 雁 塔 门 票 多 少 钱
我们用标签告诉电脑:
-
西(B-LOC,地点开头)、安(E-LOC,地点结尾)
-
大(B-LOC)、雁(M-LOC)、塔(E-LOC)
-
门(B-ATTR,属性开头)、票(E-ATTR,属性结尾)
这样电脑就能学会“西安”和“大雁塔”都是地点,“门票”是属性。
常用标签体系:
-
BIO:B(开始)、I(中间)、O(外部)
-
BMES:B(开始)、M(中间)、E(结尾)、S(单字实体)
模型结构:
通常用BERT这样的预训练模型来提取每个字的特征,然后接一个分类层,给每个字打标签。
方法3:指针网络 + 片段网络(解决嵌套问题)
有时实体是嵌套的,比如:
他 就 读 于 北 京 大 学
“北京”是地名,“北京大学”是机构名。用普通的序列标注很难同时识别这两个,因为一个字只能有一个标签。
指针网络的做法:
-
为每个字训练两个二分类器:判断它是不是某个实体的开头,以及是不是某个实体的结尾。
-
比如:“北”是地名开头,也是机构名开头;“京”是地名结尾;“学”是机构名结尾。
-
然后后处理把“开头-结尾”配对,得到候选实体:北京、北京大学、京大(错误的会被过滤)。
片段网络则是对所有可能的片段(比如“北京”、“京大”、“北京大学”)做分类,判断每个片段属于哪类实体。两者结合,既能解决嵌套,又不会计算量太大。
方法4:用大模型生成(最新潮流)
直接把任务写成提示词,让大模型自己输出结构化结果。例如:
输入:请从句子“姚明昨天在北京打篮球”中抽取出人名、地名。
输出:人名:姚明;地名:北京
这种方法不需要训练,但大模型有时会“幻觉”,输出原文中没有的东西。
第二节 NER项目的数据处理
一、我们要做什么?
让电脑学会命名实体识别(NER),就像教它玩一个“找词贴标签”的游戏。但在玩游戏之前,我们需要把“游戏规则”和“道具”准备好。数据处理就是做这件事——把原始的文字数据,变成电脑能看懂的数字。
最终目标
-
给电脑的输入(X):是一个数字表格,形状为
[批次大小, 句子长度]。表格里的每个数字,代表句子中每个字在词汇表里的编号。 -
给电脑的正确答案(Y):也是一个同样形状的数字表格,每个数字代表对应字的实体标签编号(比如“B-dis”的编号是几)。
-
两本“字典”:
-
词汇表:把每个字映射成一个唯一的数字ID。
-
标签映射表:把每个实体标签(如“B-dis”)也映射成一个唯一的数字ID。
-
二、第一步:构建标签映射表
我们的数据是中文医学实体数据集(CMeEE-V2),每条数据包含一段文本和里面标注的实体。比如:
text
文本:(2)室上性心动过速可用常规抗心律失常药物控制,年龄小于5岁。 实体:["室上性心动过速"是“疾病(dis)”,"抗心律失常药物"是“药物(dru)”]
2.1 数据长什么样?
数据是JSON格式,里面是一个大列表,每个元素是一个字典:
json
{
"text": "(2)室上性心动过速可用常规抗心律失常药物控制,年龄小于5岁。",
"entities": [
{"start_idx": 3, "end_idx": 9, "type": "dis", "entity": "室上性心动过速"},
{"start_idx": 14, "end_idx": 20, "type": "dru", "entity": "抗心律失常药物"}
]
}
-
start_idx和end_idx是实体在文本中的起始和结束位置(注意:是包含两端的闭区间)。 -
type是实体类型,比如dis(疾病)、dru(药物)。
2.2 提取所有实体类型
我们需要先找出数据里所有可能的实体类型。比如这个数据集里有9种类型:dru(药物)、dis(疾病)、bod(身体部位)、sym(症状)等。
2.3 用BMES方案给标签编码
我们给每个字贴标签时,不能只贴一个类型,还要说明它在实体里的位置。所以用BMES方案:
-
B(Begin):实体开头
-
M(Middle):实体中间
-
E(End):实体结尾
-
S(Single):单字实体
-
O(Outside):不是实体
比如对于“室上性心动过速”这个疾病实体,标签序列是:B-dis M-dis M-dis M-dis E-dis。
最后生成的标签映射表(categories.json)就像:
json
{
"O": 0,
"B-dis": 1,
"M-dis": 2,
"E-dis": 3,
"S-dis": 4,
"B-dru": 5,
"M-dru": 6,
"E-dru": 7,
"S-dru": 8,
...
}
“O”的编号是0,其他依次递增。
三、第二步:构建词汇表
现在需要把文本里的每个字,也变成一个数字ID。就像查字典一样。
3.1 统计所有出现的字
遍历所有文本,数一数每个字出现了多少次。
3.2 规范化文本(统一全角/半角)
你会发现数据里既有全角字符(如,、()又有半角字符(如,、()。它们在电脑眼里是两个不同的字,但对我们来说意思一样。所以把它们统一转换成半角,减少词汇表大小。
3.3 过滤低频字
可以设置一个阈值,比如只保留出现次数≥2的字,去掉那些只出现一次的罕见字。这样可以进一步精简词汇表。
3.4 添加特殊标记
最后在词汇表最前面加上两个特殊标记:
-
<PAD>:用于把不同长度的句子填充到一样长。 -
<UNK>:当遇到词汇表里没有的字时,就用它代替。
生成的词汇表(vocabulary.json)是一个列表:
json
[
"<PAD>",
"<UNK>",
"一",
"了",
"人",
"他",
...
]
四、第三步:封装数据加载器
现在有了两本字典(词汇表和标签映射表),我们需要一个“数据流水线”,能自动把原始数据变成模型能吃的批次。
4.1 Vocabulary类——查字典
创建一个类,加载词汇表,并提供“字→ID”的转换功能。
4.2 NerDataset类——加工单条数据
继承PyTorch的Dataset,对每条数据做:
-
把文本转成字符列表。
-
用Vocabulary把字符转成
token_ids。 -
初始化一个全是“O”的标签序列。
-
根据实体标注,把对应位置的“O”替换成正确的BMES标签。
-
用标签映射表把标签字符串转成
label_ids。 -
返回包含
token_ids和label_ids的字典。
4.3 collate_fn函数——打包成一个批次
因为每个句子长度不同,不能直接打包。collate_fn做三件事:
-
动态填充:找出批次里最长的句子,把其他句子填充到一样长(用
<PAD>的ID填充)。 -
生成注意力掩码:创建一个和输入一样形状的矩阵,真实位置是1,填充位置是0,告诉模型哪里是真正的字。
-
处理标签填充:对标签也用同样的方式填充,但填充值设为
-100(PyTorch损失函数会自动忽略这个值)。
4.4 整合成DataLoader
把所有组件组合起来,得到一个能自动产出批次的DataLoader。
五、代码
01_build_category.py
import json
import os
def save_json(data, file_path):
"""
将数据以格式化的 JSON 形式保存到文件
"""
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
def collect_entity_types_from_file(file_path):
"""
从单个数据文件中提取所有唯一的实体类型
"""
types = set()
with open(file_path, 'r', encoding='utf-8') as f:
all_data = json.load(f)
for data in all_data:
# 遍历实体列表,提取 'type' 字段
for entity in data['entities']:
types.add(entity['type'])
return types
def generate_tag_map(data_files, output_file):
"""
从数据文件构建 BMES 标签映射并保存
"""
# 1. 从所有文件中收集实体类型
all_entity_types = set()
for file_path in data_files:
all_entity_types.update(collect_entity_types_from_file(file_path))
# 2. 排序以保证映射一致性
sorted_types = sorted(list(all_entity_types))
print(f"发现的实体类型: {sorted_types}")
# 3. 构建 BMES 标签映射
tag_to_id = {'O': 0} # 'O' 代表非实体
for entity_type in sorted_types:
for prefix in ['B', 'M', 'E', 'S']:
tag_name = f"{prefix}-{entity_type}"
tag_to_id[tag_name] = len(tag_to_id)
print(f"\n已生成 {len(tag_to_id)} 个标签映射。")
# 4. 保存映射文件
save_json(tag_to_id, output_file)
print(f"标签映射已保存至: {output_file}")
if __name__ == '__main__':
# 定义输入的数据文件和期望的输出路径
train_file = 'E:/Datawhale 2026/base-llm202602/06_homework/CMeEE-V2_train.json'
dev_file = 'E:/Datawhale 2026/base-llm202602/06_homework/CMeEE-V2_dev.json'
output_path = './categories.json'
generate_tag_map(data_files=[train_file, dev_file], output_file=output_path)
输出:
(base) PS E:\Datawhale 2026\base-llm202602\06_homework> & D:/Users/app/miniconda3/envs/base-llm/python.exe "e:/Datawhale 2026/base-llm202602/06_homework/01_build_category.py"
发现的实体类型: ['bod', 'dep', 'dis', 'dru', 'equ', 'ite', 'mic', 'pro', 'sym']
已生成 37 个标签映射。
标签映射已保存至: ./categories.json
02_build_vocabulary.py
import json
import os
from collections import Counter
def save_json(data, file_path):
"""
将数据以易于阅读的格式保存为 JSON 文件
"""
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
def normalize_text(text):
"""
规范化文本
"""
full_width = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!#$%&’()*+,-./:;<=>?@[\]^_`{|}~""
half_width = r"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!#$%&'" + r'()*+,-./:;<=>?@[\]^_`{|}~".'
mapping = str.maketrans(full_width, half_width)
return text.translate(mapping)
def create_char_vocab(data_files, output_file, min_freq=1):
"""
从数据文件创建字符级词汇表
"""
char_counts = Counter()
for file_path in data_files:
with open(file_path, 'r', encoding='utf-8') as f:
all_data = json.load(f)
for data in all_data:
text = normalize_text(data['text'])
char_counts.update(list(text))
# 过滤低频词
frequent_chars = [char for char, count in char_counts.items() if count >= min_freq]
# 保证每次生成结果一致
frequent_chars.sort()
# 添加特殊标记
special_tokens = ["<PAD>", "<UNK>"]
final_vocab_list = special_tokens + frequent_chars
print(f"词汇表大小 (min_freq={min_freq}): {len(final_vocab_list)}")
# 保存词汇表
save_json(final_vocab_list, output_file)
print(f"词汇表已保存至: {output_file}")
if __name__ == '__main__':
# 定义输入的数据文件和期望的输出路径
train_file = 'E:/Datawhale 2026/base-llm202602/06_homework/CMeEE-V2_train.json'
dev_file = 'E:/Datawhale 2026/base-llm202602/06_homework/CMeEE-V2_dev.json'
output_path = './vocabulary.json'
# 设置字符最低频率,1表示包含所有出现过的字符
create_char_vocab(data_files=[train_file, dev_file], output_file=output_path, min_freq=1)
输出:
(base) PS E:\Datawhale 2026\base-llm202602\06_homework> & D:/Users/app/miniconda3/envs/base-llm/python.exe "e:/Datawhale 2026/base-llm202602/06_homework/02_build_vocabulary.py"
词汇表大小 (min_freq=1): 2970
词汇表已保存至: ./vocabulary.json
03_data_loader.py
import json
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
def normalize_text(text):
"""
规范化文本,例如将全角字符转换为半角字符。
"""
full_width = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!#$%&’()*+,-./:;<=>?@[\]^_`{|}~""
half_width = r"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!#$%&'" + r'()*+,-./:;<=>?@[\]^_`{|}~".'
mapping = str.maketrans(full_width, half_width)
return text.translate(mapping)
class Vocabulary:
"""
负责管理词汇表和 token 到 id 的映射。
"""
def __init__(self, vocab_path):
with open(vocab_path, 'r', encoding='utf-8') as f:
self.tokens = json.load(f)
self.token_to_id = {token: i for i, token in enumerate(self.tokens)}
self.pad_id = self.token_to_id['<PAD>']
self.unk_id = self.token_to_id['<UNK>']
def __len__(self):
return len(self.tokens)
def convert_tokens_to_ids(self, tokens):
return [self.token_to_id.get(token, self.unk_id) for token in tokens]
class NerDataset(Dataset):
"""
处理 NER 数据,并将其转换为适用于 PyTorch 模型的格式。
"""
def __init__(self, data_path, vocab: Vocabulary, tag_map: dict):
self.vocab = vocab
self.tag_to_id = tag_map
with open(data_path, 'r', encoding='utf-8') as f:
self.records = json.load(f)
def __len__(self):
return len(self.records)
def __getitem__(self, idx):
record = self.records[idx]
text = normalize_text(record['text'])
tokens = list(text)
# 将文本 tokens 转换为 ids
token_ids = self.vocab.convert_tokens_to_ids(tokens)
# 初始化标签序列为 'O'
tags = ['O'] * len(tokens)
for entity in record.get('entities', []):
entity_type = entity['type']
start = entity['start_idx']
end = entity['end_idx']
if end >= len(tokens): continue
if start == end:
tags[start] = f'S-{entity_type}'
else:
tags[start] = f'B-{entity_type}'
tags[end] = f'E-{entity_type}'
for i in range(start + 1, end):
tags[i] = f'M-{entity_type}'
# 将标签转换为 ids
label_ids = [self.tag_to_id[tag] for tag in tags]
return {
"token_ids": torch.tensor(token_ids, dtype=torch.long),
"label_ids": torch.tensor(label_ids, dtype=torch.long)
}
def create_ner_dataloader(data_path, vocab, tag_map, batch_size, shuffle=False):
"""
创建 NER 任务的 DataLoader。
"""
dataset = NerDataset(data_path, vocab, tag_map)
def collate_batch(batch):
token_ids_list = [item['token_ids'] for item in batch]
label_ids_list = [item['label_ids'] for item in batch]
padded_token_ids = pad_sequence(token_ids_list, batch_first=True, padding_value=vocab.pad_id)
padded_label_ids = pad_sequence(label_ids_list, batch_first=True, padding_value=-100) # -100 用于在计算损失时忽略填充部分
attention_mask = (padded_token_ids != vocab.pad_id).long()
return {
"token_ids": padded_token_ids,
"label_ids": padded_label_ids,
"attention_mask": attention_mask
}
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_batch)
if __name__ == '__main__':
# 文件路径
# 定义输入的数据文件和期望的输出路径
train_file = 'E:/Datawhale 2026/base-llm202602/06_homework/CMeEE-V2_train.json'
dev_file = 'E:/Datawhale 2026/base-llm202602/06_homework/CMeEE-V2_dev.json'
vocab_file = './vocabulary.json'
categories_file = './categories.json'
# 1. 加载词汇表和标签映射
vocabulary = Vocabulary(vocab_path=vocab_file)
with open(categories_file, 'r', encoding='utf-8') as f:
tag_map = json.load(f)
print("词汇表和标签映射加载完成。")
# 2. 创建 DataLoader
train_loader = create_ner_dataloader(
data_path=train_file,
vocab=vocabulary,
tag_map=tag_map,
batch_size=4,
shuffle=True
)
print("DataLoader 创建完成。")
# 3. 验证一个批次的数据
print("\n--- 验证一个批次的数据 ---")
batch = next(iter(train_loader))
print(f" Token IDs (shape): {batch['token_ids'].shape}")
print(f" Label IDs (shape): {batch['label_ids'].shape}")
print(f" Attention Mask (shape): {batch['attention_mask'].shape}")
print(f" Token IDs (sample): {batch['token_ids'][0][:20]}...")
print(f" Label IDs (sample): {batch['label_ids'][0][:20]}...")
print(f" Attention Mask (sample): {batch['attention_mask'][0][:20]}...")
输出:
(base) PS E:\Datawhale 2026\base-llm202602\06_homework> & D:/Users/app/miniconda3/envs/base-llm/python.exe "e:/Datawhale 2026/base-llm202602/06_homework/03_data_loader.py"
词汇表和标签映射加载完成。
DataLoader 创建完成。
--- 验证一个批次的数据 ---
Token IDs (shape): torch.Size([4, 60])
Label IDs (shape): torch.Size([4, 60])
Attention Mask (shape): torch.Size([4, 60])
Token IDs (sample): tensor([ 342, 1793, 2295, 340, 1410, 1374, 1804, 1268, 1321, 1262, 38, 65,
575, 2420, 283, 1507, 1880, 147, 0, 0])...
Label IDs (sample): tensor([ 33, 34, 34, 34, 34, 34, 34, 34, 34, 34, 21, 23,
34, 21, 23, 34, 35, 0, -100, -100])...
Attention Mask (sample): tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0])...
📊 输出解读
词汇表和标签映射加载完成。 DataLoader 创建完成。 --- 验证一个批次的数据 --- Token IDs (shape): torch.Size([4, 60]) Label IDs (shape): torch.Size([4, 60]) Attention Mask (shape): torch.Size([4, 60])
-
Token IDs:一个批次有 4 条样本,每条样本被填充/截断为长度 60。每个数字代表一个字符在词汇表中的 ID。
-
Label IDs:同样形状,每个数字代表该字符对应的实体标签 ID(如 0 代表
O,1 代表B-bod等)。注意填充位置的标签是-100,在计算损失时会被忽略。 -
Attention Mask:1 表示真实 token,0 表示填充位置,模型在计算注意力时会忽略填充部分。
样例数据解释:
Token IDs (sample): tensor([ 342, 1793, 2295, 340, 1410, 1374, 1804, 1268, 1321, 1262, 38, 65,
575, 2420, 283, 1507, 1880, 147, 0, 0])...
-
这是一条样本的前 20 个 token。末尾的
0是<PAD>的 ID,表示填充部分。
Label IDs (sample): tensor([ 33, 34, 34, 34, 34, 34, 34, 34, 34, 34, 21, 23,
34, 21, 23, 34, 35, 0, -100, -100])...
-
对应的标签 ID:非零值代表实体标签(如
33可能是某个实体的B-xxx,34是M-xxx,35是E-xxx等)。最后的-100是填充位置的标签(损失计算时会忽略)。
Attention Mask (sample): tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0])...
-
前 18 个位置是真实 token(mask=1),后 2 个是填充(mask=0)。
这些结果完全符合预期,说明数据处理流程已经正确实现,数据可以被模型直接使用了!
✅ 已完成的部分
-
构建了标签映射表
categories.json -
构建了字符级词汇表
vocabulary.json -
创建了可批量加载数据的
DataLoader,并验证成功
04_model.py
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn
class BiGRUNerNetWork(nn.Module):
def __init__(self, vocab_size, hidden_size, num_tags, num_gru_layers=1):
super().__init__()
# 1. Token Embedding 层
self.embedding = nn.Embedding(vocab_size, hidden_size)
# 2. 使用 ModuleList 构建多层双向 GRU
self.gru_layers = nn.ModuleList()
for _ in range(num_gru_layers):
self.gru_layers.append(
nn.GRU(
input_size=hidden_size,
hidden_size=hidden_size,
num_layers=1,
batch_first=True,
bidirectional=True # 开启双向
)
)
# 3. 特征融合层
self.fc = nn.Linear(hidden_size * 2, hidden_size)
# 4. 分类决策层 (Classifier)
self.classifier = nn.Linear(hidden_size, num_tags)
def forward(self, token_ids, attention_mask):
# 1. 计算真实长度
lengths = attention_mask.sum(dim=1).cpu()
# 2. 获取词向量
embedded_text = self.embedding(token_ids)
# 3. 打包序列
current_packed_input = rnn.pack_padded_sequence(
embedded_text, lengths, batch_first=True, enforce_sorted=False
)
# 4. 循环通过 GRU 层
for gru_layer in self.gru_layers:
# GRU 输出 (packed)
packed_output, _ = gru_layer(current_packed_input)
# 解包以进行后续操作,并指定 total_length
output, _ = rnn.pad_packed_sequence(
packed_output, batch_first=True, total_length=token_ids.shape[1]
)
# 特征融合
features = self.fc(output)
# 残差连接
# 同样需要解包上一层的输入
input_padded, _ = rnn.pad_packed_sequence(
current_packed_input, batch_first=True, total_length=token_ids.shape[1]
)
current_input = features + input_padded
# 重新打包作为下一层的输入
current_packed_input = rnn.pack_padded_sequence(
current_input, lengths, batch_first=True, enforce_sorted=False
)
# 5. 解包最终输出用于分类
final_output, _ = rnn.pad_packed_sequence(
current_packed_input, batch_first=True, total_length=token_ids.shape[1]
)
# 6. 分类
logits = self.classifier(final_output)
return logits
if __name__ == '__main__':
token_ids = torch.tensor([
[210, 18, 871, 147, 0, 0, 0, 0],
[922, 2962, 842, 210, 18, 871, 147, 0]
], dtype=torch.int64)
# attention_mask 标记哪些是真实 token (1) 哪些是填充 (0)
attention_mask = torch.tensor([
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 0]
], dtype=torch.int64)
label_ids = torch.tensor([
[0, 0, 0, 0, -100, -100, -100, -100],
[0, 0, 0, 0, 0, 0, 0, -100]
], dtype=torch.int64)
# 实例化模型
model = BiGRUNerNetWork(
vocab_size=10000,
hidden_size=128,
num_tags=37,
num_gru_layers=2
)
# 3. 执行前向传播
logits = model(token_ids=token_ids, attention_mask=attention_mask)
# 4. 构造损失函数
loss_fn = nn.CrossEntropyLoss(ignore_index=-100, reduction='none')
# 5. 计算损失
# CrossEntropyLoss 要求类别维度在前,所以需要交换最后两个维度
# [batch, seq_len, num_tags] -> [batch, num_tags, seq_len]
permuted_logits = torch.permute(logits, dims=(0, 2, 1))
loss = loss_fn(permuted_logits, label_ids)
# 6. 打印结果
print(f"Logits shape: {logits.shape}")
print(f"Loss shape: {loss.shape}")
print("\n每个 Token 的损失:")
print(loss)
05_train.py
pip install tensorboard -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
import os
import torch
import torch.nn as nn
import sys
from dataclasses import asdict
# 导入定义的所有组件
from src.configs.configs import config
from src.data.data_loader import create_ner_dataloader
from src.tokenizer.vocabulary import Vocabulary
from src.tokenizer.char_tokenizer import CharTokenizer
from src.models.ner_model import BiGRUNerNetWork
from src.loss.ner_loss import NerLoss
from src.trainer.trainer import Trainer
from src.utils.file_io import load_json, save_json
from src.metrics.entity_metrics import calculate_entity_level_metrics
def seed_everything(seed: int = 42):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def main():
"""
主函数,负责组装所有组件并启动NER训练任务。
"""
# --- 0. 设置随机数种子---
seed_everything(getattr(config, 'seed', 42))
# --- 1. 加载词汇表和标签映射, 并创建分词器 ---
vocab_path = os.path.join(config.data_dir, config.vocab_file)
tags_path = os.path.join(config.data_dir, config.tags_file)
train_path = os.path.join(config.data_dir, config.train_file)
dev_path = os.path.join(config.data_dir, config.dev_file)
vocab = Vocabulary.load_from_file(vocab_path)
tokenizer = CharTokenizer(vocab)
tag_map = load_json(tags_path)
id2tag = {v: k for k, v in tag_map.items()}
# --- 2. 创建数据加载器 ---
train_loader = create_ner_dataloader(
data_path=train_path,
tokenizer=tokenizer,
tag_map=tag_map,
batch_size=config.batch_size,
shuffle=True,
device=config.device
)
dev_loader = create_ner_dataloader(
data_path=dev_path,
tokenizer=tokenizer,
tag_map=tag_map,
batch_size=config.batch_size,
shuffle=False,
device=config.device
)
# --- 3. 初始化模型、优化器、损失函数 ---
model = BiGRUNerNetWork(
vocab_size=len(vocab),
hidden_size=config.hidden_size,
num_tags=len(tag_map),
num_gru_layers=config.num_gru_layers
)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
# 根据配置选择损失函数
if config.loss_type == "cross_entropy":
loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
else:
loss_fn = NerLoss(
loss_type=config.loss_type,
entity_weight=config.entity_loss_weight,
hard_negative_ratio=config.hard_negative_ratio
)
# --- 4. 定义评估函数 ---
def eval_metric_fn(all_logits, all_labels, all_attention_mask):
all_preds_ids = [torch.argmax(logits, dim=-1) for logits in all_logits]
all_labels_cpu = [labels.cpu() for labels in all_labels]
all_preds_ids_cpu = [preds.cpu() for preds in all_preds_ids]
all_attention_mask_cpu = [mask.cpu() for mask in all_attention_mask]
active_masks = [mask.bool() for mask in all_attention_mask_cpu]
# 计算基于 mask 的 token 级准确率
total_equal_tokens = 0
total_effective_tokens = 0
for preds, labels, mask in zip(all_preds_ids_cpu, all_labels_cpu, active_masks):
# preds/labels/mask: [B, T]
equal = (preds == labels) & mask
total_equal_tokens += int(equal.sum().item())
total_effective_tokens += int(mask.sum().item())
token_acc = (total_equal_tokens / total_effective_tokens) if total_effective_tokens > 0 else 0.0
metrics = calculate_entity_level_metrics(
all_preds_ids_cpu,
all_labels_cpu,
active_masks,
id2tag
)
metrics['token_acc'] = token_acc
return metrics
# --- 5. 初始化并启动训练器 ---
# 在初始化 Trainer 前,检查检查点文件是否存在
if config.resume_checkpoint and not os.path.exists(config.resume_checkpoint):
print(f"Checkpoint file not found: {config.resume_checkpoint}. Starting training from scratch.")
config.resume_checkpoint = None # 设为 None, 避免 Trainer 报错
trainer = Trainer(
model=model,
optimizer=optimizer,
loss_fn=loss_fn,
train_loader=train_loader,
dev_loader=dev_loader,
eval_metric_fn=eval_metric_fn,
output_dir=config.output_dir,
device=config.device,
# 新增参数
summary_writer_dir=config.output_summary_dir,
early_stopping_patience=config.early_stopping_patience,
resume_checkpoint=config.resume_checkpoint
)
# 在训练开始前,保存配置文件
os.makedirs(config.output_dir, exist_ok=True)
save_json(asdict(config), os.path.join(config.output_dir, "config.json"))
print(f"Configuration saved to {os.path.join(config.output_dir, 'config.json')}")
trainer.fit(epochs=config.epochs)
if __name__ == "__main__":
main()
输出:
(base-llm) PS E:\Datawhale 2026\base-llm202602\06_homework> & D:/Users/app/miniconda3/envs/base-llm/python.exe "e:/Datawhale 2026/base-llm202602/06_homework/05_train.py"
Trainer will run on device: cuda
Configuration saved to output\config.json
--- Epoch 1/20 ---
Training Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:38<00:00, 12.11it/s]
Train Metrics: Total Loss: 1.2551
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 50.77it/s]
Validation Metrics: precision: 0.6491, recall: 0.3397, f1: 0.4460, token_acc: 0.7471, loss: 0.8552
New best model found! Saving to output
--- Epoch 2/20 ---
Training Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:25<00:00, 18.06it/s]
Train Metrics: Total Loss: 0.7105
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 51.26it/s]
Validation Metrics: precision: 0.6991, recall: 0.4457, f1: 0.5444, token_acc: 0.7829, loss: 0.7056
New best model found! Saving to output
--- Epoch 3/20 ---
Training Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:25<00:00, 18.21it/s]
Train Metrics: Total Loss: 0.5795
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 50.81it/s]
Validation Metrics: precision: 0.6987, recall: 0.4990, f1: 0.5822, token_acc: 0.7977, loss: 0.6493
New best model found! Saving to output
--- Epoch 4/20 ---
Training Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:25<00:00, 18.14it/s]
Train Metrics: Total Loss: 0.5009
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 51.53it/s]
Validation Metrics: precision: 0.7107, recall: 0.4989, f1: 0.5863, token_acc: 0.8010, loss: 0.6356
New best model found! Saving to output
--- Epoch 5/20 ---
Training Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:25<00:00, 18.10it/s]
Train Metrics: Total Loss: 0.4343
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 51.07it/s]
Validation Metrics: precision: 0.7036, recall: 0.5375, f1: 0.6095, token_acc: 0.8035, loss: 0.6252
New best model found! Saving to output
--- Epoch 6/20 ---
Training Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:26<00:00, 18.03it/s]
Train Metrics: Total Loss: 0.3782
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 51.34it/s]
Validation Metrics: precision: 0.7031, recall: 0.5380, f1: 0.6096, token_acc: 0.8031, loss: 0.6481
New best model found! Saving to output
--- Epoch 7/20 ---
Training Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:25<00:00, 18.16it/s]
Train Metrics: Total Loss: 0.3447
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 51.30it/s]
Validation Metrics: precision: 0.7077, recall: 0.5311, f1: 0.6068, token_acc: 0.8023, loss: 0.6710
EarlyStopping counter: 1 out of 5
--- Epoch 8/20 ---
Training Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:25<00:00, 18.11it/s]
Train Metrics: Total Loss: 0.2886
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 50.98it/s]
Validation Metrics: precision: 0.7001, recall: 0.5646, f1: 0.6251, token_acc: 0.8076, loss: 0.6812
New best model found! Saving to output
--- Epoch 9/20 ---
Training Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:25<00:00, 18.17it/s]
Train Metrics: Total Loss: 0.2451
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 51.65it/s]
Validation Metrics: precision: 0.7005, recall: 0.5552, f1: 0.6194, token_acc: 0.8057, loss: 0.7110
EarlyStopping counter: 1 out of 5
--- Epoch 10/20 ---
Training Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:25<00:00, 18.18it/s]
Train Metrics: Total Loss: 0.2043
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 51.29it/s]
Validation Metrics: precision: 0.6868, recall: 0.5493, f1: 0.6104, token_acc: 0.8011, loss: 0.7942
EarlyStopping counter: 2 out of 5
--- Epoch 11/20 ---
Training Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:25<00:00, 18.12it/s]
Train Metrics: Total Loss: 0.1726
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 51.55it/s]
Validation Metrics: precision: 0.6831, recall: 0.5558, f1: 0.6129, token_acc: 0.8008, loss: 0.8374
EarlyStopping counter: 3 out of 5
--- Epoch 12/20 ---
Training Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:25<00:00, 18.25it/s]
Train Metrics: Total Loss: 0.1435
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 51.50it/s]
Validation Metrics: precision: 0.6852, recall: 0.5556, f1: 0.6136, token_acc: 0.7856, loss: 0.9739
EarlyStopping counter: 4 out of 5
--- Epoch 13/20 ---
Training Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:25<00:00, 18.24it/s]
Train Metrics: Total Loss: 0.1540
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 51.69it/s]
Validation Metrics: precision: 0.6797, recall: 0.5601, f1: 0.6141, token_acc: 0.8036, loss: 0.9249
EarlyStopping counter: 5 out of 5
Early stopping triggered.
📊 训练结果解读
关键指标说明
-
precision (精确率):预测为实体的片段中,真正正确的比例。约 70%。
-
recall (召回率):真实存在的实体中,被模型找出来的比例。约 56%。
-
f1 (F1值):精确率和召回率的调和平均,是评估实体识别模型的核心指标。最佳 62.5%。
-
token_acc (Token级别准确率):每个字分类的准确率,约 80%(这个值通常比实体F1高,因为大部分字是“O”非实体)。
-
loss:损失值,训练集损失逐渐降低,验证集损失在后期略有上升(可能轻微过拟合,早停及时停止了)。
这些指标表明模型已经学会识别部分医学实体,但仍有改进空间(例如调整学习率、增加训练数据、使用预训练模型等)。
python 06_predict.py --model_dir ./output --text "患者有高血压病史,需服用降压药。"
06_predict.py
import torch
import json
import os
import argparse
from src.models.ner_model import BiGRUNerNetWork
from src.tokenizer.vocabulary import Vocabulary
from src.tokenizer.char_tokenizer import CharTokenizer
from src.utils.file_io import load_json
class NerPredictor:
def __init__(self, model_dir, device='cpu'):
self.device = torch.device(device)
# --- 1. 加载配置文件以获取模型参数 ---
config_path = os.path.join(model_dir, 'config.json')
self.config = load_json(config_path)
# --- 2. 加载词汇表和标签映射 ---
vocab_path = os.path.join(self.config["data_dir"], self.config["vocab_file"])
tags_path = os.path.join(self.config["data_dir"], self.config["tags_file"])
self.vocab = Vocabulary.load_from_file(vocab_path)
self.tokenizer = CharTokenizer(self.vocab)
tag_map = load_json(tags_path)
self.id2tag = {v: k for k, v in tag_map.items()}
# --- 3. 初始化模型并加载权重 ---
self.model = BiGRUNerNetWork(
vocab_size=len(self.vocab),
hidden_size=self.config["hidden_size"],
num_tags=len(tag_map),
num_gru_layers=self.config["num_gru_layers"]
)
model_path = os.path.join(model_dir, 'best_model.pth')
self.model.load_state_dict(torch.load(model_path, map_location=self.device)['model_state_dict'])
self.model.to(self.device)
self.model.eval()
def predict(self, text):
tokens = self.tokenizer.text_to_tokens(text)
token_ids = self.tokenizer.tokens_to_ids(tokens)
# --- 预处理 ---
token_ids_tensor = torch.tensor([token_ids], dtype=torch.long).to(self.device)
attention_mask = torch.ones_like(token_ids_tensor)
# --- 模型预测 ---
with torch.no_grad():
logits = self.model(token_ids_tensor, attention_mask)
# --- 后处理 ---
predictions = torch.argmax(logits, dim=-1).squeeze(0)
tags = [self.id2tag[id_.item()] for id_ in predictions]
return self._extract_entities(tokens, tags)
def _extract_entities(self, tokens, tags):
entities = []
current_entity = None
for i, tag in enumerate(tags):
if tag.startswith('B-'):
# 如果前一个实体未正确结束,则放弃
if current_entity:
pass # 或者可以根据业务逻辑决定是否保存不完整的实体
current_entity = {"text": tokens[i], "type": tag[2:], "start": i}
elif tag.startswith('M-'):
# M 标签必须跟在 B- 或 M- 之后
if current_entity and current_entity["type"] == tag[2:]:
current_entity["text"] += tokens[i]
else:
# 非法 M 标签,重置当前实体
current_entity = None
elif tag.startswith('E-'):
# E 标签必须跟在 B- 或 M- 之后
if current_entity and current_entity["type"] == tag[2:]:
current_entity["text"] += tokens[i]
current_entity["end"] = i + 1
entities.append(current_entity)
# 实体已结束,重置
current_entity = None
elif tag.startswith('S-'):
# S 标签表示单个字符的实体
# 如果有未结束的实体,则放弃
current_entity = None
entities.append({"text": tokens[i], "type": tag[2:], "start": i, "end": i + 1})
else: # 'O' 标签
# O 标签意味着没有实体,或者实体已经结束
current_entity = None
# 循环结束后,不再处理任何未闭合的实体
return entities
def main():
parser = argparse.ArgumentParser(description="NER Prediction")
parser.add_argument("--model_dir", type=str, required=True, help="Directory of the saved model and config.")
parser.add_argument("--text", type=str, required=True, help="Text to predict.")
args = parser.parse_args()
predictor = NerPredictor(model_dir=args.model_dir)
entities = predictor.predict(args.text)
print(f"Text: {args.text}")
print(f"Entities: {json.dumps(entities, ensure_ascii=False, indent=2)}")
if __name__ == "__main__":
main()
输出:
(base-llm) PS E:\Datawhale 2026\base-llm202602\06_homework> python 06_predict.py --model_dir ./output --text "患者有高血压病史,需服用降压药。"
Text: 患者有高血压病史,需服用降压药。
Entities: [
{
"text": "高血压",
"type": "dis",
"start": 3,
"end": 6
},
{
"text": "降压药",
"type": "dru",
"start": 12,
"end": 15
}
]
解读:
NER 模型成功识别出了文本中的实体!🎉 输出结果完全正确,模型准确地找到了:
-
高血压 (类型
dis,疾病) 从第3个字符到第6个字符(注意:在字符串中,索引从0开始,所以3到6正好是“高血压”)。 -
降压药 (类型
dru,药物) 从第12到15个字符。
这表明训练的模型已经学会了从医学文本中抽取关键信息,这正是命名实体识别的核心任务。
更多推荐



所有评论(0)