Skip to content

Commit 2da6100

Browse files
committed
Improved the TS NMD check
NMD Improvements (arc/checks/ts.py): - Replaced mean-based displacement baseline with Median + MAD (Median Absolute Deviation) to insulate the check from floppy rotors. - Implemented a mandatory Directionality Check: ensures formed and broken bonds move in anti-correlated directions along the imaginary mode. - Separated primary (formed/broken) from secondary (changed-order) bonds in the sigma test to reflect physical displacement scales. - Set a global numerical noise floor (1e-4 A) for the Hessian and raised the default validation threshold to 3.0 sigma.
1 parent 756a9ee commit 2da6100

1 file changed

Lines changed: 127 additions & 29 deletions

File tree

arc/checks/nmd.py

Lines changed: 127 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,15 @@
1515
if TYPE_CHECKING:
1616
from arc.job.adapter import JobAdapter
1717
from arc.reaction import ARCReaction
18-
from rmgpy.molecule.molecule import Molecule
18+
from arc.molecule.molecule import Molecule
1919

2020
logger = get_logger()
2121

22+
# Module-level constants for NMD validation
23+
SIGMA_THRESHOLD = 3.0
24+
STD_FLOOR = 1e-4
25+
DIRECTIONALITY_MIN_DELTA = 0.005
26+
2227

