在教育智能化领域,基于大模型的试卷手写体识别是实现自动批改的核心前提。与通用手写体识别不同,试卷场景的手写体数据具有「字体多样(学生书写差异大)、内容结构化(含选择题选项、填空题答案、解答题步骤)、格式固定(试卷排版规范)」等特性,数据准备需重点解决图像数据与文本标签的对齐、噪声去除、格式标准化及模型输入适配等问题。本文以初中数学试卷手写体识别为目标场景,完整讲解从数据收集整理、质量优化、标注处理到模型编码的全流程,通过可复现的样本数据实例与代码,拆解每一步技术细节与实现逻辑。

一、需求背景与样本数据设计

1.1 场景定义

本次案例聚焦「初中数学试卷手写体识别」,目标是微调一个基于视觉大模型(如ViT-L/14、CRNN+Transformer融合模型)的手写体识别模型,实现对试卷中三类核心手写内容的精准识别:① 选择题涂卡/手写选项(A/B/C/D);② 填空题手写答案(数字、字母、符号组合);③ 解答题手写步骤(文字、公式、数字混合),为后续自动批改提供结构化识别结果。

1.2 试卷手写体数据的核心维度与收集要求

试卷手写体数据为「图像数据+文本标签」的配对形式,核心维度及收集要求如下:

数据类型 核心内容 收集要求
图像数据 试卷手写区域截图(选择题、填空题、解答题区域) 1. 分辨率≥300dpi,清晰无模糊;2. 覆盖不同书写风格(工整/潦草/连笔);3. 包含不同纸张底色(白色试卷纸、略带泛黄纸张);4. 排除严重污渍、折痕遮挡的样本
文本标签 图像对应手写内容的标准文本(含格式标注) 1. 选择题标注格式:「题型-题号-选项」(如“选择-3-B”);2. 填空题标注格式:「题型-题号-答案」(如“填空-5-3.14”);3. 解答题标注格式:「题型-题号-步骤-内容」(如“解答-8-步骤1-解:设未知数x”);4. 公式需用LaTeX格式标注(如“x2+2x+1=0x^2 + 2x + 1 = 0x2+2x+1=0”)
辅助信息 试卷版本、年级、出题省份、书写者年级 用于数据分层,提升模型泛化性

1.3 样本数据生成与收集整理(含模拟脏数据)

考虑到手写体数据收集的特殊性,先通过模拟生成符合场景特性的样本数据(含图像模拟与标签配对),再补充真实采集数据的整理逻辑。

1.3.1 样本数据生成代码(模拟试卷手写体数据)


import os
import cv2
import numpy as np
import json
import random
from PIL import Image, ImageDraw, ImageFont

# 1. 基础配置定义
# 试卷区域参数(模拟A4纸张,300dpi)
A4_SIZE = (2480, 3508)  # 300dpi下A4纸像素尺寸(宽×高)
AREA_PARAMS = {
    "选择": {"width": 200, "height": 50, "num_per_row": 5, "start_pos": (200, 300)},
    "填空": {"width": 300, "height": 50, "num_per_row": 2, "start_pos": (200, 800)},
    "解答": {"width": 800, "height": 400, "num_per_row": 1, "start_pos": (200, 1300)}
}
# 模拟手写内容模板
CONTENT_TEMPLATES = {
    "选择": ["A", "B", "C", "D"],
    "填空": ["3.14", "x=5", "2√3", "-1/2", "$S=πr²$", "60°"],
    "解答": [
        "解:设未知数x",
        "由题意得:2x + 3 = 11",
        "解得:x = 4",
        "$∵ AB∥CD ∴ ∠1=∠2$",
        "证明:在△ABC中,AB=AC"
    ]
}
# 模拟书写风格参数(字体大小、倾斜角度、模糊程度)
STYLE_PARAMS = {
    "工整": {"font_size": 30, "angle": 0, "blur": 0},
    "潦草": {"font_size": 28, "angle": 3, "blur": 1},
    "连笔": {"font_size": 32, "angle": -2, "blur": 1}
}
# 生成样本数量
NUM_SAMPLES = 100  # 每类题型100条,共300条

