Skip to content
This repository was archived by the owner on May 5, 2026. It is now read-only.

Commit 4480f4f

Browse files
authored
Add files via upload
1 parent 6d62782 commit 4480f4f

1 file changed

Lines changed: 190 additions & 60 deletions

File tree

multioptpy/Wrapper/mapper.py

Lines changed: 190 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import bisect
1212
import glob
13+
import heapq
1314
import json
1415
import logging
1516
import os
@@ -92,8 +93,8 @@ def parse_xyz(filepath: str) -> tuple[list[str], np.ndarray]:
9293
return symbols, np.array(coords_raw, dtype=float)
9394

9495
def distance_matrix(coords: np.ndarray) -> np.ndarray:
95-
diff = coords[:, np.newaxis, :] - coords[np.newaxis, :, :]
96-
return np.sqrt((diff ** 2).sum(axis=-1))
96+
# cdist avoids the (N,N,3) intermediate array produced by manual broadcasting.
97+
return cdist(coords, coords)
9798

9899

99100
# ===========================================================================
@@ -110,6 +111,15 @@ class StructureChecker:
110111
# Relative tolerance for declaring two eigenvalues degenerate.
111112
_DEGENERACY_REL_TOL: float = 0.02
112113

114+
# The 4 proper rotations (det=+1) from PCA sign-flip ambiguity.
115+
# Built once at class definition time instead of on every call.
116+
_SIGN_FLIP_MATS: tuple[np.ndarray, ...] = (
117+
np.diag([ 1.0, 1.0, 1.0]),
118+
np.diag([-1.0, -1.0, 1.0]),
119+
np.diag([-1.0, 1.0, -1.0]),
120+
np.diag([ 1.0, -1.0, -1.0]),
121+
)
122+
113123
def __init__(self, rmsd_threshold: float = 0.30) -> None:
114124
self.rmsd_threshold = rmsd_threshold
115125

@@ -194,16 +204,34 @@ def _try_candidates(
194204
sym_a: list[str], ca: np.ndarray,
195205
sym_b: list[str], cb: np.ndarray,
196206
) -> float:
197-
"""Evaluate every rotation candidate and return the minimum RMSD found."""
207+
"""Evaluate every rotation candidate and return the minimum RMSD found.
208+
209+
Element-to-index groupings are precomputed once before the rotation
210+
loop so that the O(N) list comprehensions inside ``_optimal_mapping``
211+
are not repeated for every candidate.
212+
"""
213+
# --- Precompute element groups once (not per rotation) ---
214+
elem_groups_a: dict[str, np.ndarray] = {}
215+
for i, s in enumerate(sym_a):
216+
elem_groups_a.setdefault(s, []).append(i) # type: ignore[arg-type]
217+
groups_a = {e: np.array(v, dtype=np.intp) for e, v in elem_groups_a.items()}
218+
219+
elem_groups_b: dict[str, np.ndarray] = {}
220+
for i, s in enumerate(sym_b):
221+
elem_groups_b.setdefault(s, []).append(i) # type: ignore[arg-type]
222+
groups_b = {e: np.array(v, dtype=np.intp) for e, v in elem_groups_b.items()}
223+
198224
min_rmsd = float("inf")
199225
for R in candidates:
200226
cb_rot = cb @ R.T
201-
perm = self._optimal_mapping(sym_a, ca, sym_b, cb_rot)
227+
perm = self._optimal_mapping_fast(ca, cb_rot, groups_a, groups_b)
202228
if perm is None:
203229
continue
204230
rmsd = self._kabsch_rmsd(ca, cb_rot[perm])
205231
if rmsd < min_rmsd:
206232
min_rmsd = rmsd
233+
if min_rmsd < self.rmsd_threshold:
234+
return min_rmsd # Early exit: threshold already met
207235
return min_rmsd
208236

