|
16 | 16 | BEA_WAGES_AND_SALARIES_LOSS_WEIGHT, |
17 | 17 | BLS_CE_TOTALS, |
18 | 18 | HARD_CODED_TOTALS, |
| 19 | + HOUSEHOLD_COUNT_LOSS_WEIGHT, |
| 20 | + HOUSEHOLD_COUNT_TARGET, |
19 | 21 | LOW_AGI_INVESTMENT_INCOME_SOI_VARIABLES, |
20 | 22 | SOI_NEGATIVE_AGI_TARGETED_VARIABLES, |
21 | 23 | TRANSFER_BALANCE_TARGETS, |
22 | 24 | _add_bea_state_wage_targets, |
23 | 25 | _add_agi_metric_columns, |
24 | 26 | _add_acs_housing_cost_targets, |
| 27 | + _add_household_count_target, |
25 | 28 | _add_aotc_targets, |
26 | 29 | _add_bls_ce_targets, |
27 | 30 | _add_ctc_targets, |
@@ -167,6 +170,22 @@ def test_bea_nipa_direct_sum_targets_get_higher_loss_weight(): |
167 | 170 | ] |
168 | 171 |
|
169 | 172 |
|
| 173 | +def test_household_count_target_gets_higher_loss_weight(): |
| 174 | + target_names = np.array( |
| 175 | + [ |
| 176 | + HOUSEHOLD_COUNT_TARGET, |
| 177 | + "nation/census/population_by_age/0", |
| 178 | + ] |
| 179 | + ) |
| 180 | + |
| 181 | + weights = get_target_loss_weights(target_names) |
| 182 | + |
| 183 | + assert weights.tolist() == [ |
| 184 | + HOUSEHOLD_COUNT_LOSS_WEIGHT, |
| 185 | + 1.0, |
| 186 | + ] |
| 187 | + |
| 188 | + |
170 | 189 | def test_aca_targets_roll_forward_to_2025(): |
171 | 190 | targets, data_year = _load_aca_spending_and_enrollment_targets(2025) |
172 | 191 |
|
@@ -243,6 +262,17 @@ def __init__(self, values): |
243 | 262 | self.values = np.asarray(values) |
244 | 263 |
|
245 | 264 |
|
| 265 | +class _FakeHouseholdWeightSimulation: |
| 266 | + def __init__(self, weights): |
| 267 | + self.weights = weights |
| 268 | + |
| 269 | + def calculate(self, variable, map_to=None, period=None): |
| 270 | + assert variable == "household_weight" |
| 271 | + assert map_to is None |
| 272 | + assert period is None |
| 273 | + return _FakeArrayResult(self.weights) |
| 274 | + |
| 275 | + |
246 | 276 | class _FakeSimulation: |
247 | 277 | def __init__(self): |
248 | 278 | self.calculate_calls = [] |
@@ -427,6 +457,28 @@ def test_state_agi_targets_are_limited_to_filers(tmp_path, monkeypatch): |
427 | 457 | ) |
428 | 458 |
|
429 | 459 |
|
| 460 | +def test_add_household_count_target_uses_source_weight_total(): |
| 461 | + loss_matrix = pd.DataFrame(index=[101, 102, 103, 104]) |
| 462 | + |
| 463 | + targets, loss_matrix = _add_household_count_target( |
| 464 | + loss_matrix, |
| 465 | + [], |
| 466 | + _FakeHouseholdWeightSimulation([80.0, 20.0, 0.0, 0.0]), |
| 467 | + ) |
| 468 | + |
| 469 | + assert targets == [100.0] |
| 470 | + np.testing.assert_array_equal( |
| 471 | + loss_matrix[HOUSEHOLD_COUNT_TARGET].to_numpy(), |
| 472 | + np.ones(4, dtype=np.float32), |
| 473 | + ) |
| 474 | + |
| 475 | + |
| 476 | +def test_build_loss_matrix_adds_household_count_target_before_reweighting(): |
| 477 | + source = inspect.getsource(build_loss_matrix) |
| 478 | + |
| 479 | + assert "_add_household_count_target" in source |
| 480 | + |
| 481 | + |
430 | 482 | def test_add_ssi_recipient_targets_adds_total_and_age_counts(): |
431 | 483 | targets, loss_matrix = _add_ssi_recipient_targets( |
432 | 484 | pd.DataFrame(), |
|
0 commit comments