Skip to content

Commit a8f2af0

Browse files
fengxiuyaunyangmasheng
andauthored
add CO-MOT for multi object tracking (#266)
* add CO-MOT for multi-object tracking * add CO-MOT for multi-object tracking * Simplify the code of CO_MOT * merge data+dataloader to co-mot --------- Co-authored-by: yangmasheng <[email protected]>
1 parent 940625d commit a8f2af0

28 files changed

+5399
-3
lines changed

README.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ Results and models are available in [model zoo](https://detrex.readthedocs.io/en
116116
- [x] [DINO (ICLR'2023)](./projects/dino/)
117117
- [x] [H-Deformable-DETR (CVPR'2023)](./projects/h_deformable_detr/)
118118
- [x] [MaskDINO (CVPR'2023)](./projects/maskdino/)
119-
119+
- [x] [CO-MOT (ArXiv'2023)](./projects/co_mot/)
120120

121121
Please see [projects](./projects/) for the details about projects that are built based on detrex.
122122

@@ -222,8 +222,15 @@ relevant publications:
222222
archivePrefix={arXiv},
223223
primaryClass={cs.CV}
224224
}
225+
@article{yan2023bridging,
226+
title={Bridging the Gap Between End-to-end and Non-End-to-end Multi-Object Tracking},
227+
author={Yan, Feng and Luo, Weixin and Zhong, Yujie and Gan, Yiyang and Ma, Lin},
228+
journal={arXiv preprint arXiv:2305.12724},
229+
year={2023}
230+
}
225231
```
226232

233+
227234
</details>
228235

229236

demo/mot_demo.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
import argparse
3+
import glob
4+
import multiprocessing as mp
5+
import numpy as np
6+
import os
7+
import sys
8+
import tempfile
9+
import time
10+
import warnings
11+
import cv2
12+
import tqdm
13+
14+
sys.path.insert(0, "./") # noqa
15+
from demo.mot_predictors import VisualizationDemo
16+
from detectron2.checkpoint import DetectionCheckpointer
17+
from detectron2.config import LazyConfig, instantiate
18+
from detectron2.data.detection_utils import read_image
19+
from detectron2.utils.logger import setup_logger
20+
21+
22+
# constants
23+
WINDOW_NAME = "MOT"
24+
25+
26+
def setup(args):
27+
cfg = LazyConfig.load(args.config_file)
28+
cfg = LazyConfig.apply_overrides(cfg, args.opts)
29+
return cfg
30+
31+
32+
def get_parser():
33+
parser = argparse.ArgumentParser(description="detrex demo for visualizing customized inputs")
34+
parser.add_argument(
35+
"--config-file",
36+
default="projects/dino/configs/dino_r50_4scale_12ep.py",
37+
metavar="FILE",
38+
help="path to config file",
39+
)
40+
parser.add_argument("--webcam", action="store_true", help="Take inputs from webcam.")
41+
parser.add_argument("--video-input", help="Path to video file.")
42+
parser.add_argument(
43+
"--input",
44+
nargs="+",
45+
help="A list of space separated input images; "
46+
"or a single glob pattern such as 'directory/*.jpg'",
47+
)
48+
parser.add_argument(
49+
"--output",
50+
help="A file or directory to save output visualizations. "
51+
"If not given, will show output in an OpenCV window.",
52+
)
53+
parser.add_argument(
54+
"--min_size_test",
55+
type=int,
56+
default=800,
57+
help="Size of the smallest side of the image during testing. Set to zero to disable resize in testing.",
58+
)
59+
parser.add_argument(
60+
"--max_size_test",
61+
type=float,
62+
default=1333,
63+
help="Maximum size of the side of the image during testing.",
64+
)
65+
parser.add_argument(
66+
"--img_format",
67+
type=str,
68+
default="RGB",
69+
help="The format of the loading images.",
70+
)
71+
parser.add_argument(
72+
"--metadata_dataset",
73+
type=str,
74+
default="coco_2017_val",
75+
help="The metadata infomation to be used. Default to COCO val metadata.",
76+
)
77+
parser.add_argument(
78+
"--confidence-threshold",
79+
type=float,
80+
default=0.5,
81+
help="Minimum score for instance predictions to be shown",
82+
)
83+
parser.add_argument(
84+
"--opts",
85+
help="Modify config options using the command-line",
86+
default=None,
87+
nargs=argparse.REMAINDER,
88+
)
89+
return parser
90+
91+
92+
def test_opencv_video_format(codec, file_ext):
93+
with tempfile.TemporaryDirectory(prefix="video_format_test") as dir:
94+
filename = os.path.join(dir, "test_file" + file_ext)
95+
writer = cv2.VideoWriter(
96+
filename=filename,
97+
fourcc=cv2.VideoWriter_fourcc(*codec),
98+
fps=float(30),
99+
frameSize=(10, 10),
100+
isColor=True,
101+
)
102+
[writer.write(np.zeros((10, 10, 3), np.uint8)) for _ in range(30)]
103+
writer.release()
104+
if os.path.isfile(filename):
105+
return True
106+
return False
107+
108+
109+
if __name__ == "__main__":
110+
mp.set_start_method("spawn", force=True)
111+
args = get_parser().parse_args()
112+
setup_logger(name="fvcore")
113+
logger = setup_logger()
114+
logger.info("Arguments: " + str(args))
115+
116+
cfg = setup(args)
117+
118+
model = instantiate(cfg.model)
119+
model.to(cfg.train.device)
120+
checkpointer = DetectionCheckpointer(model)
121+
checkpointer.load(cfg.train.init_checkpoint)
122+
123+
model.eval()
124+
125+
demo = VisualizationDemo(
126+
model=model,
127+
min_size_test=args.min_size_test,
128+
max_size_test=args.max_size_test,
129+
img_format=args.img_format,
130+
metadata_dataset=args.metadata_dataset,
131+
)
132+
133+
if args.input:
134+
if len(args.input) == 1:
135+
args.input = glob.glob(os.path.expanduser(args.input[0]))
136+
assert args.input, "The input path(s) was not found"
137+
args.input = sorted(args.input)
138+
for path in tqdm.tqdm(args.input, disable=not args.output):
139+
# use PIL, to be consistent with evaluation
140+
img = read_image(path, format="BGR")
141+
start_time = time.time()
142+
predictions, visualized_output = demo.run_on_image(img, args.confidence_threshold)
143+
logger.info(
144+
"{}: {} in {:.2f}s".format(
145+
path,
146+
"detected {} instances".format(len(predictions["instances"]))
147+
if "instances" in predictions
148+
else "finished",
149+
time.time() - start_time,
150+
)
151+
)
152+
153+
if args.output:
154+
if os.path.isdir(args.output):
155+
assert os.path.isdir(args.output), args.output
156+
out_filename = os.path.join(args.output, os.path.basename(path))
157+
else:
158+
assert len(args.input) == 1, "Please specify a directory with args.output"
159+
out_filename = args.output
160+
visualized_output.save(out_filename)
161+
else:
162+
cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
163+
cv2.imshow(WINDOW_NAME, visualized_output.get_image()[:, :, ::-1])
164+
if cv2.waitKey(0) == 27:
165+
break # esc to quit
166+
elif args.webcam:
167+
assert args.input is None, "Cannot have both --input and --webcam!"
168+
assert args.output is None, "output not yet supported with --webcam!"
169+
cam = cv2.VideoCapture(0)
170+
for vis in tqdm.tqdm(demo.run_on_video(cam)):
171+
cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
172+
cv2.imshow(WINDOW_NAME, vis)
173+
if cv2.waitKey(1) == 27:
174+
break # esc to quit
175+
cam.release()
176+
cv2.destroyAllWindows()
177+
elif args.video_input:
178+
video = cv2.VideoCapture(args.video_input)
179+
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
180+
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
181+
frames_per_second = video.get(cv2.CAP_PROP_FPS)
182+
num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
183+
basename = os.path.basename(args.video_input)
184+
codec, file_ext = (
185+
("x264", ".mkv") if test_opencv_video_format("x264", ".mkv") else ("mp4v", ".mp4")
186+
)
187+
if codec == ".mp4v":
188+
warnings.warn("x264 codec not available, switching to mp4v")
189+
if args.output:
190+
if os.path.isdir(args.output):
191+
output_fname = os.path.join(args.output, basename)
192+
output_fname = os.path.splitext(output_fname)[0] + file_ext
193+
else:
194+
output_fname = args.output
195+
196+
# assert not os.path.isfile(output_fname), output_fname
197+
output_file = cv2.VideoWriter(
198+
filename=output_fname,
199+
# some installation of opencv may not support x264 (due to its license),
200+
# you can try other format (e.g. MPEG)
201+
fourcc=cv2.VideoWriter_fourcc(*codec),
202+
fps=float(frames_per_second),
203+
frameSize=(width, height),
204+
isColor=True,
205+
)
206+
assert os.path.isfile(args.video_input)
207+
for vis_frame in tqdm.tqdm(demo.run_on_video(video), total=num_frames):
208+
if args.output:
209+
output_file.write(vis_frame)
210+
else:
211+
cv2.namedWindow(basename, cv2.WINDOW_NORMAL)
212+
cv2.imshow(basename, vis_frame)
213+
if cv2.waitKey(1) == 27:
214+
break # esc to quit
215+
video.release()
216+
if args.output:
217+
output_file.release()
218+
else:
219+
cv2.destroyAllWindows()

0 commit comments

Comments
 (0)