Skip to content
This repository was archived by the owner on Jun 14, 2026. It is now read-only.

Commit c08252f

Browse files
MaxGhenisclaude
andcommitted
Make downstream aggregate weighting explicit; seed regime-aware imputer
downstream.py - Replace reliance on MicroSeries ``.sum()`` semantics with an explicit ``compute_downstream_weighted_aggregate`` helper that pulls the correct entity weight variable (tax_unit_weight / spm_unit_weight / person_weight / ...) from PE's variable metadata and takes the numpy dot product. Same numerics as ``.sum()`` on the v11 artifact, but test-covered and robust to simulator changes. - ``ENTITY_WEIGHT_VARIABLES`` table maps PE entity keys to weight variable names. RegimeAwareDonorImputer - Add ``seed`` constructor arg and deterministic ``_reset_prediction_rngs`` during ``generate`` so repeated calls with the same seed produce byte-identical output. scripts/run_b2_batched.py - Classify each h5 variable by PE's variable metadata first, then fall back to length matching; raises on ambiguous length matches rather than silently picking one. Added structural-variable overrides for IDs / weights / link columns. - Wire batched runner's per-chunk aggregate through ``compute_downstream_weighted_aggregate``. scripts/run_b2_validation.py / run_b2_validation_single_var.py - Use ``compute_downstream_weighted_aggregate`` for consistency with the other callers and explicit weighting. Tests: 3 new entity-resolution tests in test_run_b2_batched.py; 3 new weighted-aggregate tests in test_downstream.py; 2 new seed-determinism tests in test_regime_aware_donor_imputer.py. 21 tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 5796bf7 commit c08252f

11 files changed

Lines changed: 499 additions & 31 deletions

File tree

AGENTS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ To avoid rebuilding long prompts in chat:
8484
<!-- gitnexus:start -->
8585
# GitNexus — Code Intelligence
8686

87-
This project is indexed by GitNexus as **microplex-us** (4732 symbols, 12777 relationships, 300 execution flows). Use the GitNexus MCP tools to understand code, assess impact, and navigate safely.
87+
This project is indexed by GitNexus as **microplex-us** (4778 symbols, 12879 relationships, 300 execution flows). Use the GitNexus MCP tools to understand code, assess impact, and navigate safely.
8888

8989
> If any GitNexus tool warns the index is stale, run `npx gitnexus analyze` in terminal first.
9090

CLAUDE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
<!-- gitnexus:start -->
22
# GitNexus — Code Intelligence
33

4-
This project is indexed by GitNexus as **microplex-us** (4732 symbols, 12777 relationships, 300 execution flows). Use the GitNexus MCP tools to understand code, assess impact, and navigate safely.
4+
This project is indexed by GitNexus as **microplex-us** (4778 symbols, 12879 relationships, 300 execution flows). Use the GitNexus MCP tools to understand code, assess impact, and navigate safely.
55

