Skip to content
This repository was archived by the owner on May 5, 2026. It is now read-only.

Commit f9c58ae

Browse files
authored
Add files via upload
1 parent c3c16d0 commit f9c58ae

1 file changed

Lines changed: 87 additions & 41 deletions

File tree

multioptpy/Wrapper/mapper.py

Lines changed: 87 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,10 @@ def compute_priority(self, task):
532532

533533
def __init__(self, rng_seed: int = 42) -> None:
534534
self._tasks: list[ExplorationTask] = []
535+
# Parallel list of negated priorities kept in ascending order so that
536+
# bisect can locate the insertion point in O(log n) without an O(n)
537+
# list comprehension on every push().
538+
self._neg_priorities: list[float] = []
535539
self._submitted: set[tuple] = set()
536540
self._rng = np.random.default_rng(rng_seed)
537541

@@ -541,17 +545,21 @@ def push(self, task: ExplorationTask) -> bool:
541545
return False
542546

543547
task.priority = self.compute_priority(task)
544-
# _tasks is maintained in descending priority order. Since bisect
545-
# assumes ascending order, negate the key to find the correct
546-
# insertion index in O(log n).
547-
keys = [-t.priority for t in self._tasks]
548-
idx = bisect.bisect_right(keys, -task.priority)
548+
# _tasks is maintained in descending priority order. _neg_priorities
549+
# mirrors it as ascending negated values so bisect can find the correct
550+
# insertion index in O(log n) without rebuilding the list each call.
551+
neg_p = -task.priority
552+
idx = bisect.bisect_right(self._neg_priorities, neg_p)
549553
self._tasks.insert(idx, task)
554+
self._neg_priorities.insert(idx, neg_p)
550555
self._submitted.add(key)
551556
return True
552557

553558
def pop(self) -> ExplorationTask | None:
554-
return self._tasks.pop(0) if self._tasks else None
559+
if not self._tasks:
560+
return None
561+
self._neg_priorities.pop(0)
562+
return self._tasks.pop(0)
555563

