NER(named_entity_recognition)命名实体识别

完成了一个完整的 NER 项目:

  1. 数据处理:构建了标签映射 categories.json 和词汇表 vocabulary.json

  2. 模型构建:实现了基于双向 GRU 的序列标注模型。

  3. 训练封装:使用 Trainer 类高效训练,早停机制防止过拟合。

  4. 预测部署:预测脚本能对新文本进行实时实体抽取,成功。

笔记:

第一节 命名实体识别

一、什么是命名实体识别(NER)?

想象一下,你读了一句话:“姚明昨天在北京打篮球。”
普通人一眼就能看出:“姚明”是个人,“北京”是个地方。
命名实体识别就是让电脑也学会这种能力——从文字中找出(人名)、哪里(地名)、什么机构(公司名)、什么产品等关键词语,并给它们贴上正确的标签。

常见的实体类型:

  • 人名(PER):姚明、雷军

  • 地名(LOC):北京、故宫

  • 机构名(ORG):阿里巴巴、英伟达

  • 产品名(PROD):黑神话:悟空、iPhone

  • 时间(TIME):昨天、2025年

一个词是不是实体,取决于场景。比如“苹果”:

  • 在水果摊前 → 不是实体(普通名词)

  • 在手机讨论中 → 是组织机构名(Apple公司)


二、NER有什么用?

1. 智能搜索

你在百度搜“姚明的身高”,电脑会:

  • 识别出意图:“属性查询”

  • 抽取出实体:姚明(主体)和身高(属性)

  • 然后从知识库里找到答案“2.26米”,直接显示在搜索结果最上方。

2. 聊天机器人

你在淘宝问:“我的快递到哪了?”
机器人通过NER认出你关心的是“物流信息”,于是自动帮你查快递状态,而不是傻傻地问你“什么快递?”

3. 医疗辅助诊断

医生输入:“病人多饮、多食,血糖检测偏高。”
NER系统可以抽取出:

  • 症状:多饮、多食

  • 检查:血糖检测
    结合知识库,可能提示医生:“怀疑II型糖尿病,建议进一步检查。”


三、怎么教电脑做NER?

方法1:查字典 + 规则(最原始)

  • 准备一个巨大的词典,比如《中国地名词典》,里面收录所有地名。

  • 在文本中一个个词去比对,如果匹配上了,就标为地名。

  • 优点:简单、快。

  • 缺点:词典里没有的新词(比如新上市的公司)就认不出来,而且维护词典很累。

方法2:序列标注(目前最常用)

把NER变成一个“给每个字贴标签”的任务。比如句子:

text

西  安  的  大  雁  塔  门  票  多  少  钱

我们用标签告诉电脑:

  • 西(B-LOC,地点开头)、安(E-LOC,地点结尾)

  • 大(B-LOC)、雁(M-LOC)、塔(E-LOC)

  • 门(B-ATTR,属性开头)、票(E-ATTR,属性结尾)

这样电脑就能学会“西安”和“大雁塔”都是地点,“门票”是属性。

常用标签体系

  • BIO:B(开始)、I(中间)、O(外部)

  • BMES:B(开始)、M(中间)、E(结尾)、S(单字实体)

模型结构
通常用BERT这样的预训练模型来提取每个字的特征,然后接一个分类层,给每个字打标签。

方法3:指针网络 + 片段网络(解决嵌套问题)

有时实体是嵌套的,比如:

他 就 读 于 北 京 大 学

“北京”是地名,“北京大学”是机构名。用普通的序列标注很难同时识别这两个,因为一个字只能有一个标签。

指针网络的做法:

  • 为每个字训练两个二分类器:判断它是不是某个实体的开头,以及是不是某个实体的结尾

  • 比如:“北”是地名开头,也是机构名开头;“京”是地名结尾;“学”是机构名结尾。

  • 然后后处理把“开头-结尾”配对,得到候选实体:北京、北京大学、京大(错误的会被过滤)。

片段网络则是对所有可能的片段(比如“北京”、“京大”、“北京大学”)做分类,判断每个片段属于哪类实体。两者结合,既能解决嵌套,又不会计算量太大。

方法4:用大模型生成(最新潮流)

直接把任务写成提示词,让大模型自己输出结构化结果。例如:

输入:请从句子“姚明昨天在北京打篮球”中抽取出人名、地名。
输出:人名:姚明;地名:北京

这种方法不需要训练,但大模型有时会“幻觉”,输出原文中没有的东西。


第二节 NER项目的数据处理

一、我们要做什么?

让电脑学会命名实体识别(NER),就像教它玩一个“找词贴标签”的游戏。但在玩游戏之前,我们需要把“游戏规则”和“道具”准备好。数据处理就是做这件事——把原始的文字数据,变成电脑能看懂的数字

最终目标

  1. 给电脑的输入(X):是一个数字表格,形状为 [批次大小, 句子长度]。表格里的每个数字,代表句子中每个字在词汇表里的编号。

  2. 给电脑的正确答案(Y):也是一个同样形状的数字表格,每个数字代表对应字的实体标签编号(比如“B-dis”的编号是几)。

  3. 两本“字典”

    • 词汇表:把每个字映射成一个唯一的数字ID。

    • 标签映射表:把每个实体标签(如“B-dis”)也映射成一个唯一的数字ID。

