Skip to content

Commit 204e4fc

Browse files
authored
Merge pull request #1151 from PolicyEngine/codex/fix-puf-clone-support-priors-20260528
Keep PUF clone priors as support weights
2 parents 480de99 + e766abb commit 204e4fc

8 files changed

Lines changed: 221 additions & 45 deletions

File tree

changelog.d/1151.fixed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Reserve a small share of prior weight for zero-weight PUF clone rows (instead of near-zero) so they stay usable in calibration, and validate that final enhanced CPS weights keep PUF clones above a floor rather than starving them.

policyengine_us_data/datasets/cps/enhanced_cps.py

Lines changed: 72 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from policyengine_us_data.utils import (
88
ABSOLUTE_ERROR_SCALE_TARGETS,
99
HOUSEHOLD_COUNT_TARGET,
10-
PUF_CLONE_HOUSEHOLD_COUNT_TARGET_SHARE,
1110
build_loss_matrix,
1211
get_target_error_normalisation,
1312
get_target_loss_weights,
@@ -44,24 +43,31 @@
4443

4544

4645
HOUSEHOLD_WEIGHT_TOTAL_REL_TOLERANCE = 0.02
47-
PUF_CLONE_HOUSEHOLD_WEIGHT_SHARE_TOLERANCE = 0.10
4846
PERSON_POVERTY_RATE_MIN = 0.05
4947
PERSON_POVERTY_RATE_MAX = 0.25
48+
# PUF clones enter the extended CPS with zero household weight. They are support
49+
# records for calibration, but the earlier bug starved them to ~0 (unusable in
50+
# log-space optimization). Reserve a small but non-trivial share of prior mass
51+
# for them, and validate that final weights keep them above a floor. There is no
52+
# upper cap: the household-count loss target (loss.py) governs how much weight
53+
# clones ultimately carry.
54+
PUF_CLONE_PRIOR_TOTAL_SHARE = 0.05
55+
MIN_PUF_CLONE_HOUSEHOLD_WEIGHT_SHARE_PCT = 5.0
56+
MAX_PUF_CLONE_TAXES_EXCEED_MARKET_INCOME_SHARE_PCT = 25.0
5057

5158

5259
def initialize_weight_priors(
5360
original_weights: np.ndarray,
5461
seed: int = 1456,
5562
epsilon: float = 1e-6,
56-
zero_weight_total_share: float = 0.5,
63+
zero_weight_total_share: float = PUF_CLONE_PRIOR_TOTAL_SHARE,
5764
) -> np.ndarray:
5865
"""Build deterministic positive priors for sparse reweighting.
5966
6067
PUF clone households enter the extended CPS with zero household weight.
61-
Giving those records near-zero priors leaves them effectively unusable in
62-
log-space optimization. When zero-weight rows are present, preserve the
63-
relative distribution of positive survey weights but reserve a fixed share
64-
of the original total household mass for uniform zero-weight-row priors.
68+
Reserve a small but non-trivial share of prior mass for them so they remain
69+
usable in log-space optimization (the earlier bug starved them to ~0). Their
70+
final weight is governed by the household-count loss target, not this prior.
6571
"""
6672

6773
weights = np.asarray(original_weights, dtype=np.float64)
@@ -135,10 +141,14 @@ def validate_clone_household_weight_share(
135141
household_is_puf_clone: np.ndarray,
136142
*,
137143
year: int,
138-
target_share: float = PUF_CLONE_HOUSEHOLD_COUNT_TARGET_SHARE,
139-
abs_tolerance: float = PUF_CLONE_HOUSEHOLD_WEIGHT_SHARE_TOLERANCE,
144+
min_share: float = MIN_PUF_CLONE_HOUSEHOLD_WEIGHT_SHARE_PCT / 100,
140145
) -> float:
141-
"""Validate that PUF-clone households do not dominate final weights."""
146+
"""Validate that PUF-clone households keep a usable share of final weight.
147+
148+
Clones must not be starved below ``min_share`` (the earlier bug left them at
149+
~0, unusable in log-space optimization). There is no upper cap: the
150+
household-count loss target governs how much weight clones ultimately carry.
151+
"""
142152

143153
weights = np.asarray(weights, dtype=np.float64)
144154
household_is_puf_clone = np.asarray(household_is_puf_clone, dtype=bool)
@@ -154,12 +164,11 @@ def validate_clone_household_weight_share(
154164
raise ValueError(f"Year {year}: household_weight total must be positive")
155165

156166
clone_share = float(weights[household_is_puf_clone].sum()) / total
157-
if abs(clone_share - target_share) > abs_tolerance:
167+
if clone_share < min_share:
158168
raise ValueError(
159169
f"Year {year}: PUF-clone household weight share "
160-
f"{clone_share:.2%} differs from target {target_share:.2%} by "
161-
f"{abs(clone_share - target_share):.2%}, exceeding "
162-
f"{abs_tolerance:.2%} tolerance"
170+
f"{clone_share:.2%} is below the {min_share:.2%} floor; clones are "
171+
f"being starved of weight"
163172
)
164173

165174
return clone_share
@@ -201,6 +210,41 @@ def validate_person_poverty_rate(
201210
return poverty_rate
202211

203212

213+
def validate_clone_diagnostics(
214+
diagnostics: dict[str, float],
215+
*,
216+
min_household_weight_share_pct: float = MIN_PUF_CLONE_HOUSEHOLD_WEIGHT_SHARE_PCT,
217+
max_taxes_exceed_market_income_share_pct: float = (
218+
MAX_PUF_CLONE_TAXES_EXCEED_MARKET_INCOME_SHARE_PCT
219+
),
220+
) -> None:
221+
"""Reject enhanced CPS artifacts where PUF support clones are starved.
222+
223+
Enforces a floor on clone household weight share (clones must keep at least
224+
``min_household_weight_share_pct`` of total weight, the earlier bug) plus a
225+
data-quality bound on clones whose imputed taxes exceed market income. There
226+
is no upper cap on weight share: the household-count loss target governs that.
227+
"""
228+
229+
clone_household_share = diagnostics["clone_household_weight_share_pct"]
230+
if clone_household_share < min_household_weight_share_pct:
231+
raise ValueError(
232+
"PUF clone household weight share "
233+
f"{clone_household_share:.1f}% is below the "
234+
f"{min_household_weight_share_pct:.1f}% floor"
235+
)
236+
237+
taxes_exceed_market_income_share = diagnostics[
238+
"clone_taxes_exceed_market_income_share_pct"
239+
]
240+
if taxes_exceed_market_income_share > max_taxes_exceed_market_income_share_pct:
241+
raise ValueError(
242+
"PUF clone taxes-exceed-market-income share "
243+
f"{taxes_exceed_market_income_share:.1f}% exceeds "
244+
f"{max_taxes_exceed_market_income_share_pct:.1f}%"
245+
)
246+
247+
204248
def _to_numpy(value) -> np.ndarray:
205249
return np.asarray(getattr(value, "values", value))
206250

@@ -351,17 +395,22 @@ def save_clone_diagnostics_report(
351395
end_year: int,
352396
) -> tuple[Path, dict]:
353397
periods = list(range(start_year, end_year + 1))
398+
399+
def build_validated_payload():
400+
period_to_diagnostics = {
401+
period: build_clone_diagnostics_for_saved_dataset(
402+
dataset_cls,
403+
period,
404+
)
405+
for period in periods
406+
}
407+
for diagnostics in period_to_diagnostics.values():
408+
validate_clone_diagnostics(diagnostics)
409+
return build_clone_diagnostics_payload(period_to_diagnostics)
410+
354411
output_path = refresh_clone_diagnostics_report(
355412
dataset_cls.file_path,
356-
lambda: build_clone_diagnostics_payload(
357-
{
358-
period: build_clone_diagnostics_for_saved_dataset(
359-
dataset_cls,
360-
period,
361-
)
362-
for period in periods
363-
}
364-
),
413+
build_validated_payload,
365414
)
366415
diagnostics_payload = json.loads(output_path.read_text())
367416
return output_path, diagnostics_payload

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ classifiers = [
2222
"Programming Language :: Python :: 3.14",
2323
]
2424
dependencies = [
25-
"policyengine-us==1.715.2",
25+
"policyengine-us==1.715.3",
2626
# policyengine-core 3.26.1 is the current 3.26.x runtime and includes the fix for
2727
# PolicyEngine/policyengine-core#482 (user-set ETERNITY inputs lost
2828
# after _invalidate_all_caches) and is required by policyengine-us 1.682.1+.

tests/unit/datasets/test_enhanced_cps_seeding.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Earlier versions used global ``np.random.normal(1, 0.1, ...)`` jitter before
44
``reweight()`` reseeded the optimizer. Current code routes both dense CPS
55
weighting paths through ``initialize_weight_priors()``, which preserves positive
6-
survey weight shape and gives zero-weight clone records deterministic uniform
6+
survey weight shape and gives zero-weight clone records deterministic support
77
prior mass.
88
"""
99

@@ -86,11 +86,13 @@ def test_validate_household_weight_total_rejects_inflated_total():
8686
)
8787

8888

89-
def test_validate_clone_household_weight_share_accepts_target_share():
89+
def test_validate_clone_household_weight_share_accepts_healthy_share():
9090
from policyengine_us_data.datasets.cps.enhanced_cps import (
9191
validate_clone_household_weight_share,
9292
)
9393

94+
# A high clone share is fine: there is no upper cap (the loss target governs
95+
# how much weight clones carry); the guard only enforces a floor.
9496
share = validate_clone_household_weight_share(
9597
np.array([40_000_000.0, 10_000_000.0, 25_000_000.0, 25_000_000.0]),
9698
np.array([False, False, True, True]),
@@ -100,14 +102,15 @@ def test_validate_clone_household_weight_share_accepts_target_share():
100102
assert share == pytest.approx(0.5)
101103

102104

103-
def test_validate_clone_household_weight_share_rejects_clone_dominance():
105+
def test_validate_clone_household_weight_share_rejects_clone_starvation():
104106
from policyengine_us_data.datasets.cps.enhanced_cps import (
105107
validate_clone_household_weight_share,
106108
)
107109

108-
with pytest.raises(ValueError, match="PUF-clone household weight share"):
110+
# Clones starved to ~2.4% of weight (below the 5% floor) must fail.
111+
with pytest.raises(ValueError, match="floor"):
109112
validate_clone_household_weight_share(
110-
np.array([10_000_000.0, 10_000_000.0, 40_000_000.0, 40_000_000.0]),
113+
np.array([80_000_000.0, 80_000_000.0, 2_000_000.0, 2_000_000.0]),
111114
np.array([False, False, True, True]),
112115
year=2024,
113116
)

tests/unit/test_enhanced_cps_clone_diagnostics.py

Lines changed: 99 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,26 @@
99
compute_clone_diagnostics_summary,
1010
clone_diagnostics_path,
1111
initialize_weight_priors,
12+
PUF_CLONE_PRIOR_TOTAL_SHARE,
1213
refresh_clone_diagnostics_report,
1314
save_clone_diagnostics_report,
15+
validate_clone_diagnostics,
1416
)
1517

1618

17-
def test_initialize_weight_priors_gives_zero_weight_records_balanced_mass():
19+
def test_initialize_weight_priors_gives_zero_weight_records_support_mass():
1820
weights = np.array([1_500.0, 0.0, 625.0, 0.0], dtype=np.float64)
1921

2022
priors = initialize_weight_priors(weights, seed=123)
2123

2224
assert np.all(priors > 0)
2325
assert priors.sum() == pytest.approx(weights.sum())
24-
assert priors[[0, 2]].sum() == pytest.approx(weights.sum() / 2)
25-
assert priors[[1, 3]].sum() == pytest.approx(weights.sum() / 2)
26+
assert priors[[1, 3]].sum() == pytest.approx(
27+
weights.sum() * PUF_CLONE_PRIOR_TOTAL_SHARE
28+
)
29+
assert priors[[0, 2]].sum() == pytest.approx(
30+
weights.sum() * (1 - PUF_CLONE_PRIOR_TOTAL_SHARE)
31+
)
2632
assert priors[1] == pytest.approx(priors[3])
2733
assert priors[0] / priors[2] == pytest.approx(weights[0] / weights[2])
2834

@@ -44,6 +50,15 @@ def test_initialize_weight_priors_is_reproducible():
4450
np.testing.assert_allclose(priors_a, priors_b)
4551

4652

53+
def test_initialize_weight_priors_honors_configured_zero_weight_share():
54+
weights = np.array([80.0, 20.0, 0.0, 0.0])
55+
56+
priors = initialize_weight_priors(weights, zero_weight_total_share=0.5)
57+
58+
np.testing.assert_allclose(priors.sum(), 100.0)
59+
np.testing.assert_allclose(priors, np.array([40.0, 10.0, 25.0, 25.0]))
60+
61+
4762
def test_compute_clone_diagnostics_summary():
4863
diagnostics = compute_clone_diagnostics_summary(
4964
household_is_puf_clone=[False, True],
@@ -70,6 +85,49 @@ def test_compute_clone_diagnostics_summary():
7085
)
7186

7287

88+
def test_validate_clone_diagnostics_accepts_support_clone_share():
89+
validate_clone_diagnostics(
90+
{
91+
"clone_household_weight_share_pct": 10.0,
92+
"clone_taxes_exceed_market_income_share_pct": 5.0,
93+
}
94+
)
95+
96+
97+
def test_validate_clone_diagnostics_rejects_clone_starvation():
98+
with pytest.raises(ValueError, match="floor"):
99+
validate_clone_diagnostics(
100+
{
101+
"clone_household_weight_share_pct": 2.0,
102+
"clone_taxes_exceed_market_income_share_pct": 5.0,
103+
}
104+
)
105+
106+
107+
def test_validate_clone_diagnostics_accepts_high_share_no_cap():
108+
# No upper cap on clone weight share (the household-count loss target governs
109+
# it); a high share with healthy tax quality must pass.
110+
validate_clone_diagnostics(
111+
{
112+
"clone_household_weight_share_pct": 81.3,
113+
"clone_taxes_exceed_market_income_share_pct": 5.0,
114+
}
115+
)
116+
117+
118+
def test_validate_clone_diagnostics_rejects_clone_tax_pathology():
119+
with pytest.raises(
120+
ValueError,
121+
match="PUF clone taxes-exceed-market-income share",
122+
):
123+
validate_clone_diagnostics(
124+
{
125+
"clone_household_weight_share_pct": 10.0,
126+
"clone_taxes_exceed_market_income_share_pct": 66.6,
127+
}
128+
)
129+
130+
73131
def test_build_clone_diagnostics_for_simulation_maps_household_weights(
74132
monkeypatch,
75133
):
@@ -201,7 +259,11 @@ class DummyDataset:
201259

202260
monkeypatch.setattr(
203261
"policyengine_us_data.datasets.cps.enhanced_cps.build_clone_diagnostics_for_saved_dataset",
204-
lambda dataset_cls, period: {"clone_person_weight_share_pct": float(period)},
262+
lambda dataset_cls, period: {
263+
"clone_person_weight_share_pct": float(period),
264+
"clone_household_weight_share_pct": 10.0,
265+
"clone_taxes_exceed_market_income_share_pct": 5.0,
266+
},
205267
)
206268

207269
output_path, payload = save_clone_diagnostics_report(
@@ -213,8 +275,39 @@ class DummyDataset:
213275
assert output_path == clone_diagnostics_path(DummyDataset.file_path)
214276
assert payload == {
215277
"periods": {
216-
"2024": {"clone_person_weight_share_pct": 2024.0},
217-
"2025": {"clone_person_weight_share_pct": 2025.0},
278+
"2024": {
279+
"clone_person_weight_share_pct": 2024.0,
280+
"clone_household_weight_share_pct": 10.0,
281+
"clone_taxes_exceed_market_income_share_pct": 5.0,
282+
},
283+
"2025": {
284+
"clone_person_weight_share_pct": 2025.0,
285+
"clone_household_weight_share_pct": 10.0,
286+
"clone_taxes_exceed_market_income_share_pct": 5.0,
287+
},
218288
}
219289
}
220290
assert output_path.exists()
291+
292+
293+
def test_save_clone_diagnostics_report_rejects_bad_clone_payload(tmp_path, monkeypatch):
294+
class DummyDataset:
295+
file_path = tmp_path / "enhanced_cps_2024.h5"
296+
297+
DummyDataset.file_path.write_text("placeholder")
298+
299+
monkeypatch.setattr(
300+
"policyengine_us_data.datasets.cps.enhanced_cps.build_clone_diagnostics_for_saved_dataset",
301+
lambda dataset_cls, period: {
302+
"clone_person_weight_share_pct": 1.0,
303+
"clone_household_weight_share_pct": 2.0,
304+
"clone_taxes_exceed_market_income_share_pct": 5.0,
305+
},
306+
)
307+
308+
with pytest.raises(ValueError, match="PUF clone household weight share"):
309+
save_clone_diagnostics_report(
310+
DummyDataset,
311+
start_year=2024,
312+
end_year=2024,
313+
)

uv.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)