# 2. 生成模拟手写体图像与标签
def generate_handwriting_text_image(text, style, width, height):
    """生成模拟手写体文本图像"""
    # 创建空白图像(模拟试卷底色,略带灰度)
    img = Image.new("RGB", (width, height), color=(245, 245, 240))
    draw = ImageDraw.Draw(img)
    
    # 加载手写风格字体(需提前准备手写体字体文件,如“手写体.ttf”)
    try:
        font = ImageFont.truetype("handwriting_font.ttf", style["font_size"])
    except:
        # 若无手写体字体,用系统默认字体模拟
        font = ImageFont.load_default(size=style["font_size"])
    
    # 绘制文本(居中对齐)
    text_width, text_height = draw.textsize(text, font=font)
    text_x = (width - text_width) // 2
    text_y = (height - text_height) // 2
    draw.text((text_x, text_y), text, font=font, fill=(0, 0, 0))  # 黑色字体
    
    # 模拟书写倾斜
    img = img.rotate(style["angle"], expand=False, fillcolor=(245, 245, 240))
    
    # 模拟模糊(模拟纸张轻微模糊)
    if style["blur"] > 0:
        img_np = np.array(img)
        img_np = cv2.GaussianBlur(img_np, (3, 3), style["blur"])
        img = Image.fromarray(img_np)
    
    return img

def generate_exam_handwriting_data():
    """生成完整的试卷手写体样本数据(图像+标签)"""
    # 创建保存目录
    base_dir = "exam_handwriting_data"
    os.makedirs(base_dir, exist_ok=True)
    img_dir = os.path.join(base_dir, "images")
    os.makedirs(img_dir, exist_ok=True)
    
    # 初始化标签列表
    all_labels = []
    
    for question_type in AREA_PARAMS.keys():
        area_info = AREA_PARAMS[question_type]
        content_list = CONTENT_TEMPLATES[question_type]
        
        for idx in range(NUM_SAMPLES):
            # 随机选择书写风格
            style_name = random.choice(list(STYLE_PARAMS.keys()))
            style = STYLE_PARAMS[style_name]
            
            # 随机选择手写内容
            content = random.choice(content_list)
            
            # 生成手写体图像
            img = generate_handwriting_text_image(
                text=content,
                style=style,
                width=area_info["width"],
                height=area_info["height"]
            )
            
            # 生成图像文件名(题型-序号-风格.jpg)
            img_filename = f"{question_type}-{idx:03d}-{style_name}.jpg"
            img_path = os.path.join(img_dir, img_filename)
            img.save(img_path)
            
            # 生成标签信息
            label = {
                "sample_id": f"{question_type}_{idx:03d}",
                "image_path": img_path,
                "question_type": question_type,
                "question_id": f"{question_type[0]}{idx+1}",  # 如“选1”“填1”“解1”
                "handwriting_content": content,
                "standard_content": content,  # 标准文本(后续用于模型训练标签)
                "writing_style": style_name,
                "resolution": "300dpi",
                "is_valid": True  # 标记是否有效样本
            }
            
            # 故意混入脏数据(10%比例)
            if idx % 10 == 0:
                if question_type == "选择":
                    # 脏数据1:选项模糊+标签错误
                    label["is_valid"] = False
                    label["handwriting_content"] = "E"  # 无效选项
                    label["standard_content"] = "B"  # 标签与内容不匹配
                elif question_type == "填空":
                    # 脏数据2:内容遮挡(模拟试卷折痕)
                    label["is_valid"] = False
                    # 对图像添加遮挡
                    img_np = np.array(img)
                    cv2.rectangle(img_np, (50, 10), (150, 40), (245, 245, 240), -1)  # 白色遮挡条
                    img = Image.fromarray(img_np)
                    img.save(img_path)
                elif question_type == "解答":
                    # 脏数据3:内容涂改(模拟学生涂改)
                    label["is_valid"] = False
                    # 对图像添加涂改痕迹
                    img_np = np.array(img)
                    draw = ImageDraw.Draw(Image.fromarray(img_np))
                    draw.line((20, 20, 180, 30), fill=(0, 0, 0), width=2)  # 横线涂改
                    img = Image.fromarray(img_np)
                    img.save(img_path)
            
            all_labels.append(label)
    
    # 保存标签文件(JSON格式)
    label_path = os.path.join(base_dir, "raw_labels.json")
    with open(label_path, "w", encoding="utf-8") as f:
        json.dump(all_labels, f, ensure_ascii=False, indent=2)
    
    print(f"样本数据生成完成!共{len(all_labels)}条样本")
    print(f"图像保存目录:{img_dir}")
    print(f"标签文件路径:{label_path}")
    return all_labels

# 执行样本生成
raw_labels = generate_exam_handwriting_data()
# 打印前5条有效样本信息
valid_samples = [s for s in raw_labels if s["is_valid"]][:5]
print("\n前5条有效样本信息:")
for sample in valid_samples:
    print(json.dumps(sample, ensure_ascii=False, indent=2))

1.3.2 真实数据收集与整理补充

