Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit cfbfedf

Browse files
bfineranKSGulindbogunowicz
authored
[cherry-pick-1.4][torchvision][Bug-fix] ignore state dict error on transfer learning tasks + use PythonLogger default logger #1455 (#1460)
* [torchvision][Bug-fix] ignore state dict error on transfer learning tasks + use PythonLogger default logger (#1455) * Remove cf from native torchvision models * * do not pass default logger to PythonLogger * comments --------- Co-authored-by: Damian <[email protected]> Co-authored-by: Benjamin <[email protected]> * [torchvision] add ignore error tensors back to optional checkpoint load (#1459) --------- Co-authored-by: Konstantin Gulin <[email protected]> Co-authored-by: Damian <[email protected]>
1 parent 84deda6 commit cfbfedf

File tree

1 file changed

+30
-5
lines changed
  • src/sparseml/pytorch/torchvision

1 file changed

+30
-5
lines changed

src/sparseml/pytorch/torchvision/train.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ def collate_fn(batch):
572572

573573
if utils.is_main_process():
574574
loggers = [
575-
PythonLogger(logger=_LOGGER),
575+
PythonLogger(),
576576
]
577577
try:
578578
loggers.append(TensorBoardLogger(log_path=args.output_dir))
@@ -757,11 +757,36 @@ def _create_model(
757757
model, arch_key = model
758758
elif arch_key in torchvision.models.__dict__:
759759
# fall back to torchvision
760-
model = torchvision.models.__dict__[arch_key](
761-
pretrained=pretrained, num_classes=num_classes
762-
)
760+
# load initial, untrained model with correct number of classes
761+
model = torchvision.models.__dict__[arch_key](num_classes=num_classes)
762+
if pretrained is not None:
763+
# in transfer learning cases, final FC layer may not match dimensions
764+
# load base pretrained model and laod state dict with strict=False
765+
pretrained_model = torchvision.models.__dict__[arch_key](
766+
pretrained=pretrained
767+
)
768+
if (
769+
getattr(pretrained_model, "classifier", None)
770+
and pretrained_model.classifier.out_features != num_classes
771+
):
772+
del pretrained_model.classifier
773+
model.load_state_dict(pretrained_model.state_dict(), strict=False)
763774
if checkpoint_path is not None:
764-
load_model(checkpoint_path, model, strict=True)
775+
load_model(
776+
checkpoint_path,
777+
model,
778+
strict=True,
779+
ignore_error_tensors=[
780+
"classifier.fc.weight",
781+
"classifier.fc.bias",
782+
"classifier.1.weight",
783+
"classifier.1.bias",
784+
"fc.weight",
785+
"fc.bias",
786+
"classifier.weight",
787+
"classifier.bias",
788+
],
789+
)
765790
else:
766791
raise ValueError(
767792
f"Unable to find {arch_key} in ModelRegistry or in torchvision.models"

0 commit comments

Comments
 (0)