1515if TYPE_CHECKING :
1616 from arc .job .adapter import JobAdapter
1717 from arc .reaction import ARCReaction
18- from rmgpy .molecule .molecule import Molecule
18+ from arc .molecule .molecule import Molecule
1919
2020logger = get_logger ()
2121
22+ # Module-level constants for NMD validation
23+ SIGMA_THRESHOLD = 3.0
24+ STD_FLOOR = 1e-4
25+ DIRECTIONALITY_MIN_DELTA = 0.005
26+
2227
2328def analyze_ts_normal_mode_displacement (reaction : 'ARCReaction' ,
2429 job : Optional ['JobAdapter' ],
@@ -80,13 +85,73 @@ def analyze_ts_normal_mode_displacement(reaction: 'ARCReaction',
8085 return False
8186
8287
88+ def check_bond_directionality (formed_bonds : List [Tuple [int , int ]],
89+ broken_bonds : List [Tuple [int , int ]],
90+ xyzs : Tuple [dict , dict ],
91+ min_delta : float = DIRECTIONALITY_MIN_DELTA ,
92+ ) -> bool :
93+ """
94+ Check that formed and broken bonds move in opposite directions along the normal mode.
95+
96+ For a valid TS mode, formed bonds should shorten in one displacement direction
97+ while broken bonds should lengthen in the same direction (and vice versa).
98+ Only bonds with ``|delta| > min_delta`` are checked to avoid false failures from numerical noise.
99+
100+ Args:
101+ formed_bonds (List[Tuple[int, int]]): The bonds that are formed in the reaction.
102+ broken_bonds (List[Tuple[int, int]]): The bonds that are broken in the reaction.
103+ xyzs (Tuple[dict, dict]): The Cartesian coordinates of the TS displaced along the normal mode.
104+ min_delta (float): Minimum absolute signed difference for a bond to participate in the check.
105+
106+ Returns:
107+ bool: Whether the bond directionality is consistent with a reactive mode.
108+ """
109+ if not formed_bonds and not broken_bonds :
110+ return True
111+
112+ def _get_signed_deltas (bonds ):
113+ deltas = []
114+ for bond in bonds :
115+ # Use unweighted distances for directionality to use the min_delta threshold.
116+ d1 = get_bond_length_in_reaction (bond = bond , xyz = xyzs [0 ], weights = None )
117+ d2 = get_bond_length_in_reaction (bond = bond , xyz = xyzs [1 ], weights = None )
118+ if d1 is None or d2 is None :
119+ continue
120+ delta = d1 - d2
121+ if abs (delta ) > min_delta :
122+ deltas .append (delta )
123+ return deltas
124+
125+ deltas_formed = _get_signed_deltas (formed_bonds )
126+ deltas_broken = _get_signed_deltas (broken_bonds )
127+
128+ if not deltas_formed and not deltas_broken :
129+ logger .debug ('check_bond_directionality: no bonds exceeded min_delta threshold; '
130+ 'returning True (vacuously consistent).' )
131+
132+ # Check internal consistency: all formed bonds must move in the same direction
133+ if deltas_formed and not all (np .sign (d ) == np .sign (deltas_formed [0 ]) for d in deltas_formed ):
134+ return False
135+
136+ # Check internal consistency: all broken bonds must move in the same direction
137+ if deltas_broken and not all (np .sign (d ) == np .sign (deltas_broken [0 ]) for d in deltas_broken ):
138+ return False
139+
140+ # Check that formed and broken bonds move in opposite directions
141+ if deltas_formed and deltas_broken :
142+ if np .sign (deltas_formed [0 ]) == np .sign (deltas_broken [0 ]):
143+ return False
144+
145+ return True
146+
147+
83148def is_nmd_correct_for_any_mapping (reaction : 'ARCReaction' ,
84149 xyzs : Tuple [dict , dict ],
85150 formed_bonds : List [Tuple [int , int ]],
86151 broken_bonds : List [Tuple [int , int ]],
87152 changed_bonds : List [Tuple [int , int ]],
88153 r_eq_atoms : List [List [int ]],
89- weights : np .array ,
154+ weights : Optional [ np .ndarray ] ,
90155 amplitude : float ,
91156 ) -> bool :
92157 """
@@ -111,32 +176,60 @@ def is_nmd_correct_for_any_mapping(reaction: 'ARCReaction',
111176 r_eq_atoms = r_eq_atoms ,
112177 )
113178 for eq_formed_bonds , eq_broken_bonds , eq_changed_bonds in modified_bond_grand_list :
114- reactive_bonds_diffs , report = get_bond_length_changes (bonds = eq_formed_bonds + eq_broken_bonds + eq_changed_bonds ,
115- xyzs = xyzs ,
116- weights = weights ,
117- amplitude = amplitude ,
118- return_none_if_change_is_insignificant = True ,
119- considered_reactive = True ,
120- )
121- if reactive_bonds_diffs is None :
179+ if not check_bond_directionality (formed_bonds = eq_formed_bonds ,
180+ broken_bonds = eq_broken_bonds ,
181+ xyzs = xyzs ):
182+ continue
183+
184+ primary_bonds = eq_formed_bonds + eq_broken_bonds
185+ is_isomerization = not primary_bonds and bool (eq_changed_bonds )
186+
187+ if is_isomerization :
188+ check_bonds = eq_changed_bonds
189+ elif primary_bonds :
190+ check_bonds = primary_bonds
191+ else :
192+ continue
193+
194+ check_diffs , report = get_bond_length_changes (bonds = check_bonds ,
195+ xyzs = xyzs ,
196+ weights = weights ,
197+ amplitude = amplitude ,
198+ return_none_if_change_is_insignificant = True ,
199+ considered_reactive = True ,
200+ )
201+ if check_diffs is None or len (check_diffs ) == 0 :
122202 continue
203+
123204 r_bonds , _ = reaction .get_bonds (r_bonds_only = True )
124- non_reactive_bonds = list ()
125- for bond in r_bonds :
126- if bond not in eq_formed_bonds and bond not in eq_broken_bonds and bond not in eq_changed_bonds :
127- non_reactive_bonds .append (bond )
205+ reactive_bonds = set (eq_formed_bonds ) | set (eq_broken_bonds ) | set (eq_changed_bonds )
206+ non_reactive_bonds = [bond for bond in r_bonds if bond not in reactive_bonds ]
128207 baseline , std = get_bond_length_changes_baseline_and_std (non_reactive_bonds = non_reactive_bonds ,
129208 xyzs = xyzs ,
130209 weights = weights ,
131210 )
211+
212+ min_check_diff = float (np .min (check_diffs ))
213+
132214 if baseline is None :
215+ # Small molecule case: Get UNWEIGHTED changes to compare against a physical distance
216+ unweighted_diffs , _ = get_bond_length_changes (
217+ bonds = check_bonds ,
218+ xyzs = xyzs ,
219+ weights = None ,
220+ amplitude = amplitude ,
221+ return_none_if_change_is_insignificant = False
222+ )
223+ if unweighted_diffs is not None and len (unweighted_diffs ) > 0 :
224+ # For small molecules without a baseline, check that the minimum reactive
225+ # bond-length change exceeds 10% of the displacement amplitude.
226+ if np .min (unweighted_diffs ) > 0.1 * amplitude :
227+ return True
133228 continue
134229
135- min_reactive_bond_diff = np .min (reactive_bonds_diffs )
136- std = std or max (min_reactive_bond_diff * 1e-2 , 1e-8 )
137- sigma = (min_reactive_bond_diff - baseline ) / std
138- if sigma > 2.5 :
139- # print(f'V {report} {baseline[0]:.2e} {std:.2e} {sigma[0]:.2e}') # left for debugging
230+ std = max (std , STD_FLOOR )
231+ sigma = (min_check_diff - baseline ) / std
232+ if sigma > SIGMA_THRESHOLD :
140233 return True
141234 return False
142235
@@ -285,27 +378,31 @@ def get_bond_length_changes_baseline_and_std(non_reactive_bonds: List[Tuple[int,
285378 weights : Optional [np .array ] = None ,
286379 ) -> Tuple [Optional [float ], Optional [float ]]:
287380 """
288- Get the baseline for bond length change of all non-reactive bonds.
381+ Get the baseline and spread of bond length changes for non-reactive bonds using robust statistics.
382+
383+ Uses the median as the baseline and MAD (median absolute deviation) scaled by 1.4826
384+ as the spread estimator, which is robust against outliers from floppy rotors.
289385
290386 Todo:
291387 When we have a comprehensive list of atom maps, we can pass the reaction and the atom map number to use, and do:
292388 non_reactive_bonds = set(r_bonds) & set(p_bonds)
293389
294390 Args:
295- non_reactive_bonds (Set [Tuple[int, int]]): The non-reactive bonds.
391+ non_reactive_bonds (List [Tuple[int, int]]): The non-reactive bonds.
296392 xyzs (Tuple[dict, dict]): The Cartesian coordinates of the TS displaced along the normal mode.
297393 weights (np.array): The weights for the atoms.
298394
299395 Returns:
300- Tuple[float, float]:
301- - The max baseline of bond length differences for non-reactive bonds.
302- - The standard deviation of bond length differences for non-reactive bonds.
396+ Tuple[Optional[ float], Optional[ float] ]:
397+ - The median baseline of bond length differences for non-reactive bonds.
398+ - The MAD-based spread estimate of bond length differences for non-reactive bonds.
303399 """
304400 diffs , _ = get_bond_length_changes (bonds = non_reactive_bonds , xyzs = xyzs , weights = weights )
305- if diffs is None :
401+ if diffs is None or len ( diffs ) == 0 :
306402 return None , None
307- baseline = sum (diffs ) / len (diffs )
308- std = np .std (diffs )
403+ baseline = float (np .median (diffs ))
404+ mad = float (np .median (np .abs (diffs - baseline )))
405+ std = mad * 1.4826 # scale factor for consistency with normal distribution
309406 return baseline , std
310407
311408
@@ -343,8 +440,9 @@ def get_bond_length_changes(bonds: Union[List[Tuple[int, int]], Set[Tuple[int, i
343440 if r_bond_length is None or p_bond_length is None :
344441 continue
345442 diff = abs (r_bond_length - p_bond_length )
346- if amplitude is not None and return_none_if_change_is_insignificant \
347- and abs (diff * amplitude / r_bond_length ) < 0.05 and abs (diff * amplitude / p_bond_length ) < 0.05 :
443+ if return_none_if_change_is_insignificant \
444+ and abs (diff / r_bond_length ) < 0.05 \
445+ and abs (diff / p_bond_length ) < 0.05 :
348446 return None , None
349447 diffs .append (diff )
350448 if considered_reactive :
0 commit comments