二、第一步:构建标签映射表

我们的数据是中文医学实体数据集(CMeEE-V2),每条数据包含一段文本和里面标注的实体。比如:

text

文本:(2)室上性心动过速可用常规抗心律失常药物控制,年龄小于5岁。
实体:["室上性心动过速"是“疾病(dis)”,"抗心律失常药物"是“药物(dru)”]

2.1 数据长什么样?

数据是JSON格式,里面是一个大列表,每个元素是一个字典:

json

{
  "text": "(2)室上性心动过速可用常规抗心律失常药物控制,年龄小于5岁。",
  "entities": [
    {"start_idx": 3, "end_idx": 9, "type": "dis", "entity": "室上性心动过速"},
    {"start_idx": 14, "end_idx": 20, "type": "dru", "entity": "抗心律失常药物"}
  ]
}
  • start_idx 和 end_idx 是实体在文本中的起始和结束位置(注意:是包含两端的闭区间)。

  • type 是实体类型,比如 dis(疾病)、dru(药物)。

2.2 提取所有实体类型

我们需要先找出数据里所有可能的实体类型。比如这个数据集里有9种类型:dru(药物)、dis(疾病)、bod(身体部位)、sym(症状)等。

2.3 用BMES方案给标签编码

我们给每个字贴标签时,不能只贴一个类型,还要说明它在实体里的位置。所以用BMES方案:

  • B(Begin):实体开头

  • M(Middle):实体中间

  • E(End):实体结尾

  • S(Single):单字实体

  • O(Outside):不是实体

比如对于“室上性心动过速”这个疾病实体,标签序列是:B-dis M-dis M-dis M-dis E-dis

最后生成的标签映射表(categories.json)就像:

json

{
    "O": 0,
    "B-dis": 1,
    "M-dis": 2,
    "E-dis": 3,
    "S-dis": 4,
    "B-dru": 5,
    "M-dru": 6,
    "E-dru": 7,
    "S-dru": 8,
    ...
}

“O”的编号是0,其他依次递增。

三、第二步:构建词汇表

现在需要把文本里的每个字,也变成一个数字ID。就像查字典一样。

3.1 统计所有出现的字

遍历所有文本,数一数每个字出现了多少次。

3.2 规范化文本(统一全角/半角)

