1 前言

  本博客记录的内容是通过Yolov5工程训练自制数据集并进行推理。

2 项目内容详细说明

2.1 数据集制作(原始视频图像帧的提取)

  如下图为数据集制作文件夹内容,其中mp4视频即需要制作数据集的画面。
在这里插入图片描述
  首先通过getframe.pyvideo1217.mp4按帧提取出图像到当前文件夹内的image文件夹内。

  然后在当前文件夹内通过命令行打开Labelme。

在这里插入图片描述
  然后在labelme中Open Dir定位到上面保存的Image文件夹。
  由于按帧提取视频可能有大量相似图像的存在,没必要一张张地标注,可以跳着标(也可以修改提取代码,间隔一定时间再提取帧),将标注过的json文件直接放置在image文件夹(这样做的好处是方便想要补充数据的时候直接标)。
  运行copySameName.py将已经经过标注的图片和json文件拷贝到datasets文件夹中去(通过判断image文件夹中哪些同名文件既有jpg格式也有json格式的)。
  文件夹中有一个J8MoniCangSwitch名称的文件夹,里面放置的是一个datasets文件夹cabinswitch_parameter.yaml以及cabinswitch_yolov5s.yaml文件。
  datasets文件夹内有image以及labels文件夹,每个文件夹内又分别有train和test两个文件夹。可以在labels文件夹内再放一个train_json文件夹用来先暂时存放标注的json文件(需要转化成txt格式)。
  datasets通过jsonTotxt.py脚本将label里面的train_json标注全部转换成txt类型标注放入train文件夹。
  下面是cabinswitch_parameter.yaml文件的内容,需要修改nc以及names字段的内容,按照自己的需求填写。
在这里插入图片描述
  下面是cabinswitch_yolov5s.yaml文件的内容,需要修改nc字段的内容。其他内容默认了,不要轻易修改。
在这里插入图片描述

2.2 训练

  将上一步做好的数据集放置在从github下载下来的yolov5官方工程文件夹yolov5-master文件夹内,修改train.py 文件下面红色箭头指向的项目为你需要的项目。
在这里插入图片描述
  在你的环境下,直接python train.py。看到下面的内容则表明正常开始训练了。
在这里插入图片描述

2.3 推理

  在你的环境下,运行python detect.py --weights yolov5s.pt --source 0这样就是加载当前目录下的yolov5s.pt权重文件以及将摄像头0作为输入数据源。
  下图是检测到盒子为Close的状态。
在这里插入图片描述
  下图是检测到盒子为Open的状态。
在这里插入图片描述
  下图是图像中没有检测出目标物体的情况。
在这里插入图片描述

3 代码

3.1 getframe.py

  getframe.py实现如下所示。

import cv2
import os
import sys

def extract_frames_to_image_folder(video_path, frame_interval=1, image_format='jpg', quality=95):
    """
    从MP4视频中提取帧并保存到当前目录的image文件夹
    
    参数:
    video_path: 视频文件路径
    frame_interval: 帧间隔,默认为1(提取每一帧)
    image_format: 图片格式,支持'jpg', 'png'
    quality: 图片质量(仅对jpg有效),1-100
    """
    
    # 检查视频文件是否存在
    if not os.path.exists(video_path):
        print(f"错误:视频文件不存在 - {video_path}")
        return False
    
    # 创建image文件夹
    current_dir = os.getcwd()
    image_folder = os.path.join(current_dir, "image")
    
    # 如果image文件夹不存在则创建
    if not os.path.exists(image_folder):
        os.makedirs(image_folder)
        print(f"创建文件夹: {image_folder}")
    
    print(f"图片将保存到: {image_folder}")
    
    # 打开视频文件
    cap = cv2.VideoCapture(video_path)
    
    if not cap.isOpened():
        print("错误:无法打开视频文件")
        return False
    
    # 获取视频信息
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    
    print(f"\n视频信息:")
    print(f"  文件名: {os.path.basename(video_path)}")
    print(f"  分辨率: {width}x{height}")
    print(f"  总帧数: {total_frames}")
    print(f"  帧率: {fps:.2f}")
    print(f"  图片格式: {image_format}")
    print(f"  帧间隔: {frame_interval}")
    print(f"  文件名格式: 0000.{image_format}, 0001.{image_format}, ...")
    
    frame_count = 0
    saved_count = 0
    
    print(f"\n开始提取帧...")
    
    try:
        while True:
            ret, frame = cap.read()
            
            if not ret:
                break
            
            # 按间隔保存帧
            if frame_count % frame_interval == 0:
                # 生成文件名,使用4位数格式,如 0000.jpg
                if image_format.lower() in ['jpg', 'jpeg']:
                    filename = f"{saved_count:04d}.jpg"
                    filepath = os.path.join(image_folder, filename)
                    cv2.imwrite(filepath, frame, [cv2.IMWRITE_JPEG_QUALITY, quality])
                elif image_format.lower() == 'png':
                    filename = f"{saved_count:04d}.png"
                    filepath = os.path.join(image_folder, filename)
                    cv2.imwrite(filepath, frame)
                else:
                    print(f"错误:不支持的图片格式 {image_format}")
                    return False
                
                saved_count += 1
                
                # 每处理100帧显示一次进度
                if saved_count % 100 == 0:
                    print(f"已保存 {saved_count} 张图片...")
            
            frame_count += 1
    
    except KeyboardInterrupt:
        print("\n\n用户中断,停止提取...")
    except Exception as e:
        print(f"\n\n发生错误: {str(e)}")
        return False
    finally:
        # 释放资源
        cap.release()
        cv2.destroyAllWindows()
    
    print(f"\n完成!")
    print(f"处理总帧数: {frame_count}")
    print(f"保存图片数: {saved_count}")
    print(f"图片保存在: {image_folder}")
    
    return True

