Skip to content

Commit d28c617

Browse files
committed
Preserve SSI disability anchors during imputation
1 parent 6d1f714 commit d28c617

7 files changed

Lines changed: 92 additions & 14 deletions

File tree

policyengine_us_data/calibration/source_impute.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
build_vehicle_training_frame,
4242
get_ssi_disability_model,
4343
predict_ssi_disability_criteria,
44+
preserve_under_65_ssi_disability_criteria,
4445
)
4546

4647
from policyengine_us_data.datasets.org import (
@@ -767,13 +768,16 @@ def _impute_sipp(
767768
ssi_disability_model,
768769
cps_ssi_df,
769770
)
770-
if "ssi_reported" in data:
771-
reported_under_65 = (data["ssi_reported"][time_period] > 0) & (
772-
data["age"][time_period] < 65
773-
)
774-
meets_ssi_disability_criteria = (
775-
meets_ssi_disability_criteria | reported_under_65
776-
)
771+
existing_meets_ssi_disability_criteria = data.get(
772+
SSI_DISABILITY_MODEL_VARIABLE, {}
773+
).get(time_period)
774+
ssi_reported = data.get("ssi_reported", {}).get(time_period)
775+
meets_ssi_disability_criteria = preserve_under_65_ssi_disability_criteria(
776+
meets_ssi_disability_criteria,
777+
age=data["age"][time_period],
778+
ssi_reported=ssi_reported,
779+
existing_meets_ssi_disability_criteria=existing_meets_ssi_disability_criteria,
780+
)
777781
data[SSI_DISABILITY_MODEL_VARIABLE] = {
778782
time_period: meets_ssi_disability_criteria
779783
}

policyengine_us_data/datasets/cps/cps.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2652,6 +2652,7 @@ def add_tips(self, cps: h5py.File):
26522652
SSI_DISABILITY_MODEL_VARIABLE,
26532653
get_ssi_disability_model,
26542654
predict_ssi_disability_criteria,
2655+
preserve_under_65_ssi_disability_criteria,
26552656
)
26562657

26572658
n_persons = len(cps)
@@ -2671,13 +2672,14 @@ def add_tips(self, cps: h5py.File):
26712672
ssi_disability_model,
26722673
cps,
26732674
)
2674-
if "ssi_reported" in existing_data:
2675-
reported_under_65 = (np.asarray(existing_data["ssi_reported"]) > 0) & (
2676-
np.asarray(existing_data["age"]) < 65
2677-
)
2678-
meets_ssi_disability_criteria = (
2679-
meets_ssi_disability_criteria | reported_under_65
2680-
)
2675+
meets_ssi_disability_criteria = preserve_under_65_ssi_disability_criteria(
2676+
meets_ssi_disability_criteria,
2677+
age=existing_data.get("age", np.full(n_persons, 65)),
2678+
ssi_reported=existing_data.get("ssi_reported"),
2679+
existing_meets_ssi_disability_criteria=existing_data.get(
2680+
SSI_DISABILITY_MODEL_VARIABLE
2681+
),
2682+
)
26812683
cps[SSI_DISABILITY_MODEL_VARIABLE] = meets_ssi_disability_criteria
26822684

26832685
from policyengine_us_data.datasets.sipp import get_vehicle_model

policyengine_us_data/datasets/sipp/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
build_ssi_disability_training_frame,
1010
coerce_ssi_disability_predictions,
1111
predict_ssi_disability_criteria,
12+
preserve_under_65_ssi_disability_criteria,
1213
prepare_ssi_disability_receiver,
1314
train_ssi_disability_model,
1415
get_ssi_disability_model,
@@ -28,6 +29,7 @@
2829
"build_ssi_disability_training_frame",
2930
"coerce_ssi_disability_predictions",
3031
"predict_ssi_disability_criteria",
32+
"preserve_under_65_ssi_disability_criteria",
3133
"prepare_ssi_disability_receiver",
3234
"train_ssi_disability_model",
3335
"get_ssi_disability_model",

