|
10 | 10 | from policyengine_core.data import Dataset |
11 | 11 | from policyengine_core.reforms import Reform |
12 | 12 | from policyengine_us import Microsimulation |
13 | | -from policyengine_us_data.utils import ( |
14 | | - ABSOLUTE_ERROR_SCALE_TARGETS, |
15 | | - build_loss_matrix, |
16 | | - print_reweighting_diagnostics, |
17 | | -) |
| 13 | +from policyengine_us_data.utils import ABSOLUTE_ERROR_SCALE_TARGETS |
18 | 14 | from policyengine_us_data.storage import STORAGE_FOLDER |
19 | 15 |
|
20 | 16 |
|
@@ -69,61 +65,18 @@ def test_sparse_poverty_rate_reasonable(sparse_sim): |
69 | 65 | # ── Reweighting and calibration checks ──────────────────────── |
70 | 66 |
|
71 | 67 |
|
72 | | -@pytest.mark.filterwarnings("ignore:DataFrame is highly fragmented") |
73 | | -@pytest.mark.filterwarnings("ignore:The distutils package is deprecated") |
74 | | -@pytest.mark.filterwarnings( |
75 | | - "ignore:Series.__getitem__ treating keys as positions is deprecated" |
76 | | -) |
77 | | -@pytest.mark.filterwarnings( |
78 | | - "ignore:Setting an item of incompatible dtype is deprecated" |
79 | | -) |
80 | | -@pytest.mark.filterwarnings( |
81 | | - "ignore:Boolean Series key will be reindexed to match DataFrame index." |
82 | | -) |
83 | | -def test_sparse_ecps(sim): |
84 | | - data = sim.dataset.load_dataset() |
85 | | - optimised_weights = data["household_weight"]["2024"] |
86 | | - |
87 | | - bad_targets = [ |
88 | | - "nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Head of Household", |
89 | | - "nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Head of Household", |
90 | | - "nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse", |
91 | | - "nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse", |
92 | | - "nation/irs/count/count/AGI in 10k-15k/taxable/Head of Household", |
93 | | - "nation/irs/count/count/AGI in 15k-20k/taxable/Head of Household", |
94 | | - "nation/irs/count/count/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse", |
95 | | - "nation/irs/count/count/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse", |
96 | | - "state/RI/adjusted_gross_income/amount/-inf_1", |
97 | | - "nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Head of Household", |
98 | | - "nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Head of Household", |
99 | | - "nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse", |
100 | | - "nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse", |
101 | | - "nation/irs/count/count/AGI in 10k-15k/taxable/Head of Household", |
102 | | - "nation/irs/count/count/AGI in 15k-20k/taxable/Head of Household", |
103 | | - "nation/irs/count/count/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse", |
104 | | - "nation/irs/count/count/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse", |
105 | | - "state/RI/adjusted_gross_income/amount/-inf_1", |
106 | | - "nation/irs/exempt interest/count/AGI in -inf-inf/taxable/All", |
107 | | - ] |
108 | | - |
109 | | - loss_matrix, targets_array = build_loss_matrix(sim.dataset, 2024) |
110 | | - scaled_zero_target_mask = loss_matrix.columns.isin( |
111 | | - ABSOLUTE_ERROR_SCALE_TARGETS.keys() |
112 | | - ) |
113 | | - zero_mask = np.isclose(targets_array, 0.0, atol=0.1) & (~scaled_zero_target_mask) |
114 | | - bad_mask = loss_matrix.columns.isin(bad_targets) |
115 | | - keep_mask_bool = ~(zero_mask | bad_mask) |
116 | | - keep_idx = np.where(keep_mask_bool)[0] |
117 | | - loss_matrix_clean = loss_matrix.iloc[:, keep_idx] |
118 | | - targets_array_clean = targets_array[keep_idx] |
119 | | - assert loss_matrix_clean.shape[1] == targets_array_clean.size |
120 | | - |
121 | | - percent_within_10 = print_reweighting_diagnostics( |
122 | | - optimised_weights, |
123 | | - loss_matrix_clean, |
124 | | - targets_array_clean, |
125 | | - "Sparse Solutions", |
126 | | - ) |
| 68 | +def test_sparse_ecps(): |
| 69 | + calibration_log = pd.read_csv("calibration_log.csv") |
| 70 | + final_epoch = calibration_log["epoch"].max() |
| 71 | + final_rows = calibration_log[calibration_log["epoch"] == final_epoch].copy() |
| 72 | + |
| 73 | + assert not final_rows.empty, "No final-epoch calibration diagnostics found." |
| 74 | + |
| 75 | + tolerance = 0.10 * final_rows["target"].abs() |
| 76 | + for target_name, scale in ABSOLUTE_ERROR_SCALE_TARGETS.items(): |
| 77 | + tolerance.loc[final_rows["target_name"] == target_name] = 0.10 * scale |
| 78 | + |
| 79 | + percent_within_10 = (final_rows["abs_error"] <= tolerance).mean() * 100 |
127 | 80 | assert percent_within_10 > 60.0 |
128 | 81 |
|
129 | 82 |
|
|
0 commit comments