Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 1271955

Browse files
rahul-tulibfineran
authored andcommitted
[BugFix][Torchvision] update optimizer state dict before transfer learning (#1358)
* Add: an `_update_checkpoint_optimizer(...)` for deleting mismatching params from saved optimizer(s) state_dict * Remove: _update_checkpoint_optimizer in favor of loading in the optim state_dict only when `args.resume` is set * Remove: un-needed imports * Address review comments * Style
1 parent b1ec8a9 commit 1271955

File tree

1 file changed

+8
-1
lines changed
  • src/sparseml/pytorch/torchvision

1 file changed

+8
-1
lines changed

src/sparseml/pytorch/torchvision/train.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,14 @@ def collate_fn(batch):
533533
# load params
534534
if checkpoint is not None:
535535
if "optimizer" in checkpoint and not args.test_only:
536-
optimizer.load_state_dict(checkpoint["optimizer"])
536+
if args.resume:
537+
optimizer.load_state_dict(checkpoint["optimizer"])
538+
else:
539+
warnings.warn(
540+
"Optimizer state dict not loaded from checkpoint. Unless run is "
541+
"resumed with the --resume arg, the optimizer will start from a "
542+
"fresh state"
543+
)
537544
if model_ema and "model_ema" in checkpoint:
538545
model_ema.load_state_dict(checkpoint["model_ema"])
539546
if scaler and "scaler" in checkpoint:

0 commit comments

Comments
 (0)