Skip to content

Commit dad174b

Browse files
authored
Honor stored tax-benefit model versions at runtime (#220)
* Honor stored model versions at runtime * Validate shared runtime model bundles * Harden shared runtime bundle resolution
1 parent 397bc50 commit dad174b

5 files changed

Lines changed: 296 additions & 23 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Reject shared runtime bundle reuse when the compared database rows point at different model identities, even if their runtime version strings match.

src/policyengine_api/api/analysis.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@
6464
TaxBenefitModel,
6565
TaxBenefitModelVersion,
6666
)
67+
from policyengine_api.runtime_versions import (
68+
resolve_shared_runtime_model_version_from_db,
69+
)
6770
from policyengine_api.services.database import get_session
6871
from policyengine_api.services.model_resolver import (
6972
resolve_country_from_simulation,
@@ -752,7 +755,6 @@ def _run_local_economy_comparison_uk(
752755
from policyengine.core.dynamic import Dynamic as PEDynamic
753756
from policyengine.core.policy import ParameterValue as PEParameterValue
754757
from policyengine.core.policy import Policy as PEPolicy
755-
from policyengine.tax_benefit_models.uk import uk_latest
756758
from policyengine.tax_benefit_models.uk.datasets import PolicyEngineUKDataset
757759

758760
from policyengine_api.models import Policy as DBPolicy
@@ -778,7 +780,11 @@ def _run_local_economy_comparison_uk(
778780
if not dataset:
779781
raise ValueError(f"Dataset {baseline_sim.dataset_id} not found")
780782

781-
pe_model_version = uk_latest
783+
pe_model_version = resolve_shared_runtime_model_version_from_db(
784+
session,
785+
baseline_sim.tax_benefit_model_version_id,
786+
reform_sim.tax_benefit_model_version_id,
787+
)
782788
param_lookup = {p.name: p for p in pe_model_version.parameters}
783789

784790
def build_policy(policy_id):
@@ -937,7 +943,6 @@ def _run_local_economy_comparison_us(
937943
from policyengine.core.dynamic import Dynamic as PEDynamic
938944
from policyengine.core.policy import ParameterValue as PEParameterValue
939945
from policyengine.core.policy import Policy as PEPolicy
940-
from policyengine.tax_benefit_models.us import us_latest
941946
from policyengine.tax_benefit_models.us.datasets import PolicyEngineUSDataset
942947

943948
from policyengine_api.models import Policy as DBPolicy
@@ -963,7 +968,11 @@ def _run_local_economy_comparison_us(
963968
if not dataset:
964969
raise ValueError(f"Dataset {baseline_sim.dataset_id} not found")
965970

966-
pe_model_version = us_latest
971+
pe_model_version = resolve_shared_runtime_model_version_from_db(
972+
session,
973+
baseline_sim.tax_benefit_model_version_id,
974+
reform_sim.tax_benefit_model_version_id,
975+
)
967976
param_lookup = {p.name: p for p in pe_model_version.parameters}
968977

969978
def build_policy(policy_id):

src/policyengine_api/modal/functions/__init__.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
get_database_url,
1616
get_db_session,
1717
)
18+
from policyengine_api.runtime_versions import (
19+
resolve_runtime_model_version_from_db,
20+
resolve_shared_runtime_model_version_from_db,
21+
)
1822

1923
# Required environment variables from each secret
2024
REQUIRED_DB_VARS = ["DATABASE_URL", "SUPABASE_URL", "SUPABASE_KEY"]
@@ -674,12 +678,13 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N
674678

675679
# Import policyengine
676680
from policyengine.core import Simulation as PESimulation
677-
from policyengine.tax_benefit_models.uk import uk_latest
678681
from policyengine.tax_benefit_models.uk.datasets import (
679682
PolicyEngineUKDataset,
680683
)
681684

682-
pe_model_version = uk_latest
685+
pe_model_version = resolve_runtime_model_version_from_db(
686+
session, simulation.tax_benefit_model_version_id
687+
)
683688

684689
# Get policy and dynamic
685690
policy = _get_pe_policy_uk(
@@ -847,12 +852,13 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N
847852

848853
# Import policyengine
849854
from policyengine.core import Simulation as PESimulation
850-
from policyengine.tax_benefit_models.us import us_latest
851855
from policyengine.tax_benefit_models.us.datasets import (
852856
PolicyEngineUSDataset,
853857
)
854858

855-
pe_model_version = us_latest
859+
pe_model_version = resolve_runtime_model_version_from_db(
860+
session, simulation.tax_benefit_model_version_id
861+
)
856862

857863
# Get policy and dynamic
858864
policy = _get_pe_policy_us(
@@ -1008,7 +1014,6 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None:
10081014
ReportStatus,
10091015
Simulation,
10101016
SimulationStatus,
1011-
TaxBenefitModelVersion,
10121017
)
10131018

10141019
with Session(engine) as session:
@@ -1035,12 +1040,6 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None:
10351040
if not dataset:
10361041
raise ValueError(f"Dataset {baseline_sim.dataset_id} not found")
10371042

1038-
# Get model version (unused but keeping for reference)
1039-
_ = session.get(
1040-
TaxBenefitModelVersion,
1041-
baseline_sim.tax_benefit_model_version_id,
1042-
)
1043-
10441043
# Import policyengine
10451044
from policyengine.core import Simulation as PESimulation
10461045
from policyengine.outputs import DecileImpact as PEDecileImpact
@@ -1050,15 +1049,18 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None:
10501049
from policyengine.outputs.aggregate import (
10511050
AggregateType as PEAggregateType,
10521051
)
1053-
from policyengine.tax_benefit_models.uk import uk_latest
10541052
from policyengine.tax_benefit_models.uk.datasets import (
10551053
PolicyEngineUKDataset,
10561054
)
10571055
from policyengine.tax_benefit_models.uk.outputs import (
10581056
ProgrammeStatistics as PEProgrammeStats,
10591057
)
10601058

1061-
pe_model_version = uk_latest
1059+
pe_model_version = resolve_shared_runtime_model_version_from_db(
1060+
session,
1061+
baseline_sim.tax_benefit_model_version_id,
1062+
reform_sim.tax_benefit_model_version_id,
1063+
)
10621064

10631065
# Get policies
10641066
baseline_policy = _get_pe_policy_uk(
@@ -1725,15 +1727,18 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None:
17251727
from policyengine.outputs.aggregate import (
17261728
AggregateType as PEAggregateType,
17271729
)
1728-
from policyengine.tax_benefit_models.us import us_latest
17291730
from policyengine.tax_benefit_models.us.datasets import (
17301731
PolicyEngineUSDataset,
17311732
)
17321733
from policyengine.tax_benefit_models.us.outputs import (
17331734
ProgramStatistics as PEProgramStats,
17341735
)
17351736

1736-
pe_model_version = us_latest
1737+
pe_model_version = resolve_shared_runtime_model_version_from_db(
1738+
session,
1739+
baseline_sim.tax_benefit_model_version_id,
1740+
reform_sim.tax_benefit_model_version_id,
1741+
)
17371742

17381743
# Get policies
17391744
baseline_policy = _get_pe_policy_us(
@@ -2631,7 +2636,6 @@ def compute_aggregate_uk(aggregate_id: str, traceparent: str | None = None) -> N
26312636
from policyengine.core import Simulation as PESimulation
26322637
from policyengine.outputs import Aggregate as PEAggregate
26332638
from policyengine.outputs import AggregateType as PEAggregateType
2634-
from policyengine.tax_benefit_models.uk import uk_latest
26352639
from policyengine.tax_benefit_models.uk.datasets import (
26362640
PolicyEngineUKDataset,
26372641
)
@@ -2692,6 +2696,10 @@ def compute_aggregate_uk(aggregate_id: str, traceparent: str | None = None) -> N
26922696
)
26932697

26942698
# Create policyengine simulation with loaded output
2699+
pe_model_version = resolve_runtime_model_version_from_db(
2700+
session, simulation.tax_benefit_model_version_id
2701+
)
2702+
26952703
with logfire.span("load_output"):
26962704
pe_output_dataset = PolicyEngineUKDataset(
26972705
name=output_dataset.name or "output",
@@ -2703,7 +2711,7 @@ def compute_aggregate_uk(aggregate_id: str, traceparent: str | None = None) -> N
27032711

27042712
pe_sim = PESimulation(
27052713
dataset=pe_output_dataset, # Use output as dataset
2706-
tax_benefit_model_version=uk_latest,
2714+
tax_benefit_model_version=pe_model_version,
27072715
)
27082716
pe_sim.output_dataset = pe_output_dataset
27092717

@@ -2787,7 +2795,6 @@ def compute_aggregate_us(aggregate_id: str, traceparent: str | None = None) -> N
27872795
from policyengine.core import Simulation as PESimulation
27882796
from policyengine.outputs import Aggregate as PEAggregate
27892797
from policyengine.outputs import AggregateType as PEAggregateType
2790-
from policyengine.tax_benefit_models.us import us_latest
27912798
from policyengine.tax_benefit_models.us.datasets import (
27922799
PolicyEngineUSDataset,
27932800
)
@@ -2841,6 +2848,10 @@ def compute_aggregate_us(aggregate_id: str, traceparent: str | None = None) -> N
28412848
storage_bucket,
28422849
)
28432850

2851+
pe_model_version = resolve_runtime_model_version_from_db(
2852+
session, simulation.tax_benefit_model_version_id
2853+
)
2854+
28442855
with logfire.span("load_output"):
28452856
pe_output_dataset = PolicyEngineUSDataset(
28462857
name=output_dataset.name or "output",
@@ -2852,7 +2863,7 @@ def compute_aggregate_us(aggregate_id: str, traceparent: str | None = None) -> N
28522863

28532864
pe_sim = PESimulation(
28542865
dataset=pe_output_dataset,
2855-
tax_benefit_model_version=us_latest,
2866+
tax_benefit_model_version=pe_model_version,
28562867
)
28572868
pe_sim.output_dataset = pe_output_dataset
28582869

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""Helpers for resolving the deployed PolicyEngine runtime bundle."""
2+
3+
from importlib import import_module
4+
from uuid import UUID
5+
6+
from sqlmodel import Session
7+
8+
9+
def _normalize_model_name(model_name: str) -> str:
10+
return model_name.replace("_", "-").lower()
11+
12+
13+
def _load_runtime_model_version(model_name: str):
14+
normalized_name = _normalize_model_name(model_name)
15+
16+
if normalized_name == "policyengine-uk":
17+
return import_module("policyengine.tax_benefit_models.uk").uk_latest
18+
if normalized_name == "policyengine-us":
19+
return import_module("policyengine.tax_benefit_models.us").us_latest
20+
21+
raise ValueError(f"Unsupported tax-benefit model '{model_name}'")
22+
23+
24+
def resolve_runtime_model_version_from_db(
25+
session: Session,
26+
tax_benefit_model_version_id: UUID,
27+
):
28+
"""Resolve the deployed policyengine model version for a stored DB row.
29+
30+
The current deployment only has one runtime bundle per country. If the
31+
stored DB version does not match the deployed runtime bundle, fail clearly
32+
instead of silently executing against `*_latest`.
33+
"""
34+
from policyengine_api.models import TaxBenefitModel, TaxBenefitModelVersion
35+
36+
db_version = session.get(TaxBenefitModelVersion, tax_benefit_model_version_id)
37+
if db_version is None:
38+
raise ValueError(
39+
f"Tax-benefit model version {tax_benefit_model_version_id} not found"
40+
)
41+
42+
db_model = session.get(TaxBenefitModel, db_version.model_id)
43+
if db_model is None:
44+
raise ValueError(f"Tax-benefit model {db_version.model_id} not found")
45+
46+
runtime_model_version = _load_runtime_model_version(db_model.name)
47+
runtime_version = getattr(runtime_model_version, "version", None)
48+
49+
if runtime_version != db_version.version:
50+
raise ValueError(
51+
"Stored tax-benefit model version "
52+
f"{db_model.name}@{db_version.version} does not match the deployed "
53+
f"runtime bundle {db_model.name}@{runtime_version}. "
54+
"Re-seed this environment with the deployed bundle or re-run the "
55+
"analysis against the currently deployed version."
56+
)
57+
58+
return runtime_model_version
59+
60+
61+
def _resolve_runtime_model_entry_from_db(
62+
session: Session,
63+
tax_benefit_model_version_id: UUID,
64+
):
65+
from policyengine_api.models import TaxBenefitModel, TaxBenefitModelVersion
66+
67+
db_version = session.get(TaxBenefitModelVersion, tax_benefit_model_version_id)
68+
if db_version is None:
69+
raise ValueError(
70+
f"Tax-benefit model version {tax_benefit_model_version_id} not found"
71+
)
72+
73+
db_model = session.get(TaxBenefitModel, db_version.model_id)
74+
if db_model is None:
75+
raise ValueError(f"Tax-benefit model {db_version.model_id} not found")
76+
77+
runtime_model_version = resolve_runtime_model_version_from_db(
78+
session, tax_benefit_model_version_id
79+
)
80+
return _normalize_model_name(db_model.name), runtime_model_version
81+
82+
83+
def resolve_shared_runtime_model_version_from_db(
84+
session: Session,
85+
*tax_benefit_model_version_ids: UUID,
86+
):
87+
"""Resolve one deployed runtime model version shared across DB version rows."""
88+
if not tax_benefit_model_version_ids:
89+
raise ValueError("At least one tax-benefit model version ID is required")
90+
91+
resolved_entries = [
92+
_resolve_runtime_model_entry_from_db(session, version_id)
93+
for version_id in tax_benefit_model_version_ids
94+
]
95+
first_model_name, first_version = resolved_entries[0]
96+
97+
for model_name, runtime_model_version in resolved_entries[1:]:
98+
if model_name != first_model_name:
99+
raise ValueError(
100+
"All simulations in a comparison must use the same tax-benefit "
101+
f"model. Got {first_model_name} and {model_name}."
102+
)
103+
if runtime_model_version.version != first_version.version:
104+
raise ValueError(
105+
"All simulations in a comparison must use the same deployed "
106+
f"runtime bundle. Got {first_version.version} and "
107+
f"{runtime_model_version.version}."
108+
)
109+
110+
return first_version

0 commit comments

Comments
 (0)