556564
def should_add(self, node: "EQNode", reference_energy: float, **kwargs) -> bool:
557565
"""Decide probabilistically whether to enqueue a node.
@@ -602,6 +610,8 @@ def refresh_priorities(self, ref_e: float | None) -> None:
602610
task.priority = self.compute_priority(task)
603611

604612
self._tasks.sort(key=lambda t: t.priority, reverse=True)
613+
# Rebuild the parallel negated-priority list to stay in sync after sort.
614+
self._neg_priorities = [-t.priority for t in self._tasks]
605615

606616
def export_queue_status(self) -> list[dict]:
607617
return [
@@ -684,6 +694,8 @@ def __init__(self, filepath: str) -> None:
684694
self._filepath = filepath
685695
# In-memory set for O(1) look-up: (node_id, atom_i, atom_j, gamma_sign)
686696
self._explored: set[tuple[int, int, int, str]] = set()
697+
# Fast O(1) membership check: which node_ids have been explored at all.
698+
self._explored_node_ids: set[int] = set()
687699
self._load()
688700

689701
# ------------------------------------------------------------------
@@ -711,6 +723,7 @@ def _load(self) -> None:
711723
if gamma_sign not in ("+", "-"):
712724
continue
713725
self._explored.add((node_id, atom_i, atom_j, gamma_sign))
726+
self._explored_node_ids.add(node_id)
714727
except (ValueError, IndexError):
715728
continue
716729
logger.info(
@@ -732,6 +745,7 @@ def record(self, node_id: int, atom_i: int, atom_j: int, gamma_sign: str) -> Non
732745
if key in self._explored:
733746
return
734747
self._explored.add(key)
748+
self._explored_node_ids.add(node_id)
735749
with open(self._filepath, "a", encoding="utf-8") as fh:
736750
fh.write(f"EQ{node_id:06d} {atom_i} {atom_j} {gamma_sign}\n")
737751

@@ -799,36 +813,65 @@ def _build_candidates(
799813
"""Return all atom pairs that satisfy the distance and covalency filters.
800814
801815
Pairs are expressed as 0-based index tuples ``(i, j)`` with ``i < j``.
816+
817+
Implementation uses numpy broadcasting to evaluate all N*(N-1)/2 pairs
818+
in a single vectorised pass, avoiding an O(N²) Python-level loop.
802819
"""
803820
n = len(symbols)
804821
if n < 2:
805822
return []
806823

807-
dmat = distance_matrix(coords)
808-
candidates: list[tuple[int, int]] = []
809-
810824
# Build the pool of atom indices subject to active_atoms restriction.
811825
# active_atoms stores 1-based labels; convert to 0-based for indexing.
812826
if self.active_atoms is not None:
813-
atom_indices = [i for i in range(n) if (i + 1) in self.active_atoms]
827+
atom_indices = np.array(
828+
[i for i in range(n) if (i + 1) in self.active_atoms], dtype=np.intp
829+
)
814830
else:
815-
atom_indices = list(range(n))
831+
atom_indices = np.arange(n, dtype=np.intp)
816832

817-
for idx, i in enumerate(atom_indices):
818-
for j in atom_indices[idx + 1:]:
819-
dist = dmat[i, j]
820-
if self.dist_lower_ang <= dist <= self.dist_upper_ang:
821-
if covalent_radii_lib is not None:
822-
try:
823-
r_i = covalent_radii_lib(symbols[i]) * _BOHR2ANG
824-
r_j = covalent_radii_lib(symbols[j]) * _BOHR2ANG
825-
if dist <= self.covalent_margin * (r_i + r_j):
826-
continue
827-
except KeyError:
828-
pass
829-
candidates.append((i, j))
833+
if len(atom_indices) < 2:
834+
return []
835+
836+
# Restrict coords/symbols to the active subset.
837+
sub_coords = coords[atom_indices] # (m, 3)
838+
sub_symbols = [symbols[i] for i in atom_indices]
839+
m = len(atom_indices)
840+
841+
# Pairwise distances for the active subset — vectorised.
842+
diff = sub_coords[:, np.newaxis, :] - sub_coords[np.newaxis, :, :] # (m,m,3)
843+
dmat = np.sqrt((diff ** 2).sum(axis=-1)) # (m,m)
830844

831-
return candidates
845+
# Upper-triangle indices (i < j).
846+
ii, jj = np.triu_indices(m, k=1) # each has shape (P,) where P = m*(m-1)/2
847+
dists = dmat[ii, jj] # (P,)
848+
849+
# ── Distance window filter ────────────────────────────────────────
850+
dist_mask = (dists >= self.dist_lower_ang) & (dists <= self.dist_upper_ang)
851+
852+
# ── Covalent-bond exclusion filter ────────────────────────────────
853+
if covalent_radii_lib is not None:
854+
# Build per-atom radii array for the active subset, falling back to
855+
# a generic value for unknown elements so the filter still works.
856+
radii = np.empty(m, dtype=np.float64)
857+
for k, sym in enumerate(sub_symbols):
858+
try:
859+
radii[k] = covalent_radii_lib(sym) * _BOHR2ANG
860+
except KeyError:
861+
radii[k] = 0.75 # generic fallback [Å]
862+
# Vectorised threshold for every upper-triangle pair.
863+
cov_thresh = self.covalent_margin * (radii[ii] + radii[jj]) # (P,)
864+
cov_mask = dists > cov_thresh
865+
else:
866+
# No radii library — apply a fixed generic threshold.
867+
cov_mask = dists > (self.covalent_margin * 1.5)
868+
869+
keep = dist_mask & cov_mask # (P,) boolean
870+
871+
# Convert back to original (0-based) atom indices.
872+
orig_ii = atom_indices[ii[keep]]
873+
orig_jj = atom_indices[jj[keep]]
874+
return list(zip(orig_ii.tolist(), orig_jj.tolist()))
832875

833876
# ------------------------------------------------------------------
834877
# Public interface
@@ -2124,11 +2167,10 @@ def _find_or_register_node(
21242167
def _node_has_been_explored(self, node_id: int) -> bool:
21252168
"""Return ``True`` if *node_id* has at least one recorded exploration.
21262169
2127-
Used to implement "skip-after-first" semantics for excluded nodes:
2128-
an excluded node is still allowed one round of AFIR exploration
2129-
when it is first registered, but is skipped on all subsequent calls.
2170+
Uses the O(1) ``_explored_node_ids`` set on :class:`ExploredPairsLog`
2171+
instead of a linear scan over all explored tuples.
21302172
"""
2131-
return any(nid == node_id for (nid, *_) in self.explored_log._explored)
2173+
return node_id in self.explored_log._explored_node_ids
21322174

21332175
def _enqueue_perturbations(self, node: EQNode, force_add: bool = False) -> None:
21342176
if node.coords.size == 0:
@@ -2182,17 +2224,21 @@ def _enqueue_perturbations(self, node: EQNode, force_add: bool = False) -> None:
21822224
gamma_signs.append("-")
21832225

21842226
# ── Filter out already-explored and already-queued pairs ─────────
2185-
unexplored: list[tuple[int, int, str]] = []
2186-
for (i0, j0) in all_candidates:
2187-
atom_i, atom_j = i0 + 1, j0 + 1 # convert to 1-based
2188-
for sign in gamma_signs:
2189-
if self.explored_log.has(node.node_id, atom_i, atom_j, sign):
2190-
continue # already run — skip
2191-
gamma_str = neg_gamma_str if sign == "-" else pos_gamma_str
2192-
queue_key = (node.node_id, (gamma_str, str(atom_i), str(atom_j)))
2193-
if queue_key in self.queue._submitted:
2194-
continue # already in the queue — skip
2195-
unexplored.append((i0, j0, sign))
2227+
# Build unexplored list with a single list comprehension to avoid
2228+
# Python-level nested loop overhead. Local variable aliases avoid
2229+
# repeated attribute lookups inside the hot path.
2230+
nid = node.node_id
2231+
explored_set = self.explored_log._explored # set[(nid,i,j,sign)]
2232+
submitted_set = self.queue._submitted # set[(nid, tuple(params))]
2233+
sign_to_gamma = {"+": pos_gamma_str, "-": neg_gamma_str}
2234+
2235+
unexplored: list[tuple[int, int, str]] = [
2236+
(i0, j0, sign)
2237+
for (i0, j0) in all_candidates
2238+
for sign in gamma_signs
2239+
if (nid, i0 + 1, j0 + 1, sign) not in explored_set
2240+
and (nid, (sign_to_gamma[sign], str(i0 + 1), str(j0 + 1))) not in submitted_set
2241+
]
21962242

21972243
if not unexplored:
21982244
logger.debug(
@@ -2215,7 +2261,7 @@ def _enqueue_perturbations(self, node: EQNode, force_add: bool = False) -> None:
22152261
for idx in chosen_indices:
22162262
i0, j0, sign = unexplored[int(idx)]
22172263
atom_i, atom_j = i0 + 1, j0 + 1
2218-
gamma_str = neg_gamma_str if sign == "-" else pos_gamma_str
2264+
gamma_str = sign_to_gamma[sign]
22192265
afir_params = [gamma_str, str(atom_i), str(atom_j)]
22202266
task = ExplorationTask(
22212267
node_id=node.node_id,

0 commit comments

Comments
 (0)