你会发现数据里既有全角字符(如)又有半角字符(如,()。它们在电脑眼里是两个不同的字,但对我们来说意思一样。所以把它们统一转换成半角,减少词汇表大小。

3.3 过滤低频字

可以设置一个阈值,比如只保留出现次数≥2的字,去掉那些只出现一次的罕见字。这样可以进一步精简词汇表。

3.4 添加特殊标记

最后在词汇表最前面加上两个特殊标记:

  • <PAD>:用于把不同长度的句子填充到一样长。

  • <UNK>:当遇到词汇表里没有的字时,就用它代替。

生成的词汇表(vocabulary.json)是一个列表:

json

[
    "<PAD>",
    "<UNK>",
    "一",
    "了",
    "人",
    "他",
    ...
]

四、第三步:封装数据加载器

现在有了两本字典(词汇表和标签映射表),我们需要一个“数据流水线”,能自动把原始数据变成模型能吃的批次。

4.1 Vocabulary类——查字典

创建一个类,加载词汇表,并提供“字→ID”的转换功能。

4.2 NerDataset类——加工单条数据

继承PyTorch的Dataset,对每条数据做:

  1. 把文本转成字符列表。

  2. 用Vocabulary把字符转成token_ids

  3. 初始化一个全是“O”的标签序列。

  4. 根据实体标注,把对应位置的“O”替换成正确的BMES标签。

  5. 用标签映射表把标签字符串转成label_ids

  6. 返回包含token_idslabel_ids的字典。

4.3 collate_fn函数——打包成一个批次

因为每个句子长度不同,不能直接打包。collate_fn做三件事:

  1. 动态填充:找出批次里最长的句子,把其他句子填充到一样长(用<PAD>的ID填充)。

  2. 生成注意力掩码:创建一个和输入一样形状的矩阵,真实位置是1,填充位置是0,告诉模型哪里是真正的字。

  3. 处理标签填充:对标签也用同样的方式填充,但填充值设为-100(PyTorch损失函数会自动忽略这个值)。

4.4 整合成DataLoader

把所有组件组合起来,得到一个能自动产出批次的DataLoader

五、代码

01_build_category.py

import json
import os


def save_json(data, file_path):
    """
    将数据以格式化的 JSON 形式保存到文件
    """
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    with open(file_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=4)


def collect_entity_types_from_file(file_path):
    """
    从单个数据文件中提取所有唯一的实体类型
    """
    types = set()
    with open(file_path, 'r', encoding='utf-8') as f:
        all_data = json.load(f)
        for data in all_data:
            # 遍历实体列表,提取 'type' 字段
            for entity in data['entities']:
                types.add(entity['type'])
    return types


def generate_tag_map(data_files, output_file):
    """
    从数据文件构建 BMES 标签映射并保存
    """
    # 1. 从所有文件中收集实体类型
    all_entity_types = set()
    for file_path in data_files:
        all_entity_types.update(collect_entity_types_from_file(file_path))

    # 2. 排序以保证映射一致性
    sorted_types = sorted(list(all_entity_types))
    print(f"发现的实体类型: {sorted_types}")

    # 3. 构建 BMES 标签映射
    tag_to_id = {'O': 0}  # 'O' 代表非实体
    for entity_type in sorted_types:
        for prefix in ['B', 'M', 'E', 'S']:
            tag_name = f"{prefix}-{entity_type}"
            tag_to_id[tag_name] = len(tag_to_id)

    print(f"\n已生成 {len(tag_to_id)} 个标签映射。")

    # 4. 保存映射文件
    save_json(tag_to_id, output_file)
    print(f"标签映射已保存至: {output_file}")


if __name__ == '__main__':
    # 定义输入的数据文件和期望的输出路径
    train_file = 'E:/Datawhale 2026/base-llm202602/06_homework/CMeEE-V2_train.json'
    dev_file = 'E:/Datawhale 2026/base-llm202602/06_homework/CMeEE-V2_dev.json'
    output_path = './categories.json'

    generate_tag_map(data_files=[train_file, dev_file], output_file=output_path)

输出:
 

(base) PS E:\Datawhale 2026\base-llm202602\06_homework> & D:/Users/app/miniconda3/envs/base-llm/python.exe "e:/Datawhale 2026/base-llm202602/06_homework/01_build_category.py"
发现的实体类型: ['bod', 'dep', 'dis', 'dru', 'equ', 'ite', 'mic', 'pro', 'sym']

已生成 37 个标签映射。
标签映射已保存至: ./categories.json

02_build_vocabulary.py

import json
import os
from collections import Counter


def save_json(data, file_path):
    """
    将数据以易于阅读的格式保存为 JSON 文件
    """
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    with open(file_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=4)


def normalize_text(text):
    """
    规范化文本
    """
    full_width = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!#$%&’()*+,-./:;<=>?@[\]^_`{|}~""
    half_width = r"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!#$%&'" + r'()*+,-./:;<=>?@[\]^_`{|}~".'
    mapping = str.maketrans(full_width, half_width)
    return text.translate(mapping)


def create_char_vocab(data_files, output_file, min_freq=1):
    """
    从数据文件创建字符级词汇表
    """
    char_counts = Counter()
    for file_path in data_files:
        with open(file_path, 'r', encoding='utf-8') as f:
            all_data = json.load(f)
            for data in all_data:
                text = normalize_text(data['text'])
                char_counts.update(list(text))

    # 过滤低频词
    frequent_chars = [char for char, count in char_counts.items() if count >= min_freq]
    
    # 保证每次生成结果一致
    frequent_chars.sort()

    # 添加特殊标记
    special_tokens = ["<PAD>", "<UNK>"]
    final_vocab_list = special_tokens + frequent_chars
    
    print(f"词汇表大小 (min_freq={min_freq}): {len(final_vocab_list)}")

    # 保存词汇表
    save_json(final_vocab_list, output_file)
    print(f"词汇表已保存至: {output_file}")


if __name__ == '__main__':
    # 定义输入的数据文件和期望的输出路径
    train_file = 'E:/Datawhale 2026/base-llm202602/06_homework/CMeEE-V2_train.json'
    dev_file = 'E:/Datawhale 2026/base-llm202602/06_homework/CMeEE-V2_dev.json'
    output_path = './vocabulary.json'

    # 设置字符最低频率,1表示包含所有出现过的字符
    create_char_vocab(data_files=[train_file, dev_file], output_file=output_path, min_freq=1)

输出:

(base) PS E:\Datawhale 2026\base-llm202602\06_homework> & D:/Users/app/miniconda3/envs/base-llm/python.exe "e:/Datawhale 2026/base-llm202602/06_homework/02_build_vocabulary.py"
词汇表大小 (min_freq=1): 2970
词汇表已保存至: ./vocabulary.json

03_data_loader.py

import json
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence


def normalize_text(text):
    """
    规范化文本,例如将全角字符转换为半角字符。
    """
    full_width = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!#$%&’()*+,-./:;<=>?@[\]^_`{|}~""
    half_width = r"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!#$%&'" + r'()*+,-./:;<=>?@[\]^_`{|}~".'
    mapping = str.maketrans(full_width, half_width)
    return text.translate(mapping)


class Vocabulary:
    """
    负责管理词汇表和 token 到 id 的映射。
    """
    def __init__(self, vocab_path):
        with open(vocab_path, 'r', encoding='utf-8') as f:
            self.tokens = json.load(f)
        self.token_to_id = {token: i for i, token in enumerate(self.tokens)}
        self.pad_id = self.token_to_id['<PAD>']
        self.unk_id = self.token_to_id['<UNK>']

    def __len__(self):
        return len(self.tokens)

    def convert_tokens_to_ids(self, tokens):
        return [self.token_to_id.get(token, self.unk_id) for token in tokens]


class NerDataset(Dataset):
    """
    处理 NER 数据,并将其转换为适用于 PyTorch 模型的格式。
    """
    def __init__(self, data_path, vocab: Vocabulary, tag_map: dict):
        self.vocab = vocab
        self.tag_to_id = tag_map
        with open(data_path, 'r', encoding='utf-8') as f:
            self.records = json.load(f)

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        record = self.records[idx]
        text = normalize_text(record['text'])
        tokens = list(text)
        
        # 将文本 tokens 转换为 ids
        token_ids = self.vocab.convert_tokens_to_ids(tokens)

        # 初始化标签序列为 'O'
        tags = ['O'] * len(tokens)
        for entity in record.get('entities', []):
            entity_type = entity['type']
            start = entity['start_idx']
            end = entity['end_idx']

            if end >= len(tokens): continue

            if start == end:
                tags[start] = f'S-{entity_type}'
            else:
                tags[start] = f'B-{entity_type}'
                tags[end] = f'E-{entity_type}'
                for i in range(start + 1, end):
                    tags[i] = f'M-{entity_type}'
        
        # 将标签转换为 ids
        label_ids = [self.tag_to_id[tag] for tag in tags]

        return {
            "token_ids": torch.tensor(token_ids, dtype=torch.long),
            "label_ids": torch.tensor(label_ids, dtype=torch.long)
        }


def create_ner_dataloader(data_path, vocab, tag_map, batch_size, shuffle=False):
    """
    创建 NER 任务的 DataLoader。
    """
    dataset = NerDataset(data_path, vocab, tag_map)
    
    def collate_batch(batch):
        token_ids_list = [item['token_ids'] for item in batch]
        label_ids_list = [item['label_ids'] for item in batch]

        padded_token_ids = pad_sequence(token_ids_list, batch_first=True, padding_value=vocab.pad_id)
        padded_label_ids = pad_sequence(label_ids_list, batch_first=True, padding_value=-100)  # -100 用于在计算损失时忽略填充部分

        attention_mask = (padded_token_ids != vocab.pad_id).long()

        return {
            "token_ids": padded_token_ids,
            "label_ids": padded_label_ids,
            "attention_mask": attention_mask
        }

    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_batch)


