最近在绘制目标检测实验的结果时,想将多个模型的P-R曲线绘制在一张图上,发现YOLO框架内只有单个训练结果的result.csv等文件,找了很多的博客最终解决了这个问题,现将代码和方法分享如下(借鉴了大佬Frommoon的博客):

1.修改PR绘制源码--目的是保存绘制数据

拿yolo11为例,代码路径在:/ultralytics/utils/metrics.py,其他YOLO框架也大致一样

def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names={}, on_plot=None):
    """Plots a precision-recall curve."""
    fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
    py = np.stack(py, axis=1)

    if 0 < len(names) < 21:  # display per-class legend if < 21 classes
        for i, y in enumerate(py.T):
            ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}")  # plot(recall, precision)
            # 保存每个类别曲线的xy
            # if 'Pose' in str(save_dir):
            #     with open(f'pr_data/11n/{names[i]}_pose.csv', 'w+') as f:
            #         for px_v, y_v in zip(px, y):
            #             f.write(f'{px_v},{y_v}\n')
            # else:
            #     with open(f'pr_data/11n/{names[i]}_Box.csv', 'w+') as f:
            #         for px_v, y_v in zip(px, y):
            #             f.write(f'{px_v},{y_v}\n')
    else:
        ax.plot(px, py, linewidth=1, color="grey")  # plot(recall, precision)

    ax.plot(px, py.mean(1), linewidth=3, color="blue", label=f"all classes {ap[:, 0].mean():.3f} mAP@0.5")
    with open('pr_data/11n.csv', 'w+') as f:#需要修改这里的路径为自己的csv路径
        for px_v, mean_y_v in zip(px, py.mean(1)):
            f.write(f'{px_v},{mean_y_v}\n')
            
    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
    ax.set_title("Precision-Recall Curve")
    fig.savefig(save_dir, dpi=250)
    plt.close(fig)
    if on_plot:
        on_plot(save_dir)

2.运行val.py,得到用于绘制P-R曲线的csv数据,也就是PR曲线的横纵坐标点

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO

if __name__ == '__main__':
    model = YOLO('/yolo11/yolo11-2/runs/train/kitti-yolo11/weights/best.pt') # 选择训练好的权重路径
    model.val(data='/Object_detection/LS/yolo11/yolo11-1/ultralytics/cfg/datasets/kitti.yaml',
              split='val', # split可以选择train、val、test 根据自己的数据集情况来选择.
              imgsz=640,
              batch=16,
              project='runs/val',
              name='exp',
              )

3.运行结果如下:

之后绘制的不同模型,只需更改第一步中的保存的名字和第二步中的权重文件路径,这样就得到了多个的PR曲线数据

4.最后绘制脚本(单个的P-R曲线)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

if __name__ == '__main__':
    file_list = ['pr_data/v5.csv', 'pr_data/v6.csv', 'pr_data/v7.csv', 'pr_data/v8.csv', 'pr_data/v10.csv', 'pr_data/11n.csv']
    names = ['v5', 'v6', 'v7', 'v8', 'v10', '11']
    # ap = ['0.673', '0.639', '1']

    plt.figure(figsize=(6, 6))
    for i in range(len(file_list)):
        pr_data = pd.read_csv(file_list[i], header=None)
        recall, precision = np.array(pr_data[0]), np.array(pr_data[1])

        plt.plot(recall, precision, label=f'{names[i]}')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve')
    plt.legend()
    plt.tight_layout()
    plt.savefig('pr.png')

得到的结果如下图

5.我们还可以将map50和map50-95绘制在一张画布上,具体代码如下(PR曲线也绘制在一起就去掉代码段的注释即可,我的工作仅需要两个AP的代码):

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def plot_metrics(ax, metric_col_name, y_label, color, modelname, is_pr=False):
    res_path = pr_csv_dict[modelname]
    try:
        data = pd.read_csv(res_path)
        data.columns = data.columns.str.strip()  # Remove spaces from column names

        if is_pr:
            precision = data['metrics/precision(B)'].values
            recall = data['metrics/recall(B)'].values
            ax.plot(recall, precision, label=modelname, color=color, linewidth='2')
        else:
            epochs = data['epoch'].values  # epoch column
            metric_data = data[metric_col_name].values  # Get the corresponding metric column
            print(f"Model: {modelname}, Metric: {metric_col_name}, Data: {metric_data[:5]}")  # Debug information
            ax.plot(epochs, metric_data, label=modelname, color=color, linewidth='2')

    except Exception as e:
        print(f"Error reading {modelname}: {e}")
