File tree Expand file tree Collapse file tree 2 files changed +6
-3
lines changed Expand file tree Collapse file tree 2 files changed +6
-3
lines changed Original file line number Diff line number Diff line change 1919
2020 # 2. Model
2121 device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
22- model = utils .get_model (config ).to (device )
22+ model = utils .get_model (config , pretrained_backbone = True ).to (device )
2323
2424 # 3. Loss function, optimizer, lr scheduler, scaler
2525 criterion = nn .CrossEntropyLoss ()
Original file line number Diff line number Diff line change @@ -24,7 +24,7 @@ def load_config():
2424 return config
2525
2626
27- def get_model (config : dict , pretrained = False ) -> torch .nn .Module :
27+ def get_model (config : dict , pretrained = False , pretrained_backbone = False ) -> torch .nn .Module :
2828 assert isinstance (pretrained , bool )
2929 assert config ['dataset' ]['num_classes' ] == 20 or config ['dataset' ]['num_classes' ] == 8
3030
@@ -33,7 +33,10 @@ def get_model(config: dict, pretrained=False) -> torch.nn.Module:
3333 elif config ['model' ] == 'Backbone' :
3434 model = models .backbone .Backbone (config ['dataset' ]['num_classes' ])
3535 elif config ['model' ] == 'Proposed' :
36- model = models .proposed .Proposed (config ['dataset' ]['num_classes' ], config ['Backbone' ]['pretrained_weights' ])
36+ if pretrained_backbone :
37+ model = models .proposed .Proposed (config ['dataset' ]['num_classes' ], config ['Backbone' ]['pretrained_weights' ])
38+ else :
39+ model = models .proposed .Proposed (config ['dataset' ]['num_classes' ])
3740 else :
3841 raise NameError ('Wrong model name.' )
3942
You can’t perform that action at this time.
0 commit comments