模拟数据仅用于流程验证,真实场景需通过以下方式收集整理数据:

  1. 数据采集:收集真实初中数学试卷(扫描件/高清拍照,确保300dpi以上),通过图像分割工具(如OpenCV、LabelMe)裁剪出选择题、填空题、解答题的手写区域,单独保存为图像文件;

  2. 标签标注:组织教师团队对裁剪后的图像进行人工标注,严格遵循「题型-题号-内容」的标注格式,公式部分需统一用LaTeX格式标注(可借助LaTeX编辑器辅助);

  3. 数据分层:按「年级(初一/初二/初三)、书写风格(工整/潦草/连笔)、试卷难度(基础/提升/压轴)」进行分层,确保数据分布均匀;

  4. 原始数据归档:建立结构化目录(如“exam_handwriting_data/真实数据/初一/基础/选择/”),并生成对应的标签文件,包含图像路径、标注内容、采集时间等信息。

1.4 高质量试卷手写体数据的核心特性

针对试卷批改场景,高质量手写体数据需满足:

  1. 图像质量达标:无模糊、无严重污渍、无遮挡,文字边缘清晰,分辨率≥300dpi;

  2. 标签精准对齐:图像与标注内容严格一一对应,无标签错误、遗漏,公式标注格式统一;

  3. 覆盖场景全面:包含不同年级、不同书写风格、不同题型的样本,避免数据偏倚;

  4. 格式规范性:图像文件命名统一(如“初一-基础-选择-001.jpg”),标签文件结构固定,便于后续批量处理;

  5. 模型适配性:图像尺寸可调整为模型输入要求(如224×224、384×384),标签可转换为模型可识别的文本编码格式。

1.5 模型输入要求

本次选择「ViT-L/14(视觉Transformer)+ CRNN(循环神经网络)」融合模型作为微调模型,其输入要求:

  1. 图像输入:统一尺寸为384×384像素,3通道RGB图像,像素值归一化到[0,1]区间;

  2. 文本标签:将标注文本转换为字符级编码(如使用Unicode编码或自定义字符集映射),公式的LaTeX文本需保留原始格式并一同编码;

  3. 序列长度:单条文本标签的字符长度不超过128(解答题步骤较长时可分段处理);

  4. 数据配对:每个图像样本需与对应的文本标签严格配对,形成「图像张量-文本编码」的训练对。

二、数据清洗:剔除低质数据,优化图像质量

试卷手写体数据清洗分为「图像质量清洗」和「标签清洗」两部分,核心目标是去除无效样本、优化图像可读性、修正标签错误。

2.1 清洗步骤与代码实现


import os
import json
import cv2
import numpy as np
from PIL import Image
from collections import Counter

# 1. 加载原始数据(图像+标签)
base_dir = "exam_handwriting_data"
raw_label_path = os.path.join(base_dir, "raw_labels.json")
with open(raw_label_path, "r", encoding="utf-8") as f:
    raw_labels = json.load(f)

