diff --git a/src/openfecli/commands/gather.py b/src/openfecli/commands/gather.py index 9692b2fbc..18aa2ec18 100644 --- a/src/openfecli/commands/gather.py +++ b/src/openfecli/commands/gather.py @@ -11,6 +11,7 @@ from openfecli import OFECommandPlugin from openfecli.clicktypes import HyphenAwareChoice +from openfecli.commands.quickrun import _QuickrunResult FAIL_STR = "Error" # string used to indicate a failed run in output tables. @@ -181,30 +182,7 @@ def is_results_json(fpath: os.PathLike | str) -> bool: return "estimate" in open(fpath, "r").read(20) -def load_json(fpath: os.PathLike | str) -> dict: - """Load a JSON file containing a gufe object. - - Parameters - ---------- - fpath : os.PathLike | str - The path to a gufe-serialized JSON. - - - Returns - ------- - dict - A dict containing data from the results JSON. - - """ - # TODO: move this function to openfe/utils - import json - - from gufe.tokenization import JSON_HANDLER - - return json.load(open(fpath, "r"), cls=JSON_HANDLER.decoder) - - -def _get_names(result: dict) -> tuple[str, str]: +def _get_names(result: _QuickrunResult) -> tuple[str, str]: """Get the ligand names from a unit's results data. Parameters @@ -219,7 +197,7 @@ def _get_names(result: dict) -> tuple[str, str]: """ # TODO: I don't like this [0][0] indexing, but I can't think of a better way currently - protocol_data = list(result["protocol_result"]["data"].values())[0][0] + protocol_data = list(result.protocol_result["data"].values())[0][0] try: name_A = protocol_data["inputs"]["setup_results"]["inputs"]["ligandmapping"]["componentA"][ "molprops" @@ -234,10 +212,10 @@ def _get_names(result: dict) -> tuple[str, str]: return str(name_A), str(name_B) -def _get_type(result: dict) -> Literal["vacuum", "solvent", "complex"]: +def _get_type(result: _QuickrunResult) -> Literal["vacuum", "solvent", "complex"]: """Determine the simulation type based on the component types.""" - protocol_data = list(result["protocol_result"]["data"].values())[0][0] + protocol_data = list(result.protocol_result["data"].values())[0][0] try: component_types = [ x["__module__"] @@ -276,7 +254,7 @@ def _legacy_get_type(res_fn: os.PathLike | str) -> Literal["vacuum", "solvent", def _get_result_id( - result: dict, result_fn: os.PathLike | str + result: _QuickrunResult, result_fn: os.PathLike | str ) -> tuple[tuple[str, str], Literal["vacuum", "solvent", "complex"]]: """Extract the name and simulation type from a results dict. @@ -302,7 +280,9 @@ def _get_result_id( return (ligA, ligB), simtype -def _load_valid_result_json(fpath: os.PathLike | str) -> tuple[tuple | None, dict | None]: +def _load_valid_result_json( + fpath: os.PathLike | str, +) -> tuple[tuple | None, _QuickrunResult | None]: """Load the data from a results JSON into a dict. Parameters @@ -317,25 +297,25 @@ def _load_valid_result_json(fpath: os.PathLike | str) -> tuple[tuple | None, dic or None if the JSON file is invalid or missing. """ - # TODO: only load this once during collection, then pass namedtuple(fname, dict) into this function # for now though, it's not the bottleneck on performance - result = load_json(fpath) + result = _QuickrunResult.from_json(fpath) + try: result_id = _get_result_id(result, fpath) except (ValueError, IndexError): click.secho(f"{fpath}: Missing ligand names and/or simulation type. Skipping.",err=True, fg="yellow") # fmt: skip return None, None - if result["estimate"] is None: + if result.estimate is None: click.secho(f"{fpath}: No 'estimate' found, assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip return result_id, None - if result["uncertainty"] is None: + if result.uncertainty is None: click.secho(f"{fpath}: No 'uncertainty' found, assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip return result_id, None - if result["unit_results"] == {}: + if result.unit_results == {}: click.secho(f"{fpath}: No 'unit_results' found, assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip return result_id, None - if all("exception" in u for u in result["unit_results"].values()): + if all("exception" in u for u in result.unit_results.values()): click.secho(f"{fpath}: Exception found in all 'unit_results', assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip return result_id, None @@ -679,7 +659,7 @@ def _get_legs_from_result_jsons( v[0]["outputs"]["unit_estimate"], v[0]["outputs"]["unit_estimate_error"], ) - for v in result["protocol_result"]["data"].values() + for v in result.protocol_result["data"].values() ] legs[names][simtype].append(parsed_raw_data) else: @@ -689,7 +669,7 @@ def _get_legs_from_result_jsons( else: dGs = [ v[0]["outputs"]["unit_estimate"] - for v in result["protocol_result"]["data"].values() + for v in result.protocol_result["data"].values() ] legs[names][simtype].extend(dGs) return legs diff --git a/src/openfecli/commands/gather_abfe.py b/src/openfecli/commands/gather_abfe.py index da2cc1525..5d33738ba 100644 --- a/src/openfecli/commands/gather_abfe.py +++ b/src/openfecli/commands/gather_abfe.py @@ -12,12 +12,12 @@ from openfecli.commands.gather import ( _collect_result_jsons, format_df_with_precision, - load_json, rich_print_to_stdout, ) +from openfecli.quickrun_result import _QuickrunResult -def _get_name(result: dict) -> str: +def _get_name(result: _QuickrunResult) -> str: """Get the ligand name from a unit's results data. Parameters @@ -31,7 +31,7 @@ def _get_name(result: dict) -> str: Ligand name corresponding to the results. """ - solvent_data = list(result["protocol_result"]["data"]["solvent"].values())[0][0] + solvent_data = list(result.protocol_result["data"]["solvent"].values())[0][0] try: name = solvent_data["inputs"]["setup_results"]["inputs"]["alchemical_components"]["stateA"][ 0 @@ -42,7 +42,9 @@ def _get_name(result: dict) -> str: return str(name) -def _load_valid_result_json(fpath: os.PathLike | str) -> tuple[tuple | None, dict | None]: +def _load_valid_result_json( + fpath: os.PathLike | str, +) -> tuple[tuple | None, _QuickrunResult | None]: """Load the data from a results JSON into a dict. Parameters @@ -67,19 +69,19 @@ def _load_valid_result_json(fpath: os.PathLike | str) -> tuple[tuple | None, dic # TODO: only load this once during collection, then pass namedtuple(fname, dict) into this function # for now though, it's not the bottleneck on performance - result = load_json(fpath) + result = _QuickrunResult.from_json(fpath) try: names = _get_name(result) except (ValueError, IndexError): click.secho(f"{fpath}: Missing ligand names and/or simulation type. Skipping.",err=True, fg="yellow") # fmt: skip return None, None - if result["estimate"] is None: + if result.estimate is None: click.secho(f"{fpath}: No 'estimate' found, assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip return names, None - if result["uncertainty"] is None: + if result.uncertainty is None: click.secho(f"{fpath}: No 'uncertainty' found, assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip return names, None - if all("exception" in u for u in result["unit_results"].values()): + if all("exception" in u for u in result.unit_results.values()): click.secho(f"{fpath}: Exception found in all 'unit_results', assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip return names, None return names, result @@ -116,25 +118,25 @@ def _get_legs_from_result_jsons( if name is None: # this means it couldn't find name and/or simtype continue - dgs[name]["overall"].append([result["estimate"], result["uncertainty"]]) - proto_key = [k for k in result["unit_results"].keys() if k.startswith("ProtocolUnitResult")] + dgs[name]["overall"].append([result.estimate, result.uncertainty]) + proto_key = [k for k in result.unit_results.keys() if k.startswith("ProtocolUnitResult")] for p in proto_key: # In openfe v1.9+, we only want to pick up results from # the Analysis Unit. To ensure backwards compatibility with # prior releases of openfe v1.x, we exclude Setup and Simulation if ( - "Setup" in result["unit_results"][p]["source_key"] - or "Simulation" in result["unit_results"][p]["source_key"] + "Setup" in result.unit_results[p]["source_key"] + or "Simulation" in result.unit_results[p]["source_key"] ): continue - if "unit_estimate" in result["unit_results"][p]["outputs"]: - simtype = result["unit_results"][p]["outputs"]["simtype"] - dg = result["unit_results"][p]["outputs"]["unit_estimate"] - dg_error = result["unit_results"][p]["outputs"]["unit_estimate_error"] + if "unit_estimate" in result.unit_results[p]["outputs"]: + simtype = result.unit_results[p]["outputs"]["simtype"] + dg = result.unit_results[p]["outputs"]["unit_estimate"] + dg_error = result.unit_results[p]["outputs"]["unit_estimate_error"] dgs[name][simtype].append([dg, dg_error]) - if "standard_state_correction" in result["unit_results"][p]["outputs"]: - corr = result["unit_results"][p]["outputs"]["standard_state_correction"] + if "standard_state_correction" in result.unit_results[p]["outputs"]: + corr = result.unit_results[p]["outputs"]["standard_state_correction"] # In openfe v1.9+, standard state corrections are set to 0 kcal/mol # when no correction is being applied (e.g. no restraints). # To make raw outputs similar to pre-v1.9, we exclude corrections diff --git a/src/openfecli/commands/gather_septop.py b/src/openfecli/commands/gather_septop.py index 104b384d1..4326b0a84 100644 --- a/src/openfecli/commands/gather_septop.py +++ b/src/openfecli/commands/gather_septop.py @@ -12,9 +12,9 @@ from openfecli.commands.gather import ( _collect_result_jsons, format_df_with_precision, - load_json, rich_print_to_stdout, ) +from openfecli.quickrun_result import _QuickrunResult def _load_valid_result_json(fpath: os.PathLike | str) -> tuple[tuple | None, dict | None]: @@ -42,19 +42,19 @@ def _load_valid_result_json(fpath: os.PathLike | str) -> tuple[tuple | None, dic # TODO: only load this once during collection, then pass namedtuple(fname, dict) into this function # for now though, it's not the bottleneck on performance - result = load_json(fpath) + result = _QuickrunResult.from_json(fpath) try: names = _get_names(result) except (ValueError, IndexError): click.secho(f"{fpath}: Missing ligand names and/or simulation type. Skipping.",err=True, fg="yellow") # fmt: skip return None, None - if result["estimate"] is None: + if result.estimate is None: click.secho(f"{fpath}: No 'estimate' found, assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip return names, None - if result["uncertainty"] is None: + if result.uncertainty is None: click.secho(f"{fpath}: No 'uncertainty' found, assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip return names, None - if all("exception" in u for u in result["unit_results"].values()): + if all("exception" in u for u in result.unit_results.values()): click.secho(f"{fpath}: Exception found in all 'unit_results', assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip return names, None return names, result @@ -91,10 +91,10 @@ def _get_legs_from_result_jsons( if names is None: # this means it couldn't find names and/or simtype continue - ddgs[names]["overall"].append([result["estimate"], result["uncertainty"]]) + ddgs[names]["overall"].append([result.estimate, result.uncertainty]) proto_key = [ k - for k in result["unit_results"].keys() + for k in result.unit_results.keys() if k.startswith("ProtocolUnitResult") ] # fmt: skip @@ -103,27 +103,27 @@ def _get_legs_from_result_jsons( # we check if there are any analysis units. If so, # we set a flag and later exclude Setup and Run. has_analysis_units = any( - ["Analysis" in result["unit_results"][p]["source_key"] for p in proto_key] + ["Analysis" in result.unit_results[p]["source_key"] for p in proto_key] ) for p in proto_key: # Skip non-analysis units if we have any if has_analysis_units and ( - "Setup" in result["unit_results"][p]["source_key"] - or "Run" in result["unit_results"][p]["source_key"] + "Setup" in result.unit_results[p]["source_key"] + or "Run" in result.unit_results[p]["source_key"] ): continue - if "unit_estimate" in result["unit_results"][p]["outputs"]: - simtype = result["unit_results"][p]["outputs"]["simtype"] - dg = result["unit_results"][p]["outputs"]["unit_estimate"] - dg_error = result["unit_results"][p]["outputs"]["unit_estimate_error"] + if "unit_estimate" in result.unit_results[p]["outputs"]: + simtype = result.unit_results[p]["outputs"]["simtype"] + dg = result.unit_results[p]["outputs"]["unit_estimate"] + dg_error = result.unit_results[p]["outputs"]["unit_estimate_error"] ddgs[names][simtype].append([dg, dg_error]) - if "standard_state_correction_A" in result["unit_results"][p]["outputs"]: - corr_A = result["unit_results"][p]["outputs"]["standard_state_correction_A"] - corr_B = result["unit_results"][p]["outputs"]["standard_state_correction_B"] + if "standard_state_correction_A" in result.unit_results[p]["outputs"]: + corr_A = result.unit_results[p]["outputs"]["standard_state_correction_A"] + corr_B = result.unit_results[p]["outputs"]["standard_state_correction_B"] ddgs[names]["standard_state_correction_A"].append( [corr_A, 0 * unit.kilocalorie_per_mole] ) @@ -134,7 +134,7 @@ def _get_legs_from_result_jsons( return ddgs -def _get_names(result: dict) -> tuple[str, str]: +def _get_names(result: _QuickrunResult) -> tuple[str, str]: """Get the ligand names from a unit's results data. Parameters @@ -148,7 +148,7 @@ def _get_names(result: dict) -> tuple[str, str]: Ligand names corresponding to the results. """ - solvent_data = list(result["protocol_result"]["data"]["solvent"].values())[0][0] + solvent_data = list(result.protocol_result["data"]["solvent"].values())[0][0] try: setup_data = solvent_data["inputs"]["setup"]["inputs"] diff --git a/src/openfecli/commands/quickrun.py b/src/openfecli/commands/quickrun.py index 677552435..e4bf7e408 100644 --- a/src/openfecli/commands/quickrun.py +++ b/src/openfecli/commands/quickrun.py @@ -2,13 +2,12 @@ # For details, see https://github.com/OpenFreeEnergy/openfe import hashlib -import json import pathlib -import warnings import click from openfecli import OFECommandPlugin +from openfecli.quickrun_result import _QuickrunResult from openfecli.utils import configure_logger, print_duration, write @@ -175,17 +174,14 @@ def quickrun(transformation, work_dir, output, resume): else: estimate = uncertainty = None # for output file - out_dict = { - "estimate": estimate, - "uncertainty": uncertainty, - "protocol_result": prot_result.to_dict(), - "unit_results": { - unit.key: unit.to_keyed_dict() for unit in dagresult.protocol_unit_results - }, - } - - with open(output, mode="w") as outf: - json.dump(out_dict, outf, cls=JSON_HANDLER.encoder) + quickrun_result = _QuickrunResult( + estimate=estimate, + uncertainty=uncertainty, + protocol_result=prot_result.to_dict(), + unit_results={unit.key: unit.to_keyed_dict() for unit in dagresult.protocol_unit_results}, + ) + + quickrun_result.to_json(output) # remove the cached dag since the job has completed os.remove(cached_dag_path) diff --git a/src/openfecli/quickrun_result.py b/src/openfecli/quickrun_result.py new file mode 100644 index 000000000..e3c79118b --- /dev/null +++ b/src/openfecli/quickrun_result.py @@ -0,0 +1,51 @@ +import json +from dataclasses import asdict, dataclass +from os import PathLike +from typing import Any, Self + +from gufe.tokenization import JSON_HANDLER +from openff.units import Quantity + + +@dataclass +class _QuickrunResult: + """ + Class for storing protocol result data along with useful metadata. + Could ProtocolResults store this data alongside ``n_protocol_dag_results``? + """ + + estimate: Quantity + uncertainty: Quantity + protocol_result: dict[str, Any] + unit_results: dict[int, dict] + + def to_json(self, filepath) -> None: + with open(filepath, mode="w") as file: + json.dump(asdict(self), file, cls=JSON_HANDLER.encoder) + + @classmethod + def from_json(cls, file: PathLike | None, content: str | None = None) -> Self: + """Load a JSON file containing a gufe object. + + Parameters + ---------- + fpath : os.PathLike | str + The path to a results JSON generated by ``openfe quickrun``. + + + Returns + ------- + _QuickrunResult + A _QuickrunResult instance containing the data from ``file, co``. + + """ + # similar to gufe.tokenization.from_json + if content is not None and file is not None: + raise ValueError("Cannot specify both `content` and `file`; only one input allowed") + elif content is None and file is None: + raise ValueError("Must specify either `content` and `file` for JSON input") + if file: + data = json.load(open(file, "r"), cls=JSON_HANDLER.decoder) + if content: + data = json.loads(content, cls=JSON_HANDLER.decoder) + return cls(**data) diff --git a/src/openfecli/tests/commands/test_gather.py b/src/openfecli/tests/commands/test_gather.py index 9d8defba1..b76447fce 100644 --- a/src/openfecli/tests/commands/test_gather.py +++ b/src/openfecli/tests/commands/test_gather.py @@ -18,6 +18,7 @@ ) from openfecli.commands.gather_abfe import gather_abfe from openfecli.commands.gather_septop import gather_septop +from openfecli.commands.quickrun import _QuickrunResult from openfecli.data._registry import ( POOCH_CACHE, zenodo_abfe_data, @@ -80,62 +81,88 @@ def test_get_column(val, col): assert _get_column(val) == col +@pytest.fixture +def min_valid_quickrun_result(min_result_json): + return _QuickrunResult(**min_result_json) + + class TestResultLoading: - def test_minimal_valid_results(self, capsys, min_result_json): - with mock.patch("openfecli.commands.gather.load_json", return_value=min_result_json): + def test_minimal_valid_results(self, capsys, min_valid_quickrun_result): + with mock.patch( + "openfecli.commands.gather._QuickrunResult.from_json", + return_value=min_valid_quickrun_result, + ): result = _load_valid_result_json(fpath="") captured = capsys.readouterr() - assert result == ((("lig_ejm_31", "lig_ejm_42"), "solvent"), min_result_json) + assert result == ((("lig_ejm_31", "lig_ejm_42"), "solvent"), min_valid_quickrun_result) assert captured.err == "" - def test_skip_missing_unit_result(self, capsys, min_result_json): - min_result_json["unit_results"] = {} + def test_skip_missing_unit_result(self, capsys, min_valid_quickrun_result): + min_valid_quickrun_result.unit_results = {} - with mock.patch("openfecli.commands.gather.load_json", return_value=min_result_json): + with mock.patch( + "openfecli.commands.gather._QuickrunResult.from_json", + return_value=min_valid_quickrun_result, + ): result = _load_valid_result_json(fpath="") captured = capsys.readouterr() assert result == ((("lig_ejm_31", "lig_ejm_42"), "solvent"), None) assert "No 'unit_results' found" in captured.err - def test_skip_missing_estimate(self, capsys, min_result_json): - min_result_json["estimate"] = None + def test_skip_missing_estimate(self, capsys, min_valid_quickrun_result): + min_valid_quickrun_result.estimate = None - with mock.patch("openfecli.commands.gather.load_json", return_value=min_result_json): + with mock.patch( + "openfecli.commands.gather._QuickrunResult.from_json", + return_value=min_valid_quickrun_result, + ): result = _load_valid_result_json(fpath="") captured = capsys.readouterr() assert result == ((("lig_ejm_31", "lig_ejm_42"), "solvent"), None) assert "No 'estimate' found" in captured.err - def test_skip_missing_uncertainty(self, capsys, min_result_json): - min_result_json["uncertainty"] = None + def test_skip_missing_uncertainty(self, capsys, min_valid_quickrun_result): + min_valid_quickrun_result.uncertainty = None - with mock.patch("openfecli.commands.gather.load_json", return_value=min_result_json): + with mock.patch( + "openfecli.commands.gather._QuickrunResult.from_json", + return_value=min_valid_quickrun_result, + ): result = _load_valid_result_json(fpath="") captured = capsys.readouterr() assert result == ((("lig_ejm_31", "lig_ejm_42"), "solvent"), None) assert "No 'uncertainty' found" in captured.err - def test_skip_all_failed_runs(self, capsys, min_result_json): - del min_result_json["unit_results"]["ProtocolUnitResult-e85"] - with mock.patch("openfecli.commands.gather.load_json", return_value=min_result_json): + def test_skip_all_failed_runs(self, capsys, min_valid_quickrun_result): + del min_valid_quickrun_result.unit_results["ProtocolUnitResult-e85"] + with mock.patch( + "openfecli.commands.gather._QuickrunResult.from_json", + return_value=min_valid_quickrun_result, + ): result = _load_valid_result_json(fpath="") captured = capsys.readouterr() assert result == ((("lig_ejm_31", "lig_ejm_42"), "solvent"), None) assert "Exception found in all" in captured.err - def test_missing_pr_data(self, capsys, min_result_json): - min_result_json["protocol_result"]["data"] = {} - with mock.patch("openfecli.commands.gather.load_json", return_value=min_result_json): + def test_missing_pr_data(self, capsys, min_valid_quickrun_result): + min_valid_quickrun_result.protocol_result["data"] = {} + with mock.patch( + "openfecli.commands.gather._QuickrunResult.from_json", + return_value=min_valid_quickrun_result, + ): result = _load_valid_result_json(fpath="") captured = capsys.readouterr() assert result == (None, None) assert "Missing ligand names and/or simulation type. Skipping" in captured.err - def test_get_legs_from_result_jsons(self, capsys, min_result_json): + def test_get_legs_from_result_jsons(self, capsys, min_valid_quickrun_result): """Test that exceptions are handled correctly at the _get_legs_from_results_json level.""" - min_result_json["protocol_result"]["data"] = {} + min_valid_quickrun_result.protocol_result["data"] = {} - with mock.patch("openfecli.commands.gather.load_json", return_value=min_result_json): + with mock.patch( + "openfecli.commands.gather._QuickrunResult.from_json", + return_value=min_valid_quickrun_result, + ): result = _get_legs_from_result_jsons(result_fns=[""], report="dg") captured = capsys.readouterr() assert result == {}