Skip to content

Commit 580ed77

Browse files
committed
Fix post-reform calibration test compatibility
1 parent c71315b commit 580ed77

6 files changed

Lines changed: 48 additions & 8 deletions

File tree

policyengine_us_data/calibration/unified_matrix_builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from scipy import sparse
2020
from sqlalchemy import create_engine, text
2121

22+
from policyengine_us_data.db.create_database_tables import refresh_sql_views
2223
from policyengine_us_data.storage import STORAGE_FOLDER
2324
from policyengine_us_data.utils.census import STATE_NAME_TO_FIPS
2425
from policyengine_us_data.calibration.calibration_utils import (
@@ -928,6 +929,8 @@ def __init__(
928929
):
929930
self.db_uri = db_uri
930931
self.engine = create_engine(db_uri)
932+
# Existing SQLite checkpoints may carry an older target_overview view.
933+
refresh_sql_views(self.engine)
931934
self.time_period = time_period
932935
self.dataset_path = dataset_path
933936
self._entity_rel_cache = None

policyengine_us_data/calibration/validate_staging.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from policyengine_us_data.calibration.sanity_checks import (
4242
run_sanity_checks,
4343
)
44+
from policyengine_us_data.db.create_database_tables import refresh_sql_views
4445

4546
logger = logging.getLogger(__name__)
4647

@@ -537,6 +538,7 @@ def _validate_single_area(
537538
from sqlalchemy import create_engine as _create_engine
538539

539540
engine = _create_engine(f"sqlite:///{db_path}")
541+
refresh_sql_views(engine)
540542

541543
logger.info("Loading sim from %s", h5_path)
542544
try:
@@ -1015,6 +1017,7 @@ def main(argv=None):
10151017
from policyengine_us import Microsimulation
10161018

10171019
engine = create_engine(f"sqlite:///{args.db_path}")
1020+
refresh_sql_views(engine)
10181021

10191022
all_targets = _query_all_active_targets(engine, args.period)
10201023
logger.info("Loaded %d active targets from DB", len(all_targets))

policyengine_us_data/db/create_database_tables.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,16 @@ def validate_parent_child_constraints(mapper, connection, target: Stratum):
353353
"""
354354

355355

356+
def refresh_sql_views(engine) -> None:
357+
"""Recreate derived SQL views so existing DBs pick up schema changes."""
358+
with engine.connect() as conn:
359+
conn.execute(text("DROP VIEW IF EXISTS target_overview"))
360+
conn.execute(text("DROP VIEW IF EXISTS stratum_domain"))
361+
conn.execute(text(STRATUM_DOMAIN_VIEW))
362+
conn.execute(text(TARGET_OVERVIEW_VIEW))
363+
conn.commit()
364+
365+
356366
def create_validation_triggers(engine) -> None:
357367
"""Create SQL triggers that validate fields against field_valid_values.
358368
@@ -506,11 +516,8 @@ def create_database(
506516
# Create validation triggers
507517
create_validation_triggers(engine)
508518

509-
# Create SQL views
510-
with engine.connect() as conn:
511-
conn.execute(text(STRATUM_DOMAIN_VIEW))
512-
conn.execute(text(TARGET_OVERVIEW_VIEW))
513-
conn.commit()
519+
# Recreate SQL views so existing DB files do not keep stale definitions.
520+
refresh_sql_views(engine)
514521

515522
logger.info(f"Database and tables created successfully at {db_uri}")
516523
return engine

policyengine_us_data/tests/test_calibration/test_unified_calibration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def test_county_var_uses_county_values(self):
459459
person_hh_idx = np.array([0, 1, 2, 3])
460460

461461
builder = UnifiedMatrixBuilder.__new__(UnifiedMatrixBuilder)
462-
hh_vars, _ = builder._assemble_clone_values(
462+
hh_vars, _, _ = builder._assemble_clone_values(
463463
state_values,
464464
clone_states,
465465
person_hh_idx,
@@ -499,7 +499,7 @@ def test_non_county_var_uses_state_values(self):
499499
person_hh_idx = np.array([0, 1, 2, 3])
500500

501501
builder = UnifiedMatrixBuilder.__new__(UnifiedMatrixBuilder)
502-
hh_vars, _ = builder._assemble_clone_values(
502+
hh_vars, _, _ = builder._assemble_clone_values(
503503
state_values,
504504
clone_states,
505505
person_hh_idx,

policyengine_us_data/tests/test_calibration/test_unified_matrix_builder.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def test_reform_targets_preserved(self):
211211

212212
def test_inactive_targets_are_excluded(self):
213213
b = self._make_builder(time_period=2024)
214-
df = b._query_targets({"stratum_ids": [1], "variables": ["aca_ptc"]})
214+
df = b._query_targets({"target_ids": [1, 18, 19]})
215215
baseline_rows = df[(df["variable"] == "aca_ptc") & (df["reform_id"] == 0)]
216216
self.assertEqual(len(baseline_rows), 1)
217217
self.assertEqual(int(baseline_rows.iloc[0]["period"]), 2022)
@@ -526,6 +526,7 @@ def test_return_structure_no_takeup(self, mock_msim_cls, mock_gcv):
526526
sim=None,
527527
target_vars={"snap"},
528528
constraint_vars={"income"},
529+
reform_vars=set(),
529530
geography=geo,
530531
rerandomize_takeup=False,
531532
)
@@ -560,6 +561,7 @@ def test_fresh_sim_per_state(self, mock_msim_cls, mock_gcv):
560561
sim=None,
561562
target_vars={"snap"},
562563
constraint_vars=set(),
564+
reform_vars=set(),
563565
geography=geo,
564566
rerandomize_takeup=False,
565567
)
@@ -582,6 +584,7 @@ def test_state_fips_set_correctly(self, mock_msim_cls, mock_gcv):
582584
sim=None,
583585
target_vars={"snap"},
584586
constraint_vars=set(),
587+
reform_vars=set(),
585588
geography=geo,
586589
rerandomize_takeup=False,
587590
)
@@ -617,6 +620,7 @@ def test_takeup_vars_forced_true(self, mock_msim_cls, mock_gcv):
617620
sim=None,
618621
target_vars={"snap"},
619622
constraint_vars=set(),
623+
reform_vars=set(),
620624
geography=geo,
621625
rerandomize_takeup=True,
622626
)
@@ -664,6 +668,7 @@ def test_count_vars_skipped(self, mock_msim_cls, mock_gcv):
664668
sim=None,
665669
target_vars={"snap", "snap_count"},
666670
constraint_vars=set(),
671+
reform_vars=set(),
667672
geography=geo,
668673
rerandomize_takeup=False,
669674
)
@@ -914,6 +919,7 @@ def test_workers_gt1_creates_pool(self, mock_msim_cls, mock_gcv, mock_pool_cls):
914919
sim=None,
915920
target_vars={"snap"},
916921
constraint_vars=set(),
922+
reform_vars=set(),
917923
geography=geo,
918924
rerandomize_takeup=False,
919925
workers=2,
@@ -939,6 +945,7 @@ def test_workers_1_skips_pool(self, mock_msim_cls, mock_gcv):
939945
sim=None,
940946
target_vars={"snap"},
941947
constraint_vars=set(),
948+
reform_vars=set(),
942949
geography=geo,
943950
rerandomize_takeup=False,
944951
workers=1,

policyengine_us_data/tests/test_schema_views_and_lookups.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
import tempfile
1212
import unittest
1313

14+
from sqlalchemy import text
1415
from sqlmodel import Session
1516

1617
from policyengine_us_data.db.create_database_tables import (
1718
Stratum,
1819
StratumConstraint,
1920
Target,
2021
create_database,
22+
refresh_sql_views,
2123
)
2224
from policyengine_us_data.utils.db import get_geographic_strata
2325
from policyengine_us_data.calibration.calibration_utils import (
@@ -399,6 +401,24 @@ def test_reform_id_passthrough(self):
399401
self.assertEqual(len(matches), 1)
400402
self.assertEqual(matches[0][reform_idx], 1)
401403

404+
def test_refresh_sql_views_updates_existing_target_overview(self):
405+
"""Refreshing views updates stale target_overview definitions."""
406+
with self.engine.connect() as conn:
407+
conn.execute(text("DROP VIEW IF EXISTS target_overview"))
408+
conn.execute(
409+
text(
410+
"CREATE VIEW target_overview AS "
411+
"SELECT target_id, stratum_id, variable, value, period, active "
412+
"FROM targets"
413+
)
414+
)
415+
conn.commit()
416+
417+
refresh_sql_views(self.engine)
418+
419+
cols = self._overview_columns()
420+
self.assertIn("reform_id", cols)
421+
402422
# ----------------------------------------------------------------
403423
# get_geographic_strata()
404424
# ----------------------------------------------------------------

0 commit comments

Comments
 (0)