Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 17 additions & 37 deletions src/openfecli/commands/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from openfecli import OFECommandPlugin
from openfecli.clicktypes import HyphenAwareChoice
from openfecli.commands.quickrun import _QuickrunResult

FAIL_STR = "Error" # string used to indicate a failed run in output tables.

Expand Down Expand Up @@ -181,30 +182,7 @@ def is_results_json(fpath: os.PathLike | str) -> bool:
return "estimate" in open(fpath, "r").read(20)


def load_json(fpath: os.PathLike | str) -> dict:
"""Load a JSON file containing a gufe object.

Parameters
----------
fpath : os.PathLike | str
The path to a gufe-serialized JSON.


Returns
-------
dict
A dict containing data from the results JSON.

"""
# TODO: move this function to openfe/utils
import json

from gufe.tokenization import JSON_HANDLER

return json.load(open(fpath, "r"), cls=JSON_HANDLER.decoder)


def _get_names(result: dict) -> tuple[str, str]:
def _get_names(result: _QuickrunResult) -> tuple[str, str]:
"""Get the ligand names from a unit's results data.

Parameters
Expand All @@ -219,7 +197,7 @@ def _get_names(result: dict) -> tuple[str, str]:
"""

# TODO: I don't like this [0][0] indexing, but I can't think of a better way currently
protocol_data = list(result["protocol_result"]["data"].values())[0][0]
protocol_data = list(result.protocol_result["data"].values())[0][0]
try:
name_A = protocol_data["inputs"]["setup_results"]["inputs"]["ligandmapping"]["componentA"][
"molprops"
Expand All @@ -234,10 +212,10 @@ def _get_names(result: dict) -> tuple[str, str]:
return str(name_A), str(name_B)


def _get_type(result: dict) -> Literal["vacuum", "solvent", "complex"]:
def _get_type(result: _QuickrunResult) -> Literal["vacuum", "solvent", "complex"]:
"""Determine the simulation type based on the component types."""

protocol_data = list(result["protocol_result"]["data"].values())[0][0]
protocol_data = list(result.protocol_result["data"].values())[0][0]
try:
component_types = [
x["__module__"]
Expand Down Expand Up @@ -276,7 +254,7 @@ def _legacy_get_type(res_fn: os.PathLike | str) -> Literal["vacuum", "solvent",


def _get_result_id(
result: dict, result_fn: os.PathLike | str
result: _QuickrunResult, result_fn: os.PathLike | str
) -> tuple[tuple[str, str], Literal["vacuum", "solvent", "complex"]]:
"""Extract the name and simulation type from a results dict.

Expand All @@ -302,7 +280,9 @@ def _get_result_id(
return (ligA, ligB), simtype


def _load_valid_result_json(fpath: os.PathLike | str) -> tuple[tuple | None, dict | None]:
def _load_valid_result_json(
fpath: os.PathLike | str,
) -> tuple[tuple | None, _QuickrunResult | None]:
"""Load the data from a results JSON into a dict.

Parameters
Expand All @@ -317,25 +297,25 @@ def _load_valid_result_json(fpath: os.PathLike | str) -> tuple[tuple | None, dic
or None if the JSON file is invalid or missing.

"""

# TODO: only load this once during collection, then pass namedtuple(fname, dict) into this function
# for now though, it's not the bottleneck on performance
result = load_json(fpath)
result = _QuickrunResult.from_json(fpath)

try:
result_id = _get_result_id(result, fpath)
except (ValueError, IndexError):
click.secho(f"{fpath}: Missing ligand names and/or simulation type. Skipping.",err=True, fg="yellow") # fmt: skip
return None, None
if result["estimate"] is None:
if result.estimate is None:
click.secho(f"{fpath}: No 'estimate' found, assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip
return result_id, None
if result["uncertainty"] is None:
if result.uncertainty is None:
click.secho(f"{fpath}: No 'uncertainty' found, assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip
return result_id, None
if result["unit_results"] == {}:
if result.unit_results == {}:
click.secho(f"{fpath}: No 'unit_results' found, assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip
return result_id, None
if all("exception" in u for u in result["unit_results"].values()):
if all("exception" in u for u in result.unit_results.values()):
click.secho(f"{fpath}: Exception found in all 'unit_results', assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip
return result_id, None

Expand Down Expand Up @@ -679,7 +659,7 @@ def _get_legs_from_result_jsons(
v[0]["outputs"]["unit_estimate"],
v[0]["outputs"]["unit_estimate_error"],
)
for v in result["protocol_result"]["data"].values()
for v in result.protocol_result["data"].values()
]
legs[names][simtype].append(parsed_raw_data)
else:
Expand All @@ -689,7 +669,7 @@ def _get_legs_from_result_jsons(
else:
dGs = [
v[0]["outputs"]["unit_estimate"]
for v in result["protocol_result"]["data"].values()
for v in result.protocol_result["data"].values()
]
legs[names][simtype].extend(dGs)
return legs
Expand Down
38 changes: 20 additions & 18 deletions src/openfecli/commands/gather_abfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
from openfecli.commands.gather import (
_collect_result_jsons,
format_df_with_precision,
load_json,
rich_print_to_stdout,
)
from openfecli.quickrun_result import _QuickrunResult


def _get_name(result: dict) -> str:
def _get_name(result: _QuickrunResult) -> str:
"""Get the ligand name from a unit's results data.

Parameters
Expand All @@ -31,7 +31,7 @@ def _get_name(result: dict) -> str:
Ligand name corresponding to the results.
"""

solvent_data = list(result["protocol_result"]["data"]["solvent"].values())[0][0]
solvent_data = list(result.protocol_result["data"]["solvent"].values())[0][0]
try:
name = solvent_data["inputs"]["setup_results"]["inputs"]["alchemical_components"]["stateA"][
0
Expand All @@ -42,7 +42,9 @@ def _get_name(result: dict) -> str:
return str(name)


def _load_valid_result_json(fpath: os.PathLike | str) -> tuple[tuple | None, dict | None]:
def _load_valid_result_json(
fpath: os.PathLike | str,
) -> tuple[tuple | None, _QuickrunResult | None]:
"""Load the data from a results JSON into a dict.

Parameters
Expand All @@ -67,19 +69,19 @@ def _load_valid_result_json(fpath: os.PathLike | str) -> tuple[tuple | None, dic

# TODO: only load this once during collection, then pass namedtuple(fname, dict) into this function
# for now though, it's not the bottleneck on performance
result = load_json(fpath)
result = _QuickrunResult.from_json(fpath)
try:
names = _get_name(result)
except (ValueError, IndexError):
click.secho(f"{fpath}: Missing ligand names and/or simulation type. Skipping.",err=True, fg="yellow") # fmt: skip
return None, None
if result["estimate"] is None:
if result.estimate is None:
click.secho(f"{fpath}: No 'estimate' found, assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip
return names, None
if result["uncertainty"] is None:
if result.uncertainty is None:
click.secho(f"{fpath}: No 'uncertainty' found, assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip
return names, None
if all("exception" in u for u in result["unit_results"].values()):
if all("exception" in u for u in result.unit_results.values()):
click.secho(f"{fpath}: Exception found in all 'unit_results', assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip
return names, None
return names, result
Expand Down Expand Up @@ -116,25 +118,25 @@ def _get_legs_from_result_jsons(
if name is None: # this means it couldn't find name and/or simtype
continue

dgs[name]["overall"].append([result["estimate"], result["uncertainty"]])
proto_key = [k for k in result["unit_results"].keys() if k.startswith("ProtocolUnitResult")]
dgs[name]["overall"].append([result.estimate, result.uncertainty])
proto_key = [k for k in result.unit_results.keys() if k.startswith("ProtocolUnitResult")]
for p in proto_key:
# In openfe v1.9+, we only want to pick up results from
# the Analysis Unit. To ensure backwards compatibility with
# prior releases of openfe v1.x, we exclude Setup and Simulation
if (
"Setup" in result["unit_results"][p]["source_key"]
or "Simulation" in result["unit_results"][p]["source_key"]
"Setup" in result.unit_results[p]["source_key"]
or "Simulation" in result.unit_results[p]["source_key"]
):
continue
if "unit_estimate" in result["unit_results"][p]["outputs"]:
simtype = result["unit_results"][p]["outputs"]["simtype"]
dg = result["unit_results"][p]["outputs"]["unit_estimate"]
dg_error = result["unit_results"][p]["outputs"]["unit_estimate_error"]
if "unit_estimate" in result.unit_results[p]["outputs"]:
simtype = result.unit_results[p]["outputs"]["simtype"]
dg = result.unit_results[p]["outputs"]["unit_estimate"]
dg_error = result.unit_results[p]["outputs"]["unit_estimate_error"]

dgs[name][simtype].append([dg, dg_error])
if "standard_state_correction" in result["unit_results"][p]["outputs"]:
corr = result["unit_results"][p]["outputs"]["standard_state_correction"]
if "standard_state_correction" in result.unit_results[p]["outputs"]:
corr = result.unit_results[p]["outputs"]["standard_state_correction"]
# In openfe v1.9+, standard state corrections are set to 0 kcal/mol
# when no correction is being applied (e.g. no restraints).
# To make raw outputs similar to pre-v1.9, we exclude corrections
Expand Down
38 changes: 19 additions & 19 deletions src/openfecli/commands/gather_septop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from openfecli.commands.gather import (
_collect_result_jsons,
format_df_with_precision,
load_json,
rich_print_to_stdout,
)
from openfecli.quickrun_result import _QuickrunResult


def _load_valid_result_json(fpath: os.PathLike | str) -> tuple[tuple | None, dict | None]:
Expand Down Expand Up @@ -42,19 +42,19 @@ def _load_valid_result_json(fpath: os.PathLike | str) -> tuple[tuple | None, dic

# TODO: only load this once during collection, then pass namedtuple(fname, dict) into this function
# for now though, it's not the bottleneck on performance
result = load_json(fpath)
result = _QuickrunResult.from_json(fpath)
try:
names = _get_names(result)
except (ValueError, IndexError):
click.secho(f"{fpath}: Missing ligand names and/or simulation type. Skipping.",err=True, fg="yellow") # fmt: skip
return None, None
if result["estimate"] is None:
if result.estimate is None:
click.secho(f"{fpath}: No 'estimate' found, assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip
return names, None
if result["uncertainty"] is None:
if result.uncertainty is None:
click.secho(f"{fpath}: No 'uncertainty' found, assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip
return names, None
if all("exception" in u for u in result["unit_results"].values()):
if all("exception" in u for u in result.unit_results.values()):
click.secho(f"{fpath}: Exception found in all 'unit_results', assuming to be a failed simulation.",err=True, fg="yellow") # fmt: skip
return names, None
return names, result
Expand Down Expand Up @@ -91,10 +91,10 @@ def _get_legs_from_result_jsons(
if names is None: # this means it couldn't find names and/or simtype
continue

ddgs[names]["overall"].append([result["estimate"], result["uncertainty"]])
ddgs[names]["overall"].append([result.estimate, result.uncertainty])
proto_key = [
k
for k in result["unit_results"].keys()
for k in result.unit_results.keys()
if k.startswith("ProtocolUnitResult")
] # fmt: skip

Expand All @@ -103,27 +103,27 @@ def _get_legs_from_result_jsons(
# we check if there are any analysis units. If so,
# we set a flag and later exclude Setup and Run.
has_analysis_units = any(
["Analysis" in result["unit_results"][p]["source_key"] for p in proto_key]
["Analysis" in result.unit_results[p]["source_key"] for p in proto_key]
)

for p in proto_key:
# Skip non-analysis units if we have any
if has_analysis_units and (
"Setup" in result["unit_results"][p]["source_key"]
or "Run" in result["unit_results"][p]["source_key"]
"Setup" in result.unit_results[p]["source_key"]
or "Run" in result.unit_results[p]["source_key"]
):
continue

if "unit_estimate" in result["unit_results"][p]["outputs"]:
simtype = result["unit_results"][p]["outputs"]["simtype"]
dg = result["unit_results"][p]["outputs"]["unit_estimate"]
dg_error = result["unit_results"][p]["outputs"]["unit_estimate_error"]
if "unit_estimate" in result.unit_results[p]["outputs"]:
simtype = result.unit_results[p]["outputs"]["simtype"]
dg = result.unit_results[p]["outputs"]["unit_estimate"]
dg_error = result.unit_results[p]["outputs"]["unit_estimate_error"]

ddgs[names][simtype].append([dg, dg_error])

if "standard_state_correction_A" in result["unit_results"][p]["outputs"]:
corr_A = result["unit_results"][p]["outputs"]["standard_state_correction_A"]
corr_B = result["unit_results"][p]["outputs"]["standard_state_correction_B"]
if "standard_state_correction_A" in result.unit_results[p]["outputs"]:
corr_A = result.unit_results[p]["outputs"]["standard_state_correction_A"]
corr_B = result.unit_results[p]["outputs"]["standard_state_correction_B"]
ddgs[names]["standard_state_correction_A"].append(
[corr_A, 0 * unit.kilocalorie_per_mole]
)
Expand All @@ -134,7 +134,7 @@ def _get_legs_from_result_jsons(
return ddgs


def _get_names(result: dict) -> tuple[str, str]:
def _get_names(result: _QuickrunResult) -> tuple[str, str]:
"""Get the ligand names from a unit's results data.

Parameters
Expand All @@ -148,7 +148,7 @@ def _get_names(result: dict) -> tuple[str, str]:
Ligand names corresponding to the results.
"""

solvent_data = list(result["protocol_result"]["data"]["solvent"].values())[0][0]
solvent_data = list(result.protocol_result["data"]["solvent"].values())[0][0]

try:
setup_data = solvent_data["inputs"]["setup"]["inputs"]
Expand Down
22 changes: 9 additions & 13 deletions src/openfecli/commands/quickrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
# For details, see https://github.com/OpenFreeEnergy/openfe

import hashlib
import json
import pathlib
import warnings

import click

from openfecli import OFECommandPlugin
from openfecli.quickrun_result import _QuickrunResult
from openfecli.utils import configure_logger, print_duration, write


Expand Down Expand Up @@ -175,17 +174,14 @@ def quickrun(transformation, work_dir, output, resume):
else:
estimate = uncertainty = None # for output file

out_dict = {
"estimate": estimate,
"uncertainty": uncertainty,
"protocol_result": prot_result.to_dict(),
"unit_results": {
unit.key: unit.to_keyed_dict() for unit in dagresult.protocol_unit_results
},
}

with open(output, mode="w") as outf:
json.dump(out_dict, outf, cls=JSON_HANDLER.encoder)
quickrun_result = _QuickrunResult(
estimate=estimate,
uncertainty=uncertainty,
protocol_result=prot_result.to_dict(),
unit_results={unit.key: unit.to_keyed_dict() for unit in dagresult.protocol_unit_results},
)

quickrun_result.to_json(output)

# remove the cached dag since the job has completed
os.remove(cached_dag_path)
Expand Down
Loading
Loading