|
5 | 5 |
|
6 | 6 | import numpy as np |
7 | 7 | import pandas as pd |
| 8 | +import huggingface_hub |
8 | 9 |
|
| 10 | +from policyengine_us_data.calibration import source_impute |
9 | 11 | from policyengine_us_data.calibration.source_impute import ( |
10 | 12 | ACS_IMPUTED_VARIABLES, |
11 | 13 | ACS_PREDICTORS, |
@@ -332,6 +334,62 @@ def test_impute_acs_exists(self): |
332 | 334 | def test_impute_sipp_exists(self): |
333 | 335 | assert callable(_impute_sipp) |
334 | 336 |
|
| 337 | + def test_calibration_sipp_tip_counts_use_reference_month(self, monkeypatch): |
| 338 | + captured = {} |
| 339 | + |
| 340 | + columns = { |
| 341 | + "SSUID": [1, 1, 1, 2], |
| 342 | + "MONTHCODE": [1, 12, 12, 12], |
| 343 | + "TAGE": [5, 40, 10, 30], |
| 344 | + "TPTOTINC": [1_000.0, 2_000.0, 0.0, 3_000.0], |
| 345 | + "WPFINWGT": [1.0, 1.0, 1.0, 1.0], |
| 346 | + } |
| 347 | + for column in source_impute.SIPP_TIP_AMOUNT_COLUMNS: |
| 348 | + columns[column] = [0.0, 10.0, 0.0, 5.0] |
| 349 | + for column in source_impute.SIPP_TIP_ALLOCATION_COLUMNS: |
| 350 | + columns[column] = [0, 0, 0, 0] |
| 351 | + for column in source_impute.SIPP_JOB_OCCUPATION_COLUMNS: |
| 352 | + columns[column] = [0, 0, 0, 0] |
| 353 | + tip_source = pd.DataFrame(columns) |
| 354 | + |
| 355 | + read_count = {"count": 0} |
| 356 | + |
| 357 | + def fake_read_csv(*args, **kwargs): |
| 358 | + read_count["count"] += 1 |
| 359 | + if read_count["count"] == 1: |
| 360 | + return tip_source.copy() |
| 361 | + raise FileNotFoundError("stop after tip imputation") |
| 362 | + |
| 363 | + class FakeQRF: |
| 364 | + def __init__(self, *args, **kwargs): |
| 365 | + pass |
| 366 | + |
| 367 | + def fit(self, X_train, **kwargs): |
| 368 | + captured["train"] = X_train.copy() |
| 369 | + return self |
| 370 | + |
| 371 | + def predict(self, X_test): |
| 372 | + return pd.DataFrame({"tip_income": np.zeros(len(X_test))}) |
| 373 | + |
| 374 | + monkeypatch.setattr( |
| 375 | + huggingface_hub, |
| 376 | + "hf_hub_download", |
| 377 | + lambda *args, **kwargs: None, |
| 378 | + ) |
| 379 | + monkeypatch.setattr(source_impute.pd, "read_csv", fake_read_csv) |
| 380 | + monkeypatch.setattr(source_impute, "QRF", FakeQRF) |
| 381 | + |
| 382 | + data = _make_data_dict(n_persons=4) |
| 383 | + _impute_sipp( |
| 384 | + data=data, |
| 385 | + state_fips=np.array([1, 1], dtype=np.int32), |
| 386 | + time_period=2024, |
| 387 | + ) |
| 388 | + |
| 389 | + household_one = captured["train"][captured["train"]["household_id"] == 1] |
| 390 | + np.testing.assert_array_equal(household_one["count_under_18"], [1, 1]) |
| 391 | + np.testing.assert_array_equal(household_one["count_under_6"], [0, 0]) |
| 392 | + |
335 | 393 | def test_impute_org_exists(self): |
336 | 394 | assert callable(_impute_org) |
337 | 395 |
|
|
0 commit comments