209237
# ------------------------------------------------------------------ #
@@ -246,19 +274,14 @@ def _pca_align(coords: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
246274
# Rotation candidates #
247275
# ------------------------------------------------------------------ #
248276

249-
@staticmethod
250-
def _sign_flip_candidates() -> list[np.ndarray]:
277+
@classmethod
278+
def _sign_flip_candidates(cls) -> list[np.ndarray]:
251279
"""
252280
The 4 proper rotations (det = +1) arising from sign-flip ambiguity
253281
of PCA eigenvectors. Always necessary; sufficient when no
254282
eigenvalue degeneracy is present.
255283
"""
256-
return [
257-
np.diag([ 1.0, 1.0, 1.0]),
258-
np.diag([-1.0, -1.0, 1.0]),
259-
np.diag([-1.0, 1.0, -1.0]),
260-
np.diag([ 1.0, -1.0, -1.0]),
261-
]
284+
return list(cls._SIGN_FLIP_MATS)
262285

263286
@classmethod
264287
def _build_planar_candidates(
@@ -323,25 +346,54 @@ def _so3_grid(n: int) -> list[np.ndarray]:
323346
so that grid points are roughly uniformly distributed on S².
324347
325348
Total: n³ rotation matrices (512 for n = 8).
326-
"""
327-
rotations: list[np.ndarray] = []
328-
for i in range(n):
329-
alpha = 2 * np.pi * i / n
330-
ca, sa = np.cos(alpha), np.sin(alpha)
331-
Rz_alpha = np.array([[ca, -sa, 0.0], [sa, ca, 0.0], [0.0, 0.0, 1.0]])
332-
333-
for j in range(n):
334-
beta = np.arccos(np.clip(1.0 - 2.0 * (j + 0.5) / n, -1.0, 1.0))
335-
cb, sb = np.cos(beta), np.sin(beta)
336-
Ry_beta = np.array([[cb, 0.0, sb], [0.0, 1.0, 0.0], [-sb, 0.0, cb]])
337349
338-
for k in range(n):
339-
gamma = 2 * np.pi * k / n
340-
cg, sg = np.cos(gamma), np.sin(gamma)
341-
Rz_gamma = np.array([[cg, -sg, 0.0], [sg, cg, 0.0], [0.0, 0.0, 1.0]])
342-
rotations.append(Rz_alpha @ Ry_beta @ Rz_gamma)
343-
344-
return rotations
350+
Vectorised: all n³ matrix products are computed with batched numpy
351+
operations instead of a triple Python for-loop.
352+
"""
353+
# --- Angle arrays ---
354+
k = np.arange(n)
355+
alphas = 2.0 * np.pi * k / n # (n,)
356+
betas = np.arccos(np.clip(1.0 - 2.0 * (k + 0.5) / n, -1.0, 1.0)) # (n,)
357+
gammas = 2.0 * np.pi * k / n # (n,)
358+
359+
# --- Component trig values ---
360+
ca, sa = np.cos(alphas), np.sin(alphas) # (n,)
361+
cb, sb = np.cos(betas), np.sin(betas) # (n,)
362+
cg, sg = np.cos(gammas), np.sin(gammas) # (n,)
363+
364+
# --- Batched Rz(alpha): shape (n, 3, 3) ---
365+
zero = np.zeros(n)
366+
one = np.ones(n)
367+
Rz_a = np.stack([
368+
np.stack([ ca, -sa, zero], axis=-1),
369+
np.stack([ sa, ca, zero], axis=-1),
370+
np.stack([zero, zero, one], axis=-1),
371+
], axis=1) # (n, 3, 3)
372+
373+
# --- Batched Ry(beta): shape (n, 3, 3) ---
374+
Ry_b = np.stack([
375+
np.stack([ cb, zero, sb], axis=-1),
376+
np.stack([zero, one, zero], axis=-1),
377+
np.stack([-sb, zero, cb], axis=-1),
378+
], axis=1) # (n, 3, 3)
379+
380+
# --- Batched Rz(gamma): shape (n, 3, 3) ---
381+
Rz_g = np.stack([
382+
np.stack([ cg, -sg, zero], axis=-1),
383+
np.stack([ sg, cg, zero], axis=-1),
384+
np.stack([zero, zero, one], axis=-1),
385+
], axis=1) # (n, 3, 3)
386+
387+
# --- ZYZ product over all (n, n, n) combinations ---
388+
# Rz_a[:, None, None] @ Ry_b[None, :, None] @ Rz_g[None, None, :]
389+
# Broadcasting shapes: (n,1,1,3,3) @ (1,n,1,3,3) @ (1,1,n,3,3)
390+
Rza = Rz_a[:, None, None] # (n, 1, 1, 3, 3)
391+
Ryb = Ry_b[None, :, None] # (1, n, 1, 3, 3)
392+
Rzg = Rz_g[None, None, :] # (1, 1, n, 3, 3)
393+
R_all = Rza @ Ryb @ Rzg # (n, n, n, 3, 3)
394+
395+
# Flatten to list of (3,3) matrices.
396+
return list(R_all.reshape(-1, 3, 3))
345397

346398
@staticmethod
347399
def _Rx(t: float) -> np.ndarray:
@@ -357,13 +409,36 @@ def _Rz(t: float) -> np.ndarray:
357409
# Atom mapping (Hungarian algorithm) #
358410
# ------------------------------------------------------------------ #
359411

412+
@staticmethod
413+
def _optimal_mapping_fast(
414+
coords_a: np.ndarray,
415+
coords_b: np.ndarray,
416+
groups_a: dict[str, np.ndarray],
417+
groups_b: dict[str, np.ndarray],
418+
) -> list[int] | None:
419+
"""Find the atom permutation of B minimising total squared distance to A.
420+
421+
Accepts precomputed element-to-index arrays (``groups_a``, ``groups_b``)
422+
so that the grouping step is not repeated for every rotation candidate.
423+
Returns ``None`` if stoichiometry is inconsistent.
424+
"""
425+
perm: list[int | None] = [None] * sum(len(v) for v in groups_a.values())
426+
for elem, idx_a in groups_a.items():
427+
idx_b = groups_b.get(elem)
428+
if idx_b is None or len(idx_a) != len(idx_b):
429+
return None
430+
cost = cdist(coords_a[idx_a], coords_b[idx_b], metric="sqeuclidean")
431+
row_ind, col_ind = linear_sum_assignment(cost)
432+
for r, c in zip(row_ind, col_ind):
433+
perm[idx_a[r]] = idx_b[c]
434+
return None if None in perm else perm # type: ignore[return-value]
435+
360436
@staticmethod
361437
def _optimal_mapping(
362438
sym_a: list[str], coords_a: np.ndarray,
363439
sym_b: list[str], coords_b: np.ndarray,
364440
) -> list[int] | None:
365-
"""
366-
Find the permutation of B's atoms that minimises the total
441+
"""Find the permutation of B's atoms that minimises the total
367442
squared distance to A, solved independently per element.
368443
Returns None if stoichiometry is inconsistent.
369444
"""
@@ -443,17 +518,41 @@ def fingerprint(
443518
Each key is a ``(elem_a, elem_b)`` tuple with elements in sorted
444519
order (so C–H and H–C map to the same key). The value is the
445520
number of such bonds.
521+
522+
Radii are precomputed per unique element so ``_bond_threshold`` is
523+
called at most once per element instead of once per atom-pair.
524+
Distances are computed with ``cdist`` to avoid a Python-level O(N²)
525+
loop.
446526
"""
447527
n = len(symbols)
448-
dmat = distance_matrix(coords)
449-
counts: dict[tuple[str, str], int] = {}
528+
# Precompute covalent radius for each unique element.
529+
unique_elems = set(symbols)
530+
elem_radius: dict[str, float] = {}
531+
for elem in unique_elems:
532+
if covalent_radii_lib is not None:
533+
try:
534+
elem_radius[elem] = covalent_radii_lib(elem) * _BOHR2ANG
535+
except KeyError:
536+
elem_radius[elem] = 0.75 # generic fallback [Å]
537+
else:
538+
elem_radius[elem] = 0.75
539+
540+
radii_arr = np.array([elem_radius[s] for s in symbols], dtype=np.float64)
541+
542+
# Pairwise distances — vectorised via cdist.
543+
dmat = cdist(coords, coords)
544+
ii, jj = np.triu_indices(n, k=1)
545+
dists = dmat[ii, jj]
450546

451-
for i in range(n):
452-
for j in range(i + 1, n):
453-
threshold = self._bond_threshold(symbols[i], symbols[j])
454-
if dmat[i, j] <= threshold:
455-
key = (min(symbols[i], symbols[j]), max(symbols[i], symbols[j]))
456-
counts[key] = counts.get(key, 0) + 1
547+
# Per-pair bonding threshold.
548+
thresholds = self.covalent_margin * (radii_arr[ii] + radii_arr[jj])
549+
bonded_idx = np.where(dists <= thresholds)[0]
550+
551+
counts: dict[tuple[str, str], int] = {}
552+
for k in bonded_idx:
553+
si, sj = symbols[ii[k]], symbols[jj[k]]
554+
key = (si, sj) if si <= sj else (sj, si)
555+
counts[key] = counts.get(key, 0) + 1
457556

458557
return counts
459558

@@ -531,11 +630,15 @@ def compute_priority(self, task):
531630
"""
532631

533632
def __init__(self, rng_seed: int = 42) -> None:
633+
# _tasks: canonical list of ExplorationTask objects.
634+
# Kept as a real list so that subclasses (e.g. RCMCQueue) can call
635+
# .sort() and .pop(0) on it directly without breaking.
534636
self._tasks: list[ExplorationTask] = []
535-
# Parallel list of negated priorities kept in ascending order so that
536-
# bisect can locate the insertion point in O(log n) without an O(n)
537-
# list comprehension on every push().
538-
self._neg_priorities: list[float] = []
637+
# _heap: parallel min-heap of (-priority, counter, task) used by the
638+
# base-class push()/pop() for O(log n) insertion and extraction.
639+
# RCMCQueue overrides pop() entirely and never touches _heap.
640+
self._heap: list[tuple[float, int, ExplorationTask]] = []
641+
self._push_counter: int = 0
539642
self._submitted: set[tuple] = set()
540643
self._rng = np.random.default_rng(rng_seed)
541644

@@ -545,21 +648,27 @@ def push(self, task: ExplorationTask) -> bool:
545648
return False
546649

547650
task.priority = self.compute_priority(task)
548-
# _tasks is maintained in descending priority order. _neg_priorities
549-
# mirrors it as ascending negated values so bisect can find the correct
550-
# insertion index in O(log n) without rebuilding the list each call.
551-
neg_p = -task.priority
552-
idx = bisect.bisect_right(self._neg_priorities, neg_p)
553-
self._tasks.insert(idx, task)
554-
self._neg_priorities.insert(idx, neg_p)
651+
# Update both _tasks (for subclass access) and _heap (for base-class pop).
652+
self._tasks.append(task)
653+
heapq.heappush(self._heap, (-task.priority, self._push_counter, task))
654+
self._push_counter += 1
555655
self._submitted.add(key)
556656
return True
557657

558658
def pop(self) -> ExplorationTask | None:
559-
if not self._tasks:
659+
"""Pop the highest-priority task using the heap (O(log n)).
660+
661+
Also removes the task from ``_tasks`` so subclasses that iterate
662+
``_tasks`` see a consistent state.
663+
"""
664+
if not self._heap:
560665
return None
561-
self._neg_priorities.pop(0)
562-
return self._tasks.pop(0)
666+
_, _, task = heapq.heappop(self._heap)
667+
try:
668+
self._tasks.remove(task) # O(n) but only called in base-class path
669+
except ValueError:
670+
pass
671+
return task
563672

564673
def should_add(self, node: "EQNode", reference_energy: float, **kwargs) -> bool:
565674
"""Decide probabilistically whether to enqueue a node.
@@ -631,9 +740,10 @@ def refresh_priorities(self, ref_e: float | None) -> None:
631740
task.metadata["delta_E_hartree"] = eff_e - ref_e
632741
task.priority = self.compute_priority(task)
633742

634-
self._tasks.sort(key=lambda t: t.priority, reverse=True)
635-
# Rebuild the parallel negated-priority list to stay in sync after sort.
636-
self._neg_priorities = [-t.priority for t in self._tasks]
743+
# Rebuild the heap from the updated _tasks list.
744+
self._heap = [(-t.priority, i, t) for i, t in enumerate(self._tasks)]
745+
heapq.heapify(self._heap)
746+
self._push_counter = len(self._heap)
637747

638748
def export_queue_status(self) -> list[dict]:
639749
return [
@@ -1047,15 +1157,23 @@ def to_dict(self) -> dict:
10471157
data[k] = str(v)
10481158
return data
10491159

1160+
# Sentinel object used by NetworkGraph to distinguish "not yet computed"
1161+
# from a cached value of None (meaning no energy is available).
1162+
_UNSET = object()
1163+
1164+
10501165
class NetworkGraph:
10511166
def __init__(self) -> None:
10521167
self._nodes: dict[int, EQNode] = {}
10531168
self._edges: dict[int, TSEdge] = {}
10541169
self._node_counter: int = 0
10551170
self._edge_counter: int = 0
1171+
# Cached reference energy; set to _UNSET when invalidated.
1172+
self._ref_energy_cache: float | None = _UNSET # type: ignore[assignment]
10561173

10571174
def add_node(self, node: EQNode) -> None:
10581175
self._nodes[node.node_id] = node
1176+
self._ref_energy_cache = _UNSET # type: ignore[assignment]
10591177

10601178
def get_node(self, node_id: int) -> EQNode | None:
10611179
return self._nodes.get(node_id)
@@ -1091,15 +1209,25 @@ def reference_energy(self) -> float | None:
10911209
intended mixed-mode behaviour (see Q2 design decision).
10921210
* When **no** node has free energy, fall back to the minimum
10931211
electronic energy, preserving the original behaviour.
1212+
1213+
The result is cached and automatically invalidated whenever a new
1214+
node is added via :meth:`add_node`.
10941215
"""
1216+
if self._ref_energy_cache is not _UNSET:
1217+
return self._ref_energy_cache # type: ignore[return-value]
1218+
10951219
free_energies = [
10961220
n.free_energy for n in self._nodes.values()
10971221
if n.free_energy is not None
10981222
]
10991223
if free_energies:
1100-
return min(free_energies)
1101-
real_energies = [n.energy for n in self._nodes.values() if n.has_real_energy]
1102-
return min(real_energies) if real_energies else None
1224+
result: float | None = min(free_energies)
1225+
else:
1226+
real_energies = [n.energy for n in self._nodes.values() if n.has_real_energy]
1227+
result = min(real_energies) if real_energies else None
1228+
1229+
self._ref_energy_cache = result # type: ignore[assignment]
1230+
return result
11031231

11041232
def save(self, filepath: str) -> None:
11051233
data = {
@@ -2285,6 +2413,8 @@ def _flush_node_energy_updates(self) -> None:
22852413
)
22862414

22872415
self._pending_node_updates.clear()
2416+
# Node energies may have changed; invalidate the cached reference energy.
2417+
self.graph._ref_energy_cache = _UNSET # type: ignore[assignment]
22882418

22892419
def _find_or_register_node(
22902420
self,

0 commit comments

Comments
 (0)