66
> If any GitNexus tool warns the index is stale, run `npx gitnexus analyze` in terminal first.
77
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""Copy the calibration targets DB and add direct targets on SSI / CTC / ACA PTC.
2+
3+
The v11 downstream validation showed those three aggregates drifting
4+
+64% / +32% / -76% from their benchmark totals. They weren't in the
5+
original calibration target set (which focuses on AGI / income
6+
marginals, not downstream-disbursed amounts). Adding them as direct
7+
national targets should drive their calibrated aggregates toward the
8+
benchmark values.
9+
10+
Stratum 1 is "United States" (from the existing DB). Period 2024 and
11+
reform_id=0 (baseline) match the rest of the 2024 target set.
12+
"""
13+
14+
from __future__ import annotations
15+
16+
import argparse
17+
import shutil
18+
import sqlite3
19+
from pathlib import Path
20+
21+
from microplex_us.validation.downstream import DOWNSTREAM_BENCHMARKS_2024
22+
23+
24+
def main() -> int:
25+
parser = argparse.ArgumentParser()
26+
parser.add_argument("--source", required=True, type=Path)
27+
parser.add_argument("--output", required=True, type=Path)
28+
parser.add_argument(
29+
"--variables",
30+
nargs="+",
31+
default=["ssi", "ctc", "aca_ptc"],
32+
)
33+
parser.add_argument("--period", default=2024, type=int)
34+
args = parser.parse_args()
35+
36+
args.output.parent.mkdir(parents=True, exist_ok=True)
37+
shutil.copyfile(args.source, args.output)
38+
39+
benchmarks_by_name = {spec.name: spec for spec in DOWNSTREAM_BENCHMARKS_2024}
40+
41+
con = sqlite3.connect(args.output)
42+
cur = con.cursor()
43+
for variable in args.variables:
44+
spec = benchmarks_by_name.get(variable)
45+
if spec is None:
46+
raise KeyError(f"No 2024 benchmark spec for {variable}")
47+
cur.execute(
48+
"SELECT COUNT(*) FROM targets WHERE variable=? AND period=? "
49+
"AND stratum_id=1 AND reform_id=0",
50+
(variable, args.period),
51+
)
52+
if cur.fetchone()[0] > 0:
53+
print(f"[skip] {variable} already has a national 2024 target")
54+
continue
55+
cur.execute(
56+
"INSERT INTO targets "
57+
"(variable, period, stratum_id, reform_id, value, active, source, notes) "
58+
"VALUES (?, ?, 1, 0, ?, 1, ?, ?)",
59+
(
60+
variable,
61+
args.period,
62+
float(spec.benchmark),
63+
spec.source,
64+
f"B2 follow-up direct target for {variable}",
65+
),
66+
)
67+
print(
68+
f"[add ] {variable} @ 2024 national: ${spec.benchmark/1e9:.1f}B ({spec.source})"
69+
)
70+
con.commit()
71+
con.close()
72+
print(f"\nWrote augmented DB to {args.output}")
73+
return 0
74+
75+
76+
if __name__ == "__main__":
77+
raise SystemExit(main())

scripts/run_b2_batched.py

Lines changed: 82 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import h5py
2727
import numpy as np
2828

29-
3029
HOUSEHOLD_ID = "household_id"
3130

3231
ENTITY_ID_COLUMNS = {
@@ -44,6 +43,25 @@
4443
"family": "person_family_id",
4544
"marital_unit": "person_marital_unit_id",
4645
}
46+
STRUCTURAL_VARIABLE_ENTITIES = {
47+
"household_id": "household",
48+
"household_weight": "household",
49+
"person_id": "person",
50+
"person_household_id": "person",
51+
"person_weight": "person",
52+
"tax_unit_id": "tax_unit",
53+
"person_tax_unit_id": "person",
54+
"tax_unit_weight": "tax_unit",
55+
"spm_unit_id": "spm_unit",
56+
"person_spm_unit_id": "person",
57+
"spm_unit_weight": "spm_unit",
58+
"family_id": "family",
59+
"person_family_id": "person",
60+
"family_weight": "family",
61+
"marital_unit_id": "marital_unit",
62+
"person_marital_unit_id": "person",
63+
"marital_unit_weight": "marital_unit",
64+
}
4765

4866

4967
def _load_all_arrays(h5_path: Path, period_key: str) -> dict[str, np.ndarray]:
@@ -55,17 +73,51 @@ def _load_all_arrays(h5_path: Path, period_key: str) -> dict[str, np.ndarray]:
5573
return out
5674

5775

58-
def _entity_of(variable: str, arrays: dict[str, np.ndarray]) -> str:
59-
"""Classify a variable by matching its array length to an entity's id column."""
76+
def _load_policyengine_variable_entities() -> dict[str, str]:
77+
try:
78+
from policyengine_us import (
79+
system as policyengine_system_module, # noqa: PLC0415
80+
)
81+
except ImportError:
82+
return {}
83+
84+
tax_benefit_system = getattr(policyengine_system_module, "system", None)
85+
if tax_benefit_system is None:
86+
return {}
87+
variables = getattr(tax_benefit_system, "variables", {})
88+
entity_map: dict[str, str] = {}
89+
for name, metadata in variables.items():
90+
entity_key = getattr(getattr(metadata, "entity", None), "key", None)
91+
if entity_key is not None:
92+
entity_map[str(name)] = str(entity_key)
93+
return entity_map
94+
95+
96+
def _entity_of(
97+
variable: str,
98+
arrays: dict[str, np.ndarray],
99+
*,
100+
variable_entities: dict[str, str] | None = None,
101+
) -> str:
102+
"""Classify a variable, preferring PE metadata over fragile length matching."""
103+
explicit_entity = STRUCTURAL_VARIABLE_ENTITIES.get(variable)
104+
if explicit_entity is not None:
105+
return explicit_entity
106+
if variable_entities is not None and variable in variable_entities:
107+
return variable_entities[variable]
60108
n = len(arrays[variable])
61109
entity_lengths = {
62110
entity: len(arrays[id_col])
63111
for entity, id_col in ENTITY_ID_COLUMNS.items()
64112
if id_col in arrays
65113
}
66-
for entity, length in entity_lengths.items():
67-
if length == n:
68-
return entity
114+
matches = [entity for entity, length in entity_lengths.items() if length == n]
115+
if len(matches) == 1:
116+
return matches[0]
117+
if len(matches) > 1:
118+
raise ValueError(
119+
f"Ambiguous entity for variable {variable!r}: matched {matches} by length"
120+
)
69121
return "unknown"
70122

71123

