Skip to content

Commit 4c9e331

Browse files
authored
Merge pull request #895 from alan-turing-institute/model_params_fix
Overwrite defaults with user supplied param values
2 parents 3a63ee4 + 33700ce commit 4c9e331

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

autoemulate/core/compare.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -406,14 +406,18 @@ def compare(self):
406406
"parameters",
407407
model_cls.__name__,
408408
)
409-
# extract default parameters from the model's __init__
409+
# Extract default parameters from the model's __init__
410410
init_sig = inspect.signature(model_cls.__init__)
411-
init_params = {
411+
default_params = {
412412
param_name: param.default
413413
for param_name, param in init_sig.parameters.items()
414414
if param_name in model_cls.get_tune_params()
415415
}
416-
best_params_for_this_model = init_params
416+
# Overwrite defaults with user-supplied values
417+
best_params_for_this_model = {
418+
**default_params,
419+
**self.model_params,
420+
}
417421

418422
self.logger.debug(
419423
'Running cross-validation for model "%s" '

tests/core/test_compare.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,15 @@ def test_ae_no_tuning(sample_data_for_ae_compare):
102102
assert "likelihood_cls" in gp_params
103103

104104

105+
def test_ae_no_tuning_fix_params(sample_data_for_ae_compare):
106+
"""Test that model_params are correctly passed when tuning is disabled."""
107+
x, y = sample_data_for_ae_compare
108+
ae = AutoEmulate(
109+
x, y, models=["GaussianProcessRBF"], model_params={"posterior_predictive": True}
110+
)
111+
assert ae.best_result().model.model.posterior_predictive is True # pyright: ignore[reportAttributeAccessIssue]
112+
113+
105114
def test_get_model_subset():
106115
"""Test getting a subset of models based on pytroch and probabilistic flags."""
107116

0 commit comments

Comments
 (0)