def main():
    if len(sys.argv) < 2:
        print("使用方法: python extract_frames.py <视频文件> [帧间隔] [图片格式] [质量]")
        print("示例:")
        print("  python extract_frames.py video.mp4")
        print("  python extract_frames.py video.mp4 1 jpg 95")
        print("  python extract_frames.py video.mp4 5 png")
        print("\n参数说明:")
        print("  视频文件: 要处理的MP4视频文件路径")
        print("  帧间隔: 可选,默认1(提取每一帧)")
        print("  图片格式: 可选,默认jpg")
        print("  质量: 可选,仅对jpg有效,默认95")
        print("\n输出:")
        print("  图片保存在当前目录的image文件夹中")
        print("  文件名格式: 0000.jpg, 0001.jpg, 0002.jpg, ...")
        return
    
    # 获取参数
    video_path = sys.argv[1]
    
    # 设置默认值
    frame_interval = 1
    image_format = 'jpg'
    quality = 95
    
    # 解析可选参数
    if len(sys.argv) >= 3:
        try:
            frame_interval = int(sys.argv[2])
        except:
            print(f"警告:无效的帧间隔参数,使用默认值1")
    
    if len(sys.argv) >= 4:
        if sys.argv[3].lower() in ['jpg', 'jpeg', 'png']:
            image_format = sys.argv[3].lower()
        else:
            print(f"警告:无效的图片格式,使用默认值jpg")
    
    if len(sys.argv) >= 5 and image_format in ['jpg', 'jpeg']:
        try:
            quality = int(sys.argv[4])
            if quality < 1 or quality > 100:
                quality = 95
                print(f"警告:质量参数应在1-100之间,使用默认值95")
        except:
            print(f"警告:无效的质量参数,使用默认值95")
    
    # 执行提取
    extract_frames_to_image_folder(
        video_path=video_path,
        frame_interval=frame_interval,
        image_format=image_format,
        quality=quality
    )

if __name__ == "__main__":
    print("=" * 60)
    print("MP4视频帧提取工具")
    print("=" * 60)
    
    # 检查OpenCV是否安装
    try:
        cv2_version = cv2.__version__
        print(f"OpenCV版本: {cv2_version}")
    except:
        print("错误:未安装OpenCV!")
        print("请先安装: pip install opencv-python")
        sys.exit(1)
    
    main()

3.2 copySameName.py

  copySameName.py实现如下所示。

import os
import shutil
import argparse

# 在代码中直接指定路径
source_folder = "/media/ai/5c45cbac-396a-4328-b602-e47bc899eb89/ai/XR_SERVER/1222/video12221459/image"  # 替换为你的源文件夹路径
target_folder = "/media/ai/5c45cbac-396a-4328-b602-e47bc899eb89/ai/XR_SERVER/1222/video12221459/datasets"  # 替换为你的目标文件夹路径