if __name__ == '__main__':
    # 文件路径
    # 定义输入的数据文件和期望的输出路径
    train_file = 'E:/Datawhale 2026/base-llm202602/06_homework/CMeEE-V2_train.json'
    dev_file = 'E:/Datawhale 2026/base-llm202602/06_homework/CMeEE-V2_dev.json'
    vocab_file = './vocabulary.json'
    categories_file = './categories.json'

    # 1. 加载词汇表和标签映射
    vocabulary = Vocabulary(vocab_path=vocab_file)
    with open(categories_file, 'r', encoding='utf-8') as f:
        tag_map = json.load(f)
    print("词汇表和标签映射加载完成。")

    # 2. 创建 DataLoader
    train_loader = create_ner_dataloader(
        data_path=train_file,
        vocab=vocabulary,
        tag_map=tag_map,
        batch_size=4,
        shuffle=True
    )
    print("DataLoader 创建完成。")

    # 3. 验证一个批次的数据
    print("\n--- 验证一个批次的数据 ---")
    batch = next(iter(train_loader))
    
    print(f"  Token IDs (shape): {batch['token_ids'].shape}")
    print(f"  Label IDs (shape): {batch['label_ids'].shape}")
    print(f"  Attention Mask (shape): {batch['attention_mask'].shape}")
    print(f"  Token IDs (sample): {batch['token_ids'][0][:20]}...")
    print(f"  Label IDs (sample): {batch['label_ids'][0][:20]}...")
    print(f"  Attention Mask (sample): {batch['attention_mask'][0][:20]}...")

输出:

(base) PS E:\Datawhale 2026\base-llm202602\06_homework> & D:/Users/app/miniconda3/envs/base-llm/python.exe "e:/Datawhale 2026/base-llm202602/06_homework/03_data_loader.py"
词汇表和标签映射加载完成。
DataLoader 创建完成。

--- 验证一个批次的数据 ---
  Token IDs (shape): torch.Size([4, 60])
  Label IDs (shape): torch.Size([4, 60])
  Attention Mask (shape): torch.Size([4, 60])
  Token IDs (sample): tensor([ 342, 1793, 2295,  340, 1410, 1374, 1804, 1268, 1321, 1262,   38,   65,
         575, 2420,  283, 1507, 1880,  147,    0,    0])...
  Label IDs (sample): tensor([  33,   34,   34,   34,   34,   34,   34,   34,   34,   34,   21,   23,
          34,   21,   23,   34,   35,    0, -100, -100])...
  Attention Mask (sample): tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0])...

📊 输出解读

词汇表和标签映射加载完成。
DataLoader 创建完成。

--- 验证一个批次的数据 ---
  Token IDs (shape): torch.Size([4, 60])
  Label IDs (shape): torch.Size([4, 60])
  Attention Mask (shape): torch.Size([4, 60])
  • Token IDs:一个批次有 4 条样本,每条样本被填充/截断为长度 60。每个数字代表一个字符在词汇表中的 ID。

  • Label IDs:同样形状,每个数字代表该字符对应的实体标签 ID(如 0 代表 O,1 代表 B-bod 等)。注意填充位置的标签是 -100,在计算损失时会被忽略。

  • Attention Mask:1 表示真实 token,0 表示填充位置,模型在计算注意力时会忽略填充部分。

