Skip to content

Commit 906054a

Browse files
committed
Refresh downloaded policy DB views
1 parent f2a0c2e commit 906054a

3 files changed

Lines changed: 31 additions & 2 deletions

File tree

policyengine_us_data/db/create_database_tables.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
import logging
21
import hashlib
2+
import logging
3+
from pathlib import Path
34
from typing import List, Optional
45

56
from sqlalchemy import event, text, UniqueConstraint
@@ -522,5 +523,14 @@ def create_or_replace_views(engine) -> None:
522523
conn.commit()
523524

524525

526+
def refresh_views_for_db_path(db_path: str | Path) -> None:
527+
"""Refresh SQL views for an existing SQLite database file."""
528+
engine = create_engine(f"sqlite:///{Path(db_path)}")
529+
try:
530+
create_or_replace_views(engine)
531+
finally:
532+
engine.dispose()
533+
534+
525535
if __name__ == "__main__":
526536
engine = create_database()

policyengine_us_data/storage/download_private_prerequisites.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import os
22

3-
from policyengine_us_data.utils.huggingface import download
43
from pathlib import Path
4+
from policyengine_us_data.db.create_database_tables import (
5+
refresh_views_for_db_path,
6+
)
7+
from policyengine_us_data.utils.huggingface import download
58

69
FOLDER = Path(__file__).parent
710

@@ -41,3 +44,4 @@
4144
local_folder=FOLDER,
4245
version=None,
4346
)
47+
refresh_views_for_db_path(FOLDER / "calibration" / "policy_data.db")

policyengine_us_data/tests/test_calibration/conftest.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,25 @@
11
"""Shared fixtures for local area calibration tests."""
22

33
import pytest
4+
from sqlalchemy import create_engine
45

6+
from policyengine_us_data.db.create_database_tables import (
7+
create_or_replace_views,
8+
)
59
from policyengine_us_data.storage import STORAGE_FOLDER
610

711

12+
@pytest.fixture(scope="session", autouse=True)
13+
def refresh_policy_db_views():
14+
db_path = STORAGE_FOLDER / "calibration" / "policy_data.db"
15+
if db_path.exists():
16+
engine = create_engine(f"sqlite:///{db_path}")
17+
try:
18+
create_or_replace_views(engine)
19+
finally:
20+
engine.dispose()
21+
22+
823
@pytest.fixture(scope="module")
924
def db_uri():
1025
db_path = STORAGE_FOLDER / "calibration" / "policy_data.db"

0 commit comments

Comments
 (0)