Skip to content

Commit 1aeea47

Browse files
committed
NMD
1 parent 1d1c42c commit 1aeea47

1 file changed

Lines changed: 115 additions & 26 deletions

File tree

arc/checks/nmd.py

Lines changed: 115 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,15 @@
1515
if TYPE_CHECKING:
1616
from arc.job.adapter import JobAdapter
1717
from arc.reaction import ARCReaction
18-
from rmgpy.molecule.molecule import Molecule
18+
from arc.molecule.molecule import Molecule
1919

2020
logger = get_logger()
2121

22+
# Module-level constants for NMD validation
23+
SIGMA_THRESHOLD = 3.0
24+
STD_FLOOR = 1e-4
25+
DIRECTIONALITY_MIN_DELTA = 0.005
26+
2227

2328
def analyze_ts_normal_mode_displacement(reaction: 'ARCReaction',
2429
job: Optional['JobAdapter'],
@@ -80,6 +85,63 @@ def analyze_ts_normal_mode_displacement(reaction: 'ARCReaction',
8085
return False
8186

8287

88+
def check_bond_directionality(formed_bonds: List[Tuple[int, int]],
89+
broken_bonds: List[Tuple[int, int]],
90+
xyzs: Tuple[dict, dict],
91+
weights: Optional[np.array] = None,
92+
min_delta: float = DIRECTIONALITY_MIN_DELTA,
93+
) -> bool:
94+
"""
95+
Check that formed and broken bonds move in opposite directions along the normal mode.
96+
97+
For a valid TS mode, formed bonds should shorten in one displacement direction
98+
while broken bonds should lengthen in the same direction (and vice versa).
99+
Only bonds with ``|delta| > min_delta`` are checked to avoid false failures from numerical noise.
100+
101+
Args:
102+
formed_bonds (List[Tuple[int, int]]): The bonds that are formed in the reaction.
103+
broken_bonds (List[Tuple[int, int]]): The bonds that are broken in the reaction.
104+
xyzs (Tuple[dict, dict]): The Cartesian coordinates of the TS displaced along the normal mode.
105+
weights (np.array, optional): The weights for the atoms.
106+
min_delta (float): Minimum absolute signed difference for a bond to participate in the check.
107+
108+
Returns:
109+
bool: Whether the bond directionality is consistent with a reactive mode.
110+
"""
111+
if not formed_bonds and not broken_bonds:
112+
return True
113+
114+
def _get_signed_deltas(bonds):
115+
deltas = []
116+
for bond in bonds:
117+
d1 = get_bond_length_in_reaction(bond=bond, xyz=xyzs[0], weights=weights)
118+
d2 = get_bond_length_in_reaction(bond=bond, xyz=xyzs[1], weights=weights)
119+
if d1 is None or d2 is None:
120+
continue
121+
delta = d1 - d2
122+
if abs(delta) > min_delta:
123+
deltas.append(delta)
124+
return deltas
125+
126+
deltas_formed = _get_signed_deltas(formed_bonds)
127+
deltas_broken = _get_signed_deltas(broken_bonds)
128+
129+
# Check internal consistency: all formed bonds must move in the same direction
130+
if deltas_formed and not all(np.sign(d) == np.sign(deltas_formed[0]) for d in deltas_formed):
131+
return False
132+
133+
# Check internal consistency: all broken bonds must move in the same direction
134+
if deltas_broken and not all(np.sign(d) == np.sign(deltas_broken[0]) for d in deltas_broken):
135+
return False
136+
137+
# Check that formed and broken bonds move in opposite directions
138+
if deltas_formed and deltas_broken:
139+
if np.sign(deltas_formed[0]) == np.sign(deltas_broken[0]):
140+
return False
141+
142+
return True
143+
144+
83145
def is_nmd_correct_for_any_mapping(reaction: 'ARCReaction',
84146
xyzs: Tuple[dict, dict],
85147
formed_bonds: List[Tuple[int, int]],
@@ -111,32 +173,55 @@ def is_nmd_correct_for_any_mapping(reaction: 'ARCReaction',
111173
r_eq_atoms=r_eq_atoms,
112174
)
113175
for eq_formed_bonds, eq_broken_bonds, eq_changed_bonds in modified_bond_grand_list:
114-
reactive_bonds_diffs, report = get_bond_length_changes(bonds=eq_formed_bonds + eq_broken_bonds + eq_changed_bonds,
115-
xyzs=xyzs,
116-
weights=weights,
117-
amplitude=amplitude,
118-
return_none_if_change_is_insignificant=True,
119-
considered_reactive=True,
120-
)
121-
if reactive_bonds_diffs is None:
176+
# Directionality gate: formed and broken bonds must move in opposite directions
177+
if not check_bond_directionality(formed_bonds=eq_formed_bonds,
178+
broken_bonds=eq_broken_bonds,
179+
xyzs=xyzs,
180+
weights=weights):
181+
continue
182+
183+
# Separate primary (formed + broken) from secondary (changed-order) bonds
184+
primary_bonds = eq_formed_bonds + eq_broken_bonds
185+
is_isomerization = not primary_bonds and bool(eq_changed_bonds)
186+
187+
if is_isomerization:
188+
check_bonds = eq_changed_bonds
189+
elif primary_bonds:
190+
check_bonds = primary_bonds
191+
else:
192+
continue
193+
194+
check_diffs, report = get_bond_length_changes(bonds=check_bonds,
195+
xyzs=xyzs,
196+
weights=weights,
197+
amplitude=amplitude,
198+
return_none_if_change_is_insignificant=True,
199+
considered_reactive=True,
200+
)
201+
if check_diffs is None or len(check_diffs) == 0:
122202
continue
203+
123204
r_bonds, _ = reaction.get_bonds(r_bonds_only=True)
124-
non_reactive_bonds = list()
125-
for bond in r_bonds:
126-
if bond not in eq_formed_bonds and bond not in eq_broken_bonds and bond not in eq_changed_bonds:
127-
non_reactive_bonds.append(bond)
205+
non_reactive_bonds = [bond for bond in r_bonds
206+
if bond not in eq_formed_bonds
207+
and bond not in eq_broken_bonds
208+
and bond not in eq_changed_bonds]
128209
baseline, std = get_bond_length_changes_baseline_and_std(non_reactive_bonds=non_reactive_bonds,
129210
xyzs=xyzs,
130211
weights=weights,
131212
)
213+
214+
min_check_diff = float(np.min(check_diffs))
215+
132216
if baseline is None:
217+
# No non-reactive bonds (tiny molecule): use absolute displacement threshold
218+
if min_check_diff > 0.02:
219+
return True
133220
continue
134221

135-
min_reactive_bond_diff = np.min(reactive_bonds_diffs)
136-
std = std or max(min_reactive_bond_diff * 1e-2, 1e-8)
137-
sigma = (min_reactive_bond_diff - baseline) / std
138-
if sigma > 2.5:
139-
# print(f'V {report} {baseline[0]:.2e} {std:.2e} {sigma[0]:.2e}') # left for debugging
222+
std = max(std, STD_FLOOR)
223+
sigma = (min_check_diff - baseline) / std
224+
if sigma > SIGMA_THRESHOLD:
140225
return True
141226
return False
142227

@@ -285,27 +370,31 @@ def get_bond_length_changes_baseline_and_std(non_reactive_bonds: List[Tuple[int,
285370
weights: Optional[np.array] = None,
286371
) -> Tuple[Optional[float], Optional[float]]:
287372
"""
288-
Get the baseline for bond length change of all non-reactive bonds.
373+
Get the baseline and spread of bond length changes for non-reactive bonds using robust statistics.
374+
375+
Uses the median as the baseline and MAD (median absolute deviation) scaled by 1.4826
376+
as the spread estimator, which is robust against outliers from floppy rotors.
289377
290378
Todo:
291379
When we have a comprehensive list of atom maps, we can pass the reaction and the atom map number to use, and do:
292380
non_reactive_bonds = set(r_bonds) & set(p_bonds)
293381
294382
Args:
295-
non_reactive_bonds (Set[Tuple[int, int]]): The non-reactive bonds.
383+
non_reactive_bonds (List[Tuple[int, int]]): The non-reactive bonds.
296384
xyzs (Tuple[dict, dict]): The Cartesian coordinates of the TS displaced along the normal mode.
297385
weights (np.array): The weights for the atoms.
298386
299387
Returns:
300-
Tuple[float, float]:
301-
- The max baseline of bond length differences for non-reactive bonds.
302-
- The standard deviation of bond length differences for non-reactive bonds.
388+
Tuple[Optional[float], Optional[float]]:
389+
- The median baseline of bond length differences for non-reactive bonds.
390+
- The MAD-based spread estimate of bond length differences for non-reactive bonds.
303391
"""
304392
diffs, _ = get_bond_length_changes(bonds=non_reactive_bonds, xyzs=xyzs, weights=weights)
305-
if diffs is None:
393+
if diffs is None or len(diffs) == 0:
306394
return None, None
307-
baseline = sum(diffs) / len(diffs)
308-
std = np.std(diffs)
395+
baseline = float(np.median(diffs))
396+
mad = float(np.median(np.abs(diffs - baseline)))
397+
std = mad * 1.4826 # scale factor for consistency with normal distribution
309398
return baseline, std
310399

311400

0 commit comments

Comments
 (0)