policyengine_us_data/datasets/sipp/sipp.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,29 @@ def apply_ssi_disability_signal_screen(
427427
return np.asarray(meets_ssi_disability_criteria, dtype=bool) & disability_signal
428428

429429

430+
def preserve_under_65_ssi_disability_criteria(
431+
meets_ssi_disability_criteria: np.ndarray,
432+
age: np.ndarray,
433+
ssi_reported: np.ndarray | None = None,
434+
existing_meets_ssi_disability_criteria: np.ndarray | None = None,
435+
) -> np.ndarray:
436+
"""Preserve observed under-65 SSI disability criteria anchors."""
437+
result = np.asarray(meets_ssi_disability_criteria, dtype=bool).copy()
438+
under_65 = pd.Series(age).fillna(np.inf).astype(float).lt(65).to_numpy()
439+
440+
if ssi_reported is not None:
441+
reported_ssi = pd.Series(ssi_reported).fillna(0).astype(float).gt(0).to_numpy()
442+
result |= reported_ssi & under_65
443+
444+
if existing_meets_ssi_disability_criteria is not None:
445+
result |= (
446+
_coerce_ssi_disability_signal(existing_meets_ssi_disability_criteria)
447+
& under_65
448+
)
449+
450+
return result
451+
452+
430453
def coerce_ssi_disability_predictions(values) -> np.ndarray:
431454
"""Convert classifier labels to booleans without treating 'False' as true."""
432455
series = pd.Series(values)

tests/integration/test_cps_generation.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def calculate(self, variable_name):
5151
"receives_wic": [False, False],
5252
"hud_income_level": ["VERY_LOW"],
5353
"spm_unit_tenure_type": ["RENTER"],
54+
"is_eligible_for_housing_assistance": [True],
5455
"tax_unit_child_dependents": [0],
5556
"age_head": [40],
5657
}
@@ -223,9 +224,27 @@ def predict(self, X_test, mean_quantile):
223224
}
224225
)
225226

227+
class FakeSsiDisabilityModel:
228+
pass
229+
230+
def fake_predict_ssi_disability_criteria(model, receiver_df):
231+
assert isinstance(model, FakeSsiDisabilityModel)
232+
assert receiver_df["employment_income"].tolist() == [25_000.0, 30_000.0]
233+
return np.array([True, False])
234+
226235
monkeypatch.setattr(sipp_module, "get_tip_model", lambda: FakeTipModel())
227236
monkeypatch.setattr(sipp_module, "get_asset_model", lambda: FakeAssetModel())
228237
monkeypatch.setattr(sipp_module, "get_vehicle_model", lambda: FakeVehicleModel())
238+
monkeypatch.setattr(
239+
sipp_module,
240+
"get_ssi_disability_model",
241+
lambda: FakeSsiDisabilityModel(),
242+
)
243+
monkeypatch.setattr(
244+
sipp_module,
245+
"predict_ssi_disability_criteria",
246+
fake_predict_ssi_disability_criteria,
247+
)
229248

230249
dataset = FakeDataset()
231250
add_tips(
@@ -245,6 +264,10 @@ def predict(self, X_test, mean_quantile):
245264
18_000.0,
246265
7_500.0,
247266
]
267+
assert dataset.saved_dataset["meets_ssi_disability_criteria"].tolist() == [
268+
True,
269+
False,
270+
]
248271

249272

250273
def test_add_rent_requests_person_level_frames(monkeypatch, tmp_path):

tests/unit/calibration/test_source_impute.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
_impute_sipp,
2323
_person_state_fips,
2424
impute_source_variables,
25+
preserve_under_65_ssi_disability_criteria,
2526
)
2627
from policyengine_us_data.datasets.cps.tipped_occupation import (
2728
derive_any_treasury_tipped_occupation_code,
@@ -249,6 +250,17 @@ def test_impute_org_exists(self):
249250
def test_impute_scf_exists(self):
250251
assert callable(_impute_scf)
251252

253+
def test_source_impute_preserves_existing_under_65_ssi_criteria(self):
254+
fake_model_predictions = np.array([False, False, False])
255+
256+
result = preserve_under_65_ssi_disability_criteria(
257+
fake_model_predictions,
258+
age=np.array([40, 64, 70]),
259+
existing_meets_ssi_disability_criteria=np.array([True, False, True]),
260+
)
261+
262+
np.testing.assert_array_equal(result, np.array([True, False, False]))
263+
252264

253265
class TestTippedOccupationHelpers:
254266
def test_derive_any_treasury_tipped_occupation_code(self):

tests/unit/datasets/test_sipp_ssi_disability.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
build_ssi_disability_training_frame,
99
coerce_ssi_disability_predictions,
1010
predict_ssi_disability_criteria,
11+
preserve_under_65_ssi_disability_criteria,
1112
prepare_ssi_disability_receiver,
1213
)
1314

@@ -103,6 +104,17 @@ def test_apply_ssi_disability_signal_screen_treats_missing_as_false():
103104
np.testing.assert_array_equal(result, np.array([False, False, False]))
104105

105106

107+
def test_preserve_under_65_ssi_disability_criteria_keeps_observed_anchors():
108+
result = preserve_under_65_ssi_disability_criteria(
109+
np.array([False, False, False, False]),
110+
age=np.array([40, 64, 70, 30]),
111+
ssi_reported=np.array([0, 100, 100, np.nan]),
112+
existing_meets_ssi_disability_criteria=np.array([True, False, True, np.nan]),
113+
)
114+
115+
np.testing.assert_array_equal(result, np.array([True, True, False, False]))
116+
117+
106118
def test_coerce_ssi_disability_predictions_handles_string_false():
107119
result = coerce_ssi_disability_predictions(
108120
pd.Series(["False", "True", "0", "1", False, True, 0, 1])

0 commit comments

Comments
 (0)