@@ -74,7 +126,6 @@ def _build_entity_masks(
74126
) -> dict[str, np.ndarray]:
75127
"""Produce boolean masks into each entity array for the households in ``chunk_hh_ids``."""
76128
hh_id = arrays["household_id"]
77-
chunk_set = set(chunk_hh_ids.tolist())
78129
masks: dict[str, np.ndarray] = {}
79130
masks["household"] = np.isin(hh_id, chunk_hh_ids)
80131
person_hh = arrays["person_household_id"]
@@ -94,11 +145,17 @@ def _write_chunk_h5(
94145
entity_masks: dict[str, np.ndarray],
95146
period_key: str,
96147
tmp_path: Path,
148+
*,
149+
variable_entities: dict[str, str] | None = None,
97150
) -> None:
98151
"""Write a subset h5 keeping only rows matching each variable's entity mask."""
99152
with h5py.File(tmp_path, "w") as f:
100153
for variable, values in arrays.items():
101-
entity = _entity_of(variable, arrays)
154+
entity = _entity_of(
155+
variable,
156+
arrays,
157+
variable_entities=variable_entities,
158+
)
102159
mask = entity_masks.get(entity)
103160
if mask is None or len(values) != len(mask):
104161
continue
@@ -118,6 +175,7 @@ def main() -> int:
118175
period_key = str(args.period)
119176
print(f"[{time.strftime('%H:%M:%S')}] loading all arrays from {args.dataset}", flush=True)
120177
arrays = _load_all_arrays(args.dataset, period_key)
178+
variable_entities = _load_policyengine_variable_entities()
121179
print(
122180
f"[{time.strftime('%H:%M:%S')}] loaded {len(arrays)} variables",
123181
flush=True,
@@ -132,6 +190,10 @@ def main() -> int:
132190

133191
from policyengine_us import Microsimulation # noqa: PLC0415
134192

193+
from microplex_us.validation.downstream import ( # noqa: PLC0415
194+
compute_downstream_weighted_aggregate,
195+
)
196+
135197
for batch_idx in range(n_batches):
136198
start = batch_idx * args.batch_size
137199
end = min(start + args.batch_size, n_hh)
@@ -141,12 +203,21 @@ def main() -> int:
141203

142204
with tempfile.TemporaryDirectory() as tmp:
143205
tmp_path = Path(tmp) / "chunk.h5"
144-
_write_chunk_h5(arrays, entity_masks, period_key, tmp_path)
206+
_write_chunk_h5(
207+
arrays,
208+
entity_masks,
209+
period_key,
210+
tmp_path,
211+
variable_entities=variable_entities,
212+
)
145213

146214
t0 = time.time()
147215
sim = Microsimulation(dataset=str(tmp_path))
148-
values = sim.calculate(args.variable, args.period)
149-
chunk_sum = float(values.sum())
216+
chunk_sum = compute_downstream_weighted_aggregate(
217+
sim,
218+
args.variable,
219+
args.period,
220+
)
150221
total += chunk_sum
151222
elapsed = time.time() - t0
152223

scripts/run_b2_validation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from microplex_us.validation.downstream import (
1717
DOWNSTREAM_BENCHMARKS_2024,
1818
compute_downstream_comparison,
19+
compute_downstream_weighted_aggregate,
1920
)
2021

2122

@@ -42,7 +43,7 @@ def main() -> int:
4243
t0 = time.time()
4344
print(f"[{time.strftime('%H:%M:%S')}] computing {variable} ...", flush=True)
4445
try:
45-
total = float(sim.calculate(variable, args.period).sum())
46+
total = compute_downstream_weighted_aggregate(sim, variable, args.period)
4647
except Exception as exc:
4748
print(f" {variable}: FAILED ({exc})", flush=True)
4849
aggregates[variable] = float("nan")

scripts/run_b2_validation_single_var.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from microplex_us.validation.downstream import (
1717
DOWNSTREAM_BENCHMARKS_2024,
1818
compute_downstream_comparison,
19+
compute_downstream_weighted_aggregate,
1920
)
2021

2122

@@ -33,7 +34,7 @@ def main() -> int:
3334
sim = Microsimulation(dataset=str(args.dataset))
3435
print(f"[{time.strftime('%H:%M:%S')}] loaded — computing {args.variable}", flush=True)
3536
t0 = time.time()
36-
total = float(sim.calculate(args.variable, args.period).sum())
37+
total = compute_downstream_weighted_aggregate(sim, args.variable, args.period)
3738
elapsed = time.time() - t0
3839
print(
3940
f"[{time.strftime('%H:%M:%S')}] {args.variable} = ${total/1e9:.2f}B "
@@ -42,11 +43,6 @@ def main() -> int:
4243
)
4344

4445
args.output.parent.mkdir(parents=True, exist_ok=True)
45-
if args.output.exists():
46-
existing = json.loads(args.output.read_text())
47-
else:
48-
existing = {}
49-
5046
# Re-read intermediate file if present (accumulates across runs).
5147
raw_agg_path = args.output.with_suffix(".raw.json")
5248
raw_aggs = (

0 commit comments

Comments
 (0)