Skip to content

Commit 310bb73

Browse files
baogorekclaude
andcommitted
Salt takeup draws with hh_id:clone_idx instead of block:hh_id
Replace block-based RNG salting with (hh_id, clone_idx) salting. Draws are now tied to the donor household identity and independent across clones, eliminating the multi-clone-same-block collision issue (#597). Geographic variation comes through the rate threshold, not the draw. Closes #597 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 3df0d91 commit 310bb73

4 files changed

Lines changed: 154 additions & 138 deletions

File tree

policyengine_us_data/calibration/publish_local_area.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,7 @@ def build_h5(
511511
hh_blocks=active_blocks,
512512
hh_state_fips=hh_state_fips,
513513
hh_ids=original_hh_ids,
514+
hh_clone_indices=active_geo.astype(np.int64),
514515
entity_hh_indices=entity_hh_indices,
515516
entity_counts=entity_counts,
516517
time_period=time_period,

policyengine_us_data/calibration/unified_matrix_builder.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,11 +635,13 @@ def _process_single_clone(
635635
ent_hh = entity_hh_idx_map[entity]
636636
ent_blocks = clone_blocks[ent_hh]
637637
ent_hh_ids = household_ids[ent_hh]
638+
ent_ci = np.full(len(ent_hh), clone_idx, dtype=np.int64)
638639
draws = compute_block_takeup_for_entities(
639640
var_name,
640641
precomputed_rates[rate_key],
641642
ent_blocks,
642643
ent_hh_ids,
644+
ent_ci,
643645
)
644646
wf_draws[entity] = draws
645647
if var_name in person_vars:
@@ -706,12 +708,14 @@ def _process_single_clone(
706708

707709
ent_blocks = clone_blocks[ent_hh]
708710
ent_hh_ids = household_ids[ent_hh]
711+
ent_ci = np.full(n_ent, clone_idx, dtype=np.int64)
709712

710713
ent_takeup = compute_block_takeup_for_entities(
711714
takeup_var,
712715
precomputed_rates[info["rate_key"]],
713716
ent_blocks,
714717
ent_hh_ids,
718+
ent_ci,
715719
)
716720

717721
ent_values = (ent_eligible * ent_takeup).astype(np.float32)
@@ -2290,11 +2294,13 @@ def build_matrix(
22902294
ent_hh = entity_hh_idx_map[entity]
22912295
ent_blocks = clone_blocks[ent_hh]
22922296
ent_hh_ids = household_ids[ent_hh]
2297+
ent_ci = np.full(len(ent_hh), clone_idx, dtype=np.int64)
22932298
draws = compute_block_takeup_for_entities(
22942299
var_name,
22952300
precomputed_rates[rate_key],
22962301
ent_blocks,
22972302
ent_hh_ids,
2303+
ent_ci,
22982304
)
22992305
wf_draws[entity] = draws
23002306
if var_name in person_vars:
@@ -2368,12 +2374,14 @@ def build_matrix(
23682374

23692375
ent_blocks = clone_blocks[ent_hh]
23702376
ent_hh_ids = household_ids[ent_hh]
2377+
ent_ci = np.full(n_ent, clone_idx, dtype=np.int64)
23712378

23722379
ent_takeup = compute_block_takeup_for_entities(
23732380
takeup_var,
23742381
precomputed_rates[info["rate_key"]],
23752382
ent_blocks,
23762383
ent_hh_ids,
2384+
ent_ci,
23772385
)
23782386

23792387
ent_values = (ent_eligible * ent_takeup).astype(np.float32)

policyengine_us_data/tests/test_calibration/test_unified_calibration.py

Lines changed: 99 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -74,44 +74,61 @@ def test_rate_comparison_produces_booleans(self):
7474

7575
class TestBlockSaltedDraws:
7676
"""Verify compute_block_takeup_for_entities produces
77-
reproducible, block-dependent draws."""
77+
reproducible, clone-dependent draws."""
7878

79-
def test_same_block_same_results(self):
80-
blocks = np.array(["370010001001001"] * 500)
81-
d1 = compute_block_takeup_for_entities("takes_up_snap_if_eligible", 0.8, blocks)
82-
d2 = compute_block_takeup_for_entities("takes_up_snap_if_eligible", 0.8, blocks)
79+
def test_same_inputs_same_results(self):
80+
n = 500
81+
blocks = np.array(["370010001001001"] * n)
82+
hh_ids = np.arange(n, dtype=np.int64)
83+
ci = np.zeros(n, dtype=np.int64)
84+
d1 = compute_block_takeup_for_entities(
85+
"takes_up_snap_if_eligible", 0.8, blocks, hh_ids, ci
86+
)
87+
d2 = compute_block_takeup_for_entities(
88+
"takes_up_snap_if_eligible", 0.8, blocks, hh_ids, ci
89+
)
8390
np.testing.assert_array_equal(d1, d2)
8491

85-
def test_different_blocks_different_results(self):
92+
def test_different_clone_idx_different_results(self):
8693
n = 500
94+
blocks = np.array(["370010001001001"] * n)
95+
hh_ids = np.arange(n, dtype=np.int64)
96+
ci0 = np.zeros(n, dtype=np.int64)
97+
ci1 = np.ones(n, dtype=np.int64)
8798
d1 = compute_block_takeup_for_entities(
88-
"takes_up_snap_if_eligible",
89-
0.8,
90-
np.array(["370010001001001"] * n),
99+
"takes_up_snap_if_eligible", 0.8, blocks, hh_ids, ci0
91100
)
92101
d2 = compute_block_takeup_for_entities(
93-
"takes_up_snap_if_eligible",
94-
0.8,
95-
np.array(["480010002002002"] * n),
102+
"takes_up_snap_if_eligible", 0.8, blocks, hh_ids, ci1
96103
)
97104
assert not np.array_equal(d1, d2)
98105

99106
def test_different_vars_different_results(self):
100-
blocks = np.array(["370010001001001"] * 500)
101-
d1 = compute_block_takeup_for_entities("takes_up_snap_if_eligible", 0.8, blocks)
102-
d2 = compute_block_takeup_for_entities("takes_up_aca_if_eligible", 0.8, blocks)
107+
n = 500
108+
blocks = np.array(["370010001001001"] * n)
109+
hh_ids = np.arange(n, dtype=np.int64)
110+
ci = np.zeros(n, dtype=np.int64)
111+
d1 = compute_block_takeup_for_entities(
112+
"takes_up_snap_if_eligible", 0.8, blocks, hh_ids, ci
113+
)
114+
d2 = compute_block_takeup_for_entities(
115+
"takes_up_aca_if_eligible", 0.8, blocks, hh_ids, ci
116+
)
103117
assert not np.array_equal(d1, d2)
104118

105-
def test_hh_salt_differs_from_block_only(self):
106-
blocks = np.array(["370010001001001"] * 500)
107-
hh_ids = np.array([1] * 500)
108-
d_block = compute_block_takeup_for_entities(
109-
"takes_up_snap_if_eligible", 0.8, blocks
119+
def test_different_hh_ids_different_results(self):
120+
n = 500
121+
blocks = np.array(["370010001001001"] * n)
122+
ci = np.zeros(n, dtype=np.int64)
123+
hh_a = np.arange(n, dtype=np.int64)
124+
hh_b = np.arange(n, dtype=np.int64) + 1000
125+
d1 = compute_block_takeup_for_entities(
126+
"takes_up_snap_if_eligible", 0.8, blocks, hh_a, ci
110127
)
111-
d_hh = compute_block_takeup_for_entities(
112-
"takes_up_snap_if_eligible", 0.8, blocks, hh_ids
128+
d2 = compute_block_takeup_for_entities(
129+
"takes_up_snap_if_eligible", 0.8, blocks, hh_b, ci
113130
)
114-
assert not np.array_equal(d_block, d_hh)
131+
assert not np.array_equal(d1, d2)
115132

116133

117134
class TestApplyBlockTakeupToArrays:
@@ -126,6 +143,7 @@ def _make_arrays(self, n_hh, persons_per_hh, tu_per_hh, spm_per_hh):
126143
hh_blocks = np.array(["370010001001001"] * n_hh)
127144
hh_state_fips = np.array([37] * n_hh, dtype=np.int32)
128145
hh_ids = np.arange(n_hh, dtype=np.int64)
146+
hh_clone_indices = np.zeros(n_hh, dtype=np.int64)
129147
entity_hh_indices = {
130148
"person": np.repeat(np.arange(n_hh), persons_per_hh),
131149
"tax_unit": np.repeat(np.arange(n_hh), tu_per_hh),
@@ -140,6 +158,7 @@ def _make_arrays(self, n_hh, persons_per_hh, tu_per_hh, spm_per_hh):
140158
hh_blocks,
141159
hh_state_fips,
142160
hh_ids,
161+
hh_clone_indices,
143162
entity_hh_indices,
144163
entity_counts,
145164
)
@@ -336,38 +355,61 @@ def test_county_fips_length(self):
336355

337356
class TestBlockTakeupSeeding:
338357
"""Verify compute_block_takeup_for_entities is
339-
reproducible and block-dependent."""
358+
reproducible and clone-dependent."""
340359

341360
def test_reproducible(self):
361+
n = 100
342362
blocks = np.array(["010010001001001"] * 50 + ["020010001001001"] * 50)
343-
r1 = compute_block_takeup_for_entities("takes_up_snap_if_eligible", 0.8, blocks)
344-
r2 = compute_block_takeup_for_entities("takes_up_snap_if_eligible", 0.8, blocks)
363+
hh_ids = np.arange(n, dtype=np.int64)
364+
ci = np.zeros(n, dtype=np.int64)
365+
r1 = compute_block_takeup_for_entities(
366+
"takes_up_snap_if_eligible", 0.8, blocks, hh_ids, ci
367+
)
368+
r2 = compute_block_takeup_for_entities(
369+
"takes_up_snap_if_eligible", 0.8, blocks, hh_ids, ci
370+
)
345371
np.testing.assert_array_equal(r1, r2)
346372

347-
def test_different_blocks_different_draws(self):
373+
def test_different_blocks_different_rates(self):
374+
"""With state-dependent rates, different blocks yield
375+
different takeup because rate thresholds differ."""
348376
n = 500
349-
blocks_a = np.array(["010010001001001"] * n)
350-
blocks_b = np.array(["020010001001001"] * n)
377+
hh_ids = np.arange(n, dtype=np.int64)
378+
ci = np.zeros(n, dtype=np.int64)
379+
rate_dict = {"AL": 0.9, "AK": 0.3}
351380
r_a = compute_block_takeup_for_entities(
352-
"takes_up_snap_if_eligible", 0.8, blocks_a
381+
"takes_up_snap_if_eligible",
382+
rate_dict,
383+
np.array(["010010001001001"] * n),
384+
hh_ids,
385+
ci,
353386
)
354387
r_b = compute_block_takeup_for_entities(
355-
"takes_up_snap_if_eligible", 0.8, blocks_b
388+
"takes_up_snap_if_eligible",
389+
rate_dict,
390+
np.array(["020010001001001"] * n),
391+
hh_ids,
392+
ci,
356393
)
357394
assert not np.array_equal(r_a, r_b)
358395

359396
def test_returns_booleans(self):
360-
blocks = np.array(["370010001001001"] * 100)
397+
n = 100
398+
blocks = np.array(["370010001001001"] * n)
399+
hh_ids = np.arange(n, dtype=np.int64)
400+
ci = np.zeros(n, dtype=np.int64)
361401
result = compute_block_takeup_for_entities(
362-
"takes_up_snap_if_eligible", 0.8, blocks
402+
"takes_up_snap_if_eligible", 0.8, blocks, hh_ids, ci
363403
)
364404
assert result.dtype == bool
365405

366406
def test_rate_respected(self):
367407
n = 10000
368408
blocks = np.array(["370010001001001"] * n)
409+
hh_ids = np.arange(n, dtype=np.int64)
410+
ci = np.zeros(n, dtype=np.int64)
369411
result = compute_block_takeup_for_entities(
370-
"takes_up_snap_if_eligible", 0.75, blocks
412+
"takes_up_snap_if_eligible", 0.75, blocks, hh_ids, ci
371413
)
372414
frac = result.mean()
373415
assert 0.70 < frac < 0.80
@@ -481,6 +523,7 @@ def test_matrix_and_stacked_identical_draws(self):
481523
"""Both paths must produce identical boolean arrays."""
482524
var = "takes_up_snap_if_eligible"
483525
rate = 0.75
526+
clone_idx = 5
484527

485528
# 2 blocks, 3 households, variable entity counts per HH
486529
# HH0 has 2 entities in block A
@@ -497,20 +540,23 @@ def test_matrix_and_stacked_identical_draws(self):
497540
]
498541
)
499542
hh_ids = np.array([100, 100, 200, 200, 200, 300])
543+
ci = np.full(len(blocks), clone_idx, dtype=np.int64)
500544

501-
# Path 1: compute_block_takeup_for_entities (stacked)
502-
stacked = compute_block_takeup_for_entities(var, rate, blocks, hh_ids)
545+
# Path 1: compute_block_takeup_for_entities
546+
stacked = compute_block_takeup_for_entities(var, rate, blocks, hh_ids, ci)
503547

504-
# Path 2: reproduce matrix builder inline logic
548+
# Path 2: reproduce inline logic with hh_id:clone_idx salt
505549
n = len(blocks)
506550
inline_takeup = np.zeros(n, dtype=bool)
507-
for blk in np.unique(blocks):
508-
bm = blocks == blk
509-
for hh_id in np.unique(hh_ids[bm]):
510-
hh_mask = bm & (hh_ids == hh_id)
511-
rng = seeded_rng(var, salt=f"{blk}:{int(hh_id)}")
512-
draws = rng.random(int(hh_mask.sum()))
513-
inline_takeup[hh_mask] = draws < rate
551+
for hh_id in np.unique(hh_ids):
552+
hh_mask = hh_ids == hh_id
553+
rng = seeded_rng(var, salt=f"{int(hh_id)}:{clone_idx}")
554+
draws = rng.random(int(hh_mask.sum()))
555+
# Rate from block's state FIPS
556+
blk = blocks[hh_mask][0]
557+
sf = int(str(blk)[:2])
558+
r = _resolve_rate(rate, sf)
559+
inline_takeup[hh_mask] = draws < r
514560

515561
np.testing.assert_array_equal(stacked, inline_takeup)
516562

@@ -542,18 +588,22 @@ def test_state_specific_rate_resolved_from_block(self):
542588
n = 5000
543589

544590
blocks_nc = np.array(["370010001001001"] * n)
545-
result_nc = compute_block_takeup_for_entities(var, rate_dict, blocks_nc)
546-
# NC rate=0.9, expect ~90%
591+
hh_ids_nc = np.arange(n, dtype=np.int64)
592+
ci = np.zeros(n, dtype=np.int64)
593+
result_nc = compute_block_takeup_for_entities(
594+
var, rate_dict, blocks_nc, hh_ids_nc, ci
595+
)
547596
frac_nc = result_nc.mean()
548597
assert 0.85 < frac_nc < 0.95, f"NC frac={frac_nc}"
549598

550599
blocks_tx = np.array(["480010002002002"] * n)
551-
result_tx = compute_block_takeup_for_entities(var, rate_dict, blocks_tx)
552-
# TX rate=0.6, expect ~60%
600+
hh_ids_tx = np.arange(n, dtype=np.int64)
601+
result_tx = compute_block_takeup_for_entities(
602+
var, rate_dict, blocks_tx, hh_ids_tx, ci
603+
)
553604
frac_tx = result_tx.mean()
554605
assert 0.55 < frac_tx < 0.65, f"TX frac={frac_tx}"
555606

556-
# Verify _resolve_rate actually gives different rates
557607
assert _resolve_rate(rate_dict, 37) == 0.9
558608
assert _resolve_rate(rate_dict, 48) == 0.6
559609

0 commit comments

Comments
 (0)