1313import os
1414
1515import numpy as np
16+ from scipy .linalg import lu_factor , lu_solve
1617from multioptpy .Wrapper .mapper import ExplorationQueue , ExplorationTask
1718
1819logger = logging .getLogger (__name__ )
@@ -131,6 +132,53 @@ def _save_K_matrix(
131132 except OSError as exc :
132133 logger .warning ("RCMC contracted K matrix could not be saved: %s" , exc )
133134
135+ def _save_population (
136+ self ,
137+ q : "np.ndarray" ,
138+ nodes : list ,
139+ pop_count : int ,
140+ ) -> None :
141+ """Append the per-node transient population distribution to the
142+ contracted K-matrix CSV file.
143+
144+ Called immediately after :meth:`_save_K_matrix` so the two tables
145+ appear in the same file, separated by a blank line.
146+
147+ File path: ``{output_dir}/rcmc_K_contracted.csv``
148+
149+ Format (appended below the K-matrix block)
150+ -------------------------------------------
151+ # RCMC transient population pop_step=N T=300.0 K t=1.0 s
152+ node,population
153+ EQ0,4.56000000e-01
154+ EQ1,3.21000000e-01
155+ ...
156+ """
157+ if self .output_dir is None :
158+ return
159+ try :
160+ fpath = os .path .join (self .output_dir , "rcmc_K_contracted.csv" )
161+ with open (fpath , "a" , encoding = "utf-8" ) as fh :
162+ fh .write ("\n " )
163+ fh .write (
164+ f"# RCMC transient population "
165+ f"pop_step={ pop_count } "
166+ f"T={ self .temperature_K } K "
167+ f"t={ self .reaction_time_s } s\n "
168+ )
169+ fh .write ("node,population\n " )
170+ for i , node in enumerate (nodes ):
171+ fh .write (f"EQ{ node .node_id } ,{ q [i ]:.8e} \n " )
172+ logger .info (
173+ "RCMC population distribution appended: %s "
174+ "(pop_step=%d n_nodes=%d)" ,
175+ fpath ,
176+ pop_count ,
177+ len (nodes ),
178+ )
179+ except OSError as exc :
180+ logger .warning ("RCMC population CSV could not be saved: %s" , exc )
181+
134182 def pop (self ) -> ExplorationTask | None :
135183 if not self ._tasks :
136184 return None
@@ -201,57 +249,86 @@ def pop(self) -> ExplorationTask | None:
201249 # Initially every node is its own super-state.
202250 superstate_members : dict [int , list [int ]] = {i : [i ] for i in range (n_nodes )}
203251
252+ # ── Incremental K_SS buffer ───────────────────────────────────────
253+ # Maintained by block-appending each newly contracted node so we
254+ # avoid rebuilding via fancy indexing (K[np.ix_(S,S)]) every step.
255+ K_SS_buf : np .ndarray = np .empty ((0 , 0 ), dtype = np .float64 )
256+
204257 while len (T ) > 1 :
205- diag_D = np .abs (np .diag (D ))
206- j_local = np .argmax (diag_D )
258+ # Only the diagonal is needed for argmax — extract with np.diag
259+ # rather than keeping a full abs matrix.
260+ j_local = int (np .argmax (np .abs (np .diag (D ))))
207261 j_global = T [j_local ]
208262 D_jj = D [j_local , j_local ]
209263
210264 if abs (D_jj ) < 1e-30 :
211265 break
212266
213- mask = np .arange (len (T )) != j_local
214- D_TT = D [np .ix_ (mask , mask )]
215- D_Tj = D [mask , j_local ].reshape (- 1 , 1 )
216- D_jT = D [j_local , mask ].reshape (1 , - 1 )
217-
218- D_new = D_TT - (D_Tj @ D_jT ) / D_jj
219-
220- # Recalculate diagonal elements to preserve numerical stability
221- for i in range (D_new .shape [0 ]):
222- D_new [i , i ] = - np .sum (D_new [:, i ]) + D_new [i , i ]
267+ # Boolean mask is faster than np.arange comparison for slicing.
268+ mask = np .ones (len (T ), dtype = bool )
269+ mask [j_local ] = False
270+
271+ D_Tj = D [mask , j_local ] # shape (n-1,)
272+ D_jT = D [j_local , mask ] # shape (n-1,)
273+
274+ # Schur-complement rank-1 update.
275+ D_new = D [np .ix_ (mask , mask )] - np .outer (D_Tj , D_jT ) / D_jj
276+
277+ # Vectorised diagonal correction (replaces Python for-loop):
278+ # enforce column-sum-to-zero so numerical drift does not accumulate.
279+ off_diag_col_sums = D_new .sum (axis = 0 ) - D_new .diagonal ()
280+ np .fill_diagonal (D_new , - off_diag_col_sums )
223281
224282 # Assign j to the T state most strongly coupled to it,
225283 # then merge j's member list into that state's members.
226- coupling = np .abs (D [mask , j_local ])
227- if coupling .max () > 0 :
228- absorb_local = int (np .argmax (coupling ))
229- else :
230- absorb_local = 0
284+ # D_Tj was already computed above — reuse it.
285+ coupling = np .abs (D_Tj )
286+ absorb_local = int (np .argmax (coupling )) if coupling .max () > 0 else 0
231287 remaining_T = [t for k , t in enumerate (T ) if k != j_local ]
232288 absorb_global = remaining_T [absorb_local ]
233289 superstate_members [absorb_global ].extend (
234290 superstate_members .pop (j_global )
235291 )
236292
293+ # ── Incremental K_SS expansion ────────────────────────────────
294+ # Append j_global as a new row/column instead of re-indexing K.
295+ if K_SS_buf .size == 0 :
296+ K_SS_buf = np .array ([[K [j_global , j_global ]]])
297+ else :
298+ new_col = K [S , j_global ].reshape (- 1 , 1 )
299+ new_row = K [j_global , S ].reshape (1 , - 1 )
300+ K_SS_buf = np .block ([
301+ [K_SS_buf , new_col ],
302+ [new_row , np .array ([[K [j_global , j_global ]]])]
303+ ])
304+
237305 S .append (j_global )
238306 T .pop (j_local )
239307 D = D_new
240308
241- if len (S ) > 0 :
242- K_SS = K [np .ix_ (S , S )]
243- try :
244- inv_1S = np .linalg .solve (- K_SS , np .ones (len (S )))
245- inv_T_1S = np .linalg .solve (- K_SS .T , np .ones (len (S )))
246- rho_KSS_inv = min (np .max (inv_1S ), np .max (inv_T_1S ))
247- sigma_KSS = 1.0 / rho_KSS_inv if rho_KSS_inv > 0 else 1e-30
248- except np .linalg .LinAlgError :
249- sigma_KSS = 1e-30
250-
251- rho_D = min (np .max (np .sum (np .abs (D ), axis = 1 )), np .max (np .sum (np .abs (D ), axis = 0 )))
309+ # ── Convergence check using a single LU factorisation ─────────
310+ # lu_solve(…, trans=1) solves the transposed system without a
311+ # second factorisation, replacing two separate np.linalg.solve
312+ # calls on -K_SS and -K_SS.T.
313+ try :
314+ lu , piv = lu_factor (- K_SS_buf )
315+ ones_S = np .ones (len (S ))
316+ inv_1S = lu_solve ((lu , piv ), ones_S )
317+ inv_T_1S = lu_solve ((lu , piv ), ones_S , trans = 1 )
318+ rho_KSS_inv = min (float (np .max (inv_1S )), float (np .max (inv_T_1S )))
319+ sigma_KSS = 1.0 / rho_KSS_inv if rho_KSS_inv > 0 else 1e-30
320+ except Exception :
321+ sigma_KSS = 1e-30
322+
323+ # Compute abs(D) once and reuse for both axis sums.
324+ abs_D = np .abs (D )
325+ rho_D = min (
326+ float (np .max (abs_D .sum (axis = 1 ))),
327+ float (np .max (abs_D .sum (axis = 0 ))),
328+ )
252329
253- if sigma_KSS > 0 and rho_D > 0 :
254- time_current = np .log (2 ) / np .sqrt (sigma_KSS * rho_D )
330+ if sigma_KSS > 0 and rho_D > 0 :
331+ time_current = np .log (2 ) / np .sqrt (sigma_KSS * rho_D )
255332
256333 logger .debug (
257334 "RCMC contraction step %d: contracted_nodes=%s, "
@@ -268,20 +345,21 @@ def pop(self) -> ExplorationTask | None:
268345
269346 # ── Update contracted super-state K matrix (D at termination) ────
270347 self ._save_K_matrix (D , T , superstate_members , nodes , self ._pop_count )
271- self ._pop_count += 1
272348
273349 q = np .zeros (n_nodes , dtype = np .float64 )
274350 if len (S ) > 0 and len (T ) > 0 :
275- K_SS = K [ np . ix_ ( S , S )]
351+ # K_SS_buf is already up to date — no need to re-index K.
276352 K_ST = K [np .ix_ (S , T )]
277353 K_TS = K [np .ix_ (T , S )]
278354 p_S = p [S ]
279355 p_T = p [T ]
280356
281357 try :
282- X_ST = np .linalg .solve (K_SS , K_ST )
283- X_pS = np .linalg .solve (K_SS , p_S )
284- X_ST_2 = np .linalg .solve (K_SS , X_ST )
358+ # Factorise K_SS once; reuse for the three back-solves.
359+ lu , piv = lu_factor (K_SS_buf )
360+ X_ST = lu_solve ((lu , piv ), K_ST )
361+ X_pS = lu_solve ((lu , piv ), p_S )
362+ X_ST_2 = lu_solve ((lu , piv ), X_ST )
285363
286364 M = np .eye (len (T )) + K_TS @ X_ST_2
287365 m_vec = np .sum (M , axis = 0 )
@@ -295,15 +373,19 @@ def pop(self) -> ExplorationTask | None:
295373 q_S = np .maximum (q_S , 0.0 )
296374 q [T ] = q_T
297375 q [S ] = q_S
298- total_q = np .sum (q )
376+ total_q = float ( np .sum (q ) )
299377 if total_q > 0.0 :
300378 q /= total_q
301379
302- except np . linalg . LinAlgError :
380+ except Exception :
303381 q = p
304382 else :
305383 q = p
306384
385+ # ── Append population distribution below the K-matrix CSV ────────
386+ self ._save_population (q , nodes , self ._pop_count )
387+ self ._pop_count += 1
388+
307389 for task in self ._tasks :
308390 if task .node_id in node_to_idx :
309391 idx = node_to_idx [task .node_id ]
0 commit comments