Skip to content

Commit ad21cd7

Browse files
committed
refactor: extract shared helpers from convergence test runners
1 parent cdc9b96 commit ad21cd7

5 files changed

Lines changed: 275 additions & 629 deletions

File tree

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
"""Shared helpers for MFC convergence/order test runners."""
2+
3+
import json
4+
import math
5+
import os
6+
import shutil
7+
import struct
8+
import subprocess
9+
import sys
10+
11+
import numpy as np
12+
13+
MFC = "./mfc.sh"
14+
CONS_TOL = 1e-10
15+
16+
17+
def read_cons_var(run_dir, step, var_idx, num_ranks=1, expected_size=None):
18+
"""Read q_cons_vf{var_idx} from all ranks; rank-order concatenation."""
19+
chunks = []
20+
for rank in range(num_ranks):
21+
path = os.path.join(run_dir, "p_all", f"p{rank}", str(step), f"q_cons_vf{var_idx}.dat")
22+
with open(path, "rb") as f:
23+
rec_len = struct.unpack("i", f.read(4))[0]
24+
data = np.frombuffer(f.read(rec_len), dtype=np.float64)
25+
f.read(4)
26+
chunks.append(data.copy())
27+
combined = np.concatenate(chunks)
28+
if expected_size is not None and combined.size != expected_size:
29+
raise ValueError(f"Expected {expected_size} values across {num_ranks} ranks, got {combined.size}")
30+
return combined
31+
32+
33+
def conservation_errors(run_dir, Nt, cell_vol, var_list, num_ranks, expected_size=None):
34+
"""|Σq(T) - Σq(0)| / |Σq(0)| for each named variable."""
35+
errs = {}
36+
for name, idx in var_list:
37+
q0 = read_cons_var(run_dir, 0, idx, num_ranks, expected_size)
38+
qT = read_cons_var(run_dir, Nt, idx, num_ranks, expected_size)
39+
s0 = float(np.sum(q0)) * cell_vol
40+
sT = float(np.sum(qT)) * cell_vol
41+
errs[name] = abs(sT - s0) / (abs(s0) + 1e-300)
42+
return errs
43+
44+
45+
def l2_norm(diff, scale):
46+
"""sqrt(sum(diff^2) * scale)."""
47+
return float(np.sqrt(np.sum(diff**2) * scale))
48+
49+
50+
def fit_rate(errors, h_values):
51+
"""Least-squares slope of log(error) vs log(h)."""
52+
log_h = np.log(np.array(h_values, dtype=float))
53+
log_err = np.log(np.array(errors, dtype=float))
54+
slope, _ = np.polyfit(log_h, log_err, 1)
55+
return float(slope)
56+
57+
58+
def pairwise_rates(errors, h_values):
59+
"""Pairwise rates aligned with errors (first entry is None)."""
60+
rates = [None]
61+
for i in range(1, len(errors)):
62+
log_h0 = math.log(h_values[i - 1])
63+
log_h1 = math.log(h_values[i])
64+
rates.append((math.log(errors[i]) - math.log(errors[i - 1])) / (log_h1 - log_h0))
65+
return rates
66+
67+
68+
def run_mfc_case(case_path, tmpdir, run_tag, case_args, num_ranks=1):
69+
"""Run case.py once and copy p_all to tmpdir/run_tag. Returns (cfg_dict, run_dir)."""
70+
result = subprocess.run(
71+
[sys.executable, case_path, "--mfc", "{}"] + case_args,
72+
capture_output=True,
73+
text=True,
74+
check=False,
75+
)
76+
if result.returncode != 0:
77+
raise RuntimeError(f"case.py failed:\n{result.stderr}")
78+
cfg = json.loads(result.stdout)
79+
80+
cmd = [MFC, "run", case_path, "-t", "pre_process", "simulation", "-n", str(num_ranks), "--"] + case_args
81+
result = subprocess.run(cmd, capture_output=True, text=True, cwd=os.getcwd(), check=False)
82+
if result.returncode != 0:
83+
print(result.stdout[-3000:])
84+
print(result.stderr)
85+
raise RuntimeError(f"./mfc.sh run failed for {run_tag}")
86+
87+
case_dir = os.path.dirname(case_path)
88+
src = os.path.join(case_dir, "p_all")
89+
dst = os.path.join(tmpdir, run_tag, "p_all")
90+
if os.path.exists(dst):
91+
shutil.rmtree(dst)
92+
shutil.copytree(src, dst)
93+
shutil.rmtree(src, ignore_errors=True)
94+
shutil.rmtree(os.path.join(case_dir, "D"), ignore_errors=True)
95+
96+
return cfg, os.path.join(tmpdir, run_tag)
97+
98+
99+
def print_conservation_check(all_cons_errs, var_list, tol=CONS_TOL):
100+
"""Print per-variable max conservation error and return overall pass."""
101+
print(f"\n Conservation (need rel. error < {tol:.0e}):")
102+
passed = True
103+
for name, _ in var_list:
104+
max_err = max(ce[name] for ce in all_cons_errs)
105+
ok = max_err < tol
106+
print(f" {name:<14}: max = {max_err:.2e} {'OK' if ok else 'FAIL'}")
107+
if not ok:
108+
passed = False
109+
return passed
110+
111+
112+
def print_summary(results, label_width=18):
113+
"""Print PASS/FAIL summary table; return overall bool."""
114+
print(f"\n{'=' * 60}")
115+
print(" Summary")
116+
print(f"{'=' * 60}")
117+
all_pass = True
118+
for label, passed in results.items():
119+
print(f" {label:<{label_width}} {'PASS' if passed else 'FAIL'}")
120+
if not passed:
121+
all_pass = False
122+
return all_pass
123+
124+
125+
def run_with_traceback(label, fn, *args, **kwargs):
126+
"""Run a test function, print traceback on failure, return pass/fail bool."""
127+
try:
128+
return fn(*args, **kwargs)
129+
except Exception as e:
130+
import traceback
131+
132+
print(f" ERROR ({label}): {e}")
133+
traceback.print_exc()
134+
return False

0 commit comments

Comments
 (0)