Skip to content

Commit 86ac0bc

Browse files
committed
Format rebased CPS clone imputation changes
1 parent c9809fd commit 86ac0bc

2 files changed

Lines changed: 19 additions & 11 deletions

File tree

policyengine_us_data/datasets/cps/extended_cps.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@
2222

2323
logger = logging.getLogger(__name__)
2424

25+
2526
def _supports_structural_mortgage_inputs() -> bool:
2627
return has_policyengine_us_variables(*STRUCTURAL_MORTGAGE_VARIABLES)
2728

29+
2830
# CPS-only categorical features to donor-impute onto the PUF clone half.
2931
# These drive subgroup analysis and occupation-based logic, so naive donor
3032
# duplication dilutes the relationship between the clone's PUF-imputed
@@ -197,7 +199,9 @@ def _clone_half_person_values(data: dict, variable: str, time_period: int):
197199
continue
198200
entity_half = len(entity_ids) // 2
199201
clone_entity_ids = entity_ids[entity_half:]
200-
clone_person_entity_ids = data[person_entity_id_var][time_period][n_persons_half:]
202+
clone_person_entity_ids = data[person_entity_id_var][time_period][
203+
n_persons_half:
204+
]
201205
value_map = dict(zip(clone_entity_ids, values[entity_half:]))
202206
return np.array([value_map[idx] for idx in clone_person_entity_ids])
203207

@@ -270,7 +274,9 @@ def _impute_clone_cps_features(
270274
CPS_CLONE_FEATURE_PREDICTORS + CPS_CLONE_FEATURE_VARIABLES
271275
)
272276
available_outputs = [
273-
variable for variable in CPS_CLONE_FEATURE_VARIABLES if variable in X_train.columns
277+
variable
278+
for variable in CPS_CLONE_FEATURE_VARIABLES
279+
if variable in X_train.columns
274280
]
275281
if not available_outputs:
276282
n_half = len(data["person_id"][time_period]) // 2
@@ -323,8 +329,12 @@ def _impute_clone_cps_features(
323329
predictions.loc[test_mask, available_outputs] = donor_outputs.to_numpy()
324330

325331
if "detailed_occupation_recode" in predictions:
326-
occupation_codes = predictions["detailed_occupation_recode"].astype(float).to_numpy()
327-
for column, values in _derive_overtime_occupation_inputs(occupation_codes).items():
332+
occupation_codes = (
333+
predictions["detailed_occupation_recode"].astype(float).to_numpy()
334+
)
335+
for column, values in _derive_overtime_occupation_inputs(
336+
occupation_codes
337+
).items():
328338
predictions[column] = values
329339

330340
return predictions

policyengine_us_data/tests/test_extended_cps.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -364,9 +364,7 @@ def calculate_dataframe(self, columns):
364364
"employment_income": {
365365
tp: np.array([20_000, 35_000, 90_000, 150_000], dtype=np.float32)
366366
},
367-
"self_employment_income": {
368-
tp: np.zeros(4, dtype=np.float32)
369-
},
367+
"self_employment_income": {tp: np.zeros(4, dtype=np.float32)},
370368
"social_security": {tp: np.zeros(4, dtype=np.float32)},
371369
"is_tax_unit_head": {tp: np.ones(4, dtype=bool)},
372370
"is_tax_unit_spouse": {tp: np.zeros(4, dtype=bool)},
@@ -387,9 +385,7 @@ def calculate_dataframe(self, columns):
387385
assert result["tax_unit_is_joint"].tolist() == [0, 1]
388386

389387
def test_derive_overtime_occupation_inputs(self):
390-
derived = _derive_overtime_occupation_inputs(
391-
np.array([53, 52, 8, 41, 1, 99])
392-
)
388+
derived = _derive_overtime_occupation_inputs(np.array([53, 52, 8, 41, 1, 99]))
393389

394390
assert derived["has_never_worked"].tolist() == [
395391
True,
@@ -477,7 +473,9 @@ def calculate_dataframe(self, columns):
477473
"is_tax_unit_head": {tp: np.array([1, 0, 1, 0], dtype=bool)},
478474
"is_tax_unit_spouse": {tp: np.zeros(4, dtype=bool)},
479475
"is_tax_unit_dependent": {tp: np.array([0, 1, 0, 1], dtype=bool)},
480-
"employment_income": {tp: np.array([95_000, 0, 97_000, 0], dtype=np.float32)},
476+
"employment_income": {
477+
tp: np.array([95_000, 0, 97_000, 0], dtype=np.float32)
478+
},
481479
"self_employment_income": {tp: np.zeros(4, dtype=np.float32)},
482480
"social_security": {tp: np.zeros(4, dtype=np.float32)},
483481
}

0 commit comments

Comments
 (0)