Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 82 additions & 120 deletions package/src/openflash/meem_engine.py

Large diffs are not rendered by default.

48 changes: 24 additions & 24 deletions package/src/openflash/problem_cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# package/src/openflash/problem_cache.py
import numpy as np
from typing import Callable, Dict, Any, Optional, List
from typing import Callable, Dict, Any, Optional, List, Tuple

from openflash.multi_equations import *

Expand All @@ -9,20 +9,29 @@ def __init__(self, problem):
self.problem = problem
self.A_template: Optional[np.ndarray] = None
self.b_template: Optional[np.ndarray] = None

# Scalar entries (legacy/sparse)
self.m0_dependent_A_indices: list[tuple[int, int, Callable]] = []

# [NEW] Vectorized blocks (row_start, col_start, calculator_func)
self.m0_dependent_blocks: list[tuple[int, int, Callable]] = []

self.m0_dependent_b_indices: list[tuple[int, Callable]] = []

self.m_k_entry_func: Optional[Callable] = None
self.N_k_func: Optional[Callable] = None
self.m_k_arr: Optional[np.ndarray] = None
self.N_k_arr: Optional[np.ndarray] = None

# --- FIX: Track the m0 value associated with the current cache ---
self.cached_m0: Optional[float] = None
# -----------------------------------------------------------------

self.I_nm_vals: Optional[List[np.ndarray]] = None
self.named_closures: Dict[str, Any] = {}

# Integration constants
self.int_R1_vals = None
self.int_R2_vals = None
self.int_phi_vals = None

def _set_A_template(self, A_template: np.ndarray):
self.A_template = A_template
Expand All @@ -33,22 +42,24 @@ def _set_b_template(self, b_template: np.ndarray):
def _add_m0_dependent_A_entry(self, row: int, col: int, calc_func: Callable):
self.m0_dependent_A_indices.append((row, col, calc_func))

def _add_m0_dependent_block(self, row_start: int, col_start: int, calc_func: Callable):
"""
Registers a function that returns a dense sub-matrix (block) to be inserted
into A at [row_start:..., col_start:...].
"""
self.m0_dependent_blocks.append((row_start, col_start, calc_func))

def _add_m0_dependent_b_entry(self, row: int, calc_func: Callable):
self.m0_dependent_b_indices.append((row, calc_func))

def _set_m_k_and_N_k_funcs(self, m_k_entry_func: Callable, N_k_func: Callable):
self.m_k_entry_func = m_k_entry_func
self.N_k_func = N_k_func

# --- FIX: Accept m0 as an argument to store it ---
def _set_precomputed_m_k_N_k(self, m_k_arr: np.ndarray, N_k_arr: np.ndarray, m0: float):
"""
Sets the pre-computed m_k and N_k arrays for a specific m0.
"""
self.m_k_arr = m_k_arr
self.N_k_arr = N_k_arr
self.cached_m0 = m0
# ------------------------------------------------

def _set_I_nm_vals(self, I_nm_vals: List[np.ndarray]):
self.I_nm_vals = I_nm_vals
Expand All @@ -68,6 +79,7 @@ def _set_closure(self, key: str, closure):

def _get_closure(self, key: str):
return self.named_closures.get(key, None)

def _set_integration_constants(self, int_R1, int_R2, int_phi):
self.int_R1_vals = int_R1
self.int_R2_vals = int_R2
Expand All @@ -77,22 +89,19 @@ def _get_integration_constants(self):
if self.int_R1_vals is None:
raise ValueError("Integration constants have not been set.")
return self.int_R1_vals, self.int_R2_vals, self.int_phi_vals

def refresh_forcing_terms(self, problem):
"""
Re-calculates b_template and m0_dependent_b_indices based on the
current heaving configuration of the problem.
This allows re-using the cache (and Matrix A) while changing the active mode.
Re-calculates b_template and m0_dependent_b_indices.
"""
domain_list = problem.domain_list
domain_keys = list(domain_list.keys())

# Extract geometry params (same as build_problem_cache)
h = domain_list[0].h
d = [domain_list[idx].di for idx in domain_keys]
a = [domain_list[idx].a for idx in domain_keys]
NMK = [domain.number_harmonics for domain in domain_list.values()]

