Skip to content

Commit f89bc1a

Browse files
committed
Address source imputation review findings
1 parent 89908fd commit f89bc1a

9 files changed

Lines changed: 139 additions & 13 deletions

File tree

policyengine_us_data/calibration/source_impute.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
from policyengine_us_data.pipeline_metadata import pipeline_node
8282
from policyengine_us_data.pipeline_schema import PipelineNode
8383
from policyengine_us_data.utils.source_quality import (
84+
require_columns_present,
8485
target_observed_source_masks,
8586
)
8687

@@ -532,11 +533,20 @@ def _impute_acs(
532533
acs_df["state_fips"] = acs.calculate("state_fips", map_to="person").values.astype(
533534
np.float32
534535
)
536+
required_acs_flags = [
537+
column
538+
for columns in ACS_TARGET_ALLOCATION_COLUMNS.values()
539+
for column in columns
540+
]
535541
with h5py.File(ACS_2022.file_path, "r") as acs_h5:
542+
require_columns_present(
543+
acs_h5,
544+
required_acs_flags,
545+
source_name="ACS_2022 artifact",
546+
)
536547
for flag_columns in ACS_TARGET_ALLOCATION_COLUMNS.values():
537548
for flag_column in flag_columns:
538-
if flag_column in acs_h5:
539-
acs_df[flag_column] = np.asarray(acs_h5[flag_column], dtype=bool)
549+
acs_df[flag_column] = np.asarray(acs_h5[flag_column], dtype=bool)
540550

541551
train_df = acs_df[acs_df.is_household_head].copy()
542552
train_df = _encode_tenure_type(train_df)
@@ -654,6 +664,8 @@ def _impute_sipp(
654664
sipp_df["treasury_tipped_occupation_code"]
655665
)
656666

667+
if "MONTHCODE" in sipp_df:
668+
sipp_df = sipp_df[sipp_df["MONTHCODE"] == 12].copy()
657669
sipp_df["is_under_18"] = sipp_df.TAGE < 18
658670
sipp_df["is_under_6"] = sipp_df.TAGE < 6
659671
sipp_df["count_under_18"] = (
@@ -662,8 +674,6 @@ def _impute_sipp(
662674
sipp_df["count_under_6"] = (
663675
sipp_df.groupby("SSUID")["is_under_6"].sum().loc[sipp_df.SSUID.values].values
664676
)
665-
if "MONTHCODE" in sipp_df:
666-
sipp_df = sipp_df[sipp_df["MONTHCODE"] == 12].copy()
667677

668678
tip_target_filters = target_observed_source_masks(
669679
sipp_df,

policyengine_us_data/datasets/cps/cps.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,10 @@
7272
)
7373
from policyengine_us_data.pipeline_metadata import pipeline_node
7474
from policyengine_us_data.pipeline_schema import PipelineNode
75-
from policyengine_us_data.utils.source_quality import target_observed_source_masks
75+
from policyengine_us_data.utils.source_quality import (
76+
require_columns_present,
77+
target_observed_source_masks,
78+
)
7679

7780
ACS_RENT_TARGET_ALLOCATION_COLUMNS = {
7881
"rent": ["rent_is_allocated"],
@@ -415,14 +418,23 @@ def add_rent(self, cps: h5py.File, person: DataFrame, household: DataFrame):
415418
# H5; for CPS we use the in-memory dict (already populated upstream in
416419
# add_id_variables). Remove both overrides once pyproject.toml's
417420
# policyengine-core upper bound is lifted.
421+
required_acs_flags = [
422+
column
423+
for columns in ACS_RENT_TARGET_ALLOCATION_COLUMNS.values()
424+
for column in columns
425+
]
418426
with h5py.File(ACS_2022.file_path, "r") as acs_h5:
419427
train_df["is_household_head"] = np.asarray(
420428
acs_h5["is_household_head"], dtype=bool
421429
)
430+
require_columns_present(
431+
acs_h5,
432+
required_acs_flags,
433+
source_name="ACS_2022 artifact",
434+
)
422435
for flag_columns in ACS_RENT_TARGET_ALLOCATION_COLUMNS.values():
423436
for flag_column in flag_columns:
424-
if flag_column in acs_h5:
425-
train_df[flag_column] = np.asarray(acs_h5[flag_column], dtype=bool)
437+
train_df[flag_column] = np.asarray(acs_h5[flag_column], dtype=bool)
426438
train_df.tenure_type = train_df.tenure_type.map(
427439
{
428440
"OWNED_OUTRIGHT": "OWNED_WITH_MORTGAGE",

policyengine_us_data/utils/source_quality.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
import logging
6-
from collections.abc import Mapping, Sequence
6+
from collections.abc import Container, Mapping, Sequence
77

88
import pandas as pd
99

@@ -17,6 +17,24 @@ def sipp_allocation_flag_for(source_column: str) -> str:
1717
return f"A{source_column[1:]}"
1818

1919

20+
def require_columns_present(
21+
available_columns: Container[str],
22+
required_columns: Sequence[str],
23+
*,
24+
source_name: str,
25+
) -> None:
26+
"""Raise if required donor-source provenance columns are unavailable."""
27+
missing_columns = sorted(
28+
{column for column in required_columns if column not in available_columns}
29+
)
30+
if missing_columns:
31+
raise KeyError(
32+
f"{source_name} is missing required source-quality columns: "
33+
f"{', '.join(missing_columns)}. Regenerate the donor artifact with "
34+
"allocation flag columns before fitting source imputations."
35+
)
36+
37+
2038
def observed_source_mask(
2139
df: pd.DataFrame,
2240
*,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ dependencies = [
3232
"tqdm>=4.60.0",
3333
"microdf_python>=1.2.1",
3434
"setuptools>=60",
35-
"microimpute>=2.0.5",
35+
"microimpute>=2.1.0",
3636
"pip-system-certs>=3.0",
3737
"google-cloud-storage>=2.0.0",
3838
"google-auth>=2.0.0",

tests/unit/calibration/test_source_impute.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55

66
import numpy as np
77
import pandas as pd
8+
import huggingface_hub
89

10+
from policyengine_us_data.calibration import source_impute
911
from policyengine_us_data.calibration.source_impute import (
1012
ACS_IMPUTED_VARIABLES,
1113
ACS_PREDICTORS,
@@ -332,6 +334,62 @@ def test_impute_acs_exists(self):
332334
def test_impute_sipp_exists(self):
333335
assert callable(_impute_sipp)
334336

337+
def test_calibration_sipp_tip_counts_use_reference_month(self, monkeypatch):
338+
captured = {}
339+
340+
columns = {
341+
"SSUID": [1, 1, 1, 2],
342+
"MONTHCODE": [1, 12, 12, 12],
343+
"TAGE": [5, 40, 10, 30],
344+
"TPTOTINC": [1_000.0, 2_000.0, 0.0, 3_000.0],
345+
"WPFINWGT": [1.0, 1.0, 1.0, 1.0],
346+
}
347+
for column in source_impute.SIPP_TIP_AMOUNT_COLUMNS:
348+
columns[column] = [0.0, 10.0, 0.0, 5.0]
349+
for column in source_impute.SIPP_TIP_ALLOCATION_COLUMNS:
350+
columns[column] = [0, 0, 0, 0]
351+
for column in source_impute.SIPP_JOB_OCCUPATION_COLUMNS:
352+
columns[column] = [0, 0, 0, 0]
353+
tip_source = pd.DataFrame(columns)
354+
355+
read_count = {"count": 0}
356+
357+
def fake_read_csv(*args, **kwargs):
358+
read_count["count"] += 1
359+
if read_count["count"] == 1:
360+
return tip_source.copy()
361+
raise FileNotFoundError("stop after tip imputation")
362+
363+
class FakeQRF:
364+
def __init__(self, *args, **kwargs):
365+
pass
366+
367+
def fit(self, X_train, **kwargs):
368+
captured["train"] = X_train.copy()
369+
return self
370+
371+
def predict(self, X_test):
372+
return pd.DataFrame({"tip_income": np.zeros(len(X_test))})
373+
374+
monkeypatch.setattr(
375+
huggingface_hub,
376+
"hf_hub_download",
377+
lambda *args, **kwargs: None,
378+
)
379+
monkeypatch.setattr(source_impute.pd, "read_csv", fake_read_csv)
380+
monkeypatch.setattr(source_impute, "QRF", FakeQRF)
381+
382+
data = _make_data_dict(n_persons=4)
383+
_impute_sipp(
384+
data=data,
385+
state_fips=np.array([1, 1], dtype=np.int32),
386+
time_period=2024,
387+
)
388+
389+
household_one = captured["train"][captured["train"]["household_id"] == 1]
390+
np.testing.assert_array_equal(household_one["count_under_18"], [1, 1])
391+
np.testing.assert_array_equal(household_one["count_under_6"], [0, 0])
392+
335393
def test_impute_org_exists(self):
336394
assert callable(_impute_org)
337395

tests/unit/datasets/test_cps_file_handles.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,8 @@ def recording_hdfstore(path, mode="a", *args, **kwargs):
390390
acs_fixture_path = tmp_path / "acs_fixture.h5"
391391
with h5py.File(acs_fixture_path, "w") as acs_fixture:
392392
acs_fixture["is_household_head"] = np.ones(10_000, dtype=bool)
393+
acs_fixture["rent_is_allocated"] = np.zeros(10_000, dtype=bool)
394+
acs_fixture["real_estate_taxes_is_allocated"] = np.zeros(10_000, dtype=bool)
393395

394396
real_h5py_file = cps_module.h5py.File
395397
opened_h5_paths = []

tests/unit/datasets/test_rng_seeding.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def test_select_random_subset_uses_local_generator_only():
130130

131131
def test_sipp_training_samples_use_seeded_rng():
132132
"""N5: the weighted resample for tip and asset training frames
133-
must come from a seeded Generator, not the global ``np.random``."""
133+
must come from a deterministic sampler, not the global ``np.random``."""
134134
src = SIPP_SOURCE.read_text()
135135
assert "seeded_rng(" in src, "sipp.py must import/use seeded_rng()"
136136
tree = ast.parse(src)
@@ -149,8 +149,9 @@ def test_sipp_training_samples_use_seeded_rng():
149149
assert "np.random.choice" not in fn_src, (
150150
f"{fn_name} must not use np.random.choice (use a seeded_rng Generator)"
151151
)
152-
assert "seeded_rng(" in fn_src, (
153-
f"{fn_name} must derive its resampler from a seeded generator"
152+
assert "seeded_rng(" in fn_src or "max_train_samples=" in fn_src, (
153+
f"{fn_name} must derive its resampler from a seeded generator or "
154+
"delegate capped sampling to microimpute's deterministic QRF sampler"
154155
)
155156

156157

tests/unit/test_source_quality.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from policyengine_us_data.utils.source_quality import (
55
observed_source_mask,
6+
require_columns_present,
67
sipp_allocation_flag_for,
78
target_observed_source_masks,
89
)
@@ -14,6 +15,30 @@ def test_sipp_allocation_flag_for_source_column():
1415
assert sipp_allocation_flag_for("TJB1_TXAMT") == "AJB1_TXAMT"
1516

1617

18+
def test_require_columns_present_accepts_available_columns():
19+
require_columns_present(
20+
{"rent_is_allocated", "real_estate_taxes_is_allocated"},
21+
["rent_is_allocated", "real_estate_taxes_is_allocated"],
22+
source_name="ACS",
23+
)
24+
25+
26+
def test_require_columns_present_raises_for_missing_columns():
27+
try:
28+
require_columns_present(
29+
{"rent_is_allocated"},
30+
["rent_is_allocated", "real_estate_taxes_is_allocated"],
31+
source_name="ACS",
32+
)
33+
except KeyError as error:
34+
message = str(error)
35+
else:
36+
raise AssertionError("Expected missing source-quality columns to fail")
37+
38+
assert "real_estate_taxes_is_allocated" in message
39+
assert "Regenerate the donor artifact" in message
40+
41+
1742
def test_observed_source_mask_excludes_nonzero_allocation_flags():
1843
df = pd.DataFrame(
1944
{

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)