Source code for gloryxr.som

"""
Functions to annotate the educt molecule with SOM indices.
"""

import numpy as np
from rdkit.Chem.rdchem import Mol
from rdkit.Chem.rdmolops import GetDistanceMatrix

__all__ = ["annotate_educt_and_product_inplace"]


[docs] def annotate_educt_and_product_inplace( educt: Mol, product: Mol, strict_soms: bool = False ) -> None: """ Annotate the educt and product molecules with SOM indices. """ product_idxs = ( _get_strict_som_indices(educt, product) if strict_soms else _get_loose_som_indices(product) ) for idx in product_idxs: atom = product.GetAtomWithIdx(idx) mapno = atom.GetIntProp("old_mapno") if atom.HasProp("old_mapno") else 1 atom.SetAtomMapNum(mapno) educt.GetAtomWithIdx(atom.GetIntProp("react_atom_idx")).SetAtomMapNum(mapno)
def _get_loose_som_indices(product: Mol) -> list[int]: return [ atom.GetIdx() for atom in product.GetAtoms() if atom.HasProp("old_mapno") and atom.GetAtomicNum() != 1 ] def _get_strict_som_indices(educt: Mol, product: Mol) -> list[int]: involved_idx_mappings = { atom.GetIntProp("react_atom_idx"): atom.GetIdx() for atom in product.GetAtoms() if atom.HasProp("react_atom_idx") and atom.GetAtomicNum() != 1 } added_by_reaction_idx = [ atom.GetIdx() for atom in product.GetAtoms() if not atom.HasProp("react_atom_idx") ] removed_by_reaction_idx = [ atom.GetIdx() for atom in educt.GetAtoms() if atom.GetIdx() not in involved_idx_mappings ] if len(removed_by_reaction_idx) != 0: return [ involved_idx_mappings[idx] for idx in _get_closest_idxs( educt, removed_by_reaction_idx, list(involved_idx_mappings.keys()) ) ] elif len(added_by_reaction_idx) != 0: return _get_closest_idxs( product, added_by_reaction_idx, list(involved_idx_mappings.values()) ) else: return [] def _get_closest_idxs( mol: Mol, reference_idx_: list[int], filter_idx_: list[int] ) -> list[int]: reference_idx = np.asarray(reference_idx_, dtype=int) filter_idx = np.asarray(filter_idx_, dtype=int) distances = GetDistanceMatrix(mol)[:, reference_idx].min(axis=1)[filter_idx] closest_indices = filter_idx[ np.argwhere(distances == distances.min(initial=np.inf)).flatten() ] return [int(idx) for idx in closest_indices]