Skip to content

Commit 3e31883

Browse files
committed
utils: Add parameter of get model()
1 parent 21baff2 commit 3e31883

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
# 2. Model
2121
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22-
model = utils.get_model(config).to(device)
22+
model = utils.get_model(config, pretrained_backbone=True).to(device)
2323

2424
# 3. Loss function, optimizer, lr scheduler, scaler
2525
criterion = nn.CrossEntropyLoss()

utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def load_config():
2424
return config
2525

2626

27-
def get_model(config: dict, pretrained=False) -> torch.nn.Module:
27+
def get_model(config: dict, pretrained=False, pretrained_backbone=False) -> torch.nn.Module:
2828
assert isinstance(pretrained, bool)
2929
assert config['dataset']['num_classes'] == 20 or config['dataset']['num_classes'] == 8
3030

@@ -33,7 +33,10 @@ def get_model(config: dict, pretrained=False) -> torch.nn.Module:
3333
elif config['model'] == 'Backbone':
3434
model = models.backbone.Backbone(config['dataset']['num_classes'])
3535
elif config['model'] == 'Proposed':
36-
model = models.proposed.Proposed(config['dataset']['num_classes'], config['Backbone']['pretrained_weights'])
36+
if pretrained_backbone:
37+
model = models.proposed.Proposed(config['dataset']['num_classes'], config['Backbone']['pretrained_weights'])
38+
else:
39+
model = models.proposed.Proposed(config['dataset']['num_classes'])
3740
else:
3841
raise NameError('Wrong model name.')
3942

0 commit comments

Comments
 (0)