(1)伪造检测(Detection):对输入的图像,判断其是否包含任何形式的伪造。输出二分类标签label(0-“真实”, 1-"伪造”)。
(2)伪造定位(Grounding):如果图像被判断为“伪造”,系统需输出与原图尺寸相同的二值化掩码(Mask),其中伪造区域被标记为前景(白色像素, 像素值 255),真实区域为背景(黑色像素,像素值 0),最终结果中将其转换为RLE编码。
(3)可解释(Explanation):针对定位出的伪造区域,系统需生成一段自然语言文本,详细描述该区域被判定为伪造的原因。解释归因应具体、有逻辑,并结合图像内容,避免过度幻觉。

# ==============================
# 1. 导入所有需要的库
# ==============================
import os
import cv2
import json
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import segmentation_models_pytorch as smp
from pycocotools import mask as mask_utils
import albumentations as A
from albumentations.pytorch import ToTensorV2


# ==============================
# 2. 配置参数
# ==============================
class Config:
    # --------------------------
    # --------------------------
    TRAIN_BLACK_IMAGE_DIR = r"D:\比赛\ForgeryAnalysis_Stage_1_Train\ForgeryAnalysis_Stage_1_Train\White\Image"
    TRAIN_BLACK_MASK_DIR = r"D:\比赛\ForgeryAnalysis_Stage_1_Train\ForgeryAnalysis_Stage_1_Train\Black\Mask"

    # 测试集路径
    TEST_IMAGE_DIR = r"D:\比赛\ForgeryAnalysis_Stage_1_Train\ForgeryAnalysis_Stage_1_Train\Black\Image"

    # 其他配置
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    IMG_SIZE = (512, 512)
    BATCH_SIZE = 2  # CPU训练设为2,减少内存占用
    EPOCHS =20#先训练5轮测试
    LR = 1e-4
    MODEL_SAVE_PATH = "best_forgery_model.pth"
    SUBMISSION_SAVE_PATH = "submission.csv"


# ==============================
# 3. 数据增强函数
# ==============================
def get_train_transforms(img_size=Config.IMG_SIZE):
    return A.Compose([
        A.Resize(img_size[0], img_size[1]),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.2),
        A.RandomRotate90(p=0.5),
        A.ImageCompression(quality_range=(60, 100), p=0.5),
        A.GaussNoise(p=0.3),
        A.OneOf([
            A.GaussianBlur(blur_limit=3, p=1.0),
            A.MedianBlur(blur_limit=3, p=1.0),
        ], p=0.3),
        A.RandomBrightnessContrast(p=0.3),
        A.HueSaturationValue(p=0.2),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ], is_check_shapes=True)


def get_valid_transforms(img_size=Config.IMG_SIZE):
    return A.Compose([
        A.Resize(img_size[0], img_size[1]),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])


# ==============================
# 4. 数据集类
# ==============================
class ForensicsDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform

        # 兼容中文路径的图像读取函数
        def read_img_chinese(path):
            try:
                stream = open(path, 'rb')
                bytes = bytearray(stream.read())
                arr = np.asarray(bytes, dtype=np.uint8)
                return cv2.imdecode(arr, cv2.IMREAD_COLOR)
            except Exception as e:
                print(f"读取失败: {path}, 错误: {e}")
                return None

        # 检查路径是否存在
        if not os.path.exists(image_dir):
            raise ValueError(f"图像路径不存在: {image_dir}")

        # 列出所有文件并筛选图像
        all_files = os.listdir(image_dir)
        print(f"【日志】路径下总文件数: {len(all_files)}")

        img_ext = ('.jpg', '.jpeg', '.png', '.bmp', '.webp', '.tif')
        img_files = [f for f in all_files if f.lower().endswith(img_ext)]
        print(f"【日志】筛选出图像文件: {len(img_files)} 个")

        # 过滤有效图像
        self.valid_image_names = []
        self.valid_base_names = []
        for f in img_files:
            full_path = os.path.join(image_dir, f)
            img = read_img_chinese(full_path)
            if img is not None and img.size > 0:
                self.valid_image_names.append(f)
                self.valid_base_names.append(os.path.splitext(f)[0])

        # 最终检查
        print(f"【日志】有效图像数: {len(self.valid_image_names)}")
        if len(self.valid_image_names) == 0:
            raise ValueError(
                f"无有效图像文件!\n"
                f"路径: {image_dir}\n"
                f"请检查:1.路径是否正确 2.是否以管理员运行PyCharm"
            )

    def __len__(self):
        return len(self.valid_image_names)

    def __getitem__(self, idx):
        # 兼容中文路径读取
        def read_img_chinese(path):
            stream = open(path, 'rb')
            bytes = bytearray(stream.read())
            arr = np.asarray(bytes, dtype=np.uint8)
            return cv2.imdecode(arr, cv2.IMREAD_COLOR)

        base_name = self.valid_base_names[idx]
        img_name = self.valid_image_names[idx]
        img_path = os.path.join(self.image_dir, img_name)

        # 读取图像并转RGB
        image = read_img_chinese(img_path)
        if len(image.shape) == 2:
            image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
        elif image.shape[2] == 4:
            image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
        else:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # 读取Mask
        mask = None
        mask_path_png = os.path.join(self.mask_dir, f"{base_name}.png")
        mask_path_jpg = os.path.join(self.mask_dir, f"{base_name}.jpg")
        if os.path.exists(mask_path_png):
            mask = read_img_chinese(mask_path_png)
        elif os.path.exists(mask_path_jpg):
            mask = read_img_chinese(mask_path_jpg)

        if mask is None or mask.size == 0:
            mask = np.zeros((Config.IMG_SIZE[0], Config.IMG_SIZE[1]), dtype=np.uint8)
        else:
            mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
            _, mask = cv2.threshold(mask, 127, 1, cv2.THRESH_BINARY)

        # 数据增强
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        return {"image": image.float(), "mask": mask.long()}


