Skip to content
This repository was archived by the owner on May 5, 2026. It is now read-only.

Commit c3c16d0

Browse files
authored
Add files via upload
1 parent 97f8c90 commit c3c16d0

1 file changed

Lines changed: 119 additions & 37 deletions

File tree

multioptpy/Utils/rcmc.py

Lines changed: 119 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import os
1414

1515
import numpy as np
16+
from scipy.linalg import lu_factor, lu_solve
1617
from multioptpy.Wrapper.mapper import ExplorationQueue, ExplorationTask
1718

1819
logger = logging.getLogger(__name__)
@@ -131,6 +132,53 @@ def _save_K_matrix(
131132
except OSError as exc:
132133
logger.warning("RCMC contracted K matrix could not be saved: %s", exc)
133134

135+
def _save_population(
136+
self,
137+
q: "np.ndarray",
138+
nodes: list,
139+
pop_count: int,
140+
) -> None:
141+
"""Append the per-node transient population distribution to the
142+
contracted K-matrix CSV file.
143+
144+
Called immediately after :meth:`_save_K_matrix` so the two tables
145+
appear in the same file, separated by a blank line.
146+
147+
File path: ``{output_dir}/rcmc_K_contracted.csv``
148+
149+
Format (appended below the K-matrix block)
150+
-------------------------------------------
151+
# RCMC transient population pop_step=N T=300.0 K t=1.0 s
152+
node,population
153+
EQ0,4.56000000e-01
154+
EQ1,3.21000000e-01
155+
...
156+
"""
157+
if self.output_dir is None:
158+
return
159+
try:
160+
fpath = os.path.join(self.output_dir, "rcmc_K_contracted.csv")
161+
with open(fpath, "a", encoding="utf-8") as fh:
162+
fh.write("\n")
163+
fh.write(
164+
f"# RCMC transient population "
165+
f"pop_step={pop_count} "
166+
f"T={self.temperature_K} K "
167+
f"t={self.reaction_time_s} s\n"
168+
)
169+
fh.write("node,population\n")
170+
for i, node in enumerate(nodes):
171+
fh.write(f"EQ{node.node_id},{q[i]:.8e}\n")
172+
logger.info(
173+
"RCMC population distribution appended: %s "
174+
"(pop_step=%d n_nodes=%d)",
175+
fpath,
176+
pop_count,
177+
len(nodes),
178+
)
179+
except OSError as exc:
180+
logger.warning("RCMC population CSV could not be saved: %s", exc)
181+
134182
def pop(self) -> ExplorationTask | None:
135183
if not self._tasks:
136184
return None
@@ -201,57 +249,86 @@ def pop(self) -> ExplorationTask | None:
201249
# Initially every node is its own super-state.
202250
superstate_members: dict[int, list[int]] = {i: [i] for i in range(n_nodes)}
203251

252+
# ── Incremental K_SS buffer ───────────────────────────────────────
253+
# Maintained by block-appending each newly contracted node so we
254+
# avoid rebuilding via fancy indexing (K[np.ix_(S,S)]) every step.
255+
K_SS_buf: np.ndarray = np.empty((0, 0), dtype=np.float64)
256+
204257
while len(T) > 1:
205-
diag_D = np.abs(np.diag(D))
206-
j_local = np.argmax(diag_D)
258+
# Only the diagonal is needed for argmax — extract with np.diag
259+
# rather than keeping a full abs matrix.
260+
j_local = int(np.argmax(np.abs(np.diag(D))))
207261
j_global = T[j_local]
208262
D_jj = D[j_local, j_local]
209263

210264
if abs(D_jj) < 1e-30:
211265
break
212266

213-
mask = np.arange(len(T)) != j_local
214-
D_TT = D[np.ix_(mask, mask)]
215-
D_Tj = D[mask, j_local].reshape(-1, 1)
216-
D_jT = D[j_local, mask].reshape(1, -1)
217-
218-
D_new = D_TT - (D_Tj @ D_jT) / D_jj
219-
220-
# Recalculate diagonal elements to preserve numerical stability
221-
for i in range(D_new.shape[0]):
222-
D_new[i, i] = -np.sum(D_new[:, i]) + D_new[i, i]
267+
# Boolean mask is faster than np.arange comparison for slicing.
268+
mask = np.ones(len(T), dtype=bool)
269+
mask[j_local] = False
270+
271+
D_Tj = D[mask, j_local] # shape (n-1,)
272+
D_jT = D[j_local, mask] # shape (n-1,)
273+
274+
# Schur-complement rank-1 update.
275+
D_new = D[np.ix_(mask, mask)] - np.outer(D_Tj, D_jT) / D_jj
276+
277+
# Vectorised diagonal correction (replaces Python for-loop):
278+
# enforce column-sum-to-zero so numerical drift does not accumulate.
279+
off_diag_col_sums = D_new.sum(axis=0) - D_new.diagonal()
280+
np.fill_diagonal(D_new, -off_diag_col_sums)
223281

224282
# Assign j to the T state most strongly coupled to it,
225283
# then merge j's member list into that state's members.
226-
coupling = np.abs(D[mask, j_local])
227-
if coupling.max() > 0:
228-
absorb_local = int(np.argmax(coupling))
229-
else:
230-
absorb_local = 0
284+
# D_Tj was already computed above — reuse it.
285+
coupling = np.abs(D_Tj)
286+
absorb_local = int(np.argmax(coupling)) if coupling.max() > 0 else 0
231287
remaining_T = [t for k, t in enumerate(T) if k != j_local]
232288
absorb_global = remaining_T[absorb_local]
233289
superstate_members[absorb_global].extend(
234290
superstate_members.pop(j_global)
235291
)
236292

293+
# ── Incremental K_SS expansion ────────────────────────────────
294+
# Append j_global as a new row/column instead of re-indexing K.
295+
if K_SS_buf.size == 0:
296+
K_SS_buf = np.array([[K[j_global, j_global]]])
297+
else:
298+
new_col = K[S, j_global].reshape(-1, 1)
299+
new_row = K[j_global, S].reshape(1, -1)
300+
K_SS_buf = np.block([
301+
[K_SS_buf, new_col ],
302+
[new_row, np.array([[K[j_global, j_global]]])]
303+
])
304+
237305
S.append(j_global)
238306
T.pop(j_local)
239307
D = D_new
240308

241-
if len(S) > 0:
242-
K_SS = K[np.ix_(S, S)]
243-
try:
244-
inv_1S = np.linalg.solve(-K_SS, np.ones(len(S)))
245-
inv_T_1S = np.linalg.solve(-K_SS.T, np.ones(len(S)))
246-
rho_KSS_inv = min(np.max(inv_1S), np.max(inv_T_1S))
247-
sigma_KSS = 1.0 / rho_KSS_inv if rho_KSS_inv > 0 else 1e-30
248-
except np.linalg.LinAlgError:
249-
sigma_KSS = 1e-30
250-
251-
rho_D = min(np.max(np.sum(np.abs(D), axis=1)), np.max(np.sum(np.abs(D), axis=0)))
309+
# ── Convergence check using a single LU factorisation ─────────
310+
# lu_solve(…, trans=1) solves the transposed system without a
311+
# second factorisation, replacing two separate np.linalg.solve
312+
# calls on -K_SS and -K_SS.T.
313+
try:
314+
lu, piv = lu_factor(-K_SS_buf)
315+
ones_S = np.ones(len(S))
316+
inv_1S = lu_solve((lu, piv), ones_S)
317+
inv_T_1S = lu_solve((lu, piv), ones_S, trans=1)
318+
rho_KSS_inv = min(float(np.max(inv_1S)), float(np.max(inv_T_1S)))
319+
sigma_KSS = 1.0 / rho_KSS_inv if rho_KSS_inv > 0 else 1e-30
320+
except Exception:
321+
sigma_KSS = 1e-30
322+
323+
# Compute abs(D) once and reuse for both axis sums.
324+
abs_D = np.abs(D)
325+
rho_D = min(
326+
float(np.max(abs_D.sum(axis=1))),
327+
float(np.max(abs_D.sum(axis=0))),
328+
)
252329

253-
if sigma_KSS > 0 and rho_D > 0:
254-
time_current = np.log(2) / np.sqrt(sigma_KSS * rho_D)
330+
if sigma_KSS > 0 and rho_D > 0:
331+
time_current = np.log(2) / np.sqrt(sigma_KSS * rho_D)
255332

256333
logger.debug(
257334
"RCMC contraction step %d: contracted_nodes=%s, "
@@ -268,20 +345,21 @@ def pop(self) -> ExplorationTask | None:
268345

269346
# ── Update contracted super-state K matrix (D at termination) ────
270347
self._save_K_matrix(D, T, superstate_members, nodes, self._pop_count)
271-
self._pop_count += 1
272348

273349
q = np.zeros(n_nodes, dtype=np.float64)
274350
if len(S) > 0 and len(T) > 0:
275-
K_SS = K[np.ix_(S, S)]
351+
# K_SS_buf is already up to date — no need to re-index K.
276352
K_ST = K[np.ix_(S, T)]
277353
K_TS = K[np.ix_(T, S)]
278354
p_S = p[S]
279355
p_T = p[T]
280356

281357
try:
282-
X_ST = np.linalg.solve(K_SS, K_ST)
283-
X_pS = np.linalg.solve(K_SS, p_S)
284-
X_ST_2 = np.linalg.solve(K_SS, X_ST)
358+
# Factorise K_SS once; reuse for the three back-solves.
359+
lu, piv = lu_factor(K_SS_buf)
360+
X_ST = lu_solve((lu, piv), K_ST)
361+
X_pS = lu_solve((lu, piv), p_S)
362+
X_ST_2 = lu_solve((lu, piv), X_ST)
285363

286364
M = np.eye(len(T)) + K_TS @ X_ST_2
287365
m_vec = np.sum(M, axis=0)
@@ -295,15 +373,19 @@ def pop(self) -> ExplorationTask | None:
295373
q_S = np.maximum(q_S, 0.0)
296374
q[T] = q_T
297375
q[S] = q_S
298-
total_q = np.sum(q)
376+
total_q = float(np.sum(q))
299377
if total_q > 0.0:
300378
q /= total_q
301379

302-
except np.linalg.LinAlgError:
380+
except Exception:
303381
q = p
304382
else:
305383
q = p
306384

385+
# ── Append population distribution below the K-matrix CSV ────────
386+
self._save_population(q, nodes, self._pop_count)
387+
self._pop_count += 1
388+
307389
for task in self._tasks:
308390
if task.node_id in node_to_idx:
309391
idx = node_to_idx[task.node_id]

0 commit comments

Comments
 (0)