Skip to content

Commit 2d687fe

Browse files
authored
Several more atom mapping improvements (#839)
This PR improves several aspects of atom mappings. two key fixes: 1) Stopping conformer generation in `ARCSpecies._scissors`. This is a major issue that improves both accuracy and performence: 1) Performence - less RDKit calls 2) Accuracy - Generating conformers distorts the XYZ of the scissored products. Moreover, this is critical for cyclic species, since debugging showed we did not include the original XYZ in the ARCSpecies, which caused errors. 2) the `identify_superimposable_candidates`, which was being called $\mathcal{O}(N^2)$ times and reduces to $\mathcal{O}(N)$. More calls are not required, and does not increase accuracy. (There is some more math behind it, and it can be discussed offline @alongd). Other fixes added, and more generalization of AM testing (manually verified)
2 parents a56feff + a56bd2c commit 2d687fe

13 files changed

Lines changed: 414 additions & 140 deletions

File tree

arc/common.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@
3636
logger = logging.getLogger('arc')
3737
logging.getLogger('matplotlib.font_manager').disabled = True
3838

39+
try:
40+
from rdkit import RDLogger
41+
RDLogger.DisableLog('rdApp.*')
42+
except ImportError:
43+
pass
44+
3945
# Absolute path to the ARC folder.
4046
ARC_PATH = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
4147
ARC_TESTING_PATH = os.path.join(ARC_PATH, 'arc', 'testing')

arc/family/family.py

Lines changed: 128 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,37 @@
2424
logger = get_logger()
2525

2626

