Skip to content

Commit 44102e3

Browse files
authored
Merge pull request #651 from PolicyEngine/codex/pr639-cleanup
Clean up reform validation cache and add regression tests
2 parents 9ee1a2c + a71c069 commit 44102e3

5 files changed

Lines changed: 238 additions & 12 deletions

File tree

policyengine_us_data/calibration/unified_matrix_builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from scipy import sparse
2020
from sqlalchemy import create_engine, text
2121

22+
from policyengine_us_data.db.create_database_tables import create_or_replace_views
2223
from policyengine_us_data.storage import STORAGE_FOLDER
2324
from policyengine_us_data.utils.census import STATE_NAME_TO_FIPS
2425
from policyengine_us_data.calibration.calibration_utils import (
@@ -928,6 +929,8 @@ def __init__(
928929
):
929930
self.db_uri = db_uri
930931
self.engine = create_engine(db_uri)
932+
# Existing SQLite checkpoints may carry an older target_overview view.
933+
create_or_replace_views(self.engine)
931934
self.time_period = time_period
932935
self.dataset_path = dataset_path
933936
self._entity_rel_cache = None

policyengine_us_data/calibration/validate_staging.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from policyengine_us_data.calibration.sanity_checks import (
4242
run_sanity_checks,
4343
)
44+
from policyengine_us_data.db.create_database_tables import create_or_replace_views
4445

4546
logger = logging.getLogger(__name__)
4647

@@ -270,27 +271,29 @@ def _build_entity_rel(sim) -> pd.DataFrame:
270271
)
271272

272273

