Skip to content

Commit 230a8d5

Browse files
Merge pull request #72 from symbiotic-engineering/recompute
Recompute
2 parents 8c416e5 + 791df5f commit 230a8d5

7 files changed

Lines changed: 116 additions & 206 deletions

File tree

package/src/openflash/meem_engine.py

Lines changed: 82 additions & 120 deletions
Large diffs are not rendered by default.

package/src/openflash/problem_cache.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# package/src/openflash/problem_cache.py
22
import numpy as np
3-
from typing import Callable, Dict, Any, Optional, List
3+
from typing import Callable, Dict, Any, Optional, List, Tuple
44

55
from openflash.multi_equations import *
66

@@ -9,20 +9,29 @@ def __init__(self, problem):
99
self.problem = problem
1010
self.A_template: Optional[np.ndarray] = None
1111
self.b_template: Optional[np.ndarray] = None
12+
13+
# Scalar entries (legacy/sparse)
1214
self.m0_dependent_A_indices: list[tuple[int, int, Callable]] = []
15+
16+
# [NEW] Vectorized blocks (row_start, col_start, calculator_func)
17+
self.m0_dependent_blocks: list[tuple[int, int, Callable]] = []
18+
1319
self.m0_dependent_b_indices: list[tuple[int, Callable]] = []
1420

1521
self.m_k_entry_func: Optional[Callable] = None
1622
self.N_k_func: Optional[Callable] = None
1723
self.m_k_arr: Optional[np.ndarray] = None
1824
self.N_k_arr: Optional[np.ndarray] = None
1925

20-
# --- FIX: Track the m0 value associated with the current cache ---
2126
self.cached_m0: Optional[float] = None
22-
# -----------------------------------------------------------------
2327

2428
self.I_nm_vals: Optional[List[np.ndarray]] = None
2529
self.named_closures: Dict[str, Any] = {}
30+
31+
# Integration constants
32+
self.int_R1_vals = None
33+
self.int_R2_vals = None
34+
self.int_phi_vals = None
2635

2736
def _set_A_template(self, A_template: np.ndarray):
2837
self.A_template = A_template
@@ -33,22 +42,24 @@ def _set_b_template(self, b_template: np.ndarray):
3342
def _add_m0_dependent_A_entry(self, row: int, col: int, calc_func: Callable):
3443
self.m0_dependent_A_indices.append((row, col, calc_func))
3544

45+
def _add_m0_dependent_block(self, row_start: int, col_start: int, calc_func: Callable):
46+
"""
47+
Registers a function that returns a dense sub-matrix (block) to be inserted
48+
into A at [row_start:..., col_start:...].
49+
"""
50+
self.m0_dependent_blocks.append((row_start, col_start, calc_func))
51+
3652
def _add_m0_dependent_b_entry(self, row: int, calc_func: Callable):
3753
self.m0_dependent_b_indices.append((row, calc_func))
3854

3955
def _set_m_k_and_N_k_funcs(self, m_k_entry_func: Callable, N_k_func: Callable):
4056
self.m_k_entry_func = m_k_entry_func
4157
self.N_k_func = N_k_func
4258

43-
# --- FIX: Accept m0 as an argument to store it ---
4459
def _set_precomputed_m_k_N_k(self, m_k_arr: np.ndarray, N_k_arr: np.ndarray, m0: float):
45-
"""
46-
Sets the pre-computed m_k and N_k arrays for a specific m0.
47-
"""
4860
self.m_k_arr = m_k_arr
4961
self.N_k_arr = N_k_arr
5062
self.cached_m0 = m0
51-
# ------------------------------------------------
5263

5364
def _set_I_nm_vals(self, I_nm_vals: List[np.ndarray]):
5465
self.I_nm_vals = I_nm_vals
@@ -68,6 +79,7 @@ def _set_closure(self, key: str, closure):
6879

6980
def _get_closure(self, key: str):
7081
return self.named_closures.get(key, None)
82+
7183
def _set_integration_constants(self, int_R1, int_R2, int_phi):
7284
self.int_R1_vals = int_R1
7385
self.int_R2_vals = int_R2
@@ -77,22 +89,19 @@ def _get_integration_constants(self):
7789
if self.int_R1_vals is None:
7890
raise ValueError("Integration constants have not been set.")
7991
return self.int_R1_vals, self.int_R2_vals, self.int_phi_vals
92+
8093
def refresh_forcing_terms(self, problem):
8194
"""
82-
Re-calculates b_template and m0_dependent_b_indices based on the
83-
current heaving configuration of the problem.
84-
This allows re-using the cache (and Matrix A) while changing the active mode.
95+
Re-calculates b_template and m0_dependent_b_indices.
8596
"""
8697
domain_list = problem.domain_list
8798
domain_keys = list(domain_list.keys())
8899