27+
REACTION_FAMILY_CACHE: dict[tuple[str, bool], 'ReactionFamily'] = {}
28+
29+
# Pre-compiled regex patterns
30+
ENTRY_PATTERN = re.compile(r'entry\((.*?)\)', re.DOTALL)
31+
LABEL_PATTERN = re.compile(r'label\s*=\s*(["\'])(.*?)\1|label\s*=\s*(\w+)')
32+
GROUP_PATTERN = re.compile(r'group\s*=\s*(?:("""(.*?)"""|"(.*?)"|\'(.*?)\')|(OR\{.*?\}))', re.DOTALL)
33+
REVERSIBLE_PATTERN = re.compile(r'reversible\s*=\s*(True|False)')
34+
OWN_REVERSE_PATTERN = re.compile(r'ownReverse\s*=\s*(True|False)')
35+
RECIPE_PATTERN = re.compile(r'recipe\((.*?)\)', re.DOTALL)
36+
REACTANTS_PATTERN = re.compile(r'reactants\s*=\s*\[(.*?)\]', re.DOTALL)
37+
PRODUCTS_PATTERN = re.compile(r'products\s*=\s*\[(.*?)\]', re.DOTALL)
38+
ACTIONS_PATTERN = re.compile(r'actions\s*=\s*\[(.*?)\]', re.DOTALL)
39+
40+
41+
def get_reaction_family(label: str, consider_arc_families: bool = True) -> 'ReactionFamily':
42+
"""
43+
A helper function for getting a cached ReactionFamily object.
44+
45+
Args:
46+
label (str): The reaction family label.
47+
consider_arc_families (bool, optional): Whether to consider ARC's custom families.
48+
49+
Returns:
50+
ReactionFamily: The ReactionFamily object.
51+
"""
52+
key = (label, consider_arc_families)
53+
if key not in REACTION_FAMILY_CACHE:
54+
REACTION_FAMILY_CACHE[key] = ReactionFamily(label=label, consider_arc_families=consider_arc_families)
55+
return REACTION_FAMILY_CACHE[key]
56+
57+
2758
def get_rmg_db_subpath(*parts: str, must_exist: bool = False) -> str:
2859
"""Return a path under the RMG database, handling both source and packaged layouts."""
2960
if RMG_DB_PATH is None:
@@ -108,7 +139,17 @@ def __init__(self,
108139
self.groups_as_lines = read_groups_file_lines(label, consider_arc_families)
109140
self.reversible = is_reversible(self.groups_as_lines)
110141
self.own_reverse = is_own_reverse(self.groups_as_lines)
111-
self.reactants = get_reactant_groups_from_template(self.groups_as_lines)
142+
143+
reactant_labels = get_initial_reactant_labels_from_template(self.groups_as_lines)
144+
all_necessary_entries = get_entries(self.groups_as_lines, entry_labels=reactant_labels, recursive=True)
145+
self.reactants = get_reactant_groups_from_template(self.groups_as_lines, entries=all_necessary_entries)
146+
self.entries = all_necessary_entries
147+
148+
self.groups = {}
149+
for reactant_group in self.reactants:
150+
for label in reactant_group:
151+
if label not in self.groups and label in self.entries:
152+
self.groups[label] = Group().from_adjacency_list(self.entries[label])
112153
self.reactant_num = self.get_reactant_num()
113154
self.product_num = get_product_num(self.groups_as_lines)
114155
entry_labels = list()
@@ -156,6 +197,9 @@ def generate_products(self,
156197
for group_label in group_labels:
157198
group = self.groups_by_label[group_label]
158199
for mol in reactant.mol_list or [reactant.mol]:
200+
if not any(a.atomtype for a in mol.atoms):
201+
# Update atomtypes if they are missing (e.g., from SMILES)
202+
mol.update_atomtypes(log_species=False, raise_exception=False)
159203
splits = group.split()
160204
if mol.is_subgraph_isomorphic(other=group, save_order=True) \
161205
or len(splits) > 1 and any(mol.is_subgraph_isomorphic(other=g, save_order=True) for g in splits):
@@ -297,9 +341,15 @@ def generate_bimolecular_products(self,
297341
group_2 = self.groups_by_label[reactant_to_group_map_2['subgroup']]
298342
isomorphic_subgraphs_1 = mol_1.find_subgraph_isomorphisms(other=group_1, save_order=True)
299343
isomorphic_subgraphs_2 = mol_2.find_subgraph_isomorphisms(other=group_2, save_order=True)
344+
300345
if len(isomorphic_subgraphs_1) and len(isomorphic_subgraphs_2):
301346
for isomorphic_subgraph_1 in isomorphic_subgraphs_1:
302347
for isomorphic_subgraph_2 in isomorphic_subgraphs_2:
348+
# Create the combined isomorphic subgraph.
349+
# Note: get_isomorphic_subgraph needs to know which subgraph corresponds to which template index.
350+
# It assumes mol_1 corresponds to the first group match and mol_2 to the second.
351+
# The labels are already inside the group_atom.label.
352+
303353
isomorphic_subgraph_dicts.append(
304354
{'mols': [mol_1, mol_2],
305355
'subgroups': (reactant_to_group_map_1['subgroup'],
@@ -422,7 +472,7 @@ def get_reactant_num(self) -> int:
422472
if match:
423473
return int(match.group(1))
424474
if len(self.reactants) == 1:
425-
group = Group().from_adjacency_list(get_group_adjlist(self.groups_as_lines, entry_label=self.reactants[0][0]))
475+
group = self.groups[self.reactants[0][0]]
426476
groups = group.split()
427477
return len(groups)
428478
else:
@@ -523,7 +573,7 @@ def determine_possible_reaction_products_from_family(rxn: ARCReaction,
523573
and whether the family's template also represents its own reverse.
524574
"""
525575
product_dicts = list()
526-
family = ReactionFamily(label=family_label, consider_arc_families=consider_arc_families)
576+
family = get_reaction_family(label=family_label, consider_arc_families=consider_arc_families)
527577
products = family.generate_products(reactants=rxn.get_reactants_and_products(return_copies=True)[0])
528578
if products:
529579
for group_labels, product_lists in products.items():
@@ -765,11 +815,10 @@ def is_reversible(groups_as_lines: list[str]) -> bool:
765815
Returns:
766816
bool: Whether the reaction family is reversible.
767817
"""
768-
for line in groups_as_lines:
769-
if 'reversible = True' in line:
770-
return True
771-
if 'reversible = False' in line:
772-
return False
818+
groups_str = ''.join(groups_as_lines)
819+
match = REVERSIBLE_PATTERN.search(groups_str)
820+
if match:
821+
return match.group(1) == 'True'
773822
return True
774823

775824

@@ -780,36 +829,41 @@ def is_own_reverse(groups_as_lines: list[str]) -> bool:
780829
Returns:
781830
bool: Whether the reaction family's template also represents its own reverse.
782831
"""
783-
for line in groups_as_lines:
784-
if 'ownReverse=True' in line:
785-
return True
786-
if 'ownReverse=False' in line:
787-
return False
832+
groups_str = ''.join(groups_as_lines)
833+
match = OWN_REVERSE_PATTERN.search(groups_str)
834+
if match:
835+
return match.group(1) == 'True'
788836
return False
789837

790838

791-
def get_reactant_groups_from_template(groups_as_lines: list[str]) -> list[list[str]]:
839+
def get_reactant_groups_from_template(groups_as_lines: list[str],
840+
entries: dict[str, str] | None = None,
841+
) -> list[list[str]]:
792842
"""
793843
Get the reactant groups from a template content string.
794844
Descends the entries if a group is defined as an OR complex,
795845
e.g.: group = "OR{Xtrirad_H, Xbirad_H, Xrad_H, X_H}"
796846
797847
Args:
798848
groups_as_lines (list[str]): The template content string.
849+
entries (dict[str, str], optional): Pre-extracted entries.
799850
800851
Returns:
801852
list[list[str]]: The non-complex reactant groups.
802853
"""
803854
reactant_labels = get_initial_reactant_labels_from_template(groups_as_lines)
855+
if entries is None:
856+
entries = get_entries(groups_as_lines, entry_labels=reactant_labels)
804857
result = list()
805858
for reactant_label in reactant_labels:
806-
if 'OR{' not in get_group_adjlist(groups_as_lines, entry_label=reactant_label):
859+
adj = get_group_adjlist(groups_as_lines, entry_label=reactant_label, entries=entries)
860+
if 'OR{' not in adj:
807861
result.append([reactant_label])
808862
else:
809863
stack = [reactant_label]
810-
while any('OR{' in get_group_adjlist(groups_as_lines, entry_label=label) for label in stack):
864+
while any('OR{' in get_group_adjlist(groups_as_lines, entry_label=label, entries=entries) for label in stack):
811865
label = stack.pop(0)
812-
group_adjlist = get_group_adjlist(groups_as_lines, entry_label=label)
866+
group_adjlist = get_group_adjlist(groups_as_lines, entry_label=label, entries=entries)
813867
if 'OR{' not in group_adjlist:
814868
stack.append(label)
815869
else:
@@ -851,7 +905,7 @@ def descent_complex_group(group: str) -> list[str]:
851905
list[str]: The non-complex reactant group labels, e.g.: ['Xtrirad_H', 'Xbirad_H', 'Xrad_H', 'X_H'].
852906
"""
853907
if group.startswith('OR{') and group.endswith('}'):
854-
group = [g.strip() for g in group[3:-1].split(',')]
908+
group = [c.strip() for c in group[3:-1].split(',')]
855909
if isinstance(group, str):
856910
group = [group]
857911
return group
@@ -871,13 +925,15 @@ def get_initial_reactant_labels_from_template(groups_as_lines: list[str],
871925
Returns:
872926
list[str]: The reactant groups.
873927
"""
874-
labels = list()
875-
for line in groups_as_lines:
876-
match = re.search(r'products=\[(.*?)\]', line) if products else re.search(r'reactants=\[(.*?)\]', line)
877-
if match:
878-
labels = match.group(1).replace('"', '').split(', ')
879-
break
880-
return labels
928+
groups_str = ''.join(groups_as_lines)
929+
pattern = PRODUCTS_PATTERN if products else REACTANTS_PATTERN
930+
match = pattern.search(groups_str)
931+
if match:
932+
content = match.group(1)
933+
# Use regex to find all quoted strings (with backreferences) or unquoted words
934+
labels = re.findall(r'(["\'])(.*?)\1|(\w+)', content)
935+
return [label[1] or label[2] for label in labels]
936+
return list()
881937

882938

883939
def get_recipe_actions(groups_as_lines: list[str]) -> list[list[str]]:
@@ -982,43 +1038,77 @@ def split_entries(groups_str: str) -> list[str]:
9821038

9831039
def get_entries(groups_as_lines: list[str],
9841040
entry_labels: list[str],
1041+
recursive: bool = False,
9851042
) -> dict[str, str]:
9861043
"""
987-
Get the requested entries grom a template content string.
1044+
Get the requested entries from a template content string.
9881045
9891046
Args:
9901047
groups_as_lines (list[str]): The template content string.
991-
entry_labels (list[str]): The entry labels to extract.
1048+
entry_labels (list[str], optional): The entry labels to extract. If None, all entries are extracted.
1049+
recursive (bool, optional): Whether to recursively extract child entries for OR complexes.
9921050
9931051
Returns:
9941052
dict[str, str]: The extracted entries, keys are the labels, values are the groups.
9951053
"""
996-
groups_str = ''.join(groups_as_lines)
997-
entries = split_entries(groups_str)
998-
specific_entries = dict()
999-
for i, entry in enumerate(entries):
1000-
label_match = re.search(r'label\s*=\s*"(.*?)"', entry)
1001-
group_match = re.search(r'group\s*=(.*?)(?=\w+\s*=)', entry, re.DOTALL)
1002-
if label_match is not None and group_match is not None and label_match.group(1) in entry_labels:
1003-
specific_entries[label_match.group(1)] = clean_text(group_match.group(1))
1004-
if i > 2000:
1005-
break
1006-
return specific_entries
1054+
groups_str = "\n" + "".join(groups_as_lines)
1055+
# Split by `entry(` but keep the delimiter-ish part
1056+
parts = re.split(r"\nentry\s*\(", groups_str)
1057+
1058+
temp_entries = {}
1059+
label_pat = re.compile(r"label\s*=\s*(?:([\"'])(.*?)\1|(\w+))")
1060+
group_pat = re.compile(r"group\s*=\s*(?:\"\"\"(.*?)\"\"\"|([\"'])(.*?)\2|(OR\{.*?\}))", re.DOTALL)
1061+
1062+
for part in parts[1:]: # Skip the header
1063+
label_match = label_pat.search(part)
1064+
group_match = group_pat.search(part)
1065+
if label_match and group_match:
1066+
label = label_match.group(2) or label_match.group(3)
1067+
# Extract the matched regex group (1 for triple quotes, 3 for single/double quotes, 4 for OR complex)
1068+
adj = group_match.group(1) or group_match.group(3) or group_match.group(4)
1069+
temp_entries[label] = clean_text(adj)
1070+
1071+
if entry_labels is None:
1072+
return temp_entries
1073+
1074+
all_entries = {}
1075+
to_process = list(entry_labels)
1076+
processed = set()
1077+
while to_process:
1078+
label = to_process.pop()
1079+
if label in processed or label not in temp_entries:
1080+
continue
1081+
processed.add(label)
1082+
adj = temp_entries[label]
1083+
if recursive and 'OR{' in adj:
1084+
# Match OR{label1, label2, ...}
1085+
or_match = re.search(r'OR\s*\{\s*(.*?)\s*\}', adj, re.DOTALL)
1086+
if or_match:
1087+
children_str = or_match.group(1)
1088+
children = [c.strip() for c in children_str.split(',')]
1089+
to_process.extend(children)
1090+
else:
1091+
all_entries[label] = adj
1092+
return all_entries
10071093

10081094

10091095
def get_group_adjlist(groups_as_lines: list[str],
10101096
entry_label: str,
1097+
entries: dict[str, str] | None = None,
10111098
) -> str:
10121099
"""
10131100
Get the corresponding group value for the given entry label.
10141101
10151102
Args:
10161103
groups_as_lines (list[str]): The template content string.
10171104
entry_label (str): The entry label to extract.
1105+
entries (dict[str, str], optional): Pre-extracted entries.
10181106
10191107
Returns:
10201108
str: The extracted group.
10211109
"""
1110+
if entries is not None and entry_label in entries:
1111+
return entries[entry_label]
10221112
specific_entries = get_entries(groups_as_lines, entry_labels=[entry_label])
10231113
return specific_entries[entry_label]
10241114

arc/job/adapters/ts/heuristics_test.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@
3636
from arc.species.species import ARCSpecies
3737
from arc.species.zmat import _compare_zmats, get_parameter_from_atom_indices
3838

39+
from arc.species.species import check_isomorphism
40+
from arc.species.zmat import remove_zmat_atom_0
41+
from arc.species.converter import relocate_zmat_dummy_atoms_to_the_end
42+
3943

4044
class TestHeuristicsAdapter(unittest.TestCase):
4145
"""
@@ -1409,11 +1413,32 @@ def test_get_new_zmat2_map(self):
14091413
# expected_new_map = {0: 12, 1: 13, 2: 'X24', 3: 14, 4: 15, 5: 16, 6: 'X25', 7: 17, 8: 'X26', 9: 18, 10: 19,
14101414
# 11: 20, 12: 21, 13: 22, 14: 'X27', 15: 23, 16: 'X28', 17: 2, 18: 3, 19: 1, 21: 4, 23: 0,
14111415
# 25: 7, 26: 6, 28: 5, 20: 'X8', 22: 'X9', 24: 'X10', 27: 'X11'}
1412-
expected_new_map = {0: 12, 1: 13, 2: 'X24', 3: 14, 4: 15, 5: 16, 6: 'X25', 7: 17, 8: 'X26', 9: 18, 10: 19,
1413-
11: 20, 12: 21, 13: 22, 14: 'X27', 15: 23, 16: 'X28', 17: 2, 18: 1, 19: 3, 21: 0, 23: 4,
1414-
25: 5, 26: 6, 28: 7, 20: 'X8', 22: 'X9', 24: 'X10', 27: 'X11'}
1415-
1416-
self.assertEqual(new_map, expected_new_map)
1416+
1417+
# Test isomorphism of the mapped reactant_2 part
1418+
zmat_2_mod = remove_zmat_atom_0(self.zmat_6)
1419+
zmat_2_mod['map'] = relocate_zmat_dummy_atoms_to_the_end(zmat_2_mod['map'])
1420+
spc_from_zmat_2 = ARCSpecies(label='spc_from_zmat_2', xyz=zmat_2_mod, multiplicity=reactant_2.multiplicity,
1421+
number_of_radicals=reactant_2.number_of_radicals, charge=reactant_2.charge)
1422+
1423+
# Verify that all physical atom indices in new_map that came from zmat_2 correctly map to reactant_2
1424+
# Atom indices in new_map are for the combined species.
1425+
# Atoms 0-16 in self.zmat_5, atoms 1-12 in self.zmat_6 (13 atoms total, index 0 removed).
1426+
# In get_new_zmat_2_map, zmat_2 atoms are mapped to indices in new_map.
1427+
1428+
num_atoms_1 = len(self.zmat_5['symbols'])
1429+
atom_map = dict()
1430+
for i in range(1, len(self.zmat_6['symbols'])):
1431+
if not isinstance(self.zmat_6['symbols'][i], str) or self.zmat_6['symbols'][i] != 'X':
1432+
# This is a physical atom in zmat_2 (at index i)
1433+
# Its index in the combined Z-Matrix is num_atoms_1 + i - 1
1434+
combined_idx = num_atoms_1 + i - 1
1435+
if combined_idx in new_map:
1436+
# new_map[combined_idx] is the index in reactant_2
1437+
# i-1 is the index in spc_from_zmat_2
1438+
atom_map[i-1] = new_map[combined_idx]
1439+
1440+
# Verify the atom_map is a valid isomorphism
1441+
self.assertTrue(check_isomorphism(spc_from_zmat_2.mol, reactant_2.mol, atom_map))
14171442

14181443
def test_get_new_map_based_on_zmat_1(self):
14191444
"""Test the get_new_map_based_on_zmat_1() function."""

arc/job/adapters/ts/linear_test.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -247,9 +247,9 @@ def _make_rxn_2() -> ARCReaction:
247247
H 0.27058353 -0.73979548 1.43184405""")])
248248

249249

250-
class TestHeuristicsAdapter(unittest.TestCase):
250+
class TestLinearAdapter(unittest.TestCase):
251251
"""
252-
Contains unit tests for the HeuristicsAdapter class.
252+
Contains unit tests for the LinearAdapter class.
253253
"""
254254

255255
@classmethod
@@ -1133,7 +1133,7 @@ def test_interpolate_1_plus_2_cycloaddition(self):
11331133
self.assertEqual(len(ts_xyz['symbols']), 10)
11341134
self.assertFalse(colliding_atoms(ts_xyz),
11351135
msg=f'Collision in 1+2_Cycloaddition TS:\n{xyz_to_str(ts_xyz)}')
1136-
expected_ts = """C 1.59999925 -0.11618654 -0.14166302
1136+
expected_ts_1 = """C 1.59999925 -0.11618654 -0.14166302
11371137
C 0.29517860 -0.02143486 -0.02613492
11381138
C -1.15821797 -1.12490772 0.14486040
11391139
C -0.81238032 0.84414025 0.04444949
@@ -1143,7 +1143,18 @@ def test_interpolate_1_plus_2_cycloaddition(self):
11431143
H -1.52801447 -1.64655150 -0.72678867
11441144
H -0.94547237 1.40230195 0.96062403
11451145
H -1.11212744 1.33826544 -0.86912905"""
1146-
self.assertTrue(any(almost_equal_coords(ts, str_to_xyz(expected_ts)) for ts in ts_xyzs))
1146+
expected_ts_2 = """C 1.59999925 -0.11618654 -0.14166302
1147+
C 0.29517860 -0.02143486 -0.02613492
1148+
C -0.92013120 -0.71833111 0.10894610
1149+
C -0.99229728 1.28107087 0.04554500
1150+
H 2.21797993 0.77036923 -0.22897655
1151+
H 2.09015362 -1.08321135 -0.15246324
1152+
H -1.12327237 -1.17593811 1.06705013
1153+
H -1.28992770 -1.23997489 -0.76270297
1154+
H -1.12538933 1.83923257 0.96171954
1155+
H -1.29204440 1.77519606 -0.86803354"""
1156+
self.assertTrue(any(almost_equal_coords(ts, str_to_xyz(expected_ts_1)) for ts in ts_xyzs) or
1157+
any(almost_equal_coords(ts, str_to_xyz(expected_ts_2)) for ts in ts_xyzs))
11471158
# The TS should have extended forming bonds compared to the product.
11481159
# Get atom map to find the carbene atom in the product.
11491160
atom_map = map_rxn(rxn=rxn)

0 commit comments

Comments
 (0)