Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions policyengine_us_data/calibration/unified_matrix_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from scipy import sparse
from sqlalchemy import create_engine, text

from policyengine_us_data.db.create_database_tables import create_or_replace_views
from policyengine_us_data.storage import STORAGE_FOLDER
from policyengine_us_data.utils.census import STATE_NAME_TO_FIPS
from policyengine_us_data.calibration.calibration_utils import (
Expand Down Expand Up @@ -928,6 +929,8 @@ def __init__(
):
self.db_uri = db_uri
self.engine = create_engine(db_uri)
# Existing SQLite checkpoints may carry an older target_overview view.
create_or_replace_views(self.engine)
self.time_period = time_period
self.dataset_path = dataset_path
self._entity_rel_cache = None
Expand Down
29 changes: 17 additions & 12 deletions policyengine_us_data/calibration/validate_staging.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from policyengine_us_data.calibration.sanity_checks import (
run_sanity_checks,
)
from policyengine_us_data.db.create_database_tables import create_or_replace_views

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -270,27 +271,29 @@ def _build_entity_rel(sim) -> pd.DataFrame:
)


def _get_reform_household_values(
def _get_reform_income_tax_delta(
dataset_path: str,
period: int,
variable: str,
reform_hh_cache: dict,
baseline_income_tax: np.ndarray,
reform_delta_cache: dict,
) -> np.ndarray:
if variable in reform_hh_cache:
return reform_hh_cache[variable]
if variable in reform_delta_cache:
return reform_delta_cache[variable]

from policyengine_us import Microsimulation

reform_sim = Microsimulation(
dataset=dataset_path,
reform=_make_neutralize_variable_reform(variable),
)
reform_hh_cache[variable] = reform_sim.calculate(
reform_income_tax = reform_sim.calculate(
"income_tax",
map_to="household",
period=period,
).values
return reform_hh_cache[variable]
reform_delta_cache[variable] = reform_income_tax - baseline_income_tax
return reform_delta_cache[variable]


def validate_area(
Expand Down Expand Up @@ -370,14 +373,14 @@ def validate_area(
map_to="household",
period=period,
).values
if reform_id > 0 and variable not in reform_hh_cache:
reform_income_tax = _get_reform_household_values(
if reform_id > 0:
reform_hh_cache[variable] = _get_reform_income_tax_delta(
dataset_path,
period,
variable,
hh_vars_cache["income_tax"],
reform_hh_cache,
)
reform_hh_cache[variable] = reform_income_tax - hh_vars_cache["income_tax"]

per_hh = _calculate_target_values_standalone(
target_variable=variable,
Expand Down Expand Up @@ -535,6 +538,7 @@ def _validate_single_area(
from sqlalchemy import create_engine as _create_engine

engine = _create_engine(f"sqlite:///{db_path}")
create_or_replace_views(engine)

logger.info("Loading sim from %s", h5_path)
try:
Expand Down Expand Up @@ -670,14 +674,14 @@ def _compute_district_contributions(
map_to="household",
period=period,
).values
if reform_id > 0 and variable not in reform_hh_cache:
reform_income_tax = _get_reform_household_values(
if reform_id > 0:
reform_hh_cache[variable] = _get_reform_income_tax_delta(
district_h5_path,
period,
variable,
hh_vars_cache["income_tax"],
reform_hh_cache,
)
reform_hh_cache[variable] = reform_income_tax - hh_vars_cache["income_tax"]

per_hh = _calculate_target_values_standalone(
target_variable=variable,
Expand Down Expand Up @@ -1013,6 +1017,7 @@ def main(argv=None):
from policyengine_us import Microsimulation

engine = create_engine(f"sqlite:///{args.db_path}")
create_or_replace_views(engine)

all_targets = _query_all_active_targets(engine, args.period)
logger.info("Loaded %d active targets from DB", len(all_targets))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import sys
from types import SimpleNamespace
from unittest.mock import patch

import numpy as np

from policyengine_us_data.calibration.validate_staging import (
_get_reform_income_tax_delta,
)


class _FakeArrayResult:
def __init__(self, values):
self.values = values


class _FakeMicrosimulation:
def __init__(self, dataset=None, reform=None):
self.dataset = dataset
self.reform = reform

def calculate(self, variable, map_to=None, period=None):
assert variable == "income_tax"
assert map_to == "household"
assert period == 2024
return _FakeArrayResult(np.array([150.0, 260.0], dtype=np.float32))


@patch.dict(
sys.modules,
{"policyengine_us": SimpleNamespace(Microsimulation=_FakeMicrosimulation)},
)
def test_get_reform_income_tax_delta_caches_delta():
baseline_income_tax = np.array([100.0, 200.0], dtype=np.float32)
cache = {}

delta = _get_reform_income_tax_delta(
dataset_path="fake.h5",
period=2024,
variable="salt_deduction",
baseline_income_tax=baseline_income_tax,
reform_delta_cache=cache,
)

np.testing.assert_array_equal(delta, np.array([50.0, 60.0], dtype=np.float32))
np.testing.assert_array_equal(cache["salt_deduction"], delta)

# The cached value should remain the delta, not the raw reform income tax.
cached = _get_reform_income_tax_delta(
dataset_path="fake.h5",
period=2024,
variable="salt_deduction",
baseline_income_tax=np.array([0.0, 0.0], dtype=np.float32),
reform_delta_cache=cache,
)
np.testing.assert_array_equal(cached, np.array([50.0, 60.0], dtype=np.float32))
142 changes: 142 additions & 0 deletions policyengine_us_data/tests/test_etl_national_targets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import pandas as pd
from sqlmodel import Session

from policyengine_us_data.db.create_database_tables import (
Stratum,
StratumConstraint,
Target,
create_database,
)
from policyengine_us_data.db.etl_national_targets import (
TAX_EXPENDITURE_REFORM_ID,
load_national_targets,
)


def _make_stratum(session, parent_id=None, notes=None, constraints=None):
stratum = Stratum(parent_stratum_id=parent_id, notes=notes)
stratum.constraints_rel = constraints or []
session.add(stratum)
session.commit()
session.refresh(stratum)
return stratum


def test_load_national_targets_deactivates_stale_baseline_rows(tmp_path, monkeypatch):
calibration_dir = tmp_path / "calibration"
calibration_dir.mkdir()
db_uri = f"sqlite:///{calibration_dir / 'policy_data.db'}"
engine = create_database(db_uri)

with Session(engine) as session:
national = _make_stratum(session, notes="United States")
filer = _make_stratum(
session,
parent_id=national.stratum_id,
notes="United States - Tax Filers",
constraints=[
StratumConstraint(
constraint_variable="tax_unit_is_filer",
operation="==",
value="1",
)
],
)
itemizer = _make_stratum(
session,
parent_id=national.stratum_id,
notes="United States - Itemizing Tax Filers",
constraints=[
StratumConstraint(
constraint_variable="tax_unit_is_filer",
operation="==",
value="1",
),
StratumConstraint(
constraint_variable="tax_unit_itemizes",
operation="==",
value="1",
),
],
)

session.add(
Target(
stratum_id=filer.stratum_id,
variable="qualified_business_income_deduction",
period=2024,
value=63.1e9,
active=True,
reform_id=0,
)
)
session.add(
Target(
stratum_id=itemizer.stratum_id,
variable="salt_deduction",
period=2024,
value=21.247e9,
active=True,
reform_id=0,
)
)
session.commit()

monkeypatch.setattr(
"policyengine_us_data.db.etl_national_targets.STORAGE_FOLDER",
tmp_path,
)

tax_expenditure_df = pd.DataFrame(
[
{
"variable": "salt_deduction",
"value": 21.247e9,
"source": "Joint Committee on Taxation",
"notes": "SALT deduction tax expenditure",
"year": 2024,
},
{
"variable": "qualified_business_income_deduction",
"value": 63.1e9,
"source": "Joint Committee on Taxation",
"notes": "QBI deduction tax expenditure",
"year": 2024,
},
]
)

load_national_targets(
direct_targets_df=pd.DataFrame(),
tax_filer_df=pd.DataFrame(),
tax_expenditure_df=tax_expenditure_df,
conditional_targets=[],
)
load_national_targets(
direct_targets_df=pd.DataFrame(),
tax_filer_df=pd.DataFrame(),
tax_expenditure_df=tax_expenditure_df,
conditional_targets=[],
)

with Session(engine) as session:
stale_rows = session.query(Target).filter(Target.reform_id == 0).all()
assert stale_rows
assert all(not target.active for target in stale_rows)

reform_rows = (
session.query(Target)
.filter(Target.reform_id == TAX_EXPENDITURE_REFORM_ID)
.all()
)
assert len(reform_rows) == 2
assert all(target.active for target in reform_rows)
assert {target.variable for target in reform_rows} == {
"salt_deduction",
"qualified_business_income_deduction",
}
assert all(
"Modeled as repeal-based income tax expenditure target"
in (target.notes or "")
for target in reform_rows
)
20 changes: 20 additions & 0 deletions policyengine_us_data/tests/test_schema_views_and_lookups.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
import tempfile
import unittest

from sqlalchemy import text
from sqlmodel import Session

from policyengine_us_data.db.create_database_tables import (
Stratum,
StratumConstraint,
Target,
create_or_replace_views,
create_database,
)
from policyengine_us_data.utils.db import get_geographic_strata
Expand Down Expand Up @@ -399,6 +401,24 @@ def test_reform_id_passthrough(self):
self.assertEqual(len(matches), 1)
self.assertEqual(matches[0][reform_idx], 1)

def test_create_or_replace_views_updates_existing_target_overview(self):
"""Refreshing views updates stale target_overview definitions."""
with self.engine.connect() as conn:
conn.execute(text("DROP VIEW IF EXISTS target_overview"))
conn.execute(
text(
"CREATE VIEW target_overview AS "
"SELECT target_id, stratum_id, variable, value, period, active "
"FROM targets"
)
)
conn.commit()

create_or_replace_views(self.engine)

cols = self._overview_columns()
self.assertIn("reform_id", cols)

# ----------------------------------------------------------------
# get_geographic_strata()
# ----------------------------------------------------------------
Expand Down
Loading