样例数据解释

Token IDs (sample): tensor([ 342, 1793, 2295,  340, 1410, 1374, 1804, 1268, 1321, 1262,   38,   65,
         575, 2420,  283, 1507, 1880,  147,    0,    0])...
  • 这是一条样本的前 20 个 token。末尾的 0 是 <PAD> 的 ID,表示填充部分。

Label IDs (sample): tensor([  33,   34,   34,   34,   34,   34,   34,   34,   34,   34,   21,   23,
          34,   21,   23,   34,   35,    0, -100, -100])...
  • 对应的标签 ID:非零值代表实体标签(如 33 可能是某个实体的 B-xxx34 是 M-xxx35 是 E-xxx 等)。最后的 -100 是填充位置的标签(损失计算时会忽略)。

Attention Mask (sample): tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0])...
  • 前 18 个位置是真实 token(mask=1),后 2 个是填充(mask=0)。

这些结果完全符合预期,说明数据处理流程已经正确实现,数据可以被模型直接使用了!


✅ 已完成的部分

  • 构建了标签映射表 categories.json

  • 构建了字符级词汇表 vocabulary.json

  • 创建了可批量加载数据的 DataLoader,并验证成功

04_model.py

import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn

class BiGRUNerNetWork(nn.Module):
    def __init__(self, vocab_size, hidden_size, num_tags, num_gru_layers=1):
        super().__init__()
        # 1. Token Embedding 层
        self.embedding = nn.Embedding(vocab_size, hidden_size)

        # 2. 使用 ModuleList 构建多层双向 GRU
        self.gru_layers = nn.ModuleList()
        for _ in range(num_gru_layers):
            self.gru_layers.append(
                nn.GRU(
                    input_size=hidden_size,
                    hidden_size=hidden_size,
                    num_layers=1,
                    batch_first=True,
                    bidirectional=True  # 开启双向
                )
            )

        # 3. 特征融合层
        self.fc = nn.Linear(hidden_size * 2, hidden_size)

        # 4. 分类决策层 (Classifier)
        self.classifier = nn.Linear(hidden_size, num_tags)

    def forward(self, token_ids, attention_mask):
        # 1. 计算真实长度
        lengths = attention_mask.sum(dim=1).cpu()

        # 2. 获取词向量
        embedded_text = self.embedding(token_ids)

        # 3. 打包序列
        current_packed_input = rnn.pack_padded_sequence(
            embedded_text, lengths, batch_first=True, enforce_sorted=False
        )

        # 4. 循环通过 GRU 层
        for gru_layer in self.gru_layers:
            # GRU 输出 (packed)
            packed_output, _ = gru_layer(current_packed_input)

            # 解包以进行后续操作,并指定 total_length
            output, _ = rnn.pad_packed_sequence(
                packed_output, batch_first=True, total_length=token_ids.shape[1]
            )

            # 特征融合
            features = self.fc(output)

            # 残差连接
            # 同样需要解包上一层的输入
            input_padded, _ = rnn.pad_packed_sequence(
                current_packed_input, batch_first=True, total_length=token_ids.shape[1]
            )
            current_input = features + input_padded

            # 重新打包作为下一层的输入
            current_packed_input = rnn.pack_padded_sequence(
                current_input, lengths, batch_first=True, enforce_sorted=False
            )

        # 5. 解包最终输出用于分类
        final_output, _ = rnn.pad_packed_sequence(
            current_packed_input, batch_first=True, total_length=token_ids.shape[1]
        )

        # 6. 分类
        logits = self.classifier(final_output)

        return logits


if __name__ == '__main__':

    token_ids = torch.tensor([
        [210,   18,  871, 147,   0,   0,   0,   0],
        [922, 2962,  842, 210,  18, 871, 147,   0]
    ], dtype=torch.int64)

    # attention_mask 标记哪些是真实 token (1) 哪些是填充 (0)
    attention_mask = torch.tensor([
        [1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0]
    ], dtype=torch.int64)

    label_ids = torch.tensor([
        [0, 0, 0, 0, -100, -100, -100, -100],
        [0, 0, 0, 0,    0,    0,    0, -100]
    ], dtype=torch.int64)

    # 实例化模型
    model = BiGRUNerNetWork(
        vocab_size=10000,
        hidden_size=128,
        num_tags=37,
        num_gru_layers=2
    )

    # 3. 执行前向传播
    logits = model(token_ids=token_ids, attention_mask=attention_mask)

    # 4. 构造损失函数
    loss_fn = nn.CrossEntropyLoss(ignore_index=-100, reduction='none')

    # 5. 计算损失
    # CrossEntropyLoss 要求类别维度在前,所以需要交换最后两个维度
    # [batch, seq_len, num_tags] -> [batch, num_tags, seq_len]
    permuted_logits = torch.permute(logits, dims=(0, 2, 1))
    loss = loss_fn(permuted_logits, label_ids)

    # 6. 打印结果
    print(f"Logits shape: {logits.shape}")
    print(f"Loss shape: {loss.shape}")
    print("\n每个 Token 的损失:")
    print(loss)

