Skip to content

Commit e4b28d0

Browse files
authored
Use calibration log for sparse validation (#935)
1 parent 7d51051 commit e4b28d0

2 files changed

Lines changed: 14 additions & 60 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Use saved calibration diagnostics for the sparse enhanced CPS validation gate instead of rebuilding the full loss matrix.

validation/stage_1/test_sparse_enhanced_cps.py

Lines changed: 13 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,7 @@
1010
from policyengine_core.data import Dataset
1111
from policyengine_core.reforms import Reform
1212
from policyengine_us import Microsimulation
13-
from policyengine_us_data.utils import (
14-
ABSOLUTE_ERROR_SCALE_TARGETS,
15-
build_loss_matrix,
16-
print_reweighting_diagnostics,
17-
)
13+
from policyengine_us_data.utils import ABSOLUTE_ERROR_SCALE_TARGETS
1814
from policyengine_us_data.storage import STORAGE_FOLDER
1915

2016

@@ -69,61 +65,18 @@ def test_sparse_poverty_rate_reasonable(sparse_sim):
6965
# ── Reweighting and calibration checks ────────────────────────
7066

7167

72-
@pytest.mark.filterwarnings("ignore:DataFrame is highly fragmented")
73-
@pytest.mark.filterwarnings("ignore:The distutils package is deprecated")
74-
@pytest.mark.filterwarnings(
75-
"ignore:Series.__getitem__ treating keys as positions is deprecated"
76-
)
77-
@pytest.mark.filterwarnings(
78-
"ignore:Setting an item of incompatible dtype is deprecated"
79-
)
80-
@pytest.mark.filterwarnings(
81-
"ignore:Boolean Series key will be reindexed to match DataFrame index."
82-
)
83-
def test_sparse_ecps(sim):
84-
data = sim.dataset.load_dataset()
85-
optimised_weights = data["household_weight"]["2024"]
86-
87-
bad_targets = [
88-
"nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Head of Household",
89-
"nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Head of Household",
90-
"nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse",
91-
"nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse",
92-
"nation/irs/count/count/AGI in 10k-15k/taxable/Head of Household",
93-
"nation/irs/count/count/AGI in 15k-20k/taxable/Head of Household",
94-
"nation/irs/count/count/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse",
95-
"nation/irs/count/count/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse",
96-
"state/RI/adjusted_gross_income/amount/-inf_1",
97-
"nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Head of Household",
98-
"nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Head of Household",
99-
"nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse",
100-
"nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse",
101-
"nation/irs/count/count/AGI in 10k-15k/taxable/Head of Household",
102-
"nation/irs/count/count/AGI in 15k-20k/taxable/Head of Household",
103-
"nation/irs/count/count/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse",
104-
"nation/irs/count/count/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse",
105-
"state/RI/adjusted_gross_income/amount/-inf_1",
106-
"nation/irs/exempt interest/count/AGI in -inf-inf/taxable/All",
107-
]
108-
109-
loss_matrix, targets_array = build_loss_matrix(sim.dataset, 2024)
110-
scaled_zero_target_mask = loss_matrix.columns.isin(
111-
ABSOLUTE_ERROR_SCALE_TARGETS.keys()
112-
)
113-
zero_mask = np.isclose(targets_array, 0.0, atol=0.1) & (~scaled_zero_target_mask)
114-
bad_mask = loss_matrix.columns.isin(bad_targets)
115-
keep_mask_bool = ~(zero_mask | bad_mask)
116-
keep_idx = np.where(keep_mask_bool)[0]
117-
loss_matrix_clean = loss_matrix.iloc[:, keep_idx]
118-
targets_array_clean = targets_array[keep_idx]
119-
assert loss_matrix_clean.shape[1] == targets_array_clean.size
120-
121-
percent_within_10 = print_reweighting_diagnostics(
122-
optimised_weights,
123-
loss_matrix_clean,
124-
targets_array_clean,
125-
"Sparse Solutions",
126-
)
68+
def test_sparse_ecps():
69+
calibration_log = pd.read_csv("calibration_log.csv")
70+
final_epoch = calibration_log["epoch"].max()
71+
final_rows = calibration_log[calibration_log["epoch"] == final_epoch].copy()
72+
73+
assert not final_rows.empty, "No final-epoch calibration diagnostics found."
74+
75+
tolerance = 0.10 * final_rows["target"].abs()
76+
for target_name, scale in ABSOLUTE_ERROR_SCALE_TARGETS.items():
77+
tolerance.loc[final_rows["target_name"] == target_name] = 0.10 * scale
78+
79+
percent_within_10 = (final_rows["abs_error"] <= tolerance).mean() * 100
12780
assert percent_within_10 > 60.0
12881

12982

0 commit comments

Comments
 (0)