# Crucial: Get the NEW heaving flags
heaving = [domain_list[idx].heaving for idx in domain_keys]

boundary_count = len(NMK) - 1
Expand All @@ -116,11 +125,7 @@ def refresh_forcing_terms(self, problem):
self._set_b_template(b_template)

# 2. Reset m0_dependent_b_indices
self.m0_dependent_b_indices = [] # Clear old indices

# Re-populate using the loop logic from build_problem_cache
# Note: We must reset 'index' to match the velocity loop start position
# The velocity loop starts after the potential loop.
self.m0_dependent_b_indices = []

# Calculate offset where velocity equations start
potential_eq_count = 0
Expand All @@ -135,19 +140,14 @@ def refresh_forcing_terms(self, problem):
for bd in range(boundary_count):
if bd == (boundary_count - 1):
for n_local in range(NMK[-1]):
# Closure to capture n_local and heaving state
calc_func = lambda p, m0, mk, Nk, Imk, n=n_local: \
b_velocity_end_entry(n, bd, heaving, a, h, d, m0, NMK, mk, Nk)
self._add_m0_dependent_b_entry(index, calc_func)
index += 1
else:
num_entries = NMK[bd + (d[bd] > d[bd + 1])]
for n in range(num_entries):
# b_velocity_entry is not m0 dependent, so it goes into b_template?
# Wait, look at build_problem_cache in original file.
# b_velocity_entry IS put into b_template.
b_template[index] = b_velocity_entry(n, bd, heaving, a, h, d)
index += 1

# Update the template again with the velocity entries added

self._set_b_template(b_template)
11 changes: 3 additions & 8 deletions package/test/test_high_frequency_convergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,8 @@
}

RHO = 1023
# FIX: Use a safe "high" wavenumber.
# m0=50 implies wavelength ~ 0.12m, which is << body size (3.0m+)
# m0=1e6 causes exp(-m0*d) underflow and singular matrices.
M0_MAX = 50.0
TOLERANCE = 0.05 # Slightly relaxed for asymptotic convergence
M0_MAX = 1e6
TOLERANCE = 0.01

@pytest.mark.parametrize("name, cfg", CONFIGS.items())
def test_high_frequency_limit(name, cfg):
Expand All @@ -45,10 +42,8 @@ def test_high_frequency_limit(name, cfg):
"""
print(f"\nRunning {name}...")

# --- FIX: Reduced NMK to prevent Matrix Ill-Conditioning ---
# 20-30 modes is standard and sufficient. 100 is unstable.
num_regions = len(cfg['heaving']) + 1
NMK = [30] * num_regions
NMK = [100] * num_regions
# -----------------------------------------------------------

# 1. Solve for m0 = inf
Expand Down
51 changes: 0 additions & 51 deletions package/test/test_matrix_snapshot.py

This file was deleted.

9 changes: 6 additions & 3 deletions package/test/test_meem_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# test_meem_engine.py
# package/test/test_meem_engine.py
import pytest
import numpy as np
import sys
Expand Down Expand Up @@ -236,11 +236,14 @@ def test_build_problem_cache(sample_problem):
assert np.any(cache.b_template != 0)

# 4. Verify that the lists for m0-dependent parts are populated
assert len(cache.m0_dependent_A_indices) > 0
# [FIX]: Optimization replaced scalar entries (m0_dependent_A_indices)
# with vectorized blocks (m0_dependent_blocks).
assert len(cache.m0_dependent_blocks) > 0
assert len(cache.m0_dependent_b_indices) > 0

# 5. Check a specific m0-dependent entry to ensure it's a callable
assert callable(cache.m0_dependent_A_indices[0][2])
# m0_dependent_blocks format: (row_start, col_start, calc_func)
assert callable(cache.m0_dependent_blocks[0][2])
assert callable(cache.m0_dependent_b_indices[0][1])

print("✅ Problem cache build test passed.")
Expand Down
Binary file modified package/test_artifacts/contributions/config3_body_0_real.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"pandas>=1.5",
"matplotlib>=3.5",
"h5netcdf>=0.12",
"h5py>=3.0",
"xarray>=2023.0",
"streamlit>=1.0"
]
Expand Down
Loading