diff --git a/.gitignore b/.gitignore index c92733bd..98c52996 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,7 @@ *.svg *.xz *.txt +*.csv ~* *~ .project diff --git a/docs/source/user_guide/benchmarks/bulk_crystal.rst b/docs/source/user_guide/benchmarks/bulk_crystal.rst index fb6db0f0..67aa5ca9 100644 --- a/docs/source/user_guide/benchmarks/bulk_crystal.rst +++ b/docs/source/user_guide/benchmarks/bulk_crystal.rst @@ -1,6 +1,6 @@ -============ -Bulk Crystal -============ +============= +Bulk Crystals +============= Lattice constants ================= @@ -39,6 +39,7 @@ Computational cost Low: tests are likely to less than a minute to run on CPU. + Data availability ----------------- @@ -57,3 +58,60 @@ Reference data: foundation model for atomistic materials chemistry. The Journal of Chemical Physics, 163(18). * PBE-D3(BJ) + + +Elastic Moduli +============== + +Summary +------- + +Bulk and shear moduli calculated for 12122 bulk crystals from the materials project. + + +Metrics +------- + +(1) Bulk modulus MAE + +Mean absolute error (MAE) between predicted and reference bulk modulus (B) values. + +MatCalc's ElasticityCalc is used to deform the structures with normal (diagonal) strain +magnitudes of ±0.01 and ±0.005 for ϵ11, ϵ22, ϵ33, and off-diagonal strain magnitudes of +±0.06 and ±0.03 for ϵ23, ϵ13, ϵ12. The Voigt-Reuss-Hill (VRH) average is used to obtain +the bulk and shear moduli from the stress tensor. Both the initial and deformed +structures are relaxed with MatCalc's default ElasticityCalc settings. For more information, see +`MatCalc's ElasticityCalc documentation +`_. + +Analysis excludes materials with: + * B ≤ 0, B > 500 and G ≥ 0, G > 500 structures. + * H2, N2, O2, F2, Cl2, He, Xe, Ne, Kr, Ar + * Materials with density < 0.5 (less dense than Li, the lowest density solid element) + +(2) Shear modulus MAE + +Mean absolute error (MAE) between predicted and reference shear modulus (G) values + +Calculated alongside (1), with the same exclusion criteria used in analysis. + + +Computational cost +------------------ + +High: tests are likely to take hours-days to run on GPU. + + +Data availability +----------------- + +Input structures: + +* 1. De Jong, M. et al. Charting the complete elastic properties of + inorganic crystalline compounds. Sci Data 2, 150009 (2015). +* Dataset release: mp-pbe-elasticity-2025.3.json.gz from the Materials Project database. + +Reference data: + +* Same as input data +* PBE diff --git a/ml_peg/analysis/bulk_crystal/elasticity/analyse_elasticity.py b/ml_peg/analysis/bulk_crystal/elasticity/analyse_elasticity.py new file mode 100644 index 00000000..6472f071 --- /dev/null +++ b/ml_peg/analysis/bulk_crystal/elasticity/analyse_elasticity.py @@ -0,0 +1,235 @@ +"""Analyse elasticity benchmark.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import pandas as pd +import pytest + +from ml_peg.analysis.utils.decorators import ( + build_table, + plot_density_scatter, +) +from ml_peg.analysis.utils.utils import ( + build_density_inputs, + load_metrics_config, + mae, +) +from ml_peg.app import APP_ROOT +from ml_peg.calcs import CALCS_ROOT +from ml_peg.models.get_models import get_model_names +from ml_peg.models.models import current_models + +MODELS = get_model_names(current_models) +CALC_PATH = CALCS_ROOT / "bulk_crystal" / "elasticity" / "outputs" +OUT_PATH = APP_ROOT / "data" / "bulk_crystal" / "elasticity" + +METRICS_CONFIG_PATH = Path(__file__).with_name("metrics.yml") +DEFAULT_THRESHOLDS, DEFAULT_TOOLTIPS, DEFAULT_WEIGHTS = load_metrics_config( + METRICS_CONFIG_PATH +) + +K_COLUMN = "K_vrh" +G_COLUMN = "G_vrh" + + +def _filter_results(df: pd.DataFrame, model_name: str) -> tuple[pd.DataFrame, int]: + """ + Filter outlier predictions and return remaining data with exclusion count. + + Parameters + ---------- + df + Dataframe containing raw benchmark results. + model_name + Model whose columns should be filtered. + + Returns + ------- + tuple[pd.DataFrame, int] + Filtered dataframe and number of excluded systems. + """ + mask_bulk = df[f"{K_COLUMN}_{model_name}"].between(-50, 600) + mask_shear = df[f"{G_COLUMN}_{model_name}"].between(-50, 600) + valid = df[mask_bulk & mask_shear].copy() + excluded = len(df) - len(valid) + return valid, excluded + + +@pytest.fixture +def elasticity_stats() -> dict[str, dict[str, Any]]: + """ + Load and cache processed benchmark statistics per model. + + Returns + ------- + dict[str, dict[str, Any]] + Processed information per model (bulk, shear, exclusion counts). + """ + OUT_PATH.mkdir(parents=True, exist_ok=True) + stats: dict[str, dict[str, Any]] = {} + for model_name in MODELS: + results_path = CALC_PATH / model_name / "moduli_results.csv" + df = pd.read_csv(results_path) + + filtered, excluded = _filter_results(df, model_name) + + stats[model_name] = { + "bulk": { + "ref": filtered[f"{K_COLUMN}_DFT"].tolist(), + "pred": filtered[f"{K_COLUMN}_{model_name}"].tolist(), + }, + "shear": { + "ref": filtered[f"{G_COLUMN}_DFT"].tolist(), + "pred": filtered[f"{G_COLUMN}_{model_name}"].tolist(), + }, + "excluded": excluded, + } + + return stats + + +@pytest.fixture +def bulk_mae(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, float | None]: + """ + Mean absolute error for bulk modulus predictions. + + Parameters + ---------- + elasticity_stats + Aggregated bulk/shear data per model. + + Returns + ------- + dict[str, float | None] + MAE values for each model (``None`` if no data). + """ + results: dict[str, float | None] = {} + for model_name in MODELS: + prop = elasticity_stats.get(model_name, {}).get("bulk") + results[model_name] = mae(prop["ref"], prop["pred"]) + return results + + +@pytest.fixture +def shear_mae(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, float | None]: + """ + Mean absolute error for shear modulus predictions. + + Parameters + ---------- + elasticity_stats + Aggregated bulk/shear data per model. + + Returns + ------- + dict[str, float | None] + MAE values for each model (``None`` if no data). + """ + results: dict[str, float | None] = {} + for model_name in MODELS: + prop = elasticity_stats.get(model_name, {}).get("shear") + results[model_name] = mae(prop["ref"], prop["pred"]) + return results + + +@pytest.fixture +@plot_density_scatter( + filename=OUT_PATH / "figure_bulk_density.json", + title="Bulk modulus density plot", + x_label="Reference bulk modulus / GPa", + y_label="Predicted bulk modulus / GPa", +) +def bulk_density(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, dict]: + """ + Density scatter inputs for bulk modulus. + + Parameters + ---------- + elasticity_stats + Aggregated bulk/shear data per model. + + Returns + ------- + dict[str, dict] + Mapping of model name to density-scatter data. + """ + return build_density_inputs(MODELS, elasticity_stats, "bulk", metric_fn=mae) + + +@pytest.fixture +@plot_density_scatter( + filename=OUT_PATH / "figure_shear_density.json", + title="Shear modulus density plot", + x_label="Reference shear modulus / GPa", + y_label="Predicted shear modulus / GPa", +) +def shear_density(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, dict]: + """ + Density scatter inputs for shear modulus. + + Parameters + ---------- + elasticity_stats + Aggregated bulk/shear data per model. + + Returns + ------- + dict[str, dict] + Mapping of model name to density-scatter data. + """ + return build_density_inputs(MODELS, elasticity_stats, "shear", metric_fn=mae) + + +@pytest.fixture +@build_table( + filename=OUT_PATH / "elasticity_metrics_table.json", + metric_tooltips=DEFAULT_TOOLTIPS, + thresholds=DEFAULT_THRESHOLDS, + weights=DEFAULT_WEIGHTS, +) +def metrics( + bulk_mae: dict[str, float | None], + shear_mae: dict[str, float | None], +) -> dict[str, dict]: + """ + All elasticity metrics. + + Parameters + ---------- + bulk_mae + Bulk modulus MAE per model. + shear_mae + Shear modulus MAE per model. + + Returns + ------- + dict[str, dict] + Mapping of metric name to model-value dictionaries. + """ + return { + "Bulk modulus MAE": bulk_mae, + "Shear modulus MAE": shear_mae, + } + + +def test_elasticity( + metrics: dict[str, dict], + bulk_density: dict[str, dict], + shear_density: dict[str, dict], +) -> None: + """ + Run elasticity analysis. + + Parameters + ---------- + metrics + Benchmark metric values. + bulk_density + Density scatter inputs for bulk modulus. + shear_density + Density scatter inputs for shear modulus. + """ + return diff --git a/ml_peg/analysis/bulk_crystal/elasticity/metrics.yml b/ml_peg/analysis/bulk_crystal/elasticity/metrics.yml new file mode 100644 index 00000000..615504da --- /dev/null +++ b/ml_peg/analysis/bulk_crystal/elasticity/metrics.yml @@ -0,0 +1,11 @@ +metrics: + Bulk modulus MAE: + good: 2.0 + bad: 30.0 + unit: GPa + tooltip: Mean absolute error of VRH bulk modulus (lower is better). Excludes systems with bulk moduli < -50 GPa and > 500 GPa. + Shear modulus MAE: + good: 2.0 + bad: 30.0 + unit: GPa + tooltip: Mean absolute error of VRH shear modulus (lower is better). Excludes systems with shear moduli < -50 GPa and > 500 GPa. diff --git a/ml_peg/analysis/utils/decorators.py b/ml_peg/analysis/utils/decorators.py index 06479244..c95c48a8 100644 --- a/ml_peg/analysis/utils/decorators.py +++ b/ml_peg/analysis/utils/decorators.py @@ -2,6 +2,7 @@ from __future__ import annotations +from collections import defaultdict from collections.abc import Callable import functools from json import dump @@ -248,6 +249,270 @@ def plot_scatter_wrapper(*args, **kwargs) -> dict[str, Any]: return plot_scatter_decorator +def plot_density_scatter( + *, + title: str | None = None, + x_label: str | None = None, + y_label: str | None = None, + filename: str = "density_scatter.json", + colorbar_title: str = "Density", + grid_size: int = 80, + max_points_per_cell: int = 5, + seed: int = 0, +) -> Callable: + """ + Plot density-coloured parity scatter with legend-based model toggling. + + The decorated function must return a mapping of model name to a dictionary with + ``ref`` and ``pred`` arrays (and optional ``mae``). Each model is rendered as a + scatter trace with marker colours indicating local data density. + Only one model is shown at a time; use the legend to toggle models. + + Parameters + ---------- + title + Graph title shown above dropdown. Default is None. + x_label + Label for x-axis. Default is None. + y_label + Label for y-axis. Default is None. + filename + Filename to save plot as JSON. Default is "density_scatter.json". + colorbar_title + Title shown next to the density colour bar. Default is "Density". + grid_size + Number of bins per axis used to estimate local density. Default is 80. + max_points_per_cell + Maximum number of examples plotted per cell to keep renders responsive. + seed + Seed for deterministic sub-sampling. Default is 0. + + Returns + ------- + Callable + Decorator to wrap function. + """ + + def plot_density_decorator(func: Callable) -> Callable: + """ + Decorate function to plot density scatter. + + Parameters + ---------- + func + Function being wrapped. + + Returns + ------- + Callable + Wrapped function. + """ + + @functools.wraps(func) + def plot_density_wrapper(*args, **kwargs) -> dict[str, Any]: + """ + Wrap function to plot density scatter. + + Parameters + ---------- + *args + Arguments to pass to the function being wrapped. + **kwargs + Key word arguments to pass to the function being wrapped. + + Returns + ------- + dict + Results dictionary. + """ + + def _downsample( + ref_vals: np.ndarray, pred_vals: np.ndarray + ) -> tuple[list[float], list[float], list[int]]: + """ + Downsample data points while keeping dense regions representative. + + Parameters + ---------- + ref_vals + Reference (x-axis) values for all systems. + pred_vals + Predicted (y-axis) values for all systems. + + Returns + ------- + tuple[list[float], list[float], list[int]] + Downsampled reference values, predicted values, and density counts + corresponding to each retained point. + """ + if ref_vals.size == 0: + return [], [], [] + + delta_x = ref_vals.max() - ref_vals.min() + delta_y = pred_vals.max() - pred_vals.min() + eps = 1e-9 + + norm_x = np.clip( + # Normalise to [0, 1). Clamp to avoid hitting the upper bound so + # bin indices always live in [0, grid_size - 1]. + (ref_vals - ref_vals.min()) / max(delta_x, eps), + 0.0, + 0.999999, + ) + norm_y = np.clip( + (pred_vals - pred_vals.min()) / max(delta_y, eps), + 0.0, + 0.999999, + ) + bins_x = (norm_x * grid_size).astype(int) + bins_y = (norm_y * grid_size).astype(int) + cell_points: dict[tuple[int, int], list[int]] = defaultdict(list) + for idx, (cx, cy) in enumerate(zip(bins_x, bins_y, strict=True)): + cell_points[(int(cx), int(cy))].append(idx) + + rng = np.random.default_rng(seed) + sampled_x: list[float] = [] + sampled_y: list[float] = [] + sampled_density: list[int] = [] + for indices in cell_points.values(): + if len(indices) > max_points_per_cell: + chosen = rng.choice( + indices, size=max_points_per_cell, replace=False + ) + else: + chosen = indices + density = len(indices) + for idx in chosen: + sampled_x.append(float(ref_vals[idx])) + sampled_y.append(float(pred_vals[idx])) + sampled_density.append(density) + + return sampled_x, sampled_y, sampled_density + + results = func(*args, **kwargs) + if not isinstance(results, dict): + raise TypeError( + "Density plot decorator expects a mapping of model results." + ) + + if not results: + raise ValueError("No results provided for density plot.") + + global_min = np.inf + global_max = -np.inf + processed = {} + annotations = [] + for model in results: + data = results[model] + ref_vals = np.asarray(data.get("ref", []), dtype=float) + pred_vals = np.asarray(data.get("pred", []), dtype=float) + meta = data.get("meta") or {} + excluded = meta.get("excluded") + excluded_text = str(excluded) if excluded is not None else "n/a" + if ref_vals.size == 0 or pred_vals.size == 0: + sampled = ([], [], []) + else: + sampled = _downsample(ref_vals, pred_vals) + global_min = min(global_min, ref_vals.min(), pred_vals.min()) + global_max = max(global_max, ref_vals.max(), pred_vals.max()) + # Top left corner annotation for each model with exclusion info + annotations.append( + { + "text": f"{model} | Excluded: {excluded_text}", + "xref": "paper", + "yref": "paper", + "x": 0.02, + "y": 0.98, + "showarrow": False, + "bgcolor": "rgba(255,255,255,0.8)", + "bordercolor": "rgba(0,0,0,0.3)", + "borderpad": 4, + } + ) + processed[model] = { + "samples": sampled, + "counts": len(ref_vals), + "meta": excluded_text, + } + + if not np.isfinite(global_min) or not np.isfinite(global_max): + global_min, global_max = 0.0, 1.0 + + padding = 0.05 * ( + global_max - global_min if global_max != global_min else 1.0 + ) + line_start = global_min - padding + line_end = global_max + padding + + fig = go.Figure() + hovertemplate = ( + "Reference: %{x:.3f}
" + "Predicted: %{y:.3f}
" + "Density: %{customdata[0]:.0f}
" + "Excluded: %{meta[0]}" + ) + + for idx, model in enumerate(results): + sample_x, sample_y, density = processed[model]["samples"] + fig.add_trace( + go.Scattergl( + x=sample_x, + y=sample_y, + mode="markers", + name=model, + visible=idx == 0, + marker={ + "size": 6, + "color": density, + "colorscale": "Viridis", + "showscale": True, + "colorbar": {"title": colorbar_title}, + }, + customdata=np.array(density, dtype=float)[:, None] + if density + else None, + meta=[processed[model]["meta"]], + hovertemplate=hovertemplate, + ) + ) + + fig.add_trace( + go.Scatter( + x=[line_start, line_end], + y=[line_start, line_end], + mode="lines", + showlegend=False, + line={"color": "black", "dash": "dash"}, + visible=True, + ) + ) + + # Store all annotations and model order in layout meta so consumers + # can swap annotation text when filtering per-model on the frontend. + layout_meta = { + "annotations": annotations, + "models": list(results), + } + + fig.update_layout( + title={"text": title} if title else None, + xaxis={"title": {"text": x_label}}, + yaxis={"title": {"text": y_label}}, + annotations=[annotations[0]], + meta=layout_meta, + showlegend=True, + legend_title_text="Model", + ) + + Path(filename).parent.mkdir(parents=True, exist_ok=True) + fig.write_json(filename) + return results + + return plot_density_wrapper + + return plot_density_decorator + + def build_table( *, thresholds: Thresholds, diff --git a/ml_peg/analysis/utils/utils.py b/ml_peg/analysis/utils/utils.py index a2346c9d..12c059a5 100644 --- a/ml_peg/analysis/utils/utils.py +++ b/ml_peg/analysis/utils/utils.py @@ -4,6 +4,7 @@ from collections.abc import Callable from pathlib import Path +from typing import Any from matplotlib import cm from matplotlib.colors import Colormap @@ -113,6 +114,55 @@ def rmse(ref: list, prediction: list) -> float: return mean_squared_error(ref, prediction) +def build_density_inputs( + models: list[str], + model_results: dict[str, dict[str, Any]], + property_key: str, + metric_fn: Callable[[list, list], float], +) -> dict[str, dict[str, Any]]: + """ + Prepare a model->data mapping for density scatter plots. + + Parameters + ---------- + models + Ordered list of model names to include. + model_results + Mapping of model -> {"": {"ref": [...], "pred": [...]}, + "excluded": int}. These per-model property arrays come from the analysis step + (e.g. filtered bulk/shear values and metadata). + property_key + Key to extract from ``model_results`` for each model (e.g. ``"bulk"`` or + ``"shear"``). + metric_fn + Function that turns the ``ref`` and ``pred`` lists into a single value (for + example, MAE). This number is stored in the result so the plotting code can show + it in hover text/annotations. + + Returns + ------- + dict[str, dict[str, Any]] + Mapping ready for ``plot_density_scatter``. + """ + inputs: dict[str, dict[str, Any]] = {} + + for model_name in models: + stats = model_results.get(model_name, {}) + prop = stats.get(property_key) + excluded = stats.get("excluded") + + ref_vals = prop.get("ref", []) + pred_vals = prop.get("pred", []) + inputs[model_name] = { + "ref": ref_vals, + "pred": pred_vals, + "metric": metric_fn(ref_vals, pred_vals) if ref_vals else None, + "meta": {"excluded": excluded} if excluded is not None else {}, + } + + return inputs + + def calc_metric_scores( metrics_data: list[MetricRow], thresholds: Thresholds | None = None, diff --git a/ml_peg/app/bulk_crystal/bulk_crystal.yml b/ml_peg/app/bulk_crystal/bulk_crystal.yml new file mode 100644 index 00000000..a384ea5f --- /dev/null +++ b/ml_peg/app/bulk_crystal/bulk_crystal.yml @@ -0,0 +1,2 @@ +title: Bulk Crystals +description: Bulk crystal properties, including elastic moduli, phonons, and lattice constants. diff --git a/ml_peg/app/bulk_crystal/elasticity/app_elasticity.py b/ml_peg/app/bulk_crystal/elasticity/app_elasticity.py new file mode 100644 index 00000000..1c2ee23c --- /dev/null +++ b/ml_peg/app/bulk_crystal/elasticity/app_elasticity.py @@ -0,0 +1,79 @@ +"""Run elasticity benchmark app.""" + +from __future__ import annotations + +from dash import Dash +from dash.html import Div + +from ml_peg.app import APP_ROOT +from ml_peg.app.base_app import BaseApp +from ml_peg.app.utils.build_callbacks import plot_from_table_cell +from ml_peg.app.utils.load import read_density_plot_for_model +from ml_peg.models.get_models import get_model_names +from ml_peg.models.models import current_models + +MODELS = get_model_names(current_models) +BENCHMARK_NAME = "Elasticity" +DOCS_URL = ( + "https://ddmms.github.io/ml-peg/user_guide/benchmarks/bulk_crystal.html#elasticity" +) +DATA_PATH = APP_ROOT / "data" / "bulk_crystal" / "elasticity" + + +class ElasticityApp(BaseApp): + """Elasticity benchmark app layout and callbacks.""" + + def register_callbacks(self) -> None: + """Register callbacks to app.""" + density_plots = { + model: { + "Bulk modulus MAE": read_density_plot_for_model( + filename=DATA_PATH / "figure_bulk_density.json", + model=model, + id=f"{BENCHMARK_NAME}-{model}-bulk-figure", + ), + "Shear modulus MAE": read_density_plot_for_model( + filename=DATA_PATH / "figure_shear_density.json", + model=model, + id=f"{BENCHMARK_NAME}-{model}-shear-figure", + ), + } + for model in MODELS + } + + plot_from_table_cell( + table_id=self.table_id, + plot_id=f"{BENCHMARK_NAME}-figure-placeholder", + cell_to_plot=density_plots, + ) + + +def get_app() -> ElasticityApp: + """ + Get elasticity benchmark app layout and callback registration. + + Returns + ------- + ElasticityApp + Benchmark layout and callback registration. + """ + return ElasticityApp( + name=BENCHMARK_NAME, + description=( + "Performance when predicting VRH bulk and shear moduli for crystalline " + "materials compared against Materials Project reference data." + ), + docs_url=DOCS_URL, + table_path=DATA_PATH / "elasticity_metrics_table.json", + extra_components=[ + Div(id=f"{BENCHMARK_NAME}-figure-placeholder"), + ], + ) + + +if __name__ == "__main__": + full_app = Dash(__name__, assets_folder=DATA_PATH.parent.parent) + elasticity_app = get_app() + full_app.layout = elasticity_app.layout + elasticity_app.register_callbacks() + full_app.run(port=8054, debug=True) diff --git a/ml_peg/app/utils/load.py b/ml_peg/app/utils/load.py index d616ea27..daf4beab 100644 --- a/ml_peg/app/utils/load.py +++ b/ml_peg/app/utils/load.py @@ -2,6 +2,7 @@ from __future__ import annotations +from copy import deepcopy import json from pathlib import Path @@ -146,3 +147,85 @@ def read_plot(filename: str | Path, id: str = "figure-1") -> Graph: Loaded plotly Graph. """ return Graph(id=id, figure=read_json(filename)) + + +def _filter_density_figure_for_model(fig_dict: dict, model: str) -> dict: + """ + Filter a density-plot figure dict to a single model trace. + + Keeps the y=x reference line and swaps to the annotation matching the model, + using metadata stored by ``plot_density_scatter``. + + Parameters + ---------- + fig_dict + Figure dictionary loaded from saved density-plot JSON. + model + Model name to keep visible in the filtered figure. + + Returns + ------- + dict + Filtered figure dictionary with only the requested model trace and reference + line. + """ + data = fig_dict.get("data", []) + layout = deepcopy(fig_dict.get("layout")) + annotations_meta = layout.get("meta") + + fig_data = [] + for trace in data: + name = trace.get("name") + if name is None or name == model: + # ``name`` is ``None`` for the y=x reference line; keep that and the + # requested model trace visible while hiding their legend entries. + trace_copy = deepcopy(trace) + trace_copy["visible"] = True + trace_copy["showlegend"] = False + fig_data.append(trace_copy) + + # Pick the matching annotation (Plotly layout annotation with MAE/exclusion text) + stored_annotations = ( + annotations_meta.get("annotations") if annotations_meta else None + ) + model_order = annotations_meta.get("models") if annotations_meta else None + chosen_annotation = None + if isinstance(stored_annotations, list) and isinstance(model_order, list): + try: + idx = model_order.index(model) + if idx < len(stored_annotations): + chosen_annotation = stored_annotations[idx] + except ValueError: + pass + if chosen_annotation: + layout["annotations"] = [chosen_annotation] + + # Hide legend entirely to prevent overlap with the density colorbar. + layout["showlegend"] = False + + return {"data": fig_data, "layout": layout} + + +def read_density_plot_for_model( + filename: str | Path, model: str, id: str = "figure-1" +) -> Graph: + """ + Read a density-plot JSON and return a Graph filtered to a single model. + + Parameters + ---------- + filename + Path to saved density-plot JSON. + model + Model name to keep visible in the returned figure. + id + Dash component id for the Graph. + + Returns + ------- + Graph + Dash Graph displaying only the requested model (plus reference line). + """ + with open(filename) as f: + fig_dict = json.load(f) + return Graph(id=id, figure=_filter_density_figure_for_model(fig_dict, model)) diff --git a/ml_peg/calcs/bulk_crystal/elasticity/calc_elasticity.py b/ml_peg/calcs/bulk_crystal/elasticity/calc_elasticity.py new file mode 100644 index 00000000..39e6e553 --- /dev/null +++ b/ml_peg/calcs/bulk_crystal/elasticity/calc_elasticity.py @@ -0,0 +1,106 @@ +"""Run calculations for elasticity benchmark.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from matcalc.benchmark import ElasticityBenchmark +import pytest + +from ml_peg.models.get_models import load_models +from ml_peg.models.models import current_models + +MODELS = load_models(current_models) +OUT_PATH = Path(__file__).parent / "outputs" + + +def run_elasticity_benchmark( + *, + calc, + model_name: str, + out_dir: Path, + n_jobs: int = 1, + norm_strains: tuple[float, float, float, float] = (-0.1, -0.05, 0.05, 0.1), + shear_strains: tuple[float, float, float, float] = (-0.02, -0.01, 0.01, 0.02), + relax_structure: bool = True, + relax_deformed_structures: bool = True, + use_checkpoint: bool = True, + n_materials: int | None = None, + fmax: float = 0.05, +) -> None: + """ + Run elasticity benchmark and write results to CSV. + + Parameters + ---------- + calc + ASE calculator for evaluating structures. + model_name + Name of MLIP model. + out_dir + Directory to write per-model outputs. + n_jobs + Number of parallel workers for the benchmark. + norm_strains + Tuple of normal strains to apply. + shear_strains + Tuple of shear strains to apply. + relax_structure + Whether to relax the equilibrium structure before deformations. + relax_deformed_structures + Whether to relax each strained structure. + use_checkpoint + If True, writes intermediate checkpoints inside ``out_dir/checkpoints``. + n_materials + Number of materials sampled from the benchmark set. If None, use all materials. + fmax + Force threshold for structural relaxations. + """ + benchmark = ElasticityBenchmark( + n_samples=n_materials, + seed=2025, + fmax=fmax, + relax_structure=relax_structure, + relax_deformed_structures=relax_deformed_structures, + norm_strains=norm_strains, + shear_strains=shear_strains, + benchmark_name="mp-pbe-elasticity-2025.3.json.gz", + ) + + out_dir.mkdir(parents=True, exist_ok=True) + checkpoint_file = None + if use_checkpoint: + checkpoint_dir = out_dir / "checkpoints" + checkpoint_dir.mkdir(parents=True, exist_ok=True) + checkpoint_file = checkpoint_dir / f"{model_name}_checkpoint.json" + + results = benchmark.run( + calc, + model_name, + n_jobs=n_jobs, + checkpoint_file=checkpoint_file if checkpoint_file else None, + checkpoint_freq=100, + delete_checkpoint_on_finish=False, + ) + + results.to_csv(out_dir / "moduli_results.csv", index=False) + + +@pytest.mark.parametrize("mlip", MODELS.items()) +def test_elasticity(mlip: tuple[str, Any]) -> None: + """ + Run elasticity benchmark for a single model. + + Parameters + ---------- + mlip + Model entry containing name and object capable of providing a calculator. + """ + model_name, model = mlip + calc = model.get_calculator() + run_elasticity_benchmark( + calc=calc, + model_name=model_name, + out_dir=OUT_PATH / model_name, + ) diff --git a/pyproject.toml b/pyproject.toml index a87205cd..1dae5fa1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,8 @@ dependencies = [ "mlipx<0.2,>=0.1.5", "scikit-learn>=1.7.1", "typer<1.0.0,>=0.19.1", + "matcalc", + "matminer", ] [project.optional-dependencies]