# 2. 定义清洗规则与工具函数
class ExamHandwritingCleaner:
    def __init__(self):
        # 有效字符集(覆盖初中数学常见字符、LaTeX符号)
        self.valid_chars = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+-×÷=()[]{}√π°$//:;.,。")
        # 图像质量阈值
        self.blur_threshold = 100  # 清晰度阈值(越高越清晰)
        self.noise_threshold = 0.05  # 噪声占比阈值(越低噪声越少)
    
    def calculate_image_sharpness(self, img_path):
        """计算图像清晰度(用拉普拉斯方差衡量)"""
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        if img is None:
            return 0
        laplacian = cv2.Laplacian(img, cv2.CV_64F)
        return np.var(laplacian)
    
    def calculate_noise_ratio(self, img_path):
        """计算图像噪声占比(用椒盐噪声检测)"""
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        if img is None:
            return 1.0
        # 检测椒盐噪声(像素值为0或255的异常点)
        height, width = img.shape
        total_pixels = height * width
        salt_noise = np.sum(img == 255)
        pepper_noise = np.sum(img == 0)
        noise_pixels = salt_noise + pepper_noise
        return noise_pixels / total_pixels
    
    def clean_image(self, img_path, save_path):
        """图像优化:去噪、增强对比度"""
        img = cv2.imread(img_path)
        if img is None:
            return False
        
        # 步骤1:去噪(高斯滤波)
        img_denoised = cv2.GaussianBlur(img, (3, 3), 0)
        
        # 步骤2:增强对比度(直方图均衡化,针对灰度图)
        img_gray = cv2.cvtColor(img_denoised, cv2.COLOR_BGR2GRAY)
        img_equalized = cv2.equalizeHist(img_gray)
        img_processed = cv2.cvtColor(img_equalized, cv2.COLOR_GRAY2BGR)
        
        # 步骤3:保存优化后的图像
        cv2.imwrite(save_path, img_processed)
        return True
    
    def validate_label(self, label):
        """验证标签有效性:字符有效性、格式规范性"""
        # 检查必填字段
        required_fields = ["sample_id", "image_path", "question_type", "question_id", "standard_content"]
        for field in required_fields:
            if field not in label or not label[field]:
                return False, "缺失必填字段"
        
        # 检查字符有效性(仅允许有效字符集中的字符)
        for char in label["standard_content"]:
            if char not in self.valid_chars and char != "$":  # $为LaTeX标识,单独允许
                return False, f"包含无效字符:{char}"
        
        # 检查题型格式
        if label["question_type"] not in ["选择", "填空", "解答"]:
            return False, "题型格式错误"
        
        return True, "验证通过"
    
    def clean(self, raw_labels, base_dir):
        """完整清洗流程"""
        # 创建清洗后的数据目录
        cleaned_img_dir = os.path.join(base_dir, "cleaned_images")
        os.makedirs(cleaned_img_dir, exist_ok=True)
        
        cleaned_labels = []
        invalid_records = []
        
        for label in raw_labels:
            sample_id = label["sample_id"]
            img_path = label["image_path"]
            
            # 跳过已标记为无效的样本
            if not label.get("is_valid", True):
                invalid_records.append({"sample_id": sample_id, "reason": "原始标记无效"})
                continue
            
            # 步骤1:验证图像文件是否存在
            if not os.path.exists(img_path):
                invalid_records.append({"sample_id": sample_id, "reason": "图像文件不存在"})
                continue
            
            # 步骤2:评估图像质量(清晰度、噪声)
            sharpness = self.calculate_image_sharpness(img_path)
            noise_ratio = self.calculate_noise_ratio(img_path)
            if sharpness < self.blur_threshold:
                invalid_records.append({"sample_id": sample_id, "reason": f"图像模糊,清晰度:{sharpness:.2f}"})
                continue
            if noise_ratio > self.noise_threshold:
                invalid_records.append({"sample_id": sample_id, "reason": f"噪声过多,噪声占比:{noise_ratio:.2f}"})
                continue
            
            # 步骤3:图像优化(去噪、增强对比度)
            img_filename = os.path.basename(img_path)
            cleaned_img_path = os.path.join(cleaned_img_dir, img_filename)
            if not self.clean_image(img_path, cleaned_img_path):
                invalid_records.append({"sample_id": sample_id, "reason": "图像优化失败"})
                continue
            
            # 步骤4:验证标签有效性
            label_valid, label_msg = self.validate_label(label)
            if not label_valid:
                invalid_records.append({"sample_id": sample_id, "reason": f"标签无效:{label_msg}"})
                continue
            
            # 步骤5:整理清洗后的标签
            cleaned_label = {
                "sample_id": sample_id,
                "raw_image_path": img_path,
                "cleaned_image_path": cleaned_img_path,
                "question_type": label["question_type"],
                "question_id": label["question_id"],
                "standard_content": label["standard_content"],
                "writing_style": label["writing_style"],
                "sharpness": round(sharpness, 2),
                "noise_ratio": round(noise_ratio, 2)
            }
            cleaned_labels.append(cleaned_label)
        
        # 保存清洗结果
        cleaned_label_path = os.path.join(base_dir, "cleaned_labels.json")
        with open(cleaned_label_path, "w", encoding="utf-8") as f:
            json.dump(cleaned_labels, f, ensure_ascii=False, indent=2)
        
        # 保存无效样本记录
        invalid_path = os.path.join(base_dir, "invalid_samples.json")
        with open(invalid_path, "w", encoding="utf-8") as f:
            json.dump(invalid_records, f, ensure_ascii=False, indent=2)
        
        print(f"数据清洗完成!")
        print(f"有效样本数:{len(cleaned_labels)}")
        print(f"无效样本数:{len(invalid_records)}")
        print(f"清洗后标签文件:{cleaned_label_path}")
        print(f"无效样本记录:{invalid_path}")
        
        # 打印无效样本原因分布
        reason_counter = Counter([r["reason"].split(":")[0] for r in invalid_records])
        print(f"\n无效样本原因分布:{dict(reason_counter)}")
        return cleaned_labels, invalid_records

# 执行清洗
cleaner = ExamHandwritingCleaner()
cleaned_labels, invalid_records = cleaner.clean(raw_labels, base_dir)

2.2 清洗步骤说明

  1. 图像存在性校验:过滤路径错误、文件损坏的图像样本;

  2. 图像质量评估:通过拉普拉斯方差计算清晰度(阈值设为100,低于则为模糊),通过椒盐噪声检测计算噪声占比(阈值设为5%,高于则为噪声过多);

  3. 图像优化:采用高斯滤波去除轻微噪声,通过直方图均衡化增强图像对比度,提升文字可读性;

  4. 标签验证:检查标签是否缺失必填字段、是否包含无效字符、题型格式是否正确,确保标签精准性;

  5. 结果归档:保存清洗后的有效样本(含优化后图像路径、质量指标)和无效样本记录,便于后续追溯与补充采集。