05_train.py

pip install tensorboard -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com

import os
import torch
import torch.nn as nn
import sys
from dataclasses import asdict
# 导入定义的所有组件
from src.configs.configs import config
from src.data.data_loader import create_ner_dataloader
from src.tokenizer.vocabulary import Vocabulary
from src.tokenizer.char_tokenizer import CharTokenizer
from src.models.ner_model import BiGRUNerNetWork
from src.loss.ner_loss import NerLoss
from src.trainer.trainer import Trainer
from src.utils.file_io import load_json, save_json
from src.metrics.entity_metrics import calculate_entity_level_metrics

def seed_everything(seed: int = 42):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def main():
    """
    主函数,负责组装所有组件并启动NER训练任务。
    """
    # --- 0. 设置随机数种子---
    seed_everything(getattr(config, 'seed', 42))

    # --- 1. 加载词汇表和标签映射, 并创建分词器 ---
    vocab_path = os.path.join(config.data_dir, config.vocab_file)
    tags_path = os.path.join(config.data_dir, config.tags_file)
    train_path = os.path.join(config.data_dir, config.train_file)
    dev_path = os.path.join(config.data_dir, config.dev_file)
    
    vocab = Vocabulary.load_from_file(vocab_path)
    tokenizer = CharTokenizer(vocab)
    tag_map = load_json(tags_path)
    id2tag = {v: k for k, v in tag_map.items()}

    # --- 2. 创建数据加载器 ---
    train_loader = create_ner_dataloader(
        data_path=train_path,
        tokenizer=tokenizer,
        tag_map=tag_map,
        batch_size=config.batch_size,
        shuffle=True,
        device=config.device
    )
    dev_loader = create_ner_dataloader(
        data_path=dev_path,
        tokenizer=tokenizer,
        tag_map=tag_map,
        batch_size=config.batch_size,
        shuffle=False,
        device=config.device
    )

    # --- 3. 初始化模型、优化器、损失函数 ---
    model = BiGRUNerNetWork(
        vocab_size=len(vocab),
        hidden_size=config.hidden_size,
        num_tags=len(tag_map),
        num_gru_layers=config.num_gru_layers
    )
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
    
    # 根据配置选择损失函数
    if config.loss_type == "cross_entropy":
        loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
    else:
        loss_fn = NerLoss(
            loss_type=config.loss_type,
            entity_weight=config.entity_loss_weight,
            hard_negative_ratio=config.hard_negative_ratio
        )

    # --- 4. 定义评估函数 ---
    def eval_metric_fn(all_logits, all_labels, all_attention_mask):
        all_preds_ids = [torch.argmax(logits, dim=-1) for logits in all_logits]
        
        all_labels_cpu = [labels.cpu() for labels in all_labels]
        all_preds_ids_cpu = [preds.cpu() for preds in all_preds_ids]
        all_attention_mask_cpu = [mask.cpu() for mask in all_attention_mask]
        
        active_masks = [mask.bool() for mask in all_attention_mask_cpu]

        # 计算基于 mask 的 token 级准确率
        total_equal_tokens = 0
        total_effective_tokens = 0
        for preds, labels, mask in zip(all_preds_ids_cpu, all_labels_cpu, active_masks):
            # preds/labels/mask: [B, T]
            equal = (preds == labels) & mask
            total_equal_tokens += int(equal.sum().item())
            total_effective_tokens += int(mask.sum().item())
        token_acc = (total_equal_tokens / total_effective_tokens) if total_effective_tokens > 0 else 0.0

        metrics = calculate_entity_level_metrics(
            all_preds_ids_cpu, 
            all_labels_cpu, 
            active_masks, 
            id2tag
        )
        metrics['token_acc'] = token_acc
        return metrics

    # --- 5. 初始化并启动训练器 ---
    # 在初始化 Trainer 前,检查检查点文件是否存在
    if config.resume_checkpoint and not os.path.exists(config.resume_checkpoint):
        print(f"Checkpoint file not found: {config.resume_checkpoint}. Starting training from scratch.")
        config.resume_checkpoint = None # 设为 None, 避免 Trainer 报错

    trainer = Trainer(
        model=model,
        optimizer=optimizer,
        loss_fn=loss_fn,
        train_loader=train_loader,
        dev_loader=dev_loader,
        eval_metric_fn=eval_metric_fn,
        output_dir=config.output_dir,
        device=config.device,
        # 新增参数
        summary_writer_dir=config.output_summary_dir,
        early_stopping_patience=config.early_stopping_patience,
        resume_checkpoint=config.resume_checkpoint
    )

    # 在训练开始前,保存配置文件
    os.makedirs(config.output_dir, exist_ok=True)
    save_json(asdict(config), os.path.join(config.output_dir, "config.json"))
    print(f"Configuration saved to {os.path.join(config.output_dir, 'config.json')}")

    trainer.fit(epochs=config.epochs)

if __name__ == "__main__":
    main()

