Skip to content

Commit c2de134

Browse files
committed
Address SSI disability review findings
1 parent b090a9f commit c2de134

4 files changed

Lines changed: 172 additions & 3 deletions

File tree

policyengine_us_data/datasets/cps/extended_cps.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@
2727
ORG_IMPUTED_VARIABLES,
2828
apply_org_domain_constraints,
2929
)
30+
from policyengine_us_data.datasets.sipp import (
31+
SSI_DISABILITY_MODEL_PREDICTORS,
32+
SSI_DISABILITY_MODEL_VARIABLE,
33+
get_ssi_disability_model,
34+
predict_ssi_disability_criteria,
35+
)
3036
from policyengine_us_data.pipeline_metadata import pipeline_node
3137
from policyengine_us_data.pipeline_schema import PipelineNode
3238
from policyengine_us_data.datasets.puf import PUF, PUF_2024
@@ -185,7 +191,7 @@ def _supports_structural_mortgage_inputs() -> bool:
185191
"financial_assistance",
186192
"survivor_benefits",
187193
"disability_benefits",
188-
"meets_ssi_disability_criteria",
194+
SSI_DISABILITY_MODEL_VARIABLE,
189195
"strike_benefits",
190196
"receives_wic",
191197
# SPM variables
@@ -236,6 +242,20 @@ def _supports_structural_mortgage_inputs() -> bool:
236242
"age",
237243
}
238244

245+
_CLONE_REFRESH_STRUCTURAL_ROLE_VARIABLES = {
246+
"is_household_head",
247+
"is_tax_unit_head",
248+
"is_tax_unit_spouse",
249+
"is_tax_unit_dependent",
250+
"is_tax_unit_head_or_spouse",
251+
"is_family_head",
252+
"is_family_spouse",
253+
"is_family_dependent",
254+
"is_spm_unit_head",
255+
"is_spm_unit_spouse",
256+
"is_spm_unit_dependent",
257+
}
258+
239259
# Predictors used for the second-stage CPS-only imputation: demographics
240260
# plus key income variables that were already imputed from PUF data.
241261
CPS_STAGE2_DEMOGRAPHIC_PREDICTORS = [
@@ -308,6 +328,7 @@ def _is_structural_clone_variable(variable: str) -> bool:
308328
or variable in _CLONE_REFRESH_GEOGRAPHY_VARIABLES
309329
or variable in CLONE_ORIGIN_FLAGS.values()
310330
or variable in _CLONE_REFRESH_ANCHOR_VARIABLES
331+
or variable in _CLONE_REFRESH_STRUCTURAL_ROLE_VARIABLES
311332
or variable in _STAGE2_COMPUTED_PREDICTORS
312333
)
313334

@@ -389,6 +410,41 @@ def _build_clone_test_frame(
389410
return X_test[predictors]
390411

391412

413+
def _build_ssi_disability_clone_receiver(
414+
predictions: pd.DataFrame,
415+
X_test: pd.DataFrame,
416+
data: dict,
417+
time_period: int,
418+
) -> pd.DataFrame:
419+
"""Build SIPP SSI disability model inputs for PUF clone records."""
420+
n = len(X_test)
421+
receiver = pd.DataFrame(index=X_test.index)
422+
for predictor in SSI_DISABILITY_MODEL_PREDICTORS:
423+
values = None
424+
if (
425+
predictor == "has_disability_income"
426+
and "disability_benefits" in predictions
427+
):
428+
values = predictions["disability_benefits"].to_numpy() > 0
429+
elif predictor in predictions:
430+
values = predictions[predictor].to_numpy()
431+
elif predictor in X_test:
432+
values = X_test[predictor].to_numpy()
433+
else:
434+
clone_values = _clone_half_person_values(data, predictor, time_period)
435+
if clone_values is not None and len(clone_values) == n:
436+
values = clone_values
437+
438+
if values is None and predictor == "is_female" and "is_male" in X_test:
439+
values = ~X_test["is_male"].astype(bool).to_numpy()
440+
if values is None:
441+
values = np.zeros(n)
442+
443+
receiver[predictor] = values
444+
445+
return receiver
446+
447+
392448
def _prepare_knn_matrix(
393449
df: pd.DataFrame,
394450
reference: pd.DataFrame | None = None,
@@ -923,6 +979,18 @@ def _apply_post_processing(predictions, X_test, time_period, data):
923979
"employer_sponsored_insurance_premiums",
924980
] = 0
925981

982+
if SSI_DISABILITY_MODEL_VARIABLE in predictions.columns:
983+
receiver = _build_ssi_disability_clone_receiver(
984+
predictions,
985+
X_test,
986+
data,
987+
time_period,
988+
)
989+
predictions[SSI_DISABILITY_MODEL_VARIABLE] = predict_ssi_disability_criteria(
990+
get_ssi_disability_model(time_period=time_period),
991+
receiver,
992+
)
993+
926994
return predictions
927995

928996

policyengine_us_data/datasets/sipp/sipp.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
]
4848

