JDE(Towards Real-Time Multi-Object Tracking)代码测试——小白必看
自己在跑JDE多目标跟踪代码时,总结的方法和教训。
自己在跑JDE多目标跟踪代码时,总结的方法和教训。
训练代码在我另一篇文章:
JDE(Towards Real-Time Multi-Object Tracking)代码训练——小白必看_来吧,搞学习了!的博客-CSDN博客
1.数据集官方工具箱下载
UAVDT工具箱(UAV-benchmark-MOTD_v1.0):
VisDrone工具箱(VisDrone2018-MOT-toolkit)
2.模型测试,得到txt结果
首先需要在JDE代码上跑 track.py 文件,我们这里需要将UAVDT和VisDrone测试集路径及序列名放到代码里。代码如下,我主要加入了测试集路径及序列名,以及把--test-mot16参数换成了--dataset,这样就可以切换想要测试的数据集,--debug-detection-results是我在multitracker.py中加入的一部分代码用于展示检测结果,可以去看我训练的那篇文章,里面有讲这个。我更换了加载数据集路径,和安装openpyxl库(pip install openpyxl)生成指标结果.xlsx
加载数据集路径:
由于我的数据集没有 seqinfo.ini 文件,所以我注释掉了读取 seqinfo.ini的代码,然后自己设定了frame_rate 参数。
生成txt文件路径:
track.py 的代码:
import os
import os.path as osp
import cv2
import logging
import argparse
import motmetrics as mm
import torch
from tracker.multitracker import JDETracker
from utils import visualization as vis
from utils.log import logger
from utils.timer import Timer
from utils.evaluation import Evaluator
from utils.parse_config import parse_model_cfg
import utils.datasets as datasets
from utils.utils import *
def write_results(filename, results, data_type):
if data_type == 'mot':
save_format = '{frame},{id},{x1},{y1},{w},{h},1,-1,-1,-1\n'
elif data_type == 'kitti':
save_format = '{frame} {id} pedestrian 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
else:
raise ValueError(data_type)
with open(filename, 'w') as f:
for frame_id, tlwhs, track_ids in results:
if data_type == 'kitti':
frame_id -= 1
for tlwh, track_id in zip(tlwhs, track_ids):
if track_id < 0:
continue
x1, y1, w, h = tlwh
x2, y2 = x1 + w, y1 + h
line = save_format.format(frame=frame_id, id=track_id, x1=x1, y1=y1, x2=x2, y2=y2, w=w, h=h)
f.write(line)
logger.info('save results to {}'.format(filename))
def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_image=True, frame_rate=30):
'''
Processes the video sequence given and provides the output of tracking result (write the results in video file)
It uses JDE model for getting information about the online targets present.
Parameters
----------
opt : Namespace
Contains information passed as commandline arguments.
dataloader : LoadVideo
Instance of LoadVideo class used for fetching the image sequence and associated data.
data_type : String
Type of dataset corresponding(similar) to the given video.
result_filename : String
The name(path) of the file for storing results.
save_dir : String
Path to the folder for storing the frames containing bounding box information (Result frames).
show_image : bool
Option for shhowing individial frames during run-time.
frame_rate : int
Frame-rate of the given video.
Returns
-------
(Returns are not significant here)
frame_id : int
Sequence number of the last sequence
'''
if save_dir:
mkdir_if_missing(save_dir)
tracker = JDETracker(opt, frame_rate=frame_rate)
timer = Timer()
results = []
frame_id = 0
for path, img, img0 in dataloader:
if frame_id % 20 == 0:
logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1./max(1e-5, timer.average_time)))
# run tracking
timer.tic()
blob = torch.from_numpy(img).cuda().unsqueeze(0)
online_targets = tracker.update(blob, img0)
online_tlwhs = []
online_ids = []
for t in online_targets:
tlwh = t.tlwh
tid = t.track_id
# vertical = tlwh[2] / tlwh[3] > 1.6 # 由于JDE原始是做行人跟踪,所以过滤掉了宽高比大于1.6的跟踪框
# 参考 https://blog.csdn.net/sinat_33486980/article/details/106213731
if tlwh[2] * tlwh[3] > opt.min_box_area: # 由于JDE原始是做行人跟踪,所以过滤掉了宽高比大于1.6的跟踪框
# if tlwh[2] * tlwh[3] > opt.min_box_area and not vertical:
online_tlwhs.append(tlwh)
online_ids.append(tid)
timer.toc()
# save results
results.append((frame_id + 1, online_tlwhs, online_ids))
if show_image or save_dir is not None:
online_im = vis.plot_tracking(img0, online_tlwhs, online_ids, frame_id=frame_id,
fps=1. / timer.average_time)
if show_image:
cv2.imshow('online_im', online_im)
if save_dir is not None:
cv2.imwrite(os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), online_im)
frame_id += 1
# save results
write_results(result_filename, results, data_type)
return frame_id, timer.average_time, timer.calls
def main(opt, data_root='/data/MOT16/train', det_root=None, seqs=('MOT16-05',), exp_name='demo',
save_images=False, save_videos=False, show_image=True):
logger.setLevel(logging.INFO)
result_root = os.path.join(data_root, 'results', exp_name)
mkdir_if_missing(result_root)
data_type = 'mot'
# Read config
cfg_dict = parse_model_cfg(opt.cfg)
opt.img_size = [int(cfg_dict[0]['width']), int(cfg_dict[0]['height'])]
# run tracking
accs = []
n_frame = 0
timer_avgs, timer_calls = [], []
for seq in seqs:
output_dir = os.path.join(data_root, '..','outputs', exp_name, seq) if save_images or save_videos else None
logger.info('start seq: {}'.format(seq))
if 'VisDrone' in data_root:
dataloader = datasets.LoadImages(osp.join(data_root, seq), opt.img_size)
else:
dataloader = datasets.LoadImages(osp.join(data_root, seq, 'img1'), opt.img_size)
# dataloader = datasets.LoadImages(osp.join(data_root, seq, 'img1'), opt.img_size)
result_filename = os.path.join(result_root, '{}.txt'.format(seq))
# meta_info = open(os.path.join(data_root, seq, 'seqinfo.ini')).read()
# frame_rate = int(meta_info[meta_info.find('frameRate')+10:meta_info.find('\nseqLength')])
frame_rate = 30
nf, ta, tc = eval_seq(opt, dataloader, data_type, result_filename,
save_dir=output_dir, show_image=show_image, frame_rate=frame_rate)
n_frame += nf
timer_avgs.append(ta)
timer_calls.append(tc)
# eval
logger.info('Evaluate seq: {}'.format(seq))
evaluator = Evaluator(data_root, seq, data_type)
accs.append(evaluator.eval_file(result_filename))
if save_videos:
output_video_path = osp.join(output_dir, '{}.mp4'.format(seq))
cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg -c:v copy {}'.format(output_dir, output_video_path)
os.system(cmd_str)
timer_avgs = np.asarray(timer_avgs)
timer_calls = np.asarray(timer_calls)
all_time = np.dot(timer_avgs, timer_calls)
avg_time = all_time / np.sum(timer_calls)
logger.info('Time elapsed: {:.2f} seconds, FPS: {:.2f}'.format(all_time, 1.0 / avg_time))
# get summary
metrics = mm.metrics.motchallenge_metrics
mh = mm.metrics.create()
summary = Evaluator.get_summary(accs, seqs, metrics)
strsummary = mm.io.render_summary(
summary,
formatters=mh.formatters,
namemap=mm.io.motchallenge_metric_names
)
print(strsummary)
Evaluator.save_summary(summary, os.path.join(result_root, 'summary_{}.xlsx'.format(exp_name)))
if __name__ == '__main__':
parser = argparse.ArgumentParser(prog='track.py')
parser.add_argument('--cfg', type=str, default='cfg/yolov3_1088x608.cfg', help='cfg file path')
parser.add_argument('--weights', type=str, default='weights/latest.pt', help='path to weights file')
parser.add_argument('--iou-thres', type=float, default=0.5, help='iou threshold required to qualify as detected')
parser.add_argument('--conf-thres', type=float, default=0.5, help='object confidence threshold')
parser.add_argument('--nms-thres', type=float, default=0.4, help='iou threshold for non-maximum suppression')
parser.add_argument('--min-box-area', type=float, default=200, help='filter out tiny boxes')
parser.add_argument('--track-buffer', type=int, default=30, help='tracking buffer')
# parser.add_argument('--test-mot16', action='store_true', help='tracking buffer')
parser.add_argument('--dataset', type=str, default='UAVDT', help='tracking datasets') # xyl 20221027
parser.add_argument('--save-images', action='store_true', help='save tracking results (image)')
parser.add_argument('--save-videos', action='store_true', help='save tracking results (video)')
parser.add_argument('--debug-detection-results', action='store_true', help='whether visualzie detection result') # xyl 20221019 检查检查结果 multitracker.py/line219
opt = parser.parse_args()
print(opt, end='\n\n')
if opt.dataset == 'MOT17':
seqs_str = '''MOT17-02-SDP
MOT17-04-SDP
MOT17-05-SDP
MOT17-09-SDP
MOT17-10-SDP
MOT17-11-SDP
MOT17-13-SDP
'''
data_root = '/home/wangzd/datasets/MOT/MOT17/images/train'
elif opt.dataset == 'MOT16':
seqs_str = '''MOT16-01
MOT16-03
MOT16-06
MOT16-07
MOT16-08
MOT16-12
MOT16-14'''
data_root = '/home/wangzd/datasets/MOT/MOT16/images/test'
elif opt.dataset == 'UAVDT':
seqs_str = '''M0203
M0205
M0208
M0209
M0403
M0601
M0602
M0606
M0701
M0801
M0802
M1001
M1004
M1007
M1009
M1101
M1301
M1302
M1303
M1401'''
data_root = '/home/xyl/xyl-code/MOT/0.datasets/UAVDT_M/images/test'
elif opt.dataset == 'VisDrone2019':
seqs_str = '''uav0000009_03358_v
uav0000073_00600_v
uav0000073_04464_v
uav0000077_00720_v
uav0000119_02301_v
uav0000120_04775_v
uav0000161_00000_v
uav0000188_00000_v
uav0000201_00000_v
uav0000249_00001_v
uav0000249_02688_v
uav0000297_00000_v
uav0000297_02761_v
uav0000306_00230_v
uav0000355_00001_v
uav0000370_00001_v'''
data_root = '/home/xyl/xyl-code/MOT/0.datasets/VisDrone2019-MOT/VisDrone2019-MOT-test-dev/images/test'
seqs = [seq.strip() for seq in seqs_str.split()]
main(opt,
data_root=data_root,
seqs=seqs,
exp_name=opt.weights.split('/')[-2],
show_image=False,
save_images=opt.save_images,
save_videos=opt.save_videos)
再把你的模型的路径和配置文件cfg输进去,运行即可
python track.py --UAVDT
跑的结果是这样的
展示指标结果,并保存为.xlsx文件:
3.根据txt结果再官方数据集上进行测试
1.UAVDT上
将生成的结果文件夹放到RES_MOT的检测器文件夹下(我这里放到了FRCNN,FRCNN是检测器为Faster R-CNN),并给自己的文件夹命名JDE
生成的所有结果txt:
单个结果txt的内容:
修改CalculateTrakingAcc.m文件中的trackerName名称,和自己txt结果文件夹的名字对应上(我是JDE),然后可以运行了。
运行窗口:
最终测试结果:(可以看出是与JDE自己测试的结果是不同的,这里的指标高一点)
运行后的文件展示
其中 clean 文件夹是把跟踪结果再保存了一下,备份。相关代码在utils/preprocessResult.m中备份的:
2.在VisDrone上测试:
主要测试文件是evalMOT.py。有两种情况:
2.1用官方自带的跟踪器来测试(没有跟踪结果txt文件)
需要修改datasetPath,detPath,resPath的路径,detPath是数据集自带的检测器结果(visdrone2019中没有,visdrone2018中有),Task4a和Task4b代表着不同评估模式:Task4a without detection input and Task4b with detection input。
需要改动的地方为:
自带跟踪器为GOG:
runTrackerAllClass函数 对应着 生成跟踪结果,路径是刚开始定义的 resPath
isSeqDisplay = false; % flag to display the detections。是展示检测结果的开关
生成跟踪结果过程为:
开始各个序列的评估:
总的评估:
2.测试自己的跟踪器(已有跟踪结果txt文件)
同样的需要修改datasetPath,detPath,resPath的路径,还需要根据自己的数据集路径来更改seqPath,gtPath等。另外自己的模型只是在车辆上进行跟踪,所以 evalClassSet 评估的类别只有 car(想要多测评几类,就在这写几类,后面的函数classSplit_single_class也得对应的划分出来对应的类,测试结果就会给出对应的类别的测试指标)
需要改动的地方为:
PS:
因为VisDrone是多类别的多目标跟踪,所以会对跟踪结果进行分类。首先使用classSplit对真实值groundtruth进行分类 gtsortdata = classSplit(gtdata);
classSplit函数为:
真实值的标注情况为(倒数第三位 trackData(:,8) 是 0-11 ):
我的模型只用了4,5,6,9,跟踪结果中都是预测为一类 车辆 。所以我的结果文件中倒数第三位,即classSplit代码中的trackData(:,8) ,只有 -1
我再定义了一个函数classSplit_single_class来划分自己的测试结果txt
划分函数的位置:
运行过程为:
最终结果(我这里只有一类car,所以单类别和总指标都是一样的):
后记
自己是刚开始接触的多目标跟踪,跑代码记录的问题也会随时更新。有什么问题,希望大家互相交流。O(∩_∩)O
更多推荐
所有评论(0)