Skip to content

Commit d327734

Browse files
authored
Add household count target to ECPS calibration (#1149)
1 parent e181981 commit d327734

9 files changed

Lines changed: 115 additions & 5 deletions

File tree

changelog.d/1149.fixed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Constrain ECPS calibration to the source household count so PUF clone reweighting cannot inflate total household weight.

policyengine_us_data/datasets/cps/enhanced_cps.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pandas as pd
77
from policyengine_us_data.utils import (
88
ABSOLUTE_ERROR_SCALE_TARGETS,
9+
HOUSEHOLD_COUNT_TARGET,
910
build_loss_matrix,
1011
get_target_error_normalisation,
1112
get_target_loss_weights,
@@ -669,6 +670,11 @@ def generate(self):
669670
del loss_matrix, targets_array
670671
gc.collect()
671672
assert loss_matrix_clean.shape[1] == targets_array_clean.size
673+
if HOUSEHOLD_COUNT_TARGET not in loss_matrix_clean.columns:
674+
raise ValueError(
675+
f"{HOUSEHOLD_COUNT_TARGET} missing from EnhancedCPS "
676+
"calibration targets"
677+
)
672678

673679
loss_matrix_clean = loss_matrix_clean.astype(np.float32)
674680

policyengine_us_data/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
__all__ = [
1212
"ABSOLUTE_ERROR_SCALE_TARGETS",
1313
"HardConcrete",
14+
"HOUSEHOLD_COUNT_TARGET",
1415
"build_loss_matrix",
1516
"get_target_error_normalisation",
1617
"get_target_loss_weights",

policyengine_us_data/utils/loss.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@
8282

8383
BEA_NIPA_DIRECT_SUM_LOSS_WEIGHT = 1_000.0
8484
BEA_WAGES_AND_SALARIES_LOSS_WEIGHT = 1_000.0
85+
HOUSEHOLD_COUNT_TARGET = "nation/source/household_count"
86+
HOUSEHOLD_COUNT_LOSS_WEIGHT = 1_000.0
8587

8688
CBO_INCOME_BY_SOURCE_TARGETS = [
8789
("irs_employment_income", "employment_income"),
@@ -1199,6 +1201,31 @@ def _add_transfer_balance_targets(loss_matrix, targets_list, sim, time_period):
11991201
return targets_list, loss_matrix
12001202

12011203

1204+
def _add_household_count_target(loss_matrix, targets_list, sim):
1205+
"""Constrain total household weight to the source survey total."""
1206+
1207+
household_weights = sim.calculate("household_weight").values
1208+
if len(loss_matrix) != len(household_weights):
1209+
raise ValueError(
1210+
"Household count target length mismatch: "
1211+
f"loss matrix has {len(loss_matrix)} rows but household_weight has "
1212+
f"{len(household_weights)} values"
1213+
)
1214+
1215+
target = float(np.sum(household_weights))
1216+
if not np.isfinite(target) or target <= 0:
1217+
raise ValueError(
1218+
"Household count target must have positive finite source weight total"
1219+
)
1220+
1221+
loss_matrix[HOUSEHOLD_COUNT_TARGET] = np.ones(
1222+
len(household_weights),
1223+
dtype=np.float32,
1224+
)
1225+
targets_list.append(target)
1226+
return targets_list, loss_matrix
1227+
1228+
12021229
def get_target_error_normalisation(target_names, targets_array):
12031230
"""Return numerator shifts and denominators for target loss scaling."""
12041231
target_names = np.asarray(target_names)
@@ -1227,6 +1254,7 @@ def get_target_loss_weights(target_names):
12271254
) | np.char.startswith(target_names, "state/bea/wages_and_salaries/")
12281255
weights[is_bea_direct_sum_target] = BEA_NIPA_DIRECT_SUM_LOSS_WEIGHT
12291256
weights[is_bea_wage_target] = BEA_WAGES_AND_SALARIES_LOSS_WEIGHT
1257+
weights[target_names == HOUSEHOLD_COUNT_TARGET] = HOUSEHOLD_COUNT_LOSS_WEIGHT
12301258
return weights
12311259

12321260

@@ -1360,6 +1388,12 @@ def build_loss_matrix(dataset: type, time_period):
13601388
hh_id = sim.calculate("household_id").values
13611389
loss_matrix = loss_matrix.loc[hh_id]
13621390

1391+
targets_array, loss_matrix = _add_household_count_target(
1392+
loss_matrix,
1393+
targets_array,
1394+
sim,
1395+
)
1396+
13631397
# Census single-year age population projections
13641398

13651399
populations = pd.read_csv(CALIBRATION_FOLDER / "np2023_d5_mid.csv")

policyengine_us_data/utils/national_target_parity.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,8 @@ def _legacy_reason(target_name: str) -> str:
675675
return "legacy_cms_aca_spending_target_not_in_target_db"
676676
if target_name.startswith("nation/accounting/"):
677677
return "legacy_accounting_balance_target_not_in_target_db"
678+
if target_name == "nation/source/household_count":
679+
return "legacy_source_household_count_target_not_in_target_db"
678680
if target_name.startswith("nation/irs/negative_household_market_income_"):
679681
return "legacy_negative_market_income_target_not_in_target_db"
680682
if target_name == "nation/census/infants":

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ classifiers = [
2222
"Programming Language :: Python :: 3.14",
2323
]
2424
dependencies = [
25-
"policyengine-us==1.709.1",
25+
"policyengine-us==1.711.0",
2626
# policyengine-core 3.26.1 is the current 3.26.x runtime and includes the fix for
2727
# PolicyEngine/policyengine-core#482 (user-set ETERNITY inputs lost
2828
# after _invalidate_all_caches) and is required by policyengine-us 1.682.1+.

tests/unit/calibration/test_loss_targets.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@
1616
BEA_WAGES_AND_SALARIES_LOSS_WEIGHT,
1717
BLS_CE_TOTALS,
1818
HARD_CODED_TOTALS,
19+
HOUSEHOLD_COUNT_LOSS_WEIGHT,
20+
HOUSEHOLD_COUNT_TARGET,
1921
LOW_AGI_INVESTMENT_INCOME_SOI_VARIABLES,
2022
SOI_NEGATIVE_AGI_TARGETED_VARIABLES,
2123
TRANSFER_BALANCE_TARGETS,
2224
_add_bea_state_wage_targets,
2325
_add_agi_metric_columns,
2426
_add_acs_housing_cost_targets,
27+
_add_household_count_target,
2528
_add_aotc_targets,
2629
_add_bls_ce_targets,
2730
_add_ctc_targets,
@@ -167,6 +170,22 @@ def test_bea_nipa_direct_sum_targets_get_higher_loss_weight():
167170
]
168171

169172

173+
def test_household_count_target_gets_higher_loss_weight():
174+
target_names = np.array(
175+
[
176+
HOUSEHOLD_COUNT_TARGET,
177+
"nation/census/population_by_age/0",
178+
]
179+
)
180+
181+
weights = get_target_loss_weights(target_names)
182+
183+
assert weights.tolist() == [
184+
HOUSEHOLD_COUNT_LOSS_WEIGHT,
185+
1.0,
186+
]
187+
188+
170189
def test_aca_targets_roll_forward_to_2025():
171190
targets, data_year = _load_aca_spending_and_enrollment_targets(2025)
172191

@@ -243,6 +262,17 @@ def __init__(self, values):
243262
self.values = np.asarray(values)
244263

245264

265+
class _FakeHouseholdWeightSimulation:
266+
def __init__(self, weights):
267+
self.weights = weights
268+
269+
def calculate(self, variable, map_to=None, period=None):
270+
assert variable == "household_weight"
271+
assert map_to is None
272+
assert period is None
273+
return _FakeArrayResult(self.weights)
274+
275+
246276
class _FakeSimulation:
247277
def __init__(self):
248278
self.calculate_calls = []
@@ -427,6 +457,28 @@ def test_state_agi_targets_are_limited_to_filers(tmp_path, monkeypatch):
427457
)
428458

429459

460+
def test_add_household_count_target_uses_source_weight_total():
461+
loss_matrix = pd.DataFrame(index=[101, 102, 103, 104])
462+
463+
targets, loss_matrix = _add_household_count_target(
464+
loss_matrix,
465+
[],
466+
_FakeHouseholdWeightSimulation([80.0, 20.0, 0.0, 0.0]),
467+
)
468+
469+
assert targets == [100.0]
470+
np.testing.assert_array_equal(
471+
loss_matrix[HOUSEHOLD_COUNT_TARGET].to_numpy(),
472+
np.ones(4, dtype=np.float32),
473+
)
474+
475+
476+
def test_build_loss_matrix_adds_household_count_target_before_reweighting():
477+
source = inspect.getsource(build_loss_matrix)
478+
479+
assert "_add_household_count_target" in source
480+
481+
430482
def test_add_ssi_recipient_targets_adds_total_and_age_counts():
431483
targets, loss_matrix = _add_ssi_recipient_targets(
432484
pd.DataFrame(),

tests/unit/datasets/test_enhanced_cps_seeding.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,17 @@ def test_enhanced_cps_sources_use_deterministic_weight_priors():
4242

4343
assert "np.random.normal" not in source
4444
assert source.count("initialize_weight_priors(original_weights.values)") == 2
45+
46+
47+
def test_initialize_weight_priors_preserves_source_weight_total():
48+
from policyengine_us_data.datasets.cps.enhanced_cps import (
49+
initialize_weight_priors,
50+
)
51+
52+
priors = initialize_weight_priors(
53+
np.array([80.0, 20.0, 0.0, 0.0]),
54+
zero_weight_total_share=0.5,
55+
)
56+
57+
np.testing.assert_allclose(priors.sum(), 100.0)
58+
np.testing.assert_allclose(priors, np.array([40.0, 10.0, 25.0, 25.0]))

uv.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)