|
| 1 | +#!/usr/bin/env python |
| 2 | +# Copyright (c) Facebook, Inc. and its affiliates. |
| 3 | +""" |
| 4 | +Training script using the new "LazyConfig" python config files. |
| 5 | +
|
| 6 | +This scripts reads a given python config file and runs the training or evaluation. |
| 7 | +It can be used to train any models or dataset as long as they can be |
| 8 | +instantiated by the recursive construction defined in the given config file. |
| 9 | +
|
| 10 | +Besides lazy construction of models, dataloader, etc., this scripts expects a |
| 11 | +few common configuration parameters currently defined in "configs/common/train.py". |
| 12 | +To add more complicated training logic, you can easily add other configs |
| 13 | +in the config file and implement a new train_net.py to handle them. |
| 14 | +""" |
| 15 | +import logging |
| 16 | +import os |
| 17 | +import sys |
| 18 | +import time |
| 19 | +import torch |
| 20 | +from torch.nn.parallel import DataParallel, DistributedDataParallel |
| 21 | + |
| 22 | +from detectron2.checkpoint import DetectionCheckpointer |
| 23 | +from detectron2.config import LazyConfig, instantiate |
| 24 | +from detectron2.engine import ( |
| 25 | + SimpleTrainer, |
| 26 | + default_argument_parser, |
| 27 | + default_setup, |
| 28 | + default_writers, |
| 29 | + hooks, |
| 30 | + launch, |
| 31 | +) |
| 32 | +from detectron2.engine.defaults import create_ddp_model |
| 33 | +from detectron2.evaluation import inference_on_dataset, print_csv_format |
| 34 | +from detectron2.utils import comm |
| 35 | + |
| 36 | +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) |
| 37 | + |
| 38 | +logger = logging.getLogger("detrex") |
| 39 | + |
| 40 | + |
| 41 | +def match_name_keywords(n, name_keywords): |
| 42 | + out = False |
| 43 | + for b in name_keywords: |
| 44 | + if b in n: |
| 45 | + out = True |
| 46 | + break |
| 47 | + return out |
| 48 | + |
| 49 | + |
| 50 | +class Trainer(SimpleTrainer): |
| 51 | + """ |
| 52 | + We've combine Simple and AMP Trainer together. |
| 53 | + """ |
| 54 | + |
| 55 | + def __init__( |
| 56 | + self, |
| 57 | + model, |
| 58 | + dataloader, |
| 59 | + optimizer, |
| 60 | + amp=False, |
| 61 | + clip_grad_params=None, |
| 62 | + grad_scaler=None, |
| 63 | + ): |
| 64 | + super().__init__(model=model, data_loader=dataloader, optimizer=optimizer) |
| 65 | + |
| 66 | + unsupported = "AMPTrainer does not support single-process multi-device training!" |
| 67 | + if isinstance(model, DistributedDataParallel): |
| 68 | + assert not (model.device_ids and len(model.device_ids) > 1), unsupported |
| 69 | + assert not isinstance(model, DataParallel), unsupported |
| 70 | + |
| 71 | + if amp: |
| 72 | + if grad_scaler is None: |
| 73 | + from torch.cuda.amp import GradScaler |
| 74 | + |
| 75 | + grad_scaler = GradScaler() |
| 76 | + self.grad_scaler = grad_scaler |
| 77 | + |
| 78 | + # set True to use amp training |
| 79 | + self.amp = amp |
| 80 | + |
| 81 | + # gradient clip hyper-params |
| 82 | + self.clip_grad_params = clip_grad_params |
| 83 | + |
| 84 | + def run_step(self): |
| 85 | + """ |
| 86 | + Implement the standard training logic described above. |
| 87 | + """ |
| 88 | + assert self.model.training, "[Trainer] model was changed to eval mode!" |
| 89 | + assert torch.cuda.is_available(), "[Trainer] CUDA is required for AMP training!" |
| 90 | + from torch.cuda.amp import autocast |
| 91 | + |
| 92 | + start = time.perf_counter() |
| 93 | + """ |
| 94 | + If you want to do something with the data, you can wrap the dataloader. |
| 95 | + """ |
| 96 | + data = next(self._data_loader_iter) |
| 97 | + data_time = time.perf_counter() - start |
| 98 | + |
| 99 | + """ |
| 100 | + If you want to do something with the losses, you can wrap the model. |
| 101 | + """ |
| 102 | + loss_dict = self.model(data) |
| 103 | + with autocast(enabled=self.amp): |
| 104 | + if isinstance(loss_dict, torch.Tensor): |
| 105 | + losses = loss_dict |
| 106 | + loss_dict = {"total_loss": loss_dict} |
| 107 | + else: |
| 108 | + losses = sum(loss_dict.values()) |
| 109 | + |
| 110 | + """ |
| 111 | + If you need to accumulate gradients or do something similar, you can |
| 112 | + wrap the optimizer with your custom `zero_grad()` method. |
| 113 | + """ |
| 114 | + self.optimizer.zero_grad() |
| 115 | + |
| 116 | + if self.amp: |
| 117 | + self.grad_scaler.scale(losses).backward() |
| 118 | + if self.clip_grad_params is not None: |
| 119 | + self.grad_scaler.unscale_(self.optimizer) |
| 120 | + self.clip_grads(self.model.parameters()) |
| 121 | + self.grad_scaler.step(self.optimizer) |
| 122 | + self.grad_scaler.update() |
| 123 | + else: |
| 124 | + losses.backward() |
| 125 | + if self.clip_grad_params is not None: |
| 126 | + self.clip_grads(self.model.parameters()) |
| 127 | + self.optimizer.step() |
| 128 | + |
| 129 | + self._write_metrics(loss_dict, data_time) |
| 130 | + |
| 131 | + def clip_grads(self, params): |
| 132 | + params = list(filter(lambda p: p.requires_grad and p.grad is not None, params)) |
| 133 | + if len(params) > 0: |
| 134 | + return torch.nn.utils.clip_grad_norm_( |
| 135 | + parameters=params, |
| 136 | + **self.clip_grad_params, |
| 137 | + ) |
| 138 | + |
| 139 | + |
| 140 | +def do_test(cfg, model): |
| 141 | + if "evaluator" in cfg.dataloader: |
| 142 | + ret = inference_on_dataset( |
| 143 | + model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator) |
| 144 | + ) |
| 145 | + print_csv_format(ret) |
| 146 | + return ret |
| 147 | + |
| 148 | + |
| 149 | +def do_train(args, cfg): |
| 150 | + """ |
| 151 | + Args: |
| 152 | + cfg: an object with the following attributes: |
| 153 | + model: instantiate to a module |
| 154 | + dataloader.{train,test}: instantiate to dataloaders |
| 155 | + dataloader.evaluator: instantiate to evaluator for test set |
| 156 | + optimizer: instantaite to an optimizer |
| 157 | + lr_multiplier: instantiate to a fvcore scheduler |
| 158 | + train: other misc config defined in `configs/common/train.py`, including: |
| 159 | + output_dir (str) |
| 160 | + init_checkpoint (str) |
| 161 | + amp.enabled (bool) |
| 162 | + max_iter (int) |
| 163 | + eval_period, log_period (int) |
| 164 | + device (str) |
| 165 | + checkpointer (dict) |
| 166 | + ddp (dict) |
| 167 | + """ |
| 168 | + model = instantiate(cfg.model) |
| 169 | + logger = logging.getLogger("detectron2") |
| 170 | + logger.info("Model:\n{}".format(model)) |
| 171 | + model.to(cfg.train.device) |
| 172 | + |
| 173 | + # this is an hack of train_net |
| 174 | + param_dicts = [ |
| 175 | + { |
| 176 | + "params": [ |
| 177 | + p |
| 178 | + for n, p in model.named_parameters() |
| 179 | + if not match_name_keywords(n, ["backbone"]) |
| 180 | + and not match_name_keywords(n, ["reference_points", "sampling_offsets"]) |
| 181 | + and p.requires_grad |
| 182 | + ], |
| 183 | + "lr": 2e-4, |
| 184 | + }, |
| 185 | + { |
| 186 | + "params": [ |
| 187 | + p |
| 188 | + for n, p in model.named_parameters() |
| 189 | + if match_name_keywords(n, ["backbone"]) and p.requires_grad |
| 190 | + ], |
| 191 | + "lr": 2e-5, |
| 192 | + }, |
| 193 | + { |
| 194 | + "params": [ |
| 195 | + p |
| 196 | + for n, p in model.named_parameters() |
| 197 | + if match_name_keywords(n, ["reference_points", "sampling_offsets"]) |
| 198 | + and p.requires_grad |
| 199 | + ], |
| 200 | + "lr": 2e-5, |
| 201 | + }, |
| 202 | + ] |
| 203 | + optim = torch.optim.AdamW(param_dicts, 2e-4, weight_decay=1e-4) |
| 204 | + |
| 205 | + train_loader = instantiate(cfg.dataloader.train) |
| 206 | + |
| 207 | + model = create_ddp_model(model, **cfg.train.ddp) |
| 208 | + |
| 209 | + trainer = Trainer( |
| 210 | + model=model, |
| 211 | + dataloader=train_loader, |
| 212 | + optimizer=optim, |
| 213 | + amp=cfg.train.amp.enabled, |
| 214 | + clip_grad_params=cfg.train.clip_grad.params if cfg.train.clip_grad.enabled else None, |
| 215 | + ) |
| 216 | + |
| 217 | + checkpointer = DetectionCheckpointer( |
| 218 | + model, |
| 219 | + cfg.train.output_dir, |
| 220 | + trainer=trainer, |
| 221 | + ) |
| 222 | + |
| 223 | + trainer.register_hooks( |
| 224 | + [ |
| 225 | + hooks.IterationTimer(), |
| 226 | + hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)), |
| 227 | + hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer) |
| 228 | + if comm.is_main_process() |
| 229 | + else None, |
| 230 | + hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)), |
| 231 | + hooks.PeriodicWriter( |
| 232 | + default_writers(cfg.train.output_dir, cfg.train.max_iter), |
| 233 | + period=cfg.train.log_period, |
| 234 | + ) |
| 235 | + if comm.is_main_process() |
| 236 | + else None, |
| 237 | + ] |
| 238 | + ) |
| 239 | + |
| 240 | + checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume) |
| 241 | + if args.resume and checkpointer.has_checkpoint(): |
| 242 | + # The checkpoint stores the training iteration that just finished, thus we start |
| 243 | + # at the next iteration |
| 244 | + start_iter = trainer.iter + 1 |
| 245 | + else: |
| 246 | + start_iter = 0 |
| 247 | + trainer.train(start_iter, cfg.train.max_iter) |
| 248 | + |
| 249 | + |
| 250 | +def main(args): |
| 251 | + cfg = LazyConfig.load(args.config_file) |
| 252 | + cfg = LazyConfig.apply_overrides(cfg, args.opts) |
| 253 | + default_setup(cfg, args) |
| 254 | + |
| 255 | + if args.eval_only: |
| 256 | + model = instantiate(cfg.model) |
| 257 | + model.to(cfg.train.device) |
| 258 | + model = create_ddp_model(model) |
| 259 | + DetectionCheckpointer(model).load(cfg.train.init_checkpoint) |
| 260 | + print(do_test(cfg, model)) |
| 261 | + else: |
| 262 | + do_train(args, cfg) |
| 263 | + |
| 264 | + |
| 265 | +if __name__ == "__main__": |
| 266 | + args = default_argument_parser().parse_args() |
| 267 | + launch( |
| 268 | + main, |
| 269 | + args.num_gpus, |
| 270 | + num_machines=args.num_machines, |
| 271 | + machine_rank=args.machine_rank, |
| 272 | + dist_url=args.dist_url, |
| 273 | + args=(args,), |
| 274 | + ) |
0 commit comments