89-
# Extract geometry params (same as build_problem_cache)
90100
h = domain_list[0].h
91101
d = [domain_list[idx].di for idx in domain_keys]
92102
a = [domain_list[idx].a for idx in domain_keys]
93103
NMK = [domain.number_harmonics for domain in domain_list.values()]
94104

95-
# Crucial: Get the NEW heaving flags
96105
heaving = [domain_list[idx].heaving for idx in domain_keys]
97106

98107
boundary_count = len(NMK) - 1
@@ -116,11 +125,7 @@ def refresh_forcing_terms(self, problem):
116125
self._set_b_template(b_template)
117126

118127
# 2. Reset m0_dependent_b_indices
119-
self.m0_dependent_b_indices = [] # Clear old indices
120-
121-
# Re-populate using the loop logic from build_problem_cache
122-
# Note: We must reset 'index' to match the velocity loop start position
123-
# The velocity loop starts after the potential loop.
128+
self.m0_dependent_b_indices = []
124129

125130
# Calculate offset where velocity equations start
126131
potential_eq_count = 0
@@ -135,19 +140,14 @@ def refresh_forcing_terms(self, problem):
135140
for bd in range(boundary_count):
136141
if bd == (boundary_count - 1):
137142
for n_local in range(NMK[-1]):
138-
# Closure to capture n_local and heaving state
139143
calc_func = lambda p, m0, mk, Nk, Imk, n=n_local: \
140144
b_velocity_end_entry(n, bd, heaving, a, h, d, m0, NMK, mk, Nk)
141145
self._add_m0_dependent_b_entry(index, calc_func)
142146
index += 1
143147
else:
144148
num_entries = NMK[bd + (d[bd] > d[bd + 1])]
145149
for n in range(num_entries):
146-
# b_velocity_entry is not m0 dependent, so it goes into b_template?
147-
# Wait, look at build_problem_cache in original file.
148-
# b_velocity_entry IS put into b_template.
149150
b_template[index] = b_velocity_entry(n, bd, heaving, a, h, d)
150151
index += 1
151-
152-
# Update the template again with the velocity entries added
152+
153153
self._set_b_template(b_template)

package/test/test_high_frequency_convergence.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,8 @@
3131
}
3232

3333
RHO = 1023
34-
# FIX: Use a safe "high" wavenumber.
35-
# m0=50 implies wavelength ~ 0.12m, which is << body size (3.0m+)
36-
# m0=1e6 causes exp(-m0*d) underflow and singular matrices.
37-
M0_MAX = 50.0
38-
TOLERANCE = 0.05 # Slightly relaxed for asymptotic convergence
34+
M0_MAX = 1e6
35+
TOLERANCE = 0.01
3936

4037
@pytest.mark.parametrize("name, cfg", CONFIGS.items())
4138
def test_high_frequency_limit(name, cfg):
@@ -45,10 +42,8 @@ def test_high_frequency_limit(name, cfg):
4542
"""
4643
print(f"\nRunning {name}...")
4744

48-
# --- FIX: Reduced NMK to prevent Matrix Ill-Conditioning ---
49-
# 20-30 modes is standard and sufficient. 100 is unstable.
5045
num_regions = len(cfg['heaving']) + 1
51-
NMK = [30] * num_regions
46+
NMK = [100] * num_regions
5247
# -----------------------------------------------------------
5348

5449
# 1. Solve for m0 = inf

package/test/test_matrix_snapshot.py

Lines changed: 0 additions & 51 deletions
This file was deleted.

package/test/test_meem_engine.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# test_meem_engine.py
1+
# package/test/test_meem_engine.py
22
import pytest
33
import numpy as np
44
import sys
@@ -236,11 +236,14 @@ def test_build_problem_cache(sample_problem):
236236
assert np.any(cache.b_template != 0)
237237

238238
# 4. Verify that the lists for m0-dependent parts are populated
239-
assert len(cache.m0_dependent_A_indices) > 0
239+
# [FIX]: Optimization replaced scalar entries (m0_dependent_A_indices)
240+
# with vectorized blocks (m0_dependent_blocks).
241+
assert len(cache.m0_dependent_blocks) > 0
240242
assert len(cache.m0_dependent_b_indices) > 0
241243

242244
# 5. Check a specific m0-dependent entry to ensure it's a callable
243-
assert callable(cache.m0_dependent_A_indices[0][2])
245+
# m0_dependent_blocks format: (row_start, col_start, calc_func)
246+
assert callable(cache.m0_dependent_blocks[0][2])
244247
assert callable(cache.m0_dependent_b_indices[0][1])
245248

246249
print("✅ Problem cache build test passed.")
0 Bytes
Loading

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ dependencies = [
3030
"pandas>=1.5",
3131
"matplotlib>=3.5",
3232
"h5netcdf>=0.12",
33+
"h5py>=3.0",
3334
"xarray>=2023.0",
3435
"streamlit>=1.0"
3536
]

0 commit comments

Comments
 (0)