三、数据标准化:统一图像格式与标签编码

标准化的核心是将清洗后的图像和标签统一为模型可直接处理的格式,包括图像尺寸归一化、像素值标准化,以及标签的字符集映射与序列长度统一。

3.1 标准化步骤与代码实现


import os
import json
import cv2
import numpy as np
from PIL import Image
from collections import defaultdict

# 1. 加载清洗后的数据
base_dir = "exam_handwriting_data"
cleaned_label_path = os.path.join(base_dir, "cleaned_labels.json")
with open(cleaned_label_path, "r", encoding="utf-8") as f:
    cleaned_labels = json.load(f)

# 2. 定义标准化配置
STANDARD_SIZE = (384, 384)  # 模型输入标准尺寸(宽×高)
NORMALIZE_RANGE = (0, 1)    # 像素值标准化范围
MAX_SEQ_LENGTH = 128        # 标签序列最大长度
PAD_TOKEN = "[PAD]"         # 填充符
UNK_TOKEN = "[UNK]"         # 未知字符

# 3. 构建字符集(基于清洗后的标签)
def build_vocab(cleaned_labels):
    """基于标签内容构建字符集映射"""
    vocab = defaultdict(int)
    # 先添加特殊符号
    vocab[PAD_TOKEN] = 0
    vocab[UNK_TOKEN] = 1
    
    # 统计所有标签中的字符
    for label in cleaned_labels:
        content = label["standard_content"]
        for char in content:
            vocab[char] += 1
    
    # 生成字符→索引的映射(过滤出现次数≤1的稀有字符)
    char2idx = {char: idx for idx, (char, count) in enumerate(vocab.items()) if count > 1}
    # 若稀有字符仍需保留,映射到UNK_TOKEN
    char2idx[UNK_TOKEN] = 1
    idx2char = {idx: char for char, idx in char2idx.items()}
    
    print(f"字符集构建完成!共包含{len(char2idx)}个字符")
    print(f"字符集示例:{list(char2idx.items())[:10]}")
    return char2idx, idx2char