输出:
(base-llm) PS E:\Datawhale 2026\base-llm202602\06_homework> & D:/Users/app/miniconda3/envs/base-llm/python.exe "e:/Datawhale 2026/base-llm202602/06_homework/05_train.py" 
Trainer will run on device: cuda
Configuration saved to output\config.json
--- Epoch 1/20 ---
Training Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:38<00:00, 12.11it/s]
Train Metrics: Total Loss: 1.2551
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 50.77it/s] 
Validation Metrics: precision: 0.6491, recall: 0.3397, f1: 0.4460, token_acc: 0.7471, loss: 0.8552
New best model found! Saving to output
--- Epoch 2/20 ---
Training Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:25<00:00, 18.06it/s]
Train Metrics: Total Loss: 0.7105
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 51.26it/s] 
Validation Metrics: precision: 0.6991, recall: 0.4457, f1: 0.5444, token_acc: 0.7829, loss: 0.7056
New best model found! Saving to output
--- Epoch 3/20 ---
Training Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:25<00:00, 18.21it/s]
Train Metrics: Total Loss: 0.5795
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 50.81it/s] 
Validation Metrics: precision: 0.6987, recall: 0.4990, f1: 0.5822, token_acc: 0.7977, loss: 0.6493
New best model found! Saving to output
--- Epoch 4/20 ---
Training Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:25<00:00, 18.14it/s]
Train Metrics: Total Loss: 0.5009
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 51.53it/s] 
Validation Metrics: precision: 0.7107, recall: 0.4989, f1: 0.5863, token_acc: 0.8010, loss: 0.6356
New best model found! Saving to output
--- Epoch 5/20 ---
Training Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:25<00:00, 18.10it/s]
Train Metrics: Total Loss: 0.4343
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 51.07it/s] 
Validation Metrics: precision: 0.7036, recall: 0.5375, f1: 0.6095, token_acc: 0.8035, loss: 0.6252
New best model found! Saving to output
--- Epoch 6/20 ---
Training Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:26<00:00, 18.03it/s]
Train Metrics: Total Loss: 0.3782
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 51.34it/s] 
Validation Metrics: precision: 0.7031, recall: 0.5380, f1: 0.6096, token_acc: 0.8031, loss: 0.6481
New best model found! Saving to output
--- Epoch 7/20 ---
Training Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:25<00:00, 18.16it/s]
Train Metrics: Total Loss: 0.3447
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 51.30it/s] 
Validation Metrics: precision: 0.7077, recall: 0.5311, f1: 0.6068, token_acc: 0.8023, loss: 0.6710
EarlyStopping counter: 1 out of 5
--- Epoch 8/20 ---
Training Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:25<00:00, 18.11it/s] 
Train Metrics: Total Loss: 0.2886
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 50.98it/s] 
Validation Metrics: precision: 0.7001, recall: 0.5646, f1: 0.6251, token_acc: 0.8076, loss: 0.6812
New best model found! Saving to output
--- Epoch 9/20 ---
Training Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:25<00:00, 18.17it/s]
Train Metrics: Total Loss: 0.2451
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 51.65it/s] 
Validation Metrics: precision: 0.7005, recall: 0.5552, f1: 0.6194, token_acc: 0.8057, loss: 0.7110
EarlyStopping counter: 1 out of 5
--- Epoch 10/20 ---
Training Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:25<00:00, 18.18it/s]
Train Metrics: Total Loss: 0.2043
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 51.29it/s] 
Validation Metrics: precision: 0.6868, recall: 0.5493, f1: 0.6104, token_acc: 0.8011, loss: 0.7942
EarlyStopping counter: 2 out of 5
--- Epoch 11/20 ---
Training Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:25<00:00, 18.12it/s] 
Train Metrics: Total Loss: 0.1726
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 51.55it/s] 
Validation Metrics: precision: 0.6831, recall: 0.5558, f1: 0.6129, token_acc: 0.8008, loss: 0.8374
EarlyStopping counter: 3 out of 5
--- Epoch 12/20 ---
Training Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:25<00:00, 18.25it/s] 
Train Metrics: Total Loss: 0.1435
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 51.50it/s] 
Validation Metrics: precision: 0.6852, recall: 0.5556, f1: 0.6136, token_acc: 0.7856, loss: 0.9739
EarlyStopping counter: 4 out of 5
--- Epoch 13/20 ---
Training Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [00:25<00:00, 18.24it/s] 
Train Metrics: Total Loss: 0.1540
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 51.69it/s] 
Validation Metrics: precision: 0.6797, recall: 0.5601, f1: 0.6141, token_acc: 0.8036, loss: 0.9249
EarlyStopping counter: 5 out of 5
Early stopping triggered.

📊 训练结果解读

关键指标说明

  • precision (精确率):预测为实体的片段中,真正正确的比例。约 70%。

  • recall (召回率):真实存在的实体中,被模型找出来的比例。约 56%。

  • f1 (F1值):精确率和召回率的调和平均,是评估实体识别模型的核心指标。最佳 62.5%

  • token_acc (Token级别准确率):每个字分类的准确率,约 80%(这个值通常比实体F1高,因为大部分字是“O”非实体)。

  • loss:损失值,训练集损失逐渐降低,验证集损失在后期略有上升(可能轻微过拟合,早停及时停止了)。

这些指标表明模型已经学会识别部分医学实体,但仍有改进空间(例如调整学习率、增加训练数据、使用预训练模型等)。


