@@ -532,6 +532,10 @@ def compute_priority(self, task):
532532
533533 def __init__ (self , rng_seed : int = 42 ) -> None :
534534 self ._tasks : list [ExplorationTask ] = []
535+ # Parallel list of negated priorities kept in ascending order so that
536+ # bisect can locate the insertion point in O(log n) without an O(n)
537+ # list comprehension on every push().
538+ self ._neg_priorities : list [float ] = []
535539 self ._submitted : set [tuple ] = set ()
536540 self ._rng = np .random .default_rng (rng_seed )
537541
@@ -541,17 +545,21 @@ def push(self, task: ExplorationTask) -> bool:
541545 return False
542546
543547 task .priority = self .compute_priority (task )
544- # _tasks is maintained in descending priority order. Since bisect
545- # assumes ascending order, negate the key to find the correct
546- # insertion index in O(log n).
547- keys = [ - t .priority for t in self . _tasks ]
548- idx = bisect .bisect_right (keys , - task . priority )
548+ # _tasks is maintained in descending priority order. _neg_priorities
549+ # mirrors it as ascending negated values so bisect can find the correct
550+ # insertion index in O(log n) without rebuilding the list each call .
551+ neg_p = - task .priority
552+ idx = bisect .bisect_right (self . _neg_priorities , neg_p )
549553 self ._tasks .insert (idx , task )
554+ self ._neg_priorities .insert (idx , neg_p )
550555 self ._submitted .add (key )
551556 return True
552557
553558 def pop (self ) -> ExplorationTask | None :
554- return self ._tasks .pop (0 ) if self ._tasks else None
559+ if not self ._tasks :
560+ return None
561+ self ._neg_priorities .pop (0 )
562+ return self ._tasks .pop (0 )
555563
556564 def should_add (self , node : "EQNode" , reference_energy : float , ** kwargs ) -> bool :
557565 """Decide probabilistically whether to enqueue a node.
@@ -602,6 +610,8 @@ def refresh_priorities(self, ref_e: float | None) -> None:
602610 task .priority = self .compute_priority (task )
603611
604612 self ._tasks .sort (key = lambda t : t .priority , reverse = True )
613+ # Rebuild the parallel negated-priority list to stay in sync after sort.
614+ self ._neg_priorities = [- t .priority for t in self ._tasks ]
605615
606616 def export_queue_status (self ) -> list [dict ]:
607617 return [
@@ -684,6 +694,8 @@ def __init__(self, filepath: str) -> None:
684694 self ._filepath = filepath
685695 # In-memory set for O(1) look-up: (node_id, atom_i, atom_j, gamma_sign)
686696 self ._explored : set [tuple [int , int , int , str ]] = set ()
697+ # Fast O(1) membership check: which node_ids have been explored at all.
698+ self ._explored_node_ids : set [int ] = set ()
687699 self ._load ()
688700
689701 # ------------------------------------------------------------------
@@ -711,6 +723,7 @@ def _load(self) -> None:
711723 if gamma_sign not in ("+" , "-" ):
712724 continue
713725 self ._explored .add ((node_id , atom_i , atom_j , gamma_sign ))
726+ self ._explored_node_ids .add (node_id )
714727 except (ValueError , IndexError ):
715728 continue
716729 logger .info (
@@ -732,6 +745,7 @@ def record(self, node_id: int, atom_i: int, atom_j: int, gamma_sign: str) -> Non
732745 if key in self ._explored :
733746 return
734747 self ._explored .add (key )
748+ self ._explored_node_ids .add (node_id )
735749 with open (self ._filepath , "a" , encoding = "utf-8" ) as fh :
736750 fh .write (f"EQ{ node_id :06d} { atom_i } { atom_j } { gamma_sign } \n " )
737751
@@ -799,36 +813,65 @@ def _build_candidates(
799813 """Return all atom pairs that satisfy the distance and covalency filters.
800814
801815 Pairs are expressed as 0-based index tuples ``(i, j)`` with ``i < j``.
816+
817+ Implementation uses numpy broadcasting to evaluate all N*(N-1)/2 pairs
818+ in a single vectorised pass, avoiding an O(N²) Python-level loop.
802819 """
803820 n = len (symbols )
804821 if n < 2 :
805822 return []
806823
807- dmat = distance_matrix (coords )
808- candidates : list [tuple [int , int ]] = []
809-
810824 # Build the pool of atom indices subject to active_atoms restriction.
811825 # active_atoms stores 1-based labels; convert to 0-based for indexing.
812826 if self .active_atoms is not None :
813- atom_indices = [i for i in range (n ) if (i + 1 ) in self .active_atoms ]
827+ atom_indices = np .array (
828+ [i for i in range (n ) if (i + 1 ) in self .active_atoms ], dtype = np .intp
829+ )
814830 else :
815- atom_indices = list ( range ( n ) )
831+ atom_indices = np . arange ( n , dtype = np . intp )
816832
817- for idx , i in enumerate (atom_indices ):
818- for j in atom_indices [idx + 1 :]:
819- dist = dmat [i , j ]
820- if self .dist_lower_ang <= dist <= self .dist_upper_ang :
821- if covalent_radii_lib is not None :
822- try :
823- r_i = covalent_radii_lib (symbols [i ]) * _BOHR2ANG
824- r_j = covalent_radii_lib (symbols [j ]) * _BOHR2ANG
825- if dist <= self .covalent_margin * (r_i + r_j ):
826- continue
827- except KeyError :
828- pass
829- candidates .append ((i , j ))
833+ if len (atom_indices ) < 2 :
834+ return []
835+
836+ # Restrict coords/symbols to the active subset.
837+ sub_coords = coords [atom_indices ] # (m, 3)
838+ sub_symbols = [symbols [i ] for i in atom_indices ]
839+ m = len (atom_indices )
840+
841+ # Pairwise distances for the active subset — vectorised.
842+ diff = sub_coords [:, np .newaxis , :] - sub_coords [np .newaxis , :, :] # (m,m,3)
843+ dmat = np .sqrt ((diff ** 2 ).sum (axis = - 1 )) # (m,m)
830844
831- return candidates
845+ # Upper-triangle indices (i < j).
846+ ii , jj = np .triu_indices (m , k = 1 ) # each has shape (P,) where P = m*(m-1)/2
847+ dists = dmat [ii , jj ] # (P,)
848+
849+ # ── Distance window filter ────────────────────────────────────────
850+ dist_mask = (dists >= self .dist_lower_ang ) & (dists <= self .dist_upper_ang )
851+
852+ # ── Covalent-bond exclusion filter ────────────────────────────────
853+ if covalent_radii_lib is not None :
854+ # Build per-atom radii array for the active subset, falling back to
855+ # a generic value for unknown elements so the filter still works.
856+ radii = np .empty (m , dtype = np .float64 )
857+ for k , sym in enumerate (sub_symbols ):
858+ try :
859+ radii [k ] = covalent_radii_lib (sym ) * _BOHR2ANG
860+ except KeyError :
861+ radii [k ] = 0.75 # generic fallback [Å]
862+ # Vectorised threshold for every upper-triangle pair.
863+ cov_thresh = self .covalent_margin * (radii [ii ] + radii [jj ]) # (P,)
864+ cov_mask = dists > cov_thresh
865+ else :
866+ # No radii library — apply a fixed generic threshold.
867+ cov_mask = dists > (self .covalent_margin * 1.5 )
868+
869+ keep = dist_mask & cov_mask # (P,) boolean
870+
871+ # Convert back to original (0-based) atom indices.
872+ orig_ii = atom_indices [ii [keep ]]
873+ orig_jj = atom_indices [jj [keep ]]
874+ return list (zip (orig_ii .tolist (), orig_jj .tolist ()))
832875
833876 # ------------------------------------------------------------------
834877 # Public interface
@@ -2124,11 +2167,10 @@ def _find_or_register_node(
21242167 def _node_has_been_explored (self , node_id : int ) -> bool :
21252168 """Return ``True`` if *node_id* has at least one recorded exploration.
21262169
2127- Used to implement "skip-after-first" semantics for excluded nodes:
2128- an excluded node is still allowed one round of AFIR exploration
2129- when it is first registered, but is skipped on all subsequent calls.
2170+ Uses the O(1) ``_explored_node_ids`` set on :class:`ExploredPairsLog`
2171+ instead of a linear scan over all explored tuples.
21302172 """
2131- return any ( nid == node_id for ( nid , * _ ) in self .explored_log ._explored )
2173+ return node_id in self .explored_log ._explored_node_ids
21322174
21332175 def _enqueue_perturbations (self , node : EQNode , force_add : bool = False ) -> None :
21342176 if node .coords .size == 0 :
@@ -2182,17 +2224,21 @@ def _enqueue_perturbations(self, node: EQNode, force_add: bool = False) -> None:
21822224 gamma_signs .append ("-" )
21832225
21842226 # ── Filter out already-explored and already-queued pairs ─────────
2185- unexplored : list [tuple [int , int , str ]] = []
2186- for (i0 , j0 ) in all_candidates :
2187- atom_i , atom_j = i0 + 1 , j0 + 1 # convert to 1-based
2188- for sign in gamma_signs :
2189- if self .explored_log .has (node .node_id , atom_i , atom_j , sign ):
2190- continue # already run — skip
2191- gamma_str = neg_gamma_str if sign == "-" else pos_gamma_str
2192- queue_key = (node .node_id , (gamma_str , str (atom_i ), str (atom_j )))
2193- if queue_key in self .queue ._submitted :
2194- continue # already in the queue — skip
2195- unexplored .append ((i0 , j0 , sign ))
2227+ # Build unexplored list with a single list comprehension to avoid
2228+ # Python-level nested loop overhead. Local variable aliases avoid
2229+ # repeated attribute lookups inside the hot path.
2230+ nid = node .node_id
2231+ explored_set = self .explored_log ._explored # set[(nid,i,j,sign)]
2232+ submitted_set = self .queue ._submitted # set[(nid, tuple(params))]
2233+ sign_to_gamma = {"+" : pos_gamma_str , "-" : neg_gamma_str }
2234+
2235+ unexplored : list [tuple [int , int , str ]] = [
2236+ (i0 , j0 , sign )
2237+ for (i0 , j0 ) in all_candidates
2238+ for sign in gamma_signs
2239+ if (nid , i0 + 1 , j0 + 1 , sign ) not in explored_set
2240+ and (nid , (sign_to_gamma [sign ], str (i0 + 1 ), str (j0 + 1 ))) not in submitted_set
2241+ ]
21962242
21972243 if not unexplored :
21982244 logger .debug (
@@ -2215,7 +2261,7 @@ def _enqueue_perturbations(self, node: EQNode, force_add: bool = False) -> None:
22152261 for idx in chosen_indices :
22162262 i0 , j0 , sign = unexplored [int (idx )]
22172263 atom_i , atom_j = i0 + 1 , j0 + 1
2218- gamma_str = neg_gamma_str if sign == "-" else pos_gamma_str
2264+ gamma_str = sign_to_gamma [ sign ]
22192265 afir_params = [gamma_str , str (atom_i ), str (atom_j )]
22202266 task = ExplorationTask (
22212267 node_id = node .node_id ,
0 commit comments