Skip to content

Commit 6482352

Browse files
MaxGhenisclaude
andcommitted
Add B2 validation runner script with per-variable checkpointing
The one-shot ``python -c '...'`` run on the v11 output got SIGKILL'd before producing output — Python buffered stdout was lost on signal, and no per-variable state was saved to disk. This script runs the same computation with ``python -u`` for line-buffered stdout and writes a ``<output>.partial.json`` after each variable so a late kill still leaves N-of-6 aggregates recoverable. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 4b35735 commit 6482352

1 file changed

Lines changed: 81 additions & 0 deletions

File tree

scripts/run_b2_validation.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""Run B2 downstream validation on a calibrated PE-US h5.
2+
3+
One variable at a time, flushing progress and intermediate output to
4+
disk so a partial run leaves usable state. Uses the
5+
``microplex_us.validation.downstream`` module for the benchmark set.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import argparse
11+
import json
12+
import sys
13+
import time
14+
from pathlib import Path
15+
16+
from microplex_us.validation.downstream import (
17+
DOWNSTREAM_BENCHMARKS_2024,
18+
compute_downstream_comparison,
19+
)
20+
21+
22+
def main() -> int:
23+
parser = argparse.ArgumentParser()
24+
parser.add_argument("--dataset", required=True, type=Path)
25+
parser.add_argument("--output", required=True, type=Path)
26+
parser.add_argument("--period", default=2024, type=int)
27+
args = parser.parse_args()
28+
29+
print(f"[{time.strftime('%H:%M:%S')}] loading Microsimulation from {args.dataset}", flush=True)
30+
from policyengine_us import Microsimulation
31+
32+
sim = Microsimulation(dataset=str(args.dataset))
33+
print(f"[{time.strftime('%H:%M:%S')}] loaded", flush=True)
34+
35+
variables = [spec.name for spec in DOWNSTREAM_BENCHMARKS_2024]
36+
aggregates: dict[str, float] = {}
37+
38+
args.output.parent.mkdir(parents=True, exist_ok=True)
39+
intermediate_path = args.output.with_suffix(".partial.json")
40+
41+
for variable in variables:
42+
t0 = time.time()
43+
print(f"[{time.strftime('%H:%M:%S')}] computing {variable} ...", flush=True)
44+
try:
45+
total = float(sim.calculate(variable, args.period).sum())
46+
except Exception as exc:
47+
print(f" {variable}: FAILED ({exc})", flush=True)
48+
aggregates[variable] = float("nan")
49+
else:
50+
aggregates[variable] = total
51+
elapsed = time.time() - t0
52+
print(
53+
f" {variable}: ${total/1e9:,.2f}B (in {elapsed:.1f}s)",
54+
flush=True,
55+
)
56+
# Flush partial state to disk after each variable so an OOM
57+
# kill after N variables still leaves N results on disk.
58+
intermediate_path.write_text(json.dumps(aggregates, indent=2))
59+
60+
comparison = compute_downstream_comparison(aggregates, DOWNSTREAM_BENCHMARKS_2024)
61+
report = {name: rec.to_dict() for name, rec in comparison.items()}
62+
args.output.write_text(json.dumps(report, indent=2))
63+
intermediate_path.unlink(missing_ok=True)
64+
65+
print(f"\n[{time.strftime('%H:%M:%S')}] B2 validation complete", flush=True)
66+
print(f"Wrote {args.output}", flush=True)
67+
68+
print(f"\n{'variable':<12s} {'computed':>12s} {'benchmark':>12s} {'rel_error':>10s}")
69+
for name, rec in sorted(comparison.items()):
70+
rel = rec.rel_error
71+
rel_str = f"{rel*100:+.1f}%" if rel is not None else "N/A"
72+
print(
73+
f"{name:<12s} ${rec.computed/1e9:>9.2f}B "
74+
f"${rec.benchmark/1e9:>9.2f}B {rel_str:>10s}",
75+
flush=True,
76+
)
77+
return 0
78+
79+
80+
if __name__ == "__main__":
81+
sys.exit(main())

0 commit comments

Comments
 (0)