Skip to content

Commit f88f3e4

Browse files
committed
address comments
1 parent b5e35e5 commit f88f3e4

File tree

5 files changed

+31
-49
lines changed

5 files changed

+31
-49
lines changed

ml_peg/analysis/bulk_crystal/elasticity/analyse_elasticity.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,17 @@ def _filter_results(df: pd.DataFrame, model_name: str) -> tuple[pd.DataFrame, in
5858
return valid, excluded
5959

6060

61-
def _collect_model_data() -> dict[str, dict[str, Any]]:
61+
@pytest.fixture
62+
def elasticity_stats() -> dict[str, dict[str, Any]]:
6263
"""
63-
Collect filtered results, MAEs, and exclusion counts per model.
64+
Load and cache processed benchmark statistics per model.
6465
6566
Returns
6667
-------
6768
dict[str, dict[str, Any]]
68-
Mapping of model name to bulk/shear data and exclusion metadata.
69+
Processed information per model (bulk, shear, exclusion counts).
6970
"""
71+
OUT_PATH.mkdir(parents=True, exist_ok=True)
7072
stats: dict[str, dict[str, Any]] = {}
7173
for model_name in MODELS:
7274
results_path = CALC_PATH / model_name / "moduli_results.csv"
@@ -85,21 +87,8 @@ def _collect_model_data() -> dict[str, dict[str, Any]]:
8587
},
8688
"excluded": excluded,
8789
}
88-
return stats
89-
9090

91-
@pytest.fixture
92-
def elasticity_stats() -> dict[str, dict[str, Any]]:
93-
"""
94-
Load and cache processed benchmark statistics per model.
95-
96-
Returns
97-
-------
98-
dict[str, dict[str, Any]]
99-
Processed information per model (bulk, shear, exclusion counts).
100-
"""
101-
OUT_PATH.mkdir(parents=True, exist_ok=True)
102-
return _collect_model_data()
91+
return stats
10392

10493

10594
@pytest.fixture
@@ -167,7 +156,7 @@ def bulk_density(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, dict]
167156
dict[str, dict]
168157
Mapping of model name to density-scatter data.
169158
"""
170-
return build_density_inputs(MODELS, elasticity_stats, "bulk", mae_fn=mae)
159+
return build_density_inputs(MODELS, elasticity_stats, "bulk", metric_fn=mae)
171160

172161

173162
@pytest.fixture
@@ -191,7 +180,7 @@ def shear_density(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, dict
191180
dict[str, dict]
192181
Mapping of model name to density-scatter data.
193182
"""
194-
return build_density_inputs(MODELS, elasticity_stats, "shear", mae_fn=mae)
183+
return build_density_inputs(MODELS, elasticity_stats, "shear", metric_fn=mae)
195184

196185

197186
@pytest.fixture

ml_peg/analysis/utils/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def build_density_inputs(
120120
model_stats: dict[str, dict[str, Any]],
121121
property_key: str,
122122
*,
123-
mae_fn: Callable[[list, list], float] | None = None,
123+
metric_fn: Callable[[list, list], float] | None = None,
124124
) -> dict[str, dict[str, Any]]:
125125
"""
126126
Prepare a model->data mapping for density scatter plots.
@@ -135,15 +135,15 @@ def build_density_inputs(
135135
property_key
136136
Key to extract from ``model_stats`` for each model (e.g. ``"bulk"`` or
137137
``"shear"``).
138-
mae_fn
139-
Optional callable to compute MAE. Defaults to :func:`mae` when None.
138+
metric_fn
139+
Optional callable to compute metric. Defaults to :func:`mae` when None.
140140
141141
Returns
142142
-------
143143
dict[str, dict[str, Any]]
144144
Mapping ready for ``plot_density_scatter``.
145145
"""
146-
mae_fn = mae if mae_fn is None else mae_fn
146+
metric_fn = mae if metric_fn is None else metric_fn
147147
inputs: dict[str, dict[str, Any]] = {}
148148