# ==============================
# 5. 训练函数
# ==============================
def train_model():
    # 检查训练路径
    for path in [Config.TRAIN_BLACK_IMAGE_DIR, Config.TRAIN_BLACK_MASK_DIR]:
        if not os.path.exists(path):
            raise FileNotFoundError(f"路径不存在:{path}")

    # 加载数据集
    train_dataset = ForensicsDataset(
        image_dir=Config.TRAIN_BLACK_IMAGE_DIR,
        mask_dir=Config.TRAIN_BLACK_MASK_DIR,
        transform=get_train_transforms()
    )

    # 加载数据加载器
    train_loader = DataLoader(
        train_dataset,
        batch_size=Config.BATCH_SIZE,
        shuffle=True,
        num_workers=0,
        pin_memory=False
    )

    # 初始化模型
    model = smp.UnetPlusPlus(
        encoder_name="efficientnet_b0",
        encoder_weights="imagenet",  # 加载ImageNet预训练权重(核心!)
        in_channels=3,
        classes=1
    ).to(Config.DEVICE)

    # 损失函数和优化器
    dice_loss = smp.losses.DiceLoss(mode='binary')
    criterion = smp.losses.LovaszLoss(mode='binary')

    def combined_loss(outputs, masks):
        return 0.5 * dice_loss(outputs, masks) + 0.5 * bce_loss(outputs, masks)

    optimizer = optim.Adam(model.parameters(), lr=Config.LR)

    # 开始训练
    best_loss = float('inf')
    print(f"\n开始在 {Config.DEVICE} 上训练,共 {Config.EPOCHS} 轮...")
    print(f"训练集样本数:{len(train_dataset)}")

    for epoch in range(Config.EPOCHS):
        model.train()
        total_loss = 0.0
        batch_count = 0

        for batch in tqdm(train_loader, desc=f"第 {epoch + 1} 轮训练"):
            try:
                images = batch["image"].to(Config.DEVICE)
                masks = batch["mask"].to(Config.DEVICE).float().unsqueeze(1)

                # 前向传播
                outputs = model(images)
                loss = combined_loss(outputs, masks)

                # 反向传播
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                batch_count += 1
            except Exception as e:
                print(f"\n跳过错误批次:{e}")
                continue

        # 计算平均损失
        avg_loss = total_loss / batch_count if batch_count > 0 else float('inf')
        print(f"第 {epoch + 1} 轮平均损失:{avg_loss:.4f}")

        # 保存最优模型
        if avg_loss < best_loss and batch_count > 0:
            best_loss = avg_loss
            torch.save(model.state_dict(), Config.MODEL_SAVE_PATH)
            print(f"保存最优模型到:{Config.MODEL_SAVE_PATH}")

    print("\n训练完成!")