2328
def analyze_ts_normal_mode_displacement(reaction: 'ARCReaction',
2429
job: Optional['JobAdapter'],
@@ -80,13 +85,73 @@ def analyze_ts_normal_mode_displacement(reaction: 'ARCReaction',
8085
return False
8186

8287

88+
def check_bond_directionality(formed_bonds: List[Tuple[int, int]],
89+
broken_bonds: List[Tuple[int, int]],
90+
xyzs: Tuple[dict, dict],
91+
min_delta: float = DIRECTIONALITY_MIN_DELTA,
92+
) -> bool:
93+
"""
94+
Check that formed and broken bonds move in opposite directions along the normal mode.
95+
96+
For a valid TS mode, formed bonds should shorten in one displacement direction
97+
while broken bonds should lengthen in the same direction (and vice versa).
98+
Only bonds with ``|delta| > min_delta`` are checked to avoid false failures from numerical noise.
99+
100+
Args:
101+
formed_bonds (List[Tuple[int, int]]): The bonds that are formed in the reaction.
102+
broken_bonds (List[Tuple[int, int]]): The bonds that are broken in the reaction.
103+
xyzs (Tuple[dict, dict]): The Cartesian coordinates of the TS displaced along the normal mode.
104+
min_delta (float): Minimum absolute signed difference for a bond to participate in the check.
105+
106+
Returns:
107+
bool: Whether the bond directionality is consistent with a reactive mode.
108+
"""
109+
if not formed_bonds and not broken_bonds:
110+
return True
111+
112+
def _get_signed_deltas(bonds):
113+
deltas = []
114+
for bond in bonds:
115+
# Use unweighted distances for directionality to use the min_delta threshold.
116+
d1 = get_bond_length_in_reaction(bond=bond, xyz=xyzs[0], weights=None)
117+
d2 = get_bond_length_in_reaction(bond=bond, xyz=xyzs[1], weights=None)
118+
if d1 is None or d2 is None:
119+
continue
120+
delta = d1 - d2
121+
if abs(delta) > min_delta:
122+
deltas.append(delta)
123+
return deltas
124+
125+
deltas_formed = _get_signed_deltas(formed_bonds)
126+
deltas_broken = _get_signed_deltas(broken_bonds)
127+
128+
if not deltas_formed and not deltas_broken:
129+
logger.debug('check_bond_directionality: no bonds exceeded min_delta threshold; '
130+
'returning True (vacuously consistent).')
131+
132+
# Check internal consistency: all formed bonds must move in the same direction
133+
if deltas_formed and not all(np.sign(d) == np.sign(deltas_formed[0]) for d in deltas_formed):
134+
return False
135+
136+
# Check internal consistency: all broken bonds must move in the same direction
137+
if deltas_broken and not all(np.sign(d) == np.sign(deltas_broken[0]) for d in deltas_broken):
138+
return False
139+
140+
# Check that formed and broken bonds move in opposite directions
141+
if deltas_formed and deltas_broken:
142+
if np.sign(deltas_formed[0]) == np.sign(deltas_broken[0]):
143+
return False
144+
145+
return True
146+
147+
83148
def is_nmd_correct_for_any_mapping(reaction: 'ARCReaction',
84149
xyzs: Tuple[dict, dict],
85150
formed_bonds: List[Tuple[int, int]],
86151
broken_bonds: List[Tuple[int, int]],
87152
changed_bonds: List[Tuple[int, int]],
88153
r_eq_atoms: List[List[int]],
89-
weights: np.array,
154+
weights: Optional[np.ndarray],
90155
amplitude: float,
91156
) -> bool:
92157
"""
@@ -111,32 +176,60 @@ def is_nmd_correct_for_any_mapping(reaction: 'ARCReaction',
111176
r_eq_atoms=r_eq_atoms,
112177
)
113178
for eq_formed_bonds, eq_broken_bonds, eq_changed_bonds in modified_bond_grand_list:
114-
reactive_bonds_diffs, report = get_bond_length_changes(bonds=eq_formed_bonds + eq_broken_bonds + eq_changed_bonds,
115-
xyzs=xyzs,
116-
weights=weights,
117-
amplitude=amplitude,
118-
return_none_if_change_is_insignificant=True,
119-
considered_reactive=True,
120-
)
121-
if reactive_bonds_diffs is None:
179+
if not check_bond_directionality(formed_bonds=eq_formed_bonds,
180+
broken_bonds=eq_broken_bonds,
181+
xyzs=xyzs):
182+
continue
183+
184+
primary_bonds = eq_formed_bonds + eq_broken_bonds
185+
is_isomerization = not primary_bonds and bool(eq_changed_bonds)
186+
187+
if is_isomerization:
188+
check_bonds = eq_changed_bonds
189+
elif primary_bonds:
190+
check_bonds = primary_bonds
191+
else:
192+
continue
193+
194+
check_diffs, report = get_bond_length_changes(bonds=check_bonds,
195+
xyzs=xyzs,
196+
weights=weights,
197+
amplitude=amplitude,
198+
return_none_if_change_is_insignificant=True,
199+
considered_reactive=True,
200+
)
201+
if check_diffs is None or len(check_diffs) == 0:
122202
continue
203+
123204
r_bonds, _ = reaction.get_bonds(r_bonds_only=True)
124-
non_reactive_bonds = list()
125-
for bond in r_bonds:
126-
if bond not in eq_formed_bonds and bond not in eq_broken_bonds and bond not in eq_changed_bonds:
127-
non_reactive_bonds.append(bond)
205+
reactive_bonds = set(eq_formed_bonds) | set(eq_broken_bonds) | set(eq_changed_bonds)
206+
non_reactive_bonds = [bond for bond in r_bonds if bond not in reactive_bonds]
128207
baseline, std = get_bond_length_changes_baseline_and_std(non_reactive_bonds=non_reactive_bonds,
129208
xyzs=xyzs,
130209
weights=weights,
131210
)
211+
212+
min_check_diff = float(np.min(check_diffs))
213+
132214
if baseline is None:
215+
# Small molecule case: Get UNWEIGHTED changes to compare against a physical distance
216+
unweighted_diffs, _ = get_bond_length_changes(
217+
bonds=check_bonds,
218+
xyzs=xyzs,
219+
weights=None,
220+
amplitude=amplitude,
221+
return_none_if_change_is_insignificant=False
222+
)
223+
if unweighted_diffs is not None and len(unweighted_diffs) > 0:
224+
# For small molecules without a baseline, check that the minimum reactive
225+
# bond-length change exceeds 10% of the displacement amplitude.
226+
if np.min(unweighted_diffs) > 0.1 * amplitude:
227+
return True
133228
continue
134229

135-
min_reactive_bond_diff = np.min(reactive_bonds_diffs)
136-
std = std or max(min_reactive_bond_diff * 1e-2, 1e-8)
137-
sigma = (min_reactive_bond_diff - baseline) / std
138-
if sigma > 2.5:
139-
# print(f'V {report} {baseline[0]:.2e} {std:.2e} {sigma[0]:.2e}') # left for debugging
230+
std = max(std, STD_FLOOR)
231+
sigma = (min_check_diff - baseline) / std
232+
if sigma > SIGMA_THRESHOLD:
140233
return True
141234
return False
142235

@@ -285,27 +378,31 @@ def get_bond_length_changes_baseline_and_std(non_reactive_bonds: List[Tuple[int,
285378
weights: Optional[np.array] = None,
286379
) -> Tuple[Optional[float], Optional[float]]:
287380
"""
288-
Get the baseline for bond length change of all non-reactive bonds.
381+
Get the baseline and spread of bond length changes for non-reactive bonds using robust statistics.
382+
383+
Uses the median as the baseline and MAD (median absolute deviation) scaled by 1.4826
384+
as the spread estimator, which is robust against outliers from floppy rotors.
289385
290386
Todo:
291387
When we have a comprehensive list of atom maps, we can pass the reaction and the atom map number to use, and do:
292388
non_reactive_bonds = set(r_bonds) & set(p_bonds)
293389
294390
Args:
295-
non_reactive_bonds (Set[Tuple[int, int]]): The non-reactive bonds.
391+
non_reactive_bonds (List[Tuple[int, int]]): The non-reactive bonds.
296392
xyzs (Tuple[dict, dict]): The Cartesian coordinates of the TS displaced along the normal mode.
297393
weights (np.array): The weights for the atoms.
298394
299395
Returns:
300-
Tuple[float, float]:
301-
- The max baseline of bond length differences for non-reactive bonds.
302-
- The standard deviation of bond length differences for non-reactive bonds.
396+
Tuple[Optional[float], Optional[float]]:
397+
- The median baseline of bond length differences for non-reactive bonds.
398+
- The MAD-based spread estimate of bond length differences for non-reactive bonds.
303399
"""
304400
diffs, _ = get_bond_length_changes(bonds=non_reactive_bonds, xyzs=xyzs, weights=weights)
305-
if diffs is None:
401+
if diffs is None or len(diffs) == 0:
306402
return None, None
307-
baseline = sum(diffs) / len(diffs)
308-
std = np.std(diffs)
403+
baseline = float(np.median(diffs))
404+
mad = float(np.median(np.abs(diffs - baseline)))
405+
std = mad * 1.4826 # scale factor for consistency with normal distribution
309406
return baseline, std
310407

311408

@@ -343,8 +440,9 @@ def get_bond_length_changes(bonds: Union[List[Tuple[int, int]], Set[Tuple[int, i
343440
if r_bond_length is None or p_bond_length is None:
344441
continue
345442
diff = abs(r_bond_length - p_bond_length)
346-
if amplitude is not None and return_none_if_change_is_insignificant \
347-
and abs(diff * amplitude / r_bond_length) < 0.05 and abs(diff * amplitude / p_bond_length) < 0.05:
443+
if return_none_if_change_is_insignificant \
444+
and abs(diff / r_bond_length) < 0.05 \
445+
and abs(diff / p_bond_length) < 0.05:
348446
return None, None
349447
diffs.append(diff)
350448
if considered_reactive:

0 commit comments

Comments
 (0)