1010
1111import bisect
1212import glob
13+ import heapq
1314import json
1415import logging
1516import os
@@ -92,8 +93,8 @@ def parse_xyz(filepath: str) -> tuple[list[str], np.ndarray]:
9293 return symbols , np .array (coords_raw , dtype = float )
9394
9495def distance_matrix (coords : np .ndarray ) -> np .ndarray :
95- diff = coords [:, np . newaxis , :] - coords [ np . newaxis , :, :]
96- return np . sqrt (( diff ** 2 ). sum ( axis = - 1 ) )
96+ # cdist avoids the (N,N,3) intermediate array produced by manual broadcasting.
97+ return cdist ( coords , coords )
9798
9899
99100# ===========================================================================
@@ -110,6 +111,15 @@ class StructureChecker:
110111 # Relative tolerance for declaring two eigenvalues degenerate.
111112 _DEGENERACY_REL_TOL : float = 0.02
112113
114+ # The 4 proper rotations (det=+1) from PCA sign-flip ambiguity.
115+ # Built once at class definition time instead of on every call.
116+ _SIGN_FLIP_MATS : tuple [np .ndarray , ...] = (
117+ np .diag ([ 1.0 , 1.0 , 1.0 ]),
118+ np .diag ([- 1.0 , - 1.0 , 1.0 ]),
119+ np .diag ([- 1.0 , 1.0 , - 1.0 ]),
120+ np .diag ([ 1.0 , - 1.0 , - 1.0 ]),
121+ )
122+
113123 def __init__ (self , rmsd_threshold : float = 0.30 ) -> None :
114124 self .rmsd_threshold = rmsd_threshold
115125
@@ -194,16 +204,34 @@ def _try_candidates(
194204 sym_a : list [str ], ca : np .ndarray ,
195205 sym_b : list [str ], cb : np .ndarray ,
196206 ) -> float :
197- """Evaluate every rotation candidate and return the minimum RMSD found."""
207+ """Evaluate every rotation candidate and return the minimum RMSD found.
208+
209+ Element-to-index groupings are precomputed once before the rotation
210+ loop so that the O(N) list comprehensions inside ``_optimal_mapping``
211+ are not repeated for every candidate.
212+ """
213+ # --- Precompute element groups once (not per rotation) ---
214+ elem_groups_a : dict [str , np .ndarray ] = {}
215+ for i , s in enumerate (sym_a ):
216+ elem_groups_a .setdefault (s , []).append (i ) # type: ignore[arg-type]
217+ groups_a = {e : np .array (v , dtype = np .intp ) for e , v in elem_groups_a .items ()}
218+
219+ elem_groups_b : dict [str , np .ndarray ] = {}
220+ for i , s in enumerate (sym_b ):
221+ elem_groups_b .setdefault (s , []).append (i ) # type: ignore[arg-type]
222+ groups_b = {e : np .array (v , dtype = np .intp ) for e , v in elem_groups_b .items ()}
223+
198224 min_rmsd = float ("inf" )
199225 for R in candidates :
200226 cb_rot = cb @ R .T
201- perm = self ._optimal_mapping ( sym_a , ca , sym_b , cb_rot )
227+ perm = self ._optimal_mapping_fast ( ca , cb_rot , groups_a , groups_b )
202228 if perm is None :
203229 continue
204230 rmsd = self ._kabsch_rmsd (ca , cb_rot [perm ])
205231 if rmsd < min_rmsd :
206232 min_rmsd = rmsd
233+ if min_rmsd < self .rmsd_threshold :
234+ return min_rmsd # Early exit: threshold already met
207235 return min_rmsd
208236
209237 # ------------------------------------------------------------------ #
@@ -246,19 +274,14 @@ def _pca_align(coords: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
246274 # Rotation candidates #
247275 # ------------------------------------------------------------------ #
248276
249- @staticmethod
250- def _sign_flip_candidates () -> list [np .ndarray ]:
277+ @classmethod
278+ def _sign_flip_candidates (cls ) -> list [np .ndarray ]:
251279 """
252280 The 4 proper rotations (det = +1) arising from sign-flip ambiguity
253281 of PCA eigenvectors. Always necessary; sufficient when no
254282 eigenvalue degeneracy is present.
255283 """
256- return [
257- np .diag ([ 1.0 , 1.0 , 1.0 ]),
258- np .diag ([- 1.0 , - 1.0 , 1.0 ]),
259- np .diag ([- 1.0 , 1.0 , - 1.0 ]),
260- np .diag ([ 1.0 , - 1.0 , - 1.0 ]),
261- ]
284+ return list (cls ._SIGN_FLIP_MATS )
262285
263286 @classmethod
264287 def _build_planar_candidates (
@@ -323,25 +346,54 @@ def _so3_grid(n: int) -> list[np.ndarray]:
323346 so that grid points are roughly uniformly distributed on S².
324347
325348 Total: n³ rotation matrices (512 for n = 8).
326- """
327- rotations : list [np .ndarray ] = []
328- for i in range (n ):
329- alpha = 2 * np .pi * i / n
330- ca , sa = np .cos (alpha ), np .sin (alpha )
331- Rz_alpha = np .array ([[ca , - sa , 0.0 ], [sa , ca , 0.0 ], [0.0 , 0.0 , 1.0 ]])
332-
333- for j in range (n ):
334- beta = np .arccos (np .clip (1.0 - 2.0 * (j + 0.5 ) / n , - 1.0 , 1.0 ))
335- cb , sb = np .cos (beta ), np .sin (beta )
336- Ry_beta = np .array ([[cb , 0.0 , sb ], [0.0 , 1.0 , 0.0 ], [- sb , 0.0 , cb ]])
337349
338- for k in range (n ):
339- gamma = 2 * np .pi * k / n
340- cg , sg = np .cos (gamma ), np .sin (gamma )
341- Rz_gamma = np .array ([[cg , - sg , 0.0 ], [sg , cg , 0.0 ], [0.0 , 0.0 , 1.0 ]])
342- rotations .append (Rz_alpha @ Ry_beta @ Rz_gamma )
343-
344- return rotations
350+ Vectorised: all n³ matrix products are computed with batched numpy
351+ operations instead of a triple Python for-loop.
352+ """
353+ # --- Angle arrays ---
354+ k = np .arange (n )
355+ alphas = 2.0 * np .pi * k / n # (n,)
356+ betas = np .arccos (np .clip (1.0 - 2.0 * (k + 0.5 ) / n , - 1.0 , 1.0 )) # (n,)
357+ gammas = 2.0 * np .pi * k / n # (n,)
358+
359+ # --- Component trig values ---
360+ ca , sa = np .cos (alphas ), np .sin (alphas ) # (n,)
361+ cb , sb = np .cos (betas ), np .sin (betas ) # (n,)
362+ cg , sg = np .cos (gammas ), np .sin (gammas ) # (n,)
363+
364+ # --- Batched Rz(alpha): shape (n, 3, 3) ---
365+ zero = np .zeros (n )
366+ one = np .ones (n )
367+ Rz_a = np .stack ([
368+ np .stack ([ ca , - sa , zero ], axis = - 1 ),
369+ np .stack ([ sa , ca , zero ], axis = - 1 ),
370+ np .stack ([zero , zero , one ], axis = - 1 ),
371+ ], axis = 1 ) # (n, 3, 3)
372+
373+ # --- Batched Ry(beta): shape (n, 3, 3) ---
374+ Ry_b = np .stack ([
375+ np .stack ([ cb , zero , sb ], axis = - 1 ),
376+ np .stack ([zero , one , zero ], axis = - 1 ),
377+ np .stack ([- sb , zero , cb ], axis = - 1 ),
378+ ], axis = 1 ) # (n, 3, 3)
379+
380+ # --- Batched Rz(gamma): shape (n, 3, 3) ---
381+ Rz_g = np .stack ([
382+ np .stack ([ cg , - sg , zero ], axis = - 1 ),
383+ np .stack ([ sg , cg , zero ], axis = - 1 ),
384+ np .stack ([zero , zero , one ], axis = - 1 ),
385+ ], axis = 1 ) # (n, 3, 3)
386+
387+ # --- ZYZ product over all (n, n, n) combinations ---
388+ # Rz_a[:, None, None] @ Ry_b[None, :, None] @ Rz_g[None, None, :]
389+ # Broadcasting shapes: (n,1,1,3,3) @ (1,n,1,3,3) @ (1,1,n,3,3)
390+ Rza = Rz_a [:, None , None ] # (n, 1, 1, 3, 3)
391+ Ryb = Ry_b [None , :, None ] # (1, n, 1, 3, 3)
392+ Rzg = Rz_g [None , None , :] # (1, 1, n, 3, 3)
393+ R_all = Rza @ Ryb @ Rzg # (n, n, n, 3, 3)
394+
395+ # Flatten to list of (3,3) matrices.
396+ return list (R_all .reshape (- 1 , 3 , 3 ))
345397
346398 @staticmethod
347399 def _Rx (t : float ) -> np .ndarray :
@@ -357,13 +409,36 @@ def _Rz(t: float) -> np.ndarray:
357409 # Atom mapping (Hungarian algorithm) #
358410 # ------------------------------------------------------------------ #
359411
412+ @staticmethod
413+ def _optimal_mapping_fast (
414+ coords_a : np .ndarray ,
415+ coords_b : np .ndarray ,
416+ groups_a : dict [str , np .ndarray ],
417+ groups_b : dict [str , np .ndarray ],
418+ ) -> list [int ] | None :
419+ """Find the atom permutation of B minimising total squared distance to A.
420+
421+ Accepts precomputed element-to-index arrays (``groups_a``, ``groups_b``)
422+ so that the grouping step is not repeated for every rotation candidate.
423+ Returns ``None`` if stoichiometry is inconsistent.
424+ """
425+ perm : list [int | None ] = [None ] * sum (len (v ) for v in groups_a .values ())
426+ for elem , idx_a in groups_a .items ():
427+ idx_b = groups_b .get (elem )
428+ if idx_b is None or len (idx_a ) != len (idx_b ):
429+ return None
430+ cost = cdist (coords_a [idx_a ], coords_b [idx_b ], metric = "sqeuclidean" )
431+ row_ind , col_ind = linear_sum_assignment (cost )
432+ for r , c in zip (row_ind , col_ind ):
433+ perm [idx_a [r ]] = idx_b [c ]
434+ return None if None in perm else perm # type: ignore[return-value]
435+
360436 @staticmethod
361437 def _optimal_mapping (
362438 sym_a : list [str ], coords_a : np .ndarray ,
363439 sym_b : list [str ], coords_b : np .ndarray ,
364440 ) -> list [int ] | None :
365- """
366- Find the permutation of B's atoms that minimises the total
441+ """Find the permutation of B's atoms that minimises the total
367442 squared distance to A, solved independently per element.
368443 Returns None if stoichiometry is inconsistent.
369444 """
@@ -443,17 +518,41 @@ def fingerprint(
443518 Each key is a ``(elem_a, elem_b)`` tuple with elements in sorted
444519 order (so C–H and H–C map to the same key). The value is the
445520 number of such bonds.
521+
522+ Radii are precomputed per unique element so ``_bond_threshold`` is
523+ called at most once per element instead of once per atom-pair.
524+ Distances are computed with ``cdist`` to avoid a Python-level O(N²)
525+ loop.
446526 """
447527 n = len (symbols )
448- dmat = distance_matrix (coords )
449- counts : dict [tuple [str , str ], int ] = {}
528+ # Precompute covalent radius for each unique element.
529+ unique_elems = set (symbols )
530+ elem_radius : dict [str , float ] = {}
531+ for elem in unique_elems :
532+ if covalent_radii_lib is not None :
533+ try :
534+ elem_radius [elem ] = covalent_radii_lib (elem ) * _BOHR2ANG
535+ except KeyError :
536+ elem_radius [elem ] = 0.75 # generic fallback [Å]
537+ else :
538+ elem_radius [elem ] = 0.75
539+
540+ radii_arr = np .array ([elem_radius [s ] for s in symbols ], dtype = np .float64 )
541+
542+ # Pairwise distances — vectorised via cdist.
543+ dmat = cdist (coords , coords )
544+ ii , jj = np .triu_indices (n , k = 1 )
545+ dists = dmat [ii , jj ]
450546
451- for i in range (n ):
452- for j in range (i + 1 , n ):
453- threshold = self ._bond_threshold (symbols [i ], symbols [j ])
454- if dmat [i , j ] <= threshold :
455- key = (min (symbols [i ], symbols [j ]), max (symbols [i ], symbols [j ]))
456- counts [key ] = counts .get (key , 0 ) + 1
547+ # Per-pair bonding threshold.
548+ thresholds = self .covalent_margin * (radii_arr [ii ] + radii_arr [jj ])
549+ bonded_idx = np .where (dists <= thresholds )[0 ]
550+
551+ counts : dict [tuple [str , str ], int ] = {}
552+ for k in bonded_idx :
553+ si , sj = symbols [ii [k ]], symbols [jj [k ]]
554+ key = (si , sj ) if si <= sj else (sj , si )
555+ counts [key ] = counts .get (key , 0 ) + 1
457556
458557 return counts
459558
@@ -531,11 +630,15 @@ def compute_priority(self, task):
531630 """
532631
533632 def __init__ (self , rng_seed : int = 42 ) -> None :
633+ # _tasks: canonical list of ExplorationTask objects.
634+ # Kept as a real list so that subclasses (e.g. RCMCQueue) can call
635+ # .sort() and .pop(0) on it directly without breaking.
534636 self ._tasks : list [ExplorationTask ] = []
535- # Parallel list of negated priorities kept in ascending order so that
536- # bisect can locate the insertion point in O(log n) without an O(n)
537- # list comprehension on every push().
538- self ._neg_priorities : list [float ] = []
637+ # _heap: parallel min-heap of (-priority, counter, task) used by the
638+ # base-class push()/pop() for O(log n) insertion and extraction.
639+ # RCMCQueue overrides pop() entirely and never touches _heap.
640+ self ._heap : list [tuple [float , int , ExplorationTask ]] = []
641+ self ._push_counter : int = 0
539642 self ._submitted : set [tuple ] = set ()
540643 self ._rng = np .random .default_rng (rng_seed )
541644
@@ -545,21 +648,27 @@ def push(self, task: ExplorationTask) -> bool:
545648 return False
546649
547650 task .priority = self .compute_priority (task )
548- # _tasks is maintained in descending priority order. _neg_priorities
549- # mirrors it as ascending negated values so bisect can find the correct
550- # insertion index in O(log n) without rebuilding the list each call.
551- neg_p = - task .priority
552- idx = bisect .bisect_right (self ._neg_priorities , neg_p )
553- self ._tasks .insert (idx , task )
554- self ._neg_priorities .insert (idx , neg_p )
651+ # Update both _tasks (for subclass access) and _heap (for base-class pop).
652+ self ._tasks .append (task )
653+ heapq .heappush (self ._heap , (- task .priority , self ._push_counter , task ))
654+ self ._push_counter += 1
555655 self ._submitted .add (key )
556656 return True
557657
558658 def pop (self ) -> ExplorationTask | None :
559- if not self ._tasks :
659+ """Pop the highest-priority task using the heap (O(log n)).
660+
661+ Also removes the task from ``_tasks`` so subclasses that iterate
662+ ``_tasks`` see a consistent state.
663+ """
664+ if not self ._heap :
560665 return None
561- self ._neg_priorities .pop (0 )
562- return self ._tasks .pop (0 )
666+ _ , _ , task = heapq .heappop (self ._heap )
667+ try :
668+ self ._tasks .remove (task ) # O(n) but only called in base-class path
669+ except ValueError :
670+ pass
671+ return task
563672
564673 def should_add (self , node : "EQNode" , reference_energy : float , ** kwargs ) -> bool :
565674 """Decide probabilistically whether to enqueue a node.
@@ -631,9 +740,10 @@ def refresh_priorities(self, ref_e: float | None) -> None:
631740 task .metadata ["delta_E_hartree" ] = eff_e - ref_e
632741 task .priority = self .compute_priority (task )
633742
634- self ._tasks .sort (key = lambda t : t .priority , reverse = True )
635- # Rebuild the parallel negated-priority list to stay in sync after sort.
636- self ._neg_priorities = [- t .priority for t in self ._tasks ]
743+ # Rebuild the heap from the updated _tasks list.
744+ self ._heap = [(- t .priority , i , t ) for i , t in enumerate (self ._tasks )]
745+ heapq .heapify (self ._heap )
746+ self ._push_counter = len (self ._heap )
637747
638748 def export_queue_status (self ) -> list [dict ]:
639749 return [
@@ -1047,15 +1157,23 @@ def to_dict(self) -> dict:
10471157 data [k ] = str (v )
10481158 return data
10491159
1160+ # Sentinel object used by NetworkGraph to distinguish "not yet computed"
1161+ # from a cached value of None (meaning no energy is available).
1162+ _UNSET = object ()
1163+
1164+
10501165class NetworkGraph :
10511166 def __init__ (self ) -> None :
10521167 self ._nodes : dict [int , EQNode ] = {}
10531168 self ._edges : dict [int , TSEdge ] = {}
10541169 self ._node_counter : int = 0
10551170 self ._edge_counter : int = 0
1171+ # Cached reference energy; set to _UNSET when invalidated.
1172+ self ._ref_energy_cache : float | None = _UNSET # type: ignore[assignment]
10561173
10571174 def add_node (self , node : EQNode ) -> None :
10581175 self ._nodes [node .node_id ] = node
1176+ self ._ref_energy_cache = _UNSET # type: ignore[assignment]
10591177
10601178 def get_node (self , node_id : int ) -> EQNode | None :
10611179 return self ._nodes .get (node_id )
@@ -1091,15 +1209,25 @@ def reference_energy(self) -> float | None:
10911209 intended mixed-mode behaviour (see Q2 design decision).
10921210 * When **no** node has free energy, fall back to the minimum
10931211 electronic energy, preserving the original behaviour.
1212+
1213+ The result is cached and automatically invalidated whenever a new
1214+ node is added via :meth:`add_node`.
10941215 """
1216+ if self ._ref_energy_cache is not _UNSET :
1217+ return self ._ref_energy_cache # type: ignore[return-value]
1218+
10951219 free_energies = [
10961220 n .free_energy for n in self ._nodes .values ()
10971221 if n .free_energy is not None
10981222 ]
10991223 if free_energies :
1100- return min (free_energies )
1101- real_energies = [n .energy for n in self ._nodes .values () if n .has_real_energy ]
1102- return min (real_energies ) if real_energies else None
1224+ result : float | None = min (free_energies )
1225+ else :
1226+ real_energies = [n .energy for n in self ._nodes .values () if n .has_real_energy ]
1227+ result = min (real_energies ) if real_energies else None
1228+
1229+ self ._ref_energy_cache = result # type: ignore[assignment]
1230+ return result
11031231
11041232 def save (self , filepath : str ) -> None :
11051233 data = {
@@ -2285,6 +2413,8 @@ def _flush_node_energy_updates(self) -> None:
22852413 )
22862414
22872415 self ._pending_node_updates .clear ()
2416+ # Node energies may have changed; invalidate the cached reference energy.
2417+ self .graph ._ref_energy_cache = _UNSET # type: ignore[assignment]
22882418
22892419 def _find_or_register_node (
22902420 self ,
0 commit comments