149149
for model_name in models:
@@ -156,7 +156,7 @@ def build_density_inputs(
156156
inputs[model_name] = {
157157
"ref": ref_vals,
158158
"pred": pred_vals,
159-
"mae": mae_fn(ref_vals, pred_vals) if ref_vals else None,
159+
"metric": metric_fn(ref_vals, pred_vals) if ref_vals else None,
160160
"meta": {"excluded": excluded} if excluded is not None else {},
161161
}
162162

ml_peg/app/bulk_crystal/elasticity/app_elasticity.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
MODELS = get_model_names(current_models)
1616
BENCHMARK_NAME = "Elasticity"
17-
DOCS_URL = "https://ddmms.github.io/ml-peg/user_guide/benchmarks/bulk.html#elasticity"
17+
DOCS_URL = (
18+
"https://ddmms.github.io/ml-peg/user_guide/benchmarks/bulk_crystal.html#elasticity"
19+
)
1820
DATA_PATH = APP_ROOT / "data" / "bulk_crystal" / "elasticity"
1921

2022

ml_peg/app/utils/load.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -174,29 +174,25 @@ def _filter_density_figure_for_model(fig_dict: dict, model: str) -> dict:
174174
line.
175175
"""
176176
data = fig_dict.get("data", [])
177-
layout = deepcopy(fig_dict.get("layout", {}))
178-
annotations_meta = (
179-
layout.get("meta") if isinstance(layout.get("meta"), dict) else {}
180-
)
177+
layout = deepcopy(fig_dict.get("layout"))
178+
annotations_meta = layout.get("meta")
181179

182180
fig_data = []
183181
for trace in data:
184182
name = trace.get("name")
185-
if name is None:
186-
line_trace = deepcopy(trace)
187-
line_trace["visible"] = True
188-
line_trace["showlegend"] = False # keep reference line, no legend
189-
fig_data.append(line_trace)
190-
elif name == model:
191-
model_trace = deepcopy(trace)
192-
model_trace["visible"] = True
193-
model_trace["showlegend"] = False # hide legend to avoid overlap
194-
fig_data.append(model_trace)
195-
196-
# Pick the matching annotation when available; otherwise keep a simple fallback.
183+
if name is None or name == model:
184+
# y=x line or matching model trace
185+
trace_copy = deepcopy(trace)
186+
trace_copy["visible"] = True
187+
trace_copy["showlegend"] = False
188+
fig_data.append(trace_copy)
189+
190+
# Pick the matching annotation when available
191+
stored_annotations = (
192+
annotations_meta.get("annotations") if annotations_meta else None
193+
)
194+
model_order = annotations_meta.get("models") if annotations_meta else None
197195
chosen_annotation = None
198-
stored_annotations = annotations_meta.get("annotations")
199-
model_order = annotations_meta.get("models")
200196
if isinstance(stored_annotations, list) and isinstance(model_order, list):
201197
try:
202198
idx = model_order.index(model)
@@ -206,11 +202,6 @@ def _filter_density_figure_for_model(fig_dict: dict, model: str) -> dict:
206202
pass
207203
if chosen_annotation:
208204
layout["annotations"] = [chosen_annotation]
209-
elif layout.get("annotations"):
210-
fallback = deepcopy(layout["annotations"][0])
211-
if isinstance(fallback, dict):
212-
fallback["text"] = model
213-
layout["annotations"] = [fallback]
214205

215206
# Hide legend entirely to prevent overlap with the density colorbar.
216207
layout["showlegend"] = False

ml_peg/calcs/bulk_crystal/elasticity/calc_elasticity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def run_elasticity_benchmark(
7979
calc,
8080
model_name,
8181
n_jobs=n_jobs,
82-
checkpoint_file=str(checkpoint_file) if checkpoint_file else None,
82+
checkpoint_file=checkpoint_file if checkpoint_file else None,
8383
checkpoint_freq=100,
8484
delete_checkpoint_on_finish=False,
8585
)

0 commit comments

Comments
 (0)