def copy_matching_files(source_dir, target_dir):
    """
    将源文件夹中所有同名的jpg和json文件复制到目标文件夹
    
    Args:
        source_dir (str): 源文件夹路径
        target_dir (str): 目标文件夹路径
    """
    # 创建目标文件夹(如果不存在)
    os.makedirs(target_dir, exist_ok=True)
    
    # 获取源文件夹中所有文件
    all_files = os.listdir(source_dir)
    
    # 分离jpg和json文件
    jpg_files = [f for f in all_files if f.lower().endswith('.jpg')]
    json_files = [f for f in all_files if f.lower().endswith('.json')]
    
    # 获取不带扩展名的文件名
    jpg_names = {os.path.splitext(f)[0] for f in jpg_files}
    json_names = {os.path.splitext(f)[0] for f in json_files}
    
    # 找到同名的文件(既有jpg又有json)
    matching_names = jpg_names & json_names
    
    print(f"找到 {len(matching_names)} 对匹配的jpg和json文件")
    
    # 复制文件
    copied_count = 0
    for name in matching_names:
        jpg_file = name + '.jpg'
        json_file = name + '.json'
        
        # 构建完整路径
        jpg_src = os.path.join(source_dir, jpg_file)
        json_src = os.path.join(source_dir, json_file)
        jpg_dst = os.path.join(target_dir, jpg_file)
        json_dst = os.path.join(target_dir, json_file)
        
        # 复制文件
        try:
            shutil.copy2(jpg_src, jpg_dst)
            shutil.copy2(json_src, json_dst)
            copied_count += 2
            print(f"已复制: {jpg_file} 和 {json_file}")
        except Exception as e:
            print(f"复制失败 {jpg_file}/{json_file}: {e}")
    
    print(f"复制完成!共复制了 {copied_count} 个文件 ({len(matching_names)} 对)")

def main():

    
    copy_matching_files(source_folder, target_folder)

if __name__ == "__main__":
    main()

3.3 jasonTotxt.py

  jasonTotxt.py如下所示。

import json
import os

name2id = {'AUTOPILOT_PATH': 0, 
	   'AUTOPILOT_ALT/HDG': 1,
	   'AUTOPILOT_ALT': 2, 
	   'EAC_ARM': 3,
	   'EAC_OFF': 4, 
	   'FLAPS_UP': 5,
	   'FLAPS_MVR': 6, 
	   'FLAPS_DN': 7,
	   'RDRALTM_NRM': 8, 
	   'RDRALTM_DIS': 9	   
	   }  # 标签名称


def convert(img_size, box):
    dw = 1. / (img_size[0])
    dh = 1. / (img_size[1])
    x = (box[0] + box[2]) / 2.0 - 1
    y = (box[1] + box[3]) / 2.0 - 1
    w = abs(box[2] - box[0])
    h = abs(box[3] - box[1])
    x = x * dw
    w = w * dw
    y = y * dh
    h = h * dh
    return (x, y, w, h)


def decode_json(json_floder_path, json_name):
    txt_name = '/media/ai/5c45cbac-396a-4328-b602-e47bc899eb89/ai/XR_SERVER/1222/video12221459/J8MoniCangSwitch/datasets/labels/train/' + json_name[0:-5] + '.txt'
    # 存放txt的绝对路径

    print(txt_name)
    txt_file = open(txt_name, 'w')

    json_path = os.path.join(json_floder_path, json_name)
    data = json.load(open(json_path, 'r', encoding='gb2312', errors='ignore'))

    img_w = data['imageWidth']
    img_h = data['imageHeight']

    for i in data['shapes']:

        label_name = i['label']
        if (i['shape_type'] == 'rectangle'):
            x1 = int(i['points'][0][0])
            y1 = int(i['points'][0][1])
            x2 = int(i['points'][1][0])
            y2 = int(i['points'][1][1])

            bb = (x1, y1, x2, y2)
            bbox = convert((img_w, img_h), bb)
            txt_file.write(str(name2id[label_name]) + " " + " ".join([str(a) for a in bbox]) + '\n')


if __name__ == "__main__":

    json_floder_path = '/media/ai/5c45cbac-396a-4328-b602-e47bc899eb89/ai/XR_SERVER/1222/video12221459/J8MoniCangSwitch/datasets/labels/train_json'
    # 存放json的文件夹的绝对路径
    json_names = os.listdir(json_floder_path)
    for json_name in json_names:
        print(json_name)
        decode_json(json_floder_path, json_name)
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