# ==============================
# 7. 推理函数(新增!)
# ==============================
def run_inference():
    # 检查测试集路径
    if not os.path.exists(Config.TEST_IMAGE_DIR):
        raise FileNotFoundError(f"测试集路径不存在:{Config.TEST_IMAGE_DIR}")

    # 加载训练好的模型
    model = smp.UnetPlusPlus(
        encoder_name="resnet34",
        encoder_weights=None,
        in_channels=3,
        classes=1
    ).to(Config.DEVICE)
    model.load_state_dict(torch.load(Config.MODEL_SAVE_PATH, map_location=Config.DEVICE))
    model.eval()

    # 筛选测试集有效图像
    def read_img_chinese(path):
        stream = open(path, 'rb')
        bytes = bytearray(stream.read())
        arr = np.asarray(bytes, dtype=np.uint8)
        return cv2.imdecode(arr, cv2.IMREAD_COLOR)

    valid_test_names = []
    all_test_files = os.listdir(Config.TEST_IMAGE_DIR)
    img_ext = ('.jpg', '.jpeg', '.png', '.bmp', '.webp', '.tif')
    for f in all_test_files:
        if f.lower().endswith(img_ext):
            full_path = os.path.join(Config.TEST_IMAGE_DIR, f)
            img = read_img_chinese(full_path)
            if img is not None and img.size > 0:
                valid_test_names.append(f)

    if len(valid_test_names) == 0:
        raise ValueError("测试集无有效图像!")

    # RLE编码函数
    def mask_to_rle(binary_mask):
        if binary_mask.sum() == 0:
            return ""
        fortran_mask = np.asfortranarray(binary_mask.astype(np.uint8))
        rle_dict = mask_utils.encode(fortran_mask)
        if isinstance(rle_dict['counts'], bytes):
            rle_dict['counts'] = rle_dict['counts'].decode('utf-8')
        return json.dumps(rle_dict)

    # 批量推理
    results = []
    with torch.no_grad():
        for img_name in tqdm(valid_test_names, desc="推理中"):
            img_path = os.path.join(Config.TEST_IMAGE_DIR, img_name)
            original_img = read_img_chinese(img_path)
            h, w = original_img.shape[:2]

            # 预处理
            img = cv2.resize(original_img, Config.IMG_SIZE)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            transform = get_valid_transforms()
            augmented = transform(image=img)
            img_tensor = augmented['image'].unsqueeze(0).to(Config.DEVICE)

            # 预测
            output = model(img_tensor)
            output = torch.sigmoid(output).cpu().numpy()[0, 0]
            pred_mask = (output > 0.5).astype(np.uint8)
            pred_mask_full = cv2.resize(pred_mask, (w, h), interpolation=cv2.INTER_NEAREST)

            # 判定伪造/非伪造
            is_forged = 1 if pred_mask_full.sum() > 100 else 0
            rle_str = mask_to_rle(pred_mask_full) if is_forged else ""

            # 生成解释文本
            if is_forged:
                explanation = "Abnormal noise and edge discontinuities detected, indicating image forgery."
            else:
                explanation = "No forgery traces found, image is authentic."

            results.append({
                "image_name": img_name,
                "label": is_forged,
                "location": rle_str,
                "explanation": explanation
            })

    # 保存提交文件
    df = pd.DataFrame(results)
    df.to_csv(Config.SUBMISSION_SAVE_PATH, index=False, encoding='utf-8')
    print(f"✅ 提交文件已保存:{Config.SUBMISSION_SAVE_PATH}")

# ==============================
# 6. 主函数
# ==============================
if __name__ == "__main__":
    # 注释训练,取消注释推理
    # train_model()
    run_inference()

1.导入库(固定依赖)

python

import os
import cv2
import json
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import segmentation_models_pytorch as smp
from pycocotools import mask as mask_utils
import albumentations as A
from albumentations.pytorch import ToTensorV2

2. 配置参数(可修改的核心设置)

python

class Config:
    TRAIN_BLACK_IMAGE_DIR = r"你的训练图像路径"
    TRAIN_BLACK_MASK_DIR = r"你的训练掩码路径"
    TEST_IMAGE_DIR = r"你的测试图像路径"
    
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    IMG_SIZE = (512, 512)
    BATCH_SIZE = 2  # CPU训练设小值,GPU可增大
    EPOCHS = 5
    LR = 1e-4
    MODEL_SAVE_PATH = "best_forgery_model.pth"
    SUBMISSION_SAVE_PATH = "submission.csv"
  • 核心作用:集中管理所有可配置参数,避免在代码中到处修改路径 / 超参数
  • 关键参数说明
    • DEVICE:自动选择 GPU/CPU 训练(有 GPU 会优先用,速度快 10 倍以上)
    • IMG_SIZE:统一将图像缩放到 512x512(平衡精度和速度)
    • BATCH_SIZE:每次训练喂给模型的图像数量(CPU 设 2,GPU 可设 8/16)
    • EPOCHS:训练轮数(先训 5 轮测试,后续可增至 20-50 轮)
    • MODEL_SAVE_PATH:训练好的模型保存路径
    • SUBMISSION_SAVE_PATH:推理结果的 CSV 保存路径

