11# package/src/openflash/problem_cache.py
22import numpy as np
3- from typing import Callable , Dict , Any , Optional , List
3+ from typing import Callable , Dict , Any , Optional , List , Tuple
44
55from openflash .multi_equations import *
66
@@ -9,20 +9,29 @@ def __init__(self, problem):
99 self .problem = problem
1010 self .A_template : Optional [np .ndarray ] = None
1111 self .b_template : Optional [np .ndarray ] = None
12+
13+ # Scalar entries (legacy/sparse)
1214 self .m0_dependent_A_indices : list [tuple [int , int , Callable ]] = []
15+
16+ # [NEW] Vectorized blocks (row_start, col_start, calculator_func)
17+ self .m0_dependent_blocks : list [tuple [int , int , Callable ]] = []
18+
1319 self .m0_dependent_b_indices : list [tuple [int , Callable ]] = []
1420
1521 self .m_k_entry_func : Optional [Callable ] = None
1622 self .N_k_func : Optional [Callable ] = None
1723 self .m_k_arr : Optional [np .ndarray ] = None
1824 self .N_k_arr : Optional [np .ndarray ] = None
1925
20- # --- FIX: Track the m0 value associated with the current cache ---
2126 self .cached_m0 : Optional [float ] = None
22- # -----------------------------------------------------------------
2327
2428 self .I_nm_vals : Optional [List [np .ndarray ]] = None
2529 self .named_closures : Dict [str , Any ] = {}
30+
31+ # Integration constants
32+ self .int_R1_vals = None
33+ self .int_R2_vals = None
34+ self .int_phi_vals = None
2635
2736 def _set_A_template (self , A_template : np .ndarray ):
2837 self .A_template = A_template
@@ -33,22 +42,24 @@ def _set_b_template(self, b_template: np.ndarray):
3342 def _add_m0_dependent_A_entry (self , row : int , col : int , calc_func : Callable ):
3443 self .m0_dependent_A_indices .append ((row , col , calc_func ))
3544
45+ def _add_m0_dependent_block (self , row_start : int , col_start : int , calc_func : Callable ):
46+ """
47+ Registers a function that returns a dense sub-matrix (block) to be inserted
48+ into A at [row_start:..., col_start:...].
49+ """
50+ self .m0_dependent_blocks .append ((row_start , col_start , calc_func ))
51+
3652 def _add_m0_dependent_b_entry (self , row : int , calc_func : Callable ):
3753 self .m0_dependent_b_indices .append ((row , calc_func ))
3854
3955 def _set_m_k_and_N_k_funcs (self , m_k_entry_func : Callable , N_k_func : Callable ):
4056 self .m_k_entry_func = m_k_entry_func
4157 self .N_k_func = N_k_func
4258
43- # --- FIX: Accept m0 as an argument to store it ---
4459 def _set_precomputed_m_k_N_k (self , m_k_arr : np .ndarray , N_k_arr : np .ndarray , m0 : float ):
45- """
46- Sets the pre-computed m_k and N_k arrays for a specific m0.
47- """
4860 self .m_k_arr = m_k_arr
4961 self .N_k_arr = N_k_arr
5062 self .cached_m0 = m0
51- # ------------------------------------------------
5263
5364 def _set_I_nm_vals (self , I_nm_vals : List [np .ndarray ]):
5465 self .I_nm_vals = I_nm_vals
@@ -68,6 +79,7 @@ def _set_closure(self, key: str, closure):
6879
6980 def _get_closure (self , key : str ):
7081 return self .named_closures .get (key , None )
82+
7183 def _set_integration_constants (self , int_R1 , int_R2 , int_phi ):
7284 self .int_R1_vals = int_R1
7385 self .int_R2_vals = int_R2
@@ -77,22 +89,19 @@ def _get_integration_constants(self):
7789 if self .int_R1_vals is None :
7890 raise ValueError ("Integration constants have not been set." )
7991 return self .int_R1_vals , self .int_R2_vals , self .int_phi_vals
92+
8093 def refresh_forcing_terms (self , problem ):
8194 """
82- Re-calculates b_template and m0_dependent_b_indices based on the
83- current heaving configuration of the problem.
84- This allows re-using the cache (and Matrix A) while changing the active mode.
95+ Re-calculates b_template and m0_dependent_b_indices.
8596 """
8697 domain_list = problem .domain_list
8798 domain_keys = list (domain_list .keys ())
8899
89- # Extract geometry params (same as build_problem_cache)
90100 h = domain_list [0 ].h
91101 d = [domain_list [idx ].di for idx in domain_keys ]
92102 a = [domain_list [idx ].a for idx in domain_keys ]
93103 NMK = [domain .number_harmonics for domain in domain_list .values ()]
94104
95- # Crucial: Get the NEW heaving flags
96105 heaving = [domain_list [idx ].heaving for idx in domain_keys ]
97106
98107 boundary_count = len (NMK ) - 1
@@ -116,11 +125,7 @@ def refresh_forcing_terms(self, problem):
116125 self ._set_b_template (b_template )
117126
118127 # 2. Reset m0_dependent_b_indices
119- self .m0_dependent_b_indices = [] # Clear old indices
120-
121- # Re-populate using the loop logic from build_problem_cache
122- # Note: We must reset 'index' to match the velocity loop start position
123- # The velocity loop starts after the potential loop.
128+ self .m0_dependent_b_indices = []
124129
125130 # Calculate offset where velocity equations start
126131 potential_eq_count = 0
@@ -135,19 +140,14 @@ def refresh_forcing_terms(self, problem):
135140 for bd in range (boundary_count ):
136141 if bd == (boundary_count - 1 ):
137142 for n_local in range (NMK [- 1 ]):
138- # Closure to capture n_local and heaving state
139143 calc_func = lambda p , m0 , mk , Nk , Imk , n = n_local : \
140144 b_velocity_end_entry (n , bd , heaving , a , h , d , m0 , NMK , mk , Nk )
141145 self ._add_m0_dependent_b_entry (index , calc_func )
142146 index += 1
143147 else :
144148 num_entries = NMK [bd + (d [bd ] > d [bd + 1 ])]
145149 for n in range (num_entries ):
146- # b_velocity_entry is not m0 dependent, so it goes into b_template?
147- # Wait, look at build_problem_cache in original file.
148- # b_velocity_entry IS put into b_template.
149150 b_template [index ] = b_velocity_entry (n , bd , heaving , a , h , d )
150151 index += 1
151-
152- # Update the template again with the velocity entries added
152+
153153 self ._set_b_template (b_template )
0 commit comments