# 4. 定义标准化类
class ExamHandwritingStandardizer:
    def __init__(self, char2idx, idx2char, standard_size, normalize_range):
        self.char2idx = char2idx
        self.idx2char = idx2char
        self.standard_size = standard_size
        self.normalize_range = normalize_range
    
    def standardize_image(self, img_path):
        """图像标准化:尺寸归一化、像素值标准化"""
        # 读取图像(3通道RGB)
        img = cv2.imread(img_path)
        if img is None:
            return None
        
        # 步骤1:尺寸归一化(保持比例,填充黑边)
        h, w = img.shape[:2]
        target_h, target_w = self.standard_size
        # 计算缩放比例
        scale = min(target_w / w, target_h / h)
        new_w = int(w * scale)
        new_h = int(h * scale)
        # 缩放图像
        img_resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
        # 创建标准尺寸的空白图像(黑色背景)
        img_standard = np.zeros((target_h, target_w, 3), dtype=np.uint8)
        # 计算填充位置(居中对齐)
        x_offset = (target_w - new_w) // 2
        y_offset = (target_h - new_h) // 2
        # 填充图像
        img_standard[y_offset:y_offset+new_h, x_offset:x_offset+new_w, :] = img_resized
        
        # 步骤2:像素值标准化(转换为[0,1]区间)
        img_normalized = img_standard / 255.0
        # 若需要转换为[-1,1]区间,可使用:img_normalized = (img_standard / 127.5) - 1.0
        
        # 步骤3:调整维度(HWC→CHW,适配PyTorch模型输入)
        img_tensor = np.transpose(img_normalized, (2, 0, 1))
        return img_tensor
    
    def standardize_label(self, content):
        """标签标准化:字符映射为索引,统一序列长度"""
        # 字符→索引映射
        idx_sequence = []
        for char in content:
            idx = self.char2idx.get(char, self.char2idx[UNK_TOKEN])
            idx_sequence.append(idx)
        
        # 统一序列长度(填充或截断)
        if len(idx_sequence) < MAX_SEQ_LENGTH:
            # 不足则填充PAD_TOKEN
            pad_length = MAX_SEQ_LENGTH - len(idx_sequence)
            idx_sequence += [self.char2idx[PAD_TOKEN]] * pad_length
        else:
            # 超过则截断
            idx_sequence = idx_sequence[:MAX_SEQ_LENGTH]
        
        # 转换为numpy数组
        idx_sequence = np.array(idx_sequence, dtype=np.int64)
        # 生成注意力掩码(PAD_TOKEN对应0,其他对应1)
        attention_mask = np.where(idx_sequence == self.char2idx[PAD_TOKEN], 0, 1)
        return idx_sequence, attention_mask
    
    def standardize(self, cleaned_labels, base_dir):
        """完整标准化流程"""
        # 创建标准化后的数据目录(用于保存可视化结果,模型训练直接用张量)
        standard_img_dir = os.path.join(base_dir, "standardized_images")
        os.makedirs(standard_img_dir, exist_ok=True)
        
        standardized_data = []
        failed_samples = []
        
        for label in cleaned_labels:
            sample_id = label["sample_id"]
            cleaned_img_path = label["cleaned_image_path"]
            
            # 步骤1:图像标准化
            img_tensor = self.standardize_image(cleaned_img_path)
            if img_tensor is None:
                failed_samples.append({"sample_id": sample_id, "reason": "图像标准化失败"})
                continue
            
            # 步骤2:标签标准化
            content = label["standard_content"]
            idx_sequence, attention_mask = self.standardize_label(content)
            
            # 步骤3:保存标准化后的图像(用于校验)
            # 转换回HWC格式并反归一化
            img_vis = np.transpose(img_tensor, (1, 2, 0)) * 255.0
            img_vis = img_vis.astype(np.uint8)
            img_filename = os.path.basename(cleaned_img_path)
            standard_img_path = os.path.join(standard_img_dir, img_filename)
            cv2.imwrite(standard_img_path, img_vis)
            
            # 整理标准化后的数据
            standardized_sample = {
                "sample_id": sample_id,
                "cleaned_image_path": cleaned_img_path,
                "standard_image_path": standard_img_path,
                "question_type": label["question_type"],
                "question_id": label["question_id"],
                "standard_content": content,
                "image_tensor": img_tensor.tolist(),  # 列表格式便于保存JSON(实际训练用numpy)
                "label_sequence": idx_sequence.tolist(),
                "attention_mask": attention_mask.tolist()
            }
            standardized_data.append(standardized_sample)
        
        # 保存标准化后的数据(含张量信息,用于模型训练加载)
        standard_data_path = os.path.join(base_dir, "standardized_data.json")
        with open(standard_data_path, "w", encoding="utf-8") as f:
            json.dump(standardized_data, f, ensure_ascii=False, indent=2)
        
        # 保存字符集映射
        vocab_path = os.path.join(base_dir, "char2idx.json")
        with open(vocab_path, "w", encoding="utf-8") as f:
            json.dump(self.char2idx, f, ensure_ascii=False, indent=2)
        
        idx2char_path = os.path.join(base_dir, "idx2char.json")
        with open(idx2char_path, "w", encoding="utf-8") as f:
            json.dump(self.idx2char, f, ensure_ascii=False, indent=2)
        
        # 保存失败样本记录
        failed_path = os.path.join(base_dir, "standardization_failed.json")
        with open(failed_path, "w", encoding="utf-8") as f:
            json.dump(failed_samples, f, ensure_ascii=False, indent=2)
        
        print(f"数据标准化完成!")
        print(f"标准化成功样本数:{len(standardized_data)}")
        print(f"标准化失败样本数:{len(failed_samples)}")
        print(f"标准化数据文件:{standard_data_path}")
        print(f"字符集映射文件:{vocab_path}")
        return standardized_data, failed_samples

# 5. 执行标准化
# 第一步:构建字符集
char2idx, idx2char = build_vocab(cleaned_labels)

# 第二步:初始化标准化器并执行
standardizer = ExamHandwritingStandardizer(
    char2idx=char2idx,
    idx2char=idx2char,
    standard_size=STANDARD_SIZE,
    normalize_range=NORMALIZE_RANGE
)
standardized_data, failed_samples = standardizer.standardize(cleaned_labels, base_dir)

# 打印标准化示例
if standardized_data:
    sample = standardized_data[0]
    print("\n标准化示例:")
    print(f"样本ID:{sample['sample_id']}")
    print(f"原始内容:{sample['standard_content']}")
    print(f"标签序列(前10个索引):{sample['label_sequence'][:10]}")
    print(f"注意力掩码(前10个):{sample['attention_mask'][:10]}")
    print(f"图像张量形状:{np.array(sample['image_tensor']).shape}")  # (3, 384, 384)

