Skip to content

Commit 39ad696

Browse files
committed
Clean up reform validation cache and add tests
1 parent 9ee1a2c commit 39ad696

3 files changed

Lines changed: 212 additions & 12 deletions

File tree

policyengine_us_data/calibration/validate_staging.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -270,27 +270,29 @@ def _build_entity_rel(sim) -> pd.DataFrame:
270270
)
271271

272272

273-
def _get_reform_household_values(
273+
def _get_reform_income_tax_delta(
274274
dataset_path: str,
275275
period: int,
276276
variable: str,
277-
reform_hh_cache: dict,
277+
baseline_income_tax: np.ndarray,
278+
reform_delta_cache: dict,
278279
) -> np.ndarray:
279-
if variable in reform_hh_cache:
280-
return reform_hh_cache[variable]
280+
if variable in reform_delta_cache:
281+
return reform_delta_cache[variable]
281282

282283
from policyengine_us import Microsimulation
283284

284285
reform_sim = Microsimulation(
285286
dataset=dataset_path,
286287
reform=_make_neutralize_variable_reform(variable),
287288
)
288-
reform_hh_cache[variable] = reform_sim.calculate(
289+
reform_income_tax = reform_sim.calculate(
289290
"income_tax",
290291
map_to="household",
291292
period=period,
292293
).values
293-
return reform_hh_cache[variable]
294+
reform_delta_cache[variable] = reform_income_tax - baseline_income_tax
295+
return reform_delta_cache[variable]
294296

295297

296298
def validate_area(
@@ -370,14 +372,14 @@ def validate_area(
370372
map_to="household",
371373
period=period,
372374
).values
373-
if reform_id > 0 and variable not in reform_hh_cache:
374-
reform_income_tax = _get_reform_household_values(
375+
if reform_id > 0:
376+
reform_hh_cache[variable] = _get_reform_income_tax_delta(
375377
dataset_path,
376378
period,
377379
variable,
380+
hh_vars_cache["income_tax"],
378381
reform_hh_cache,
379382
)
380-
reform_hh_cache[variable] = reform_income_tax - hh_vars_cache["income_tax"]
381383

382384
per_hh = _calculate_target_values_standalone(
383385
target_variable=variable,
@@ -670,14 +672,14 @@ def _compute_district_contributions(
670672
map_to="household",
671673
period=period,
672674
).values
673-
if reform_id > 0 and variable not in reform_hh_cache:
674-
reform_income_tax = _get_reform_household_values(
675+
if reform_id > 0:
676+
reform_hh_cache[variable] = _get_reform_income_tax_delta(
675677
district_h5_path,
676678
period,
677679
variable,
680+
hh_vars_cache["income_tax"],
678681
reform_hh_cache,
679682
)
680-
reform_hh_cache[variable] = reform_income_tax - hh_vars_cache["income_tax"]
681683

682684
per_hh = _calculate_target_values_standalone(
683685
target_variable=variable,
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import sys
2+
from types import SimpleNamespace
3+
from unittest.mock import patch
4+
5+
import numpy as np
6+
7+
from policyengine_us_data.calibration.validate_staging import (
8+
_get_reform_income_tax_delta,
9+
)
10+
11+
12+
class _FakeArrayResult:
13+
def __init__(self, values):
14+
self.values = values
15+
16+
17+
class _FakeMicrosimulation:
18+
def __init__(self, dataset=None, reform=None):
19+
self.dataset = dataset
20+
self.reform = reform
21+
22+
def calculate(self, variable, map_to=None, period=None):
23+
assert variable == "income_tax"
24+
assert map_to == "household"
25+
assert period == 2024
26+
return _FakeArrayResult(np.array([150.0, 260.0], dtype=np.float32))
27+
28+
29+
@patch.dict(
30+
sys.modules,
31+
{"policyengine_us": SimpleNamespace(Microsimulation=_FakeMicrosimulation)},
32+
)
33+
def test_get_reform_income_tax_delta_caches_delta():
34+
baseline_income_tax = np.array([100.0, 200.0], dtype=np.float32)
35+
cache = {}
36+
37+
delta = _get_reform_income_tax_delta(
38+
dataset_path="fake.h5",
39+
period=2024,
40+
variable="salt_deduction",
41+
baseline_income_tax=baseline_income_tax,
42+
reform_delta_cache=cache,
43+
)
44+
45+
np.testing.assert_array_equal(delta, np.array([50.0, 60.0], dtype=np.float32))
46+
np.testing.assert_array_equal(cache["salt_deduction"], delta)
47+
48+
# The cached value should remain the delta, not the raw reform income tax.
49+
cached = _get_reform_income_tax_delta(
50+
dataset_path="fake.h5",
51+
period=2024,
52+
variable="salt_deduction",
53+
baseline_income_tax=np.array([0.0, 0.0], dtype=np.float32),
54+
reform_delta_cache=cache,
55+
)
56+
np.testing.assert_array_equal(cached, np.array([50.0, 60.0], dtype=np.float32))
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import pandas as pd
2+
from sqlmodel import Session
3+
4+
from policyengine_us_data.db.create_database_tables import (
5+
Stratum,
6+
StratumConstraint,
7+
Target,
8+
create_database,
9+
)
10+
from policyengine_us_data.db.etl_national_targets import (
11+
TAX_EXPENDITURE_REFORM_ID,
12+
load_national_targets,
13+
)
14+
15+
16+
def _make_stratum(session, parent_id=None, notes=None, constraints=None):
17+
stratum = Stratum(parent_stratum_id=parent_id, notes=notes)
18+
stratum.constraints_rel = constraints or []
19+
session.add(stratum)
20+
session.commit()
21+
session.refresh(stratum)
22+
return stratum
23+
24+
25+
def test_load_national_targets_deactivates_stale_baseline_rows(tmp_path, monkeypatch):
26+
calibration_dir = tmp_path / "calibration"
27+
calibration_dir.mkdir()
28+
db_uri = f"sqlite:///{calibration_dir / 'policy_data.db'}"
29+
engine = create_database(db_uri)
30+
31+
with Session(engine) as session:
32+
national = _make_stratum(session, notes="United States")
33+
filer = _make_stratum(
34+
session,
35+
parent_id=national.stratum_id,
36+
notes="United States - Tax Filers",
37+
constraints=[
38+
StratumConstraint(
39+
constraint_variable="tax_unit_is_filer",
40+
operation="==",
41+
value="1",
42+
)
43+
],
44+
)
45+
itemizer = _make_stratum(
46+
session,
47+
parent_id=national.stratum_id,
48+
notes="United States - Itemizing Tax Filers",
49+
constraints=[
50+
StratumConstraint(
51+
constraint_variable="tax_unit_is_filer",
52+
operation="==",
53+
value="1",
54+
),
55+
StratumConstraint(
56+
constraint_variable="tax_unit_itemizes",
57+
operation="==",
58+
value="1",
59+
),
60+
],
61+
)
62+
63+
session.add(
64+
Target(
65+
stratum_id=filer.stratum_id,
66+
variable="qualified_business_income_deduction",
67+
period=2024,
68+
value=63.1e9,
69+
active=True,
70+
reform_id=0,
71+
)
72+
)
73+
session.add(
74+
Target(
75+
stratum_id=itemizer.stratum_id,
76+
variable="salt_deduction",
77+
period=2024,
78+
value=21.247e9,
79+
active=True,
80+
reform_id=0,
81+
)
82+
)
83+
session.commit()
84+
85+
monkeypatch.setattr(
86+
"policyengine_us_data.db.etl_national_targets.STORAGE_FOLDER",
87+
tmp_path,
88+
)
89+
90+
tax_expenditure_df = pd.DataFrame(
91+
[
92+
{
93+
"variable": "salt_deduction",
94+
"value": 21.247e9,
95+
"source": "Joint Committee on Taxation",
96+
"notes": "SALT deduction tax expenditure",
97+
"year": 2024,
98+
},
99+
{
100+
"variable": "qualified_business_income_deduction",
101+
"value": 63.1e9,
102+
"source": "Joint Committee on Taxation",
103+
"notes": "QBI deduction tax expenditure",
104+
"year": 2024,
105+
},
106+
]
107+
)
108+
109+
load_national_targets(
110+
direct_targets_df=pd.DataFrame(),
111+
tax_filer_df=pd.DataFrame(),
112+
tax_expenditure_df=tax_expenditure_df,
113+
conditional_targets=[],
114+
)
115+
load_national_targets(
116+
direct_targets_df=pd.DataFrame(),
117+
tax_filer_df=pd.DataFrame(),
118+
tax_expenditure_df=tax_expenditure_df,
119+
conditional_targets=[],
120+
)
121+
122+
with Session(engine) as session:
123+
stale_rows = session.query(Target).filter(Target.reform_id == 0).all()
124+
assert stale_rows
125+
assert all(not target.active for target in stale_rows)
126+
127+
reform_rows = (
128+
session.query(Target)
129+
.filter(Target.reform_id == TAX_EXPENDITURE_REFORM_ID)
130+
.all()
131+
)
132+
assert len(reform_rows) == 2
133+
assert all(target.active for target in reform_rows)
134+
assert {target.variable for target in reform_rows} == {
135+
"salt_deduction",
136+
"qualified_business_income_deduction",
137+
}
138+
assert all(
139+
"Modeled as repeal-based income tax expenditure target"
140+
in (target.notes or "")
141+
for target in reform_rows
142+
)

0 commit comments

Comments
 (0)