3. 数据增强函数(提升模型泛化能力)

python

def get_train_transforms(img_size=Config.IMG_SIZE):
    return A.Compose([...])  # 包含翻转、旋转、噪声、模糊、亮度调整等

def get_valid_transforms(img_size=Config.IMG_SIZE):
    return A.Compose([...])  # 仅缩放+归一化,无随机增强
  • 核心作用
    • get_train_transforms:训练时对图像做随机增强(翻转、旋转、加噪声等),让模型见过更多样的图像,避免过拟合
    • get_valid_transforms:推理时仅做必要的预处理(缩放 + 归一化),保证结果稳定
  • 关键操作
    • Normalize:用 ImageNet 的均值 / 标准差归一化(符合预训练模型的输入要求)
    • ToTensorV2:将 numpy 数组转为 PyTorch 张量(模型只能处理张量)
    • is_check_shapes=True:检查增强后图像 / 掩码形状是否匹配,避免报错

4. 数据集类(核心数据加载逻辑)

python

class ForensicsDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        # 1. 兼容中文路径的图像读取(解决Windows中文路径读取失败问题)
        # 2. 筛选有效图像文件(过滤非图像/损坏文件)
        # 3. 统计有效图像数量,做路径校验
        
    def __len__(self):
        return len(self.valid_image_names)  # 返回数据集大小
    
    def __getitem__(self, idx):
        # 1. 读取图像并转换为RGB(统一格式,避免灰度/透明通道问题)
        # 2. 读取对应的掩码文件,转为二值掩码(0=正常,1=伪造)
        # 3. 应用数据增强
        # 4. 返回图像张量和掩码张量
  • 核心作用:实现 PyTorch 的 Dataset 接口,负责「读取图像 / 掩码、格式转换、数据增强」,是连接原始数据和模型的桥梁
  • 关键
    • 兼容中文路径读取(解决 Windows 下 cv2.imread 读中文路径失败的问题)
    • 自动过滤损坏 / 无效图像,避免训练中断
    • 掩码自动转为二值格式(1 表示伪造区域),符合分割任务要求

5. 训练函数(模型训练核心逻辑)

python

def train_model():
    # 1. 校验训练路径,加载数据集和DataLoader
    # 2. 初始化Unet++模型(resnet34作为骨干网络)
    # 3. 定义混合损失函数(DiceLoss + BCEWithLogitsLoss)
    # 4. 训练循环:前向传播→计算损失→反向传播→更新参数
    # 5. 保存损失最低的最优模型
  • 核心作用:完成模型的端到端训练
  • 关键细节
    • 损失函数:DiceLoss(适合分割任务,关注区域重叠)+ BCE(二分类损失),兼顾精度和稳定性
    • 优化器:Adam(自适应学习率,收敛快)
    • 进度显示:用 tqdm 显示每轮训练的进度
    • 容错处理:跳过错误批次,避免单个坏数据导致训练终止
    • 模型保存:只保存损失最低的模型,避免保存训练后期过拟合的模型

6. 推理函数(模型预测核心逻辑)

python

def run_inference():
    # 1. 加载训练好的模型,设置为评估模式
    # 2. 筛选测试集有效图像
    # 3. 定义RLE编码函数(将掩码转为比赛要求的字符串格式)
    # 4. 逐张图像推理:预处理→模型预测→后处理→生成结果
    # 5. 保存结果为CSV文件(包含图像名、是否伪造、伪造区域、解释文本)
  • 核心作用:用训练好的模型对新图像进行检测,输出标准化结果
  • 关键步骤
    • 模型加载:加载保存的最优模型权重
    • 图像预处理:和训练时的验证预处理保持一致,保证输入格式匹配
    • 预测后处理:
      • 对模型输出用 sigmoid 激活(转为 0-1 概率)
      • 阈值 0.5 分割(>0.5 判定为伪造区域)
      • 将掩码缩放回原图尺寸(匹配原始图像大小)
    • 伪造判定:伪造区域像素数 > 100 才判定为伪造(避免微小噪声误判)
    • RLE 编码:将伪造区域转为字符串格式(方便存储和提交)
    • 结果保存:生成包含 4 列的 CSV

7. 主函数(程序入口)

python

if __name__ == "__main__":
    # train_model()  # 训练时取消注释
    run_inference()   # 推理时取消注释
  • 核心作用:程序的入口,控制运行「训练」或「推理」
  • 使用方式
    • 先取消train_model()注释,运行训练,得到模型文件
    • 训练完成后,注释train_model(),取消run_inference()注释,运行推理

最终效果图

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