3.2 标准化步骤说明

  1. 字符集构建:基于清洗后的标签内容,统计出现次数≥2的字符(过滤稀有字符),构建「字符→索引」映射,包含填充符[PAD]和未知字符[UNK];

  2. 图像标准化:

    • 尺寸归一化:采用「保持比例+填充黑边」的方式,将图像统一缩放到384×384像素,避免拉伸导致文字变形;

    • 像素值标准化:将像素值从[0,255]转换为[0,1]区间,适配模型对输入数据范围的要求;

    • 维度调整:将OpenCV默认的HWC(高度×宽度×通道)格式转换为PyTorch模型要求的CHW(通道×高度×宽度)格式;

  3. 标签标准化:

    • 字符编码:将文本标签中的每个字符映射为字符集中的索引;

    • 序列长度统一:通过填充[PAD]或截断,将所有标签序列长度统一为128;

    • 注意力掩码生成:标记有效字符(1)和填充字符(0),便于模型在训练时忽略填充部分;

  4. 结果保存:保存标准化后的图像(用于人工校验)、张量数据(图像张量、标签序列、注意力掩码)及字符集映射,为后续模型编码做好准备。

四、数据编码:转换为模型可直接训练的格式

标准化后的图像张量和标签序列仍为列表/JSON格式,需进一步转换为模型训练框架(如PyTorch)可直接加载的张量格式,并按训练/验证/测试集拆分,最终生成批量训练数据。

4.1 编码步骤与代码实现


import os
import json
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

# 1. 加载标准化后的数据
base_dir = "exam_handwriting_data"
standard_data_path = os.path.join(base_dir, "standardized_data.json")
char2idx_path = os.path.join(base_dir, "char2idx.json")

with open(standard_data_path, "r", encoding="utf-8") as f:
    standardized_data = json.load(f)

with open(char2idx_path, "r", encoding="utf-8") as f:
    char2idx = json.load(f)

# 2. 定义数据集类(适配PyTorch)
class ExamHandwritingDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        # 转换图像张量(列表→PyTorch张量,float32类型)
        image_tensor = torch.tensor(sample["image_tensor"], dtype=torch.float32)
        # 转换标签序列(列表→PyTorch张量,int64类型)
        label_sequence = torch.tensor(sample["label_sequence"], dtype=torch.int64)
        # 转换注意力掩码(列表→PyTorch张量,int64类型)
        attention_mask = torch.tensor(sample["attention_mask"], dtype=torch.int64)
        
        # 返回字典格式,便于模型训练
        return {
            "sample_id": sample["sample_id"],
            "image": image_tensor,
            "label": label_sequence,
            "attention_mask": attention_mask
        }

# 3. 数据拆分(训练集:验证集:测试集 = 7:2:1)
def split_data(standardized_data, test_size=0.1, val_size=0.2):
    """拆分训练集、验证集、测试集"""
    # 先拆分训练集和测试集
    train_data, test_data = train_test_split(
        standardized_data, test_size=test_size, random_state=42, shuffle=True
    )
    # 再从训练集中拆分验证集
    train_data, val_data = train_test_split(
        train_data, test_size=val_size/(1-test_size), random_state=42, shuffle=True
    )
    
    print(f"数据拆分完成!")
    print(f"训练集样本数:{len(train_data)}")
    print(f"验证集样本数:{len(val_data)}")
    print(f"测试集样本数:{len(test_data)}")
    return train_data, val_data, test_data

# 4. 生成批量数据加载器
def create_data_loaders(train_data, val_data, test_data, batch_size=16):
    """创建训练、验证、测试数据加载器"""
    train_dataset = ExamHandwritingDataset(train_data)
    val_dataset = ExamHandwritingDataset(val_data)
    test_dataset = ExamHandwritingDataset(test_data)
    
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=0  # num_workers根据CPU核心调整
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, num_workers=0
    )
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, num_workers=0
    )
    
    return train_loader, val_loader, test_loader

# 5. 保存拆分后的数据(用于后续复用)
def save_split_data(train_data, val_data, test_data, base_dir):
    split_dir = os.path.join(base_dir, "split_data")
    os.makedirs(split_dir, exist_ok=True)
    
    # 保存训练集
    train_path = os.path.join(split_dir, "train_data.json")
    with open(train_path, "w", encoding="utf-8") as f:
        json.dump(train_data, f, ensure_ascii=False, indent=2)
    
    # 保存验证集
    val_path = os.path.join(split_dir, "val_data.json")
    with open(val_path, "w", encoding="utf-8") as f:
        json.dump(val_data, f, ensure_ascii=False, indent=2)
    
    # 保存测试集
    test_path = os.path.join(split_dir, "test_data.json")
    with open(test_path, "w", encoding="utf-8") as f:
        json.dump(test_data, f, ensure_ascii=False, indent=2)
    
    print(f"拆分后数据保存完成:{split_dir}")

