-
Notifications
You must be signed in to change notification settings - Fork 4
Add elasticity benchmark #151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
2e5f976
ignore csv
joehart2001 6dad8a3
add density plot decorator
joehart2001 63b6304
update pyproject.toml
joehart2001 8767f21
add calc, analysis, app and half written docs
joehart2001 3e0a41c
add elas docs
joehart2001 660149b
tidy elasticity, move to interactive cells (but required a bit more c…
joehart2001 1c0e2b2
Delete unnecessary files
ElliottKasoar 6482f02
Fix docs
ElliottKasoar 6134fb8
address comments
joehart2001 a712e4a
Tidy docs
ElliottKasoar 93c9cd8
address comments
joehart2001 e1ab6d0
Update ml_peg/analysis/utils/decorators.py
joehart2001 62dd141
Tidy docs
ElliottKasoar 3500ed0
address comments about utils and decorators
joehart2001 b1fd29a
Update ml_peg/analysis/utils/decorators.py
joehart2001 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,7 @@ | |
| *.svg | ||
| *.xz | ||
| *.txt | ||
| *.csv | ||
| ~* | ||
| *~ | ||
| .project | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
235 changes: 235 additions & 0 deletions
235
ml_peg/analysis/bulk_crystal/elasticity/analyse_elasticity.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,235 @@ | ||
| """Analyse elasticity benchmark.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from pathlib import Path | ||
| from typing import Any | ||
|
|
||
| import pandas as pd | ||
| import pytest | ||
|
|
||
| from ml_peg.analysis.utils.decorators import ( | ||
| build_table, | ||
| plot_density_scatter, | ||
| ) | ||
| from ml_peg.analysis.utils.utils import ( | ||
| build_density_inputs, | ||
| load_metrics_config, | ||
| mae, | ||
| ) | ||
| from ml_peg.app import APP_ROOT | ||
| from ml_peg.calcs import CALCS_ROOT | ||
| from ml_peg.models.get_models import get_model_names | ||
| from ml_peg.models.models import current_models | ||
|
|
||
| MODELS = get_model_names(current_models) | ||
| CALC_PATH = CALCS_ROOT / "bulk_crystal" / "elasticity" / "outputs" | ||
| OUT_PATH = APP_ROOT / "data" / "bulk_crystal" / "elasticity" | ||
|
|
||
| METRICS_CONFIG_PATH = Path(__file__).with_name("metrics.yml") | ||
| DEFAULT_THRESHOLDS, DEFAULT_TOOLTIPS, DEFAULT_WEIGHTS = load_metrics_config( | ||
| METRICS_CONFIG_PATH | ||
| ) | ||
|
|
||
| K_COLUMN = "K_vrh" | ||
| G_COLUMN = "G_vrh" | ||
|
|
||
|
|
||
| def _filter_results(df: pd.DataFrame, model_name: str) -> tuple[pd.DataFrame, int]: | ||
| """ | ||
| Filter outlier predictions and return remaining data with exclusion count. | ||
ElliottKasoar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Parameters | ||
| ---------- | ||
| df | ||
| Dataframe containing raw benchmark results. | ||
| model_name | ||
| Model whose columns should be filtered. | ||
| Returns | ||
| ------- | ||
| tuple[pd.DataFrame, int] | ||
| Filtered dataframe and number of excluded systems. | ||
| """ | ||
| mask_bulk = df[f"{K_COLUMN}_{model_name}"].between(-50, 600) | ||
| mask_shear = df[f"{G_COLUMN}_{model_name}"].between(-50, 600) | ||
| valid = df[mask_bulk & mask_shear].copy() | ||
| excluded = len(df) - len(valid) | ||
| return valid, excluded | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def elasticity_stats() -> dict[str, dict[str, Any]]: | ||
| """ | ||
| Load and cache processed benchmark statistics per model. | ||
| Returns | ||
| ------- | ||
| dict[str, dict[str, Any]] | ||
| Processed information per model (bulk, shear, exclusion counts). | ||
| """ | ||
| OUT_PATH.mkdir(parents=True, exist_ok=True) | ||
| stats: dict[str, dict[str, Any]] = {} | ||
| for model_name in MODELS: | ||
| results_path = CALC_PATH / model_name / "moduli_results.csv" | ||
| df = pd.read_csv(results_path) | ||
|
|
||
| filtered, excluded = _filter_results(df, model_name) | ||
|
|
||
| stats[model_name] = { | ||
| "bulk": { | ||
| "ref": filtered[f"{K_COLUMN}_DFT"].tolist(), | ||
| "pred": filtered[f"{K_COLUMN}_{model_name}"].tolist(), | ||
| }, | ||
| "shear": { | ||
| "ref": filtered[f"{G_COLUMN}_DFT"].tolist(), | ||
| "pred": filtered[f"{G_COLUMN}_{model_name}"].tolist(), | ||
| }, | ||
| "excluded": excluded, | ||
| } | ||
|
|
||
| return stats | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def bulk_mae(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, float | None]: | ||
| """ | ||
| Mean absolute error for bulk modulus predictions. | ||
| Parameters | ||
| ---------- | ||
| elasticity_stats | ||
| Aggregated bulk/shear data per model. | ||
| Returns | ||
| ------- | ||
| dict[str, float | None] | ||
| MAE values for each model (``None`` if no data). | ||
| """ | ||
| results: dict[str, float | None] = {} | ||
| for model_name in MODELS: | ||
| prop = elasticity_stats.get(model_name, {}).get("bulk") | ||
| results[model_name] = mae(prop["ref"], prop["pred"]) | ||
| return results | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def shear_mae(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, float | None]: | ||
| """ | ||
| Mean absolute error for shear modulus predictions. | ||
| Parameters | ||
| ---------- | ||
| elasticity_stats | ||
| Aggregated bulk/shear data per model. | ||
| Returns | ||
| ------- | ||
| dict[str, float | None] | ||
| MAE values for each model (``None`` if no data). | ||
| """ | ||
| results: dict[str, float | None] = {} | ||
| for model_name in MODELS: | ||
| prop = elasticity_stats.get(model_name, {}).get("shear") | ||
| results[model_name] = mae(prop["ref"], prop["pred"]) | ||
| return results | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| @plot_density_scatter( | ||
| filename=OUT_PATH / "figure_bulk_density.json", | ||
| title="Bulk modulus density plot", | ||
| x_label="Reference bulk modulus / GPa", | ||
| y_label="Predicted bulk modulus / GPa", | ||
| ) | ||
| def bulk_density(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, dict]: | ||
| """ | ||
| Density scatter inputs for bulk modulus. | ||
| Parameters | ||
| ---------- | ||
| elasticity_stats | ||
| Aggregated bulk/shear data per model. | ||
| Returns | ||
| ------- | ||
| dict[str, dict] | ||
| Mapping of model name to density-scatter data. | ||
| """ | ||
| return build_density_inputs(MODELS, elasticity_stats, "bulk", metric_fn=mae) | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| @plot_density_scatter( | ||
joehart2001 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| filename=OUT_PATH / "figure_shear_density.json", | ||
| title="Shear modulus density plot", | ||
| x_label="Reference shear modulus / GPa", | ||
| y_label="Predicted shear modulus / GPa", | ||
| ) | ||
| def shear_density(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, dict]: | ||
| """ | ||
| Density scatter inputs for shear modulus. | ||
| Parameters | ||
| ---------- | ||
| elasticity_stats | ||
| Aggregated bulk/shear data per model. | ||
| Returns | ||
| ------- | ||
| dict[str, dict] | ||
| Mapping of model name to density-scatter data. | ||
| """ | ||
| return build_density_inputs(MODELS, elasticity_stats, "shear", metric_fn=mae) | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| @build_table( | ||
| filename=OUT_PATH / "elasticity_metrics_table.json", | ||
| metric_tooltips=DEFAULT_TOOLTIPS, | ||
| thresholds=DEFAULT_THRESHOLDS, | ||
| weights=DEFAULT_WEIGHTS, | ||
| ) | ||
| def metrics( | ||
| bulk_mae: dict[str, float | None], | ||
| shear_mae: dict[str, float | None], | ||
| ) -> dict[str, dict]: | ||
| """ | ||
| All elasticity metrics. | ||
| Parameters | ||
| ---------- | ||
| bulk_mae | ||
| Bulk modulus MAE per model. | ||
| shear_mae | ||
| Shear modulus MAE per model. | ||
| Returns | ||
| ------- | ||
| dict[str, dict] | ||
| Mapping of metric name to model-value dictionaries. | ||
| """ | ||
| return { | ||
| "Bulk modulus MAE": bulk_mae, | ||
| "Shear modulus MAE": shear_mae, | ||
| } | ||
|
|
||
|
|
||
| def test_elasticity( | ||
| metrics: dict[str, dict], | ||
| bulk_density: dict[str, dict], | ||
| shear_density: dict[str, dict], | ||
| ) -> None: | ||
| """ | ||
| Run elasticity analysis. | ||
| Parameters | ||
| ---------- | ||
| metrics | ||
| Benchmark metric values. | ||
| bulk_density | ||
| Density scatter inputs for bulk modulus. | ||
| shear_density | ||
| Density scatter inputs for shear modulus. | ||
| """ | ||
| return | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| metrics: | ||
| Bulk modulus MAE: | ||
| good: 2.0 | ||
| bad: 30.0 | ||
| unit: GPa | ||
| tooltip: Mean absolute error of VRH bulk modulus (lower is better). Excludes systems with bulk moduli < -50 GPa and > 500 GPa. | ||
| Shear modulus MAE: | ||
| good: 2.0 | ||
| bad: 30.0 | ||
| unit: GPa | ||
| tooltip: Mean absolute error of VRH shear modulus (lower is better). Excludes systems with shear moduli < -50 GPa and > 500 GPa. |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.