|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. |
| 2 | +import argparse |
| 3 | +import glob |
| 4 | +import multiprocessing as mp |
| 5 | +import numpy as np |
| 6 | +import os |
| 7 | +import sys |
| 8 | +import tempfile |
| 9 | +import time |
| 10 | +import warnings |
| 11 | +import cv2 |
| 12 | +import tqdm |
| 13 | + |
| 14 | +sys.path.insert(0, "./") # noqa |
| 15 | +from demo.mot_predictors import VisualizationDemo |
| 16 | +from detectron2.checkpoint import DetectionCheckpointer |
| 17 | +from detectron2.config import LazyConfig, instantiate |
| 18 | +from detectron2.data.detection_utils import read_image |
| 19 | +from detectron2.utils.logger import setup_logger |
| 20 | + |
| 21 | + |
| 22 | +# constants |
| 23 | +WINDOW_NAME = "MOT" |
| 24 | + |
| 25 | + |
| 26 | +def setup(args): |
| 27 | + cfg = LazyConfig.load(args.config_file) |
| 28 | + cfg = LazyConfig.apply_overrides(cfg, args.opts) |
| 29 | + return cfg |
| 30 | + |
| 31 | + |
| 32 | +def get_parser(): |
| 33 | + parser = argparse.ArgumentParser(description="detrex demo for visualizing customized inputs") |
| 34 | + parser.add_argument( |
| 35 | + "--config-file", |
| 36 | + default="projects/dino/configs/dino_r50_4scale_12ep.py", |
| 37 | + metavar="FILE", |
| 38 | + help="path to config file", |
| 39 | + ) |
| 40 | + parser.add_argument("--webcam", action="store_true", help="Take inputs from webcam.") |
| 41 | + parser.add_argument("--video-input", help="Path to video file.") |
| 42 | + parser.add_argument( |
| 43 | + "--input", |
| 44 | + nargs="+", |
| 45 | + help="A list of space separated input images; " |
| 46 | + "or a single glob pattern such as 'directory/*.jpg'", |
| 47 | + ) |
| 48 | + parser.add_argument( |
| 49 | + "--output", |
| 50 | + help="A file or directory to save output visualizations. " |
| 51 | + "If not given, will show output in an OpenCV window.", |
| 52 | + ) |
| 53 | + parser.add_argument( |
| 54 | + "--min_size_test", |
| 55 | + type=int, |
| 56 | + default=800, |
| 57 | + help="Size of the smallest side of the image during testing. Set to zero to disable resize in testing.", |
| 58 | + ) |
| 59 | + parser.add_argument( |
| 60 | + "--max_size_test", |
| 61 | + type=float, |
| 62 | + default=1333, |
| 63 | + help="Maximum size of the side of the image during testing.", |
| 64 | + ) |
| 65 | + parser.add_argument( |
| 66 | + "--img_format", |
| 67 | + type=str, |
| 68 | + default="RGB", |
| 69 | + help="The format of the loading images.", |
| 70 | + ) |
| 71 | + parser.add_argument( |
| 72 | + "--metadata_dataset", |
| 73 | + type=str, |
| 74 | + default="coco_2017_val", |
| 75 | + help="The metadata infomation to be used. Default to COCO val metadata.", |
| 76 | + ) |
| 77 | + parser.add_argument( |
| 78 | + "--confidence-threshold", |
| 79 | + type=float, |
| 80 | + default=0.5, |
| 81 | + help="Minimum score for instance predictions to be shown", |
| 82 | + ) |
| 83 | + parser.add_argument( |
| 84 | + "--opts", |
| 85 | + help="Modify config options using the command-line", |
| 86 | + default=None, |
| 87 | + nargs=argparse.REMAINDER, |
| 88 | + ) |
| 89 | + return parser |
| 90 | + |
| 91 | + |
| 92 | +def test_opencv_video_format(codec, file_ext): |
| 93 | + with tempfile.TemporaryDirectory(prefix="video_format_test") as dir: |
| 94 | + filename = os.path.join(dir, "test_file" + file_ext) |
| 95 | + writer = cv2.VideoWriter( |
| 96 | + filename=filename, |
| 97 | + fourcc=cv2.VideoWriter_fourcc(*codec), |
| 98 | + fps=float(30), |
| 99 | + frameSize=(10, 10), |
| 100 | + isColor=True, |
| 101 | + ) |
| 102 | + [writer.write(np.zeros((10, 10, 3), np.uint8)) for _ in range(30)] |
| 103 | + writer.release() |
| 104 | + if os.path.isfile(filename): |
| 105 | + return True |
| 106 | + return False |
| 107 | + |
| 108 | + |
| 109 | +if __name__ == "__main__": |
| 110 | + mp.set_start_method("spawn", force=True) |
| 111 | + args = get_parser().parse_args() |
| 112 | + setup_logger(name="fvcore") |
| 113 | + logger = setup_logger() |
| 114 | + logger.info("Arguments: " + str(args)) |
| 115 | + |
| 116 | + cfg = setup(args) |
| 117 | + |
| 118 | + model = instantiate(cfg.model) |
| 119 | + model.to(cfg.train.device) |
| 120 | + checkpointer = DetectionCheckpointer(model) |
| 121 | + checkpointer.load(cfg.train.init_checkpoint) |
| 122 | + |
| 123 | + model.eval() |
| 124 | + |
| 125 | + demo = VisualizationDemo( |
| 126 | + model=model, |
| 127 | + min_size_test=args.min_size_test, |
| 128 | + max_size_test=args.max_size_test, |
| 129 | + img_format=args.img_format, |
| 130 | + metadata_dataset=args.metadata_dataset, |
| 131 | + ) |
| 132 | + |
| 133 | + if args.input: |
| 134 | + if len(args.input) == 1: |
| 135 | + args.input = glob.glob(os.path.expanduser(args.input[0])) |
| 136 | + assert args.input, "The input path(s) was not found" |
| 137 | + args.input = sorted(args.input) |
| 138 | + for path in tqdm.tqdm(args.input, disable=not args.output): |
| 139 | + # use PIL, to be consistent with evaluation |
| 140 | + img = read_image(path, format="BGR") |
| 141 | + start_time = time.time() |
| 142 | + predictions, visualized_output = demo.run_on_image(img, args.confidence_threshold) |
| 143 | + logger.info( |
| 144 | + "{}: {} in {:.2f}s".format( |
| 145 | + path, |
| 146 | + "detected {} instances".format(len(predictions["instances"])) |
| 147 | + if "instances" in predictions |
| 148 | + else "finished", |
| 149 | + time.time() - start_time, |
| 150 | + ) |
| 151 | + ) |
| 152 | + |
| 153 | + if args.output: |
| 154 | + if os.path.isdir(args.output): |
| 155 | + assert os.path.isdir(args.output), args.output |
| 156 | + out_filename = os.path.join(args.output, os.path.basename(path)) |
| 157 | + else: |
| 158 | + assert len(args.input) == 1, "Please specify a directory with args.output" |
| 159 | + out_filename = args.output |
| 160 | + visualized_output.save(out_filename) |
| 161 | + else: |
| 162 | + cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL) |
| 163 | + cv2.imshow(WINDOW_NAME, visualized_output.get_image()[:, :, ::-1]) |
| 164 | + if cv2.waitKey(0) == 27: |
| 165 | + break # esc to quit |
| 166 | + elif args.webcam: |
| 167 | + assert args.input is None, "Cannot have both --input and --webcam!" |
| 168 | + assert args.output is None, "output not yet supported with --webcam!" |
| 169 | + cam = cv2.VideoCapture(0) |
| 170 | + for vis in tqdm.tqdm(demo.run_on_video(cam)): |
| 171 | + cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL) |
| 172 | + cv2.imshow(WINDOW_NAME, vis) |
| 173 | + if cv2.waitKey(1) == 27: |
| 174 | + break # esc to quit |
| 175 | + cam.release() |
| 176 | + cv2.destroyAllWindows() |
| 177 | + elif args.video_input: |
| 178 | + video = cv2.VideoCapture(args.video_input) |
| 179 | + width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| 180 | + height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| 181 | + frames_per_second = video.get(cv2.CAP_PROP_FPS) |
| 182 | + num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) |
| 183 | + basename = os.path.basename(args.video_input) |
| 184 | + codec, file_ext = ( |
| 185 | + ("x264", ".mkv") if test_opencv_video_format("x264", ".mkv") else ("mp4v", ".mp4") |
| 186 | + ) |
| 187 | + if codec == ".mp4v": |
| 188 | + warnings.warn("x264 codec not available, switching to mp4v") |
| 189 | + if args.output: |
| 190 | + if os.path.isdir(args.output): |
| 191 | + output_fname = os.path.join(args.output, basename) |
| 192 | + output_fname = os.path.splitext(output_fname)[0] + file_ext |
| 193 | + else: |
| 194 | + output_fname = args.output |
| 195 | + |
| 196 | + # assert not os.path.isfile(output_fname), output_fname |
| 197 | + output_file = cv2.VideoWriter( |
| 198 | + filename=output_fname, |
| 199 | + # some installation of opencv may not support x264 (due to its license), |
| 200 | + # you can try other format (e.g. MPEG) |
| 201 | + fourcc=cv2.VideoWriter_fourcc(*codec), |
| 202 | + fps=float(frames_per_second), |
| 203 | + frameSize=(width, height), |
| 204 | + isColor=True, |
| 205 | + ) |
| 206 | + assert os.path.isfile(args.video_input) |
| 207 | + for vis_frame in tqdm.tqdm(demo.run_on_video(video), total=num_frames): |
| 208 | + if args.output: |
| 209 | + output_file.write(vis_frame) |
| 210 | + else: |
| 211 | + cv2.namedWindow(basename, cv2.WINDOW_NORMAL) |
| 212 | + cv2.imshow(basename, vis_frame) |
| 213 | + if cv2.waitKey(1) == 27: |
| 214 | + break # esc to quit |
| 215 | + video.release() |
| 216 | + if args.output: |
| 217 | + output_file.release() |
| 218 | + else: |
| 219 | + cv2.destroyAllWindows() |
0 commit comments