4949
SSI_DISABILITY_MODEL_VARIABLE = "meets_ssi_disability_criteria"
50+
SSI_DISABILITY_MODEL_VERSION = 3
5051

5152
SSI_DISABILITY_DIFFICULTY_PREDICTORS = [
5253
"difficulty_dressing_or_bathing",
@@ -818,7 +819,7 @@ def train_ssi_disability_model(time_period: int = 2024):
818819

819820
def get_ssi_disability_model(time_period: int = 2024) -> QRF:
820821
"""Get or train the SSI disability criteria imputation model."""
821-
model_path = STORAGE_FOLDER / f"ssi_disability_criteria_v2_{time_period}.pkl"
822+
model_path = _ssi_disability_model_path(time_period)
822823

823824
if not model_path.exists():
824825
model = train_ssi_disability_model(time_period=time_period)
@@ -832,6 +833,13 @@ def get_ssi_disability_model(time_period: int = 2024) -> QRF:
832833
return model
833834

834835

836+
def _ssi_disability_model_path(time_period: int):
837+
return (
838+
STORAGE_FOLDER
839+
/ f"ssi_disability_criteria_v{SSI_DISABILITY_MODEL_VERSION}_{time_period}.pkl"
840+
)
841+
842+
835843
def build_vehicle_training_frame() -> pd.DataFrame:
836844
"""Build a household-level SIPP frame for vehicle asset imputation."""
837845
hf_hub_download(

tests/unit/datasets/test_sipp_ssi_disability.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
preserve_under_65_ssi_disability_criteria,
1313
prepare_ssi_disability_receiver,
1414
)
15-
from policyengine_us_data.datasets.sipp.sipp import SSI_DISABILITY_COLUMNS
15+
from policyengine_us_data.datasets.sipp.sipp import (
16+
SSI_DISABILITY_COLUMNS,
17+
SSI_DISABILITY_MODEL_VERSION,
18+
_ssi_disability_model_path,
19+
)
1620

1721

1822
def _base_sipp_frame() -> pd.DataFrame:
@@ -95,6 +99,13 @@ def test_ssi_disability_predictors_use_six_comparable_difficulty_items():
9599
assert "is_disabled" not in SSI_DISABILITY_MODEL_PREDICTORS
96100

97101

102+
def test_ssi_disability_model_cache_version_tracks_predictor_schema():
103+
assert SSI_DISABILITY_MODEL_VERSION == 3
104+
assert (
105+
_ssi_disability_model_path(2024).name == "ssi_disability_criteria_v3_2024.pkl"
106+
)
107+
108+
98109
def test_build_ssi_disability_training_frame_annualizes_ssdi_amount():
99110
frame = _base_sipp_frame().iloc[[2]].copy()
100111
frame["ESSRSN2YN"] = 1

tests/unit/test_extended_cps.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
_load_raw_spm_capped_housing_subsidy,
3030
_apply_post_processing,
3131
_build_clone_test_frame,
32+
_build_ssi_disability_clone_receiver,
3233
_cps_clone_feature_variables_for_data,
3334
_derive_overtime_occupation_inputs,
3435
_impute_clone_cps_features,
@@ -216,6 +217,8 @@ def test_clone_feature_candidates_include_person_level_cps_only_flags(self):
216217
"household_weight": {2024: np.array([1.0, 1.0, 0.0, 0.0])},
217218
"state_fips": {2024: np.array([6, 36, 6, 36])},
218219
"employment_income": {2024: np.array([1.0, 2.0, 3.0, 4.0])},
220+
"is_household_head": {2024: np.array([True, True, True, True])},
221+
"is_tax_unit_head": {2024: np.array([True, False, True, False])},
219222
"is_disabled": {2024: np.array([True, False, True, False])},
220223
"difficulty_hearing": {2024: np.array([False, True, False, True])},
221224
"meets_ssi_disability_criteria": {
@@ -232,6 +235,8 @@ def test_clone_feature_candidates_include_person_level_cps_only_flags(self):
232235
assert "household_weight" not in result
233236
assert "state_fips" not in result
234237
assert "employment_income" not in result
238+
assert "is_household_head" not in result
239+
assert "is_tax_unit_head" not in result
235240
assert "meets_ssi_disability_criteria" not in result
236241

237242
def test_spm_threshold_is_formula_output_not_qrf_imputed(self):
@@ -881,6 +886,83 @@ def test_leaves_data_unchanged_when_pe_us_lacks_llc_inputs(self, monkeypatch):
881886

882887

883888
class TestStage2PostProcessing:
889+
def test_ssi_disability_clone_receiver_uses_stage2_disability_benefits(self):
890+
data = {
891+
"person_id": {2024: np.array([1, 2, 101, 102])},
892+
"difficulty_hearing": {2024: np.array([False, False, True, False])},
893+
}
894+
predictions = pd.DataFrame({"disability_benefits": [0.0, 500.0]})
895+
x_test = pd.DataFrame(
896+
{
897+
"age": [40, 40],
898+
"is_male": [False, True],
899+
"employment_income": [0.0, 0.0],
900+
}
901+
)
902+
903+
result = _build_ssi_disability_clone_receiver(
904+
predictions,
905+
x_test,
906+
data,
907+
2024,
908+
)
909+
910+
np.testing.assert_array_equal(result["difficulty_hearing"], [True, False])
911+
np.testing.assert_array_equal(result["has_disability_income"], [False, True])
912+
np.testing.assert_array_equal(result["is_female"], [True, False])
913+
914+
def test_post_processing_replaces_generic_ssi_disability_predictions(
915+
self,
916+
monkeypatch,
917+
):
918+
class AlwaysTrueModel:
919+
def predict(self, X_test):
920+
return pd.DataFrame(
921+
{
922+
"meets_ssi_disability_criteria": np.ones(
923+
len(X_test),
924+
dtype=bool,
925+
)
926+
}
927+
)
928+
929+
monkeypatch.setattr(
930+
extended_cps_module,
931+
"get_ssi_disability_model",
932+
lambda time_period: AlwaysTrueModel(),
933+
)
934+
data = {
935+
"person_id": {2024: np.arange(6)},
936+
"difficulty_walking_or_climbing_stairs": {
937+
2024: np.array([False, False, False, True, False, False])
938+
},
939+
}
940+
predictions = pd.DataFrame(
941+
{
942+
"meets_ssi_disability_criteria": [False, False, True],
943+
"disability_benefits": [0.0, 500.0, 0.0],
944+
}
945+
)
946+
x_test = pd.DataFrame(
947+
{
948+
"age": [40, 40, 40],
949+
"is_male": [False, True, False],
950+
"employment_income": [0.0, 0.0, 0.0],
951+
}
952+
)
953+
954+
result = _apply_post_processing(
955+
predictions=predictions,
956+
X_test=x_test,
957+
time_period=2024,
958+
data=data,
959+
)
960+
961+
np.testing.assert_array_equal(
962+
result["meets_ssi_disability_criteria"],
963+
np.array([True, True, False]),
964+
)
965+
884966
def test_splice_replaces_clone_half_ssi_disability_criteria(self, monkeypatch):
885967
import policyengine_us
886968

0 commit comments

Comments
 (0)