273-
def _get_reform_household_values(
274+
def _get_reform_income_tax_delta(
274275
dataset_path: str,
275276
period: int,
276277
variable: str,
277-
reform_hh_cache: dict,
278+
baseline_income_tax: np.ndarray,
279+
reform_delta_cache: dict,
278280
) -> np.ndarray:
279-
if variable in reform_hh_cache:
280-
return reform_hh_cache[variable]
281+
if variable in reform_delta_cache:
282+
return reform_delta_cache[variable]
281283

282284
from policyengine_us import Microsimulation
283285

284286
reform_sim = Microsimulation(
285287
dataset=dataset_path,
286288
reform=_make_neutralize_variable_reform(variable),
287289
)
288-
reform_hh_cache[variable] = reform_sim.calculate(
290+
reform_income_tax = reform_sim.calculate(
289291
"income_tax",
290292
map_to="household",
291293
period=period,
292294
).values
293-
return reform_hh_cache[variable]
295+
reform_delta_cache[variable] = reform_income_tax - baseline_income_tax
296+
return reform_delta_cache[variable]
294297

295298

296299
def validate_area(
@@ -370,14 +373,14 @@ def validate_area(
370373
map_to="household",
371374
period=period,
372375
).values
373-
if reform_id > 0 and variable not in reform_hh_cache:
374-
reform_income_tax = _get_reform_household_values(
376+
if reform_id > 0:
377+
reform_hh_cache[variable] = _get_reform_income_tax_delta(
375378
dataset_path,
376379
period,
377380
variable,
381+
hh_vars_cache["income_tax"],
378382
reform_hh_cache,
379383
)
380-
reform_hh_cache[variable] = reform_income_tax - hh_vars_cache["income_tax"]
381384

382385
per_hh = _calculate_target_values_standalone(
383386
target_variable=variable,
@@ -535,6 +538,7 @@ def _validate_single_area(
535538
from sqlalchemy import create_engine as _create_engine
536539

537540
engine = _create_engine(f"sqlite:///{db_path}")
541+
create_or_replace_views(engine)
538542

539543
logger.info("Loading sim from %s", h5_path)
540544
try:
@@ -670,14 +674,14 @@ def _compute_district_contributions(
670674
map_to="household",
671675
period=period,
672676
).values
673-
if reform_id > 0 and variable not in reform_hh_cache:
674-
reform_income_tax = _get_reform_household_values(
677+
if reform_id > 0:
678+
reform_hh_cache[variable] = _get_reform_income_tax_delta(
675679
district_h5_path,
676680
period,
677681
variable,
682+
hh_vars_cache["income_tax"],
678683
reform_hh_cache,
679684
)
680-
reform_hh_cache[variable] = reform_income_tax - hh_vars_cache["income_tax"]
681685

682686
per_hh = _calculate_target_values_standalone(
683687
target_variable=variable,
@@ -1013,6 +1017,7 @@ def main(argv=None):
10131017
from policyengine_us import Microsimulation
10141018

10151019
engine = create_engine(f"sqlite:///{args.db_path}")
1020+
create_or_replace_views(engine)
10161021

10171022
all_targets = _query_all_active_targets(engine, args.period)
10181023
logger.info("Loaded %d active targets from DB", len(all_targets))
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import sys
2+
from types import SimpleNamespace
3+
from unittest.mock import patch
4+
5+
import numpy as np
6+
7+
from policyengine_us_data.calibration.validate_staging import (
8+
_get_reform_income_tax_delta,
9+
)
10+
11+
12+
class _FakeArrayResult:
13+
def __init__(self, values):
14+
self.values = values
15+
16+
17+
class _FakeMicrosimulation:
18+
def __init__(self, dataset=None, reform=None):
19+
self.dataset = dataset
20+
self.reform = reform
21+
22+
def calculate(self, variable, map_to=None, period=None):
23+
assert variable == "income_tax"
24+
assert map_to == "household"
25+
assert period == 2024
26+
return _FakeArrayResult(np.array([150.0, 260.0], dtype=np.float32))
27+
28+
29+
@patch.dict(
30+
sys.modules,
31+
{"policyengine_us": SimpleNamespace(Microsimulation=_FakeMicrosimulation)},
32+
)
33+
def test_get_reform_income_tax_delta_caches_delta():
34+
baseline_income_tax = np.array([100.0, 200.0], dtype=np.float32)
35+
cache = {}
36+
37+
delta = _get_reform_income_tax_delta(
38+
dataset_path="fake.h5",
39+
period=2024,
40+
variable="salt_deduction",
41+
baseline_income_tax=baseline_income_tax,
42+
reform_delta_cache=cache,
43+
)
44+
45+
np.testing.assert_array_equal(delta, np.array([50.0, 60.0], dtype=np.float32))
46+
np.testing.assert_array_equal(cache["salt_deduction"], delta)
47+
48+
# The cached value should remain the delta, not the raw reform income tax.
49+
cached = _get_reform_income_tax_delta(
50+
dataset_path="fake.h5",
51+
period=2024,
52+
variable="salt_deduction",
53+
baseline_income_tax=np.array([0.0, 0.0], dtype=np.float32),
54+
reform_delta_cache=cache,
55+
)
56+
np.testing.assert_array_equal(cached, np.array([50.0, 60.0], dtype=np.float32))
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import pandas as pd
2+
from sqlmodel import Session
3+
4+
from policyengine_us_data.db.create_database_tables import (
5+
Stratum,
6+
StratumConstraint,
7+
Target,
8+
create_database,
9+
)
10+
from policyengine_us_data.db.etl_national_targets import (
11+
TAX_EXPENDITURE_REFORM_ID,
12+
load_national_targets,
13+
)
14+
15+
16+
def _make_stratum(session, parent_id=None, notes=None, constraints=None):
17+
stratum = Stratum(parent_stratum_id=parent_id, notes=notes)
18+
stratum.constraints_rel = constraints or []
19+
session.add(stratum)
20+
session.commit()
21+
session.refresh(stratum)
22+
return stratum
23+
24+
25+
def test_load_national_targets_deactivates_stale_baseline_rows(tmp_path, monkeypatch):
26+
calibration_dir = tmp_path / "calibration"
27+
calibration_dir.mkdir()
28+
db_uri = f"sqlite:///{calibration_dir / 'policy_data.db'}"
29+
engine = create_database(db_uri)
30+
31+
with Session(engine) as session:
32+
national = _make_stratum(session, notes="United States")
33+
filer = _make_stratum(
34+
session,
35+
parent_id=national.stratum_id,
36+
notes="United States - Tax Filers",
37+
constraints=[
38+
StratumConstraint(
39+
constraint_variable="tax_unit_is_filer",
40+
operation="==",
41+
value="1",
42+
)
43+
],
44+
)
45+
itemizer = _make_stratum(
46+
session,
47+
parent_id=national.stratum_id,
48+
notes="United States - Itemizing Tax Filers",
49+
constraints=[
50+
StratumConstraint(
51+
constraint_variable="tax_unit_is_filer",
52+
operation="==",
53+
value="1",
54+
),
55+
StratumConstraint(
56+
constraint_variable="tax_unit_itemizes",
57+
operation="==",
58+
value="1",
59+
),
60+
],
61+
)
62+
63+
session.add(
64+
Target(
65+
stratum_id=filer.stratum_id,
66+
variable="qualified_business_income_deduction",
67+
period=2024,
68+
value=63.1e9,
69+
active=True,
70+
reform_id=0,
71+
)
72+
)
73+
session.add(
74+
Target(
75+
stratum_id=itemizer.stratum_id,
76+
variable="salt_deduction",
77+
period=2024,
78+
value=21.247e9,
79+
active=True,
80+
reform_id=0,
81+
)
82+
)
83+
session.commit()
84+
85+
monkeypatch.setattr(
86+
"policyengine_us_data.db.etl_national_targets.STORAGE_FOLDER",
87+
tmp_path,
88+
)
89+
90+
tax_expenditure_df = pd.DataFrame(
91+
[
92+
{
93+
"variable": "salt_deduction",
94+
"value": 21.247e9,
95+
"source": "Joint Committee on Taxation",
96+
"notes": "SALT deduction tax expenditure",
97+
"year": 2024,
98+
},
99+
{
100+
"variable": "qualified_business_income_deduction",
101+
"value": 63.1e9,
102+
"source": "Joint Committee on Taxation",
103+
"notes": "QBI deduction tax expenditure",
104+
"year": 2024,
105+
},
106+
]
107+
)
108+
109+
load_national_targets(
110+
direct_targets_df=pd.DataFrame(),
111+
tax_filer_df=pd.DataFrame(),
112+
tax_expenditure_df=tax_expenditure_df,
113+
conditional_targets=[],
114+
)
115+
load_national_targets(
116+
direct_targets_df=pd.DataFrame(),
117+
tax_filer_df=pd.DataFrame(),
118+
tax_expenditure_df=tax_expenditure_df,
119+
conditional_targets=[],
120+
)
121+
122+
with Session(engine) as session:
123+
stale_rows = session.query(Target).filter(Target.reform_id == 0).all()
124+
assert stale_rows
125+
assert all(not target.active for target in stale_rows)
126+
127+
reform_rows = (
128+
session.query(Target)
129+
.filter(Target.reform_id == TAX_EXPENDITURE_REFORM_ID)
130+
.all()
131+
)
132+
assert len(reform_rows) == 2
133+
assert all(target.active for target in reform_rows)
134+
assert {target.variable for target in reform_rows} == {
135+
"salt_deduction",
136+
"qualified_business_income_deduction",
137+
}
138+
assert all(
139+
"Modeled as repeal-based income tax expenditure target"
140+
in (target.notes or "")
141+
for target in reform_rows
142+
)

policyengine_us_data/tests/test_schema_views_and_lookups.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@
1111
import tempfile
1212
import unittest
1313

14+
from sqlalchemy import text
1415
from sqlmodel import Session
1516

1617
from policyengine_us_data.db.create_database_tables import (
1718
Stratum,
1819
StratumConstraint,
1920
Target,
21+
create_or_replace_views,
2022
create_database,
2123
)
2224
from policyengine_us_data.utils.db import get_geographic_strata
@@ -399,6 +401,24 @@ def test_reform_id_passthrough(self):
399401
self.assertEqual(len(matches), 1)
400402
self.assertEqual(matches[0][reform_idx], 1)
401403

404+
def test_create_or_replace_views_updates_existing_target_overview(self):
405+
"""Refreshing views updates stale target_overview definitions."""
406+
with self.engine.connect() as conn:
407+
conn.execute(text("DROP VIEW IF EXISTS target_overview"))
408+
conn.execute(
409+
text(
410+
"CREATE VIEW target_overview AS "
411+
"SELECT target_id, stratum_id, variable, value, period, active "
412+
"FROM targets"
413+
)
414+
)
415+
conn.commit()
416+
417+
create_or_replace_views(self.engine)
418+
419+
cols = self._overview_columns()
420+
self.assertIn("reform_id", cols)
421+
402422
# ----------------------------------------------------------------
403423
# get_geographic_strata()
404424
# ----------------------------------------------------------------

0 commit comments

Comments
 (0)