python 06_predict.py --model_dir ./output --text "患者有高血压病史,需服用降压药。"

06_predict.py

import torch
import json
import os
import argparse
from src.models.ner_model import BiGRUNerNetWork
from src.tokenizer.vocabulary import Vocabulary
from src.tokenizer.char_tokenizer import CharTokenizer
from src.utils.file_io import load_json


class NerPredictor:
    def __init__(self, model_dir, device='cpu'):
        self.device = torch.device(device)
        
        # --- 1. 加载配置文件以获取模型参数 ---
        config_path = os.path.join(model_dir, 'config.json')
        self.config = load_json(config_path)

        # --- 2. 加载词汇表和标签映射 ---
        vocab_path = os.path.join(self.config["data_dir"], self.config["vocab_file"])
        tags_path = os.path.join(self.config["data_dir"], self.config["tags_file"])

        self.vocab = Vocabulary.load_from_file(vocab_path)
        self.tokenizer = CharTokenizer(self.vocab)
        tag_map = load_json(tags_path)
        self.id2tag = {v: k for k, v in tag_map.items()}

        # --- 3. 初始化模型并加载权重 ---
        self.model = BiGRUNerNetWork(
            vocab_size=len(self.vocab),
            hidden_size=self.config["hidden_size"],
            num_tags=len(tag_map),
            num_gru_layers=self.config["num_gru_layers"]
        )
        model_path = os.path.join(model_dir, 'best_model.pth')
        self.model.load_state_dict(torch.load(model_path, map_location=self.device)['model_state_dict'])
        self.model.to(self.device)
        self.model.eval()

    def predict(self, text):
        tokens = self.tokenizer.text_to_tokens(text)
        token_ids = self.tokenizer.tokens_to_ids(tokens)
        
        # --- 预处理 ---
        token_ids_tensor = torch.tensor([token_ids], dtype=torch.long).to(self.device)
        attention_mask = torch.ones_like(token_ids_tensor)

        # --- 模型预测 ---
        with torch.no_grad():
            logits = self.model(token_ids_tensor, attention_mask)
        
        # --- 后处理 ---
        predictions = torch.argmax(logits, dim=-1).squeeze(0)
        tags = [self.id2tag[id_.item()] for id_ in predictions]

        return self._extract_entities(tokens, tags)

    def _extract_entities(self, tokens, tags):
        entities = []
        current_entity = None
        for i, tag in enumerate(tags):
            if tag.startswith('B-'):
                # 如果前一个实体未正确结束,则放弃
                if current_entity:
                    pass # 或者可以根据业务逻辑决定是否保存不完整的实体
                current_entity = {"text": tokens[i], "type": tag[2:], "start": i}
            elif tag.startswith('M-'):
                # M 标签必须跟在 B- 或 M- 之后
                if current_entity and current_entity["type"] == tag[2:]:
                    current_entity["text"] += tokens[i]
                else:
                    # 非法 M 标签,重置当前实体
                    current_entity = None
            elif tag.startswith('E-'):
                # E 标签必须跟在 B- 或 M- 之后
                if current_entity and current_entity["type"] == tag[2:]:
                    current_entity["text"] += tokens[i]
                    current_entity["end"] = i + 1
                    entities.append(current_entity)
                # 实体已结束,重置
                current_entity = None
            elif tag.startswith('S-'):
                # S 标签表示单个字符的实体
                # 如果有未结束的实体,则放弃
                current_entity = None
                entities.append({"text": tokens[i], "type": tag[2:], "start": i, "end": i + 1})
            else: # 'O' 标签
                # O 标签意味着没有实体,或者实体已经结束
                current_entity = None
        
        # 循环结束后,不再处理任何未闭合的实体
        return entities

def main():
    parser = argparse.ArgumentParser(description="NER Prediction")
    parser.add_argument("--model_dir", type=str, required=True, help="Directory of the saved model and config.")
    parser.add_argument("--text", type=str, required=True, help="Text to predict.")
    args = parser.parse_args()

    predictor = NerPredictor(model_dir=args.model_dir)
    entities = predictor.predict(args.text)
    print(f"Text: {args.text}")
    print(f"Entities: {json.dumps(entities, ensure_ascii=False, indent=2)}")

if __name__ == "__main__":
    main()

输出:
(base-llm) PS E:\Datawhale 2026\base-llm202602\06_homework> python 06_predict.py --model_dir ./output --text "患者有高血压病史,需服用降压药。"
Text: 患者有高血压病史,需服用降压药。
Entities: [
  {
    "text": "高血压",
    "type": "dis",
    "start": 3,
    "end": 6
  },
  {
    "text": "降压药",
    "type": "dru",
    "start": 12,
    "end": 15
  }
]

解读:

NER 模型成功识别出了文本中的实体!🎉 输出结果完全正确,模型准确地找到了:

  • 高血压 (类型 dis,疾病) 从第3个字符到第6个字符(注意:在字符串中,索引从0开始,所以3到6正好是“高血压”)。

  • 降压药 (类型 dru,药物) 从第12到15个字符。

这表明训练的模型已经学会了从医学文本中抽取关键信息,这正是命名实体识别的核心任务。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