Skip to content

Commit 278abbf

Browse files
committed
Fix post-reform calibration test compatibility
1 parent 8a95916 commit 278abbf

4 files changed

Lines changed: 38 additions & 5 deletions

File tree

policyengine_us_data/calibration/unified_matrix_builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from scipy import sparse
2020
from sqlalchemy import create_engine, text
2121

22+
from policyengine_us_data.db.create_database_tables import refresh_sql_views
2223
from policyengine_us_data.storage import STORAGE_FOLDER
2324
from policyengine_us_data.utils.census import STATE_NAME_TO_FIPS
2425
from policyengine_us_data.calibration.calibration_utils import (
@@ -928,6 +929,8 @@ def __init__(
928929
):
929930
self.db_uri = db_uri
930931
self.engine = create_engine(db_uri)
932+
# Existing SQLite checkpoints may carry an older target_overview view.
933+
refresh_sql_views(self.engine)
931934
self.time_period = time_period
932935
self.dataset_path = dataset_path
933936
self._entity_rel_cache = None

policyengine_us_data/calibration/validate_staging.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from policyengine_us_data.calibration.sanity_checks import (
4242
run_sanity_checks,
4343
)
44+
from policyengine_us_data.db.create_database_tables import refresh_sql_views
4445

4546
logger = logging.getLogger(__name__)
4647

@@ -537,6 +538,7 @@ def _validate_single_area(
537538
from sqlalchemy import create_engine as _create_engine
538539

539540
engine = _create_engine(f"sqlite:///{db_path}")
541+
refresh_sql_views(engine)
540542

541543
logger.info("Loading sim from %s", h5_path)
542544
try:
@@ -1015,6 +1017,7 @@ def main(argv=None):
10151017
from policyengine_us import Microsimulation
10161018

10171019
engine = create_engine(f"sqlite:///{args.db_path}")
1020+
refresh_sql_views(engine)
10181021

10191022
all_targets = _query_all_active_targets(engine, args.period)
10201023
logger.info("Loaded %d active targets from DB", len(all_targets))

policyengine_us_data/db/create_database_tables.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,16 @@ def validate_parent_child_constraints(mapper, connection, target: Stratum):
353353
"""
354354

355355

356+
def refresh_sql_views(engine) -> None:
357+
"""Recreate derived SQL views so existing DBs pick up schema changes."""
358+
with engine.connect() as conn:
359+
conn.execute(text("DROP VIEW IF EXISTS target_overview"))
360+
conn.execute(text("DROP VIEW IF EXISTS stratum_domain"))
361+
conn.execute(text(STRATUM_DOMAIN_VIEW))
362+
conn.execute(text(TARGET_OVERVIEW_VIEW))
363+
conn.commit()
364+
365+
356366
def create_validation_triggers(engine) -> None:
357367
"""Create SQL triggers that validate fields against field_valid_values.
358368
@@ -506,11 +516,8 @@ def create_database(
506516
# Create validation triggers
507517
create_validation_triggers(engine)
508518

509-
# Create SQL views
510-
with engine.connect() as conn:
511-
conn.execute(text(STRATUM_DOMAIN_VIEW))
512-
conn.execute(text(TARGET_OVERVIEW_VIEW))
513-
conn.commit()
519+
# Recreate SQL views so existing DB files do not keep stale definitions.
520+
refresh_sql_views(engine)
514521

515522
logger.info(f"Database and tables created successfully at {db_uri}")
516523
return engine

policyengine_us_data/tests/test_schema_views_and_lookups.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
import tempfile
1212
import unittest
1313

14+
from sqlalchemy import text
1415
from sqlmodel import Session
1516

1617
from policyengine_us_data.db.create_database_tables import (
1718
Stratum,
1819
StratumConstraint,
1920
Target,
2021
create_database,
22+
refresh_sql_views,
2123
)
2224
from policyengine_us_data.utils.db import get_geographic_strata
2325
from policyengine_us_data.calibration.calibration_utils import (
@@ -399,6 +401,24 @@ def test_reform_id_passthrough(self):
399401
self.assertEqual(len(matches), 1)
400402
self.assertEqual(matches[0][reform_idx], 1)
401403

404+
def test_refresh_sql_views_updates_existing_target_overview(self):
405+
"""Refreshing views updates stale target_overview definitions."""
406+
with self.engine.connect() as conn:
407+
conn.execute(text("DROP VIEW IF EXISTS target_overview"))
408+
conn.execute(
409+
text(
410+
"CREATE VIEW target_overview AS "
411+
"SELECT target_id, stratum_id, variable, value, period, active "
412+
"FROM targets"
413+
)
414+
)
415+
conn.commit()
416+
417+
refresh_sql_views(self.engine)
418+
419+
cols = self._overview_columns()
420+
self.assertIn("reform_id", cols)
421+
402422
# ----------------------------------------------------------------
403423
# get_geographic_strata()
404424
# ----------------------------------------------------------------

0 commit comments

Comments
 (0)