1. 准备数据集

步骤:

  1. 准备图片和标签文件

    • 将所有图像文件放置在同一个文件夹中,支持的格式包括 .jpg, .jpeg, .png, .bmp
    • 使用工具(如 X-AnyLabeling-CPU.exe)对图片进行标注,并生成 YOLO 格式的标签文件(与图片同名,扩展名为 .txt)。
    • 标签文件包含目标的类别和边界框坐标,格式为:
      class x_center y_center width height
      
      其中坐标为相对值(相对于图片的宽高)。
      在这里插入图片描述
  2. 拆分训练集和验证集
    使用以下代码将数据分为训练集和验证集:

数据分割代码:
import os
import random
import shutil


def split_dataset(dataset_dir, output_dir, val_ratio=0.2, seed=42):
    """
    将数据集分割为训练集和验证集。

    :param dataset_dir: 数据集路径,包含图像和标注文件。
    :param output_dir: 输出路径,将生成训练集和验证集文件夹。
    :param val_ratio: 验证集比例,默认为 0.2。
    :param seed: 随机种子,确保结果可复现,默认为 42。
    """
    # 设置随机种子
    random.seed(seed)

    # 获取所有的图像文件
    image_extensions = (".jpg", ".jpeg", ".png", ".bmp")
    images = [f for f in os.listdir(dataset_dir) if f.endswith(image_extensions)]
    images.sort()  # 确保排序一致

    # 打乱数据
    random.shuffle(images)

    # 按验证集比例分割数据
    total_images = len(images)
    val_size = int(total_images * val_ratio)
    val_images = images[:val_size]
    train_images = images[val_size:]

    # 创建输出文件夹
    train_dir = os.path.join(output_dir, "train")
    val_dir = os.path.join(output_dir, "val")
    os.makedirs(train_dir, exist_ok=True)
    os.makedirs(val_dir, exist_ok=True)

    # 复制文件到对应的文件夹
    for image in train_images:
        # 图像文件复制
        shutil.copy(os.path.join(dataset_dir, image), os.path.join(train_dir, image))

        # 标注文件复制
        label_file = os.path.splitext(image)[0] + ".txt"
        if os.path.exists(os.path.join(dataset_dir, label_file)):
            shutil.copy(os.path.join(dataset_dir, label_file), os.path.join(train_dir, label_file))

    for image in val_images:
        # 图像文件复制
        shutil.copy(os.path.join(dataset_dir, image), os.path.join(val_dir, image))

        # 标注文件复制
        label_file = os.path.splitext(image)[0] + ".txt"
        if os.path.exists(os.path.join(dataset_dir, label_file)):
            shutil.copy(os.path.join(dataset_dir, label_file), os.path.join(val_dir, label_file))

    print(f"数据集分割完成!")
    print(f"训练集大小: {len(train_images)}")
    print(f"验证集大小: {len(val_images)}")


if __name__ == "__main__":
    # 数据集路径 (包含图片和标注)
    dataset_dir = "path/to/your/dataset"

    # 输出路径
    output_dir = "path/to/output/dataset"

    # 验证集比例
    val_ratio = 0.2  # 20% 验证集

    # 调用函数进行数据分割
    split_dataset(dataset_dir, output_dir, val_ratio)

执行此代码后,会在 output_dir 中生成 trainval 两个文件夹,分别存放训练集和验证集。


2. 编写配置文件

YOLO 的配置文件使用 YAML 格式,主要包含以下内容:

train:
  - D:\\python\\ultralyticsTest\\train_data  # 训练集路径
val:
  - D:\\python\\ultralyticsTest\\train_data  # 验证集路径
nc: 10  # 类别数量,根据实际数据集修改
names:  # 类别名称
  - 0
  - 1
  - 2
  - 3
  - 4
  - 5
  - 6
  - 7
  - 8
  - 9

将此文件保存为 data.yaml,路径可以自定义。


3. 开始训练

使用 ultralytics 库进行训练。以下是完整代码:

from ultralytics import YOLO
import os


def train_yolo():
    # 加载YOLO模型
    # 可以加载预训练的权重文件,也可以从头开始训练
    # model = YOLO("runs/detect/train3/weights/best.pt")  # 继续训练之前的模型
    model = YOLO("yolo12n.pt")  # 或加载YOLO官方的模型,如 yolo12n.pt

    # 配置训练参数
    training_args = {
        'data': 'main/data.yaml',  # 使用YAML配置文件
        'epochs': 200,  # 训练轮数
        'batch': 20,  # 批次大小
        'imgsz': 640,  # 图片尺寸
        'device': 0,  # 使用GPU设备
        'workers': 8,  # 数据加载的工作进程数
        'patience': 100,  # 早停的耐心值
        'save': True,  # 保存训练结果
        'cache': False  # 是否缓存图片到内存
    }

    # 开始训练
    try:
        results = model.train(**training_args)
        print("训练完成!")
    except Exception as e:
        print(f"训练过程中出现错误: {str(e)}")


if __name__ == "__main__":
    train_yolo()

参数说明:

  • data: 数据配置文件路径。
  • epochs: 训练轮数。
  • batch: 批量大小,建议根据显存大小调整。
  • imgsz: 图片尺寸,默认为 640。
  • device: 训练设备,0 表示使用第一块 GPU。
  • workers: 数据加载进程数。
  • patience: 早停参数,若在指定轮数内验证集指标无提升,则停止训练。
  • save: 是否保存训练结果。
  • cache: 是否将图片缓存到内存以加速训练。

训练完成后,结果(包括权重文件和日志)会保存在 runs/detect/train/ 目录下。


4. 模型验证和测试

训练完成后,可以使用以下代码对模型进行验证或测试:

from ultralytics import YOLO

# 加载训练好的模型
model = YOLO("runs/detect/train/weights/best.pt")

# 验证模型
results = model.val()

# 测试图片
results = model.predict(source="path/to/test/image.jpg", save=True, save_txt=True)

# 查看结果
print(results)

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