# 6. 执行编码与数据拆分
if __name__ == "__main__":
    # 步骤1:数据拆分
    train_data, val_data, test_data = split_data(standardized_data)
    
    # 步骤2:保存拆分后的数据
    save_split_data(train_data, val_data, test_data, base_dir)
    
    # 步骤3:创建数据加载器
    batch_size = 16
    train_loader, val_loader, test_loader = create_data_loaders(
        train_data, val_data, test_data, batch_size=batch_size
    )
    
    # 验证数据加载器
    print(f"\n验证数据加载器:")
    for batch in train_loader:
        print(f"批量样本ID:{batch['sample_id'][:5]}...")  # 打印前5个样本ID
        print(f"图像批量形状:{batch['image'].shape}")  # (batch_size, 3, 384, 384)
        print(f"标签批量形状:{batch['label'].shape}")  # (batch_size, 128)
        print(f"注意力掩码批量形状:{batch['attention_mask'].shape}")  # (batch_size, 128)
        break  # 只打印第一个批量
    
    # 保存数据加载器配置
    loader_config = {
        "batch_size": batch_size,
        "train_loader_length": len(train_loader),
        "val_loader_length": len(val_loader),
        "test_loader_length": len(test_loader)
    }
    config_path = os.path.join(base_dir, "data_loader_config.json")
    with open(config_path, "w", encoding="utf-8") as f:
        json.dump(loader_config, f, ensure_ascii=False, indent=2)
    
    print(f"\n数据编码完成!数据加载器配置已保存:{config_path}")
    print("后续可直接使用train_loader、val_loader、test_loader进行模型微调训练")

4.2 编码步骤说明

  1. 数据集类定义:自定义PyTorch的Dataset类,实现__len__和__getitem__方法,将标准化后的列表格式数据转换为PyTorch张量(图像为float32,标签和掩码为int64);

  2. 数据拆分:采用分层随机拆分策略,按7:2:1的比例将数据分为训练集(用于模型参数更新)、验证集(用于训练过程中的超参数调整)、测试集(用于最终模型性能评估),确保各集数据分布一致;

  3. 批量数据加载:通过DataLoader将数据集按指定批量大小(如16)打包,训练集开启shuffle(打乱数据顺序,提升泛化性),验证集和测试集关闭shuffle;

  4. 结果验证与保存:验证批量数据的形状是否符合模型输入要求(图像批量形状为(batch_size, 3, 384, 384),标签批量形状为(batch_size, 128)),并保存数据拆分结果和加载器配置,便于后续模型训练直接复用。

五、完整流程总结与扩展建议

5.1 试卷手写体识别数据准备全流程

数据收集整理:试卷扫描→区域裁剪→人工标注

数据清洗:图像质量筛选→图像优化→标签验证

数据标准化:图像尺寸/像素归一化→标签字符编码→序列长度统一

数据编码:张量转换→数据拆分→批量加载器生成

模型微调训练:ViT-L/14+CRNN融合模型

5.2 核心要点回顾

  1. 场景适配性:试卷手写体数据需重点关注图像清晰度(≥300dpi)和标签格式规范性(题型-题号-内容,公式LaTeX标注),避免因数据问题影响识别精度;

  2. 图像优化关键:采用「高斯滤波去噪+直方图均衡化增强对比度」的组合策略,有效提升手写文字的可读性;

  3. 字符集构建:基于真实标签内容构建字符集,过滤稀有字符,确保模型能覆盖场景内所有常见字符;

  4. 数据拆分合理性:分层随机拆分数据,保证训练/验证/测试集的分布一致性,避免模型过拟合或评估偏差。

5.3 扩展建议

  1. 数据增强:针对手写体数据量不足的问题,可采用图像增强技术(如随机旋转±5°、轻微缩放、亮度调整、添加轻微噪声)扩充训练集,提升模型泛化性;

  2. 多题型适配:若需扩展到语文、英语等其他学科,需调整字符集(如添加汉字、英语字母、作文标点)和图像标准化参数(如解答题步骤较长时,可将MAX_SEQ_LENGTH调整为256);

  3. 半监督学习:对于标注成本高的场景,可采用半监督学习策略,先用少量标注数据训练基础模型,再用模型预测未标注数据,筛选置信度高的样本加入训练集;

  4. 模型适配优化:若使用不同模型(如CNN+Transformer、MAE微调),需调整图像标准化尺寸(如ViT-B/16对应224×224)和标签编码方式(如部分模型需使用CTC损失,无需注意力掩码);

  5. 人工校验:每一步数据处理后需加入人工抽检环节(如抽检10%的清洗后图像、20%的标准化标签),确保数据质量符合训练要求。

通过以上完整流程,可将试卷手写体的原始图像数据,逐步转换为大模型微调所需的高质量、结构化、可直接训练的格式。数据准备的核心是「场景适配+质量优先」,只有保证数据的精准性和全面性,才能为后续模型微调奠定坚实基础,最终实现试卷手写体的高精度识别与自动批改。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