# Main function
def plot_all_metrics():
    global pr_csv_dict
    pr_csv_dict = {
        'YOLOv5s-OBB': r'runs/train/train9_0.926/results.csv',
        'YOLOv6s-OBB': r'runs/train/exp25_0.929/results.csv',
        'YOLOv8s-OBB': r'runs/train/exp22_0.931/results.csv',
        'YOLOv10s-OBB': r'runs/train/exp21_0.932/results.csv',
        'YOLOv11s-OBB': r'runs/train/exp19_0.941/results.csv',
        'Rotated FasterRCNN-OBB': r'runs/train/exp8_0.875/results.csv',
        'Oriented R-CNN-OBB': r'runs/train/exp8_0.881/results.csv',
        'Ours': r'runs/train/train_0.945/results.csv',
    }

    colors = {
        'YOLOv5s-OBB': '#00EE76',
        'YOLOv6s-OBB': '#EEEE00',
        'YOLOv8s-OBB': '#8470FF',
        'YOLOv10s-OBB': 'orange',
        'YOLOv11s-OBB': '#838B8B',
        'Rotated FasterRCNN-OBB': '#00BFFF',
        'Oriented R-CNN-OBB': 'pink',
        'Ours': 'red',
    }

    fig, axs = plt.subplots(1, 2, figsize=(24, 8), tight_layout=True)  # 1 row, 3 columns

    # Set global font size
    plt.rcParams.update({'font.size': 16})

    # Plot PR Curve
    # file_list = ['pr_data/5s.csv', 'pr_data/6s.csv', 'pr_data/8s.csv', 'pr_data/10s.csv', 'pr_data/11s.csv', 'pr_data/Rotated FasterRCNN-OBB.csv','pr_data/Oriented R-CNN.csv','pr_data/ours.csv']
    # names = ['YOLOv5s-OBB', 'YOLOv6s-OBB', 'YOLOv8s-OBB', 'YOLOv10s-OBB', 'YOLOv11s-OBB', 'Rotated FasterRCNN-OBB','Oriented R-CNN-OBB','Ours']
    # for i in range(len(file_list)):
    #     pr_data = pd.read_csv(file_list[i], header=None)
    #     recall, precision = np.array(pr_data[0]), np.array(pr_data[1])
    #     color = colors[f'{names[i]}']  # Use the corresponding color
    #     axs[0].plot(recall, precision, label=f'{names[i]}', color=color, linewidth='2')  # Set linewidth
    #
    # axs[0].set_xlabel('Recall', fontsize=16)
    # axs[0].set_ylabel('Precision', fontsize=16)
    # axs[0].set_xlim(0, 1)
    # axs[0].set_ylim(0, 1)
    # axs[0].legend(loc='lower left', fontsize=12)
    # axs[0].spines['top'].set_linewidth(2)
    # axs[0].spines['right'].set_linewidth(2)
    # axs[0].spines['left'].set_linewidth(2)
    # axs[0].spines['bottom'].set_linewidth(2)
    # axs[0].tick_params(width=2, labelsize=14)
    # axs[0].set_title('Precision-Recall Curve', fontsize=18)

    # Plot mAP@0.5
    for modelname in pr_csv_dict:
        plot_metrics(axs[0], 'metrics/mAP50(B)', 'mAP@0.5', colors[modelname], modelname)

    axs[0].set_xlabel('Epoch', fontsize=16)
    axs[0].set_ylabel('mAP@0.5', fontsize=16)
    axs[0].set_xlim(0, None)
    axs[0].set_ylim(0, 1)
    axs[0].legend(loc='lower right', fontsize=12)
    axs[0].spines['top'].set_linewidth(2)
    axs[0].spines['right'].set_linewidth(2)
    axs[0].spines['left'].set_linewidth(2)
    axs[0].spines['bottom'].set_linewidth(2)
    axs[0].tick_params(width=2, labelsize=14)
    axs[0].set_title('mAP@0.5', fontsize=18)

    # Plot mAP@0.95
    for modelname in pr_csv_dict:
        plot_metrics(axs[1], 'metrics/mAP50-95(B)', 'mAP@0.5:0.95', colors[modelname], modelname)

    axs[1].set_xlabel('Epoch', fontsize=16)
    axs[1].set_ylabel('mAP@0.5:0.95', fontsize=16)
    axs[1].set_xlim(0, None)
    axs[1].set_ylim(0, 1)
    axs[1].legend(loc='lower right', fontsize=12)
    axs[1].spines['top'].set_linewidth(2)
    axs[1].spines['right'].set_linewidth(2)
    axs[1].spines['left'].set_linewidth(2)
    axs[1].spines['bottom'].set_linewidth(2)
    axs[1].tick_params(width=2, labelsize=14)
    axs[1].set_title('mAP@0.5:0.95', fontsize=18)

    plt.subplots_adjust(wspace=0.3)  # Adjust spacing between subplots

    # Save the figure
    plt.savefig('diff_yolo_metrics6.png', dpi=300)#保存位置
    plt.show()

# Execute plotting
if __name__ == '__main__':
    plot_all_metrics()

得到的效果如下:

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