2424logger = get_logger ()
2525
2626
27+ REACTION_FAMILY_CACHE : dict [tuple [str , bool ], 'ReactionFamily' ] = {}
28+
29+ # Pre-compiled regex patterns
30+ ENTRY_PATTERN = re .compile (r'entry\((.*?)\)' , re .DOTALL )
31+ LABEL_PATTERN = re .compile (r'label\s*=\s*(["\'])(.*?)\1|label\s*=\s*(\w+)' )
32+ GROUP_PATTERN = re .compile (r'group\s*=\s*(?:("""(.*?)"""|"(.*?)"|\'(.*?)\')|(OR\{.*?\}))' , re .DOTALL )
33+ REVERSIBLE_PATTERN = re .compile (r'reversible\s*=\s*(True|False)' )
34+ OWN_REVERSE_PATTERN = re .compile (r'ownReverse\s*=\s*(True|False)' )
35+ RECIPE_PATTERN = re .compile (r'recipe\((.*?)\)' , re .DOTALL )
36+ REACTANTS_PATTERN = re .compile (r'reactants\s*=\s*\[(.*?)\]' , re .DOTALL )
37+ PRODUCTS_PATTERN = re .compile (r'products\s*=\s*\[(.*?)\]' , re .DOTALL )
38+ ACTIONS_PATTERN = re .compile (r'actions\s*=\s*\[(.*?)\]' , re .DOTALL )
39+
40+
41+ def get_reaction_family (label : str , consider_arc_families : bool = True ) -> 'ReactionFamily' :
42+ """
43+ A helper function for getting a cached ReactionFamily object.
44+
45+ Args:
46+ label (str): The reaction family label.
47+ consider_arc_families (bool, optional): Whether to consider ARC's custom families.
48+
49+ Returns:
50+ ReactionFamily: The ReactionFamily object.
51+ """
52+ key = (label , consider_arc_families )
53+ if key not in REACTION_FAMILY_CACHE :
54+ REACTION_FAMILY_CACHE [key ] = ReactionFamily (label = label , consider_arc_families = consider_arc_families )
55+ return REACTION_FAMILY_CACHE [key ]
56+
57+
2758def get_rmg_db_subpath (* parts : str , must_exist : bool = False ) -> str :
2859 """Return a path under the RMG database, handling both source and packaged layouts."""
2960 if RMG_DB_PATH is None :
@@ -108,7 +139,17 @@ def __init__(self,
108139 self .groups_as_lines = read_groups_file_lines (label , consider_arc_families )
109140 self .reversible = is_reversible (self .groups_as_lines )
110141 self .own_reverse = is_own_reverse (self .groups_as_lines )
111- self .reactants = get_reactant_groups_from_template (self .groups_as_lines )
142+
143+ reactant_labels = get_initial_reactant_labels_from_template (self .groups_as_lines )
144+ all_necessary_entries = get_entries (self .groups_as_lines , entry_labels = reactant_labels , recursive = True )
145+ self .reactants = get_reactant_groups_from_template (self .groups_as_lines , entries = all_necessary_entries )
146+ self .entries = all_necessary_entries
147+
148+ self .groups = {}
149+ for reactant_group in self .reactants :
150+ for label in reactant_group :
151+ if label not in self .groups and label in self .entries :
152+ self .groups [label ] = Group ().from_adjacency_list (self .entries [label ])
112153 self .reactant_num = self .get_reactant_num ()
113154 self .product_num = get_product_num (self .groups_as_lines )
114155 entry_labels = list ()
@@ -156,6 +197,9 @@ def generate_products(self,
156197 for group_label in group_labels :
157198 group = self .groups_by_label [group_label ]
158199 for mol in reactant .mol_list or [reactant .mol ]:
200+ if not any (a .atomtype for a in mol .atoms ):
201+ # Update atomtypes if they are missing (e.g., from SMILES)
202+ mol .update_atomtypes (log_species = False , raise_exception = False )
159203 splits = group .split ()
160204 if mol .is_subgraph_isomorphic (other = group , save_order = True ) \
161205 or len (splits ) > 1 and any (mol .is_subgraph_isomorphic (other = g , save_order = True ) for g in splits ):
@@ -297,9 +341,15 @@ def generate_bimolecular_products(self,
297341 group_2 = self .groups_by_label [reactant_to_group_map_2 ['subgroup' ]]
298342 isomorphic_subgraphs_1 = mol_1 .find_subgraph_isomorphisms (other = group_1 , save_order = True )
299343 isomorphic_subgraphs_2 = mol_2 .find_subgraph_isomorphisms (other = group_2 , save_order = True )
344+
300345 if len (isomorphic_subgraphs_1 ) and len (isomorphic_subgraphs_2 ):
301346 for isomorphic_subgraph_1 in isomorphic_subgraphs_1 :
302347 for isomorphic_subgraph_2 in isomorphic_subgraphs_2 :
348+ # Create the combined isomorphic subgraph.
349+ # Note: get_isomorphic_subgraph needs to know which subgraph corresponds to which template index.
350+ # It assumes mol_1 corresponds to the first group match and mol_2 to the second.
351+ # The labels are already inside the group_atom.label.
352+
303353 isomorphic_subgraph_dicts .append (
304354 {'mols' : [mol_1 , mol_2 ],
305355 'subgroups' : (reactant_to_group_map_1 ['subgroup' ],
@@ -422,7 +472,7 @@ def get_reactant_num(self) -> int:
422472 if match :
423473 return int (match .group (1 ))
424474 if len (self .reactants ) == 1 :
425- group = Group (). from_adjacency_list ( get_group_adjlist ( self .groups_as_lines , entry_label = self .reactants [0 ][0 ]))
475+ group = self .groups [ self .reactants [0 ][0 ]]
426476 groups = group .split ()
427477 return len (groups )
428478 else :
@@ -523,7 +573,7 @@ def determine_possible_reaction_products_from_family(rxn: ARCReaction,
523573 and whether the family's template also represents its own reverse.
524574 """
525575 product_dicts = list ()
526- family = ReactionFamily (label = family_label , consider_arc_families = consider_arc_families )
576+ family = get_reaction_family (label = family_label , consider_arc_families = consider_arc_families )
527577 products = family .generate_products (reactants = rxn .get_reactants_and_products (return_copies = True )[0 ])
528578 if products :
529579 for group_labels , product_lists in products .items ():
@@ -765,11 +815,10 @@ def is_reversible(groups_as_lines: list[str]) -> bool:
765815 Returns:
766816 bool: Whether the reaction family is reversible.
767817 """
768- for line in groups_as_lines :
769- if 'reversible = True' in line :
770- return True
771- if 'reversible = False' in line :
772- return False
818+ groups_str = '' .join (groups_as_lines )
819+ match = REVERSIBLE_PATTERN .search (groups_str )
820+ if match :
821+ return match .group (1 ) == 'True'
773822 return True
774823
775824
@@ -780,36 +829,41 @@ def is_own_reverse(groups_as_lines: list[str]) -> bool:
780829 Returns:
781830 bool: Whether the reaction family's template also represents its own reverse.
782831 """
783- for line in groups_as_lines :
784- if 'ownReverse=True' in line :
785- return True
786- if 'ownReverse=False' in line :
787- return False
832+ groups_str = '' .join (groups_as_lines )
833+ match = OWN_REVERSE_PATTERN .search (groups_str )
834+ if match :
835+ return match .group (1 ) == 'True'
788836 return False
789837
790838
791- def get_reactant_groups_from_template (groups_as_lines : list [str ]) -> list [list [str ]]:
839+ def get_reactant_groups_from_template (groups_as_lines : list [str ],
840+ entries : dict [str , str ] | None = None ,
841+ ) -> list [list [str ]]:
792842 """
793843 Get the reactant groups from a template content string.
794844 Descends the entries if a group is defined as an OR complex,
795845 e.g.: group = "OR{Xtrirad_H, Xbirad_H, Xrad_H, X_H}"
796846
797847 Args:
798848 groups_as_lines (list[str]): The template content string.
849+ entries (dict[str, str], optional): Pre-extracted entries.
799850
800851 Returns:
801852 list[list[str]]: The non-complex reactant groups.
802853 """
803854 reactant_labels = get_initial_reactant_labels_from_template (groups_as_lines )
855+ if entries is None :
856+ entries = get_entries (groups_as_lines , entry_labels = reactant_labels )
804857 result = list ()
805858 for reactant_label in reactant_labels :
806- if 'OR{' not in get_group_adjlist (groups_as_lines , entry_label = reactant_label ):
859+ adj = get_group_adjlist (groups_as_lines , entry_label = reactant_label , entries = entries )
860+ if 'OR{' not in adj :
807861 result .append ([reactant_label ])
808862 else :
809863 stack = [reactant_label ]
810- while any ('OR{' in get_group_adjlist (groups_as_lines , entry_label = label ) for label in stack ):
864+ while any ('OR{' in get_group_adjlist (groups_as_lines , entry_label = label , entries = entries ) for label in stack ):
811865 label = stack .pop (0 )
812- group_adjlist = get_group_adjlist (groups_as_lines , entry_label = label )
866+ group_adjlist = get_group_adjlist (groups_as_lines , entry_label = label , entries = entries )
813867 if 'OR{' not in group_adjlist :
814868 stack .append (label )
815869 else :
@@ -851,7 +905,7 @@ def descent_complex_group(group: str) -> list[str]:
851905 list[str]: The non-complex reactant group labels, e.g.: ['Xtrirad_H', 'Xbirad_H', 'Xrad_H', 'X_H'].
852906 """
853907 if group .startswith ('OR{' ) and group .endswith ('}' ):
854- group = [g .strip () for g in group [3 :- 1 ].split (',' )]
908+ group = [c .strip () for c in group [3 :- 1 ].split (',' )]
855909 if isinstance (group , str ):
856910 group = [group ]
857911 return group
@@ -871,13 +925,15 @@ def get_initial_reactant_labels_from_template(groups_as_lines: list[str],
871925 Returns:
872926 list[str]: The reactant groups.
873927 """
874- labels = list ()
875- for line in groups_as_lines :
876- match = re .search (r'products=\[(.*?)\]' , line ) if products else re .search (r'reactants=\[(.*?)\]' , line )
877- if match :
878- labels = match .group (1 ).replace ('"' , '' ).split (', ' )
879- break
880- return labels
928+ groups_str = '' .join (groups_as_lines )
929+ pattern = PRODUCTS_PATTERN if products else REACTANTS_PATTERN
930+ match = pattern .search (groups_str )
931+ if match :
932+ content = match .group (1 )
933+ # Use regex to find all quoted strings (with backreferences) or unquoted words
934+ labels = re .findall (r'(["\'])(.*?)\1|(\w+)' , content )
935+ return [label [1 ] or label [2 ] for label in labels ]
936+ return list ()
881937
882938
883939def get_recipe_actions (groups_as_lines : list [str ]) -> list [list [str ]]:
@@ -982,43 +1038,77 @@ def split_entries(groups_str: str) -> list[str]:
9821038
9831039def get_entries (groups_as_lines : list [str ],
9841040 entry_labels : list [str ],
1041+ recursive : bool = False ,
9851042 ) -> dict [str , str ]:
9861043 """
987- Get the requested entries grom a template content string.
1044+ Get the requested entries from a template content string.
9881045
9891046 Args:
9901047 groups_as_lines (list[str]): The template content string.
991- entry_labels (list[str]): The entry labels to extract.
1048+ entry_labels (list[str], optional): The entry labels to extract. If None, all entries are extracted.
1049+ recursive (bool, optional): Whether to recursively extract child entries for OR complexes.
9921050
9931051 Returns:
9941052 dict[str, str]: The extracted entries, keys are the labels, values are the groups.
9951053 """
996- groups_str = '' .join (groups_as_lines )
997- entries = split_entries (groups_str )
998- specific_entries = dict ()
999- for i , entry in enumerate (entries ):
1000- label_match = re .search (r'label\s*=\s*"(.*?)"' , entry )
1001- group_match = re .search (r'group\s*=(.*?)(?=\w+\s*=)' , entry , re .DOTALL )
1002- if label_match is not None and group_match is not None and label_match .group (1 ) in entry_labels :
1003- specific_entries [label_match .group (1 )] = clean_text (group_match .group (1 ))
1004- if i > 2000 :
1005- break
1006- return specific_entries
1054+ groups_str = "\n " + "" .join (groups_as_lines )
1055+ # Split by `entry(` but keep the delimiter-ish part
1056+ parts = re .split (r"\nentry\s*\(" , groups_str )
1057+
1058+ temp_entries = {}
1059+ label_pat = re .compile (r"label\s*=\s*(?:([\"'])(.*?)\1|(\w+))" )
1060+ group_pat = re .compile (r"group\s*=\s*(?:\"\"\"(.*?)\"\"\"|([\"'])(.*?)\2|(OR\{.*?\}))" , re .DOTALL )
1061+
1062+ for part in parts [1 :]: # Skip the header
1063+ label_match = label_pat .search (part )
1064+ group_match = group_pat .search (part )
1065+ if label_match and group_match :
1066+ label = label_match .group (2 ) or label_match .group (3 )
1067+ # Extract the matched regex group (1 for triple quotes, 3 for single/double quotes, 4 for OR complex)
1068+ adj = group_match .group (1 ) or group_match .group (3 ) or group_match .group (4 )
1069+ temp_entries [label ] = clean_text (adj )
1070+
1071+ if entry_labels is None :
1072+ return temp_entries
1073+
1074+ all_entries = {}
1075+ to_process = list (entry_labels )
1076+ processed = set ()
1077+ while to_process :
1078+ label = to_process .pop ()
1079+ if label in processed or label not in temp_entries :
1080+ continue
1081+ processed .add (label )
1082+ adj = temp_entries [label ]
1083+ if recursive and 'OR{' in adj :
1084+ # Match OR{label1, label2, ...}
1085+ or_match = re .search (r'OR\s*\{\s*(.*?)\s*\}' , adj , re .DOTALL )
1086+ if or_match :
1087+ children_str = or_match .group (1 )
1088+ children = [c .strip () for c in children_str .split (',' )]
1089+ to_process .extend (children )
1090+ else :
1091+ all_entries [label ] = adj
1092+ return all_entries
10071093
10081094
10091095def get_group_adjlist (groups_as_lines : list [str ],
10101096 entry_label : str ,
1097+ entries : dict [str , str ] | None = None ,
10111098 ) -> str :
10121099 """
10131100 Get the corresponding group value for the given entry label.
10141101
10151102 Args:
10161103 groups_as_lines (list[str]): The template content string.
10171104 entry_label (str): The entry label to extract.
1105+ entries (dict[str, str], optional): Pre-extracted entries.
10181106
10191107 Returns:
10201108 str: The extracted group.
10211109 """
1110+ if entries is not None and entry_label in entries :
1111+ return entries [entry_label ]
10221112 specific_entries = get_entries (groups_as_lines , entry_labels = [entry_label ])
10231113 return specific_entries [entry_label ]
10241114
0 commit comments