diff --git a/src/openfe_analysis/rmsd.py b/src/openfe_analysis/rmsd.py index b54beb9..2ee8e32 100644 --- a/src/openfe_analysis/rmsd.py +++ b/src/openfe_analysis/rmsd.py @@ -344,7 +344,11 @@ def gather_rms_data( output["protein_2D_RMSD"].append(prot_rmsd2d.results.rmsd2d) if ligand: + # For now, leave it at the normal RMSD lig_rmsd = RMSDAnalysis(ligand, mass_weighted=True).run(step=skip) + # state_lig = select_state_atoms(u, end_state="A").select_atoms("resname UNK") + # guess_ligand_bonds(state_lig, delete_existing=True) + # lig_rmsd = SymmetryCorrectedLigandRMSD(state_lig).run(step=skip) output["ligand_RMSD"].append(lig_rmsd.results.rmsd) lig_com_drift = LigandCOMDrift(ligand).run(step=skip) diff --git a/src/openfe_analysis/tests/utils/test_universe_utils.py b/src/openfe_analysis/tests/utils/test_universe_utils.py index e104cb8..839d641 100644 --- a/src/openfe_analysis/tests/utils/test_universe_utils.py +++ b/src/openfe_analysis/tests/utils/test_universe_utils.py @@ -1,7 +1,11 @@ +import MDAnalysis as mda +import numpy as np import pytest +from rdkit import Chem from openfe_analysis.utils import apply_transformations from openfe_analysis.utils.universe_utils import ( + correct_elements, create_universe_single_state, guess_ligand_bonds, select_state_atoms, @@ -77,3 +81,106 @@ def test_select_state_atoms_shared_atoms(universe): shared_a = set(atom.ix for atom in state_a if atom.bfactor == 0.5) shared_b = set(atom.ix for atom in state_b if atom.bfactor == 0.5) assert shared_a == shared_b + + +def test_correct_elements_fixes_element(): + """correct_elements should update element where rdmol differs.""" + + # Build a minimal universe with a C atom + u = mda.Universe.empty(2, n_residues=1, trajectory=True) + u.add_TopologyAttr("elements", ["C", "C"]) # second atom is wrong + u.add_TopologyAttr("names", ["C1", "C2"]) + u.add_TopologyAttr("resnames", ["UNK"]) + u.add_TopologyAttr("resids", [1]) + u.load_new( + np.array([[[0.0, 0.0, 0.0], [1.5, 0.0, 0.0]]]), + order="fac", + ) + ag = u.select_atoms("all") + + mol = Chem.RWMol() + mol.AddAtom(Chem.Atom(6)) # C + mol.AddAtom(Chem.Atom(7)) # N + rdmol = mol.GetMol() + + with pytest.warns(UserWarning, match="No atom_mapping provided"): + correct_elements(ag, rdmol) + + assert ag[0].element == "C" + assert ag[1].element == "N" + assert ag[1].name == "N" + + +def test_correct_elements_no_change_when_correct(): + """correct_elements should not modify atoms that already have correct elements.""" + + u = mda.Universe.empty(2, n_residues=1, trajectory=True) + u.add_TopologyAttr("elements", ["C", "N"]) + u.add_TopologyAttr("names", ["C1", "N1"]) + u.add_TopologyAttr("resnames", ["UNK"]) + u.add_TopologyAttr("resids", [1]) + u.load_new( + np.array([[[0.0, 0.0, 0.0], [1.5, 0.0, 0.0]]]), + order="fac", + ) + ag = u.select_atoms("all") + + mol = Chem.RWMol() + mol.AddAtom(Chem.Atom(6)) # C + mol.AddAtom(Chem.Atom(7)) # N + rdmol = mol.GetMol() + + with pytest.warns(UserWarning, match="No atom_mapping provided"): + correct_elements(ag, rdmol) + + assert ag[0].element == "C" + assert ag[0].name == "C1" # name unchanged + assert ag[1].element == "N" + assert ag[1].name == "N1" # name unchanged + + +def test_correct_elements_with_atom_mapping(): + """correct_elements with atom_mapping should use mapping""" + + u = mda.Universe.empty(2, n_residues=1, trajectory=True) + u.add_TopologyAttr("elements", ["C", "C"]) # second atom is wrong + u.add_TopologyAttr("names", ["C1", "C2"]) + u.add_TopologyAttr("resnames", ["UNK"]) + u.add_TopologyAttr("resids", [1]) + u.load_new( + np.array([[[0.0, 0.0, 0.0], [1.5, 0.0, 0.0]]]), + order="fac", + ) + ag = u.select_atoms("all") + + # rdmol has atoms in reverse order: N, C + mol = Chem.RWMol() + mol.AddAtom(Chem.Atom(7)) # N at rdmol index 0 + mol.AddAtom(Chem.Atom(6)) # C at rdmol index 1 + rdmol = mol.GetMol() + + # explicitly map ag index 0 -> rdmol index 1 (C), ag index 1 -> rdmol index 0 (N) + correct_elements(ag, rdmol, atom_mapping={0: 1, 1: 0}) + + assert ag[0].element == "C" # mapped to rdmol index 1 (C) + assert ag[1].element == "N" # mapped to rdmol index 0 (N) + assert ag[1].name == "N" + + +def test_correct_elements_raises_size_error(): + """correct_elements should raise ValueError if atom counts don't match.""" + + u = mda.Universe.empty(2, n_residues=1, trajectory=True) + u.add_TopologyAttr("elements", ["C", "N"]) + u.add_TopologyAttr("names", ["C1", "N1"]) + u.add_TopologyAttr("resnames", ["UNK"]) + u.add_TopologyAttr("resids", [1]) + u.load_new(np.array([[[0.0, 0.0, 0.0], [1.5, 0.0, 0.0]]]), order="fac") + ag = u.select_atoms("all") + + mol = Chem.RWMol() + mol.AddAtom(Chem.Atom(6)) # only 1 atom + rdmol = mol.GetMol() + + with pytest.raises(ValueError, match="atomgroup has 2 atoms but rdmol has 1"): + correct_elements(ag, rdmol) diff --git a/src/openfe_analysis/utils/universe_utils.py b/src/openfe_analysis/utils/universe_utils.py index f1db828..341c5a1 100644 --- a/src/openfe_analysis/utils/universe_utils.py +++ b/src/openfe_analysis/utils/universe_utils.py @@ -1,11 +1,13 @@ from __future__ import annotations +import warnings from pathlib import Path from typing import Literal import MDAnalysis as mda import netCDF4 as nc from MDAnalysis.guesser.tables import vdwradii as MDA_VDWRADII +from rdkit import Chem from ..reader import FEReader @@ -92,6 +94,65 @@ def guess_ligand_bonds( atomgroup.guess_bonds(vdwradii) +def correct_elements( + atomgroup: mda.AtomGroup, + rdmol: Chem.Mol, + atom_mapping: dict[int, int] | None = None, +) -> None: + """ + Correct element and atom names in an AtomGroup in-place + using an RDKit molecule. + + This is needed for hybrid topologies where mapped atoms that + undergo element changes carry state A's element types, even when + state B's ligand is selected. + + Parameters + ---------- + atomgroup : mda.AtomGroup + Ligand atoms whose elements and names will be corrected. + rdmol : Chem.Mol + RDKit molecule with the correct element and atom name information. + atom_mapping : dict[int, int], optional + A mapping of ``{atomgroup_index: rdmol_index}`` defining the + correspondence between atoms in ``atomgroup`` and ``rdmol``. If + ``None``, atoms are matched by position which gives wrong results if + the atom order was not the same. + + Raises + ------ + ValueError + If the number of atoms in ``atomgroup`` and ``rdmol`` do not match. + """ + periodic_table = Chem.GetPeriodicTable() + + if len(atomgroup) != rdmol.GetNumAtoms(): + raise ValueError( + f"atomgroup has {len(atomgroup)} atoms but rdmol has {rdmol.GetNumAtoms()} atoms." + ) + + if atom_mapping is not None: + for ag_idx, rd_idx in atom_mapping.items(): + mda_atom = atomgroup[ag_idx] + rd_atom = rdmol.GetAtomWithIdx(rd_idx) + element = periodic_table.GetElementSymbol(rd_atom.GetAtomicNum()) + if mda_atom.element != element: + mda_atom.element = element + mda_atom.name = rd_atom.GetSymbol() + else: + warnings.warn( + "No atom_mapping provided to correct_elements. Assuming that " + "atom ordering is the same between atomgroup and rdmol. This may " + "give incorrect results if the atom ordering differs between the two.", + UserWarning, + ) + for mda_atom, rd_atom in zip(atomgroup, rdmol.GetAtoms()): + element = periodic_table.GetElementSymbol(rd_atom.GetAtomicNum()) + if mda_atom.element != element: + mda_atom.element = element + mda_atom.name = rd_atom.GetSymbol() + + def create_universe_single_state( top: Path | mda.core.topology.Topology, trj: nc.Dataset, state: int ) -> mda.Universe: