Skip to content

Commit 828ebef

Browse files
authored
Fam recognition (#852)
Added support for recognizing reaction families which were shown to be problematic in the past. Specifically, addressing: #813 #787 #738 #606 As well as: ``` Bimolec_Hydroperoxide_Decomposition Birad_R_Recombination Birad_recombination Br_Abstraction ``` that were mentioned offline by @kfir4444
2 parents e8ce91a + fa8eba2 commit 828ebef

3 files changed

Lines changed: 559 additions & 55 deletions

File tree

arc/family/family.py

Lines changed: 126 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import re
99

1010
from arc.common import clean_text, get_logger
11+
from arc.exceptions import InvalidAdjacencyListError
1112
from arc.imports import settings
1213
from arc.molecule import Bond, Group, Molecule
1314
from arc.molecule.resonance import generate_resonance_structures_safely
@@ -324,8 +325,14 @@ def apply_recipe(self,
324325
if action[0] in ['CHANGE_BOND', 'FORM_BOND', 'BREAK_BOND']:
325326
structure.reset_connectivity_values()
326327
label_1, info, label_2 = action[1:]
327-
atom_1 = structure.get_labeled_atoms(label_1)[0]
328-
atom_2 = structure.get_labeled_atoms(label_2)[0]
328+
labeled_1 = structure.get_labeled_atoms(label_1)
329+
atom_1 = labeled_1[0] if labeled_1 else None
330+
if label_1 == label_2 and len(labeled_1) >= 2:
331+
# Same label on two different atoms (e.g., R_Recombination: * + * → *-*)
332+
atom_2 = labeled_1[1]
333+
else:
334+
labeled_2 = structure.get_labeled_atoms(label_2)
335+
atom_2 = labeled_2[0] if labeled_2 else None
329336
if atom_1 is None or atom_2 is None or atom_1 is atom_2:
330337
raise ValueError('Invalid atom labels in reaction recipe.')
331338
if action[0] == 'CHANGE_BOND':
@@ -356,10 +363,11 @@ def apply_recipe(self,
356363
elif action[0] in ['LOSE_RADICAL', 'GAIN_RADICAL', 'LOSE_PAIR', 'GAIN_PAIR']:
357364
label, change = action[1:]
358365
change = int(change)
359-
atom = structure.get_labeled_atoms(label)[0]
360-
if atom is None:
366+
labeled_atoms = structure.get_labeled_atoms(label)
367+
if not labeled_atoms:
361368
raise ValueError(f'Unable to find atom with label "{label}" while applying reaction recipe.')
362-
atom.apply_action([action[0], label, change])
369+
for atom in labeled_atoms:
370+
atom.apply_action([action[0], label, change])
363371
else:
364372
raise ValueError(f'Unknown action "{action[0]}" encountered.')
365373
if 'validAromatic' in structure.props and not structure.props['validAromatic']:
@@ -378,10 +386,18 @@ def apply_recipe(self,
378386
def get_reactant_num(self) -> int:
379387
"""
380388
Get the number of reactants for this family.
389+
Uses the explicit ``reactantNum`` value from the groups file if available,
390+
otherwise infers from the template group structure.
381391
382392
Returns:
383393
int: The number of reactants.
384394
"""
395+
for line in self.groups_as_lines:
396+
stripped = line.strip()
397+
if stripped.startswith('reactantNum'):
398+
match = re.search(r'reactantNum\s*=\s*(\d+)', stripped)
399+
if match:
400+
return int(match.group(1))
385401
if len(self.reactants) == 1:
386402
group = Group().from_adjacency_list(get_group_adjlist(self.groups_as_lines, entry_label=self.reactants[0][0]))
387403
groups = group.split()
@@ -424,28 +440,31 @@ def get_reaction_family_products(rxn: 'ARCReaction',
424440
consider_arc_families=consider_arc_families)
425441
product_dicts = list()
426442
for family_label in family_labels:
427-
# Forward:
428-
products = determine_possible_reaction_products_from_family(rxn=rxn,
429-
family_label=family_label,
430-
consider_arc_families=consider_arc_families,
431-
reverse=False,
432-
)
433-
if len(products):
434-
product_dicts.extend(filter_products_by_reaction(rxn=rxn, product_dicts=products))
435-
436-
# Reverse:
437-
flipped_rxn = rxn.flip_reaction(report_family=False)
438-
products = determine_possible_reaction_products_from_family(rxn=flipped_rxn,
439-
family_label=family_label,
440-
consider_arc_families=consider_arc_families,
441-
reverse=True,
442-
)
443-
if len(products):
444-
filtered_products = filter_products_by_reaction(rxn=flipped_rxn, product_dicts=products)
445-
if not discover_own_reverse_rxns_in_reverse:
446-
product_dicts.extend([prod for prod in filtered_products if not prod['own_reverse']])
447-
else:
448-
product_dicts.extend(filtered_products)
443+
try:
444+
# Forward:
445+
products = determine_possible_reaction_products_from_family(rxn=rxn,
446+
family_label=family_label,
447+
consider_arc_families=consider_arc_families,
448+
reverse=False,
449+
)
450+
if len(products):
451+
product_dicts.extend(filter_products_by_reaction(rxn=rxn, product_dicts=products))
452+
453+
# Reverse:
454+
flipped_rxn = rxn.flip_reaction(report_family=False)
455+
products = determine_possible_reaction_products_from_family(rxn=flipped_rxn,
456+
family_label=family_label,
457+
consider_arc_families=consider_arc_families,
458+
reverse=True,
459+
)
460+
if len(products):
461+
filtered_products = filter_products_by_reaction(rxn=flipped_rxn, product_dicts=products)
462+
if not discover_own_reverse_rxns_in_reverse:
463+
product_dicts.extend([prod for prod in filtered_products if not prod['own_reverse']])
464+
else:
465+
product_dicts.extend(filtered_products)
466+
except (KeyError, ValueError, InvalidAdjacencyListError) as e:
467+
logger.debug(f'Skipping family {family_label} due to unsupported group definition: {type(e).__name__}: {e}')
449468
return product_dicts
450469

451470

@@ -462,7 +481,7 @@ def determine_possible_reaction_products_from_family(rxn: 'ARCReaction',
462481
[{'family': str: Family label,
463482
'group_labels': Tuple[str, str]: Group labels used to generate the products,
464483
'products': List['Molecule']: The generated products,
465-
'r_label_map': Dict[int, str]: Mapping of reactant atom indices to labels,
484+
'r_label_map': Dict[str, int]: Mapping of reactant atom indices to labels,
466485
'p_label_map': Dict[str, int]: Mapping of product labels to atom indices
467486
(refers to the given 'products' in this dict
468487
and not to the products of the original reaction),
@@ -489,7 +508,18 @@ def determine_possible_reaction_products_from_family(rxn: 'ARCReaction',
489508
template_mols, r_label_dict = product_list[0], product_list[1]
490509
if not isomorphic_products(rxn=rxn, products=template_mols):
491510
continue
492-
r_label_map = {val: key for key, val in r_label_dict.items() if val}
511+
# Build r_label_map preserving duplicate labels by suffixing
512+
# (e.g., R_Recombination has two atoms labeled '*' → '*' and '*_2').
513+
r_label_map = {}
514+
for key, val in r_label_dict.items():
515+
if not val:
516+
continue
517+
label = val
518+
suffix = 2
519+
while label in r_label_map:
520+
label = f'{val}_{suffix}'
521+
suffix += 1
522+
r_label_map[label] = key
493523
offsets = [0]
494524
for mol in template_mols:
495525
offsets.append(offsets[-1] + len(mol.atoms))
@@ -498,7 +528,12 @@ def determine_possible_reaction_products_from_family(rxn: 'ARCReaction',
498528
base = offsets[i]
499529
for j, atom in enumerate(mol.atoms):
500530
if atom.label:
501-
p_label_map[atom.label] = base + j
531+
label = atom.label
532+
suffix = 2
533+
while label in p_label_map:
534+
label = f'{atom.label}_{suffix}'
535+
suffix += 1
536+
p_label_map[label] = base + j
502537
product_dicts.append({
503538
'family': family_label,
504539
'group_labels': group_labels,
@@ -543,36 +578,65 @@ def check_product_isomorphism(products: List['Molecule'],
543578
) -> bool:
544579
"""
545580
Check whether the products are isomorphic to the given species.
546-
Supports unimolecular and bimolecular reactions.
581+
Falls back to InChI comparison when graph isomorphism fails
582+
(e.g., different Lewis structures perceived from XYZ vs SMILES).
547583
548584
Args:
549-
products (Tuple[List['Molecule'], Dict[int, str]]): The products to check.
585+
products (List['Molecule']): The products to check.
550586
p_species (List['ARCSpecies']): The species to check against.
551587
552588
Returns:
553589
bool: Whether the products are isomorphic to the species.
554590
"""
591+
if len(products) != len(p_species):
592+
return False
555593
prods_a = [generate_resonance_structures_safely(mol) or [mol.copy(deep=True)] for mol in products]
556594
prods_b = [spc.mol_list or [spc.mol] for spc in p_species]
557-
if len(prods_a) == 1:
558-
prod_a = prods_a[0]
559-
prod_b = prods_b[0]
560-
for mol_a in prod_a:
561-
if any(mol_b.is_isomorphic(mol_a) for mol_b in prod_b):
562-
return True
563-
if len(products) == 2:
564-
isomorphic = [False, False]
565-
for i, prod_a in enumerate(prods_a):
566-
skip = False
567-
for prod_b in prods_b:
568-
if skip:
595+
isomorphic = [False] * len(products)
596+
for i, prod_a in enumerate(prods_a):
597+
for prod_b in prods_b:
598+
if isomorphic[i]:
599+
break
600+
for mol_a in prod_a:
601+
if any(mol_b.is_isomorphic(mol_a) for mol_b in prod_b):
602+
isomorphic[i] = True
569603
break
570-
for mol_a in prod_a:
571-
if any(mol_b.is_isomorphic(mol_a) for mol_b in prod_b):
572-
isomorphic[i] = True
573-
skip = True
574-
return all(isomorphic)
575-
return False
604+
if all(isomorphic):
605+
return True
606+
# Fall back to InChI comparison for unmatched products.
607+
# Different Lewis structures perceived from XYZ vs SMILES (e.g., O=C=C(O)C=O vs O=C[C-](O)C#[O+])
608+
# may not be graph-isomorphic but share the same InChI.
609+
# InChI does not encode radical electrons, so also require matching multiplicity
610+
# to avoid false matches between biradicals and closed-shell species
611+
# (e.g., [CH2][CH2] vs C=C both give InChI=1S/C2H4/c1-2/h1-2H2).
612+
# Gate behind molecular-formula check to avoid expensive InChI generation
613+
# for products that can't possibly match.
614+
species_fingerprints = {spc.mol.fingerprint for spc in p_species if spc.mol is not None}
615+
needs_inchi = False
616+
for i in range(len(products)):
617+
if not isomorphic[i] and products[i].fingerprint in species_fingerprints:
618+
needs_inchi = True
619+
break
620+
if not needs_inchi:
621+
return False
622+
# Precompute p_species InChIs once (not per candidate product).
623+
try:
624+
species_inchi_mult = [(spc.mol.to_inchi(), spc.mol.multiplicity) for spc in p_species]
625+
except Exception:
626+
return False
627+
for i in range(len(products)):
628+
if not isomorphic[i]:
629+
if products[i].fingerprint not in species_fingerprints:
630+
continue
631+
try:
632+
inchi_a = products[i].to_inchi()
633+
mult_a = products[i].multiplicity
634+
except Exception:
635+
continue
636+
if any(inchi_a == inchi_b and mult_a == mult_b
637+
for inchi_b, mult_b in species_inchi_mult):
638+
isomorphic[i] = True
639+
return all(isomorphic)
576640

577641

578642
def get_all_families(rmg_family_set: Union[List[str], str] = 'default',
@@ -722,13 +786,21 @@ def get_reactant_groups_from_template(groups_as_lines: List[str]) -> List[List[s
722786
def get_product_num(groups_as_lines: List[str]) -> int:
723787
"""
724788
Get the number of products from a template content string.
789+
Uses the explicit ``productNum`` value from the groups file if available,
790+
otherwise infers from the template product labels.
725791
726792
Args:
727793
groups_as_lines (List[str]): The template content string.
728794
729795
Returns:
730796
int: The number of products.
731797
"""
798+
for line in groups_as_lines:
799+
stripped = line.strip()
800+
if stripped.startswith('productNum'):
801+
match = re.search(r'productNum\s*=\s*(\d+)', stripped)
802+
if match:
803+
return int(match.group(1))
732804
return len(get_initial_reactant_labels_from_template(groups_as_lines, products=True))
733805

734806

@@ -744,7 +816,7 @@ def descent_complex_group(group: str) -> List[str]:
744816
List[str]: The non-complex reactant group labels, e.g.: ['Xtrirad_H', 'Xbirad_H', 'Xrad_H', 'X_H'].
745817
"""
746818
if group.startswith('OR{') and group.endswith('}'):
747-
group = group[3:-1].split(', ')
819+
group = [g.strip() for g in group[3:-1].split(',')]
748820
if isinstance(group, str):
749821
group = [group]
750822
return group
@@ -789,7 +861,9 @@ def get_recipe_actions(groups_as_lines: List[str]) -> List[List[str]]:
789861
j = 0
790862
while '])' not in groups_as_lines[i + 1 + j]:
791863
if "['" in groups_as_lines[i + 1 + j]:
792-
actions.append(ast.literal_eval(groups_as_lines[i + 1 + j].strip())[0])
864+
line = groups_as_lines[i + 1 + j].strip()
865+
if not line.startswith('#'):
866+
actions.append(ast.literal_eval(line)[0])
793867
j += 1
794868
break
795869
return actions

0 commit comments

Comments
 (0)