From 01d5cd35fd4322a4276889722453fcf9d7011c6e Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 21 May 2026 21:13:19 +0200 Subject: [PATCH 01/23] chore: upgrade simulation runtime to PolicyEngine v4.10 --- .../fixtures/test_simulation_api_contracts.py | 86 +++++++++++++++ .../pyproject.toml | 4 +- .../src/modal/app.py | 3 +- .../tests/test_simulation_api_contracts.py | 102 ++++++++++++++++++ 4 files changed, 192 insertions(+), 3 deletions(-) create mode 100644 projects/policyengine-api-simulation/fixtures/test_simulation_api_contracts.py create mode 100644 projects/policyengine-api-simulation/tests/test_simulation_api_contracts.py diff --git a/projects/policyengine-api-simulation/fixtures/test_simulation_api_contracts.py b/projects/policyengine-api-simulation/fixtures/test_simulation_api_contracts.py new file mode 100644 index 000000000..feaa1aee6 --- /dev/null +++ b/projects/policyengine-api-simulation/fixtures/test_simulation_api_contracts.py @@ -0,0 +1,86 @@ +"""Fixtures for simulation API contract tests.""" + +CURRENT_SINGLE_YEAR_MACRO_KEYS = { + "model_version", + "data_version", + "budget", + "detailed_budget", + "decile", + "inequality", + "poverty", + "poverty_by_gender", + "poverty_by_race", + "intra_decile", + "wealth_decile", + "intra_wealth_decile", + "labor_supply_response", + "constituency_impact", + "local_authority_impact", + "congressional_district_impact", + "cliff_impact", +} + +CURRENT_REQUIRED_BUDGET_KEYS = { + "budgetary_impact", + "tax_revenue_impact", + "state_tax_revenue_impact", + "benefit_spending_impact", + "households", + "baseline_net_income", +} + +CURRENT_SINGLE_YEAR_MACRO_RESULT = { + "model_version": "1.691.3", + "data_version": "1.110.12", + "budget": { + "budgetary_impact": 300.0, + "tax_revenue_impact": 500.0, + "state_tax_revenue_impact": 125.0, + "benefit_spending_impact": 200.0, + "households": 2.0, + "baseline_net_income": 1000.0, + }, + "detailed_budget": { + "income_tax": { + "baseline": 1000.0, + "reform": 1100.0, + "difference": 100.0, + } + }, + "decile": { + "relative": {"1": 0.01}, + "average": {"1": 10.0}, + }, + "inequality": { + "baseline": {"gini": 0.3}, + "reform": {"gini": 0.29}, + }, + "poverty": { + "baseline": {"all": 0.1}, + "reform": {"all": 0.09}, + }, + "poverty_by_gender": { + "baseline": {"male": 0.1, "female": 0.11}, + "reform": {"male": 0.09, "female": 0.1}, + }, + "poverty_by_race": None, + "intra_decile": { + "relative": {"1": {"1": 0.01}}, + "average": {"1": {"1": 10.0}}, + }, + "wealth_decile": None, + "intra_wealth_decile": None, + "labor_supply_response": { + "substitution_lsr": 0.0, + "income_lsr": 0.0, + "relative_lsr": {}, + "total_change": 0.0, + "revenue_change": 0.0, + "decile": {}, + "hours": {"baseline": 0.0, "reform": 0.0, "change": 0.0}, + }, + "constituency_impact": None, + "local_authority_impact": None, + "congressional_district_impact": None, + "cliff_impact": None, +} diff --git a/projects/policyengine-api-simulation/pyproject.toml b/projects/policyengine-api-simulation/pyproject.toml index 84a91564a..bd3357056 100644 --- a/projects/policyengine-api-simulation/pyproject.toml +++ b/projects/policyengine-api-simulation/pyproject.toml @@ -16,8 +16,8 @@ dependencies = [ "pydantic-settings (>=2.7.1,<3.0.0)", "opentelemetry-instrumentation-fastapi (>=0.51b0,<0.52)", "policyengine-fastapi", - "policyengine==0.13.0", - "policyengine-core>=3.23.5", + "policyengine==4.10.0", + "policyengine-core==3.26.1", "policyengine-uk==2.88.20", "policyengine-us==1.702.0", "tables>=3.10.2", diff --git a/projects/policyengine-api-simulation/src/modal/app.py b/projects/policyengine-api-simulation/src/modal/app.py index dd4a0b5fc..30a9a71b4 100644 --- a/projects/policyengine-api-simulation/src/modal/app.py +++ b/projects/policyengine-api-simulation/src/modal/app.py @@ -49,7 +49,8 @@ def get_app_name(us_version: str, uk_version: str) -> str: .pip_install( f"policyengine-us=={US_VERSION}", f"policyengine-uk=={UK_VERSION}", - "policyengine==0.13.0", + "policyengine==4.10.0", + "policyengine-core==3.26.1", "tables>=3.10.2", "logfire", ) diff --git a/projects/policyengine-api-simulation/tests/test_simulation_api_contracts.py b/projects/policyengine-api-simulation/tests/test_simulation_api_contracts.py new file mode 100644 index 000000000..dd8099171 --- /dev/null +++ b/projects/policyengine-api-simulation/tests/test_simulation_api_contracts.py @@ -0,0 +1,102 @@ +"""Contract tests for simulation API response shapes.""" + +from src.modal.gateway.generate_openapi import create_openapi_app +from src.modal.gateway.models import ( + BudgetWindowAnnualImpact, + BudgetWindowResult, + BudgetWindowTotals, + JobStatusResponse, +) + +from fixtures.test_simulation_api_contracts import ( + CURRENT_REQUIRED_BUDGET_KEYS, + CURRENT_SINGLE_YEAR_MACRO_KEYS, + CURRENT_SINGLE_YEAR_MACRO_RESULT, +) + + +def test_job_status_result_preserves_current_single_year_macro_dict_contract(): + response = JobStatusResponse( + status="complete", + result=CURRENT_SINGLE_YEAR_MACRO_RESULT, + ) + + assert response.result is not None + assert set(response.result) == CURRENT_SINGLE_YEAR_MACRO_KEYS + assert set(response.result["budget"]) == CURRENT_REQUIRED_BUDGET_KEYS + assert ( + response.model_dump(mode="json")["result"] == CURRENT_SINGLE_YEAR_MACRO_RESULT + ) + + +def test_openapi_keeps_job_status_result_as_unstructured_dict(): + spec = create_openapi_app().openapi() + schemas = spec["components"]["schemas"] + + assert "SingleYearMacroOutput" not in schemas + result_schema = schemas["JobStatusResponse"]["properties"]["result"] + assert result_schema == { + "anyOf": [ + { + "additionalProperties": True, + "type": "object", + }, + { + "type": "null", + }, + ], + "title": "Result", + } + + +def test_budget_window_result_keeps_compact_public_contract(): + result = BudgetWindowResult( + startYear="2026", + endYear="2027", + windowSize=2, + annualImpacts=[ + BudgetWindowAnnualImpact( + year="2026", + taxRevenueImpact=10.0, + federalTaxRevenueImpact=7.0, + stateTaxRevenueImpact=3.0, + benefitSpendingImpact=2.0, + budgetaryImpact=8.0, + ) + ], + totals=BudgetWindowTotals( + taxRevenueImpact=10.0, + federalTaxRevenueImpact=7.0, + stateTaxRevenueImpact=3.0, + benefitSpendingImpact=2.0, + budgetaryImpact=8.0, + ), + ) + + dumped = result.model_dump(mode="json") + assert set(dumped) == { + "kind", + "startYear", + "endYear", + "windowSize", + "annualImpacts", + "totals", + } + assert dumped["kind"] == "budgetWindow" + assert dumped["totals"]["year"] == "Total" + assert "outputsByYear" not in dumped + + +def test_openapi_keeps_budget_window_result_compact(): + spec = create_openapi_app().openapi() + budget_window_schema = spec["components"]["schemas"]["BudgetWindowResult"] + + assert set(budget_window_schema["properties"]) == { + "kind", + "startYear", + "endYear", + "windowSize", + "annualImpacts", + "totals", + } + assert "outputsByYear" not in budget_window_schema["properties"] From 36d3be0951a16d25086ecbaf344e6b6945070388 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 20 May 2026 20:02:24 +0200 Subject: [PATCH 02/23] feat: adapt simulation worker to PolicyEngine v4 outputs --- .../fixtures/test_simulation_api_contracts.py | 51 ++- .../test_simulation_output_adapter.py | 169 +++++++ .../src/modal/gateway/endpoints.py | 20 +- .../src/modal/simulation.py | 414 +++++++++++++++++- .../src/modal/simulation_output_adapter.py | 318 ++++++++++++++ .../policyengine_api_simulation/simulation.py | 25 +- .../tests/gateway/test_endpoints.py | 14 +- .../tests/gateway/test_models.py | 8 +- .../tests/test_simulation_output_adapter.py | 148 +++++++ 9 files changed, 1103 insertions(+), 64 deletions(-) create mode 100644 projects/policyengine-api-simulation/fixtures/test_simulation_output_adapter.py create mode 100644 projects/policyengine-api-simulation/src/modal/simulation_output_adapter.py create mode 100644 projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py diff --git a/projects/policyengine-api-simulation/fixtures/test_simulation_api_contracts.py b/projects/policyengine-api-simulation/fixtures/test_simulation_api_contracts.py index feaa1aee6..28e2dd1ab 100644 --- a/projects/policyengine-api-simulation/fixtures/test_simulation_api_contracts.py +++ b/projects/policyengine-api-simulation/fixtures/test_simulation_api_contracts.py @@ -31,7 +31,7 @@ CURRENT_SINGLE_YEAR_MACRO_RESULT = { "model_version": "1.691.3", - "data_version": "1.110.12", + "data_version": "1.115.3", "budget": { "budgetary_impact": 300.0, "tax_revenue_impact": 500.0, @@ -56,17 +56,52 @@ "reform": {"gini": 0.29}, }, "poverty": { - "baseline": {"all": 0.1}, - "reform": {"all": 0.09}, + "poverty": { + "adult": {"baseline": 0.1, "reform": 0.09}, + "all": {"baseline": 0.1, "reform": 0.09}, + "child": {"baseline": 0.12, "reform": 0.1}, + "senior": {"baseline": 0.08, "reform": 0.07}, + }, + "deep_poverty": { + "adult": {"baseline": 0.03, "reform": 0.02}, + "all": {"baseline": 0.03, "reform": 0.02}, + "child": {"baseline": 0.04, "reform": 0.03}, + "senior": {"baseline": 0.02, "reform": 0.01}, + }, }, "poverty_by_gender": { - "baseline": {"male": 0.1, "female": 0.11}, - "reform": {"male": 0.09, "female": 0.1}, + "poverty": { + "male": {"baseline": 0.1, "reform": 0.09}, + "female": {"baseline": 0.11, "reform": 0.1}, + }, + "deep_poverty": { + "male": {"baseline": 0.03, "reform": 0.02}, + "female": {"baseline": 0.04, "reform": 0.03}, + }, + }, + "poverty_by_race": { + "poverty": { + "black": {"baseline": 0.12, "reform": 0.11}, + "hispanic": {"baseline": 0.13, "reform": 0.12}, + "other": {"baseline": 0.1, "reform": 0.09}, + "white": {"baseline": 0.08, "reform": 0.07}, + }, }, - "poverty_by_race": None, "intra_decile": { - "relative": {"1": {"1": 0.01}}, - "average": {"1": {"1": 10.0}}, + "all": { + "Gain less than 5%": 0.2, + "Gain more than 5%": 0.1, + "Lose less than 5%": 0.1, + "Lose more than 5%": 0.0, + "No change": 0.6, + }, + "deciles": { + "Gain less than 5%": [0.2], + "Gain more than 5%": [0.1], + "Lose less than 5%": [0.1], + "Lose more than 5%": [0.0], + "No change": [0.6], + }, }, "wealth_decile": None, "intra_wealth_decile": None, diff --git a/projects/policyengine-api-simulation/fixtures/test_simulation_output_adapter.py b/projects/policyengine-api-simulation/fixtures/test_simulation_output_adapter.py new file mode 100644 index 000000000..b593a0be4 --- /dev/null +++ b/projects/policyengine-api-simulation/fixtures/test_simulation_output_adapter.py @@ -0,0 +1,169 @@ +"""Fixtures for PolicyEngine v4 output adapter tests.""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pandas as pd + + +class FakeCollection: + def __init__(self, records): + self.dataframe = pd.DataFrame(records) + + +class FakeModelOutput: + def __init__(self, payload): + self.payload = payload + + def model_dump(self, *, mode): + assert mode == "json" + return self.payload + + +def fake_analysis(): + return SimpleNamespace( + program_statistics=FakeCollection( + [ + { + "program_name": "income_tax", + "baseline_total": 100.0, + "reform_total": 125.0, + "change": 25.0, + } + ] + ), + decile_impacts=FakeCollection( + [ + { + "decile": 2, + "absolute_change": 20.0, + "relative_change": 0.02, + }, + { + "decile": 1, + "absolute_change": 10.0, + "relative_change": 0.01, + }, + ] + ), + wealth_decile_impacts=FakeCollection( + [ + { + "decile": 1, + "absolute_change": 30.0, + "relative_change": 0.03, + } + ] + ), + intra_wealth_decile_impacts=FakeCollection( + [ + { + "decile": 1, + "lose_more_than_5pct": 0.1, + "lose_less_than_5pct": 0.2, + "no_change": 0.3, + "gain_less_than_5pct": 0.4, + "gain_more_than_5pct": 0.5, + } + ] + ), + baseline_poverty=FakeCollection( + [{"poverty_type": "spm", "filter_group": None, "rate": 0.10}] + ), + reform_poverty=FakeCollection( + [{"poverty_type": "spm", "filter_group": None, "rate": 0.09}] + ), + baseline_inequality=SimpleNamespace( + gini=0.40, + top_10_share=0.30, + top_1_share=0.10, + ), + reform_inequality=SimpleNamespace( + gini=0.39, + top_10_share=0.29, + top_1_share=0.09, + ), + labor_supply_response=FakeModelOutput( + { + "substitution_lsr": 0.0, + "income_lsr": 0.0, + "relative_lsr": {"income": 0.0, "substitution": 0.0}, + "total_change": 0.0, + "revenue_change": 0.0, + "decile": { + "average": {"income": {}, "substitution": {}}, + "relative": {"income": {}, "substitution": {}}, + }, + "hours": { + "baseline": 0.0, + "reform": 0.0, + "change": 0.0, + "income_effect": 0.0, + "substitution_effect": 0.0, + }, + } + ), + ) + + +INTRA_DECILE_COLLECTION = FakeCollection( + [ + { + "decile": 1, + "lose_more_than_5pct": 0.1, + "lose_less_than_5pct": 0.2, + "no_change": 0.3, + "gain_less_than_5pct": 0.4, + "gain_more_than_5pct": 0.5, + }, + { + "decile": 2, + "lose_more_than_5pct": 0.0, + "lose_less_than_5pct": 0.1, + "no_change": 0.6, + "gain_less_than_5pct": 0.2, + "gain_more_than_5pct": 0.1, + }, + ] +) + +BASELINE_POVERTY_BY_AGE = FakeCollection( + [ + {"poverty_type": "spm", "filter_group": "child", "rate": 0.12}, + {"poverty_type": "spm_deep", "filter_group": "child", "rate": 0.04}, + ] +) +REFORM_POVERTY_BY_AGE = FakeCollection( + [ + {"poverty_type": "spm", "filter_group": "child", "rate": 0.11}, + {"poverty_type": "spm_deep", "filter_group": "child", "rate": 0.03}, + ] +) +BASELINE_POVERTY_BY_GENDER = FakeCollection( + [{"poverty_type": "spm", "filter_group": "male", "rate": 0.08}] +) +REFORM_POVERTY_BY_GENDER = FakeCollection( + [{"poverty_type": "spm", "filter_group": "male", "rate": 0.07}] +) +BASELINE_POVERTY_BY_RACE = FakeCollection( + [ + {"poverty_type": "spm", "filter_group": "white", "rate": 0.06}, + {"poverty_type": "spm_deep", "filter_group": "white", "rate": 0.02}, + ] +) +REFORM_POVERTY_BY_RACE = FakeCollection( + [ + {"poverty_type": "spm", "filter_group": "white", "rate": 0.05}, + {"poverty_type": "spm_deep", "filter_group": "white", "rate": 0.01}, + ] +) + +BUDGET = { + "tax_revenue_impact": 100.0, + "state_tax_revenue_impact": 20.0, + "benefit_spending_impact": 30.0, + "budgetary_impact": 70.0, + "households": 2.0, + "baseline_net_income": 1000.0, +} diff --git a/projects/policyengine-api-simulation/src/modal/gateway/endpoints.py b/projects/policyengine-api-simulation/src/modal/gateway/endpoints.py index be42b648e..a3cc84f49 100644 --- a/projects/policyengine-api-simulation/src/modal/gateway/endpoints.py +++ b/projects/policyengine-api-simulation/src/modal/gateway/endpoints.py @@ -41,18 +41,18 @@ JOB_METADATA_DICT_NAME = "simulation-api-job-metadata" DATASET_URIS = { "us": { - "enhanced_cps": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.110.12", - "enhanced_cps_2024": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.110.12", - "cps": "hf://policyengine/policyengine-us-data/cps_2023.h5@1.110.12", - "cps_2023": "hf://policyengine/policyengine-us-data/cps_2023.h5@1.110.12", - "pooled_cps": "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.110.12", - "pooled_3_year_cps_2023": "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.110.12", + "enhanced_cps": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.3", + "enhanced_cps_2024": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.3", + "cps": "hf://policyengine/policyengine-us-data/cps_2023.h5@1.115.3", + "cps_2023": "hf://policyengine/policyengine-us-data/cps_2023.h5@1.115.3", + "pooled_cps": "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.115.3", + "pooled_3_year_cps_2023": "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.115.3", }, "uk": { - "enhanced_frs": "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.40.3", - "enhanced_frs_2023_24": "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.40.3", - "frs": "hf://policyengine/policyengine-uk-data-private/frs_2023_24.h5@1.40.3", - "frs_2023_24": "hf://policyengine/policyengine-uk-data-private/frs_2023_24.h5@1.40.3", + "enhanced_frs": "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.55.5", + "enhanced_frs_2023_24": "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.55.5", + "frs": "hf://policyengine/policyengine-uk-data-private/frs_2023_24.h5@1.55.5", + "frs_2023_24": "hf://policyengine/policyengine-uk-data-private/frs_2023_24.h5@1.55.5", }, } diff --git a/projects/policyengine-api-simulation/src/modal/simulation.py b/projects/policyengine-api-simulation/src/modal/simulation.py index 549ea127c..0d3b6adfc 100644 --- a/projects/policyengine-api-simulation/src/modal/simulation.py +++ b/projects/policyengine-api-simulation/src/modal/simulation.py @@ -1,8 +1,9 @@ """ Simulation implementation - pure logic with snapshotted imports. -This module has policyengine imports at module level so they are -captured in Modal's image snapshot. No Modal dependencies here. +This module avoids importing policyengine at module level so the worker can +load the requested country module without triggering cross-country imports. +No Modal dependencies here. """ import contextlib @@ -10,15 +11,36 @@ import logging import os import tempfile -from typing import Iterator - -# Module-level imports - these are SNAPSHOTTED at image build time -from policyengine.simulation import Simulation, SimulationOptions +from importlib import import_module +from typing import Any, Iterator +from src.modal.simulation_output_adapter import adapt_analysis_to_legacy_macro_output from src.modal.telemetry import split_internal_payload logger = logging.getLogger(__name__) +os.environ.setdefault("POLICYENGINE_SKIP_COUNTRY_IMPORTS", "1") + +DEFAULT_YEAR = 2026 +DATASET_ALIASES = { + "us": { + "enhanced_cps": "enhanced_cps_2024", + "enhanced_cps_2024": "enhanced_cps_2024", + "gs://policyengine-us-data/enhanced_cps_2024.h5": "enhanced_cps_2024", + "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5": "enhanced_cps_2024", + "cps_small": "cps_small_2024", + "cps_small_2024": "cps_small_2024", + }, + "uk": { + "enhanced_frs": "enhanced_frs_2023_24", + "enhanced_frs_2023_24": "enhanced_frs_2023_24", + "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5": "enhanced_frs_2023_24", + "frs": "frs_2023_24", + "frs_2023_24": "frs_2023_24", + "hf://policyengine/policyengine-uk-data-private/frs_2023_24.h5": "frs_2023_24", + }, +} + def _normalize_credentials_blob(creds_json: str) -> str: """Return the raw JSON blob, decoding the outer escape if present. @@ -107,7 +129,7 @@ def run_simulation_impl(params: dict) -> dict: Execute economic simulation. Pure implementation with no Modal dependencies. - Accepts SimulationOptions as a dict and returns EconomyComparison as a dict. + Accepts the gateway simulation payload and returns the legacy macro result dict. """ # Set up GCP credentials if needed. The credentials temp file is # cleaned up on exit so we never leave signed JSON material on disk. @@ -115,6 +137,340 @@ def run_simulation_impl(params: dict) -> dict: return _run_simulation_impl_core(params) +def _parse_year(params: dict[str, Any]) -> int: + value = params.get("time_period") or params.get("year") or DEFAULT_YEAR + return int(value) + + +def _normalise_period_key(period_key: Any) -> str: + """Convert legacy ``start.stop`` period keys to v4 effective dates.""" + text = str(period_key) + parts = text.split(".") + if len(parts) > 1 and len(parts[0]) == 10: + return parts[0] + return text + + +def _normalise_policy(policy: dict[str, Any] | None) -> dict[str, Any] | None: + if not policy: + return None + + normalised: dict[str, Any] = {} + for parameter, value in policy.items(): + if isinstance(value, dict): + normalised[parameter] = { + _normalise_period_key(period): period_value + for period, period_value in value.items() + } + else: + normalised[parameter] = value + return normalised + + +def _resolve_dataset_name(country: str, requested_data: str | None) -> str: + if requested_data is None: + return "enhanced_cps_2024" if country == "us" else "enhanced_frs_2023_24" + + requested_without_revision = requested_data.split("@", maxsplit=1)[0] + return DATASET_ALIASES.get(country, {}).get( + requested_without_revision, requested_data + ) + + +def _microframe_like(frame, weights: str): + from microdf import MicroDataFrame + + return MicroDataFrame(frame.copy(), weights=weights) + + +def _person_group_column(person, entity: str) -> str: + prefixed = f"person_{entity}_id" + if prefixed in person.columns: + return prefixed + return f"{entity}_id" + + +def _subsample_us_dataset(dataset, subsample: int | None): + if not subsample: + return dataset + + from policyengine.tax_benefit_models.us.datasets import ( + PolicyEngineUSDataset, + USYearData, + ) + + dataset.load() + data = dataset.data + household = data.household.head(int(subsample)).copy() + household_ids = set(household["household_id"]) + + person_household_col = _person_group_column(data.person, "household") + person = data.person[data.person[person_household_col].isin(household_ids)].copy() + + def group_subset(entity: str): + person_col = _person_group_column(person, entity) + entity_id_col = f"{entity}_id" + ids = set(person[person_col]) + frame = getattr(data, entity) + return frame[frame[entity_id_col].isin(ids)].copy() + + subset_data = USYearData( + person=_microframe_like(person, "person_weight"), + marital_unit=_microframe_like( + group_subset("marital_unit"), "marital_unit_weight" + ), + family=_microframe_like(group_subset("family"), "family_weight"), + spm_unit=_microframe_like(group_subset("spm_unit"), "spm_unit_weight"), + tax_unit=_microframe_like(group_subset("tax_unit"), "tax_unit_weight"), + household=_microframe_like(household, "household_weight"), + ) + subset_path = os.path.join( + os.environ.get("POLICYENGINE_DATA_FOLDER", "/tmp/policyengine-data"), + f"{dataset.id}_subsample_{subsample}.h5", + ) + return PolicyEngineUSDataset( + id=f"{dataset.id}_subsample_{subsample}", + name=f"{dataset.name} subsample {subsample}", + description=dataset.description, + filepath=subset_path, + year=dataset.year, + is_output_dataset=dataset.is_output_dataset, + metadata=getattr(dataset, "metadata", {}), + metadata_filepath=getattr(dataset, "metadata_filepath", None), + data=subset_data, + ) + + +def _country_module(country: str): + country = country.lower() + if country not in {"us", "uk"}: + raise ValueError(f"Unsupported country: {country}") + + return import_module(f"policyengine.tax_benefit_models.{country}") + + +def _load_dataset(params: dict[str, Any]): + country = params.get("country", "us").lower() + year = _parse_year(params) + country_module = _country_module(country) + dataset_name = _resolve_dataset_name(country, params.get("data")) + datasets = country_module.ensure_datasets( + datasets=[dataset_name], + years=[year], + data_folder=os.environ.get( + "POLICYENGINE_DATA_FOLDER", "/tmp/policyengine-data" + ), + ) + dataset = next(iter(datasets.values())) + if country == "us": + return _subsample_us_dataset(dataset, params.get("subsample")) + return dataset + + +def _build_simulation( + params: dict[str, Any], + *, + dataset, + policy: dict[str, Any] | None, +): + from policyengine.core import Simulation + + country_module = _country_module(params.get("country", "us")) + return Simulation( + dataset=dataset, + tax_benefit_model_version=country_module.model, + policy=policy, + ) + + +def _entity_data(simulation, entity: str): + if simulation.output_dataset is None or simulation.output_dataset.data is None: + simulation.ensure() + return getattr(simulation.output_dataset.data, entity) + + +def _sum_output_variable(simulation, variable: str, entity: str) -> float: + data = _entity_data(simulation, entity) + if variable in data.columns: + return float(data[variable].sum()) + + from policyengine.outputs import Aggregate, AggregateType + + output = Aggregate( + simulation=simulation, + variable=variable, + entity=entity, + aggregate_type=AggregateType.SUM, + ) + output.run() + return float(output.result) + + +def _try_sum_output_variable(simulation, variable: str, entity: str) -> float: + try: + return _sum_output_variable(simulation, variable, entity) + except Exception: + logger.warning("Unable to calculate sum for %s", variable, exc_info=True) + return 0.0 + + +def _change_output_variable(baseline, reform, variable: str, entity: str) -> float: + baseline_data = _entity_data(baseline, entity) + reform_data = _entity_data(reform, entity) + if variable in baseline_data.columns and variable in reform_data.columns: + return float((reform_data[variable] - baseline_data[variable]).sum()) + + from policyengine.outputs import ChangeAggregate, ChangeAggregateType + + output = ChangeAggregate( + baseline_simulation=baseline, + reform_simulation=reform, + variable=variable, + entity=entity, + aggregate_type=ChangeAggregateType.SUM, + ) + output.run() + return float(output.result) + + +def _try_change_output_variable(baseline, reform, variable: str, entity: str) -> float: + try: + return _change_output_variable(baseline, reform, variable, entity) + except Exception: + logger.warning("Unable to calculate change for %s", variable, exc_info=True) + return 0.0 + + +def _budget_result(country: str, baseline, reform) -> dict[str, float]: + tax_revenue_impact = _try_change_output_variable( + baseline, reform, "household_tax", entity="household" + ) + benefit_spending_impact = _try_change_output_variable( + baseline, reform, "household_benefits", entity="household" + ) + state_tax_revenue_impact = ( + _try_change_output_variable( + baseline, + reform, + "household_state_income_tax", + entity="household", + ) + if country == "us" + else 0.0 + ) + + return { + "tax_revenue_impact": tax_revenue_impact, + "state_tax_revenue_impact": state_tax_revenue_impact, + "benefit_spending_impact": benefit_spending_impact, + "budgetary_impact": tax_revenue_impact - benefit_spending_impact, + "households": _try_sum_output_variable( + baseline, "household_weight", entity="household" + ), + "baseline_net_income": _try_sum_output_variable( + baseline, "household_net_income", entity="household" + ), + } + + +def _poverty_module_function(name: str): + module = import_module("policyengine.outputs.poverty") + return getattr(module, name) + + +def _try_compute_output(label: str, fn, *args, **kwargs): + try: + return fn(*args, **kwargs) + except Exception: + logger.warning("Unable to calculate %s", label, exc_info=True) + return None + + +def _additional_poverty_outputs(country: str, baseline, reform) -> dict[str, Any]: + prefix = "us" if country == "us" else "uk" + output = { + "baseline_poverty_by_age": _try_compute_output( + "baseline poverty by age", + _poverty_module_function(f"calculate_{prefix}_poverty_by_age"), + baseline, + ), + "reform_poverty_by_age": _try_compute_output( + "reform poverty by age", + _poverty_module_function(f"calculate_{prefix}_poverty_by_age"), + reform, + ), + "baseline_poverty_by_gender": _try_compute_output( + "baseline poverty by gender", + _poverty_module_function(f"calculate_{prefix}_poverty_by_gender"), + baseline, + ), + "reform_poverty_by_gender": _try_compute_output( + "reform poverty by gender", + _poverty_module_function(f"calculate_{prefix}_poverty_by_gender"), + reform, + ), + "baseline_poverty_by_race": None, + "reform_poverty_by_race": None, + } + if country == "us": + output["baseline_poverty_by_race"] = _try_compute_output( + "baseline poverty by race", + _poverty_module_function("calculate_us_poverty_by_race"), + baseline, + ) + output["reform_poverty_by_race"] = _try_compute_output( + "reform poverty by race", + _poverty_module_function("calculate_us_poverty_by_race"), + reform, + ) + return output + + +def _intra_decile_output(baseline, reform): + from policyengine.outputs.intra_decile_impact import compute_intra_decile_impacts + + return _try_compute_output( + "intra-decile impacts", + compute_intra_decile_impacts, + baseline, + reform, + income_variable="household_net_income", + entity="household", + ) + + +def _congressional_district_impact(country: str, baseline, reform): + if country != "us": + return None + + from policyengine.outputs.congressional_district_impact import ( + compute_us_congressional_district_impacts, + ) + + impact = _try_compute_output( + "congressional district impacts", + compute_us_congressional_district_impacts, + baseline, + reform, + ) + return getattr(impact, "district_results", None) if impact is not None else None + + +def _model_version(country_module) -> str: + return str(getattr(country_module.model, "version", "")) + + +def _data_version(params: dict[str, Any], dataset) -> str: + if params.get("data_version"): + return str(params["data_version"]) + metadata = getattr(dataset, "metadata", {}) or {} + for key in ("data_version", "version"): + value = metadata.get(key) + if value is not None: + return str(value) + return "" + + def _run_simulation_impl_core(params: dict) -> dict: simulation_params, telemetry, metadata = split_internal_payload(params) @@ -127,17 +483,41 @@ def _run_simulation_impl_core(params: dict) -> dict: if metadata: logger.info("Received simulation metadata keys: %s", sorted(metadata)) - # Validate and create simulation options - options = SimulationOptions.model_validate(simulation_params) - logger.info("Initialising simulation from input") + country = simulation_params.get("country", "us").lower() + country_module = _country_module(country) + dataset = _load_dataset(simulation_params) + baseline_policy = _normalise_policy(simulation_params.get("baseline")) + reform_policy = _normalise_policy(simulation_params.get("reform")) - # Create simulation instance - simulation = Simulation(**options.model_dump()) - logger.info("Calculating comparison") + logger.info("Initialising baseline and reform simulations") + baseline = _build_simulation( + simulation_params, + dataset=dataset, + policy=baseline_policy, + ) + reform = _build_simulation( + simulation_params, + dataset=dataset, + policy=reform_policy, + ) - # Run the economy comparison calculation - result = simulation.calculate_economy_comparison() + logger.info("Calculating economic impact") + analysis = country_module.economic_impact_analysis(baseline, reform) + budget = _budget_result(country, baseline, reform) + poverty_outputs = _additional_poverty_outputs(country, baseline, reform) + intra_decile = _intra_decile_output(baseline, reform) + congressional_district_impact = _congressional_district_impact( + country, baseline, reform + ) logger.info("Comparison complete") - # Use mode='json' to ensure numpy arrays are converted to lists - return result.model_dump(mode="json") + return adapt_analysis_to_legacy_macro_output( + country=country, + model_version=_model_version(country_module), + data_version=_data_version(simulation_params, dataset), + budget=budget, + analysis=analysis, + intra_decile=intra_decile, + congressional_district_impact=congressional_district_impact, + **poverty_outputs, + ) diff --git a/projects/policyengine-api-simulation/src/modal/simulation_output_adapter.py b/projects/policyengine-api-simulation/src/modal/simulation_output_adapter.py new file mode 100644 index 000000000..865dbc7f4 --- /dev/null +++ b/projects/policyengine-api-simulation/src/modal/simulation_output_adapter.py @@ -0,0 +1,318 @@ +"""Adapt PolicyEngine v4 macro outputs to the existing simulation API shape.""" + +from __future__ import annotations + +import math +from collections.abc import Iterable, Mapping +from typing import Any + + +INTRA_DECILE_COLUMNS = { + "Lose more than 5%": "lose_more_than_5pct", + "Lose less than 5%": "lose_less_than_5pct", + "No change": "no_change", + "Gain less than 5%": "gain_less_than_5pct", + "Gain more than 5%": "gain_more_than_5pct", +} + +US_POVERTY_TYPES = { + "spm": "poverty", + "spm_deep": "deep_poverty", +} + +UK_POVERTY_TYPES = { + "relative_bhc": "poverty", + "absolute_bhc": "deep_poverty", +} + + +def _number(value: Any, default: float = 0.0) -> float: + if value is None: + return default + try: + result = float(value) + except (TypeError, ValueError): + return default + if math.isnan(result) or math.isinf(result): + return default + return result + + +def _collection_records(collection: Any) -> list[dict[str, Any]]: + if collection is None: + return [] + dataframe = getattr(collection, "dataframe", None) + if dataframe is not None: + return list(dataframe.to_dict("records")) + if isinstance(collection, list): + return [dict(item) for item in collection if isinstance(item, Mapping)] + return [] + + +def _output_model_dump(value: Any) -> Any: + if value is None: + return None + if hasattr(value, "model_dump"): + return value.model_dump(mode="json") + if isinstance(value, Mapping): + return dict(value) + return value + + +def _detailed_budget(collection: Any) -> dict[str, dict[str, float]]: + detailed_budget: dict[str, dict[str, float]] = {} + for row in _collection_records(collection): + program_name = row.get("program_name") + if not program_name: + continue + baseline = _number(row.get("baseline_total")) + reform = _number(row.get("reform_total")) + detailed_budget[str(program_name)] = { + "baseline": baseline, + "reform": reform, + "difference": _number(row.get("change"), reform - baseline), + } + return detailed_budget + + +def _decile_impact(collection: Any) -> dict[str, dict[str, float]]: + average: dict[str, float] = {} + relative: dict[str, float] = {} + for row in sorted( + _collection_records(collection), + key=lambda item: _number(item.get("decile")), + ): + decile = int(_number(row.get("decile"))) + if decile <= 0: + continue + key = str(decile) + average[key] = _number(row.get("absolute_change")) + relative[key] = _number(row.get("relative_change")) + return {"average": average, "relative": relative} + + +def _empty_intra_decile() -> dict[str, Any]: + return { + "deciles": {label: [] for label in INTRA_DECILE_COLUMNS}, + "all": {label: 0.0 for label in INTRA_DECILE_COLUMNS}, + } + + +def _intra_decile_impact(collection: Any) -> dict[str, Any]: + result = _empty_intra_decile() + rows = [ + row + for row in sorted( + _collection_records(collection), + key=lambda item: _number(item.get("decile")), + ) + if int(_number(row.get("decile"))) > 0 + ] + + for label, column in INTRA_DECILE_COLUMNS.items(): + values = [_number(row.get(column)) for row in rows] + result["deciles"][label] = values + result["all"][label] = sum(values) / len(values) if values else 0.0 + return result + + +def _empty_age_poverty() -> dict[str, dict[str, float]]: + return { + "child": {"baseline": 0.0, "reform": 0.0}, + "adult": {"baseline": 0.0, "reform": 0.0}, + "senior": {"baseline": 0.0, "reform": 0.0}, + "all": {"baseline": 0.0, "reform": 0.0}, + } + + +def _empty_gender_poverty() -> dict[str, dict[str, float]]: + return { + "male": {"baseline": 0.0, "reform": 0.0}, + "female": {"baseline": 0.0, "reform": 0.0}, + } + + +def _empty_race_poverty() -> dict[str, dict[str, float]]: + return { + "white": {"baseline": 0.0, "reform": 0.0}, + "black": {"baseline": 0.0, "reform": 0.0}, + "hispanic": {"baseline": 0.0, "reform": 0.0}, + "other": {"baseline": 0.0, "reform": 0.0}, + } + + +def _poverty_type(country: str, row: Mapping[str, Any]) -> str | None: + poverty_type = str(row.get("poverty_type") or "").lower() + if country == "us": + return US_POVERTY_TYPES.get(poverty_type) + return UK_POVERTY_TYPES.get(poverty_type) + + +def _fill_poverty_block( + *, + country: str, + output: dict[str, dict[str, dict[str, float]]], + baseline_records: Iterable[Mapping[str, Any]], + reform_records: Iterable[Mapping[str, Any]], + default_group: str, +) -> None: + for side, records in (("baseline", baseline_records), ("reform", reform_records)): + for row in records: + poverty_type = _poverty_type(country, row) + if poverty_type is None: + continue + if poverty_type not in output: + continue + group = str(row.get("filter_group") or default_group).lower() + if group not in output[poverty_type]: + continue + output[poverty_type][group][side] = _number(row.get("rate")) + + +def _poverty_impact( + country: str, + *, + baseline: Any, + reform: Any, + baseline_by_age: Any, + reform_by_age: Any, +) -> dict[str, dict[str, dict[str, float]]]: + result = {"poverty": _empty_age_poverty(), "deep_poverty": _empty_age_poverty()} + _fill_poverty_block( + country=country, + output=result, + baseline_records=_collection_records(baseline), + reform_records=_collection_records(reform), + default_group="all", + ) + _fill_poverty_block( + country=country, + output=result, + baseline_records=_collection_records(baseline_by_age), + reform_records=_collection_records(reform_by_age), + default_group="all", + ) + return result + + +def _poverty_by_gender( + country: str, + *, + baseline_by_gender: Any, + reform_by_gender: Any, +) -> dict[str, dict[str, dict[str, float]]]: + result = { + "poverty": _empty_gender_poverty(), + "deep_poverty": _empty_gender_poverty(), + } + _fill_poverty_block( + country=country, + output=result, + baseline_records=_collection_records(baseline_by_gender), + reform_records=_collection_records(reform_by_gender), + default_group="all", + ) + return result + + +def _poverty_by_race( + *, + baseline_by_race: Any, + reform_by_race: Any, +) -> dict[str, dict[str, dict[str, float]]]: + result = {"poverty": _empty_race_poverty()} + _fill_poverty_block( + country="us", + output=result, + baseline_records=_collection_records(baseline_by_race), + reform_records=_collection_records(reform_by_race), + default_group="all", + ) + return result + + +def _inequality_impact(baseline: Any, reform: Any) -> dict[str, Any]: + return { + "gini": { + "baseline": _number(getattr(baseline, "gini", None)), + "reform": _number(getattr(reform, "gini", None)), + }, + "top_10_pct_share": { + "baseline": _number(getattr(baseline, "top_10_share", None)), + "reform": _number(getattr(reform, "top_10_share", None)), + }, + "top_1_pct_share": { + "baseline": _number(getattr(baseline, "top_1_share", None)), + "reform": _number(getattr(reform, "top_1_share", None)), + }, + } + + +def adapt_analysis_to_legacy_macro_output( + *, + country: str, + model_version: str, + data_version: str, + budget: dict[str, float], + analysis: Any, + baseline_poverty_by_age: Any = None, + reform_poverty_by_age: Any = None, + baseline_poverty_by_gender: Any = None, + reform_poverty_by_gender: Any = None, + baseline_poverty_by_race: Any = None, + reform_poverty_by_race: Any = None, + intra_decile: Any = None, + congressional_district_impact: Any = None, + constituency_impact: Any = None, + local_authority_impact: Any = None, +) -> dict[str, Any]: + """Return the legacy single-year macro result expected by API callers.""" + country = country.lower() + wealth_decile = getattr(analysis, "wealth_decile_impacts", None) + intra_wealth_decile = getattr(analysis, "intra_wealth_decile_impacts", None) + + return { + "model_version": model_version, + "data_version": data_version, + "budget": budget, + "detailed_budget": _detailed_budget( + getattr(analysis, "program_statistics", None) + ), + "decile": _decile_impact(getattr(analysis, "decile_impacts", None)), + "inequality": _inequality_impact( + getattr(analysis, "baseline_inequality", None), + getattr(analysis, "reform_inequality", None), + ), + "poverty": _poverty_impact( + country, + baseline=getattr(analysis, "baseline_poverty", None), + reform=getattr(analysis, "reform_poverty", None), + baseline_by_age=baseline_poverty_by_age, + reform_by_age=reform_poverty_by_age, + ), + "poverty_by_gender": _poverty_by_gender( + country, + baseline_by_gender=baseline_poverty_by_gender, + reform_by_gender=reform_poverty_by_gender, + ), + "poverty_by_race": ( + _poverty_by_race( + baseline_by_race=baseline_poverty_by_race, + reform_by_race=reform_poverty_by_race, + ) + if country == "us" + else None + ), + "intra_decile": _intra_decile_impact(intra_decile), + "wealth_decile": _decile_impact(wealth_decile) if country == "uk" else None, + "intra_wealth_decile": ( + _intra_decile_impact(intra_wealth_decile) if country == "uk" else None + ), + "labor_supply_response": _output_model_dump( + getattr(analysis, "labor_supply_response", None) + ), + "constituency_impact": constituency_impact, + "local_authority_impact": local_authority_impact, + "congressional_district_impact": congressional_district_impact, + "cliff_impact": None, + } diff --git a/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation.py b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation.py index cc81f6be4..9b8567f66 100644 --- a/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation.py +++ b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation.py @@ -1,31 +1,20 @@ -from typing import Annotated -import os -from fastapi import APIRouter, Depends, HTTPException -from pydantic import BaseModel -from policyengine.simulation import SimulationOptions, Simulation -from policyengine.outputs.macro.comparison.calculate_economy_comparison import ( - EconomyComparison, -) -from pathlib import Path import logging +from fastapi import APIRouter + +from src.modal.simulation import run_simulation_impl + logger = logging.getLogger(__file__) def create_router(): router = APIRouter() - @router.post("/simulate/economy/comparison", response_model=EconomyComparison) - async def simulate(parameters: SimulationOptions) -> EconomyComparison: - model = SimulationOptions.model_validate(parameters) - logger.info("Initialising simulation from input") - simulation = Simulation(**model.model_dump()) + @router.post("/simulate/economy/comparison", response_model=dict) + async def simulate(parameters: dict) -> dict: logger.info("Calculating comparison") - result = ( - simulation.calculate_economy_comparison() # pyright: ignore [reportAttributeAccessIssue] - ) + result = run_simulation_impl(parameters) logger.info("Comparison complete") - return result return router diff --git a/projects/policyengine-api-simulation/tests/gateway/test_endpoints.py b/projects/policyengine-api-simulation/tests/gateway/test_endpoints.py index 506641600..58fc135c6 100644 --- a/projects/policyengine-api-simulation/tests/gateway/test_endpoints.py +++ b/projects/policyengine-api-simulation/tests/gateway/test_endpoints.py @@ -250,7 +250,7 @@ def test__given_submission_with_data__then_returns_resolved_bundle_metadata( "country": "us", "scope": "macro", "reform": {}, - "data": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.110.12", + "data": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.3", } # When @@ -262,7 +262,7 @@ def test__given_submission_with_data__then_returns_resolved_bundle_metadata( assert data["resolved_app_name"] == "policyengine-simulation-us1-500-0-uk2-66-0" assert data["policyengine_bundle"] == { "model_version": "1.500.0", - "dataset": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.110.12", + "dataset": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.3", } def test__given_submission_with_alias_data__then_bundle_dataset_stays_unresolved( @@ -286,7 +286,7 @@ def test__given_submission_with_alias_data__then_bundle_dataset_stays_unresolved data = response.json() assert ( data["policyengine_bundle"]["dataset"] - == "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.110.12" + == "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.3" ) def test__given_submission_with_uk_alias_data__then_bundle_dataset_is_versioned_uri( @@ -310,7 +310,7 @@ def test__given_submission_with_uk_alias_data__then_bundle_dataset_is_versioned_ data = response.json() assert ( data["policyengine_bundle"]["dataset"] - == "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.40.3" + == "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.55.5" ) def test__given_submission_with_runtime_bundle__then_accepts_internal_provenance( @@ -341,7 +341,7 @@ def test__given_submission_with_runtime_bundle__then_accepts_internal_provenance assert data["policyengine_bundle"] == { "model_version": "1.500.0", "data_version": "1.78.2", - "dataset": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.110.12", + "dataset": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.3", } assert mock_modal["func"].last_payload["data_version"] == "1.78.2" assert "_runtime_bundle" not in mock_modal["func"].last_payload @@ -388,7 +388,7 @@ def test__given_submitted_job__then_job_status_includes_bundle_metadata( "country": "us", "scope": "macro", "reform": {}, - "data": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.110.12", + "data": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.3", }, ) @@ -403,7 +403,7 @@ def test__given_submitted_job__then_job_status_includes_bundle_metadata( assert data["resolved_app_name"] == "policyengine-simulation-us1-500-0-uk2-66-0" assert data["policyengine_bundle"] == { "model_version": "1.500.0", - "dataset": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.110.12", + "dataset": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.3", } def test__given_submitted_job_with_telemetry__then_polling_echoes_run_id( diff --git a/projects/policyengine-api-simulation/tests/gateway/test_models.py b/projects/policyengine-api-simulation/tests/gateway/test_models.py index 500058949..72e4d96c8 100644 --- a/projects/policyengine-api-simulation/tests/gateway/test_models.py +++ b/projects/policyengine-api-simulation/tests/gateway/test_models.py @@ -293,7 +293,7 @@ def test_job_submit_response_creates_with_all_fields(self): "model_version": "1.459.0", "policyengine_version": None, "data_version": None, - "dataset": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.110.12", + "dataset": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.3", }, } @@ -312,7 +312,7 @@ def test_job_submit_response_creates_with_all_fields(self): assert response.policyengine_bundle.model_version == "1.459.0" assert response.policyengine_bundle.policyengine_version is None assert response.policyengine_bundle.dataset == ( - "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.110.12" + "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.3" ) @@ -376,7 +376,7 @@ def test_job_status_response_accepts_bundle_metadata(self): "model_version": "1.459.0", "policyengine_version": None, "data_version": None, - "dataset": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.110.12", + "dataset": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.3", }, ) @@ -385,7 +385,7 @@ def test_job_status_response_accepts_bundle_metadata(self): ) assert response.policyengine_bundle is not None assert response.policyengine_bundle.dataset == ( - "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.110.12" + "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.3" ) diff --git a/projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py b/projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py new file mode 100644 index 000000000..e260bd18c --- /dev/null +++ b/projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py @@ -0,0 +1,148 @@ +"""Tests for translating PolicyEngine v4 outputs into API-v2 macro results.""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pandas as pd + +from fixtures.test_simulation_api_contracts import CURRENT_SINGLE_YEAR_MACRO_KEYS +from fixtures.test_simulation_output_adapter import ( + BASELINE_POVERTY_BY_AGE, + BASELINE_POVERTY_BY_GENDER, + BASELINE_POVERTY_BY_RACE, + BUDGET, + INTRA_DECILE_COLLECTION, + REFORM_POVERTY_BY_AGE, + REFORM_POVERTY_BY_GENDER, + REFORM_POVERTY_BY_RACE, + fake_analysis, +) +from src.modal.simulation import _budget_result, _normalise_policy +from src.modal.simulation_output_adapter import adapt_analysis_to_legacy_macro_output + + +def test_adapter_returns_existing_single_year_macro_shape(): + output = adapt_analysis_to_legacy_macro_output( + country="us", + model_version="1.691.12", + data_version="1.115.3", + budget=BUDGET, + analysis=fake_analysis(), + baseline_poverty_by_age=BASELINE_POVERTY_BY_AGE, + reform_poverty_by_age=REFORM_POVERTY_BY_AGE, + baseline_poverty_by_gender=BASELINE_POVERTY_BY_GENDER, + reform_poverty_by_gender=REFORM_POVERTY_BY_GENDER, + baseline_poverty_by_race=BASELINE_POVERTY_BY_RACE, + reform_poverty_by_race=REFORM_POVERTY_BY_RACE, + intra_decile=INTRA_DECILE_COLLECTION, + congressional_district_impact=[{"district_geoid": 101}], + ) + + assert set(output) == CURRENT_SINGLE_YEAR_MACRO_KEYS + assert output["model_version"] == "1.691.12" + assert output["data_version"] == "1.115.3" + assert output["budget"] == BUDGET + assert output["detailed_budget"] == { + "income_tax": {"baseline": 100.0, "reform": 125.0, "difference": 25.0} + } + assert output["decile"] == { + "average": {"1": 10.0, "2": 20.0}, + "relative": {"1": 0.01, "2": 0.02}, + } + assert output["intra_decile"]["deciles"]["Gain more than 5%"] == [0.5, 0.1] + assert output["intra_decile"]["all"]["Gain more than 5%"] == 0.3 + assert output["poverty"]["poverty"]["all"] == { + "baseline": 0.10, + "reform": 0.09, + } + assert output["poverty"]["poverty"]["child"] == { + "baseline": 0.12, + "reform": 0.11, + } + assert output["poverty_by_gender"]["poverty"]["male"] == { + "baseline": 0.08, + "reform": 0.07, + } + assert output["poverty_by_race"]["poverty"]["white"] == { + "baseline": 0.06, + "reform": 0.05, + } + assert output["wealth_decile"] is None + assert output["intra_wealth_decile"] is None + assert output["congressional_district_impact"] == [{"district_geoid": 101}] + + +def test_adapter_maps_uk_wealth_outputs_and_omits_us_only_race(): + output = adapt_analysis_to_legacy_macro_output( + country="uk", + model_version="2.88.14", + data_version="1.55.5", + budget={**BUDGET, "state_tax_revenue_impact": 0.0}, + analysis=fake_analysis(), + intra_decile=INTRA_DECILE_COLLECTION, + ) + + assert output["poverty_by_race"] is None + assert output["wealth_decile"] == { + "average": {"1": 30.0}, + "relative": {"1": 0.03}, + } + assert output["intra_wealth_decile"]["deciles"]["Lose more than 5%"] == [0.1] + + +def test_normalise_policy_converts_legacy_period_range_keys(): + assert _normalise_policy({"gov.test.parameter": {"2026-01-01.2100-12-31": 1}}) == { + "gov.test.parameter": {"2026-01-01": 1} + } + + +class _FakeOutputDataset: + def __init__(self, household): + self.data = SimpleNamespace(household=household) + + +class _FakeSimulation: + def __init__(self, household): + self.output_dataset = _FakeOutputDataset(household) + + def ensure(self): + raise AssertionError("test data is already materialized") + + +def test_budget_result_uses_materialized_household_columns_and_uk_state_tax_zero(): + baseline = _FakeSimulation( + pd.DataFrame( + { + "household_weight": [1.0, 2.0], + "household_net_income": [100.0, 200.0], + "household_tax": [20.0, 40.0], + "household_benefits": [5.0, 10.0], + "household_state_income_tax": [2.0, 3.0], + } + ) + ) + reform = _FakeSimulation( + pd.DataFrame( + { + "household_weight": [1.0, 2.0], + "household_net_income": [120.0, 210.0], + "household_tax": [25.0, 50.0], + "household_benefits": [4.0, 8.0], + "household_state_income_tax": [4.0, 6.0], + } + ) + ) + + us_budget = _budget_result("us", baseline, reform) + uk_budget = _budget_result("uk", baseline, reform) + + assert us_budget == { + "tax_revenue_impact": 15.0, + "state_tax_revenue_impact": 5.0, + "benefit_spending_impact": -3.0, + "budgetary_impact": 18.0, + "households": 3.0, + "baseline_net_income": 300.0, + } + assert uk_budget["state_tax_revenue_impact"] == 0.0 From ba38c2180865910c2ff761d7d0c8130d5f9550cf Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 21 May 2026 21:35:22 +0200 Subject: [PATCH 03/23] fix: align simulation API with PolicyEngine v4 outputs --- .../fixtures/test_simulation_api_contracts.py | 4 +- .../src/modal/gateway/endpoints.py | 20 +- .../src/modal/simulation.py | 46 +++- .../tests/gateway/test_endpoints.py | 14 +- .../tests/gateway/test_models.py | 69 ++---- .../tests/test_modal_scripts.py | 12 +- .../tests/test_simulation_output_adapter.py | 71 +++++- projects/policyengine-api-simulation/uv.lock | 206 ++++++++---------- 8 files changed, 246 insertions(+), 196 deletions(-) diff --git a/projects/policyengine-api-simulation/fixtures/test_simulation_api_contracts.py b/projects/policyengine-api-simulation/fixtures/test_simulation_api_contracts.py index 28e2dd1ab..7e900e79c 100644 --- a/projects/policyengine-api-simulation/fixtures/test_simulation_api_contracts.py +++ b/projects/policyengine-api-simulation/fixtures/test_simulation_api_contracts.py @@ -30,8 +30,8 @@ } CURRENT_SINGLE_YEAR_MACRO_RESULT = { - "model_version": "1.691.3", - "data_version": "1.115.3", + "model_version": "1.702.0", + "data_version": "1.115.5", "budget": { "budgetary_impact": 300.0, "tax_revenue_impact": 500.0, diff --git a/projects/policyengine-api-simulation/src/modal/gateway/endpoints.py b/projects/policyengine-api-simulation/src/modal/gateway/endpoints.py index a3cc84f49..eb382a41a 100644 --- a/projects/policyengine-api-simulation/src/modal/gateway/endpoints.py +++ b/projects/policyengine-api-simulation/src/modal/gateway/endpoints.py @@ -41,18 +41,18 @@ JOB_METADATA_DICT_NAME = "simulation-api-job-metadata" DATASET_URIS = { "us": { - "enhanced_cps": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.3", - "enhanced_cps_2024": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.3", - "cps": "hf://policyengine/policyengine-us-data/cps_2023.h5@1.115.3", - "cps_2023": "hf://policyengine/policyengine-us-data/cps_2023.h5@1.115.3", - "pooled_cps": "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.115.3", - "pooled_3_year_cps_2023": "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.115.3", + "enhanced_cps": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.5", + "enhanced_cps_2024": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.5", + "cps": "hf://policyengine/policyengine-us-data/cps_2023.h5@1.115.5", + "cps_2023": "hf://policyengine/policyengine-us-data/cps_2023.h5@1.115.5", + "pooled_cps": "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.115.5", + "pooled_3_year_cps_2023": "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.115.5", }, "uk": { - "enhanced_frs": "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.55.5", - "enhanced_frs_2023_24": "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.55.5", - "frs": "hf://policyengine/policyengine-uk-data-private/frs_2023_24.h5@1.55.5", - "frs_2023_24": "hf://policyengine/policyengine-uk-data-private/frs_2023_24.h5@1.55.5", + "enhanced_frs": "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.55.10", + "enhanced_frs_2023_24": "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.55.10", + "frs": "hf://policyengine/policyengine-uk-data-private/frs_2023_24.h5@1.55.10", + "frs_2023_24": "hf://policyengine/policyengine-uk-data-private/frs_2023_24.h5@1.55.10", }, } diff --git a/projects/policyengine-api-simulation/src/modal/simulation.py b/projects/policyengine-api-simulation/src/modal/simulation.py index 0d3b6adfc..58cdb7c68 100644 --- a/projects/policyengine-api-simulation/src/modal/simulation.py +++ b/projects/policyengine-api-simulation/src/modal/simulation.py @@ -373,11 +373,15 @@ def _budget_result(country: str, baseline, reform) -> dict[str, float]: } -def _poverty_module_function(name: str): - module = import_module("policyengine.outputs.poverty") +def _output_module_function(module_name: str, name: str): + module = import_module(f"policyengine.outputs.{module_name}") return getattr(module, name) +def _poverty_module_function(name: str): + return _output_module_function("poverty", name) + + def _try_compute_output(label: str, fn, *args, **kwargs): try: return fn(*args, **kwargs) @@ -456,6 +460,40 @@ def _congressional_district_impact(country: str, baseline, reform): return getattr(impact, "district_results", None) if impact is not None else None +def _uk_constituency_impact(country: str, baseline, reform): + if country != "uk": + return None + + impact = _try_compute_output( + "constituency impacts", + _output_module_function( + "constituency_impact", "compute_uk_constituency_impacts" + ), + baseline, + reform, + ) + if impact is None: + return None + return getattr(impact, "constituency_results", None) + + +def _uk_local_authority_impact(country: str, baseline, reform): + if country != "uk": + return None + + impact = _try_compute_output( + "local authority impacts", + _output_module_function( + "local_authority_impact", "compute_uk_local_authority_impacts" + ), + baseline, + reform, + ) + if impact is None: + return None + return getattr(impact, "local_authority_results", None) + + def _model_version(country_module) -> str: return str(getattr(country_module.model, "version", "")) @@ -509,6 +547,8 @@ def _run_simulation_impl_core(params: dict) -> dict: congressional_district_impact = _congressional_district_impact( country, baseline, reform ) + constituency_impact = _uk_constituency_impact(country, baseline, reform) + local_authority_impact = _uk_local_authority_impact(country, baseline, reform) logger.info("Comparison complete") return adapt_analysis_to_legacy_macro_output( @@ -519,5 +559,7 @@ def _run_simulation_impl_core(params: dict) -> dict: analysis=analysis, intra_decile=intra_decile, congressional_district_impact=congressional_district_impact, + constituency_impact=constituency_impact, + local_authority_impact=local_authority_impact, **poverty_outputs, ) diff --git a/projects/policyengine-api-simulation/tests/gateway/test_endpoints.py b/projects/policyengine-api-simulation/tests/gateway/test_endpoints.py index 58fc135c6..e2fe70611 100644 --- a/projects/policyengine-api-simulation/tests/gateway/test_endpoints.py +++ b/projects/policyengine-api-simulation/tests/gateway/test_endpoints.py @@ -250,7 +250,7 @@ def test__given_submission_with_data__then_returns_resolved_bundle_metadata( "country": "us", "scope": "macro", "reform": {}, - "data": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.3", + "data": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.5", } # When @@ -262,7 +262,7 @@ def test__given_submission_with_data__then_returns_resolved_bundle_metadata( assert data["resolved_app_name"] == "policyengine-simulation-us1-500-0-uk2-66-0" assert data["policyengine_bundle"] == { "model_version": "1.500.0", - "dataset": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.3", + "dataset": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.5", } def test__given_submission_with_alias_data__then_bundle_dataset_stays_unresolved( @@ -286,7 +286,7 @@ def test__given_submission_with_alias_data__then_bundle_dataset_stays_unresolved data = response.json() assert ( data["policyengine_bundle"]["dataset"] - == "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.3" + == "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.5" ) def test__given_submission_with_uk_alias_data__then_bundle_dataset_is_versioned_uri( @@ -310,7 +310,7 @@ def test__given_submission_with_uk_alias_data__then_bundle_dataset_is_versioned_ data = response.json() assert ( data["policyengine_bundle"]["dataset"] - == "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.55.5" + == "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.55.10" ) def test__given_submission_with_runtime_bundle__then_accepts_internal_provenance( @@ -341,7 +341,7 @@ def test__given_submission_with_runtime_bundle__then_accepts_internal_provenance assert data["policyengine_bundle"] == { "model_version": "1.500.0", "data_version": "1.78.2", - "dataset": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.3", + "dataset": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.5", } assert mock_modal["func"].last_payload["data_version"] == "1.78.2" assert "_runtime_bundle" not in mock_modal["func"].last_payload @@ -388,7 +388,7 @@ def test__given_submitted_job__then_job_status_includes_bundle_metadata( "country": "us", "scope": "macro", "reform": {}, - "data": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.3", + "data": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.5", }, ) @@ -403,7 +403,7 @@ def test__given_submitted_job__then_job_status_includes_bundle_metadata( assert data["resolved_app_name"] == "policyengine-simulation-us1-500-0-uk2-66-0" assert data["policyengine_bundle"] == { "model_version": "1.500.0", - "dataset": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.3", + "dataset": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.5", } def test__given_submitted_job_with_telemetry__then_polling_echoes_run_id( diff --git a/projects/policyengine-api-simulation/tests/gateway/test_models.py b/projects/policyengine-api-simulation/tests/gateway/test_models.py index 72e4d96c8..9bb56ecf9 100644 --- a/projects/policyengine-api-simulation/tests/gateway/test_models.py +++ b/projects/policyengine-api-simulation/tests/gateway/test_models.py @@ -1,5 +1,7 @@ """Tests for gateway Pydantic models.""" +import json + import pytest from pydantic import ValidationError @@ -11,11 +13,12 @@ BudgetWindowBatchSubmitResponse, BudgetWindowResult, BudgetWindowTotals, + JobStatusResponse, + JobSubmitResponse, + MAX_GATEWAY_REQUEST_BYTES, PingRequest, PingResponse, SimulationRequest, - JobSubmitResponse, - JobStatusResponse, ) @@ -113,6 +116,16 @@ def test_ping_response_serializes_correctly(self): class TestSimulationRequest: """Tests for SimulationRequest model.""" + @staticmethod + def _payload_with_encoded_size(target_size: int) -> dict: + payload = {"country": "us", "reform": {"mock.parameter": {"2024-01-01": ""}}} + base_size = len(json.dumps(payload, default=str)) + payload["reform"]["mock.parameter"]["2024-01-01"] = "x" * ( + target_size - base_size + ) + assert len(json.dumps(payload, default=str)) == target_size + return payload + def test_simulation_request_requires_country(self): """ Given no country @@ -200,54 +213,18 @@ def test_simulation_request_rejects_oversized_payload(self): def test_simulation_request_accepts_payload_just_below_256kb(self): """256 KB boundary (#450): a payload just below the cap must be - accepted. We build a reform dict whose JSON encoding is ~256_100 - bytes before adding the ``country`` key, then prune until just under - the 262_144 byte limit.""" - import json - - from src.modal.gateway.models import MAX_GATEWAY_REQUEST_BYTES - - # Each entry in this shape encodes to roughly 40 bytes of JSON. We - # build a generous reform then trim the last few entries until the - # total is safely under the cap. - reform = {f"mock.parameter[{i}]": {"2024-01-01": i} for i in range(6_000)} - payload = {"country": "us", "reform": reform} - - while len(json.dumps(payload, default=str)) > MAX_GATEWAY_REQUEST_BYTES - 200: - reform.popitem() - - encoded_size = len(json.dumps(payload, default=str)) - assert encoded_size < MAX_GATEWAY_REQUEST_BYTES, ( - f"Test setup produced {encoded_size} bytes, " - f"expected < {MAX_GATEWAY_REQUEST_BYTES}" - ) + accepted.""" + payload = self._payload_with_encoded_size(MAX_GATEWAY_REQUEST_BYTES - 1) # Must not raise — this is the just-under-cap happy path. request = SimulationRequest(**payload) assert request.country == "us" - assert len(request.reform) == len(reform) + assert request.reform == payload["reform"] def test_simulation_request_rejects_payload_just_above_256kb(self): """The cap is strict: a payload that crosses 262_144 bytes by even a few bytes should be rejected with a ``too large`` ValidationError.""" - import json - - from src.modal.gateway.models import MAX_GATEWAY_REQUEST_BYTES - - # Generate enough entries to definitely exceed the cap, then trim to - # just a few bytes above it. - reform = {f"mock.parameter[{i}]": {"2024-01-01": i} for i in range(12_000)} - payload = {"country": "us", "reform": reform} - - # Trim down to just above the cap (within 200 bytes). - while len(json.dumps(payload, default=str)) > MAX_GATEWAY_REQUEST_BYTES + 200: - reform.popitem() - - encoded_size = len(json.dumps(payload, default=str)) - assert encoded_size > MAX_GATEWAY_REQUEST_BYTES, ( - f"Test setup produced {encoded_size} bytes, " - f"expected > {MAX_GATEWAY_REQUEST_BYTES}" - ) + payload = self._payload_with_encoded_size(MAX_GATEWAY_REQUEST_BYTES + 1) with pytest.raises(ValidationError, match="too large"): SimulationRequest(**payload) @@ -293,7 +270,7 @@ def test_job_submit_response_creates_with_all_fields(self): "model_version": "1.459.0", "policyengine_version": None, "data_version": None, - "dataset": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.3", + "dataset": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.5", }, } @@ -312,7 +289,7 @@ def test_job_submit_response_creates_with_all_fields(self): assert response.policyengine_bundle.model_version == "1.459.0" assert response.policyengine_bundle.policyengine_version is None assert response.policyengine_bundle.dataset == ( - "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.3" + "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.5" ) @@ -376,7 +353,7 @@ def test_job_status_response_accepts_bundle_metadata(self): "model_version": "1.459.0", "policyengine_version": None, "data_version": None, - "dataset": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.3", + "dataset": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.5", }, ) @@ -385,7 +362,7 @@ def test_job_status_response_accepts_bundle_metadata(self): ) assert response.policyengine_bundle is not None assert response.policyengine_bundle.dataset == ( - "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.3" + "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.5" ) diff --git a/projects/policyengine-api-simulation/tests/test_modal_scripts.py b/projects/policyengine-api-simulation/tests/test_modal_scripts.py index 648273285..31ef338ed 100644 --- a/projects/policyengine-api-simulation/tests/test_modal_scripts.py +++ b/projects/policyengine-api-simulation/tests/test_modal_scripts.py @@ -446,7 +446,7 @@ def test_exports_us_and_uk_model_versions_to_integration_tests(self, tmp_path): fake_bin.mkdir() fake_uv = fake_bin / "uv" fake_uv.write_text( - '#!/bin/bash\n' + "#!/bin/bash\n" 'printf "%s|base=%s|us=%s|uk=%s\\n" "$*" ' '"${simulation_integ_test_base_url:-}" ' '"${simulation_integ_test_us_model_version:-}" ' @@ -473,7 +473,7 @@ def test_exports_us_and_uk_model_versions_to_integration_tests(self, tmp_path): "prod", "https://example.com", "1.690.7", - "2.88.14", + "2.88.20", ], capture_output=True, text=True, @@ -486,7 +486,7 @@ def test_exports_us_and_uk_model_versions_to_integration_tests(self, tmp_path): assert "run pytest tests/simulation/ -v -m not beta_only" in log assert "base=https://example.com" in log assert "us=1.690.7" in log - assert "uk=2.88.14" in log + assert "uk=2.88.20" in log class TestAllScriptsHaveShebang: @@ -513,6 +513,6 @@ def test_all_scripts_have_valid_syntax(self, all_modal_scripts): capture_output=True, text=True, ) - assert result.returncode == 0, ( - f"{script.name} has syntax errors: {result.stderr}" - ) + assert ( + result.returncode == 0 + ), f"{script.name} has syntax errors: {result.stderr}" diff --git a/projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py b/projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py index e260bd18c..dc4f3ed54 100644 --- a/projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py +++ b/projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py @@ -18,15 +18,20 @@ REFORM_POVERTY_BY_RACE, fake_analysis, ) -from src.modal.simulation import _budget_result, _normalise_policy +from src.modal.simulation import ( + _budget_result, + _normalise_policy, + _uk_constituency_impact, + _uk_local_authority_impact, +) from src.modal.simulation_output_adapter import adapt_analysis_to_legacy_macro_output def test_adapter_returns_existing_single_year_macro_shape(): output = adapt_analysis_to_legacy_macro_output( country="us", - model_version="1.691.12", - data_version="1.115.3", + model_version="1.702.0", + data_version="1.115.5", budget=BUDGET, analysis=fake_analysis(), baseline_poverty_by_age=BASELINE_POVERTY_BY_AGE, @@ -40,8 +45,8 @@ def test_adapter_returns_existing_single_year_macro_shape(): ) assert set(output) == CURRENT_SINGLE_YEAR_MACRO_KEYS - assert output["model_version"] == "1.691.12" - assert output["data_version"] == "1.115.3" + assert output["model_version"] == "1.702.0" + assert output["data_version"] == "1.115.5" assert output["budget"] == BUDGET assert output["detailed_budget"] == { "income_tax": {"baseline": 100.0, "reform": 125.0, "difference": 25.0} @@ -76,11 +81,13 @@ def test_adapter_returns_existing_single_year_macro_shape(): def test_adapter_maps_uk_wealth_outputs_and_omits_us_only_race(): output = adapt_analysis_to_legacy_macro_output( country="uk", - model_version="2.88.14", - data_version="1.55.5", + model_version="2.88.20", + data_version="1.55.10", budget={**BUDGET, "state_tax_revenue_impact": 0.0}, analysis=fake_analysis(), intra_decile=INTRA_DECILE_COLLECTION, + constituency_impact=[{"constituency_code": "E14000530"}], + local_authority_impact=[{"local_authority_code": "E06000001"}], ) assert output["poverty_by_race"] is None @@ -89,6 +96,8 @@ def test_adapter_maps_uk_wealth_outputs_and_omits_us_only_race(): "relative": {"1": 0.03}, } assert output["intra_wealth_decile"]["deciles"]["Lose more than 5%"] == [0.1] + assert output["constituency_impact"] == [{"constituency_code": "E14000530"}] + assert output["local_authority_impact"] == [{"local_authority_code": "E06000001"}] def test_normalise_policy_converts_legacy_period_range_keys(): @@ -146,3 +155,51 @@ def test_budget_result_uses_materialized_household_columns_and_uk_state_tax_zero "baseline_net_income": 300.0, } assert uk_budget["state_tax_revenue_impact"] == 0.0 + + +def test_uk_constituency_impact_uses_policyengine_output_function(monkeypatch): + baseline = object() + reform = object() + expected = [{"constituency_code": "E14000530"}] + + def fake_output_module_function(module_name, name): + assert module_name == "constituency_impact" + assert name == "compute_uk_constituency_impacts" + + def compute(baseline_simulation, reform_simulation): + assert baseline_simulation is baseline + assert reform_simulation is reform + return SimpleNamespace(constituency_results=expected) + + return compute + + monkeypatch.setattr( + "src.modal.simulation._output_module_function", fake_output_module_function + ) + + assert _uk_constituency_impact("uk", baseline, reform) == expected + assert _uk_constituency_impact("us", baseline, reform) is None + + +def test_uk_local_authority_impact_uses_policyengine_output_function(monkeypatch): + baseline = object() + reform = object() + expected = [{"local_authority_code": "E06000001"}] + + def fake_output_module_function(module_name, name): + assert module_name == "local_authority_impact" + assert name == "compute_uk_local_authority_impacts" + + def compute(baseline_simulation, reform_simulation): + assert baseline_simulation is baseline + assert reform_simulation is reform + return SimpleNamespace(local_authority_results=expected) + + return compute + + monkeypatch.setattr( + "src.modal.simulation._output_module_function", fake_output_module_function + ) + + assert _uk_local_authority_impact("uk", baseline, reform) == expected + assert _uk_local_authority_impact("us", baseline, reform) is None diff --git a/projects/policyengine-api-simulation/uv.lock b/projects/policyengine-api-simulation/uv.lock index 0f679e01b..3623dbd7b 100644 --- a/projects/policyengine-api-simulation/uv.lock +++ b/projects/policyengine-api-simulation/uv.lock @@ -161,15 +161,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6e/e2/d5b09cec0383381026c41fd071ae6a9342dfd70d0584aeae672e77dda82f/blosc2-4.1.2-cp313-cp313-win_amd64.whl", hash = "sha256:a72cc1fdc74744723092ccb63d03cf49c64f911450d2c9296182ce7bcda45d04", size = 3147727, upload-time = "2026-03-03T11:04:57.506Z" }, ] -[[package]] -name = "caugetch" -version = "0.0.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a3/ec/519cb37e3e58e23a5b02a74049128f6e701ccd8892b0cebecf701fac6177/caugetch-0.0.1.tar.gz", hash = "sha256:6f6ddb3b928fa272071b02aabb3342941cd99992f27413ba8c189eb4dc3e33b0", size = 2071, upload-time = "2019-10-15T22:39:49.315Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/70/33/64fee4626ec943c2d0c4eee31c784dab8452dfe014916190730880d4ea62/caugetch-0.0.1-py3-none-any.whl", hash = "sha256:ee743dcbb513409cd24cfc42435418073683ba2f4bb7ee9f8440088a47d59277", size = 3439, upload-time = "2019-10-15T22:39:47.122Z" }, -] - [[package]] name = "cbor2" version = "5.9.0" @@ -267,15 +258,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ae/44/c1221527f6a71a01ec6fbad7fa78f1d50dfa02217385cf0fa3eec7087d59/click-8.3.3-py3-none-any.whl", hash = "sha256:a2bf429bb3033c89fa4936ffb35d5cb471e3719e1f3c8a7c3fff0b8314305613", size = 110502, upload-time = "2026-04-22T15:11:25.044Z" }, ] -[[package]] -name = "clipboard" -version = "0.0.4" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pyperclip" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/8a/38/17f3885713d0f39994563029942b1d31c93d4e56d80da505abfbfb3a3bc4/clipboard-0.0.4.tar.gz", hash = "sha256:a72a78e9c9bf68da1c3f29ee022417d13ec9e3824b511559fd2b702b1dd5b817", size = 1713, upload-time = "2014-05-22T12:49:08.683Z" } - [[package]] name = "colorama" version = "0.4.6" @@ -384,15 +366,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/84/d0/205d54408c08b13550c733c4b85429e7ead111c7f0014309637425520a9a/deprecated-1.3.1-py2.py3-none-any.whl", hash = "sha256:597bfef186b6f60181535a29fbe44865ce137a5079f295b479886c82729d5f3f", size = 11298, upload-time = "2025-10-30T08:19:00.758Z" }, ] -[[package]] -name = "diskcache" -version = "5.6.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/3f/21/1c1ffc1a039ddcc459db43cc108658f32c57d271d7289a2794e401d0fdb6/diskcache-5.6.3.tar.gz", hash = "sha256:2c3a3fa2743d8535d832ec61c2054a1641f41775aa7c556758a109941e33e4fc", size = 67916, upload-time = "2023-08-31T06:12:00.316Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/3f/27/4570e78fc0bf5ea0ca45eb1de3818a23787af9b390c0b0a0033a1b8236f9/diskcache-5.6.3-py3-none-any.whl", hash = "sha256:5e31b2d5fbad117cc363ebaf6b689474db18a1f6438bc82358b024abd4c2ca19", size = 45550, upload-time = "2023-08-31T06:11:58.822Z" }, -] - [[package]] name = "dnspython" version = "2.8.0" @@ -588,21 +561,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d5/1f/5f4a3cd9e4440e9d9bc78ad0a91a1c8d46b4d429d5239ebe6793c9fe5c41/fsspec-2026.3.0-py3-none-any.whl", hash = "sha256:d2ceafaad1b3457968ed14efa28798162f1638dbb5d2a6868a2db002a5ee39a4", size = 202595, upload-time = "2026-03-27T19:11:13.595Z" }, ] -[[package]] -name = "getpass4" -version = "0.0.14.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "caugetch" }, - { name = "clipboard" }, - { name = "colorama" }, - { name = "pyperclip" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a2/f9/312f84afc384f693d02eb4ff7306a7268577a8b808aa08f0124c9abba683/getpass4-0.0.14.1.tar.gz", hash = "sha256:80aa4e3a665f2eccc6cda3ee22125eeb5c6338e91c40c4fd010b3c94c7aa4d3a", size = 5078, upload-time = "2021-11-28T17:08:47.276Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/0f/d3/ea114aba31f76418b2162e811793cde2e822c9d9ea8ca98d67f9e1f1bde6/getpass4-0.0.14.1-py3-none-any.whl", hash = "sha256:6642c11fb99db1bec90b963e863ec71cdb0b8888000f5089c6377bfbf833f8a9", size = 8683, upload-time = "2021-11-28T17:08:45.468Z" }, -] - [[package]] name = "google-api-core" version = "2.30.3" @@ -638,19 +596,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/73/76/d241a5c927433420507215df6cac1b1fa4ac0ba7a794df42a84326c68da8/google_auth-2.49.2-py3-none-any.whl", hash = "sha256:c2720924dfc82dedb962c9f52cabb2ab16714fd0a6a707e40561d217574ed6d5", size = 240638, upload-time = "2026-04-10T00:41:14.501Z" }, ] -[[package]] -name = "google-cloud-core" -version = "2.5.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-api-core" }, - { name = "google-auth" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/dc/24/6ca08b0a03c7b0c620427503ab00353a4ae806b848b93bcea18b6b76fde6/google_cloud_core-2.5.1.tar.gz", hash = "sha256:3dc94bdec9d05a31d9f355045ed0f369fbc0d8c665076c734f065d729800f811", size = 36078, upload-time = "2026-03-30T22:50:08.057Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/73/d9/5bb050cb32826466aa9b25f79e2ca2879fe66cb76782d4ed798dd7506151/google_cloud_core-2.5.1-py3-none-any.whl", hash = "sha256:ea62cdf502c20e3e14be8a32c05ed02113d7bef454e40ff3fab6fe1ec9f1f4e7", size = 29452, upload-time = "2026-03-30T22:48:31.567Z" }, -] - [[package]] name = "google-cloud-monitoring" version = "2.30.0" @@ -667,23 +612,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ad/c8/666c21c470b9d6fd62ac9ee74dc265419975228f9b16f8ad72ec22e8d98b/google_cloud_monitoring-2.30.0-py3-none-any.whl", hash = "sha256:2729f3b88a4798b7757b1d9d31b6cb562bb3544e8173765e4e5cd44d8685b1ed", size = 391367, upload-time = "2026-03-26T22:15:04.088Z" }, ] -[[package]] -name = "google-cloud-storage" -version = "3.10.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-api-core" }, - { name = "google-auth" }, - { name = "google-cloud-core" }, - { name = "google-crc32c" }, - { name = "google-resumable-media" }, - { name = "requests" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/4c/47/205eb8e9a1739b5345843e5a425775cbdc472cc38e7eda082ba5b8d02450/google_cloud_storage-3.10.1.tar.gz", hash = "sha256:97db9aa4460727982040edd2bd13ff3d5e2260b5331ad22895802da1fc2a5286", size = 17309950, upload-time = "2026-03-23T09:35:23.409Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ad/ff/ca9ab2417fa913d75aae38bf40bf856bb2749a604b2e0f701b37cfcd23cc/google_cloud_storage-3.10.1-py3-none-any.whl", hash = "sha256:a72f656759b7b99bda700f901adcb3425a828d4a29f911bc26b3ea79c5b1217f", size = 324453, upload-time = "2026-03-23T09:35:21.368Z" }, -] - [[package]] name = "google-cloud-trace" version = "1.19.0" @@ -700,31 +628,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a4/91/0090acafa7d2caf1bf0d7222d42935e118164a539f9f9a00a814afa63fa1/google_cloud_trace-1.19.0-py3-none-any.whl", hash = "sha256:59604c4c775c40af31b367df6bada0af34518cc35ac8cfedecd43898a120c51d", size = 108454, upload-time = "2026-03-26T22:14:32.631Z" }, ] -[[package]] -name = "google-crc32c" -version = "1.8.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/03/41/4b9c02f99e4c5fb477122cd5437403b552873f014616ac1d19ac8221a58d/google_crc32c-1.8.0.tar.gz", hash = "sha256:a428e25fb7691024de47fecfbff7ff957214da51eddded0da0ae0e0f03a2cf79", size = 14192, upload-time = "2025-12-16T00:35:25.142Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d1/db/000f15b41724589b0e7bc24bc7a8967898d8d3bc8caf64c513d91ef1f6c0/google_crc32c-1.8.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:3ebb04528e83b2634857f43f9bb8ef5b2bbe7f10f140daeb01b58f972d04736b", size = 31297, upload-time = "2025-12-16T00:23:20.709Z" }, - { url = "https://files.pythonhosted.org/packages/d7/0d/8ebed0c39c53a7e838e2a486da8abb0e52de135f1b376ae2f0b160eb4c1a/google_crc32c-1.8.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:450dc98429d3e33ed2926fc99ee81001928d63460f8538f21a5d6060912a8e27", size = 30867, upload-time = "2025-12-16T00:43:14.628Z" }, - { url = "https://files.pythonhosted.org/packages/ce/42/b468aec74a0354b34c8cbf748db20d6e350a68a2b0912e128cabee49806c/google_crc32c-1.8.0-cp313-cp313-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:3b9776774b24ba76831609ffbabce8cdf6fa2bd5e9df37b594221c7e333a81fa", size = 33344, upload-time = "2025-12-16T00:40:24.742Z" }, - { url = "https://files.pythonhosted.org/packages/1c/e8/b33784d6fc77fb5062a8a7854e43e1e618b87d5ddf610a88025e4de6226e/google_crc32c-1.8.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:89c17d53d75562edfff86679244830599ee0a48efc216200691de8b02ab6b2b8", size = 33694, upload-time = "2025-12-16T00:40:25.505Z" }, - { url = "https://files.pythonhosted.org/packages/92/b1/d3cbd4d988afb3d8e4db94ca953df429ed6db7282ed0e700d25e6c7bfc8d/google_crc32c-1.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:57a50a9035b75643996fbf224d6661e386c7162d1dfdab9bc4ca790947d1007f", size = 34435, upload-time = "2025-12-16T00:35:22.107Z" }, -] - -[[package]] -name = "google-resumable-media" -version = "2.8.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-crc32c" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/3f/d1/b1ea14b93b6b78f57fc580125de44e9f593ab88dd2460f1a8a8d18f74754/google_resumable_media-2.8.2.tar.gz", hash = "sha256:f3354a182ebd193ae3f42e3ef95e6c9b10f128320de23ac7637236713b1acd70", size = 2164510, upload-time = "2026-03-30T23:34:25.369Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/5e/f8/50bfaf4658431ff9de45c5c3935af7ab01157a4903c603cd0eee6e78e087/google_resumable_media-2.8.2-py3-none-any.whl", hash = "sha256:82b6d8ccd11765268cdd2a2123f417ec806b8eef3000a9a38dfe3033da5fb220", size = 81511, upload-time = "2026-03-30T23:34:09.671Z" }, -] - [[package]] name = "googleapis-common-protos" version = "1.74.0" @@ -1048,6 +951,33 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c1/73/04df8a6fa66d43a9fd45c30f283cc4afff17da671886e451d52af60bdc7e/jsonpickle-4.1.1-py3-none-any.whl", hash = "sha256:bb141da6057898aa2438ff268362b126826c812a1721e31cf08a6e142910dc91", size = 47125, upload-time = "2025-06-02T20:36:08.647Z" }, ] +[[package]] +name = "jsonschema" +version = "4.26.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "jsonschema-specifications" }, + { name = "referencing" }, + { name = "rpds-py" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b3/fc/e067678238fa451312d4c62bf6e6cf5ec56375422aee02f9cb5f909b3047/jsonschema-4.26.0.tar.gz", hash = "sha256:0c26707e2efad8aa1bfc5b7ce170f3fccc2e4918ff85989ba9ffa9facb2be326", size = 366583, upload-time = "2026-01-07T13:41:07.246Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/69/90/f63fb5873511e014207a475e2bb4e8b2e570d655b00ac19a9a0ca0a385ee/jsonschema-4.26.0-py3-none-any.whl", hash = "sha256:d489f15263b8d200f8387e64b4c3a75f06629559fb73deb8fdfb525f2dab50ce", size = 90630, upload-time = "2026-01-07T13:41:05.306Z" }, +] + +[[package]] +name = "jsonschema-specifications" +version = "2025.9.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "referencing" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/19/74/a633ee74eb36c44aa6d1095e7cc5569bebf04342ee146178e2d36600708b/jsonschema_specifications-2025.9.1.tar.gz", hash = "sha256:b540987f239e745613c7a9176f3edb72b832a4ac465cf02712288397832b5e8d", size = 32855, upload-time = "2025-09-08T01:34:59.186Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/41/45/1a4ed80516f02155c51f51e8cedb3c1902296743db0bbc66608a0db2814f/jsonschema_specifications-2025.9.1-py3-none-any.whl", hash = "sha256:98802fee3a11ee76ecaca44429fda8a41bff98b00a0f2838151b113f210cc6fe", size = 18437, upload-time = "2025-09-08T01:34:57.871Z" }, +] + [[package]] name = "logfire" version = "4.6.0" @@ -1681,19 +1611,22 @@ wheels = [ [[package]] name = "policyengine" -version = "0.13.0" +version = "4.10.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "diskcache" }, - { name = "getpass4" }, - { name = "google-cloud-storage" }, + { name = "h5py" }, + { name = "jsonschema" }, { name = "microdf-python" }, - { name = "policyengine-core" }, - { name = "policyengine-uk" }, - { name = "policyengine-us" }, + { name = "packaging" }, + { name = "pandas" }, + { name = "psutil" }, { name = "pydantic" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/10/27/59ca969ab71647d526f6a7553f93cb61a1853d3f4fcc88552f08292be8a2/policyengine-4.10.0.tar.gz", hash = "sha256:68f634d107bd3ac81427364b03203a7d80407599cae9d13ff44231001436daa6", size = 571499, upload-time = "2026-05-21T19:04:09.051Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ee/d2/b4d11bd59e87da0376255779da7c0fc8afc1141cb33e8891a1fa5ce2c7e5/policyengine-4.10.0-py3-none-any.whl", hash = "sha256:db7454a3bf9cbc791ed8b8ccb7d9d5dcb8f4a08f93bb6c2fb3cf920d1dcce7c2", size = 189098, upload-time = "2026-05-21T19:04:07.296Z" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a0/f3/eeea7dab690e46cd91533691eef41097f1c2e9eb729d4f70408b865c750e/policyengine-0.13.0.tar.gz", hash = "sha256:15cf9f0ff0801c8cf12fbaeabe5e03e3cc2822cd3436b08553cf2ef0e00673ba", size = 230501, upload-time = "2026-04-08T13:41:59.62Z" } [[package]] name = "policyengine-core" @@ -1799,8 +1732,8 @@ requires-dist = [ { name = "openapi-python-client", marker = "extra == 'build'", specifier = ">=0.21.6" }, { name = "opentelemetry-instrumentation-fastapi", specifier = ">=0.51b0,<0.52" }, { name = "opentelemetry-instrumentation-sqlalchemy", specifier = ">=0.51b0,<0.52" }, - { name = "policyengine", specifier = "==0.13.0" }, - { name = "policyengine-core", specifier = ">=3.23.5" }, + { name = "policyengine", specifier = "==4.10.0" }, + { name = "policyengine-core", specifier = "==3.26.1" }, { name = "policyengine-fastapi", editable = "../../libs/policyengine-fastapi" }, { name = "policyengine-uk", specifier = "==2.88.20" }, { name = "policyengine-us", specifier = "==1.702.0" }, @@ -2072,15 +2005,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e5/7a/8dd906bd22e79e47397a61742927f6747fe93242ef86645ee9092e610244/pyjwt-2.12.1-py3-none-any.whl", hash = "sha256:28ca37c070cad8ba8cd9790cd940535d40274d22f80ab87f3ac6a713e6e8454c", size = 29726, upload-time = "2026-03-13T19:27:35.677Z" }, ] -[[package]] -name = "pyperclip" -version = "1.11.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e8/52/d87eba7cb129b81563019d1679026e7a112ef76855d6159d24754dbd2a51/pyperclip-1.11.0.tar.gz", hash = "sha256:244035963e4428530d9e3a6101a1ef97209c6825edab1567beac148ccc1db1b6", size = 12185, upload-time = "2025-09-26T14:40:37.245Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/df/80/fc9d01d5ed37ba4c42ca2b55b4339ae6e200b456be3a1aaddf4a9fa99b8c/pyperclip-1.11.0-py3-none-any.whl", hash = "sha256:299403e9ff44581cb9ba2ffeed69c7aa96a008622ad0c46cb575ca75b5b84273", size = 11063, upload-time = "2025-09-26T14:40:36.069Z" }, -] - [[package]] name = "pyright" version = "1.1.409" @@ -2221,6 +2145,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/73/e8/2bdf3ca2090f68bb3d75b44da7bbc71843b19c9f2b9cb9b0f4ab7a5a4329/pyyaml-6.0.3-cp313-cp313-win_arm64.whl", hash = "sha256:5498cd1645aa724a7c71c8f378eb29ebe23da2fc0d7a08071d89469bf1d2defb", size = 140246, upload-time = "2025-09-25T21:32:34.663Z" }, ] +[[package]] +name = "referencing" +version = "0.37.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "rpds-py" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/22/f5/df4e9027acead3ecc63e50fe1e36aca1523e1719559c499951bb4b53188f/referencing-0.37.0.tar.gz", hash = "sha256:44aefc3142c5b842538163acb373e24cce6632bd54bdb01b21ad5863489f50d8", size = 78036, upload-time = "2025-10-13T15:30:48.871Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/58/ca301544e1fa93ed4f80d724bf5b194f6e4b945841c5bfd555878eea9fcb/referencing-0.37.0-py3-none-any.whl", hash = "sha256:381329a9f99628c9069361716891d34ad94af76e461dcb0335825aecc7692231", size = 26766, upload-time = "2025-10-13T15:30:47.625Z" }, +] + [[package]] name = "requests" version = "2.33.1" @@ -2286,6 +2223,43 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/aa/41/e26a075cab83debe41a42661262f606166157df84e0e02e2d904d134c0d8/rignore-0.7.6-cp313-cp313-win_arm64.whl", hash = "sha256:e47443de9b12fe569889bdbe020abe0e0b667516ee2ab435443f6d0869bd2804", size = 656184, upload-time = "2025-11-05T21:41:27.396Z" }, ] +[[package]] +name = "rpds-py" +version = "0.30.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/20/af/3f2f423103f1113b36230496629986e0ef7e199d2aa8392452b484b38ced/rpds_py-0.30.0.tar.gz", hash = "sha256:dd8ff7cf90014af0c0f787eea34794ebf6415242ee1d6fa91eaba725cc441e84", size = 69469, upload-time = "2025-11-30T20:24:38.837Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ed/dc/d61221eb88ff410de3c49143407f6f3147acf2538c86f2ab7ce65ae7d5f9/rpds_py-0.30.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:f83424d738204d9770830d35290ff3273fbb02b41f919870479fab14b9d303b2", size = 374887, upload-time = "2025-11-30T20:22:41.812Z" }, + { url = "https://files.pythonhosted.org/packages/fd/32/55fb50ae104061dbc564ef15cc43c013dc4a9f4527a1f4d99baddf56fe5f/rpds_py-0.30.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e7536cd91353c5273434b4e003cbda89034d67e7710eab8761fd918ec6c69cf8", size = 358904, upload-time = "2025-11-30T20:22:43.479Z" }, + { url = "https://files.pythonhosted.org/packages/58/70/faed8186300e3b9bdd138d0273109784eea2396c68458ed580f885dfe7ad/rpds_py-0.30.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2771c6c15973347f50fece41fc447c054b7ac2ae0502388ce3b6738cd366e3d4", size = 389945, upload-time = "2025-11-30T20:22:44.819Z" }, + { url = "https://files.pythonhosted.org/packages/bd/a8/073cac3ed2c6387df38f71296d002ab43496a96b92c823e76f46b8af0543/rpds_py-0.30.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0a59119fc6e3f460315fe9d08149f8102aa322299deaa5cab5b40092345c2136", size = 407783, upload-time = "2025-11-30T20:22:46.103Z" }, + { url = "https://files.pythonhosted.org/packages/77/57/5999eb8c58671f1c11eba084115e77a8899d6e694d2a18f69f0ba471ec8b/rpds_py-0.30.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:76fec018282b4ead0364022e3c54b60bf368b9d926877957a8624b58419169b7", size = 515021, upload-time = "2025-11-30T20:22:47.458Z" }, + { url = "https://files.pythonhosted.org/packages/e0/af/5ab4833eadc36c0a8ed2bc5c0de0493c04f6c06de223170bd0798ff98ced/rpds_py-0.30.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:692bef75a5525db97318e8cd061542b5a79812d711ea03dbc1f6f8dbb0c5f0d2", size = 414589, upload-time = "2025-11-30T20:22:48.872Z" }, + { url = "https://files.pythonhosted.org/packages/b7/de/f7192e12b21b9e9a68a6d0f249b4af3fdcdff8418be0767a627564afa1f1/rpds_py-0.30.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9027da1ce107104c50c81383cae773ef5c24d296dd11c99e2629dbd7967a20c6", size = 394025, upload-time = "2025-11-30T20:22:50.196Z" }, + { url = "https://files.pythonhosted.org/packages/91/c4/fc70cd0249496493500e7cc2de87504f5aa6509de1e88623431fec76d4b6/rpds_py-0.30.0-cp313-cp313-manylinux_2_31_riscv64.whl", hash = "sha256:9cf69cdda1f5968a30a359aba2f7f9aa648a9ce4b580d6826437f2b291cfc86e", size = 408895, upload-time = "2025-11-30T20:22:51.87Z" }, + { url = "https://files.pythonhosted.org/packages/58/95/d9275b05ab96556fefff73a385813eb66032e4c99f411d0795372d9abcea/rpds_py-0.30.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a4796a717bf12b9da9d3ad002519a86063dcac8988b030e405704ef7d74d2d9d", size = 422799, upload-time = "2025-11-30T20:22:53.341Z" }, + { url = "https://files.pythonhosted.org/packages/06/c1/3088fc04b6624eb12a57eb814f0d4997a44b0d208d6cace713033ff1a6ba/rpds_py-0.30.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:5d4c2aa7c50ad4728a094ebd5eb46c452e9cb7edbfdb18f9e1221f597a73e1e7", size = 572731, upload-time = "2025-11-30T20:22:54.778Z" }, + { url = "https://files.pythonhosted.org/packages/d8/42/c612a833183b39774e8ac8fecae81263a68b9583ee343db33ab571a7ce55/rpds_py-0.30.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:ba81a9203d07805435eb06f536d95a266c21e5b2dfbf6517748ca40c98d19e31", size = 599027, upload-time = "2025-11-30T20:22:56.212Z" }, + { url = "https://files.pythonhosted.org/packages/5f/60/525a50f45b01d70005403ae0e25f43c0384369ad24ffe46e8d9068b50086/rpds_py-0.30.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:945dccface01af02675628334f7cf49c2af4c1c904748efc5cf7bbdf0b579f95", size = 563020, upload-time = "2025-11-30T20:22:58.2Z" }, + { url = "https://files.pythonhosted.org/packages/0b/5d/47c4655e9bcd5ca907148535c10e7d489044243cc9941c16ed7cd53be91d/rpds_py-0.30.0-cp313-cp313-win32.whl", hash = "sha256:b40fb160a2db369a194cb27943582b38f79fc4887291417685f3ad693c5a1d5d", size = 223139, upload-time = "2025-11-30T20:23:00.209Z" }, + { url = "https://files.pythonhosted.org/packages/f2/e1/485132437d20aa4d3e1d8b3fb5a5e65aa8139f1e097080c2a8443201742c/rpds_py-0.30.0-cp313-cp313-win_amd64.whl", hash = "sha256:806f36b1b605e2d6a72716f321f20036b9489d29c51c91f4dd29a3e3afb73b15", size = 240224, upload-time = "2025-11-30T20:23:02.008Z" }, + { url = "https://files.pythonhosted.org/packages/24/95/ffd128ed1146a153d928617b0ef673960130be0009c77d8fbf0abe306713/rpds_py-0.30.0-cp313-cp313-win_arm64.whl", hash = "sha256:d96c2086587c7c30d44f31f42eae4eac89b60dabbac18c7669be3700f13c3ce1", size = 230645, upload-time = "2025-11-30T20:23:03.43Z" }, + { url = "https://files.pythonhosted.org/packages/ff/1b/b10de890a0def2a319a2626334a7f0ae388215eb60914dbac8a3bae54435/rpds_py-0.30.0-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:eb0b93f2e5c2189ee831ee43f156ed34e2a89a78a66b98cadad955972548be5a", size = 364443, upload-time = "2025-11-30T20:23:04.878Z" }, + { url = "https://files.pythonhosted.org/packages/0d/bf/27e39f5971dc4f305a4fb9c672ca06f290f7c4e261c568f3dea16a410d47/rpds_py-0.30.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:922e10f31f303c7c920da8981051ff6d8c1a56207dbdf330d9047f6d30b70e5e", size = 353375, upload-time = "2025-11-30T20:23:06.342Z" }, + { url = "https://files.pythonhosted.org/packages/40/58/442ada3bba6e8e6615fc00483135c14a7538d2ffac30e2d933ccf6852232/rpds_py-0.30.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cdc62c8286ba9bf7f47befdcea13ea0e26bf294bda99758fd90535cbaf408000", size = 383850, upload-time = "2025-11-30T20:23:07.825Z" }, + { url = "https://files.pythonhosted.org/packages/14/14/f59b0127409a33c6ef6f5c1ebd5ad8e32d7861c9c7adfa9a624fc3889f6c/rpds_py-0.30.0-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:47f9a91efc418b54fb8190a6b4aa7813a23fb79c51f4bb84e418f5476c38b8db", size = 392812, upload-time = "2025-11-30T20:23:09.228Z" }, + { url = "https://files.pythonhosted.org/packages/b3/66/e0be3e162ac299b3a22527e8913767d869e6cc75c46bd844aa43fb81ab62/rpds_py-0.30.0-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1f3587eb9b17f3789ad50824084fa6f81921bbf9a795826570bda82cb3ed91f2", size = 517841, upload-time = "2025-11-30T20:23:11.186Z" }, + { url = "https://files.pythonhosted.org/packages/3d/55/fa3b9cf31d0c963ecf1ba777f7cf4b2a2c976795ac430d24a1f43d25a6ba/rpds_py-0.30.0-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:39c02563fc592411c2c61d26b6c5fe1e51eaa44a75aa2c8735ca88b0d9599daa", size = 408149, upload-time = "2025-11-30T20:23:12.864Z" }, + { url = "https://files.pythonhosted.org/packages/60/ca/780cf3b1a32b18c0f05c441958d3758f02544f1d613abf9488cd78876378/rpds_py-0.30.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51a1234d8febafdfd33a42d97da7a43f5dcb120c1060e352a3fbc0c6d36e2083", size = 383843, upload-time = "2025-11-30T20:23:14.638Z" }, + { url = "https://files.pythonhosted.org/packages/82/86/d5f2e04f2aa6247c613da0c1dd87fcd08fa17107e858193566048a1e2f0a/rpds_py-0.30.0-cp313-cp313t-manylinux_2_31_riscv64.whl", hash = "sha256:eb2c4071ab598733724c08221091e8d80e89064cd472819285a9ab0f24bcedb9", size = 396507, upload-time = "2025-11-30T20:23:16.105Z" }, + { url = "https://files.pythonhosted.org/packages/4b/9a/453255d2f769fe44e07ea9785c8347edaf867f7026872e76c1ad9f7bed92/rpds_py-0.30.0-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6bdfdb946967d816e6adf9a3d8201bfad269c67efe6cefd7093ef959683c8de0", size = 414949, upload-time = "2025-11-30T20:23:17.539Z" }, + { url = "https://files.pythonhosted.org/packages/a3/31/622a86cdc0c45d6df0e9ccb6becdba5074735e7033c20e401a6d9d0e2ca0/rpds_py-0.30.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:c77afbd5f5250bf27bf516c7c4a016813eb2d3e116139aed0096940c5982da94", size = 565790, upload-time = "2025-11-30T20:23:19.029Z" }, + { url = "https://files.pythonhosted.org/packages/1c/5d/15bbf0fb4a3f58a3b1c67855ec1efcc4ceaef4e86644665fff03e1b66d8d/rpds_py-0.30.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:61046904275472a76c8c90c9ccee9013d70a6d0f73eecefd38c1ae7c39045a08", size = 590217, upload-time = "2025-11-30T20:23:20.885Z" }, + { url = "https://files.pythonhosted.org/packages/6d/61/21b8c41f68e60c8cc3b2e25644f0e3681926020f11d06ab0b78e3c6bbff1/rpds_py-0.30.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:4c5f36a861bc4b7da6516dbdf302c55313afa09b81931e8280361a4f6c9a2d27", size = 555806, upload-time = "2025-11-30T20:23:22.488Z" }, + { url = "https://files.pythonhosted.org/packages/f9/39/7e067bb06c31de48de3eb200f9fc7c58982a4d3db44b07e73963e10d3be9/rpds_py-0.30.0-cp313-cp313t-win32.whl", hash = "sha256:3d4a69de7a3e50ffc214ae16d79d8fbb0922972da0356dcf4d0fdca2878559c6", size = 211341, upload-time = "2025-11-30T20:23:24.449Z" }, + { url = "https://files.pythonhosted.org/packages/0a/4d/222ef0b46443cf4cf46764d9c630f3fe4abaa7245be9417e56e9f52b8f65/rpds_py-0.30.0-cp313-cp313t-win_amd64.whl", hash = "sha256:f14fc5df50a716f7ece6a80b6c78bb35ea2ca47c499e422aa4463455dd96d56d", size = 225768, upload-time = "2025-11-30T20:23:25.908Z" }, +] + [[package]] name = "ruamel-yaml" version = "0.19.1" From 3b145e950a5ab3c37f9b0517c48ec150534b4887 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 21 May 2026 21:51:35 +0200 Subject: [PATCH 04/23] docs: add API v2 AI guidance harness --- .claude/skills/github-prs.md | 26 ++++++++++++++ .claude/skills/project-setup/SKILL.md | 50 +++++++++++++++++++++++++++ .gitignore | 3 +- AGENTS.md | 40 +++++++++++++++++++++ CLAUDE.md | 37 ++++++++++++++++++++ README.md | 11 +++--- 6 files changed, 161 insertions(+), 6 deletions(-) create mode 100644 .claude/skills/github-prs.md create mode 100644 .claude/skills/project-setup/SKILL.md create mode 100644 AGENTS.md create mode 100644 CLAUDE.md diff --git a/.claude/skills/github-prs.md b/.claude/skills/github-prs.md new file mode 100644 index 000000000..94e262b72 --- /dev/null +++ b/.claude/skills/github-prs.md @@ -0,0 +1,26 @@ +# GitHub PR Workflow + +Use this guidance before opening, replacing, or marking ready any pull request. + +## Required Flow + +1. Open a GitHub issue for the work unless the user explicitly says not to. +2. Create or update a feature branch based on `origin/main`. +3. Run formatting and the most relevant tests for the changed surface. +4. Push the branch. +5. Open a same-repository draft PR. +6. Put `Fixes #ISSUE_NUMBER` as the first line of the PR description. +7. Add `Summary` and `Testing` sections below the `Fixes #...` line. + +## Suggested Commands + +```bash +git fetch origin main +git rebase origin/main +git push -u origin "$(git branch --show-current)" +gh issue create --repo PolicyEngine/policyengine-api-v2 +gh pr create --draft --repo PolicyEngine/policyengine-api-v2 \ + --head "$(git branch --show-current)" --base main +``` + +If a check was not run, note that explicitly in the PR body. diff --git a/.claude/skills/project-setup/SKILL.md b/.claude/skills/project-setup/SKILL.md new file mode 100644 index 000000000..f06dbbc2e --- /dev/null +++ b/.claude/skills/project-setup/SKILL.md @@ -0,0 +1,50 @@ +--- +name: Project Setup +description: > + Use when the user asks about local development setup, installing + dependencies, running tests, generating clients, deploying, or understanding + the PolicyEngine API v2 repository layout. +version: 0.1.0 +--- + +# Project Setup + +PolicyEngine API v2 is a monorepo with service projects, shared libraries, and +deployment configuration. + +## Common Commands + +```bash +make setup +make up +make logs +make down +make format +make check +make test +make test-complete +./scripts/generate-clients.sh +``` + +## Service-Scoped Simulation Commands + +```bash +cd projects/policyengine-api-simulation +uv sync --extra test +uv run pytest tests/ -v +``` + +## Project Layout + +| Path | Purpose | +| ---- | ------- | +| `projects/policyengine-api-simulation/` | Simulation API gateway and Modal worker | +| `projects/policyengine-apis-integ/` | Generated-client integration tests | +| `libs/policyengine-fastapi/` | Shared FastAPI utilities | +| `deployment/` | Docker Compose and Terraform deployment configuration | +| `.github/workflows/pr.yml` | Pull request checks | +| `.github/workflows/modal-deploy.yml` | Main-branch Modal deployment | + +## PR Preparation + +Before opening a PR, read `.claude/skills/github-prs.md` and `AGENTS.md`. diff --git a/.gitignore b/.gitignore index c9f5d6c52..5740d9515 100644 --- a/.gitignore +++ b/.gitignore @@ -11,7 +11,6 @@ artifacts/ .bootstrap_settings apply.tfvars backend.tfvars -CLAUDE.md .vscode deployment/terraform/*/auto.tfvars -.claude-plan \ No newline at end of file +.claude-plan diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 000000000..5f61cc962 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,40 @@ +# PolicyEngine API v2 Agent Guide + +This repository is the PolicyEngine API v2 monorepo. It contains service +projects under `projects/`, shared libraries under `libs/`, deployment +configuration under `deployment/`, and GitHub Actions under `.github/`. + +## Development + +- Use Python 3.13 and `uv`. +- Prefer repo Makefile targets when they match the task: + - `make format` formats Python source with Ruff. + - `make check` runs Ruff and Pyright over source directories. + - `make test` runs service unit tests in Docker. + - `make test-complete` runs unit and local integration tests. +- For service-scoped work, run focused `uv run pytest ...` commands from the + relevant project directory when that is faster and sufficient. +- Regenerate generated clients with `./scripts/generate-clients.sh` when API + schemas change. + +## Pull Requests + +- Do not commit directly to `main`. +- For non-trivial work, open a GitHub issue before opening a PR. +- Open same-repository draft PRs by default. +- The first line of every draft PR description must be + `Fixes #ISSUE_NUMBER`. +- Include concise `Summary` and `Testing` sections after the `Fixes #...` + line. +- Before opening a PR, run formatting and the most relevant tests for the + changed surface. If you cannot run expected checks, say so in the PR body. +- Use `gh` for GitHub operations when possible so repository Actions and + permissions behave consistently. + +## Repository Notes + +- The simulation service lives in `projects/policyengine-api-simulation`. +- API integration tests live in `projects/policyengine-apis-integ`. +- PR CI runs simulation unit tests, Ruff format checks, Docker build, and local + integration tests. +- There is currently no repository-wide changelog fragment requirement. diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..1363e2b5f --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,37 @@ +# PolicyEngine API v2 + +FastAPI and Modal-based API infrastructure for PolicyEngine services. + +Claude and other AI coding agents should follow the repository guidance in +`AGENTS.md`. In particular: + +- open an issue before non-trivial PRs, +- open same-repository draft PRs by default, +- start each draft PR description with `Fixes #ISSUE_NUMBER`, +- run formatting and relevant tests before opening the PR. + +## Common Commands + +```bash +make format +make check +make test +make test-complete +./scripts/generate-clients.sh +``` + +For simulation-service-only work: + +```bash +cd projects/policyengine-api-simulation +uv sync --extra test +uv run pytest tests/ -v +``` + +## Key Paths + +- `projects/policyengine-api-simulation/` — simulation gateway and Modal worker +- `projects/policyengine-apis-integ/` — generated-client integration tests +- `libs/policyengine-fastapi/` — shared FastAPI utilities +- `.github/workflows/pr.yml` — PR checks +- `.github/workflows/modal-deploy.yml` — main-branch Modal deployment diff --git a/README.md b/README.md index 6f6f75ce1..a6f952f85 100644 --- a/README.md +++ b/README.md @@ -170,7 +170,10 @@ Configure GitHub environments with these variables: ## Contributing 1. Create a feature branch -2. Make changes and test locally -3. Ensure `make test-complete` passes -4. Open a PR with a clear description -5. Wait for CI checks to pass \ No newline at end of file +2. Open a GitHub issue for non-trivial work +3. Make changes and test locally +4. Ensure `make test-complete` passes when feasible +5. Open a same-repository draft PR with `Fixes #ISSUE_NUMBER` as the first line + of the description +6. Include a clear summary and testing notes +7. Wait for CI checks to pass From 7fdecced4bff98453ef6d458a8c43bff831492ec Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 21 May 2026 22:08:58 +0200 Subject: [PATCH 05/23] docs: switch to model-agnostic AI harness --- .claude/skills/github-prs.md | 26 ----------- .claude/skills/project-setup/SKILL.md | 50 -------------------- .github/CONTRIBUTING.md | 66 +++++++++++++++++++++++++++ .github/copilot-instructions.md | 11 +++++ .github/pull_request_template.md | 14 ++++++ AGENTS.md | 63 ++++++++++++++++--------- CLAUDE.md | 46 ++++++++----------- Makefile | 21 ++++++++- README.md | 4 ++ docs/engineering/skills/README.md | 15 ++++++ docs/engineering/skills/github-prs.md | 39 ++++++++++++++++ docs/engineering/skills/testing.md | 53 +++++++++++++++++++++ 12 files changed, 281 insertions(+), 127 deletions(-) delete mode 100644 .claude/skills/github-prs.md delete mode 100644 .claude/skills/project-setup/SKILL.md create mode 100644 .github/CONTRIBUTING.md create mode 100644 .github/copilot-instructions.md create mode 100644 .github/pull_request_template.md create mode 100644 docs/engineering/skills/README.md create mode 100644 docs/engineering/skills/github-prs.md create mode 100644 docs/engineering/skills/testing.md diff --git a/.claude/skills/github-prs.md b/.claude/skills/github-prs.md deleted file mode 100644 index 94e262b72..000000000 --- a/.claude/skills/github-prs.md +++ /dev/null @@ -1,26 +0,0 @@ -# GitHub PR Workflow - -Use this guidance before opening, replacing, or marking ready any pull request. - -## Required Flow - -1. Open a GitHub issue for the work unless the user explicitly says not to. -2. Create or update a feature branch based on `origin/main`. -3. Run formatting and the most relevant tests for the changed surface. -4. Push the branch. -5. Open a same-repository draft PR. -6. Put `Fixes #ISSUE_NUMBER` as the first line of the PR description. -7. Add `Summary` and `Testing` sections below the `Fixes #...` line. - -## Suggested Commands - -```bash -git fetch origin main -git rebase origin/main -git push -u origin "$(git branch --show-current)" -gh issue create --repo PolicyEngine/policyengine-api-v2 -gh pr create --draft --repo PolicyEngine/policyengine-api-v2 \ - --head "$(git branch --show-current)" --base main -``` - -If a check was not run, note that explicitly in the PR body. diff --git a/.claude/skills/project-setup/SKILL.md b/.claude/skills/project-setup/SKILL.md deleted file mode 100644 index f06dbbc2e..000000000 --- a/.claude/skills/project-setup/SKILL.md +++ /dev/null @@ -1,50 +0,0 @@ ---- -name: Project Setup -description: > - Use when the user asks about local development setup, installing - dependencies, running tests, generating clients, deploying, or understanding - the PolicyEngine API v2 repository layout. -version: 0.1.0 ---- - -# Project Setup - -PolicyEngine API v2 is a monorepo with service projects, shared libraries, and -deployment configuration. - -## Common Commands - -```bash -make setup -make up -make logs -make down -make format -make check -make test -make test-complete -./scripts/generate-clients.sh -``` - -## Service-Scoped Simulation Commands - -```bash -cd projects/policyengine-api-simulation -uv sync --extra test -uv run pytest tests/ -v -``` - -## Project Layout - -| Path | Purpose | -| ---- | ------- | -| `projects/policyengine-api-simulation/` | Simulation API gateway and Modal worker | -| `projects/policyengine-apis-integ/` | Generated-client integration tests | -| `libs/policyengine-fastapi/` | Shared FastAPI utilities | -| `deployment/` | Docker Compose and Terraform deployment configuration | -| `.github/workflows/pr.yml` | Pull request checks | -| `.github/workflows/modal-deploy.yml` | Main-branch Modal deployment | - -## PR Preparation - -Before opening a PR, read `.claude/skills/github-prs.md` and `AGENTS.md`. diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md new file mode 100644 index 000000000..9f0f75ea2 --- /dev/null +++ b/.github/CONTRIBUTING.md @@ -0,0 +1,66 @@ +# Contributing to policyengine-api-v2 + +See the shared PolicyEngine contribution guide for cross-repo conventions. This +file covers policyengine-api-v2 specifics. + +## Commands + +```bash +make format # format Python source with Ruff +make check # Ruff check + Pyright over source directories +make test # service unit tests in Docker +make test-complete # unit tests + local integration tests +make push-pr-branch # push current branch to origin with tracking +./scripts/generate-clients.sh +``` + +For simulation-service-only checks: + +```bash +cd projects/policyengine-api-simulation +uv sync --extra test +uv run pytest tests/ -v +``` + +## Test Organisation + +- Service unit tests live under each service's `tests/` directory, for example + `projects/policyengine-api-simulation/tests/`. +- Generated-client integration tests live under + `projects/policyengine-apis-integ/tests/`. +- Unit tests should mock Modal, GCP, Hugging Face, and other network seams. +- Integration tests should use generated clients and clearly state any required + service, GCP, Modal, or staging dependency. + +## Opening PRs + +Always create branches on the canonical repository, not a fork. The convenience +target: + +```bash +make push-pr-branch +``` + +pushes the current branch to `origin` with the correct tracking so +`gh pr create` works. Then create and verify the PR explicitly: + +```bash +gh pr create --draft --repo PolicyEngine/policyengine-api-v2 --head "$(git branch --show-current)" --base main +gh pr view --repo PolicyEngine/policyengine-api-v2 --json isDraft,headRepositoryOwner,headRepository +``` + +Before opening the PR, open or identify a GitHub issue for the work. The first +line of the PR description must be `Fixes #ISSUE_NUMBER`. + +The PR is valid only if it is a draft and the head repository is +`PolicyEngine/policyengine-api-v2`. If you cannot push to that repository, ask +for access instead of opening a fork PR. + +## Repo-Specific Anti-Patterns + +- Do not open PRs from personal forks. +- Do not add `[codex]`, `[claude]`, `[copilot]`, or other agent labels to PR + titles. +- Do not hand-edit generated clients without documenting why regeneration is + not appropriate. +- Do not skip integration/client checks when changing public API schemas. diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md new file mode 100644 index 000000000..928809583 --- /dev/null +++ b/.github/copilot-instructions.md @@ -0,0 +1,11 @@ +# Copilot Instructions + +Follow the repository's canonical engineering skills under +`docs/engineering/skills/`. + +For tests, read `docs/engineering/skills/testing.md` before adding, moving, or +reviewing test files. Do not duplicate or override that testing guidance here. + +For pull requests, read `docs/engineering/skills/github-prs.md` before opening, +replacing, or sharing a PR. This repository only accepts same-repository PRs +from branches in `PolicyEngine/policyengine-api-v2`; never create fork PRs. diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 000000000..2b8f85442 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,14 @@ +Fixes # + +## Same-Repository Draft PR Check + +- [ ] This PR is a draft. +- [ ] This PR is from a branch in `PolicyEngine/policyengine-api-v2`, not a personal fork. + +## Summary + +- + +## Testing + +- diff --git a/AGENTS.md b/AGENTS.md index 5f61cc962..158546dab 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,35 +1,54 @@ -# PolicyEngine API v2 Agent Guide +# Agent Instructions -This repository is the PolicyEngine API v2 monorepo. It contains service -projects under `projects/`, shared libraries under `libs/`, deployment -configuration under `deployment/`, and GitHub Actions under `.github/`. +These instructions apply repository-wide. + +## Skills System + +Canonical AI-facing engineering skills live under `docs/engineering/skills/`. +Use those files as the source of truth across Codex, Claude, Copilot, and other +AI tools. + +Before opening, replacing, or sharing any pull request, read +`docs/engineering/skills/github-prs.md`. + +When adding, moving, or reviewing tests, read +`docs/engineering/skills/testing.md`. ## Development - Use Python 3.13 and `uv`. -- Prefer repo Makefile targets when they match the task: - - `make format` formats Python source with Ruff. - - `make check` runs Ruff and Pyright over source directories. - - `make test` runs service unit tests in Docker. - - `make test-complete` runs unit and local integration tests. +- Prefer Makefile targets when they match the task. - For service-scoped work, run focused `uv run pytest ...` commands from the relevant project directory when that is faster and sufficient. - Regenerate generated clients with `./scripts/generate-clients.sh` when API schemas change. -## Pull Requests - -- Do not commit directly to `main`. -- For non-trivial work, open a GitHub issue before opening a PR. -- Open same-repository draft PRs by default. -- The first line of every draft PR description must be - `Fixes #ISSUE_NUMBER`. -- Include concise `Summary` and `Testing` sections after the `Fixes #...` - line. -- Before opening a PR, run formatting and the most relevant tests for the - changed surface. If you cannot run expected checks, say so in the PR body. -- Use `gh` for GitHub operations when possible so repository Actions and - permissions behave consistently. +## GitHub PRs + +Never open `policyengine-api-v2` PRs from forks. CI and deployment checks are +designed for same-repository branches. + +Before creating or sharing any PR, all developers and agents must: + +1. Confirm the canonical repository is reachable: + `gh repo view PolicyEngine/policyengine-api-v2 --json nameWithOwner`. +2. Open a GitHub issue for the work, or verify that an appropriate issue + already exists. +3. Put `Fixes #ISSUE_NUMBER` as the first line of the PR description, using the + issue number from the issue created or found in the previous step. +4. Push the branch to the canonical repository, for example: + `make push-pr-branch`. +5. Create the PR as a draft from that same repository, for example: + `gh pr create --draft --repo PolicyEngine/policyengine-api-v2 --head "$(git branch --show-current)" --base main`. +6. Verify the PR is draft and the head repository is canonical before reporting + it: + `gh pr view --repo PolicyEngine/policyengine-api-v2 --json isDraft,headRepositoryOwner,headRepository`. + +The PR is valid only if `isDraft` is `true` and the head repository is +`PolicyEngine/policyengine-api-v2`. If you cannot push to the canonical +repository, stop and ask for access. Do not create a fork PR as a fallback. If +you accidentally create one, immediately close it and replace it with a +same-repository draft PR. ## Repository Notes diff --git a/CLAUDE.md b/CLAUDE.md index 1363e2b5f..8a6908e88 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,37 +1,27 @@ -# PolicyEngine API v2 +# Claude Instructions -FastAPI and Modal-based API infrastructure for PolicyEngine services. +These instructions apply repository-wide. -Claude and other AI coding agents should follow the repository guidance in -`AGENTS.md`. In particular: +## Canonical Guidance -- open an issue before non-trivial PRs, -- open same-repository draft PRs by default, -- start each draft PR description with `Fixes #ISSUE_NUMBER`, -- run formatting and relevant tests before opening the PR. +Repository-wide AI-facing engineering guidance lives in `AGENTS.md`. +Canonical skills live under `docs/engineering/skills/`. -## Common Commands +Use those files as the source of truth. This file is a Claude adapter and should +stay thin; do not duplicate detailed testing, CI, formatting, or architecture +rules here. -```bash -make format -make check -make test -make test-complete -./scripts/generate-clients.sh -``` +## Required Skill Lookup -For simulation-service-only work: +Before opening, replacing, or sharing a PR, read +`docs/engineering/skills/github-prs.md`. -```bash -cd projects/policyengine-api-simulation -uv sync --extra test -uv run pytest tests/ -v -``` +When adding, moving, or reviewing tests, read +`docs/engineering/skills/testing.md`. -## Key Paths +## Safety Boundaries -- `projects/policyengine-api-simulation/` — simulation gateway and Modal worker -- `projects/policyengine-apis-integ/` — generated-client integration tests -- `libs/policyengine-fastapi/` — shared FastAPI utilities -- `.github/workflows/pr.yml` — PR checks -- `.github/workflows/modal-deploy.yml` — main-branch Modal deployment +Do not commit directly to `main`. + +Do not make production deployment changes or publish clients unless the user +explicitly asks for that operation. diff --git a/Makefile b/Makefile index 092ec74ac..64b5e64e2 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # Simplified Makefile using docker-compose -.PHONY: help dev up down build test deploy clean logs format check terraform-deploy +.PHONY: help dev up down build test deploy clean logs format check terraform-deploy push-pr-branch # Load environment variables if .env exists ifneq (,$(wildcard deployment/.env)) @@ -35,6 +35,7 @@ help: @echo " make clean - Clean up containers and volumes" @echo " make format - Format Python code with ruff" @echo " make check - Run code quality checks" + @echo " make push-pr-branch - Push current branch to origin for PR creation" # Initialize GCP (enables APIs, creates bucket, etc) init-gcp: check-deploy-env @@ -331,6 +332,24 @@ check: fi \ done +push-pr-branch: + @BRANCH=$$(git branch --show-current); \ + if [ -z "$$BRANCH" ]; then \ + echo "Unable to determine current branch"; \ + exit 1; \ + fi; \ + if [ "$$BRANCH" = "main" ]; then \ + echo "Refusing to open a PR from main"; \ + exit 1; \ + fi; \ + REMOTE_URL=$$(git remote get-url origin 2>/dev/null || true); \ + case "$$REMOTE_URL" in \ + *PolicyEngine/policyengine-api-v2* ) ;; \ + * ) echo "Missing canonical origin remote PolicyEngine/policyengine-api-v2"; exit 1 ;; \ + esac; \ + git push -u origin HEAD:$$BRANCH; \ + echo "Create the PR with: gh pr create --draft --repo PolicyEngine/policyengine-api-v2 --head $$BRANCH --base main" + # Integration tests integ-test: cd projects/policyengine-apis-integ && uv run pytest diff --git a/README.md b/README.md index a6f952f85..d82b98627 100644 --- a/README.md +++ b/README.md @@ -169,6 +169,10 @@ Configure GitHub environments with these variables: ## Contributing +See [AGENTS.md](AGENTS.md) and +[docs/engineering/skills/github-prs.md](docs/engineering/skills/github-prs.md) +for the canonical same-repository draft PR workflow. + 1. Create a feature branch 2. Open a GitHub issue for non-trivial work 3. Make changes and test locally diff --git a/docs/engineering/skills/README.md b/docs/engineering/skills/README.md new file mode 100644 index 000000000..7e4db0bdc --- /dev/null +++ b/docs/engineering/skills/README.md @@ -0,0 +1,15 @@ +# Engineering Skills + +This directory is the canonical source for AI-facing engineering rules. + +Tool-specific instruction files such as `AGENTS.md`, `CLAUDE.md`, and +`.github/copilot-instructions.md` should point here instead of duplicating +implementation-specific guidance. When a rule changes, update the skill here +first, then keep adapters thin. + +Current skills: + +- `github-prs.md`: same-repository draft PR workflow, issue linkage, PR head + verification, and title conventions. +- `testing.md`: test placement, dependency boundaries, and expected validation + commands. diff --git a/docs/engineering/skills/github-prs.md b/docs/engineering/skills/github-prs.md new file mode 100644 index 000000000..7dd115ac4 --- /dev/null +++ b/docs/engineering/skills/github-prs.md @@ -0,0 +1,39 @@ +# GitHub PRs + +These rules apply to every developer and AI agent opening pull requests in this +repository. + +## Same-Repository Draft PRs Only + +Open PRs from branches in `PolicyEngine/policyengine-api-v2`, not from personal +forks. CI and deployment checks are designed around same-repository branches. + +Before creating or sharing a PR: + +1. Confirm the canonical repository is reachable: + `gh repo view PolicyEngine/policyengine-api-v2 --json nameWithOwner`. +2. Open a GitHub issue for the work, or verify that an appropriate issue + already exists. +3. Put `Fixes #ISSUE_NUMBER` as the first line of the PR description, using the + issue number from the issue created or found in the previous step. +4. Run formatting and the most relevant tests for the changed surface. If an + expected check cannot be run, say so in the PR body. +5. Push the current branch to the canonical repository: + `make push-pr-branch`. +6. Create the PR as a draft from that same repository: + `gh pr create --draft --repo PolicyEngine/policyengine-api-v2 --head "$(git branch --show-current)" --base main`. +7. Verify the PR is draft and the head repository is canonical: + `gh pr view --repo PolicyEngine/policyengine-api-v2 --json isDraft,headRepositoryOwner,headRepository`. +8. Leave the PR as draft unless a maintainer explicitly asks for it to be + marked ready for review. + +The PR is valid only if `isDraft` is `true` and the head repository is +`PolicyEngine/policyengine-api-v2`. If you cannot push to the canonical +repository, stop and ask for access. Do not create a fork PR as a fallback. If +you accidentally create one, close it immediately and replace it with a +same-repository draft PR. + +## PR Title + +Do not add `[codex]`, `[claude]`, `[copilot]`, or other agent labels to PR +titles. Use a plain descriptive title. diff --git a/docs/engineering/skills/testing.md b/docs/engineering/skills/testing.md new file mode 100644 index 000000000..a47a4c43d --- /dev/null +++ b/docs/engineering/skills/testing.md @@ -0,0 +1,53 @@ +# Testing Skill + +Use this skill whenever adding, moving, or reviewing tests. + +## Canonical Layout + +- Service unit tests live under each service's `tests/` directory, for example + `projects/policyengine-api-simulation/tests/`. +- Generated-client integration tests live under + `projects/policyengine-apis-integ/tests/`. +- Put reusable test helpers in local fixture modules or support modules near + the tests that use them. +- Avoid importing helpers across unrelated test lanes. Move shared helpers to a + neutral support module when needed. + +## Dependency Boundaries + +- Unit tests should not require real network credentials, Modal, Hugging Face, + GCP, or deployed services. Mock those seams. +- Integration tests may require services, generated clients, or deployed + environments, but should be explicit about those requirements and skip or mark + cleanly when unavailable. +- When changing public API schemas, regenerate clients and run the relevant + generated-client integration tests. + +## Common Commands + +Run the narrowest meaningful checks during development, then broader checks +before opening or updating a PR when feasible: + +```bash +make format +make check +make test +make test-complete +``` + +Simulation-service focused checks: + +```bash +cd projects/policyengine-api-simulation +uv sync --extra test +uv run pytest tests/ -v +``` + +Integration checks: + +```bash +./scripts/generate-clients.sh +cd projects/policyengine-apis-integ +uv sync --extra test +uv run pytest tests/ -v -m "not requires_gcp and not beta_only" +``` From c0d57b99bf76b7f4ea7ac60bcad4eb58c0a42f5a Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 21 May 2026 22:30:43 +0200 Subject: [PATCH 06/23] refactor: schematize simulation macro outputs --- .../src/modal/simulation_macro_output.py | 107 +++++++ .../src/modal/simulation_output_adapter.py | 285 ++++++++++++------ .../tests/test_simulation_api_contracts.py | 5 + .../tests/test_simulation_output_adapter.py | 49 ++- 4 files changed, 354 insertions(+), 92 deletions(-) create mode 100644 projects/policyengine-api-simulation/src/modal/simulation_macro_output.py diff --git a/projects/policyengine-api-simulation/src/modal/simulation_macro_output.py b/projects/policyengine-api-simulation/src/modal/simulation_macro_output.py new file mode 100644 index 000000000..343d5caac --- /dev/null +++ b/projects/policyengine-api-simulation/src/modal/simulation_macro_output.py @@ -0,0 +1,107 @@ +"""Internal schemas for the simulation API single-year macro output. + +These models define the legacy dictionary contract the simulation API returns +without exposing that schema through the gateway OpenAPI surface. The gateway +still treats job results as unstructured dictionaries for older callers. +""" + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, ConfigDict + + +class MacroOutputModel(BaseModel): + """Base model for internal macro output schemas.""" + + model_config = ConfigDict(extra="forbid") + + +class BudgetaryOutput(MacroOutputModel): + tax_revenue_impact: float + state_tax_revenue_impact: float + benefit_spending_impact: float + budgetary_impact: float + households: float + baseline_net_income: float + + +class DetailedBudgetProgramOutput(MacroOutputModel): + baseline: float + reform: float + difference: float + + +class DecileOutput(MacroOutputModel): + average: dict[str, float] + relative: dict[str, float] + + +class IntraDecileOutput(MacroOutputModel): + deciles: dict[str, list[float]] + all: dict[str, float] + + +class BaselineReformValue(MacroOutputModel): + baseline: float + reform: float + + +class AgePovertyOutput(MacroOutputModel): + child: BaselineReformValue + adult: BaselineReformValue + senior: BaselineReformValue + all: BaselineReformValue + + +class GenderPovertyOutput(MacroOutputModel): + male: BaselineReformValue + female: BaselineReformValue + + +class RacePovertyOutput(MacroOutputModel): + white: BaselineReformValue + black: BaselineReformValue + hispanic: BaselineReformValue + other: BaselineReformValue + + +class PovertyOutput(MacroOutputModel): + poverty: AgePovertyOutput + deep_poverty: AgePovertyOutput + + +class PovertyByGenderOutput(MacroOutputModel): + poverty: GenderPovertyOutput + deep_poverty: GenderPovertyOutput + + +class PovertyByRaceOutput(MacroOutputModel): + poverty: RacePovertyOutput + + +class InequalityOutput(MacroOutputModel): + gini: BaselineReformValue + top_10_pct_share: BaselineReformValue + top_1_pct_share: BaselineReformValue + + +class SingleYearMacroOutput(MacroOutputModel): + model_version: str + data_version: str + budget: BudgetaryOutput + detailed_budget: dict[str, DetailedBudgetProgramOutput] + decile: DecileOutput + inequality: InequalityOutput + poverty: PovertyOutput + poverty_by_gender: PovertyByGenderOutput + poverty_by_race: PovertyByRaceOutput | None + intra_decile: IntraDecileOutput + wealth_decile: DecileOutput | None + intra_wealth_decile: IntraDecileOutput | None + labor_supply_response: dict[str, Any] | None + constituency_impact: list[dict[str, Any]] | None + local_authority_impact: list[dict[str, Any]] | None + congressional_district_impact: list[dict[str, Any]] | None + cliff_impact: None = None diff --git a/projects/policyengine-api-simulation/src/modal/simulation_output_adapter.py b/projects/policyengine-api-simulation/src/modal/simulation_output_adapter.py index 865dbc7f4..e27ded681 100644 --- a/projects/policyengine-api-simulation/src/modal/simulation_output_adapter.py +++ b/projects/policyengine-api-simulation/src/modal/simulation_output_adapter.py @@ -6,6 +6,21 @@ from collections.abc import Iterable, Mapping from typing import Any +from src.modal.simulation_macro_output import ( + AgePovertyOutput, + BaselineReformValue, + BudgetaryOutput, + DecileOutput, + DetailedBudgetProgramOutput, + GenderPovertyOutput, + InequalityOutput, + IntraDecileOutput, + PovertyByGenderOutput, + PovertyByRaceOutput, + PovertyOutput, + RacePovertyOutput, + SingleYearMacroOutput, +) INTRA_DECILE_COLUMNS = { "Lose more than 5%": "lose_more_than_5pct", @@ -56,26 +71,48 @@ def _output_model_dump(value: Any) -> Any: return value.model_dump(mode="json") if isinstance(value, Mapping): return dict(value) - return value + return None + + +def _records_or_none(value: Any) -> list[dict[str, Any]] | None: + records = _output_model_dump(value) + if isinstance(records, list): + return [dict(item) for item in records if isinstance(item, Mapping)] + if isinstance(value, list): + return [dict(item) for item in value if isinstance(item, Mapping)] + return None + + +def build_budgetary_output(budget: Mapping[str, Any]) -> BudgetaryOutput: + return BudgetaryOutput( + tax_revenue_impact=_number(budget.get("tax_revenue_impact")), + state_tax_revenue_impact=_number(budget.get("state_tax_revenue_impact")), + benefit_spending_impact=_number(budget.get("benefit_spending_impact")), + budgetary_impact=_number(budget.get("budgetary_impact")), + households=_number(budget.get("households")), + baseline_net_income=_number(budget.get("baseline_net_income")), + ) -def _detailed_budget(collection: Any) -> dict[str, dict[str, float]]: - detailed_budget: dict[str, dict[str, float]] = {} +def build_detailed_budget_output( + collection: Any, +) -> dict[str, DetailedBudgetProgramOutput]: + detailed_budget: dict[str, DetailedBudgetProgramOutput] = {} for row in _collection_records(collection): program_name = row.get("program_name") if not program_name: continue baseline = _number(row.get("baseline_total")) reform = _number(row.get("reform_total")) - detailed_budget[str(program_name)] = { - "baseline": baseline, - "reform": reform, - "difference": _number(row.get("change"), reform - baseline), - } + detailed_budget[str(program_name)] = DetailedBudgetProgramOutput( + baseline=baseline, + reform=reform, + difference=_number(row.get("change"), reform - baseline), + ) return detailed_budget -def _decile_impact(collection: Any) -> dict[str, dict[str, float]]: +def build_decile_output(collection: Any) -> DecileOutput: average: dict[str, float] = {} relative: dict[str, float] = {} for row in sorted( @@ -88,18 +125,12 @@ def _decile_impact(collection: Any) -> dict[str, dict[str, float]]: key = str(decile) average[key] = _number(row.get("absolute_change")) relative[key] = _number(row.get("relative_change")) - return {"average": average, "relative": relative} + return DecileOutput(average=average, relative=relative) -def _empty_intra_decile() -> dict[str, Any]: - return { - "deciles": {label: [] for label in INTRA_DECILE_COLUMNS}, - "all": {label: 0.0 for label in INTRA_DECILE_COLUMNS}, - } - - -def _intra_decile_impact(collection: Any) -> dict[str, Any]: - result = _empty_intra_decile() +def build_intra_decile_output(collection: Any) -> IntraDecileOutput: + deciles: dict[str, list[float]] = {label: [] for label in INTRA_DECILE_COLUMNS} + all_values: dict[str, float] = {label: 0.0 for label in INTRA_DECILE_COLUMNS} rows = [ row for row in sorted( @@ -111,33 +142,28 @@ def _intra_decile_impact(collection: Any) -> dict[str, Any]: for label, column in INTRA_DECILE_COLUMNS.items(): values = [_number(row.get(column)) for row in rows] - result["deciles"][label] = values - result["all"][label] = sum(values) / len(values) if values else 0.0 - return result + deciles[label] = values + all_values[label] = sum(values) / len(values) if values else 0.0 + return IntraDecileOutput(deciles=deciles, all=all_values) -def _empty_age_poverty() -> dict[str, dict[str, float]]: - return { - "child": {"baseline": 0.0, "reform": 0.0}, - "adult": {"baseline": 0.0, "reform": 0.0}, - "senior": {"baseline": 0.0, "reform": 0.0}, - "all": {"baseline": 0.0, "reform": 0.0}, - } +def _empty_baseline_reform_value() -> dict[str, float]: + return {"baseline": 0.0, "reform": 0.0} -def _empty_gender_poverty() -> dict[str, dict[str, float]]: +def _empty_age_poverty() -> dict[str, dict[str, float]]: return { - "male": {"baseline": 0.0, "reform": 0.0}, - "female": {"baseline": 0.0, "reform": 0.0}, + "child": _empty_baseline_reform_value(), + "adult": _empty_baseline_reform_value(), + "senior": _empty_baseline_reform_value(), + "all": _empty_baseline_reform_value(), } -def _empty_race_poverty() -> dict[str, dict[str, float]]: +def _empty_gender_poverty() -> dict[str, dict[str, float]]: return { - "white": {"baseline": 0.0, "reform": 0.0}, - "black": {"baseline": 0.0, "reform": 0.0}, - "hispanic": {"baseline": 0.0, "reform": 0.0}, - "other": {"baseline": 0.0, "reform": 0.0}, + "male": _empty_baseline_reform_value(), + "female": _empty_baseline_reform_value(), } @@ -169,14 +195,41 @@ def _fill_poverty_block( output[poverty_type][group][side] = _number(row.get("rate")) -def _poverty_impact( +def _age_poverty_output(values: dict[str, dict[str, float]]) -> AgePovertyOutput: + return AgePovertyOutput( + child=BaselineReformValue(**values["child"]), + adult=BaselineReformValue(**values["adult"]), + senior=BaselineReformValue(**values["senior"]), + all=BaselineReformValue(**values["all"]), + ) + + +def _gender_poverty_output( + values: dict[str, dict[str, float]], +) -> GenderPovertyOutput: + return GenderPovertyOutput( + male=BaselineReformValue(**values["male"]), + female=BaselineReformValue(**values["female"]), + ) + + +def _race_poverty_output(values: dict[str, dict[str, float]]) -> RacePovertyOutput: + return RacePovertyOutput( + white=BaselineReformValue(**values["white"]), + black=BaselineReformValue(**values["black"]), + hispanic=BaselineReformValue(**values["hispanic"]), + other=BaselineReformValue(**values["other"]), + ) + + +def build_poverty_output( country: str, *, baseline: Any, reform: Any, baseline_by_age: Any, reform_by_age: Any, -) -> dict[str, dict[str, dict[str, float]]]: +) -> PovertyOutput: result = {"poverty": _empty_age_poverty(), "deep_poverty": _empty_age_poverty()} _fill_poverty_block( country=country, @@ -192,15 +245,18 @@ def _poverty_impact( reform_records=_collection_records(reform_by_age), default_group="all", ) - return result + return PovertyOutput( + poverty=_age_poverty_output(result["poverty"]), + deep_poverty=_age_poverty_output(result["deep_poverty"]), + ) -def _poverty_by_gender( +def build_poverty_by_gender_output( country: str, *, baseline_by_gender: Any, reform_by_gender: Any, -) -> dict[str, dict[str, dict[str, float]]]: +) -> PovertyByGenderOutput: result = { "poverty": _empty_gender_poverty(), "deep_poverty": _empty_gender_poverty(), @@ -212,15 +268,25 @@ def _poverty_by_gender( reform_records=_collection_records(reform_by_gender), default_group="all", ) - return result + return PovertyByGenderOutput( + poverty=_gender_poverty_output(result["poverty"]), + deep_poverty=_gender_poverty_output(result["deep_poverty"]), + ) -def _poverty_by_race( +def build_poverty_by_race_output( *, baseline_by_race: Any, reform_by_race: Any, -) -> dict[str, dict[str, dict[str, float]]]: - result = {"poverty": _empty_race_poverty()} +) -> PovertyByRaceOutput: + result = { + "poverty": { + "white": _empty_baseline_reform_value(), + "black": _empty_baseline_reform_value(), + "hispanic": _empty_baseline_reform_value(), + "other": _empty_baseline_reform_value(), + } + } _fill_poverty_block( country="us", output=result, @@ -228,32 +294,37 @@ def _poverty_by_race( reform_records=_collection_records(reform_by_race), default_group="all", ) - return result + return PovertyByRaceOutput(poverty=_race_poverty_output(result["poverty"])) -def _inequality_impact(baseline: Any, reform: Any) -> dict[str, Any]: - return { - "gini": { - "baseline": _number(getattr(baseline, "gini", None)), - "reform": _number(getattr(reform, "gini", None)), - }, - "top_10_pct_share": { - "baseline": _number(getattr(baseline, "top_10_share", None)), - "reform": _number(getattr(reform, "top_10_share", None)), - }, - "top_1_pct_share": { - "baseline": _number(getattr(baseline, "top_1_share", None)), - "reform": _number(getattr(reform, "top_1_share", None)), - }, - } +def build_inequality_output(baseline: Any, reform: Any) -> InequalityOutput: + return InequalityOutput( + gini=BaselineReformValue( + baseline=_number(getattr(baseline, "gini", None)), + reform=_number(getattr(reform, "gini", None)), + ), + top_10_pct_share=BaselineReformValue( + baseline=_number(getattr(baseline, "top_10_share", None)), + reform=_number(getattr(reform, "top_10_share", None)), + ), + top_1_pct_share=BaselineReformValue( + baseline=_number(getattr(baseline, "top_1_share", None)), + reform=_number(getattr(reform, "top_1_share", None)), + ), + ) -def adapt_analysis_to_legacy_macro_output( +def build_labor_supply_response_output(analysis: Any) -> dict[str, Any] | None: + output = _output_model_dump(getattr(analysis, "labor_supply_response", None)) + return output if isinstance(output, dict) else None + + +def build_single_year_macro_output( *, country: str, model_version: str, data_version: str, - budget: dict[str, float], + budget: Mapping[str, Any], analysis: Any, baseline_poverty_by_age: Any = None, reform_poverty_by_age: Any = None, @@ -265,54 +336,90 @@ def adapt_analysis_to_legacy_macro_output( congressional_district_impact: Any = None, constituency_impact: Any = None, local_authority_impact: Any = None, -) -> dict[str, Any]: - """Return the legacy single-year macro result expected by API callers.""" +) -> SingleYearMacroOutput: + """Build the schema-first single-year macro output.""" country = country.lower() wealth_decile = getattr(analysis, "wealth_decile_impacts", None) intra_wealth_decile = getattr(analysis, "intra_wealth_decile_impacts", None) - return { - "model_version": model_version, - "data_version": data_version, - "budget": budget, - "detailed_budget": _detailed_budget( + return SingleYearMacroOutput( + model_version=model_version, + data_version=data_version, + budget=build_budgetary_output(budget), + detailed_budget=build_detailed_budget_output( getattr(analysis, "program_statistics", None) ), - "decile": _decile_impact(getattr(analysis, "decile_impacts", None)), - "inequality": _inequality_impact( + decile=build_decile_output(getattr(analysis, "decile_impacts", None)), + inequality=build_inequality_output( getattr(analysis, "baseline_inequality", None), getattr(analysis, "reform_inequality", None), ), - "poverty": _poverty_impact( + poverty=build_poverty_output( country, baseline=getattr(analysis, "baseline_poverty", None), reform=getattr(analysis, "reform_poverty", None), baseline_by_age=baseline_poverty_by_age, reform_by_age=reform_poverty_by_age, ), - "poverty_by_gender": _poverty_by_gender( + poverty_by_gender=build_poverty_by_gender_output( country, baseline_by_gender=baseline_poverty_by_gender, reform_by_gender=reform_poverty_by_gender, ), - "poverty_by_race": ( - _poverty_by_race( + poverty_by_race=( + build_poverty_by_race_output( baseline_by_race=baseline_poverty_by_race, reform_by_race=reform_poverty_by_race, ) if country == "us" else None ), - "intra_decile": _intra_decile_impact(intra_decile), - "wealth_decile": _decile_impact(wealth_decile) if country == "uk" else None, - "intra_wealth_decile": ( - _intra_decile_impact(intra_wealth_decile) if country == "uk" else None + intra_decile=build_intra_decile_output(intra_decile), + wealth_decile=build_decile_output(wealth_decile) if country == "uk" else None, + intra_wealth_decile=( + build_intra_decile_output(intra_wealth_decile) if country == "uk" else None ), - "labor_supply_response": _output_model_dump( - getattr(analysis, "labor_supply_response", None) - ), - "constituency_impact": constituency_impact, - "local_authority_impact": local_authority_impact, - "congressional_district_impact": congressional_district_impact, - "cliff_impact": None, - } + labor_supply_response=build_labor_supply_response_output(analysis), + constituency_impact=_records_or_none(constituency_impact), + local_authority_impact=_records_or_none(local_authority_impact), + congressional_district_impact=_records_or_none(congressional_district_impact), + cliff_impact=None, + ) + + +def adapt_analysis_to_legacy_macro_output( + *, + country: str, + model_version: str, + data_version: str, + budget: dict[str, float], + analysis: Any, + baseline_poverty_by_age: Any = None, + reform_poverty_by_age: Any = None, + baseline_poverty_by_gender: Any = None, + reform_poverty_by_gender: Any = None, + baseline_poverty_by_race: Any = None, + reform_poverty_by_race: Any = None, + intra_decile: Any = None, + congressional_district_impact: Any = None, + constituency_impact: Any = None, + local_authority_impact: Any = None, +) -> dict[str, Any]: + """Return the legacy single-year macro result expected by API callers.""" + return build_single_year_macro_output( + country=country, + model_version=model_version, + data_version=data_version, + budget=budget, + analysis=analysis, + baseline_poverty_by_age=baseline_poverty_by_age, + reform_poverty_by_age=reform_poverty_by_age, + baseline_poverty_by_gender=baseline_poverty_by_gender, + reform_poverty_by_gender=reform_poverty_by_gender, + baseline_poverty_by_race=baseline_poverty_by_race, + reform_poverty_by_race=reform_poverty_by_race, + intra_decile=intra_decile, + congressional_district_impact=congressional_district_impact, + constituency_impact=constituency_impact, + local_authority_impact=local_authority_impact, + ).model_dump(mode="json") diff --git a/projects/policyengine-api-simulation/tests/test_simulation_api_contracts.py b/projects/policyengine-api-simulation/tests/test_simulation_api_contracts.py index dd8099171..57465e610 100644 --- a/projects/policyengine-api-simulation/tests/test_simulation_api_contracts.py +++ b/projects/policyengine-api-simulation/tests/test_simulation_api_contracts.py @@ -7,6 +7,7 @@ BudgetWindowTotals, JobStatusResponse, ) +from src.modal.simulation_macro_output import SingleYearMacroOutput from fixtures.test_simulation_api_contracts import ( CURRENT_REQUIRED_BUDGET_KEYS, @@ -29,6 +30,10 @@ def test_job_status_result_preserves_current_single_year_macro_dict_contract(): ) +def test_internal_single_year_macro_schema_matches_current_public_keys(): + assert set(SingleYearMacroOutput.model_fields) == CURRENT_SINGLE_YEAR_MACRO_KEYS + + def test_openapi_keeps_job_status_result_as_unstructured_dict(): spec = create_openapi_app().openapi() schemas = spec["components"]["schemas"] diff --git a/projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py b/projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py index dc4f3ed54..0b71ba754 100644 --- a/projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py +++ b/projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py @@ -24,11 +24,21 @@ _uk_constituency_impact, _uk_local_authority_impact, ) -from src.modal.simulation_output_adapter import adapt_analysis_to_legacy_macro_output +from src.modal.simulation_macro_output import ( + BudgetaryOutput, + DecileOutput, + IntraDecileOutput, + PovertyOutput, + SingleYearMacroOutput, +) +from src.modal.simulation_output_adapter import ( + adapt_analysis_to_legacy_macro_output, + build_single_year_macro_output, +) -def test_adapter_returns_existing_single_year_macro_shape(): - output = adapt_analysis_to_legacy_macro_output( +def _build_schema_output() -> SingleYearMacroOutput: + return build_single_year_macro_output( country="us", model_version="1.702.0", data_version="1.115.5", @@ -44,6 +54,39 @@ def test_adapter_returns_existing_single_year_macro_shape(): congressional_district_impact=[{"district_geoid": 101}], ) + +def test_builder_returns_schema_modules_before_legacy_dict_dump(): + output = _build_schema_output() + + assert isinstance(output, SingleYearMacroOutput) + assert isinstance(output.budget, BudgetaryOutput) + assert isinstance(output.decile, DecileOutput) + assert isinstance(output.intra_decile, IntraDecileOutput) + assert isinstance(output.poverty, PovertyOutput) + assert output.wealth_decile is None + assert output.congressional_district_impact == [{"district_geoid": 101}] + + legacy_output = adapt_analysis_to_legacy_macro_output( + country="us", + model_version="1.702.0", + data_version="1.115.5", + budget=BUDGET, + analysis=fake_analysis(), + baseline_poverty_by_age=BASELINE_POVERTY_BY_AGE, + reform_poverty_by_age=REFORM_POVERTY_BY_AGE, + baseline_poverty_by_gender=BASELINE_POVERTY_BY_GENDER, + reform_poverty_by_gender=REFORM_POVERTY_BY_GENDER, + baseline_poverty_by_race=BASELINE_POVERTY_BY_RACE, + reform_poverty_by_race=REFORM_POVERTY_BY_RACE, + intra_decile=INTRA_DECILE_COLLECTION, + congressional_district_impact=[{"district_geoid": 101}], + ) + assert output.model_dump(mode="json") == legacy_output + + +def test_adapter_returns_existing_single_year_macro_shape(): + output = _build_schema_output().model_dump(mode="json") + assert set(output) == CURRENT_SINGLE_YEAR_MACRO_KEYS assert output["model_version"] == "1.702.0" assert output["data_version"] == "1.115.5" From 6b0d0fd88f7c04d2ce646ba56b1c3167166ff218 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 21 May 2026 23:00:02 +0200 Subject: [PATCH 07/23] fix: derive simulation versions from policyengine bundle --- .github/scripts/modal-deploy-app.sh | 1 + .github/scripts/modal-extract-versions.sh | 13 +- .github/scripts/update-country-package.sh | 52 ++++---- .github/workflows/modal-deploy.reusable.yml | 8 ++ .../fixtures/test_simulation_api_contracts.py | 2 +- .../pyproject.toml | 2 +- .../src/modal/app.py | 13 +- .../src/modal/gateway/endpoints.py | 35 ++---- .../src/modal/release_bundle.py | 114 ++++++++++++++++++ .../src/modal/simulation.py | 35 ++---- .../modal/utils/extract_bundle_versions.py | 41 +++++++ .../tests/gateway/test_endpoints.py | 70 +++++++---- .../test_country_package_update_scripts.py | 7 -- .../tests/test_modal_scripts.py | 11 +- .../test_policyengine_dependency_source.py | 24 ++++ .../tests/test_release_bundle.py | 49 ++++++++ .../tests/test_simulation_output_adapter.py | 6 +- projects/policyengine-api-simulation/uv.lock | 8 +- 18 files changed, 358 insertions(+), 133 deletions(-) create mode 100644 projects/policyengine-api-simulation/src/modal/release_bundle.py create mode 100644 projects/policyengine-api-simulation/src/modal/utils/extract_bundle_versions.py create mode 100644 projects/policyengine-api-simulation/tests/test_release_bundle.py diff --git a/.github/scripts/modal-deploy-app.sh b/.github/scripts/modal-deploy-app.sh index 48f1cc872..f8581752b 100755 --- a/.github/scripts/modal-deploy-app.sh +++ b/.github/scripts/modal-deploy-app.sh @@ -2,6 +2,7 @@ # Deploy simulation API to Modal # Usage: ./modal-deploy-app.sh # Required env vars: POLICYENGINE_US_VERSION, POLICYENGINE_UK_VERSION +# These should come from the bundled policyengine.py release manifest. # # Deploys two apps: # 1. policyengine-simulation-gateway - Stable gateway with fixed URL diff --git a/.github/scripts/modal-extract-versions.sh b/.github/scripts/modal-extract-versions.sh index 9d82c890f..8fd71185c 100755 --- a/.github/scripts/modal-extract-versions.sh +++ b/.github/scripts/modal-extract-versions.sh @@ -1,7 +1,9 @@ #!/bin/bash -# Extract policyengine-us and policyengine-uk versions from uv.lock +# Extract policyengine-us, policyengine-us-data, policyengine-uk, and +# policyengine-uk-data versions from the bundled policyengine.py manifests. # Usage: ./modal-extract-versions.sh -# Outputs: Sets us_version and uk_version in GITHUB_OUTPUT +# Outputs: Sets policyengine_version, us_version, us_data_version, uk_version, +# and uk_data_version in GITHUB_OUTPUT set -euo pipefail @@ -9,9 +11,4 @@ PROJECT_DIR="${1:-.}" cd "$PROJECT_DIR" -US_VERSION=$(grep -A1 'name = "policyengine-us"' uv.lock | grep version | head -1 | sed 's/.*"\(.*\)".*/\1/') -UK_VERSION=$(grep -A1 'name = "policyengine-uk"' uv.lock | grep version | head -1 | sed 's/.*"\(.*\)".*/\1/') - -echo "us_version=$US_VERSION" >> "$GITHUB_OUTPUT" -echo "uk_version=$UK_VERSION" >> "$GITHUB_OUTPUT" -echo "Deploying with policyengine-us=$US_VERSION, policyengine-uk=$UK_VERSION" +uv run python -m src.modal.utils.extract_bundle_versions diff --git a/.github/scripts/update-country-package.sh b/.github/scripts/update-country-package.sh index 780cd229f..1fd4c9d19 100644 --- a/.github/scripts/update-country-package.sh +++ b/.github/scripts/update-country-package.sh @@ -1,7 +1,7 @@ #!/usr/bin/env bash # -# Check PyPI for a newer country package, update the simulation project pins, -# and open a version-specific PR. +# Sync a country package pin to the version bundled by policyengine.py and +# open a version-specific PR when the local pin is off-bundle. # # Usage: # .github/scripts/update-country-package.sh policyengine-us [--dry-run] @@ -9,7 +9,7 @@ # # Optional environment: # PROJECT_DIR Project containing pyproject.toml and uv.lock. -# LATEST_OVERRIDE Version to use instead of querying PyPI, for local checks. +# LATEST_OVERRIDE Version to use instead of querying the bundle, for local checks. # DRY_RUN=1 Report planned changes without editing files or opening a PR. set -euo pipefail @@ -23,13 +23,9 @@ fi case "$PACKAGE" in policyengine-us) DISPLAY_NAME="PolicyEngine US" - CONSTANT_NAME="US_VERSION" - ENV_NAME="POLICYENGINE_US_VERSION" ;; policyengine-uk) DISPLAY_NAME="PolicyEngine UK" - CONSTANT_NAME="UK_VERSION" - ENV_NAME="POLICYENGINE_UK_VERSION" ;; *) echo "ERROR: Unsupported package '${PACKAGE}'." >&2 @@ -42,7 +38,6 @@ PROJECT_DIR="${PROJECT_DIR:-projects/policyengine-api-simulation}" PROJECT_PATH="${ROOT_DIR}/${PROJECT_DIR}" PYPROJECT="${PROJECT_PATH}/pyproject.toml" LOCKFILE="${PROJECT_PATH}/uv.lock" -MODAL_APP="${PROJECT_PATH}/src/modal/app.py" create_pr_body_file() { local changelog @@ -72,7 +67,7 @@ create_pr_body_file() { echo "$pr_body_file" } -if [[ ! -f "$PYPROJECT" || ! -f "$LOCKFILE" || ! -f "$MODAL_APP" ]]; then +if [[ ! -f "$PYPROJECT" || ! -f "$LOCKFILE" ]]; then echo "ERROR: Expected simulation project files were not found under ${PROJECT_DIR}." >&2 exit 1 fi @@ -94,9 +89,23 @@ PY if [[ -n "${LATEST_OVERRIDE:-}" ]]; then LATEST="$LATEST_OVERRIDE" else - LATEST=$(curl -fsSL "https://pypi.org/pypi/${PACKAGE}/json" | python3 -c 'import json, sys; print(json.load(sys.stdin)["info"]["version"])') + LATEST=$( + cd "$PROJECT_PATH" + uv run python - "$PACKAGE" <<'PY' +import sys + +from src.modal.release_bundle import get_country_release_bundle + +package = sys.argv[1] +country_by_package = { + "policyengine-us": "us", + "policyengine-uk": "uk", +} +print(get_country_release_bundle(country_by_package[package]).model_version) +PY + ) if [[ -z "$LATEST" ]]; then - echo "ERROR: Could not fetch latest version for ${PACKAGE} from PyPI." >&2 + echo "ERROR: Could not resolve bundled version for ${PACKAGE}." >&2 exit 1 fi fi @@ -107,7 +116,7 @@ if [[ -z "$LATEST" ]]; then fi echo "Current locked version: ${PACKAGE}==${CURRENT}" -echo "Latest PyPI version: ${PACKAGE}==${LATEST}" +echo "Bundled .py version: ${PACKAGE}==${LATEST}" if [[ "$CURRENT" == "$LATEST" ]]; then echo "Already up to date. Nothing to do." @@ -125,7 +134,6 @@ if [[ "$DRY_RUN" == "1" ]]; then echo "Dry run: would create ${BRANCH} and update:" echo " ${PROJECT_DIR}/pyproject.toml" echo " ${PROJECT_DIR}/uv.lock" - echo " ${PROJECT_DIR}/src/modal/app.py" exit 0 fi @@ -155,12 +163,11 @@ git config user.name "github-actions[bot]" git config user.email "github-actions[bot]@users.noreply.github.com" git checkout -b "$BRANCH" -python3 - "$PYPROJECT" "$MODAL_APP" "$PACKAGE" "$CURRENT" "$LATEST" "$CONSTANT_NAME" "$ENV_NAME" <<'PY' -import re +python3 - "$PYPROJECT" "$PACKAGE" "$CURRENT" "$LATEST" <<'PY' import sys from pathlib import Path -pyproject_path, modal_app_path, package, current, latest, constant, env_name = sys.argv[1:] +pyproject_path, package, current, latest = sys.argv[1:] pyproject = Path(pyproject_path) pyproject_text = pyproject.read_text(encoding="utf-8") @@ -169,15 +176,6 @@ new_pin = f'"{package}=={latest}"' if old_pin not in pyproject_text: raise SystemExit(f"Could not find {old_pin} in {pyproject}") pyproject.write_text(pyproject_text.replace(old_pin, new_pin), encoding="utf-8") - -modal_app = Path(modal_app_path) -modal_text = modal_app.read_text(encoding="utf-8") -pattern = rf'{constant} = os\.environ\.get\("{env_name}", "[^"]+"\)' -replacement = f'{constant} = os.environ.get("{env_name}", "{latest}")' -updated, count = re.subn(pattern, replacement, modal_text, count=1) -if count != 1: - raise SystemExit(f"Could not update {constant} in {modal_app}") -modal_app.write_text(updated, encoding="utf-8") PY ( @@ -185,14 +183,14 @@ PY uv lock --upgrade-package "$PACKAGE" ) -if git diff --quiet -- "$PYPROJECT" "$LOCKFILE" "$MODAL_APP"; then +if git diff --quiet -- "$PYPROJECT" "$LOCKFILE"; then echo "No changes after update. Nothing to do." exit 0 fi PR_BODY_FILE="$(create_pr_body_file)" -git add "$PYPROJECT" "$LOCKFILE" "$MODAL_APP" +git add "$PYPROJECT" "$LOCKFILE" git commit -m "chore(deps): update ${PACKAGE} to ${LATEST}" git push -u origin "$BRANCH" diff --git a/.github/workflows/modal-deploy.reusable.yml b/.github/workflows/modal-deploy.reusable.yml index ef2de18d0..1fa823c17 100644 --- a/.github/workflows/modal-deploy.reusable.yml +++ b/.github/workflows/modal-deploy.reusable.yml @@ -18,9 +18,15 @@ on: us_version: description: 'The deployed policyengine-us package version' value: ${{ jobs.deploy.outputs.us_version }} + us_data_version: + description: 'The bundled policyengine-us-data version' + value: ${{ jobs.deploy.outputs.us_data_version }} uk_version: description: 'The deployed policyengine-uk package version' value: ${{ jobs.deploy.outputs.uk_version }} + uk_data_version: + description: 'The bundled policyengine-uk-data version' + value: ${{ jobs.deploy.outputs.uk_data_version }} jobs: deploy: @@ -30,7 +36,9 @@ jobs: outputs: simulation_api_url: ${{ steps.get-url.outputs.simulation_api_url }} us_version: ${{ steps.versions.outputs.us_version }} + us_data_version: ${{ steps.versions.outputs.us_data_version }} uk_version: ${{ steps.versions.outputs.uk_version }} + uk_data_version: ${{ steps.versions.outputs.uk_data_version }} steps: - name: Checkout repo diff --git a/projects/policyengine-api-simulation/fixtures/test_simulation_api_contracts.py b/projects/policyengine-api-simulation/fixtures/test_simulation_api_contracts.py index 7e900e79c..ba3df0675 100644 --- a/projects/policyengine-api-simulation/fixtures/test_simulation_api_contracts.py +++ b/projects/policyengine-api-simulation/fixtures/test_simulation_api_contracts.py @@ -30,7 +30,7 @@ } CURRENT_SINGLE_YEAR_MACRO_RESULT = { - "model_version": "1.702.0", + "model_version": "1.700.0", "data_version": "1.115.5", "budget": { "budgetary_impact": 300.0, diff --git a/projects/policyengine-api-simulation/pyproject.toml b/projects/policyengine-api-simulation/pyproject.toml index bd3357056..2190556c8 100644 --- a/projects/policyengine-api-simulation/pyproject.toml +++ b/projects/policyengine-api-simulation/pyproject.toml @@ -19,7 +19,7 @@ dependencies = [ "policyengine==4.10.0", "policyengine-core==3.26.1", "policyengine-uk==2.88.20", - "policyengine-us==1.702.0", + "policyengine-us==1.700.0", "tables>=3.10.2", "modal>=0.73.0", "logfire>=3.0.0", diff --git a/projects/policyengine-api-simulation/src/modal/app.py b/projects/policyengine-api-simulation/src/modal/app.py index 30a9a71b4..2056f14b6 100644 --- a/projects/policyengine-api-simulation/src/modal/app.py +++ b/projects/policyengine-api-simulation/src/modal/app.py @@ -12,10 +12,15 @@ from src.modal._image_setup import snapshot_models from src.modal.logging_redaction import redact_params_for_logging - -# Get versions from environment or use defaults -US_VERSION = os.environ.get("POLICYENGINE_US_VERSION", "1.702.0") -UK_VERSION = os.environ.get("POLICYENGINE_UK_VERSION", "2.88.20") +from src.modal.release_bundle import get_bundled_country_model_version + +# Get versions from environment or the bundled policyengine.py release manifest. +US_VERSION = os.environ.get( + "POLICYENGINE_US_VERSION" +) or get_bundled_country_model_version("us") +UK_VERSION = os.environ.get( + "POLICYENGINE_UK_VERSION" +) or get_bundled_country_model_version("uk") def get_app_name(us_version: str, uk_version: str) -> str: diff --git a/projects/policyengine-api-simulation/src/modal/gateway/endpoints.py b/projects/policyengine-api-simulation/src/modal/gateway/endpoints.py index eb382a41a..9db5a5442 100644 --- a/projects/policyengine-api-simulation/src/modal/gateway/endpoints.py +++ b/projects/policyengine-api-simulation/src/modal/gateway/endpoints.py @@ -34,27 +34,15 @@ failed_job_response, running_job_response, ) +from src.modal.release_bundle import ( + get_country_release_bundle, + resolve_bundle_dataset_uri, +) logger = logging.getLogger(__name__) router = APIRouter() JOB_METADATA_DICT_NAME = "simulation-api-job-metadata" -DATASET_URIS = { - "us": { - "enhanced_cps": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.5", - "enhanced_cps_2024": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.5", - "cps": "hf://policyengine/policyengine-us-data/cps_2023.h5@1.115.5", - "cps_2023": "hf://policyengine/policyengine-us-data/cps_2023.h5@1.115.5", - "pooled_cps": "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.115.5", - "pooled_3_year_cps_2023": "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.115.5", - }, - "uk": { - "enhanced_frs": "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.55.10", - "enhanced_frs_2023_24": "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.55.10", - "frs": "hf://policyengine/policyengine-uk-data-private/frs_2023_24.h5@1.55.10", - "frs_2023_24": "hf://policyengine/policyengine-uk-data-private/frs_2023_24.h5@1.55.10", - }, -} def _job_metadata_store(): @@ -82,16 +70,17 @@ def _is_modal_job_not_found(exc: BaseException) -> bool: def _build_policyengine_bundle( country: str, resolved_version: str, payload: dict ) -> PolicyEngineBundle: + bundle = get_country_release_bundle(country) dataset = payload.get("data") - if isinstance(dataset, str) and "://" in dataset: - resolved_dataset = dataset - elif isinstance(dataset, str): - resolved_dataset = DATASET_URIS.get(country.lower(), {}).get(dataset, dataset) - else: - resolved_dataset = None + resolved_dataset = ( + resolve_bundle_dataset_uri(country, dataset) + if dataset is None or isinstance(dataset, str) + else None + ) return PolicyEngineBundle( model_version=resolved_version, - data_version=payload.get("data_version"), + policyengine_version=bundle.policyengine_version, + data_version=payload.get("data_version") or bundle.data_version, dataset=resolved_dataset, ) diff --git a/projects/policyengine-api-simulation/src/modal/release_bundle.py b/projects/policyengine-api-simulation/src/modal/release_bundle.py new file mode 100644 index 000000000..f99dfadba --- /dev/null +++ b/projects/policyengine-api-simulation/src/modal/release_bundle.py @@ -0,0 +1,114 @@ +"""Helpers for using the bundled policyengine.py release manifests. + +The simulation API deploys separate versioned worker apps, but the country +package and data artifact versions must come from the policyengine.py bundle +manifest so model/data compatibility stays explicit. +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from functools import lru_cache +from typing import Mapping + +os.environ.setdefault("POLICYENGINE_SKIP_COUNTRY_IMPORTS", "1") + +SUPPORTED_COUNTRIES = frozenset({"us", "uk"}) + +DATASET_ALIASES: dict[str, dict[str, str]] = { + "us": { + "enhanced_cps": "enhanced_cps_2024", + "enhanced_cps_2024": "enhanced_cps_2024", + "gs://policyengine-us-data/enhanced_cps_2024.h5": "enhanced_cps_2024", + "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5": "enhanced_cps_2024", + "cps_small": "cps_small_2024", + "cps_small_2024": "cps_small_2024", + }, + "uk": { + "enhanced_frs": "enhanced_frs_2023_24", + "enhanced_frs_2023_24": "enhanced_frs_2023_24", + "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5": "enhanced_frs_2023_24", + "frs": "frs_2023_24", + "frs_2023_24": "frs_2023_24", + "hf://policyengine/policyengine-uk-data-private/frs_2023_24.h5": "frs_2023_24", + }, +} + + +@dataclass(frozen=True) +class CountryReleaseBundle: + country: str + policyengine_version: str + model_package_name: str + model_version: str + data_package_name: str + data_version: str + default_dataset: str + default_dataset_uri: str + dataset_uris: Mapping[str, str] + + +def _normalise_country(country: str) -> str: + country = country.lower() + if country not in SUPPORTED_COUNTRIES: + raise ValueError(f"Unsupported country: {country}") + return country + + +def _artifact_revision(data_package) -> str: + return data_package.release_manifest_revision or data_package.version + + +@lru_cache +def get_country_release_bundle(country: str) -> CountryReleaseBundle: + """Return package and dataset versions from policyengine.py's manifest.""" + + country = _normalise_country(country) + from policyengine.provenance.manifest import build_hf_uri, get_release_manifest + + manifest = get_release_manifest(country) + dataset_uris = { + name: build_hf_uri( + repo_id=manifest.data_package.repo_id, + path_in_repo=reference.path, + revision=reference.revision or _artifact_revision(manifest.data_package), + ) + for name, reference in manifest.datasets.items() + } + + return CountryReleaseBundle( + country=country, + policyengine_version=manifest.policyengine_version, + model_package_name=manifest.model_package.name, + model_version=manifest.model_package.version, + data_package_name=manifest.data_package.name, + data_version=manifest.data_package.version, + default_dataset=manifest.default_dataset, + default_dataset_uri=manifest.default_dataset_uri, + dataset_uris=dataset_uris, + ) + + +def get_bundled_country_model_version(country: str) -> str: + return get_country_release_bundle(country).model_version + + +def resolve_bundle_dataset_name(country: str, requested_data: str | None) -> str: + bundle = get_country_release_bundle(country) + if requested_data is None: + return bundle.default_dataset + + requested_without_revision = requested_data.split("@", maxsplit=1)[0] + aliased = DATASET_ALIASES.get(bundle.country, {}).get( + requested_without_revision, requested_data + ) + return aliased + + +def resolve_bundle_dataset_uri(country: str, requested_data: str | None) -> str: + bundle = get_country_release_bundle(country) + dataset_name = resolve_bundle_dataset_name(country, requested_data) + if "://" in dataset_name: + return dataset_name + return bundle.dataset_uris.get(dataset_name, dataset_name) diff --git a/projects/policyengine-api-simulation/src/modal/simulation.py b/projects/policyengine-api-simulation/src/modal/simulation.py index 58cdb7c68..c27fdbe1c 100644 --- a/projects/policyengine-api-simulation/src/modal/simulation.py +++ b/projects/policyengine-api-simulation/src/modal/simulation.py @@ -14,6 +14,10 @@ from importlib import import_module from typing import Any, Iterator +from src.modal.release_bundle import ( + get_country_release_bundle, + resolve_bundle_dataset_name, +) from src.modal.simulation_output_adapter import adapt_analysis_to_legacy_macro_output from src.modal.telemetry import split_internal_payload @@ -22,24 +26,6 @@ os.environ.setdefault("POLICYENGINE_SKIP_COUNTRY_IMPORTS", "1") DEFAULT_YEAR = 2026 -DATASET_ALIASES = { - "us": { - "enhanced_cps": "enhanced_cps_2024", - "enhanced_cps_2024": "enhanced_cps_2024", - "gs://policyengine-us-data/enhanced_cps_2024.h5": "enhanced_cps_2024", - "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5": "enhanced_cps_2024", - "cps_small": "cps_small_2024", - "cps_small_2024": "cps_small_2024", - }, - "uk": { - "enhanced_frs": "enhanced_frs_2023_24", - "enhanced_frs_2023_24": "enhanced_frs_2023_24", - "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5": "enhanced_frs_2023_24", - "frs": "frs_2023_24", - "frs_2023_24": "frs_2023_24", - "hf://policyengine/policyengine-uk-data-private/frs_2023_24.h5": "frs_2023_24", - }, -} def _normalize_credentials_blob(creds_json: str) -> str: @@ -168,13 +154,7 @@ def _normalise_policy(policy: dict[str, Any] | None) -> dict[str, Any] | None: def _resolve_dataset_name(country: str, requested_data: str | None) -> str: - if requested_data is None: - return "enhanced_cps_2024" if country == "us" else "enhanced_frs_2023_24" - - requested_without_revision = requested_data.split("@", maxsplit=1)[0] - return DATASET_ALIASES.get(country, {}).get( - requested_without_revision, requested_data - ) + return resolve_bundle_dataset_name(country, requested_data) def _microframe_like(frame, weights: str): @@ -501,6 +481,11 @@ def _model_version(country_module) -> str: def _data_version(params: dict[str, Any], dataset) -> str: if params.get("data_version"): return str(params["data_version"]) + country = params.get("country", "us").lower() + try: + return get_country_release_bundle(country).data_version + except ValueError: + pass metadata = getattr(dataset, "metadata", {}) or {} for key in ("data_version", "version"): value = metadata.get(key) diff --git a/projects/policyengine-api-simulation/src/modal/utils/extract_bundle_versions.py b/projects/policyengine-api-simulation/src/modal/utils/extract_bundle_versions.py new file mode 100644 index 000000000..bf3d91086 --- /dev/null +++ b/projects/policyengine-api-simulation/src/modal/utils/extract_bundle_versions.py @@ -0,0 +1,41 @@ +"""Print policyengine.py bundle versions for deployment scripts.""" + +from __future__ import annotations + +import os +from pathlib import Path + +from src.modal.release_bundle import get_country_release_bundle + + +def main() -> None: + us_bundle = get_country_release_bundle("us") + uk_bundle = get_country_release_bundle("uk") + + outputs = { + "policyengine_version": us_bundle.policyengine_version, + "us_version": us_bundle.model_version, + "us_data_version": us_bundle.data_version, + "uk_version": uk_bundle.model_version, + "uk_data_version": uk_bundle.data_version, + } + + github_output = os.environ.get("GITHUB_OUTPUT") + if github_output: + output_path = Path(github_output) + with output_path.open("a", encoding="utf-8") as file: + for key, value in outputs.items(): + file.write(f"{key}={value}\n") + + print( + "Deploying with policyengine.py bundle " + f"{outputs['policyengine_version']}: " + f"policyengine-us={outputs['us_version']}, " + f"policyengine-us-data={outputs['us_data_version']}, " + f"policyengine-uk={outputs['uk_version']}, " + f"policyengine-uk-data={outputs['uk_data_version']}" + ) + + +if __name__ == "__main__": + main() diff --git a/projects/policyengine-api-simulation/tests/gateway/test_endpoints.py b/projects/policyengine-api-simulation/tests/gateway/test_endpoints.py index e2fe70611..393fdfc5a 100644 --- a/projects/policyengine-api-simulation/tests/gateway/test_endpoints.py +++ b/projects/policyengine-api-simulation/tests/gateway/test_endpoints.py @@ -8,6 +8,27 @@ import pytest from fastapi.testclient import TestClient +from src.modal.release_bundle import ( + get_country_release_bundle, + resolve_bundle_dataset_uri, +) + + +def expected_bundle( + country: str, + model_version: str, + *, + dataset: str | None = None, + data_version: str | None = None, +) -> dict[str, str]: + bundle = get_country_release_bundle(country) + return { + "model_version": model_version, + "policyengine_version": bundle.policyengine_version, + "data_version": data_version or bundle.data_version, + "dataset": resolve_bundle_dataset_uri(country, dataset), + } + class TestGetAppName: """Tests for the get_app_name helper function.""" @@ -260,12 +281,13 @@ def test__given_submission_with_data__then_returns_resolved_bundle_metadata( assert response.status_code == 200 data = response.json() assert data["resolved_app_name"] == "policyengine-simulation-us1-500-0-uk2-66-0" - assert data["policyengine_bundle"] == { - "model_version": "1.500.0", - "dataset": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.5", - } + assert data["policyengine_bundle"] == expected_bundle( + "us", + "1.500.0", + dataset="hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.5", + ) - def test__given_submission_with_alias_data__then_bundle_dataset_stays_unresolved( + def test__given_submission_with_alias_data__then_bundle_dataset_uses_manifest_uri( self, mock_modal, client: TestClient ): mock_modal["dicts"]["simulation-api-us-versions"] = { @@ -284,9 +306,8 @@ def test__given_submission_with_alias_data__then_bundle_dataset_stays_unresolved assert response.status_code == 200 data = response.json() - assert ( - data["policyengine_bundle"]["dataset"] - == "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.5" + assert data["policyengine_bundle"]["dataset"] == resolve_bundle_dataset_uri( + "us", "enhanced_cps_2024" ) def test__given_submission_with_uk_alias_data__then_bundle_dataset_is_versioned_uri( @@ -308,9 +329,8 @@ def test__given_submission_with_uk_alias_data__then_bundle_dataset_is_versioned_ assert response.status_code == 200 data = response.json() - assert ( - data["policyengine_bundle"]["dataset"] - == "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.55.10" + assert data["policyengine_bundle"]["dataset"] == resolve_bundle_dataset_uri( + "uk", "enhanced_frs" ) def test__given_submission_with_runtime_bundle__then_accepts_internal_provenance( @@ -338,11 +358,12 @@ def test__given_submission_with_runtime_bundle__then_accepts_internal_provenance assert response.status_code == 200 data = response.json() - assert data["policyengine_bundle"] == { - "model_version": "1.500.0", - "data_version": "1.78.2", - "dataset": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.5", - } + assert data["policyengine_bundle"] == expected_bundle( + "us", + "1.500.0", + dataset="enhanced_cps_2024", + data_version="1.78.2", + ) assert mock_modal["func"].last_payload["data_version"] == "1.78.2" assert "_runtime_bundle" not in mock_modal["func"].last_payload assert "_metadata" not in mock_modal["func"].last_payload @@ -401,10 +422,11 @@ def test__given_submitted_job__then_job_status_includes_bundle_metadata( assert data["status"] == "complete" assert "run_id" not in data assert data["resolved_app_name"] == "policyengine-simulation-us1-500-0-uk2-66-0" - assert data["policyengine_bundle"] == { - "model_version": "1.500.0", - "dataset": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.5", - } + assert data["policyengine_bundle"] == expected_bundle( + "us", + "1.500.0", + dataset="hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.5", + ) def test__given_submitted_job_with_telemetry__then_polling_echoes_run_id( self, mock_modal, client: TestClient @@ -607,9 +629,7 @@ def test__given_budget_window_submission__then_returns_parent_batch_job_id( "country": "us", "version": "1.500.0", "resolved_app_name": "policyengine-simulation-us1-500-0-uk2-66-0", - "policyengine_bundle": { - "model_version": "1.500.0", - }, + "policyengine_bundle": expected_bundle("us", "1.500.0"), } def test__given_budget_window_submission__then_initial_poll_returns_seed_state( @@ -654,9 +674,7 @@ def test__given_budget_window_submission__then_initial_poll_returns_seed_state( "result": None, "error": None, "resolved_app_name": "policyengine-simulation-us1-500-0-uk2-66-0", - "policyengine_bundle": { - "model_version": "1.500.0", - }, + "policyengine_bundle": expected_bundle("us", "1.500.0"), "run_id": "batch-run-123", } diff --git a/projects/policyengine-api-simulation/tests/test_country_package_update_scripts.py b/projects/policyengine-api-simulation/tests/test_country_package_update_scripts.py index c035d791f..a011390be 100644 --- a/projects/policyengine-api-simulation/tests/test_country_package_update_scripts.py +++ b/projects/policyengine-api-simulation/tests/test_country_package_update_scripts.py @@ -65,7 +65,6 @@ def test_update_country_package_dry_run_reports_planned_changes_without_editing( assert "Dry run: would create auto/update-policyengine-us-1.1.0" in result.stdout assert "simulation/pyproject.toml" in result.stdout assert "simulation/uv.lock" in result.stdout - assert "simulation/src/modal/app.py" in result.stdout assert pyproject.read_text(encoding="utf-8") == original_pyproject @@ -160,13 +159,7 @@ def test_update_country_package_updates_files_and_opens_pr( pyproject_text = (fake_repo / "simulation" / "pyproject.toml").read_text( encoding="utf-8" ) - modal_text = (fake_repo / "simulation" / "src" / "modal" / "app.py").read_text( - encoding="utf-8" - ) assert "policyengine-us==1.1.0" in pyproject_text - assert ( - 'US_VERSION = os.environ.get("POLICYENGINE_US_VERSION", "1.1.0")' in modal_text - ) assert "lock --upgrade-package policyengine-us" in uv_log.read_text( encoding="utf-8" ) diff --git a/projects/policyengine-api-simulation/tests/test_modal_scripts.py b/projects/policyengine-api-simulation/tests/test_modal_scripts.py index 31ef338ed..4911ef4ab 100644 --- a/projects/policyengine-api-simulation/tests/test_modal_scripts.py +++ b/projects/policyengine-api-simulation/tests/test_modal_scripts.py @@ -32,8 +32,8 @@ def test_script_is_executable_or_can_be_run_with_bash(self): ) assert result.returncode == 0, f"Syntax error in script: {result.stderr}" - def test_extracts_versions_from_uv_lock(self, temp_github_output): - """Should extract policyengine-us and policyengine-uk versions from uv.lock.""" + def test_extracts_versions_from_policyengine_bundle(self, temp_github_output): + """Should extract model and data versions from policyengine.py's bundle.""" project_dir = REPO_ROOT / "projects" / "policyengine-api-simulation" if not (project_dir / "uv.lock").exists(): @@ -54,8 +54,11 @@ def test_extracts_versions_from_uv_lock(self, temp_github_output): with open(temp_github_output) as f: output = f.read() - assert "us_version=" in output, "us_version not found in output" - assert "uk_version=" in output, "uk_version not found in output" + assert "policyengine_version=" in output + assert "us_version=" in output + assert "us_data_version=" in output + assert "uk_version=" in output + assert "uk_data_version=" in output class TestModalHealthCheck: diff --git a/projects/policyengine-api-simulation/tests/test_policyengine_dependency_source.py b/projects/policyengine-api-simulation/tests/test_policyengine_dependency_source.py index bef75549d..ebb6e6fae 100644 --- a/projects/policyengine-api-simulation/tests/test_policyengine_dependency_source.py +++ b/projects/policyengine-api-simulation/tests/test_policyengine_dependency_source.py @@ -8,6 +8,10 @@ PYPROJECT_PATH = REPO_ROOT / "pyproject.toml" MODAL_APP_PATH = REPO_ROOT / "src" / "modal" / "app.py" POLICYENGINE_DEPENDENCY_PREFIX = "policyengine==" +COUNTRY_PACKAGES = { + "us": "policyengine-us", + "uk": "policyengine-uk", +} def _load_toml(path: Path) -> dict: @@ -22,6 +26,14 @@ def _get_pyproject_policyengine_dependency(pyproject: dict) -> str: ) +def _get_dependency_pin(pyproject: dict, package: str) -> str: + dependencies = pyproject["project"]["dependencies"] + prefix = f"{package}==" + return next( + dep.removeprefix(prefix) for dep in dependencies if dep.startswith(prefix) + ) + + def _get_modal_policyengine_dependency(modal_source: str) -> str: match = re.search( r'"(policyengine==[^"]+)"', @@ -38,3 +50,15 @@ def test_policyengine_dependency_version_is_pinned_consistently(): assert pyproject_dependency.startswith(POLICYENGINE_DEPENDENCY_PREFIX) assert modal_dependency == pyproject_dependency + + +def test_country_package_pins_match_policyengine_bundle(): + from src.modal.release_bundle import get_country_release_bundle + + pyproject = _load_toml(PYPROJECT_PATH) + + for country, package in COUNTRY_PACKAGES.items(): + assert ( + _get_dependency_pin(pyproject, package) + == get_country_release_bundle(country).model_version + ) diff --git a/projects/policyengine-api-simulation/tests/test_release_bundle.py b/projects/policyengine-api-simulation/tests/test_release_bundle.py new file mode 100644 index 000000000..3a390041e --- /dev/null +++ b/projects/policyengine-api-simulation/tests/test_release_bundle.py @@ -0,0 +1,49 @@ +"""Tests for policyengine.py release bundle helpers.""" + +from src.modal.release_bundle import ( + get_country_release_bundle, + resolve_bundle_dataset_name, + resolve_bundle_dataset_uri, +) + + +def test_country_release_bundle_exposes_model_and_data_versions(): + us_bundle = get_country_release_bundle("us") + uk_bundle = get_country_release_bundle("uk") + + assert us_bundle.model_package_name == "policyengine-us" + assert us_bundle.model_version + assert us_bundle.data_package_name == "policyengine-us-data" + assert us_bundle.data_version + assert uk_bundle.model_package_name == "policyengine-uk" + assert uk_bundle.model_version + assert uk_bundle.data_package_name == "policyengine-uk-data" + assert uk_bundle.data_version + + +def test_resolve_bundle_dataset_name_uses_manifest_default(): + assert ( + resolve_bundle_dataset_name("us", None) + == get_country_release_bundle("us").default_dataset + ) + assert ( + resolve_bundle_dataset_name("uk", None) + == get_country_release_bundle("uk").default_dataset + ) + + +def test_resolve_bundle_dataset_uri_maps_known_aliases_to_manifest_uris(): + assert ( + resolve_bundle_dataset_uri("us", "enhanced_cps") + == get_country_release_bundle("us").default_dataset_uri + ) + assert ( + resolve_bundle_dataset_uri("uk", "enhanced_frs") + == get_country_release_bundle("uk").default_dataset_uri + ) + + +def test_resolve_bundle_dataset_uri_preserves_unmanaged_unknown_values(): + assert resolve_bundle_dataset_uri("us", "custom_dataset_label") == ( + "custom_dataset_label" + ) diff --git a/projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py b/projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py index 0b71ba754..c83fdae43 100644 --- a/projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py +++ b/projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py @@ -40,7 +40,7 @@ def _build_schema_output() -> SingleYearMacroOutput: return build_single_year_macro_output( country="us", - model_version="1.702.0", + model_version="1.700.0", data_version="1.115.5", budget=BUDGET, analysis=fake_analysis(), @@ -68,7 +68,7 @@ def test_builder_returns_schema_modules_before_legacy_dict_dump(): legacy_output = adapt_analysis_to_legacy_macro_output( country="us", - model_version="1.702.0", + model_version="1.700.0", data_version="1.115.5", budget=BUDGET, analysis=fake_analysis(), @@ -88,7 +88,7 @@ def test_adapter_returns_existing_single_year_macro_shape(): output = _build_schema_output().model_dump(mode="json") assert set(output) == CURRENT_SINGLE_YEAR_MACRO_KEYS - assert output["model_version"] == "1.702.0" + assert output["model_version"] == "1.700.0" assert output["data_version"] == "1.115.5" assert output["budget"] == BUDGET assert output["detailed_budget"] == { diff --git a/projects/policyengine-api-simulation/uv.lock b/projects/policyengine-api-simulation/uv.lock index 3623dbd7b..82207be8d 100644 --- a/projects/policyengine-api-simulation/uv.lock +++ b/projects/policyengine-api-simulation/uv.lock @@ -1736,7 +1736,7 @@ requires-dist = [ { name = "policyengine-core", specifier = "==3.26.1" }, { name = "policyengine-fastapi", editable = "../../libs/policyengine-fastapi" }, { name = "policyengine-uk", specifier = "==2.88.20" }, - { name = "policyengine-us", specifier = "==1.702.0" }, + { name = "policyengine-us", specifier = "==1.700.0" }, { name = "pydantic-settings", specifier = ">=2.7.1,<3.0.0" }, { name = "pyright", marker = "extra == 'build'", specifier = ">=1.1.401" }, { name = "pytest", marker = "extra == 'test'", specifier = ">=8.3.4" }, @@ -1763,7 +1763,7 @@ wheels = [ [[package]] name = "policyengine-us" -version = "1.702.0" +version = "1.700.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "microdf-python" }, @@ -1773,9 +1773,9 @@ dependencies = [ { name = "tables" }, { name = "tqdm" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/43/7e/d3095e6dde387cb56eb2dd0543cdc0b0f7670446d3b6ea45468165d60d1f/policyengine_us-1.702.0.tar.gz", hash = "sha256:689526d444c98681d517247d5308e795e02f24c65423295232ab347e61cac981", size = 9876039, upload-time = "2026-05-21T14:56:36.133Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0a/70/767fddeeb827e96e0f9a499c25f76c642767a129be259162eaf5b5954eb1/policyengine_us-1.700.0.tar.gz", hash = "sha256:63a7b1a2b8a0c903b6d704e8095f6880024e8fa93ea405912820a42589e90add", size = 9862854, upload-time = "2026-05-20T00:07:27.748Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/95/1d/67cde50bf6401c5c3ab95ff8f4036876422fa6fc72481425f3f3c7eb3177/policyengine_us-1.702.0-py3-none-any.whl", hash = "sha256:83d787337760587dbfcfe6bc2ae59afb53d2baa5827cb535776ff7147561a72f", size = 10649615, upload-time = "2026-05-21T14:56:33.349Z" }, + { url = "https://files.pythonhosted.org/packages/49/e9/2837a0d98e99efaf4d82aade276eee6eeff419df614863f08e3512961d2d/policyengine_us-1.700.0-py3-none-any.whl", hash = "sha256:7633d8aefcaf02d7628f841bc56750606f1d7fe409ff3ae7b0ef7e364a88e945", size = 10614505, upload-time = "2026-05-20T00:07:23.699Z" }, ] [[package]] From f1b2d08ad57af9016810ec1dbb869c3a1075fc14 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 21 May 2026 23:11:39 +0200 Subject: [PATCH 08/23] chore: update policyengine bundle as single stream --- .../scripts/check-country-package-updates.py | 136 ----------------- ...kage.sh => update-policyengine-package.sh} | 138 ++++++++++-------- ...es.yml => update-policyengine-package.yml} | 10 +- ...st_policyengine_package_update_scripts.py} | 84 ++++------- .../modal/utils/extract_bundle_versions.py | 14 +- ...st_policyengine_package_update_scripts.py} | 98 ++++--------- 6 files changed, 148 insertions(+), 332 deletions(-) delete mode 100644 .github/scripts/check-country-package-updates.py rename .github/scripts/{update-country-package.sh => update-policyengine-package.sh} (50%) mode change 100644 => 100755 rename .github/workflows/{update-country-packages.yml => update-policyengine-package.yml} (78%) rename projects/policyengine-api-simulation/fixtures/{test_country_package_update_scripts.py => test_policyengine_package_update_scripts.py} (61%) rename projects/policyengine-api-simulation/tests/{test_country_package_update_scripts.py => test_policyengine_package_update_scripts.py} (51%) diff --git a/.github/scripts/check-country-package-updates.py b/.github/scripts/check-country-package-updates.py deleted file mode 100644 index 87ebf31c3..000000000 --- a/.github/scripts/check-country-package-updates.py +++ /dev/null @@ -1,136 +0,0 @@ -#!/usr/bin/env python3 -"""Format country package changelog entries between two versions.""" - -from __future__ import annotations - -import argparse -import re -import sys -import urllib.error -import urllib.request - -REPO_MAP = { - "policyengine-us": "PolicyEngine/policyengine-us", - "policyengine-uk": "PolicyEngine/policyengine-uk", -} - - -def fetch_changelog(package: str) -> str | None: - repo = REPO_MAP.get(package) - if repo is None: - return None - - for branch in ("main", "master"): - url = f"https://raw.githubusercontent.com/{repo}/{branch}/CHANGELOG.md" - try: - with urllib.request.urlopen(url, timeout=30) as response: - if response.status == 200: - return response.read().decode("utf-8") - except (TimeoutError, urllib.error.URLError): - continue - - return None - - -def parse_version(version: str) -> tuple[int, int, int]: - parts = tuple(int(part) for part in version.split(".")) - if len(parts) != 3: - raise ValueError(f"Expected a semantic version, got {version!r}") - return parts - - -def parse_changelog(text: str) -> list[dict[str, object]]: - entries: list[dict[str, object]] = [] - current_entry: dict[str, object] | None = None - current_category: str | None = None - - for line in text.splitlines(): - version_match = re.match(r"^##\s+\[?(\d+\.\d+\.\d+)\]?", line) - if version_match: - current_entry = {"version": version_match.group(1), "changes": {}} - entries.append(current_entry) - current_category = None - continue - - if current_entry is None: - continue - - category_match = re.match(r"^###\s+(.+)", line) - if category_match: - current_category = category_match.group(1).strip().lower() - continue - - item_match = re.match(r"^-\s+(.+)", line) - if item_match and current_category: - changes = current_entry["changes"] - assert isinstance(changes, dict) - changes.setdefault(current_category, []) - changes[current_category].append(item_match.group(1)) - - return entries - - -def get_changes_between( - changelog: list[dict[str, object]], old_version: str, new_version: str -) -> list[dict[str, object]]: - old_v = parse_version(old_version) - new_v = parse_version(new_version) - entries = [] - for entry in changelog: - version = entry.get("version") - if isinstance(version, str) and old_v < parse_version(version) <= new_v: - entries.append(entry) - return entries - - -def format_changes(entries: list[dict[str, object]]) -> str: - preferred_order = ("added", "changed", "fixed", "removed", "deprecated") - buckets: dict[str, list[str]] = {category: [] for category in preferred_order} - extra_buckets: dict[str, list[str]] = {} - - for entry in entries: - changes = entry.get("changes", {}) - if not isinstance(changes, dict): - continue - for category, items in changes.items(): - if not isinstance(category, str) or not isinstance(items, list): - continue - target = buckets if category in buckets else extra_buckets - target.setdefault(category, []) - target[category].extend(str(item) for item in items) - - sections = [] - for category in (*preferred_order, *sorted(extra_buckets)): - items = buckets.get(category) or extra_buckets.get(category) or [] - if items: - body = "\n".join(f"- {item}" for item in items) - sections.append(f"### {category.capitalize()}\n{body}") - - return "\n\n".join(sections) - - -def main() -> int: - parser = argparse.ArgumentParser() - parser.add_argument("--package", required=True) - parser.add_argument("--old-version", required=True) - parser.add_argument("--new-version", required=True) - args = parser.parse_args() - - changelog_text = fetch_changelog(args.package) - if changelog_text is None: - print(f"Could not fetch changelog for {args.package}.", file=sys.stderr) - return 0 - - changes = get_changes_between( - parse_changelog(changelog_text), args.old_version, args.new_version - ) - if not changes: - print("No changelog entries found between these versions.") - return 0 - - print(format_changes(changes)) - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/.github/scripts/update-country-package.sh b/.github/scripts/update-policyengine-package.sh old mode 100644 new mode 100755 similarity index 50% rename from .github/scripts/update-country-package.sh rename to .github/scripts/update-policyengine-package.sh index 1fd4c9d19..8a8a52c97 --- a/.github/scripts/update-country-package.sh +++ b/.github/scripts/update-policyengine-package.sh @@ -1,38 +1,29 @@ #!/usr/bin/env bash # -# Sync a country package pin to the version bundled by policyengine.py and -# open a version-specific PR when the local pin is off-bundle. +# Check PyPI for a newer policyengine.py package, update the simulation project +# pin, sync country package pins to that policyengine.py bundle, and open one +# bundle-level PR. # # Usage: -# .github/scripts/update-country-package.sh policyengine-us [--dry-run] -# .github/scripts/update-country-package.sh policyengine-uk [--dry-run] +# .github/scripts/update-policyengine-package.sh [--dry-run] # # Optional environment: # PROJECT_DIR Project containing pyproject.toml and uv.lock. -# LATEST_OVERRIDE Version to use instead of querying the bundle, for local checks. +# LATEST_OVERRIDE policyengine version to use instead of querying PyPI. # DRY_RUN=1 Report planned changes without editing files or opening a PR. set -euo pipefail -PACKAGE="${1:?Usage: update-country-package.sh [--dry-run]}" DRY_RUN="${DRY_RUN:-0}" -if [[ "${2:-}" == "--dry-run" ]]; then +if [[ "${1:-}" == "--dry-run" ]]; then DRY_RUN=1 +elif [[ -n "${1:-}" ]]; then + echo "ERROR: Unsupported argument '${1}'." >&2 + echo "Usage: update-policyengine-package.sh [--dry-run]" >&2 + exit 1 fi -case "$PACKAGE" in - policyengine-us) - DISPLAY_NAME="PolicyEngine US" - ;; - policyengine-uk) - DISPLAY_NAME="PolicyEngine UK" - ;; - *) - echo "ERROR: Unsupported package '${PACKAGE}'." >&2 - exit 1 - ;; -esac - +PACKAGE="policyengine" ROOT_DIR="$(git rev-parse --show-toplevel)" PROJECT_DIR="${PROJECT_DIR:-projects/policyengine-api-simulation}" PROJECT_PATH="${ROOT_DIR}/${PROJECT_DIR}" @@ -40,25 +31,19 @@ PYPROJECT="${PROJECT_PATH}/pyproject.toml" LOCKFILE="${PROJECT_PATH}/uv.lock" create_pr_body_file() { - local changelog local pr_body_file - changelog=$(python3 "${ROOT_DIR}/.github/scripts/check-country-package-updates.py" \ - --package "$PACKAGE" \ - --old-version "$CURRENT" \ - --new-version "$LATEST" 2>/dev/null || true) - pr_body_file="$(mktemp)" { echo "## Summary" echo - echo "Update ${DISPLAY_NAME} from ${CURRENT} to ${LATEST} in the simulation API runtime." - if [[ -n "$changelog" ]]; then - echo - echo "## What changed (${CURRENT} -> ${LATEST})" - echo - echo "$changelog" - fi + echo "Update policyengine.py from ${CURRENT} to ${LATEST} in the simulation API runtime." + echo + echo "This also syncs country package pins to the versions bundled by policyengine.py ${LATEST}:" + echo "- policyengine-us: ${BUNDLED_US_VERSION:-resolved from bundle during update}" + echo "- policyengine-uk: ${BUNDLED_UK_VERSION:-resolved from bundle during update}" + echo + echo "Country data package versions remain manifest-derived at runtime/deploy time rather than independently pinned here." echo echo "---" echo "Generated automatically by GitHub Actions." @@ -72,16 +57,16 @@ if [[ ! -f "$PYPROJECT" || ! -f "$LOCKFILE" ]]; then exit 1 fi -CURRENT=$(python3 - "$LOCKFILE" "$PACKAGE" <<'PY' +CURRENT=$(python3 - "$PYPROJECT" "$PACKAGE" <<'PY' import re import sys +from pathlib import Path -lockfile, package = sys.argv[1:] -text = open(lockfile, encoding="utf-8").read() -pattern = rf'\[\[package\]\]\s+name = "{re.escape(package)}"\s+version = "([^"]+)"' -match = re.search(pattern, text) +pyproject, package = sys.argv[1:] +text = Path(pyproject).read_text(encoding="utf-8") +match = re.search(rf'"{re.escape(package)}==([^"]+)"', text) if not match: - raise SystemExit(f"Package {package!r} not found in {lockfile}") + raise SystemExit(f"Package {package!r} not found in {pyproject}") print(match.group(1)) PY ) @@ -89,23 +74,9 @@ PY if [[ -n "${LATEST_OVERRIDE:-}" ]]; then LATEST="$LATEST_OVERRIDE" else - LATEST=$( - cd "$PROJECT_PATH" - uv run python - "$PACKAGE" <<'PY' -import sys - -from src.modal.release_bundle import get_country_release_bundle - -package = sys.argv[1] -country_by_package = { - "policyengine-us": "us", - "policyengine-uk": "uk", -} -print(get_country_release_bundle(country_by_package[package]).model_version) -PY - ) + LATEST=$(curl -fsSL "https://pypi.org/pypi/${PACKAGE}/json" | python3 -c 'import json, sys; print(json.load(sys.stdin)["info"]["version"])') if [[ -z "$LATEST" ]]; then - echo "ERROR: Could not resolve bundled version for ${PACKAGE}." >&2 + echo "ERROR: Could not fetch latest version for ${PACKAGE} from PyPI." >&2 exit 1 fi fi @@ -115,15 +86,15 @@ if [[ -z "$LATEST" ]]; then exit 1 fi -echo "Current locked version: ${PACKAGE}==${CURRENT}" -echo "Bundled .py version: ${PACKAGE}==${LATEST}" +echo "Current pinned version: ${PACKAGE}==${CURRENT}" +echo "Latest PyPI version: ${PACKAGE}==${LATEST}" if [[ "$CURRENT" == "$LATEST" ]]; then echo "Already up to date. Nothing to do." exit 0 fi -BRANCH="auto/update-${PACKAGE}-${LATEST}" +BRANCH="auto/update-policyengine-${LATEST}" echo "Update available: ${CURRENT} -> ${LATEST}" if [[ "$DRY_RUN" == "1" ]]; then @@ -153,7 +124,7 @@ if git ls-remote --exit-code --heads origin "$BRANCH" >/dev/null 2>&1; then gh pr create \ --base main \ --head "$BRANCH" \ - --title "chore(deps): update ${PACKAGE} to ${LATEST}" \ + --title "chore(deps): update policyengine to ${LATEST}" \ --body-file "$PR_BODY_FILE" echo "PR created for existing branch ${BRANCH}" exit 0 @@ -183,6 +154,49 @@ PY uv lock --upgrade-package "$PACKAGE" ) +BUNDLE_OUTPUT=$( + cd "$PROJECT_PATH" + uv run python -m src.modal.utils.extract_bundle_versions --shell +) +BUNDLED_US_VERSION=$(printf '%s\n' "$BUNDLE_OUTPUT" | awk -F= '$1 == "us_version" {print $2}') +BUNDLED_UK_VERSION=$(printf '%s\n' "$BUNDLE_OUTPUT" | awk -F= '$1 == "uk_version" {print $2}') + +if [[ -z "$BUNDLED_US_VERSION" || -z "$BUNDLED_UK_VERSION" ]]; then + echo "ERROR: Could not resolve bundled country package versions." >&2 + echo "$BUNDLE_OUTPUT" >&2 + exit 1 +fi + +echo "Bundled country pins:" +echo " policyengine-us==${BUNDLED_US_VERSION}" +echo " policyengine-uk==${BUNDLED_UK_VERSION}" + +python3 - "$PYPROJECT" "$BUNDLED_US_VERSION" "$BUNDLED_UK_VERSION" <<'PY' +import re +import sys +from pathlib import Path + +pyproject_path, us_version, uk_version = sys.argv[1:] +pyproject = Path(pyproject_path) +text = pyproject.read_text(encoding="utf-8") +pins = { + "policyengine-us": us_version, + "policyengine-uk": uk_version, +} +for package, version in pins.items(): + pattern = rf'"{re.escape(package)}==[^"]+"' + replacement = f'"{package}=={version}"' + text, count = re.subn(pattern, replacement, text, count=1) + if count != 1: + raise SystemExit(f"Could not update {package} in {pyproject}") +pyproject.write_text(text, encoding="utf-8") +PY + +( + cd "$PROJECT_PATH" + uv lock +) + if git diff --quiet -- "$PYPROJECT" "$LOCKFILE"; then echo "No changes after update. Nothing to do." exit 0 @@ -191,12 +205,12 @@ fi PR_BODY_FILE="$(create_pr_body_file)" git add "$PYPROJECT" "$LOCKFILE" -git commit -m "chore(deps): update ${PACKAGE} to ${LATEST}" +git commit -m "chore(deps): update policyengine to ${LATEST}" git push -u origin "$BRANCH" gh pr create \ --base main \ - --title "chore(deps): update ${PACKAGE} to ${LATEST}" \ + --title "chore(deps): update policyengine to ${LATEST}" \ --body-file "$PR_BODY_FILE" -echo "PR created for ${PACKAGE} ${CURRENT} -> ${LATEST}" +echo "PR created for policyengine ${CURRENT} -> ${LATEST}" diff --git a/.github/workflows/update-country-packages.yml b/.github/workflows/update-policyengine-package.yml similarity index 78% rename from .github/workflows/update-country-packages.yml rename to .github/workflows/update-policyengine-package.yml index 581eec077..f432f6a8a 100644 --- a/.github/workflows/update-country-packages.yml +++ b/.github/workflows/update-policyengine-package.yml @@ -1,4 +1,4 @@ -name: Update country packages +name: Update policyengine package on: schedule: @@ -11,13 +11,9 @@ permissions: jobs: update: - name: Update ${{ matrix.package }} + name: Update policyengine.py bundle runs-on: ubuntu-latest if: github.repository == 'PolicyEngine/policyengine-api-v2' - strategy: - matrix: - package: [policyengine-us, policyengine-uk] - fail-fast: false steps: - name: Generate GitHub App token @@ -46,4 +42,4 @@ jobs: - name: Check for update and open PR env: GH_TOKEN: ${{ steps.app-token.outputs.token }} - run: bash .github/scripts/update-country-package.sh ${{ matrix.package }} + run: bash .github/scripts/update-policyengine-package.sh diff --git a/projects/policyengine-api-simulation/fixtures/test_country_package_update_scripts.py b/projects/policyengine-api-simulation/fixtures/test_policyengine_package_update_scripts.py similarity index 61% rename from projects/policyengine-api-simulation/fixtures/test_country_package_update_scripts.py rename to projects/policyengine-api-simulation/fixtures/test_policyengine_package_update_scripts.py index 08a68a7e7..d0d8ef946 100644 --- a/projects/policyengine-api-simulation/fixtures/test_country_package_update_scripts.py +++ b/projects/policyengine-api-simulation/fixtures/test_policyengine_package_update_scripts.py @@ -1,63 +1,28 @@ -"""Fixtures and helpers for country package updater script tests.""" +"""Fixtures and helpers for policyengine package updater script tests.""" from __future__ import annotations -import importlib.util import os import subprocess from pathlib import Path -from types import ModuleType import pytest from fixtures.test_modal_scripts import REPO_ROOT, SCRIPTS_DIR - -SCRIPT = SCRIPTS_DIR / "update-country-package.sh" -CHANGELOG_SCRIPT = SCRIPTS_DIR / "check-country-package-updates.py" -SAMPLE_CHANGELOG = """ -# Changelog - -## 1.2.2 -### Added -- New variable - -### Fixed -- Important bug fix - -## [1.2.1] -### Changed -- Existing calculation changed - -## 1.2.0 -### Added -- Old change -""" - - -@pytest.fixture -def changelog_module() -> ModuleType: - spec = importlib.util.spec_from_file_location( - "check_country_package_updates", CHANGELOG_SCRIPT - ) - assert spec is not None - assert spec.loader is not None - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module +SCRIPT = SCRIPTS_DIR / "update-policyengine-package.sh" @pytest.fixture def fake_repo(tmp_path: Path) -> Path: project = tmp_path / "simulation" - modal_dir = project / "src" / "modal" - modal_dir.mkdir(parents=True) + project.mkdir(parents=True) (project / "pyproject.toml").write_text( "\n".join( [ "[project]", - 'dependencies = ["policyengine-us==1.0.0", "policyengine-uk==2.0.0"]', + 'dependencies = ["policyengine==4.0.0", "policyengine-us==1.0.0", "policyengine-uk==2.0.0"]', ] ), encoding="utf-8", @@ -65,6 +30,10 @@ def fake_repo(tmp_path: Path) -> Path: (project / "uv.lock").write_text( "\n".join( [ + "[[package]]", + 'name = "policyengine"', + 'version = "4.0.0"', + "", "[[package]]", 'name = "policyengine-us"', 'version = "1.0.0"', @@ -76,23 +45,6 @@ def fake_repo(tmp_path: Path) -> Path: ), encoding="utf-8", ) - (modal_dir / "app.py").write_text( - "\n".join( - [ - "import os", - 'US_VERSION = os.environ.get("POLICYENGINE_US_VERSION", "1.0.0")', - 'UK_VERSION = os.environ.get("POLICYENGINE_UK_VERSION", "2.0.0")', - ] - ), - encoding="utf-8", - ) - - helper_dir = tmp_path / ".github" / "scripts" - helper_dir.mkdir(parents=True) - (helper_dir / "check-country-package-updates.py").write_text( - '#!/usr/bin/env python3\nprint("### Added\\n- Example upstream change")\n', - encoding="utf-8", - ) return tmp_path @@ -168,18 +120,34 @@ def install_fake_gh(fake_bin: Path, *, log: Path, open_pr: str = "") -> None: ) -def install_fake_uv(fake_bin: Path, *, log: Path) -> None: +def install_fake_uv( + fake_bin: Path, + *, + log: Path, + bundled_us_version: str = "1.1.0", + bundled_uk_version: str = "2.1.0", +) -> None: write_executable( fake_bin / "uv", f"""#!/usr/bin/env bash set -euo pipefail printf 'uv %s\\n' "$*" >> "{log}" + +if [[ "$1" == "run" && "$2" == "python" && "$3" == "-m" && "$4" == "src.modal.utils.extract_bundle_versions" ]]; then + echo "policyengine_version=4.1.0" + echo "us_version={bundled_us_version}" + echo "us_data_version=1.10.0" + echo "uk_version={bundled_uk_version}" + echo "uk_data_version=1.20.0" + exit 0 +fi + exit 0 """, ) -def updater_env(fake_bin: Path, fake_repo: Path, **extra: str) -> dict[str, str]: +def updater_env(fake_bin: Path, **extra: str) -> dict[str, str]: env = os.environ.copy() env.update( { diff --git a/projects/policyengine-api-simulation/src/modal/utils/extract_bundle_versions.py b/projects/policyengine-api-simulation/src/modal/utils/extract_bundle_versions.py index bf3d91086..830eb0654 100644 --- a/projects/policyengine-api-simulation/src/modal/utils/extract_bundle_versions.py +++ b/projects/policyengine-api-simulation/src/modal/utils/extract_bundle_versions.py @@ -3,16 +3,17 @@ from __future__ import annotations import os +import sys from pathlib import Path from src.modal.release_bundle import get_country_release_bundle -def main() -> None: +def _bundle_outputs() -> dict[str, str]: us_bundle = get_country_release_bundle("us") uk_bundle = get_country_release_bundle("uk") - outputs = { + return { "policyengine_version": us_bundle.policyengine_version, "us_version": us_bundle.model_version, "us_data_version": us_bundle.data_version, @@ -20,6 +21,15 @@ def main() -> None: "uk_data_version": uk_bundle.data_version, } + +def main() -> None: + outputs = _bundle_outputs() + + if "--shell" in sys.argv[1:]: + for key, value in outputs.items(): + print(f"{key}={value}") + return + github_output = os.environ.get("GITHUB_OUTPUT") if github_output: output_path = Path(github_output) diff --git a/projects/policyengine-api-simulation/tests/test_country_package_update_scripts.py b/projects/policyengine-api-simulation/tests/test_policyengine_package_update_scripts.py similarity index 51% rename from projects/policyengine-api-simulation/tests/test_country_package_update_scripts.py rename to projects/policyengine-api-simulation/tests/test_policyengine_package_update_scripts.py index a011390be..6849c19f6 100644 --- a/projects/policyengine-api-simulation/tests/test_country_package_update_scripts.py +++ b/projects/policyengine-api-simulation/tests/test_policyengine_package_update_scripts.py @@ -1,15 +1,11 @@ -"""Unit tests for country package updater scripts.""" +"""Unit tests for policyengine package updater scripts.""" from __future__ import annotations import subprocess from pathlib import Path -from types import ModuleType -import pytest - -from fixtures.test_country_package_update_scripts import ( - SAMPLE_CHANGELOG, +from fixtures.test_policyengine_package_update_scripts import ( SCRIPT, install_fake_gh, install_fake_git, @@ -18,10 +14,10 @@ updater_env, ) -pytest_plugins = ("fixtures.test_country_package_update_scripts",) +pytest_plugins = ("fixtures.test_policyengine_package_update_scripts",) -def test_update_country_package_script_has_valid_bash_syntax() -> None: +def test_update_policyengine_package_script_has_valid_bash_syntax() -> None: result = subprocess.run( ["bash", "-n", str(SCRIPT)], capture_output=True, @@ -31,22 +27,22 @@ def test_update_country_package_script_has_valid_bash_syntax() -> None: assert result.returncode == 0, result.stderr -def test_update_country_package_rejects_unknown_package( +def test_update_policyengine_package_rejects_unknown_argument( fake_bin: Path, fake_repo: Path, tmp_path: Path ) -> None: git_log = tmp_path / "git.log" install_fake_git(fake_bin, root=fake_repo, log=git_log) result = run_updater( - "policyengine-ca", - env=updater_env(fake_bin, fake_repo, LATEST_OVERRIDE="1.1.0"), + "policyengine-us", + env=updater_env(fake_bin, LATEST_OVERRIDE="4.1.0"), ) assert result.returncode != 0 - assert "Unsupported package 'policyengine-ca'" in result.stderr + assert "Unsupported argument 'policyengine-us'" in result.stderr -def test_update_country_package_dry_run_reports_planned_changes_without_editing( +def test_update_policyengine_package_dry_run_reports_planned_changes_without_editing( fake_bin: Path, fake_repo: Path, tmp_path: Path ) -> None: git_log = tmp_path / "git.log" @@ -55,20 +51,19 @@ def test_update_country_package_dry_run_reports_planned_changes_without_editing( original_pyproject = pyproject.read_text(encoding="utf-8") result = run_updater( - "policyengine-us", "--dry-run", - env=updater_env(fake_bin, fake_repo, LATEST_OVERRIDE="1.1.0"), + env=updater_env(fake_bin, LATEST_OVERRIDE="4.1.0"), ) assert result.returncode == 0, result.stderr - assert "Update available: 1.0.0 -> 1.1.0" in result.stdout - assert "Dry run: would create auto/update-policyengine-us-1.1.0" in result.stdout + assert "Update available: 4.0.0 -> 4.1.0" in result.stdout + assert "Dry run: would create auto/update-policyengine-4.1.0" in result.stdout assert "simulation/pyproject.toml" in result.stdout assert "simulation/uv.lock" in result.stdout assert pyproject.read_text(encoding="utf-8") == original_pyproject -def test_update_country_package_dry_run_reports_existing_branch_recovery( +def test_update_policyengine_package_dry_run_reports_existing_branch_recovery( fake_bin: Path, fake_repo: Path, tmp_path: Path ) -> None: git_log = tmp_path / "git.log" @@ -80,19 +75,18 @@ def test_update_country_package_dry_run_reports_existing_branch_recovery( ) result = run_updater( - "policyengine-us", "--dry-run", - env=updater_env(fake_bin, fake_repo, LATEST_OVERRIDE="1.1.0"), + env=updater_env(fake_bin, LATEST_OVERRIDE="4.1.0"), ) assert result.returncode == 0, result.stderr assert ( - "remote branch 'auto/update-policyengine-us-1.1.0' already exists; " + "remote branch 'auto/update-policyengine-4.1.0' already exists; " "would ensure a PR exists for it." ) in result.stdout -def test_update_country_package_skips_when_open_pr_exists( +def test_update_policyengine_package_skips_when_open_pr_exists( fake_bin: Path, fake_repo: Path, tmp_path: Path ) -> None: git_log = tmp_path / "git.log" @@ -101,18 +95,15 @@ def test_update_country_package_skips_when_open_pr_exists( install_fake_gh(fake_bin, log=gh_log, open_pr="123") result = run_updater( - "policyengine-us", - env=updater_env(fake_bin, fake_repo, LATEST_OVERRIDE="1.1.0"), + env=updater_env(fake_bin, LATEST_OVERRIDE="4.1.0"), ) assert result.returncode == 0, result.stderr - assert ( - "PR #123 already exists for auto/update-policyengine-us-1.1.0" in result.stdout - ) + assert "PR #123 already exists for auto/update-policyengine-4.1.0" in result.stdout assert "pr create" not in gh_log.read_text(encoding="utf-8") -def test_update_country_package_opens_pr_for_existing_branch_without_open_pr( +def test_update_policyengine_package_opens_pr_for_existing_branch_without_open_pr( fake_bin: Path, fake_repo: Path, tmp_path: Path ) -> None: git_log = tmp_path / "git.log" @@ -126,8 +117,7 @@ def test_update_country_package_opens_pr_for_existing_branch_without_open_pr( install_fake_gh(fake_bin, log=gh_log) result = run_updater( - "policyengine-us", - env=updater_env(fake_bin, fake_repo, LATEST_OVERRIDE="1.1.0"), + env=updater_env(fake_bin, LATEST_OVERRIDE="4.1.0"), ) assert result.returncode == 0, result.stderr @@ -135,10 +125,10 @@ def test_update_country_package_opens_pr_for_existing_branch_without_open_pr( gh_calls = gh_log.read_text(encoding="utf-8") assert "pr list" in gh_calls assert "pr create" in gh_calls - assert "--head auto/update-policyengine-us-1.1.0" in gh_calls + assert "--head auto/update-policyengine-4.1.0" in gh_calls -def test_update_country_package_updates_files_and_opens_pr( +def test_update_policyengine_package_updates_py_and_bundled_country_pins( fake_bin: Path, fake_repo: Path, tmp_path: Path ) -> None: git_log = tmp_path / "git.log" @@ -149,49 +139,23 @@ def test_update_country_package_updates_files_and_opens_pr( install_fake_uv(fake_bin, log=uv_log) result = run_updater( - "policyengine-us", - env=updater_env(fake_bin, fake_repo, LATEST_OVERRIDE="1.1.0"), + env=updater_env(fake_bin, LATEST_OVERRIDE="4.1.0"), ) assert result.returncode == 0, result.stderr - assert "PR created for policyengine-us 1.0.0 -> 1.1.0" in result.stdout + assert "PR created for policyengine 4.0.0 -> 4.1.0" in result.stdout pyproject_text = (fake_repo / "simulation" / "pyproject.toml").read_text( encoding="utf-8" ) + assert "policyengine==4.1.0" in pyproject_text assert "policyengine-us==1.1.0" in pyproject_text - assert "lock --upgrade-package policyengine-us" in uv_log.read_text( - encoding="utf-8" - ) - assert "checkout -b auto/update-policyengine-us-1.1.0" in git_log.read_text( + assert "policyengine-uk==2.1.0" in pyproject_text + uv_calls = uv_log.read_text(encoding="utf-8") + assert "lock --upgrade-package policyengine" in uv_calls + assert "run python -m src.modal.utils.extract_bundle_versions --shell" in uv_calls + assert "uv lock" in uv_calls + assert "checkout -b auto/update-policyengine-4.1.0" in git_log.read_text( encoding="utf-8" ) assert "pr create" in gh_log.read_text(encoding="utf-8") - - -def test_parse_changelog_collects_versioned_category_items( - changelog_module: ModuleType, -) -> None: - parsed = changelog_module.parse_changelog(SAMPLE_CHANGELOG) - changes = changelog_module.get_changes_between(parsed, "1.2.0", "1.2.2") - formatted = changelog_module.format_changes(changes) - - assert "### Added\n- New variable" in formatted - assert "### Changed\n- Existing calculation changed" in formatted - assert "### Fixed\n- Important bug fix" in formatted - assert "Old change" not in formatted - - -def test_parse_version_requires_three_numeric_parts( - changelog_module: ModuleType, -) -> None: - assert changelog_module.parse_version("1.2.3") == (1, 2, 3) - - with pytest.raises(ValueError, match="Expected a semantic version"): - changelog_module.parse_version("1.2") - - -def test_fetch_changelog_returns_none_for_unknown_package( - changelog_module: ModuleType, -) -> None: - assert changelog_module.fetch_changelog("policyengine-ca") is None From 94d7f5c2e6644a7d20d7139e1ff2bd3ccabe7e53 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 21 May 2026 23:17:16 +0200 Subject: [PATCH 09/23] chore: rename policyengine update workflow file --- ...te-policyengine-package.yml => check-policyengine-updates.yml} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename .github/workflows/{update-policyengine-package.yml => check-policyengine-updates.yml} (100%) diff --git a/.github/workflows/update-policyengine-package.yml b/.github/workflows/check-policyengine-updates.yml similarity index 100% rename from .github/workflows/update-policyengine-package.yml rename to .github/workflows/check-policyengine-updates.yml From 9b067a576237515c57dece2897ea346e7a840228 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 22 May 2026 14:41:35 +0200 Subject: [PATCH 10/23] refactor: use typed macro output modules --- .../src/modal/simulation.py | 166 ++++++++++++------ .../src/modal/simulation_macro_output.py | 45 ++++- .../src/modal/simulation_output_adapter.py | 62 +++++-- .../tests/test_simulation_output_adapter.py | 17 +- 4 files changed, 211 insertions(+), 79 deletions(-) diff --git a/projects/policyengine-api-simulation/src/modal/simulation.py b/projects/policyengine-api-simulation/src/modal/simulation.py index c27fdbe1c..53c79cbcb 100644 --- a/projects/policyengine-api-simulation/src/modal/simulation.py +++ b/projects/policyengine-api-simulation/src/modal/simulation.py @@ -18,7 +18,24 @@ get_country_release_bundle, resolve_bundle_dataset_name, ) -from src.modal.simulation_output_adapter import adapt_analysis_to_legacy_macro_output +from src.modal.simulation_macro_output import ( + BudgetaryImpact, + GeographicImpactOutput, + IntraDecileOutput, + PovertyModuleOutputs, + SingleYearMacroOutput, +) +from src.modal.simulation_output_adapter import ( + build_decile_output, + build_detailed_budget_output, + build_geographic_impact_output, + build_inequality_output, + build_intra_decile_output, + build_labor_supply_response_output, + build_poverty_by_gender_output, + build_poverty_by_race_output, + build_poverty_output, +) from src.modal.telemetry import split_internal_payload logger = logging.getLogger(__name__) @@ -321,7 +338,7 @@ def _try_change_output_variable(baseline, reform, variable: str, entity: str) -> return 0.0 -def _budget_result(country: str, baseline, reform) -> dict[str, float]: +def _budget_result(country: str, baseline, reform) -> BudgetaryImpact: tax_revenue_impact = _try_change_output_variable( baseline, reform, "household_tax", entity="household" ) @@ -339,18 +356,18 @@ def _budget_result(country: str, baseline, reform) -> dict[str, float]: else 0.0 ) - return { - "tax_revenue_impact": tax_revenue_impact, - "state_tax_revenue_impact": state_tax_revenue_impact, - "benefit_spending_impact": benefit_spending_impact, - "budgetary_impact": tax_revenue_impact - benefit_spending_impact, - "households": _try_sum_output_variable( + return BudgetaryImpact( + tax_revenue_impact=tax_revenue_impact, + state_tax_revenue_impact=state_tax_revenue_impact, + benefit_spending_impact=benefit_spending_impact, + budgetary_impact=tax_revenue_impact - benefit_spending_impact, + households=_try_sum_output_variable( baseline, "household_weight", entity="household" ), - "baseline_net_income": _try_sum_output_variable( + baseline_net_income=_try_sum_output_variable( baseline, "household_net_income", entity="household" ), - } + ) def _output_module_function(module_name: str, name: str): @@ -370,50 +387,69 @@ def _try_compute_output(label: str, fn, *args, **kwargs): return None -def _additional_poverty_outputs(country: str, baseline, reform) -> dict[str, Any]: +def _poverty_outputs(country: str, baseline, reform, analysis) -> PovertyModuleOutputs: prefix = "us" if country == "us" else "uk" - output = { - "baseline_poverty_by_age": _try_compute_output( - "baseline poverty by age", - _poverty_module_function(f"calculate_{prefix}_poverty_by_age"), - baseline, - ), - "reform_poverty_by_age": _try_compute_output( - "reform poverty by age", - _poverty_module_function(f"calculate_{prefix}_poverty_by_age"), - reform, - ), - "baseline_poverty_by_gender": _try_compute_output( - "baseline poverty by gender", - _poverty_module_function(f"calculate_{prefix}_poverty_by_gender"), - baseline, - ), - "reform_poverty_by_gender": _try_compute_output( - "reform poverty by gender", - _poverty_module_function(f"calculate_{prefix}_poverty_by_gender"), - reform, - ), - "baseline_poverty_by_race": None, - "reform_poverty_by_race": None, - } + baseline_poverty_by_age = _try_compute_output( + "baseline poverty by age", + _poverty_module_function(f"calculate_{prefix}_poverty_by_age"), + baseline, + ) + reform_poverty_by_age = _try_compute_output( + "reform poverty by age", + _poverty_module_function(f"calculate_{prefix}_poverty_by_age"), + reform, + ) + baseline_poverty_by_gender = _try_compute_output( + "baseline poverty by gender", + _poverty_module_function(f"calculate_{prefix}_poverty_by_gender"), + baseline, + ) + reform_poverty_by_gender = _try_compute_output( + "reform poverty by gender", + _poverty_module_function(f"calculate_{prefix}_poverty_by_gender"), + reform, + ) + baseline_poverty_by_race = None + reform_poverty_by_race = None if country == "us": - output["baseline_poverty_by_race"] = _try_compute_output( + baseline_poverty_by_race = _try_compute_output( "baseline poverty by race", _poverty_module_function("calculate_us_poverty_by_race"), baseline, ) - output["reform_poverty_by_race"] = _try_compute_output( + reform_poverty_by_race = _try_compute_output( "reform poverty by race", _poverty_module_function("calculate_us_poverty_by_race"), reform, ) - return output + return PovertyModuleOutputs( + poverty=build_poverty_output( + country, + baseline=getattr(analysis, "baseline_poverty", None), + reform=getattr(analysis, "reform_poverty", None), + baseline_by_age=baseline_poverty_by_age, + reform_by_age=reform_poverty_by_age, + ), + poverty_by_gender=build_poverty_by_gender_output( + country, + baseline_by_gender=baseline_poverty_by_gender, + reform_by_gender=reform_poverty_by_gender, + ), + poverty_by_race=( + build_poverty_by_race_output( + baseline_by_race=baseline_poverty_by_race, + reform_by_race=reform_poverty_by_race, + ) + if country == "us" + else None + ), + ) -def _intra_decile_output(baseline, reform): +def _intra_decile_output(baseline, reform) -> IntraDecileOutput: from policyengine.outputs.intra_decile_impact import compute_intra_decile_impacts - return _try_compute_output( + collection = _try_compute_output( "intra-decile impacts", compute_intra_decile_impacts, baseline, @@ -421,9 +457,12 @@ def _intra_decile_output(baseline, reform): income_variable="household_net_income", entity="household", ) + return build_intra_decile_output(collection) -def _congressional_district_impact(country: str, baseline, reform): +def _congressional_district_impact( + country: str, baseline, reform +) -> GeographicImpactOutput | None: if country != "us": return None @@ -437,10 +476,14 @@ def _congressional_district_impact(country: str, baseline, reform): baseline, reform, ) - return getattr(impact, "district_results", None) if impact is not None else None + return build_geographic_impact_output( + getattr(impact, "district_results", None) if impact is not None else None + ) -def _uk_constituency_impact(country: str, baseline, reform): +def _uk_constituency_impact( + country: str, baseline, reform +) -> GeographicImpactOutput | None: if country != "uk": return None @@ -454,10 +497,12 @@ def _uk_constituency_impact(country: str, baseline, reform): ) if impact is None: return None - return getattr(impact, "constituency_results", None) + return build_geographic_impact_output(getattr(impact, "constituency_results", None)) -def _uk_local_authority_impact(country: str, baseline, reform): +def _uk_local_authority_impact( + country: str, baseline, reform +) -> GeographicImpactOutput | None: if country != "uk": return None @@ -471,7 +516,9 @@ def _uk_local_authority_impact(country: str, baseline, reform): ) if impact is None: return None - return getattr(impact, "local_authority_results", None) + return build_geographic_impact_output( + getattr(impact, "local_authority_results", None) + ) def _model_version(country_module) -> str: @@ -527,7 +574,7 @@ def _run_simulation_impl_core(params: dict) -> dict: logger.info("Calculating economic impact") analysis = country_module.economic_impact_analysis(baseline, reform) budget = _budget_result(country, baseline, reform) - poverty_outputs = _additional_poverty_outputs(country, baseline, reform) + poverty_outputs = _poverty_outputs(country, baseline, reform, analysis) intra_decile = _intra_decile_output(baseline, reform) congressional_district_impact = _congressional_district_impact( country, baseline, reform @@ -536,15 +583,32 @@ def _run_simulation_impl_core(params: dict) -> dict: local_authority_impact = _uk_local_authority_impact(country, baseline, reform) logger.info("Comparison complete") - return adapt_analysis_to_legacy_macro_output( - country=country, + wealth_decile = getattr(analysis, "wealth_decile_impacts", None) + intra_wealth_decile = getattr(analysis, "intra_wealth_decile_impacts", None) + output = SingleYearMacroOutput( model_version=_model_version(country_module), data_version=_data_version(simulation_params, dataset), budget=budget, - analysis=analysis, + detailed_budget=build_detailed_budget_output( + getattr(analysis, "program_statistics", None) + ), + decile=build_decile_output(getattr(analysis, "decile_impacts", None)), + inequality=build_inequality_output( + getattr(analysis, "baseline_inequality", None), + getattr(analysis, "reform_inequality", None), + ), + poverty=poverty_outputs.poverty, + poverty_by_gender=poverty_outputs.poverty_by_gender, + poverty_by_race=poverty_outputs.poverty_by_race, intra_decile=intra_decile, + wealth_decile=build_decile_output(wealth_decile) if country == "uk" else None, + intra_wealth_decile=( + build_intra_decile_output(intra_wealth_decile) if country == "uk" else None + ), + labor_supply_response=build_labor_supply_response_output(analysis), congressional_district_impact=congressional_district_impact, constituency_impact=constituency_impact, local_authority_impact=local_authority_impact, - **poverty_outputs, + cliff_impact=None, ) + return output.model_dump(mode="json") diff --git a/projects/policyengine-api-simulation/src/modal/simulation_macro_output.py b/projects/policyengine-api-simulation/src/modal/simulation_macro_output.py index 343d5caac..7ef198b71 100644 --- a/projects/policyengine-api-simulation/src/modal/simulation_macro_output.py +++ b/projects/policyengine-api-simulation/src/modal/simulation_macro_output.py @@ -7,9 +7,11 @@ from __future__ import annotations -from typing import Any +from typing import Any, Generic, TypeVar -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, RootModel + +T = TypeVar("T") class MacroOutputModel(BaseModel): @@ -18,7 +20,11 @@ class MacroOutputModel(BaseModel): model_config = ConfigDict(extra="forbid") -class BudgetaryOutput(MacroOutputModel): +class MacroRootModel(RootModel[T], Generic[T]): + """Base model for internal root schemas that dump to dict/list values.""" + + +class BudgetaryImpact(MacroOutputModel): tax_revenue_impact: float state_tax_revenue_impact: float benefit_spending_impact: float @@ -27,12 +33,19 @@ class BudgetaryOutput(MacroOutputModel): baseline_net_income: float +BudgetaryOutput = BudgetaryImpact + + class DetailedBudgetProgramOutput(MacroOutputModel): baseline: float reform: float difference: float +class DetailedBudgetOutput(MacroRootModel[dict[str, DetailedBudgetProgramOutput]]): + pass + + class DecileOutput(MacroOutputModel): average: dict[str, float] relative: dict[str, float] @@ -81,17 +94,31 @@ class PovertyByRaceOutput(MacroOutputModel): poverty: RacePovertyOutput +class PovertyModuleOutputs(MacroOutputModel): + poverty: PovertyOutput + poverty_by_gender: PovertyByGenderOutput + poverty_by_race: PovertyByRaceOutput | None + + class InequalityOutput(MacroOutputModel): gini: BaselineReformValue top_10_pct_share: BaselineReformValue top_1_pct_share: BaselineReformValue +class LaborSupplyResponseOutput(MacroRootModel[dict[str, Any]]): + pass + + +class GeographicImpactOutput(MacroRootModel[list[dict[str, Any]]]): + pass + + class SingleYearMacroOutput(MacroOutputModel): model_version: str data_version: str - budget: BudgetaryOutput - detailed_budget: dict[str, DetailedBudgetProgramOutput] + budget: BudgetaryImpact + detailed_budget: DetailedBudgetOutput decile: DecileOutput inequality: InequalityOutput poverty: PovertyOutput @@ -100,8 +127,8 @@ class SingleYearMacroOutput(MacroOutputModel): intra_decile: IntraDecileOutput wealth_decile: DecileOutput | None intra_wealth_decile: IntraDecileOutput | None - labor_supply_response: dict[str, Any] | None - constituency_impact: list[dict[str, Any]] | None - local_authority_impact: list[dict[str, Any]] | None - congressional_district_impact: list[dict[str, Any]] | None + labor_supply_response: LaborSupplyResponseOutput | None + constituency_impact: GeographicImpactOutput | None + local_authority_impact: GeographicImpactOutput | None + congressional_district_impact: GeographicImpactOutput | None cliff_impact: None = None diff --git a/projects/policyengine-api-simulation/src/modal/simulation_output_adapter.py b/projects/policyengine-api-simulation/src/modal/simulation_output_adapter.py index e27ded681..38cee97a3 100644 --- a/projects/policyengine-api-simulation/src/modal/simulation_output_adapter.py +++ b/projects/policyengine-api-simulation/src/modal/simulation_output_adapter.py @@ -9,12 +9,16 @@ from src.modal.simulation_macro_output import ( AgePovertyOutput, BaselineReformValue, + BudgetaryImpact, BudgetaryOutput, DecileOutput, + DetailedBudgetOutput, DetailedBudgetProgramOutput, + GeographicImpactOutput, GenderPovertyOutput, InequalityOutput, IntraDecileOutput, + LaborSupplyResponseOutput, PovertyByGenderOutput, PovertyByRaceOutput, PovertyOutput, @@ -74,17 +78,27 @@ def _output_model_dump(value: Any) -> Any: return None -def _records_or_none(value: Any) -> list[dict[str, Any]] | None: +def build_geographic_impact_output(value: Any) -> GeographicImpactOutput | None: + if isinstance(value, GeographicImpactOutput): + return value records = _output_model_dump(value) if isinstance(records, list): - return [dict(item) for item in records if isinstance(item, Mapping)] + return GeographicImpactOutput( + [dict(item) for item in records if isinstance(item, Mapping)] + ) if isinstance(value, list): - return [dict(item) for item in value if isinstance(item, Mapping)] + return GeographicImpactOutput( + [dict(item) for item in value if isinstance(item, Mapping)] + ) return None -def build_budgetary_output(budget: Mapping[str, Any]) -> BudgetaryOutput: - return BudgetaryOutput( +def build_budgetary_output( + budget: Mapping[str, Any] | BudgetaryImpact, +) -> BudgetaryOutput: + if isinstance(budget, BudgetaryImpact): + return budget + return BudgetaryImpact( tax_revenue_impact=_number(budget.get("tax_revenue_impact")), state_tax_revenue_impact=_number(budget.get("state_tax_revenue_impact")), benefit_spending_impact=_number(budget.get("benefit_spending_impact")), @@ -96,7 +110,9 @@ def build_budgetary_output(budget: Mapping[str, Any]) -> BudgetaryOutput: def build_detailed_budget_output( collection: Any, -) -> dict[str, DetailedBudgetProgramOutput]: +) -> DetailedBudgetOutput: + if isinstance(collection, DetailedBudgetOutput): + return collection detailed_budget: dict[str, DetailedBudgetProgramOutput] = {} for row in _collection_records(collection): program_name = row.get("program_name") @@ -109,10 +125,12 @@ def build_detailed_budget_output( reform=reform, difference=_number(row.get("change"), reform - baseline), ) - return detailed_budget + return DetailedBudgetOutput(detailed_budget) def build_decile_output(collection: Any) -> DecileOutput: + if isinstance(collection, DecileOutput): + return collection average: dict[str, float] = {} relative: dict[str, float] = {} for row in sorted( @@ -129,6 +147,8 @@ def build_decile_output(collection: Any) -> DecileOutput: def build_intra_decile_output(collection: Any) -> IntraDecileOutput: + if isinstance(collection, IntraDecileOutput): + return collection deciles: dict[str, list[float]] = {label: [] for label in INTRA_DECILE_COLUMNS} all_values: dict[str, float] = {label: 0.0 for label in INTRA_DECILE_COLUMNS} rows = [ @@ -230,6 +250,8 @@ def build_poverty_output( baseline_by_age: Any, reform_by_age: Any, ) -> PovertyOutput: + if isinstance(baseline, PovertyOutput): + return baseline result = {"poverty": _empty_age_poverty(), "deep_poverty": _empty_age_poverty()} _fill_poverty_block( country=country, @@ -257,6 +279,8 @@ def build_poverty_by_gender_output( baseline_by_gender: Any, reform_by_gender: Any, ) -> PovertyByGenderOutput: + if isinstance(baseline_by_gender, PovertyByGenderOutput): + return baseline_by_gender result = { "poverty": _empty_gender_poverty(), "deep_poverty": _empty_gender_poverty(), @@ -279,6 +303,8 @@ def build_poverty_by_race_output( baseline_by_race: Any, reform_by_race: Any, ) -> PovertyByRaceOutput: + if isinstance(baseline_by_race, PovertyByRaceOutput): + return baseline_by_race result = { "poverty": { "white": _empty_baseline_reform_value(), @@ -298,6 +324,8 @@ def build_poverty_by_race_output( def build_inequality_output(baseline: Any, reform: Any) -> InequalityOutput: + if isinstance(baseline, InequalityOutput): + return baseline return InequalityOutput( gini=BaselineReformValue( baseline=_number(getattr(baseline, "gini", None)), @@ -314,9 +342,13 @@ def build_inequality_output(baseline: Any, reform: Any) -> InequalityOutput: ) -def build_labor_supply_response_output(analysis: Any) -> dict[str, Any] | None: +def build_labor_supply_response_output( + analysis: Any, +) -> LaborSupplyResponseOutput | None: + if isinstance(analysis, LaborSupplyResponseOutput): + return analysis output = _output_model_dump(getattr(analysis, "labor_supply_response", None)) - return output if isinstance(output, dict) else None + return LaborSupplyResponseOutput(output) if isinstance(output, dict) else None def build_single_year_macro_output( @@ -324,7 +356,7 @@ def build_single_year_macro_output( country: str, model_version: str, data_version: str, - budget: Mapping[str, Any], + budget: Mapping[str, Any] | BudgetaryImpact, analysis: Any, baseline_poverty_by_age: Any = None, reform_poverty_by_age: Any = None, @@ -380,9 +412,11 @@ def build_single_year_macro_output( build_intra_decile_output(intra_wealth_decile) if country == "uk" else None ), labor_supply_response=build_labor_supply_response_output(analysis), - constituency_impact=_records_or_none(constituency_impact), - local_authority_impact=_records_or_none(local_authority_impact), - congressional_district_impact=_records_or_none(congressional_district_impact), + constituency_impact=build_geographic_impact_output(constituency_impact), + local_authority_impact=build_geographic_impact_output(local_authority_impact), + congressional_district_impact=build_geographic_impact_output( + congressional_district_impact + ), cliff_impact=None, ) @@ -392,7 +426,7 @@ def adapt_analysis_to_legacy_macro_output( country: str, model_version: str, data_version: str, - budget: dict[str, float], + budget: Mapping[str, Any] | BudgetaryImpact, analysis: Any, baseline_poverty_by_age: Any = None, reform_poverty_by_age: Any = None, diff --git a/projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py b/projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py index c83fdae43..d007979bd 100644 --- a/projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py +++ b/projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py @@ -25,8 +25,11 @@ _uk_local_authority_impact, ) from src.modal.simulation_macro_output import ( + BudgetaryImpact, BudgetaryOutput, DecileOutput, + DetailedBudgetOutput, + GeographicImpactOutput, IntraDecileOutput, PovertyOutput, SingleYearMacroOutput, @@ -60,11 +63,14 @@ def test_builder_returns_schema_modules_before_legacy_dict_dump(): assert isinstance(output, SingleYearMacroOutput) assert isinstance(output.budget, BudgetaryOutput) + assert isinstance(output.budget, BudgetaryImpact) + assert isinstance(output.detailed_budget, DetailedBudgetOutput) assert isinstance(output.decile, DecileOutput) assert isinstance(output.intra_decile, IntraDecileOutput) assert isinstance(output.poverty, PovertyOutput) + assert isinstance(output.congressional_district_impact, GeographicImpactOutput) assert output.wealth_decile is None - assert output.congressional_district_impact == [{"district_geoid": 101}] + assert output.congressional_district_impact.root == [{"district_geoid": 101}] legacy_output = adapt_analysis_to_legacy_macro_output( country="us", @@ -189,7 +195,8 @@ def test_budget_result_uses_materialized_household_columns_and_uk_state_tax_zero us_budget = _budget_result("us", baseline, reform) uk_budget = _budget_result("uk", baseline, reform) - assert us_budget == { + assert isinstance(us_budget, BudgetaryImpact) + assert us_budget.model_dump(mode="json") == { "tax_revenue_impact": 15.0, "state_tax_revenue_impact": 5.0, "benefit_spending_impact": -3.0, @@ -197,7 +204,7 @@ def test_budget_result_uses_materialized_household_columns_and_uk_state_tax_zero "households": 3.0, "baseline_net_income": 300.0, } - assert uk_budget["state_tax_revenue_impact"] == 0.0 + assert uk_budget.state_tax_revenue_impact == 0.0 def test_uk_constituency_impact_uses_policyengine_output_function(monkeypatch): @@ -220,7 +227,7 @@ def compute(baseline_simulation, reform_simulation): "src.modal.simulation._output_module_function", fake_output_module_function ) - assert _uk_constituency_impact("uk", baseline, reform) == expected + assert _uk_constituency_impact("uk", baseline, reform).root == expected assert _uk_constituency_impact("us", baseline, reform) is None @@ -244,5 +251,5 @@ def compute(baseline_simulation, reform_simulation): "src.modal.simulation._output_module_function", fake_output_module_function ) - assert _uk_local_authority_impact("uk", baseline, reform) == expected + assert _uk_local_authority_impact("uk", baseline, reform).root == expected assert _uk_local_authority_impact("us", baseline, reform) is None From 59437e678ff490ad2726591d2b07e3bc01f7acd4 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 22 May 2026 14:52:09 +0200 Subject: [PATCH 11/23] refactor: centralize simulation output building --- .../src/modal/simulation.py | 453 ++++++++++-------- .../src/modal/simulation_output_adapter.py | 183 +++++-- .../tests/test_simulation_output_adapter.py | 65 ++- 3 files changed, 433 insertions(+), 268 deletions(-) diff --git a/projects/policyengine-api-simulation/src/modal/simulation.py b/projects/policyengine-api-simulation/src/modal/simulation.py index 53c79cbcb..75150dfa4 100644 --- a/projects/policyengine-api-simulation/src/modal/simulation.py +++ b/projects/policyengine-api-simulation/src/modal/simulation.py @@ -11,6 +11,7 @@ import logging import os import tempfile +from dataclasses import dataclass from importlib import import_module from typing import Any, Iterator @@ -20,8 +21,12 @@ ) from src.modal.simulation_macro_output import ( BudgetaryImpact, + DecileOutput, + DetailedBudgetOutput, GeographicImpactOutput, + InequalityOutput, IntraDecileOutput, + LaborSupplyResponseOutput, PovertyModuleOutputs, SingleYearMacroOutput, ) @@ -338,38 +343,6 @@ def _try_change_output_variable(baseline, reform, variable: str, entity: str) -> return 0.0 -def _budget_result(country: str, baseline, reform) -> BudgetaryImpact: - tax_revenue_impact = _try_change_output_variable( - baseline, reform, "household_tax", entity="household" - ) - benefit_spending_impact = _try_change_output_variable( - baseline, reform, "household_benefits", entity="household" - ) - state_tax_revenue_impact = ( - _try_change_output_variable( - baseline, - reform, - "household_state_income_tax", - entity="household", - ) - if country == "us" - else 0.0 - ) - - return BudgetaryImpact( - tax_revenue_impact=tax_revenue_impact, - state_tax_revenue_impact=state_tax_revenue_impact, - benefit_spending_impact=benefit_spending_impact, - budgetary_impact=tax_revenue_impact - benefit_spending_impact, - households=_try_sum_output_variable( - baseline, "household_weight", entity="household" - ), - baseline_net_income=_try_sum_output_variable( - baseline, "household_net_income", entity="household" - ), - ) - - def _output_module_function(module_name: str, name: str): module = import_module(f"policyengine.outputs.{module_name}") return getattr(module, name) @@ -387,158 +360,254 @@ def _try_compute_output(label: str, fn, *args, **kwargs): return None -def _poverty_outputs(country: str, baseline, reform, analysis) -> PovertyModuleOutputs: - prefix = "us" if country == "us" else "uk" - baseline_poverty_by_age = _try_compute_output( - "baseline poverty by age", - _poverty_module_function(f"calculate_{prefix}_poverty_by_age"), - baseline, - ) - reform_poverty_by_age = _try_compute_output( - "reform poverty by age", - _poverty_module_function(f"calculate_{prefix}_poverty_by_age"), - reform, - ) - baseline_poverty_by_gender = _try_compute_output( - "baseline poverty by gender", - _poverty_module_function(f"calculate_{prefix}_poverty_by_gender"), - baseline, - ) - reform_poverty_by_gender = _try_compute_output( - "reform poverty by gender", - _poverty_module_function(f"calculate_{prefix}_poverty_by_gender"), - reform, - ) - baseline_poverty_by_race = None - reform_poverty_by_race = None - if country == "us": - baseline_poverty_by_race = _try_compute_output( - "baseline poverty by race", - _poverty_module_function("calculate_us_poverty_by_race"), - baseline, - ) - reform_poverty_by_race = _try_compute_output( - "reform poverty by race", - _poverty_module_function("calculate_us_poverty_by_race"), - reform, +@dataclass +class SimulationMacroOutputBuilder: + country: str + simulation_params: dict[str, Any] + country_module: Any + dataset: Any + baseline: Any + reform: Any + analysis: Any + + def __post_init__(self) -> None: + self.country = self.country.lower() + + def build(self) -> SingleYearMacroOutput: + poverty_outputs = self._build_poverty_outputs() + wealth_decile = getattr(self.analysis, "wealth_decile_impacts", None) + intra_wealth_decile = getattr( + self.analysis, "intra_wealth_decile_impacts", None ) - return PovertyModuleOutputs( - poverty=build_poverty_output( - country, - baseline=getattr(analysis, "baseline_poverty", None), - reform=getattr(analysis, "reform_poverty", None), - baseline_by_age=baseline_poverty_by_age, - reform_by_age=reform_poverty_by_age, - ), - poverty_by_gender=build_poverty_by_gender_output( - country, - baseline_by_gender=baseline_poverty_by_gender, - reform_by_gender=reform_poverty_by_gender, - ), - poverty_by_race=( - build_poverty_by_race_output( - baseline_by_race=baseline_poverty_by_race, - reform_by_race=reform_poverty_by_race, - ) - if country == "us" - else None - ), - ) - - -def _intra_decile_output(baseline, reform) -> IntraDecileOutput: - from policyengine.outputs.intra_decile_impact import compute_intra_decile_impacts - collection = _try_compute_output( - "intra-decile impacts", - compute_intra_decile_impacts, - baseline, - reform, - income_variable="household_net_income", - entity="household", - ) - return build_intra_decile_output(collection) + return SingleYearMacroOutput( + model_version=self._model_version(), + data_version=self._data_version(), + budget=self._build_budgetary_impact(), + detailed_budget=self._build_detailed_budget(), + decile=self._build_decile(), + inequality=self._build_inequality(), + poverty=poverty_outputs.poverty, + poverty_by_gender=poverty_outputs.poverty_by_gender, + poverty_by_race=poverty_outputs.poverty_by_race, + intra_decile=self._build_intra_decile_output(), + wealth_decile=self._build_wealth_decile(wealth_decile), + intra_wealth_decile=self._build_intra_wealth_decile(intra_wealth_decile), + labor_supply_response=self._build_labor_supply_response(), + congressional_district_impact=(self._build_congressional_district_impact()), + constituency_impact=self._build_uk_constituency_impact(), + local_authority_impact=self._build_uk_local_authority_impact(), + cliff_impact=None, + ) + def serialize(self) -> dict[str, Any]: + return self.build().model_dump(mode="json") -def _congressional_district_impact( - country: str, baseline, reform -) -> GeographicImpactOutput | None: - if country != "us": - return None + def _build_detailed_budget(self) -> DetailedBudgetOutput: + return build_detailed_budget_output( + getattr(self.analysis, "program_statistics", None) + ) - from policyengine.outputs.congressional_district_impact import ( - compute_us_congressional_district_impacts, - ) + def _build_decile(self) -> DecileOutput: + return build_decile_output(getattr(self.analysis, "decile_impacts", None)) - impact = _try_compute_output( - "congressional district impacts", - compute_us_congressional_district_impacts, - baseline, - reform, - ) - return build_geographic_impact_output( - getattr(impact, "district_results", None) if impact is not None else None - ) + def _build_inequality(self) -> InequalityOutput: + return build_inequality_output( + getattr(self.analysis, "baseline_inequality", None), + getattr(self.analysis, "reform_inequality", None), + ) + def _build_budgetary_impact(self) -> BudgetaryImpact: + tax_revenue_impact = _try_change_output_variable( + self.baseline, self.reform, "household_tax", entity="household" + ) + benefit_spending_impact = _try_change_output_variable( + self.baseline, self.reform, "household_benefits", entity="household" + ) + state_tax_revenue_impact = ( + _try_change_output_variable( + self.baseline, + self.reform, + "household_state_income_tax", + entity="household", + ) + if self.country == "us" + else 0.0 + ) -def _uk_constituency_impact( - country: str, baseline, reform -) -> GeographicImpactOutput | None: - if country != "uk": - return None + return BudgetaryImpact( + tax_revenue_impact=tax_revenue_impact, + state_tax_revenue_impact=state_tax_revenue_impact, + benefit_spending_impact=benefit_spending_impact, + budgetary_impact=tax_revenue_impact - benefit_spending_impact, + households=_try_sum_output_variable( + self.baseline, "household_weight", entity="household" + ), + baseline_net_income=_try_sum_output_variable( + self.baseline, "household_net_income", entity="household" + ), + ) - impact = _try_compute_output( - "constituency impacts", - _output_module_function( - "constituency_impact", "compute_uk_constituency_impacts" - ), - baseline, - reform, - ) - if impact is None: - return None - return build_geographic_impact_output(getattr(impact, "constituency_results", None)) + def _build_poverty_outputs(self) -> PovertyModuleOutputs: + prefix = "us" if self.country == "us" else "uk" + baseline_poverty_by_age = _try_compute_output( + "baseline poverty by age", + _poverty_module_function(f"calculate_{prefix}_poverty_by_age"), + self.baseline, + ) + reform_poverty_by_age = _try_compute_output( + "reform poverty by age", + _poverty_module_function(f"calculate_{prefix}_poverty_by_age"), + self.reform, + ) + baseline_poverty_by_gender = _try_compute_output( + "baseline poverty by gender", + _poverty_module_function(f"calculate_{prefix}_poverty_by_gender"), + self.baseline, + ) + reform_poverty_by_gender = _try_compute_output( + "reform poverty by gender", + _poverty_module_function(f"calculate_{prefix}_poverty_by_gender"), + self.reform, + ) + baseline_poverty_by_race = None + reform_poverty_by_race = None + if self.country == "us": + baseline_poverty_by_race = _try_compute_output( + "baseline poverty by race", + _poverty_module_function("calculate_us_poverty_by_race"), + self.baseline, + ) + reform_poverty_by_race = _try_compute_output( + "reform poverty by race", + _poverty_module_function("calculate_us_poverty_by_race"), + self.reform, + ) + return PovertyModuleOutputs( + poverty=build_poverty_output( + self.country, + baseline=getattr(self.analysis, "baseline_poverty", None), + reform=getattr(self.analysis, "reform_poverty", None), + baseline_by_age=baseline_poverty_by_age, + reform_by_age=reform_poverty_by_age, + ), + poverty_by_gender=build_poverty_by_gender_output( + self.country, + baseline_by_gender=baseline_poverty_by_gender, + reform_by_gender=reform_poverty_by_gender, + ), + poverty_by_race=( + build_poverty_by_race_output( + baseline_by_race=baseline_poverty_by_race, + reform_by_race=reform_poverty_by_race, + ) + if self.country == "us" + else None + ), + ) + def _build_intra_decile_output(self) -> IntraDecileOutput: + from policyengine.outputs.intra_decile_impact import ( + compute_intra_decile_impacts, + ) -def _uk_local_authority_impact( - country: str, baseline, reform -) -> GeographicImpactOutput | None: - if country != "uk": - return None + collection = _try_compute_output( + "intra-decile impacts", + compute_intra_decile_impacts, + self.baseline, + self.reform, + income_variable="household_net_income", + entity="household", + ) + return build_intra_decile_output(collection) + + def _build_wealth_decile(self, wealth_decile) -> DecileOutput | None: + if self.country != "uk": + return None + return build_decile_output(wealth_decile) + + def _build_intra_wealth_decile( + self, intra_wealth_decile + ) -> IntraDecileOutput | None: + if self.country != "uk": + return None + return build_intra_decile_output(intra_wealth_decile) + + def _build_labor_supply_response(self) -> LaborSupplyResponseOutput | None: + return build_labor_supply_response_output(self.analysis) + + def _build_congressional_district_impact( + self, + ) -> GeographicImpactOutput | None: + if self.country != "us": + return None + + from policyengine.outputs.congressional_district_impact import ( + compute_us_congressional_district_impacts, + ) - impact = _try_compute_output( - "local authority impacts", - _output_module_function( - "local_authority_impact", "compute_uk_local_authority_impacts" - ), - baseline, - reform, - ) - if impact is None: - return None - return build_geographic_impact_output( - getattr(impact, "local_authority_results", None) - ) + impact = _try_compute_output( + "congressional district impacts", + compute_us_congressional_district_impacts, + self.baseline, + self.reform, + ) + return build_geographic_impact_output( + getattr(impact, "district_results", None) if impact is not None else None + ) + def _build_uk_constituency_impact(self) -> GeographicImpactOutput | None: + if self.country != "uk": + return None + + impact = _try_compute_output( + "constituency impacts", + _output_module_function( + "constituency_impact", "compute_uk_constituency_impacts" + ), + self.baseline, + self.reform, + ) + if impact is None: + return None + return build_geographic_impact_output( + getattr(impact, "constituency_results", None) + ) -def _model_version(country_module) -> str: - return str(getattr(country_module.model, "version", "")) + def _build_uk_local_authority_impact(self) -> GeographicImpactOutput | None: + if self.country != "uk": + return None + + impact = _try_compute_output( + "local authority impacts", + _output_module_function( + "local_authority_impact", "compute_uk_local_authority_impacts" + ), + self.baseline, + self.reform, + ) + if impact is None: + return None + return build_geographic_impact_output( + getattr(impact, "local_authority_results", None) + ) + def _model_version(self) -> str: + return str(getattr(self.country_module.model, "version", "")) -def _data_version(params: dict[str, Any], dataset) -> str: - if params.get("data_version"): - return str(params["data_version"]) - country = params.get("country", "us").lower() - try: - return get_country_release_bundle(country).data_version - except ValueError: - pass - metadata = getattr(dataset, "metadata", {}) or {} - for key in ("data_version", "version"): - value = metadata.get(key) - if value is not None: - return str(value) - return "" + def _data_version(self) -> str: + if self.simulation_params.get("data_version"): + return str(self.simulation_params["data_version"]) + try: + return get_country_release_bundle(self.country).data_version + except ValueError: + pass + metadata = getattr(self.dataset, "metadata", {}) or {} + for key in ("data_version", "version"): + value = metadata.get(key) + if value is not None: + return str(value) + return "" def _run_simulation_impl_core(params: dict) -> dict: @@ -573,42 +642,14 @@ def _run_simulation_impl_core(params: dict) -> dict: logger.info("Calculating economic impact") analysis = country_module.economic_impact_analysis(baseline, reform) - budget = _budget_result(country, baseline, reform) - poverty_outputs = _poverty_outputs(country, baseline, reform, analysis) - intra_decile = _intra_decile_output(baseline, reform) - congressional_district_impact = _congressional_district_impact( - country, baseline, reform - ) - constituency_impact = _uk_constituency_impact(country, baseline, reform) - local_authority_impact = _uk_local_authority_impact(country, baseline, reform) + output = SimulationMacroOutputBuilder( + country=country, + simulation_params=simulation_params, + country_module=country_module, + dataset=dataset, + baseline=baseline, + reform=reform, + analysis=analysis, + ).serialize() logger.info("Comparison complete") - - wealth_decile = getattr(analysis, "wealth_decile_impacts", None) - intra_wealth_decile = getattr(analysis, "intra_wealth_decile_impacts", None) - output = SingleYearMacroOutput( - model_version=_model_version(country_module), - data_version=_data_version(simulation_params, dataset), - budget=budget, - detailed_budget=build_detailed_budget_output( - getattr(analysis, "program_statistics", None) - ), - decile=build_decile_output(getattr(analysis, "decile_impacts", None)), - inequality=build_inequality_output( - getattr(analysis, "baseline_inequality", None), - getattr(analysis, "reform_inequality", None), - ), - poverty=poverty_outputs.poverty, - poverty_by_gender=poverty_outputs.poverty_by_gender, - poverty_by_race=poverty_outputs.poverty_by_race, - intra_decile=intra_decile, - wealth_decile=build_decile_output(wealth_decile) if country == "uk" else None, - intra_wealth_decile=( - build_intra_decile_output(intra_wealth_decile) if country == "uk" else None - ), - labor_supply_response=build_labor_supply_response_output(analysis), - congressional_district_impact=congressional_district_impact, - constituency_impact=constituency_impact, - local_authority_impact=local_authority_impact, - cliff_impact=None, - ) - return output.model_dump(mode="json") + return output diff --git a/projects/policyengine-api-simulation/src/modal/simulation_output_adapter.py b/projects/policyengine-api-simulation/src/modal/simulation_output_adapter.py index 38cee97a3..7f1ad1625 100644 --- a/projects/policyengine-api-simulation/src/modal/simulation_output_adapter.py +++ b/projects/policyengine-api-simulation/src/modal/simulation_output_adapter.py @@ -4,6 +4,7 @@ import math from collections.abc import Iterable, Mapping +from dataclasses import dataclass from typing import Any from src.modal.simulation_macro_output import ( @@ -351,6 +352,122 @@ def build_labor_supply_response_output( return LaborSupplyResponseOutput(output) if isinstance(output, dict) else None +@dataclass +class SingleYearMacroOutputBuilder: + country: str + model_version: str + data_version: str + budget: Mapping[str, Any] | BudgetaryImpact + analysis: Any + baseline_poverty_by_age: Any = None + reform_poverty_by_age: Any = None + baseline_poverty_by_gender: Any = None + reform_poverty_by_gender: Any = None + baseline_poverty_by_race: Any = None + reform_poverty_by_race: Any = None + intra_decile: Any = None + congressional_district_impact: Any = None + constituency_impact: Any = None + local_authority_impact: Any = None + + def __post_init__(self) -> None: + self.country = self.country.lower() + + def build(self) -> SingleYearMacroOutput: + return SingleYearMacroOutput( + model_version=self.model_version, + data_version=self.data_version, + budget=self._build_budgetary_impact(), + detailed_budget=self._build_detailed_budget(), + decile=self._build_decile(), + inequality=self._build_inequality(), + poverty=self._build_poverty(), + poverty_by_gender=self._build_poverty_by_gender(), + poverty_by_race=self._build_poverty_by_race(), + intra_decile=self._build_intra_decile(), + wealth_decile=self._build_wealth_decile(), + intra_wealth_decile=self._build_intra_wealth_decile(), + labor_supply_response=self._build_labor_supply_response(), + constituency_impact=self._build_constituency_impact(), + local_authority_impact=self._build_local_authority_impact(), + congressional_district_impact=self._build_congressional_district_impact(), + cliff_impact=None, + ) + + def serialize(self) -> dict[str, Any]: + return self.build().model_dump(mode="json") + + def _build_budgetary_impact(self) -> BudgetaryImpact: + return build_budgetary_output(self.budget) + + def _build_detailed_budget(self) -> DetailedBudgetOutput: + return build_detailed_budget_output( + getattr(self.analysis, "program_statistics", None) + ) + + def _build_decile(self) -> DecileOutput: + return build_decile_output(getattr(self.analysis, "decile_impacts", None)) + + def _build_inequality(self) -> InequalityOutput: + return build_inequality_output( + getattr(self.analysis, "baseline_inequality", None), + getattr(self.analysis, "reform_inequality", None), + ) + + def _build_poverty(self) -> PovertyOutput: + return build_poverty_output( + self.country, + baseline=getattr(self.analysis, "baseline_poverty", None), + reform=getattr(self.analysis, "reform_poverty", None), + baseline_by_age=self.baseline_poverty_by_age, + reform_by_age=self.reform_poverty_by_age, + ) + + def _build_poverty_by_gender(self) -> PovertyByGenderOutput: + return build_poverty_by_gender_output( + self.country, + baseline_by_gender=self.baseline_poverty_by_gender, + reform_by_gender=self.reform_poverty_by_gender, + ) + + def _build_poverty_by_race(self) -> PovertyByRaceOutput | None: + if self.country != "us": + return None + return build_poverty_by_race_output( + baseline_by_race=self.baseline_poverty_by_race, + reform_by_race=self.reform_poverty_by_race, + ) + + def _build_intra_decile(self) -> IntraDecileOutput: + return build_intra_decile_output(self.intra_decile) + + def _build_wealth_decile(self) -> DecileOutput | None: + if self.country != "uk": + return None + return build_decile_output( + getattr(self.analysis, "wealth_decile_impacts", None) + ) + + def _build_intra_wealth_decile(self) -> IntraDecileOutput | None: + if self.country != "uk": + return None + return build_intra_decile_output( + getattr(self.analysis, "intra_wealth_decile_impacts", None) + ) + + def _build_labor_supply_response(self) -> LaborSupplyResponseOutput | None: + return build_labor_supply_response_output(self.analysis) + + def _build_constituency_impact(self) -> GeographicImpactOutput | None: + return build_geographic_impact_output(self.constituency_impact) + + def _build_local_authority_impact(self) -> GeographicImpactOutput | None: + return build_geographic_impact_output(self.local_authority_impact) + + def _build_congressional_district_impact(self) -> GeographicImpactOutput | None: + return build_geographic_impact_output(self.congressional_district_impact) + + def build_single_year_macro_output( *, country: str, @@ -370,55 +487,23 @@ def build_single_year_macro_output( local_authority_impact: Any = None, ) -> SingleYearMacroOutput: """Build the schema-first single-year macro output.""" - country = country.lower() - wealth_decile = getattr(analysis, "wealth_decile_impacts", None) - intra_wealth_decile = getattr(analysis, "intra_wealth_decile_impacts", None) - - return SingleYearMacroOutput( + return SingleYearMacroOutputBuilder( + country=country, model_version=model_version, data_version=data_version, - budget=build_budgetary_output(budget), - detailed_budget=build_detailed_budget_output( - getattr(analysis, "program_statistics", None) - ), - decile=build_decile_output(getattr(analysis, "decile_impacts", None)), - inequality=build_inequality_output( - getattr(analysis, "baseline_inequality", None), - getattr(analysis, "reform_inequality", None), - ), - poverty=build_poverty_output( - country, - baseline=getattr(analysis, "baseline_poverty", None), - reform=getattr(analysis, "reform_poverty", None), - baseline_by_age=baseline_poverty_by_age, - reform_by_age=reform_poverty_by_age, - ), - poverty_by_gender=build_poverty_by_gender_output( - country, - baseline_by_gender=baseline_poverty_by_gender, - reform_by_gender=reform_poverty_by_gender, - ), - poverty_by_race=( - build_poverty_by_race_output( - baseline_by_race=baseline_poverty_by_race, - reform_by_race=reform_poverty_by_race, - ) - if country == "us" - else None - ), - intra_decile=build_intra_decile_output(intra_decile), - wealth_decile=build_decile_output(wealth_decile) if country == "uk" else None, - intra_wealth_decile=( - build_intra_decile_output(intra_wealth_decile) if country == "uk" else None - ), - labor_supply_response=build_labor_supply_response_output(analysis), - constituency_impact=build_geographic_impact_output(constituency_impact), - local_authority_impact=build_geographic_impact_output(local_authority_impact), - congressional_district_impact=build_geographic_impact_output( - congressional_district_impact - ), - cliff_impact=None, - ) + budget=budget, + analysis=analysis, + baseline_poverty_by_age=baseline_poverty_by_age, + reform_poverty_by_age=reform_poverty_by_age, + baseline_poverty_by_gender=baseline_poverty_by_gender, + reform_poverty_by_gender=reform_poverty_by_gender, + baseline_poverty_by_race=baseline_poverty_by_race, + reform_poverty_by_race=reform_poverty_by_race, + intra_decile=intra_decile, + congressional_district_impact=congressional_district_impact, + constituency_impact=constituency_impact, + local_authority_impact=local_authority_impact, + ).build() def adapt_analysis_to_legacy_macro_output( @@ -440,7 +525,7 @@ def adapt_analysis_to_legacy_macro_output( local_authority_impact: Any = None, ) -> dict[str, Any]: """Return the legacy single-year macro result expected by API callers.""" - return build_single_year_macro_output( + return SingleYearMacroOutputBuilder( country=country, model_version=model_version, data_version=data_version, @@ -456,4 +541,4 @@ def adapt_analysis_to_legacy_macro_output( congressional_district_impact=congressional_district_impact, constituency_impact=constituency_impact, local_authority_impact=local_authority_impact, - ).model_dump(mode="json") + ).serialize() diff --git a/projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py b/projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py index d007979bd..d3ad36cc6 100644 --- a/projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py +++ b/projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py @@ -19,10 +19,8 @@ fake_analysis, ) from src.modal.simulation import ( - _budget_result, + SimulationMacroOutputBuilder, _normalise_policy, - _uk_constituency_impact, - _uk_local_authority_impact, ) from src.modal.simulation_macro_output import ( BudgetaryImpact, @@ -35,13 +33,13 @@ SingleYearMacroOutput, ) from src.modal.simulation_output_adapter import ( + SingleYearMacroOutputBuilder, adapt_analysis_to_legacy_macro_output, - build_single_year_macro_output, ) def _build_schema_output() -> SingleYearMacroOutput: - return build_single_year_macro_output( + return SingleYearMacroOutputBuilder( country="us", model_version="1.700.0", data_version="1.115.5", @@ -55,7 +53,7 @@ def _build_schema_output() -> SingleYearMacroOutput: reform_poverty_by_race=REFORM_POVERTY_BY_RACE, intra_decile=INTRA_DECILE_COLLECTION, congressional_district_impact=[{"district_geoid": 101}], - ) + ).build() def test_builder_returns_schema_modules_before_legacy_dict_dump(): @@ -155,6 +153,23 @@ def test_normalise_policy_converts_legacy_period_range_keys(): } +def _simulation_output_builder( + country: str, + baseline, + reform, + analysis=None, +) -> SimulationMacroOutputBuilder: + return SimulationMacroOutputBuilder( + country=country, + simulation_params={"country": country}, + country_module=SimpleNamespace(model=SimpleNamespace(version="test")), + dataset=SimpleNamespace(metadata={}), + baseline=baseline, + reform=reform, + analysis=analysis or fake_analysis(), + ) + + class _FakeOutputDataset: def __init__(self, household): self.data = SimpleNamespace(household=household) @@ -168,7 +183,7 @@ def ensure(self): raise AssertionError("test data is already materialized") -def test_budget_result_uses_materialized_household_columns_and_uk_state_tax_zero(): +def test_builder_budgetary_impact_uses_materialized_columns_and_uk_state_tax_zero(): baseline = _FakeSimulation( pd.DataFrame( { @@ -192,8 +207,12 @@ def test_budget_result_uses_materialized_household_columns_and_uk_state_tax_zero ) ) - us_budget = _budget_result("us", baseline, reform) - uk_budget = _budget_result("uk", baseline, reform) + us_budget = _simulation_output_builder( + "us", baseline, reform + )._build_budgetary_impact() + uk_budget = _simulation_output_builder( + "uk", baseline, reform + )._build_budgetary_impact() assert isinstance(us_budget, BudgetaryImpact) assert us_budget.model_dump(mode="json") == { @@ -227,8 +246,18 @@ def compute(baseline_simulation, reform_simulation): "src.modal.simulation._output_module_function", fake_output_module_function ) - assert _uk_constituency_impact("uk", baseline, reform).root == expected - assert _uk_constituency_impact("us", baseline, reform) is None + assert ( + _simulation_output_builder("uk", baseline, reform) + ._build_uk_constituency_impact() + .root + == expected + ) + assert ( + _simulation_output_builder( + "us", baseline, reform + )._build_uk_constituency_impact() + is None + ) def test_uk_local_authority_impact_uses_policyengine_output_function(monkeypatch): @@ -251,5 +280,15 @@ def compute(baseline_simulation, reform_simulation): "src.modal.simulation._output_module_function", fake_output_module_function ) - assert _uk_local_authority_impact("uk", baseline, reform).root == expected - assert _uk_local_authority_impact("us", baseline, reform) is None + assert ( + _simulation_output_builder("uk", baseline, reform) + ._build_uk_local_authority_impact() + .root + == expected + ) + assert ( + _simulation_output_builder( + "us", baseline, reform + )._build_uk_local_authority_impact() + is None + ) From 06b8bdf1fa30f3411a014c611df8b9d3de22b946 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 22 May 2026 14:59:18 +0200 Subject: [PATCH 12/23] refactor: move simulation output builder module --- .../src/modal/simulation.py | 354 +---------------- .../modal/simulation_macro_output_builder.py | 359 ++++++++++++++++++ .../tests/test_simulation_output_adapter.py | 12 +- 3 files changed, 367 insertions(+), 358 deletions(-) create mode 100644 projects/policyengine-api-simulation/src/modal/simulation_macro_output_builder.py diff --git a/projects/policyengine-api-simulation/src/modal/simulation.py b/projects/policyengine-api-simulation/src/modal/simulation.py index 75150dfa4..89307b909 100644 --- a/projects/policyengine-api-simulation/src/modal/simulation.py +++ b/projects/policyengine-api-simulation/src/modal/simulation.py @@ -11,36 +11,11 @@ import logging import os import tempfile -from dataclasses import dataclass from importlib import import_module from typing import Any, Iterator -from src.modal.release_bundle import ( - get_country_release_bundle, - resolve_bundle_dataset_name, -) -from src.modal.simulation_macro_output import ( - BudgetaryImpact, - DecileOutput, - DetailedBudgetOutput, - GeographicImpactOutput, - InequalityOutput, - IntraDecileOutput, - LaborSupplyResponseOutput, - PovertyModuleOutputs, - SingleYearMacroOutput, -) -from src.modal.simulation_output_adapter import ( - build_decile_output, - build_detailed_budget_output, - build_geographic_impact_output, - build_inequality_output, - build_intra_decile_output, - build_labor_supply_response_output, - build_poverty_by_gender_output, - build_poverty_by_race_output, - build_poverty_output, -) +from src.modal.release_bundle import resolve_bundle_dataset_name +from src.modal.simulation_macro_output_builder import SimulationMacroOutputBuilder from src.modal.telemetry import split_internal_payload logger = logging.getLogger(__name__) @@ -285,331 +260,6 @@ def _build_simulation( ) -def _entity_data(simulation, entity: str): - if simulation.output_dataset is None or simulation.output_dataset.data is None: - simulation.ensure() - return getattr(simulation.output_dataset.data, entity) - - -def _sum_output_variable(simulation, variable: str, entity: str) -> float: - data = _entity_data(simulation, entity) - if variable in data.columns: - return float(data[variable].sum()) - - from policyengine.outputs import Aggregate, AggregateType - - output = Aggregate( - simulation=simulation, - variable=variable, - entity=entity, - aggregate_type=AggregateType.SUM, - ) - output.run() - return float(output.result) - - -def _try_sum_output_variable(simulation, variable: str, entity: str) -> float: - try: - return _sum_output_variable(simulation, variable, entity) - except Exception: - logger.warning("Unable to calculate sum for %s", variable, exc_info=True) - return 0.0 - - -def _change_output_variable(baseline, reform, variable: str, entity: str) -> float: - baseline_data = _entity_data(baseline, entity) - reform_data = _entity_data(reform, entity) - if variable in baseline_data.columns and variable in reform_data.columns: - return float((reform_data[variable] - baseline_data[variable]).sum()) - - from policyengine.outputs import ChangeAggregate, ChangeAggregateType - - output = ChangeAggregate( - baseline_simulation=baseline, - reform_simulation=reform, - variable=variable, - entity=entity, - aggregate_type=ChangeAggregateType.SUM, - ) - output.run() - return float(output.result) - - -def _try_change_output_variable(baseline, reform, variable: str, entity: str) -> float: - try: - return _change_output_variable(baseline, reform, variable, entity) - except Exception: - logger.warning("Unable to calculate change for %s", variable, exc_info=True) - return 0.0 - - -def _output_module_function(module_name: str, name: str): - module = import_module(f"policyengine.outputs.{module_name}") - return getattr(module, name) - - -def _poverty_module_function(name: str): - return _output_module_function("poverty", name) - - -def _try_compute_output(label: str, fn, *args, **kwargs): - try: - return fn(*args, **kwargs) - except Exception: - logger.warning("Unable to calculate %s", label, exc_info=True) - return None - - -@dataclass -class SimulationMacroOutputBuilder: - country: str - simulation_params: dict[str, Any] - country_module: Any - dataset: Any - baseline: Any - reform: Any - analysis: Any - - def __post_init__(self) -> None: - self.country = self.country.lower() - - def build(self) -> SingleYearMacroOutput: - poverty_outputs = self._build_poverty_outputs() - wealth_decile = getattr(self.analysis, "wealth_decile_impacts", None) - intra_wealth_decile = getattr( - self.analysis, "intra_wealth_decile_impacts", None - ) - - return SingleYearMacroOutput( - model_version=self._model_version(), - data_version=self._data_version(), - budget=self._build_budgetary_impact(), - detailed_budget=self._build_detailed_budget(), - decile=self._build_decile(), - inequality=self._build_inequality(), - poverty=poverty_outputs.poverty, - poverty_by_gender=poverty_outputs.poverty_by_gender, - poverty_by_race=poverty_outputs.poverty_by_race, - intra_decile=self._build_intra_decile_output(), - wealth_decile=self._build_wealth_decile(wealth_decile), - intra_wealth_decile=self._build_intra_wealth_decile(intra_wealth_decile), - labor_supply_response=self._build_labor_supply_response(), - congressional_district_impact=(self._build_congressional_district_impact()), - constituency_impact=self._build_uk_constituency_impact(), - local_authority_impact=self._build_uk_local_authority_impact(), - cliff_impact=None, - ) - - def serialize(self) -> dict[str, Any]: - return self.build().model_dump(mode="json") - - def _build_detailed_budget(self) -> DetailedBudgetOutput: - return build_detailed_budget_output( - getattr(self.analysis, "program_statistics", None) - ) - - def _build_decile(self) -> DecileOutput: - return build_decile_output(getattr(self.analysis, "decile_impacts", None)) - - def _build_inequality(self) -> InequalityOutput: - return build_inequality_output( - getattr(self.analysis, "baseline_inequality", None), - getattr(self.analysis, "reform_inequality", None), - ) - - def _build_budgetary_impact(self) -> BudgetaryImpact: - tax_revenue_impact = _try_change_output_variable( - self.baseline, self.reform, "household_tax", entity="household" - ) - benefit_spending_impact = _try_change_output_variable( - self.baseline, self.reform, "household_benefits", entity="household" - ) - state_tax_revenue_impact = ( - _try_change_output_variable( - self.baseline, - self.reform, - "household_state_income_tax", - entity="household", - ) - if self.country == "us" - else 0.0 - ) - - return BudgetaryImpact( - tax_revenue_impact=tax_revenue_impact, - state_tax_revenue_impact=state_tax_revenue_impact, - benefit_spending_impact=benefit_spending_impact, - budgetary_impact=tax_revenue_impact - benefit_spending_impact, - households=_try_sum_output_variable( - self.baseline, "household_weight", entity="household" - ), - baseline_net_income=_try_sum_output_variable( - self.baseline, "household_net_income", entity="household" - ), - ) - - def _build_poverty_outputs(self) -> PovertyModuleOutputs: - prefix = "us" if self.country == "us" else "uk" - baseline_poverty_by_age = _try_compute_output( - "baseline poverty by age", - _poverty_module_function(f"calculate_{prefix}_poverty_by_age"), - self.baseline, - ) - reform_poverty_by_age = _try_compute_output( - "reform poverty by age", - _poverty_module_function(f"calculate_{prefix}_poverty_by_age"), - self.reform, - ) - baseline_poverty_by_gender = _try_compute_output( - "baseline poverty by gender", - _poverty_module_function(f"calculate_{prefix}_poverty_by_gender"), - self.baseline, - ) - reform_poverty_by_gender = _try_compute_output( - "reform poverty by gender", - _poverty_module_function(f"calculate_{prefix}_poverty_by_gender"), - self.reform, - ) - baseline_poverty_by_race = None - reform_poverty_by_race = None - if self.country == "us": - baseline_poverty_by_race = _try_compute_output( - "baseline poverty by race", - _poverty_module_function("calculate_us_poverty_by_race"), - self.baseline, - ) - reform_poverty_by_race = _try_compute_output( - "reform poverty by race", - _poverty_module_function("calculate_us_poverty_by_race"), - self.reform, - ) - return PovertyModuleOutputs( - poverty=build_poverty_output( - self.country, - baseline=getattr(self.analysis, "baseline_poverty", None), - reform=getattr(self.analysis, "reform_poverty", None), - baseline_by_age=baseline_poverty_by_age, - reform_by_age=reform_poverty_by_age, - ), - poverty_by_gender=build_poverty_by_gender_output( - self.country, - baseline_by_gender=baseline_poverty_by_gender, - reform_by_gender=reform_poverty_by_gender, - ), - poverty_by_race=( - build_poverty_by_race_output( - baseline_by_race=baseline_poverty_by_race, - reform_by_race=reform_poverty_by_race, - ) - if self.country == "us" - else None - ), - ) - - def _build_intra_decile_output(self) -> IntraDecileOutput: - from policyengine.outputs.intra_decile_impact import ( - compute_intra_decile_impacts, - ) - - collection = _try_compute_output( - "intra-decile impacts", - compute_intra_decile_impacts, - self.baseline, - self.reform, - income_variable="household_net_income", - entity="household", - ) - return build_intra_decile_output(collection) - - def _build_wealth_decile(self, wealth_decile) -> DecileOutput | None: - if self.country != "uk": - return None - return build_decile_output(wealth_decile) - - def _build_intra_wealth_decile( - self, intra_wealth_decile - ) -> IntraDecileOutput | None: - if self.country != "uk": - return None - return build_intra_decile_output(intra_wealth_decile) - - def _build_labor_supply_response(self) -> LaborSupplyResponseOutput | None: - return build_labor_supply_response_output(self.analysis) - - def _build_congressional_district_impact( - self, - ) -> GeographicImpactOutput | None: - if self.country != "us": - return None - - from policyengine.outputs.congressional_district_impact import ( - compute_us_congressional_district_impacts, - ) - - impact = _try_compute_output( - "congressional district impacts", - compute_us_congressional_district_impacts, - self.baseline, - self.reform, - ) - return build_geographic_impact_output( - getattr(impact, "district_results", None) if impact is not None else None - ) - - def _build_uk_constituency_impact(self) -> GeographicImpactOutput | None: - if self.country != "uk": - return None - - impact = _try_compute_output( - "constituency impacts", - _output_module_function( - "constituency_impact", "compute_uk_constituency_impacts" - ), - self.baseline, - self.reform, - ) - if impact is None: - return None - return build_geographic_impact_output( - getattr(impact, "constituency_results", None) - ) - - def _build_uk_local_authority_impact(self) -> GeographicImpactOutput | None: - if self.country != "uk": - return None - - impact = _try_compute_output( - "local authority impacts", - _output_module_function( - "local_authority_impact", "compute_uk_local_authority_impacts" - ), - self.baseline, - self.reform, - ) - if impact is None: - return None - return build_geographic_impact_output( - getattr(impact, "local_authority_results", None) - ) - - def _model_version(self) -> str: - return str(getattr(self.country_module.model, "version", "")) - - def _data_version(self) -> str: - if self.simulation_params.get("data_version"): - return str(self.simulation_params["data_version"]) - try: - return get_country_release_bundle(self.country).data_version - except ValueError: - pass - metadata = getattr(self.dataset, "metadata", {}) or {} - for key in ("data_version", "version"): - value = metadata.get(key) - if value is not None: - return str(value) - return "" - - def _run_simulation_impl_core(params: dict) -> dict: simulation_params, telemetry, metadata = split_internal_payload(params) diff --git a/projects/policyengine-api-simulation/src/modal/simulation_macro_output_builder.py b/projects/policyengine-api-simulation/src/modal/simulation_macro_output_builder.py new file mode 100644 index 000000000..5f0ed34f2 --- /dev/null +++ b/projects/policyengine-api-simulation/src/modal/simulation_macro_output_builder.py @@ -0,0 +1,359 @@ +"""Build and serialize the runtime simulation macro output.""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from importlib import import_module +from typing import Any + +from src.modal.release_bundle import get_country_release_bundle +from src.modal.simulation_macro_output import ( + BudgetaryImpact, + DecileOutput, + DetailedBudgetOutput, + GeographicImpactOutput, + InequalityOutput, + IntraDecileOutput, + LaborSupplyResponseOutput, + PovertyModuleOutputs, + SingleYearMacroOutput, +) +from src.modal.simulation_output_adapter import ( + build_decile_output, + build_detailed_budget_output, + build_geographic_impact_output, + build_inequality_output, + build_intra_decile_output, + build_labor_supply_response_output, + build_poverty_by_gender_output, + build_poverty_by_race_output, + build_poverty_output, +) + +logger = logging.getLogger(__name__) + + +def _entity_data(simulation, entity: str): + if simulation.output_dataset is None or simulation.output_dataset.data is None: + simulation.ensure() + return getattr(simulation.output_dataset.data, entity) + + +def _sum_output_variable(simulation, variable: str, entity: str) -> float: + data = _entity_data(simulation, entity) + if variable in data.columns: + return float(data[variable].sum()) + + from policyengine.outputs import Aggregate, AggregateType + + output = Aggregate( + simulation=simulation, + variable=variable, + entity=entity, + aggregate_type=AggregateType.SUM, + ) + output.run() + return float(output.result) + + +def _try_sum_output_variable(simulation, variable: str, entity: str) -> float: + try: + return _sum_output_variable(simulation, variable, entity) + except Exception: + logger.warning("Unable to calculate sum for %s", variable, exc_info=True) + return 0.0 + + +def _change_output_variable(baseline, reform, variable: str, entity: str) -> float: + baseline_data = _entity_data(baseline, entity) + reform_data = _entity_data(reform, entity) + if variable in baseline_data.columns and variable in reform_data.columns: + return float((reform_data[variable] - baseline_data[variable]).sum()) + + from policyengine.outputs import ChangeAggregate, ChangeAggregateType + + output = ChangeAggregate( + baseline_simulation=baseline, + reform_simulation=reform, + variable=variable, + entity=entity, + aggregate_type=ChangeAggregateType.SUM, + ) + output.run() + return float(output.result) + + +def _try_change_output_variable(baseline, reform, variable: str, entity: str) -> float: + try: + return _change_output_variable(baseline, reform, variable, entity) + except Exception: + logger.warning("Unable to calculate change for %s", variable, exc_info=True) + return 0.0 + + +def _output_module_function(module_name: str, name: str): + module = import_module(f"policyengine.outputs.{module_name}") + return getattr(module, name) + + +def _poverty_module_function(name: str): + return _output_module_function("poverty", name) + + +def _try_compute_output(label: str, fn, *args, **kwargs): + try: + return fn(*args, **kwargs) + except Exception: + logger.warning("Unable to calculate %s", label, exc_info=True) + return None + + +@dataclass +class SimulationMacroOutputBuilder: + country: str + simulation_params: dict[str, Any] + country_module: Any + dataset: Any + baseline: Any + reform: Any + analysis: Any + + def __post_init__(self) -> None: + self.country = self.country.lower() + + def build(self) -> SingleYearMacroOutput: + poverty_outputs = self._build_poverty_outputs() + wealth_decile = getattr(self.analysis, "wealth_decile_impacts", None) + intra_wealth_decile = getattr( + self.analysis, "intra_wealth_decile_impacts", None + ) + + return SingleYearMacroOutput( + model_version=self._model_version(), + data_version=self._data_version(), + budget=self._build_budgetary_impact(), + detailed_budget=self._build_detailed_budget(), + decile=self._build_decile(), + inequality=self._build_inequality(), + poverty=poverty_outputs.poverty, + poverty_by_gender=poverty_outputs.poverty_by_gender, + poverty_by_race=poverty_outputs.poverty_by_race, + intra_decile=self._build_intra_decile_output(), + wealth_decile=self._build_wealth_decile(wealth_decile), + intra_wealth_decile=self._build_intra_wealth_decile(intra_wealth_decile), + labor_supply_response=self._build_labor_supply_response(), + congressional_district_impact=(self._build_congressional_district_impact()), + constituency_impact=self._build_uk_constituency_impact(), + local_authority_impact=self._build_uk_local_authority_impact(), + cliff_impact=None, + ) + + def serialize(self) -> dict[str, Any]: + return self.build().model_dump(mode="json") + + def _build_detailed_budget(self) -> DetailedBudgetOutput: + return build_detailed_budget_output( + getattr(self.analysis, "program_statistics", None) + ) + + def _build_decile(self) -> DecileOutput: + return build_decile_output(getattr(self.analysis, "decile_impacts", None)) + + def _build_inequality(self) -> InequalityOutput: + return build_inequality_output( + getattr(self.analysis, "baseline_inequality", None), + getattr(self.analysis, "reform_inequality", None), + ) + + def _build_budgetary_impact(self) -> BudgetaryImpact: + tax_revenue_impact = _try_change_output_variable( + self.baseline, self.reform, "household_tax", entity="household" + ) + benefit_spending_impact = _try_change_output_variable( + self.baseline, self.reform, "household_benefits", entity="household" + ) + state_tax_revenue_impact = ( + _try_change_output_variable( + self.baseline, + self.reform, + "household_state_income_tax", + entity="household", + ) + if self.country == "us" + else 0.0 + ) + + return BudgetaryImpact( + tax_revenue_impact=tax_revenue_impact, + state_tax_revenue_impact=state_tax_revenue_impact, + benefit_spending_impact=benefit_spending_impact, + budgetary_impact=tax_revenue_impact - benefit_spending_impact, + households=_try_sum_output_variable( + self.baseline, "household_weight", entity="household" + ), + baseline_net_income=_try_sum_output_variable( + self.baseline, "household_net_income", entity="household" + ), + ) + + def _build_poverty_outputs(self) -> PovertyModuleOutputs: + prefix = "us" if self.country == "us" else "uk" + baseline_poverty_by_age = _try_compute_output( + "baseline poverty by age", + _poverty_module_function(f"calculate_{prefix}_poverty_by_age"), + self.baseline, + ) + reform_poverty_by_age = _try_compute_output( + "reform poverty by age", + _poverty_module_function(f"calculate_{prefix}_poverty_by_age"), + self.reform, + ) + baseline_poverty_by_gender = _try_compute_output( + "baseline poverty by gender", + _poverty_module_function(f"calculate_{prefix}_poverty_by_gender"), + self.baseline, + ) + reform_poverty_by_gender = _try_compute_output( + "reform poverty by gender", + _poverty_module_function(f"calculate_{prefix}_poverty_by_gender"), + self.reform, + ) + baseline_poverty_by_race = None + reform_poverty_by_race = None + if self.country == "us": + baseline_poverty_by_race = _try_compute_output( + "baseline poverty by race", + _poverty_module_function("calculate_us_poverty_by_race"), + self.baseline, + ) + reform_poverty_by_race = _try_compute_output( + "reform poverty by race", + _poverty_module_function("calculate_us_poverty_by_race"), + self.reform, + ) + return PovertyModuleOutputs( + poverty=build_poverty_output( + self.country, + baseline=getattr(self.analysis, "baseline_poverty", None), + reform=getattr(self.analysis, "reform_poverty", None), + baseline_by_age=baseline_poverty_by_age, + reform_by_age=reform_poverty_by_age, + ), + poverty_by_gender=build_poverty_by_gender_output( + self.country, + baseline_by_gender=baseline_poverty_by_gender, + reform_by_gender=reform_poverty_by_gender, + ), + poverty_by_race=( + build_poverty_by_race_output( + baseline_by_race=baseline_poverty_by_race, + reform_by_race=reform_poverty_by_race, + ) + if self.country == "us" + else None + ), + ) + + def _build_intra_decile_output(self) -> IntraDecileOutput: + from policyengine.outputs.intra_decile_impact import ( + compute_intra_decile_impacts, + ) + + collection = _try_compute_output( + "intra-decile impacts", + compute_intra_decile_impacts, + self.baseline, + self.reform, + income_variable="household_net_income", + entity="household", + ) + return build_intra_decile_output(collection) + + def _build_wealth_decile(self, wealth_decile) -> DecileOutput | None: + if self.country != "uk": + return None + return build_decile_output(wealth_decile) + + def _build_intra_wealth_decile( + self, intra_wealth_decile + ) -> IntraDecileOutput | None: + if self.country != "uk": + return None + return build_intra_decile_output(intra_wealth_decile) + + def _build_labor_supply_response(self) -> LaborSupplyResponseOutput | None: + return build_labor_supply_response_output(self.analysis) + + def _build_congressional_district_impact( + self, + ) -> GeographicImpactOutput | None: + if self.country != "us": + return None + + from policyengine.outputs.congressional_district_impact import ( + compute_us_congressional_district_impacts, + ) + + impact = _try_compute_output( + "congressional district impacts", + compute_us_congressional_district_impacts, + self.baseline, + self.reform, + ) + return build_geographic_impact_output( + getattr(impact, "district_results", None) if impact is not None else None + ) + + def _build_uk_constituency_impact(self) -> GeographicImpactOutput | None: + if self.country != "uk": + return None + + impact = _try_compute_output( + "constituency impacts", + _output_module_function( + "constituency_impact", "compute_uk_constituency_impacts" + ), + self.baseline, + self.reform, + ) + if impact is None: + return None + return build_geographic_impact_output( + getattr(impact, "constituency_results", None) + ) + + def _build_uk_local_authority_impact(self) -> GeographicImpactOutput | None: + if self.country != "uk": + return None + + impact = _try_compute_output( + "local authority impacts", + _output_module_function( + "local_authority_impact", "compute_uk_local_authority_impacts" + ), + self.baseline, + self.reform, + ) + if impact is None: + return None + return build_geographic_impact_output( + getattr(impact, "local_authority_results", None) + ) + + def _model_version(self) -> str: + return str(getattr(self.country_module.model, "version", "")) + + def _data_version(self) -> str: + if self.simulation_params.get("data_version"): + return str(self.simulation_params["data_version"]) + try: + return get_country_release_bundle(self.country).data_version + except ValueError: + pass + metadata = getattr(self.dataset, "metadata", {}) or {} + for key in ("data_version", "version"): + value = metadata.get(key) + if value is not None: + return str(value) + return "" diff --git a/projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py b/projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py index d3ad36cc6..a094ed385 100644 --- a/projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py +++ b/projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py @@ -18,10 +18,8 @@ REFORM_POVERTY_BY_RACE, fake_analysis, ) -from src.modal.simulation import ( - SimulationMacroOutputBuilder, - _normalise_policy, -) +from src.modal.simulation import _normalise_policy +from src.modal.simulation_macro_output_builder import SimulationMacroOutputBuilder from src.modal.simulation_macro_output import ( BudgetaryImpact, BudgetaryOutput, @@ -243,7 +241,8 @@ def compute(baseline_simulation, reform_simulation): return compute monkeypatch.setattr( - "src.modal.simulation._output_module_function", fake_output_module_function + "src.modal.simulation_macro_output_builder._output_module_function", + fake_output_module_function, ) assert ( @@ -277,7 +276,8 @@ def compute(baseline_simulation, reform_simulation): return compute monkeypatch.setattr( - "src.modal.simulation._output_module_function", fake_output_module_function + "src.modal.simulation_macro_output_builder._output_module_function", + fake_output_module_function, ) assert ( From dcf865a8e7c144a4d2014027579f9bd91b284315 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 22 May 2026 15:36:03 +0200 Subject: [PATCH 13/23] refactor: fold output adapter into builder --- ...r.py => test_simulation_output_builder.py} | 0 .../src/modal/simulation.py | 6 +- .../modal/simulation_macro_output_builder.py | 359 ---------- .../src/modal/simulation_output_adapter.py | 544 --------------- .../src/modal/simulation_output_builder.py | 658 ++++++++++++++++++ ...r.py => test_simulation_output_builder.py} | 264 ++++--- 6 files changed, 836 insertions(+), 995 deletions(-) rename projects/policyengine-api-simulation/fixtures/{test_simulation_output_adapter.py => test_simulation_output_builder.py} (100%) delete mode 100644 projects/policyengine-api-simulation/src/modal/simulation_macro_output_builder.py delete mode 100644 projects/policyengine-api-simulation/src/modal/simulation_output_adapter.py create mode 100644 projects/policyengine-api-simulation/src/modal/simulation_output_builder.py rename projects/policyengine-api-simulation/tests/{test_simulation_output_adapter.py => test_simulation_output_builder.py} (56%) diff --git a/projects/policyengine-api-simulation/fixtures/test_simulation_output_adapter.py b/projects/policyengine-api-simulation/fixtures/test_simulation_output_builder.py similarity index 100% rename from projects/policyengine-api-simulation/fixtures/test_simulation_output_adapter.py rename to projects/policyengine-api-simulation/fixtures/test_simulation_output_builder.py diff --git a/projects/policyengine-api-simulation/src/modal/simulation.py b/projects/policyengine-api-simulation/src/modal/simulation.py index 89307b909..c0527cda6 100644 --- a/projects/policyengine-api-simulation/src/modal/simulation.py +++ b/projects/policyengine-api-simulation/src/modal/simulation.py @@ -15,7 +15,7 @@ from typing import Any, Iterator from src.modal.release_bundle import resolve_bundle_dataset_name -from src.modal.simulation_macro_output_builder import SimulationMacroOutputBuilder +from src.modal.simulation_output_builder import SimulationOutputBuilder from src.modal.telemetry import split_internal_payload logger = logging.getLogger(__name__) @@ -291,15 +291,13 @@ def _run_simulation_impl_core(params: dict) -> dict: ) logger.info("Calculating economic impact") - analysis = country_module.economic_impact_analysis(baseline, reform) - output = SimulationMacroOutputBuilder( + output = SimulationOutputBuilder( country=country, simulation_params=simulation_params, country_module=country_module, dataset=dataset, baseline=baseline, reform=reform, - analysis=analysis, ).serialize() logger.info("Comparison complete") return output diff --git a/projects/policyengine-api-simulation/src/modal/simulation_macro_output_builder.py b/projects/policyengine-api-simulation/src/modal/simulation_macro_output_builder.py deleted file mode 100644 index 5f0ed34f2..000000000 --- a/projects/policyengine-api-simulation/src/modal/simulation_macro_output_builder.py +++ /dev/null @@ -1,359 +0,0 @@ -"""Build and serialize the runtime simulation macro output.""" - -from __future__ import annotations - -import logging -from dataclasses import dataclass -from importlib import import_module -from typing import Any - -from src.modal.release_bundle import get_country_release_bundle -from src.modal.simulation_macro_output import ( - BudgetaryImpact, - DecileOutput, - DetailedBudgetOutput, - GeographicImpactOutput, - InequalityOutput, - IntraDecileOutput, - LaborSupplyResponseOutput, - PovertyModuleOutputs, - SingleYearMacroOutput, -) -from src.modal.simulation_output_adapter import ( - build_decile_output, - build_detailed_budget_output, - build_geographic_impact_output, - build_inequality_output, - build_intra_decile_output, - build_labor_supply_response_output, - build_poverty_by_gender_output, - build_poverty_by_race_output, - build_poverty_output, -) - -logger = logging.getLogger(__name__) - - -def _entity_data(simulation, entity: str): - if simulation.output_dataset is None or simulation.output_dataset.data is None: - simulation.ensure() - return getattr(simulation.output_dataset.data, entity) - - -def _sum_output_variable(simulation, variable: str, entity: str) -> float: - data = _entity_data(simulation, entity) - if variable in data.columns: - return float(data[variable].sum()) - - from policyengine.outputs import Aggregate, AggregateType - - output = Aggregate( - simulation=simulation, - variable=variable, - entity=entity, - aggregate_type=AggregateType.SUM, - ) - output.run() - return float(output.result) - - -def _try_sum_output_variable(simulation, variable: str, entity: str) -> float: - try: - return _sum_output_variable(simulation, variable, entity) - except Exception: - logger.warning("Unable to calculate sum for %s", variable, exc_info=True) - return 0.0 - - -def _change_output_variable(baseline, reform, variable: str, entity: str) -> float: - baseline_data = _entity_data(baseline, entity) - reform_data = _entity_data(reform, entity) - if variable in baseline_data.columns and variable in reform_data.columns: - return float((reform_data[variable] - baseline_data[variable]).sum()) - - from policyengine.outputs import ChangeAggregate, ChangeAggregateType - - output = ChangeAggregate( - baseline_simulation=baseline, - reform_simulation=reform, - variable=variable, - entity=entity, - aggregate_type=ChangeAggregateType.SUM, - ) - output.run() - return float(output.result) - - -def _try_change_output_variable(baseline, reform, variable: str, entity: str) -> float: - try: - return _change_output_variable(baseline, reform, variable, entity) - except Exception: - logger.warning("Unable to calculate change for %s", variable, exc_info=True) - return 0.0 - - -def _output_module_function(module_name: str, name: str): - module = import_module(f"policyengine.outputs.{module_name}") - return getattr(module, name) - - -def _poverty_module_function(name: str): - return _output_module_function("poverty", name) - - -def _try_compute_output(label: str, fn, *args, **kwargs): - try: - return fn(*args, **kwargs) - except Exception: - logger.warning("Unable to calculate %s", label, exc_info=True) - return None - - -@dataclass -class SimulationMacroOutputBuilder: - country: str - simulation_params: dict[str, Any] - country_module: Any - dataset: Any - baseline: Any - reform: Any - analysis: Any - - def __post_init__(self) -> None: - self.country = self.country.lower() - - def build(self) -> SingleYearMacroOutput: - poverty_outputs = self._build_poverty_outputs() - wealth_decile = getattr(self.analysis, "wealth_decile_impacts", None) - intra_wealth_decile = getattr( - self.analysis, "intra_wealth_decile_impacts", None - ) - - return SingleYearMacroOutput( - model_version=self._model_version(), - data_version=self._data_version(), - budget=self._build_budgetary_impact(), - detailed_budget=self._build_detailed_budget(), - decile=self._build_decile(), - inequality=self._build_inequality(), - poverty=poverty_outputs.poverty, - poverty_by_gender=poverty_outputs.poverty_by_gender, - poverty_by_race=poverty_outputs.poverty_by_race, - intra_decile=self._build_intra_decile_output(), - wealth_decile=self._build_wealth_decile(wealth_decile), - intra_wealth_decile=self._build_intra_wealth_decile(intra_wealth_decile), - labor_supply_response=self._build_labor_supply_response(), - congressional_district_impact=(self._build_congressional_district_impact()), - constituency_impact=self._build_uk_constituency_impact(), - local_authority_impact=self._build_uk_local_authority_impact(), - cliff_impact=None, - ) - - def serialize(self) -> dict[str, Any]: - return self.build().model_dump(mode="json") - - def _build_detailed_budget(self) -> DetailedBudgetOutput: - return build_detailed_budget_output( - getattr(self.analysis, "program_statistics", None) - ) - - def _build_decile(self) -> DecileOutput: - return build_decile_output(getattr(self.analysis, "decile_impacts", None)) - - def _build_inequality(self) -> InequalityOutput: - return build_inequality_output( - getattr(self.analysis, "baseline_inequality", None), - getattr(self.analysis, "reform_inequality", None), - ) - - def _build_budgetary_impact(self) -> BudgetaryImpact: - tax_revenue_impact = _try_change_output_variable( - self.baseline, self.reform, "household_tax", entity="household" - ) - benefit_spending_impact = _try_change_output_variable( - self.baseline, self.reform, "household_benefits", entity="household" - ) - state_tax_revenue_impact = ( - _try_change_output_variable( - self.baseline, - self.reform, - "household_state_income_tax", - entity="household", - ) - if self.country == "us" - else 0.0 - ) - - return BudgetaryImpact( - tax_revenue_impact=tax_revenue_impact, - state_tax_revenue_impact=state_tax_revenue_impact, - benefit_spending_impact=benefit_spending_impact, - budgetary_impact=tax_revenue_impact - benefit_spending_impact, - households=_try_sum_output_variable( - self.baseline, "household_weight", entity="household" - ), - baseline_net_income=_try_sum_output_variable( - self.baseline, "household_net_income", entity="household" - ), - ) - - def _build_poverty_outputs(self) -> PovertyModuleOutputs: - prefix = "us" if self.country == "us" else "uk" - baseline_poverty_by_age = _try_compute_output( - "baseline poverty by age", - _poverty_module_function(f"calculate_{prefix}_poverty_by_age"), - self.baseline, - ) - reform_poverty_by_age = _try_compute_output( - "reform poverty by age", - _poverty_module_function(f"calculate_{prefix}_poverty_by_age"), - self.reform, - ) - baseline_poverty_by_gender = _try_compute_output( - "baseline poverty by gender", - _poverty_module_function(f"calculate_{prefix}_poverty_by_gender"), - self.baseline, - ) - reform_poverty_by_gender = _try_compute_output( - "reform poverty by gender", - _poverty_module_function(f"calculate_{prefix}_poverty_by_gender"), - self.reform, - ) - baseline_poverty_by_race = None - reform_poverty_by_race = None - if self.country == "us": - baseline_poverty_by_race = _try_compute_output( - "baseline poverty by race", - _poverty_module_function("calculate_us_poverty_by_race"), - self.baseline, - ) - reform_poverty_by_race = _try_compute_output( - "reform poverty by race", - _poverty_module_function("calculate_us_poverty_by_race"), - self.reform, - ) - return PovertyModuleOutputs( - poverty=build_poverty_output( - self.country, - baseline=getattr(self.analysis, "baseline_poverty", None), - reform=getattr(self.analysis, "reform_poverty", None), - baseline_by_age=baseline_poverty_by_age, - reform_by_age=reform_poverty_by_age, - ), - poverty_by_gender=build_poverty_by_gender_output( - self.country, - baseline_by_gender=baseline_poverty_by_gender, - reform_by_gender=reform_poverty_by_gender, - ), - poverty_by_race=( - build_poverty_by_race_output( - baseline_by_race=baseline_poverty_by_race, - reform_by_race=reform_poverty_by_race, - ) - if self.country == "us" - else None - ), - ) - - def _build_intra_decile_output(self) -> IntraDecileOutput: - from policyengine.outputs.intra_decile_impact import ( - compute_intra_decile_impacts, - ) - - collection = _try_compute_output( - "intra-decile impacts", - compute_intra_decile_impacts, - self.baseline, - self.reform, - income_variable="household_net_income", - entity="household", - ) - return build_intra_decile_output(collection) - - def _build_wealth_decile(self, wealth_decile) -> DecileOutput | None: - if self.country != "uk": - return None - return build_decile_output(wealth_decile) - - def _build_intra_wealth_decile( - self, intra_wealth_decile - ) -> IntraDecileOutput | None: - if self.country != "uk": - return None - return build_intra_decile_output(intra_wealth_decile) - - def _build_labor_supply_response(self) -> LaborSupplyResponseOutput | None: - return build_labor_supply_response_output(self.analysis) - - def _build_congressional_district_impact( - self, - ) -> GeographicImpactOutput | None: - if self.country != "us": - return None - - from policyengine.outputs.congressional_district_impact import ( - compute_us_congressional_district_impacts, - ) - - impact = _try_compute_output( - "congressional district impacts", - compute_us_congressional_district_impacts, - self.baseline, - self.reform, - ) - return build_geographic_impact_output( - getattr(impact, "district_results", None) if impact is not None else None - ) - - def _build_uk_constituency_impact(self) -> GeographicImpactOutput | None: - if self.country != "uk": - return None - - impact = _try_compute_output( - "constituency impacts", - _output_module_function( - "constituency_impact", "compute_uk_constituency_impacts" - ), - self.baseline, - self.reform, - ) - if impact is None: - return None - return build_geographic_impact_output( - getattr(impact, "constituency_results", None) - ) - - def _build_uk_local_authority_impact(self) -> GeographicImpactOutput | None: - if self.country != "uk": - return None - - impact = _try_compute_output( - "local authority impacts", - _output_module_function( - "local_authority_impact", "compute_uk_local_authority_impacts" - ), - self.baseline, - self.reform, - ) - if impact is None: - return None - return build_geographic_impact_output( - getattr(impact, "local_authority_results", None) - ) - - def _model_version(self) -> str: - return str(getattr(self.country_module.model, "version", "")) - - def _data_version(self) -> str: - if self.simulation_params.get("data_version"): - return str(self.simulation_params["data_version"]) - try: - return get_country_release_bundle(self.country).data_version - except ValueError: - pass - metadata = getattr(self.dataset, "metadata", {}) or {} - for key in ("data_version", "version"): - value = metadata.get(key) - if value is not None: - return str(value) - return "" diff --git a/projects/policyengine-api-simulation/src/modal/simulation_output_adapter.py b/projects/policyengine-api-simulation/src/modal/simulation_output_adapter.py deleted file mode 100644 index 7f1ad1625..000000000 --- a/projects/policyengine-api-simulation/src/modal/simulation_output_adapter.py +++ /dev/null @@ -1,544 +0,0 @@ -"""Adapt PolicyEngine v4 macro outputs to the existing simulation API shape.""" - -from __future__ import annotations - -import math -from collections.abc import Iterable, Mapping -from dataclasses import dataclass -from typing import Any - -from src.modal.simulation_macro_output import ( - AgePovertyOutput, - BaselineReformValue, - BudgetaryImpact, - BudgetaryOutput, - DecileOutput, - DetailedBudgetOutput, - DetailedBudgetProgramOutput, - GeographicImpactOutput, - GenderPovertyOutput, - InequalityOutput, - IntraDecileOutput, - LaborSupplyResponseOutput, - PovertyByGenderOutput, - PovertyByRaceOutput, - PovertyOutput, - RacePovertyOutput, - SingleYearMacroOutput, -) - -INTRA_DECILE_COLUMNS = { - "Lose more than 5%": "lose_more_than_5pct", - "Lose less than 5%": "lose_less_than_5pct", - "No change": "no_change", - "Gain less than 5%": "gain_less_than_5pct", - "Gain more than 5%": "gain_more_than_5pct", -} - -US_POVERTY_TYPES = { - "spm": "poverty", - "spm_deep": "deep_poverty", -} - -UK_POVERTY_TYPES = { - "relative_bhc": "poverty", - "absolute_bhc": "deep_poverty", -} - - -def _number(value: Any, default: float = 0.0) -> float: - if value is None: - return default - try: - result = float(value) - except (TypeError, ValueError): - return default - if math.isnan(result) or math.isinf(result): - return default - return result - - -def _collection_records(collection: Any) -> list[dict[str, Any]]: - if collection is None: - return [] - dataframe = getattr(collection, "dataframe", None) - if dataframe is not None: - return list(dataframe.to_dict("records")) - if isinstance(collection, list): - return [dict(item) for item in collection if isinstance(item, Mapping)] - return [] - - -def _output_model_dump(value: Any) -> Any: - if value is None: - return None - if hasattr(value, "model_dump"): - return value.model_dump(mode="json") - if isinstance(value, Mapping): - return dict(value) - return None - - -def build_geographic_impact_output(value: Any) -> GeographicImpactOutput | None: - if isinstance(value, GeographicImpactOutput): - return value - records = _output_model_dump(value) - if isinstance(records, list): - return GeographicImpactOutput( - [dict(item) for item in records if isinstance(item, Mapping)] - ) - if isinstance(value, list): - return GeographicImpactOutput( - [dict(item) for item in value if isinstance(item, Mapping)] - ) - return None - - -def build_budgetary_output( - budget: Mapping[str, Any] | BudgetaryImpact, -) -> BudgetaryOutput: - if isinstance(budget, BudgetaryImpact): - return budget - return BudgetaryImpact( - tax_revenue_impact=_number(budget.get("tax_revenue_impact")), - state_tax_revenue_impact=_number(budget.get("state_tax_revenue_impact")), - benefit_spending_impact=_number(budget.get("benefit_spending_impact")), - budgetary_impact=_number(budget.get("budgetary_impact")), - households=_number(budget.get("households")), - baseline_net_income=_number(budget.get("baseline_net_income")), - ) - - -def build_detailed_budget_output( - collection: Any, -) -> DetailedBudgetOutput: - if isinstance(collection, DetailedBudgetOutput): - return collection - detailed_budget: dict[str, DetailedBudgetProgramOutput] = {} - for row in _collection_records(collection): - program_name = row.get("program_name") - if not program_name: - continue - baseline = _number(row.get("baseline_total")) - reform = _number(row.get("reform_total")) - detailed_budget[str(program_name)] = DetailedBudgetProgramOutput( - baseline=baseline, - reform=reform, - difference=_number(row.get("change"), reform - baseline), - ) - return DetailedBudgetOutput(detailed_budget) - - -def build_decile_output(collection: Any) -> DecileOutput: - if isinstance(collection, DecileOutput): - return collection - average: dict[str, float] = {} - relative: dict[str, float] = {} - for row in sorted( - _collection_records(collection), - key=lambda item: _number(item.get("decile")), - ): - decile = int(_number(row.get("decile"))) - if decile <= 0: - continue - key = str(decile) - average[key] = _number(row.get("absolute_change")) - relative[key] = _number(row.get("relative_change")) - return DecileOutput(average=average, relative=relative) - - -def build_intra_decile_output(collection: Any) -> IntraDecileOutput: - if isinstance(collection, IntraDecileOutput): - return collection - deciles: dict[str, list[float]] = {label: [] for label in INTRA_DECILE_COLUMNS} - all_values: dict[str, float] = {label: 0.0 for label in INTRA_DECILE_COLUMNS} - rows = [ - row - for row in sorted( - _collection_records(collection), - key=lambda item: _number(item.get("decile")), - ) - if int(_number(row.get("decile"))) > 0 - ] - - for label, column in INTRA_DECILE_COLUMNS.items(): - values = [_number(row.get(column)) for row in rows] - deciles[label] = values - all_values[label] = sum(values) / len(values) if values else 0.0 - return IntraDecileOutput(deciles=deciles, all=all_values) - - -def _empty_baseline_reform_value() -> dict[str, float]: - return {"baseline": 0.0, "reform": 0.0} - - -def _empty_age_poverty() -> dict[str, dict[str, float]]: - return { - "child": _empty_baseline_reform_value(), - "adult": _empty_baseline_reform_value(), - "senior": _empty_baseline_reform_value(), - "all": _empty_baseline_reform_value(), - } - - -def _empty_gender_poverty() -> dict[str, dict[str, float]]: - return { - "male": _empty_baseline_reform_value(), - "female": _empty_baseline_reform_value(), - } - - -def _poverty_type(country: str, row: Mapping[str, Any]) -> str | None: - poverty_type = str(row.get("poverty_type") or "").lower() - if country == "us": - return US_POVERTY_TYPES.get(poverty_type) - return UK_POVERTY_TYPES.get(poverty_type) - - -def _fill_poverty_block( - *, - country: str, - output: dict[str, dict[str, dict[str, float]]], - baseline_records: Iterable[Mapping[str, Any]], - reform_records: Iterable[Mapping[str, Any]], - default_group: str, -) -> None: - for side, records in (("baseline", baseline_records), ("reform", reform_records)): - for row in records: - poverty_type = _poverty_type(country, row) - if poverty_type is None: - continue - if poverty_type not in output: - continue - group = str(row.get("filter_group") or default_group).lower() - if group not in output[poverty_type]: - continue - output[poverty_type][group][side] = _number(row.get("rate")) - - -def _age_poverty_output(values: dict[str, dict[str, float]]) -> AgePovertyOutput: - return AgePovertyOutput( - child=BaselineReformValue(**values["child"]), - adult=BaselineReformValue(**values["adult"]), - senior=BaselineReformValue(**values["senior"]), - all=BaselineReformValue(**values["all"]), - ) - - -def _gender_poverty_output( - values: dict[str, dict[str, float]], -) -> GenderPovertyOutput: - return GenderPovertyOutput( - male=BaselineReformValue(**values["male"]), - female=BaselineReformValue(**values["female"]), - ) - - -def _race_poverty_output(values: dict[str, dict[str, float]]) -> RacePovertyOutput: - return RacePovertyOutput( - white=BaselineReformValue(**values["white"]), - black=BaselineReformValue(**values["black"]), - hispanic=BaselineReformValue(**values["hispanic"]), - other=BaselineReformValue(**values["other"]), - ) - - -def build_poverty_output( - country: str, - *, - baseline: Any, - reform: Any, - baseline_by_age: Any, - reform_by_age: Any, -) -> PovertyOutput: - if isinstance(baseline, PovertyOutput): - return baseline - result = {"poverty": _empty_age_poverty(), "deep_poverty": _empty_age_poverty()} - _fill_poverty_block( - country=country, - output=result, - baseline_records=_collection_records(baseline), - reform_records=_collection_records(reform), - default_group="all", - ) - _fill_poverty_block( - country=country, - output=result, - baseline_records=_collection_records(baseline_by_age), - reform_records=_collection_records(reform_by_age), - default_group="all", - ) - return PovertyOutput( - poverty=_age_poverty_output(result["poverty"]), - deep_poverty=_age_poverty_output(result["deep_poverty"]), - ) - - -def build_poverty_by_gender_output( - country: str, - *, - baseline_by_gender: Any, - reform_by_gender: Any, -) -> PovertyByGenderOutput: - if isinstance(baseline_by_gender, PovertyByGenderOutput): - return baseline_by_gender - result = { - "poverty": _empty_gender_poverty(), - "deep_poverty": _empty_gender_poverty(), - } - _fill_poverty_block( - country=country, - output=result, - baseline_records=_collection_records(baseline_by_gender), - reform_records=_collection_records(reform_by_gender), - default_group="all", - ) - return PovertyByGenderOutput( - poverty=_gender_poverty_output(result["poverty"]), - deep_poverty=_gender_poverty_output(result["deep_poverty"]), - ) - - -def build_poverty_by_race_output( - *, - baseline_by_race: Any, - reform_by_race: Any, -) -> PovertyByRaceOutput: - if isinstance(baseline_by_race, PovertyByRaceOutput): - return baseline_by_race - result = { - "poverty": { - "white": _empty_baseline_reform_value(), - "black": _empty_baseline_reform_value(), - "hispanic": _empty_baseline_reform_value(), - "other": _empty_baseline_reform_value(), - } - } - _fill_poverty_block( - country="us", - output=result, - baseline_records=_collection_records(baseline_by_race), - reform_records=_collection_records(reform_by_race), - default_group="all", - ) - return PovertyByRaceOutput(poverty=_race_poverty_output(result["poverty"])) - - -def build_inequality_output(baseline: Any, reform: Any) -> InequalityOutput: - if isinstance(baseline, InequalityOutput): - return baseline - return InequalityOutput( - gini=BaselineReformValue( - baseline=_number(getattr(baseline, "gini", None)), - reform=_number(getattr(reform, "gini", None)), - ), - top_10_pct_share=BaselineReformValue( - baseline=_number(getattr(baseline, "top_10_share", None)), - reform=_number(getattr(reform, "top_10_share", None)), - ), - top_1_pct_share=BaselineReformValue( - baseline=_number(getattr(baseline, "top_1_share", None)), - reform=_number(getattr(reform, "top_1_share", None)), - ), - ) - - -def build_labor_supply_response_output( - analysis: Any, -) -> LaborSupplyResponseOutput | None: - if isinstance(analysis, LaborSupplyResponseOutput): - return analysis - output = _output_model_dump(getattr(analysis, "labor_supply_response", None)) - return LaborSupplyResponseOutput(output) if isinstance(output, dict) else None - - -@dataclass -class SingleYearMacroOutputBuilder: - country: str - model_version: str - data_version: str - budget: Mapping[str, Any] | BudgetaryImpact - analysis: Any - baseline_poverty_by_age: Any = None - reform_poverty_by_age: Any = None - baseline_poverty_by_gender: Any = None - reform_poverty_by_gender: Any = None - baseline_poverty_by_race: Any = None - reform_poverty_by_race: Any = None - intra_decile: Any = None - congressional_district_impact: Any = None - constituency_impact: Any = None - local_authority_impact: Any = None - - def __post_init__(self) -> None: - self.country = self.country.lower() - - def build(self) -> SingleYearMacroOutput: - return SingleYearMacroOutput( - model_version=self.model_version, - data_version=self.data_version, - budget=self._build_budgetary_impact(), - detailed_budget=self._build_detailed_budget(), - decile=self._build_decile(), - inequality=self._build_inequality(), - poverty=self._build_poverty(), - poverty_by_gender=self._build_poverty_by_gender(), - poverty_by_race=self._build_poverty_by_race(), - intra_decile=self._build_intra_decile(), - wealth_decile=self._build_wealth_decile(), - intra_wealth_decile=self._build_intra_wealth_decile(), - labor_supply_response=self._build_labor_supply_response(), - constituency_impact=self._build_constituency_impact(), - local_authority_impact=self._build_local_authority_impact(), - congressional_district_impact=self._build_congressional_district_impact(), - cliff_impact=None, - ) - - def serialize(self) -> dict[str, Any]: - return self.build().model_dump(mode="json") - - def _build_budgetary_impact(self) -> BudgetaryImpact: - return build_budgetary_output(self.budget) - - def _build_detailed_budget(self) -> DetailedBudgetOutput: - return build_detailed_budget_output( - getattr(self.analysis, "program_statistics", None) - ) - - def _build_decile(self) -> DecileOutput: - return build_decile_output(getattr(self.analysis, "decile_impacts", None)) - - def _build_inequality(self) -> InequalityOutput: - return build_inequality_output( - getattr(self.analysis, "baseline_inequality", None), - getattr(self.analysis, "reform_inequality", None), - ) - - def _build_poverty(self) -> PovertyOutput: - return build_poverty_output( - self.country, - baseline=getattr(self.analysis, "baseline_poverty", None), - reform=getattr(self.analysis, "reform_poverty", None), - baseline_by_age=self.baseline_poverty_by_age, - reform_by_age=self.reform_poverty_by_age, - ) - - def _build_poverty_by_gender(self) -> PovertyByGenderOutput: - return build_poverty_by_gender_output( - self.country, - baseline_by_gender=self.baseline_poverty_by_gender, - reform_by_gender=self.reform_poverty_by_gender, - ) - - def _build_poverty_by_race(self) -> PovertyByRaceOutput | None: - if self.country != "us": - return None - return build_poverty_by_race_output( - baseline_by_race=self.baseline_poverty_by_race, - reform_by_race=self.reform_poverty_by_race, - ) - - def _build_intra_decile(self) -> IntraDecileOutput: - return build_intra_decile_output(self.intra_decile) - - def _build_wealth_decile(self) -> DecileOutput | None: - if self.country != "uk": - return None - return build_decile_output( - getattr(self.analysis, "wealth_decile_impacts", None) - ) - - def _build_intra_wealth_decile(self) -> IntraDecileOutput | None: - if self.country != "uk": - return None - return build_intra_decile_output( - getattr(self.analysis, "intra_wealth_decile_impacts", None) - ) - - def _build_labor_supply_response(self) -> LaborSupplyResponseOutput | None: - return build_labor_supply_response_output(self.analysis) - - def _build_constituency_impact(self) -> GeographicImpactOutput | None: - return build_geographic_impact_output(self.constituency_impact) - - def _build_local_authority_impact(self) -> GeographicImpactOutput | None: - return build_geographic_impact_output(self.local_authority_impact) - - def _build_congressional_district_impact(self) -> GeographicImpactOutput | None: - return build_geographic_impact_output(self.congressional_district_impact) - - -def build_single_year_macro_output( - *, - country: str, - model_version: str, - data_version: str, - budget: Mapping[str, Any] | BudgetaryImpact, - analysis: Any, - baseline_poverty_by_age: Any = None, - reform_poverty_by_age: Any = None, - baseline_poverty_by_gender: Any = None, - reform_poverty_by_gender: Any = None, - baseline_poverty_by_race: Any = None, - reform_poverty_by_race: Any = None, - intra_decile: Any = None, - congressional_district_impact: Any = None, - constituency_impact: Any = None, - local_authority_impact: Any = None, -) -> SingleYearMacroOutput: - """Build the schema-first single-year macro output.""" - return SingleYearMacroOutputBuilder( - country=country, - model_version=model_version, - data_version=data_version, - budget=budget, - analysis=analysis, - baseline_poverty_by_age=baseline_poverty_by_age, - reform_poverty_by_age=reform_poverty_by_age, - baseline_poverty_by_gender=baseline_poverty_by_gender, - reform_poverty_by_gender=reform_poverty_by_gender, - baseline_poverty_by_race=baseline_poverty_by_race, - reform_poverty_by_race=reform_poverty_by_race, - intra_decile=intra_decile, - congressional_district_impact=congressional_district_impact, - constituency_impact=constituency_impact, - local_authority_impact=local_authority_impact, - ).build() - - -def adapt_analysis_to_legacy_macro_output( - *, - country: str, - model_version: str, - data_version: str, - budget: Mapping[str, Any] | BudgetaryImpact, - analysis: Any, - baseline_poverty_by_age: Any = None, - reform_poverty_by_age: Any = None, - baseline_poverty_by_gender: Any = None, - reform_poverty_by_gender: Any = None, - baseline_poverty_by_race: Any = None, - reform_poverty_by_race: Any = None, - intra_decile: Any = None, - congressional_district_impact: Any = None, - constituency_impact: Any = None, - local_authority_impact: Any = None, -) -> dict[str, Any]: - """Return the legacy single-year macro result expected by API callers.""" - return SingleYearMacroOutputBuilder( - country=country, - model_version=model_version, - data_version=data_version, - budget=budget, - analysis=analysis, - baseline_poverty_by_age=baseline_poverty_by_age, - reform_poverty_by_age=reform_poverty_by_age, - baseline_poverty_by_gender=baseline_poverty_by_gender, - reform_poverty_by_gender=reform_poverty_by_gender, - baseline_poverty_by_race=baseline_poverty_by_race, - reform_poverty_by_race=reform_poverty_by_race, - intra_decile=intra_decile, - congressional_district_impact=congressional_district_impact, - constituency_impact=constituency_impact, - local_authority_impact=local_authority_impact, - ).serialize() diff --git a/projects/policyengine-api-simulation/src/modal/simulation_output_builder.py b/projects/policyengine-api-simulation/src/modal/simulation_output_builder.py new file mode 100644 index 000000000..e7ef45a0e --- /dev/null +++ b/projects/policyengine-api-simulation/src/modal/simulation_output_builder.py @@ -0,0 +1,658 @@ +"""Build and serialize the runtime simulation macro output.""" + +from __future__ import annotations + +import logging +import math +from collections.abc import Iterable, Mapping +from dataclasses import dataclass, field +from importlib import import_module +from typing import Any + +from src.modal.release_bundle import get_country_release_bundle +from src.modal.simulation_macro_output import ( + AgePovertyOutput, + BaselineReformValue, + BudgetaryImpact, + DecileOutput, + DetailedBudgetOutput, + DetailedBudgetProgramOutput, + GeographicImpactOutput, + GenderPovertyOutput, + InequalityOutput, + IntraDecileOutput, + LaborSupplyResponseOutput, + PovertyModuleOutputs, + PovertyByGenderOutput, + PovertyByRaceOutput, + PovertyOutput, + RacePovertyOutput, + SingleYearMacroOutput, +) + +logger = logging.getLogger(__name__) + +INTRA_DECILE_COLUMNS = { + "Lose more than 5%": "lose_more_than_5pct", + "Lose less than 5%": "lose_less_than_5pct", + "No change": "no_change", + "Gain less than 5%": "gain_less_than_5pct", + "Gain more than 5%": "gain_more_than_5pct", +} + +US_POVERTY_TYPES = { + "spm": "poverty", + "spm_deep": "deep_poverty", +} + +UK_POVERTY_TYPES = { + "relative_bhc": "poverty", + "absolute_bhc": "deep_poverty", +} + + +def _number(value: Any, default: float = 0.0) -> float: + if value is None: + return default + try: + result = float(value) + except (TypeError, ValueError): + return default + if math.isnan(result) or math.isinf(result): + return default + return result + + +def _collection_records(collection: Any) -> list[dict[str, Any]]: + if collection is None: + return [] + dataframe = getattr(collection, "dataframe", None) + if dataframe is not None: + return list(dataframe.to_dict("records")) + if isinstance(collection, list): + return [dict(item) for item in collection if isinstance(item, Mapping)] + return [] + + +def _output_model_dump(value: Any) -> Any: + if value is None: + return None + if hasattr(value, "model_dump"): + return value.model_dump(mode="json") + if isinstance(value, Mapping): + return dict(value) + return None + + +def _empty_baseline_reform_value() -> dict[str, float]: + return {"baseline": 0.0, "reform": 0.0} + + +def _empty_age_poverty() -> dict[str, dict[str, float]]: + return { + "child": _empty_baseline_reform_value(), + "adult": _empty_baseline_reform_value(), + "senior": _empty_baseline_reform_value(), + "all": _empty_baseline_reform_value(), + } + + +def _empty_gender_poverty() -> dict[str, dict[str, float]]: + return { + "male": _empty_baseline_reform_value(), + "female": _empty_baseline_reform_value(), + } + + +def _poverty_type(country: str, row: Mapping[str, Any]) -> str | None: + poverty_type = str(row.get("poverty_type") or "").lower() + if country == "us": + return US_POVERTY_TYPES.get(poverty_type) + return UK_POVERTY_TYPES.get(poverty_type) + + +def _fill_poverty_block( + *, + country: str, + output: dict[str, dict[str, dict[str, float]]], + baseline_records: Iterable[Mapping[str, Any]], + reform_records: Iterable[Mapping[str, Any]], + default_group: str, +) -> None: + for side, records in (("baseline", baseline_records), ("reform", reform_records)): + for row in records: + poverty_type = _poverty_type(country, row) + if poverty_type is None: + continue + if poverty_type not in output: + continue + group = str(row.get("filter_group") or default_group).lower() + if group not in output[poverty_type]: + continue + output[poverty_type][group][side] = _number(row.get("rate")) + + +def _age_poverty_output(values: dict[str, dict[str, float]]) -> AgePovertyOutput: + return AgePovertyOutput( + child=BaselineReformValue(**values["child"]), + adult=BaselineReformValue(**values["adult"]), + senior=BaselineReformValue(**values["senior"]), + all=BaselineReformValue(**values["all"]), + ) + + +def _gender_poverty_output( + values: dict[str, dict[str, float]], +) -> GenderPovertyOutput: + return GenderPovertyOutput( + male=BaselineReformValue(**values["male"]), + female=BaselineReformValue(**values["female"]), + ) + + +def _race_poverty_output(values: dict[str, dict[str, float]]) -> RacePovertyOutput: + return RacePovertyOutput( + white=BaselineReformValue(**values["white"]), + black=BaselineReformValue(**values["black"]), + hispanic=BaselineReformValue(**values["hispanic"]), + other=BaselineReformValue(**values["other"]), + ) + + +def _entity_data(simulation, entity: str): + if simulation.output_dataset is None or simulation.output_dataset.data is None: + simulation.ensure() + return getattr(simulation.output_dataset.data, entity) + + +def _sum_output_variable(simulation, variable: str, entity: str) -> float: + data = _entity_data(simulation, entity) + if variable in data.columns: + return float(data[variable].sum()) + + from policyengine.outputs import Aggregate, AggregateType + + output = Aggregate( + simulation=simulation, + variable=variable, + entity=entity, + aggregate_type=AggregateType.SUM, + ) + output.run() + return float(output.result) + + +def _try_sum_output_variable(simulation, variable: str, entity: str) -> float: + try: + return _sum_output_variable(simulation, variable, entity) + except Exception: + logger.warning("Unable to calculate sum for %s", variable, exc_info=True) + return 0.0 + + +def _change_output_variable(baseline, reform, variable: str, entity: str) -> float: + baseline_data = _entity_data(baseline, entity) + reform_data = _entity_data(reform, entity) + if variable in baseline_data.columns and variable in reform_data.columns: + return float((reform_data[variable] - baseline_data[variable]).sum()) + + from policyengine.outputs import ChangeAggregate, ChangeAggregateType + + output = ChangeAggregate( + baseline_simulation=baseline, + reform_simulation=reform, + variable=variable, + entity=entity, + aggregate_type=ChangeAggregateType.SUM, + ) + output.run() + return float(output.result) + + +def _try_change_output_variable(baseline, reform, variable: str, entity: str) -> float: + try: + return _change_output_variable(baseline, reform, variable, entity) + except Exception: + logger.warning("Unable to calculate change for %s", variable, exc_info=True) + return 0.0 + + +def _output_module_function(module_name: str, name: str): + module = import_module(f"policyengine.outputs.{module_name}") + return getattr(module, name) + + +def _poverty_module_function(name: str): + return _output_module_function("poverty", name) + + +def _try_compute_output(label: str, fn, *args, **kwargs): + try: + return fn(*args, **kwargs) + except Exception: + logger.warning("Unable to calculate %s", label, exc_info=True) + return None + + +@dataclass +class SimulationOutputBuilder: + country: str + simulation_params: dict[str, Any] + country_module: Any + dataset: Any + baseline: Any + reform: Any + _analysis: Any = field(default=None, init=False) + + def __post_init__(self) -> None: + self.country = self.country.lower() + + @property + def analysis(self) -> Any: + if self._analysis is None: + self._analysis = self.country_module.economic_impact_analysis( + self.baseline, self.reform + ) + return self._analysis + + def build(self) -> SingleYearMacroOutput: + poverty_outputs = self._build_poverty_outputs() + wealth_decile = getattr(self.analysis, "wealth_decile_impacts", None) + intra_wealth_decile = getattr( + self.analysis, "intra_wealth_decile_impacts", None + ) + + return SingleYearMacroOutput( + model_version=self._model_version(), + data_version=self._data_version(), + budget=self._build_budgetary_impact(), + detailed_budget=self._build_detailed_budget(), + decile=self._build_decile(), + inequality=self._build_inequality(), + poverty=poverty_outputs.poverty, + poverty_by_gender=poverty_outputs.poverty_by_gender, + poverty_by_race=poverty_outputs.poverty_by_race, + intra_decile=self._build_intra_decile_output(), + wealth_decile=self._build_wealth_decile(wealth_decile), + intra_wealth_decile=self._build_intra_wealth_decile(intra_wealth_decile), + labor_supply_response=self._build_labor_supply_response(), + congressional_district_impact=(self._build_congressional_district_impact()), + constituency_impact=self._build_uk_constituency_impact(), + local_authority_impact=self._build_uk_local_authority_impact(), + cliff_impact=None, + ) + + def serialize(self) -> dict[str, Any]: + return self.build().model_dump(mode="json") + + def _build_detailed_budget(self) -> DetailedBudgetOutput: + collection = getattr(self.analysis, "program_statistics", None) + if isinstance(collection, DetailedBudgetOutput): + return collection + detailed_budget: dict[str, DetailedBudgetProgramOutput] = {} + for row in _collection_records(collection): + program_name = row.get("program_name") + if not program_name: + continue + baseline = _number(row.get("baseline_total")) + reform = _number(row.get("reform_total")) + detailed_budget[str(program_name)] = DetailedBudgetProgramOutput( + baseline=baseline, + reform=reform, + difference=_number(row.get("change"), reform - baseline), + ) + return DetailedBudgetOutput(detailed_budget) + + def _build_decile(self) -> DecileOutput: + return self._build_decile_output(getattr(self.analysis, "decile_impacts", None)) + + def _build_inequality(self) -> InequalityOutput: + baseline = getattr(self.analysis, "baseline_inequality", None) + reform = getattr(self.analysis, "reform_inequality", None) + if isinstance(baseline, InequalityOutput): + return baseline + return InequalityOutput( + gini=BaselineReformValue( + baseline=_number(getattr(baseline, "gini", None)), + reform=_number(getattr(reform, "gini", None)), + ), + top_10_pct_share=BaselineReformValue( + baseline=_number(getattr(baseline, "top_10_share", None)), + reform=_number(getattr(reform, "top_10_share", None)), + ), + top_1_pct_share=BaselineReformValue( + baseline=_number(getattr(baseline, "top_1_share", None)), + reform=_number(getattr(reform, "top_1_share", None)), + ), + ) + + def _build_budgetary_impact(self) -> BudgetaryImpact: + tax_revenue_impact = _try_change_output_variable( + self.baseline, self.reform, "household_tax", entity="household" + ) + benefit_spending_impact = _try_change_output_variable( + self.baseline, self.reform, "household_benefits", entity="household" + ) + state_tax_revenue_impact = ( + _try_change_output_variable( + self.baseline, + self.reform, + "household_state_income_tax", + entity="household", + ) + if self.country == "us" + else 0.0 + ) + + return BudgetaryImpact( + tax_revenue_impact=tax_revenue_impact, + state_tax_revenue_impact=state_tax_revenue_impact, + benefit_spending_impact=benefit_spending_impact, + budgetary_impact=tax_revenue_impact - benefit_spending_impact, + households=_try_sum_output_variable( + self.baseline, "household_weight", entity="household" + ), + baseline_net_income=_try_sum_output_variable( + self.baseline, "household_net_income", entity="household" + ), + ) + + def _build_poverty_outputs(self) -> PovertyModuleOutputs: + prefix = "us" if self.country == "us" else "uk" + baseline_poverty_by_age = _try_compute_output( + "baseline poverty by age", + _poverty_module_function(f"calculate_{prefix}_poverty_by_age"), + self.baseline, + ) + reform_poverty_by_age = _try_compute_output( + "reform poverty by age", + _poverty_module_function(f"calculate_{prefix}_poverty_by_age"), + self.reform, + ) + baseline_poverty_by_gender = _try_compute_output( + "baseline poverty by gender", + _poverty_module_function(f"calculate_{prefix}_poverty_by_gender"), + self.baseline, + ) + reform_poverty_by_gender = _try_compute_output( + "reform poverty by gender", + _poverty_module_function(f"calculate_{prefix}_poverty_by_gender"), + self.reform, + ) + baseline_poverty_by_race = None + reform_poverty_by_race = None + if self.country == "us": + baseline_poverty_by_race = _try_compute_output( + "baseline poverty by race", + _poverty_module_function("calculate_us_poverty_by_race"), + self.baseline, + ) + reform_poverty_by_race = _try_compute_output( + "reform poverty by race", + _poverty_module_function("calculate_us_poverty_by_race"), + self.reform, + ) + return PovertyModuleOutputs( + poverty=self._build_poverty_output( + baseline=getattr(self.analysis, "baseline_poverty", None), + reform=getattr(self.analysis, "reform_poverty", None), + baseline_by_age=baseline_poverty_by_age, + reform_by_age=reform_poverty_by_age, + ), + poverty_by_gender=self._build_poverty_by_gender_output( + baseline_by_gender=baseline_poverty_by_gender, + reform_by_gender=reform_poverty_by_gender, + ), + poverty_by_race=( + self._build_poverty_by_race_output( + baseline_by_race=baseline_poverty_by_race, + reform_by_race=reform_poverty_by_race, + ) + if self.country == "us" + else None + ), + ) + + def _build_intra_decile_output(self) -> IntraDecileOutput: + from policyengine.outputs.intra_decile_impact import ( + compute_intra_decile_impacts, + ) + + collection = _try_compute_output( + "intra-decile impacts", + compute_intra_decile_impacts, + self.baseline, + self.reform, + income_variable="household_net_income", + entity="household", + ) + return self._build_intra_decile_output_from_collection(collection) + + def _build_wealth_decile(self, wealth_decile) -> DecileOutput | None: + if self.country != "uk": + return None + return self._build_decile_output(wealth_decile) + + def _build_intra_wealth_decile( + self, intra_wealth_decile + ) -> IntraDecileOutput | None: + if self.country != "uk": + return None + return self._build_intra_decile_output_from_collection(intra_wealth_decile) + + def _build_labor_supply_response(self) -> LaborSupplyResponseOutput | None: + labor_supply_response = getattr(self.analysis, "labor_supply_response", None) + if isinstance(labor_supply_response, LaborSupplyResponseOutput): + return labor_supply_response + output = _output_model_dump(labor_supply_response) + return LaborSupplyResponseOutput(output) if isinstance(output, dict) else None + + def _build_geographic_impact_output( + self, value: Any + ) -> GeographicImpactOutput | None: + if isinstance(value, GeographicImpactOutput): + return value + records = _output_model_dump(value) + if isinstance(records, list): + return GeographicImpactOutput( + [dict(item) for item in records if isinstance(item, Mapping)] + ) + if isinstance(value, list): + return GeographicImpactOutput( + [dict(item) for item in value if isinstance(item, Mapping)] + ) + return None + + def _build_decile_output(self, collection: Any) -> DecileOutput: + if isinstance(collection, DecileOutput): + return collection + average: dict[str, float] = {} + relative: dict[str, float] = {} + for row in sorted( + _collection_records(collection), + key=lambda item: _number(item.get("decile")), + ): + decile = int(_number(row.get("decile"))) + if decile <= 0: + continue + key = str(decile) + average[key] = _number(row.get("absolute_change")) + relative[key] = _number(row.get("relative_change")) + return DecileOutput(average=average, relative=relative) + + def _build_intra_decile_output_from_collection( + self, collection: Any + ) -> IntraDecileOutput: + if isinstance(collection, IntraDecileOutput): + return collection + deciles: dict[str, list[float]] = {label: [] for label in INTRA_DECILE_COLUMNS} + all_values: dict[str, float] = {label: 0.0 for label in INTRA_DECILE_COLUMNS} + rows = [ + row + for row in sorted( + _collection_records(collection), + key=lambda item: _number(item.get("decile")), + ) + if int(_number(row.get("decile"))) > 0 + ] + + for label, column in INTRA_DECILE_COLUMNS.items(): + values = [_number(row.get(column)) for row in rows] + deciles[label] = values + all_values[label] = sum(values) / len(values) if values else 0.0 + return IntraDecileOutput(deciles=deciles, all=all_values) + + def _build_poverty_output( + self, + *, + baseline: Any, + reform: Any, + baseline_by_age: Any, + reform_by_age: Any, + ) -> PovertyOutput: + if isinstance(baseline, PovertyOutput): + return baseline + result = { + "poverty": _empty_age_poverty(), + "deep_poverty": _empty_age_poverty(), + } + _fill_poverty_block( + country=self.country, + output=result, + baseline_records=_collection_records(baseline), + reform_records=_collection_records(reform), + default_group="all", + ) + _fill_poverty_block( + country=self.country, + output=result, + baseline_records=_collection_records(baseline_by_age), + reform_records=_collection_records(reform_by_age), + default_group="all", + ) + return PovertyOutput( + poverty=_age_poverty_output(result["poverty"]), + deep_poverty=_age_poverty_output(result["deep_poverty"]), + ) + + def _build_poverty_by_gender_output( + self, + *, + baseline_by_gender: Any, + reform_by_gender: Any, + ) -> PovertyByGenderOutput: + if isinstance(baseline_by_gender, PovertyByGenderOutput): + return baseline_by_gender + result = { + "poverty": _empty_gender_poverty(), + "deep_poverty": _empty_gender_poverty(), + } + _fill_poverty_block( + country=self.country, + output=result, + baseline_records=_collection_records(baseline_by_gender), + reform_records=_collection_records(reform_by_gender), + default_group="all", + ) + return PovertyByGenderOutput( + poverty=_gender_poverty_output(result["poverty"]), + deep_poverty=_gender_poverty_output(result["deep_poverty"]), + ) + + def _build_poverty_by_race_output( + self, + *, + baseline_by_race: Any, + reform_by_race: Any, + ) -> PovertyByRaceOutput: + if isinstance(baseline_by_race, PovertyByRaceOutput): + return baseline_by_race + result = { + "poverty": { + "white": _empty_baseline_reform_value(), + "black": _empty_baseline_reform_value(), + "hispanic": _empty_baseline_reform_value(), + "other": _empty_baseline_reform_value(), + } + } + _fill_poverty_block( + country="us", + output=result, + baseline_records=_collection_records(baseline_by_race), + reform_records=_collection_records(reform_by_race), + default_group="all", + ) + return PovertyByRaceOutput(poverty=_race_poverty_output(result["poverty"])) + + def _build_congressional_district_impact( + self, + ) -> GeographicImpactOutput | None: + if self.country != "us": + return None + + from policyengine.outputs.congressional_district_impact import ( + compute_us_congressional_district_impacts, + ) + + impact = _try_compute_output( + "congressional district impacts", + compute_us_congressional_district_impacts, + self.baseline, + self.reform, + ) + return self._build_geographic_impact_output( + getattr(impact, "district_results", None) if impact is not None else None + ) + + def _build_uk_constituency_impact(self) -> GeographicImpactOutput | None: + if self.country != "uk": + return None + + impact = _try_compute_output( + "constituency impacts", + _output_module_function( + "constituency_impact", "compute_uk_constituency_impacts" + ), + self.baseline, + self.reform, + ) + if impact is None: + return None + return self._build_geographic_impact_output( + getattr(impact, "constituency_results", None) + ) + + def _build_uk_local_authority_impact(self) -> GeographicImpactOutput | None: + if self.country != "uk": + return None + + impact = _try_compute_output( + "local authority impacts", + _output_module_function( + "local_authority_impact", "compute_uk_local_authority_impacts" + ), + self.baseline, + self.reform, + ) + if impact is None: + return None + return self._build_geographic_impact_output( + getattr(impact, "local_authority_results", None) + ) + + def _model_version(self) -> str: + return str(getattr(self.country_module.model, "version", "")) + + def _data_version(self) -> str: + if self.simulation_params.get("data_version"): + return str(self.simulation_params["data_version"]) + try: + return get_country_release_bundle(self.country).data_version + except ValueError: + pass + metadata = getattr(self.dataset, "metadata", {}) or {} + for key in ("data_version", "version"): + value = metadata.get(key) + if value is not None: + return str(value) + return "" diff --git a/projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py b/projects/policyengine-api-simulation/tests/test_simulation_output_builder.py similarity index 56% rename from projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py rename to projects/policyengine-api-simulation/tests/test_simulation_output_builder.py index a094ed385..9f0fb3ef6 100644 --- a/projects/policyengine-api-simulation/tests/test_simulation_output_adapter.py +++ b/projects/policyengine-api-simulation/tests/test_simulation_output_builder.py @@ -1,4 +1,4 @@ -"""Tests for translating PolicyEngine v4 outputs into API-v2 macro results.""" +"""Tests for building PolicyEngine v4 outputs into API-v2 macro results.""" from __future__ import annotations @@ -7,11 +7,10 @@ import pandas as pd from fixtures.test_simulation_api_contracts import CURRENT_SINGLE_YEAR_MACRO_KEYS -from fixtures.test_simulation_output_adapter import ( +from fixtures.test_simulation_output_builder import ( BASELINE_POVERTY_BY_AGE, BASELINE_POVERTY_BY_GENDER, BASELINE_POVERTY_BY_RACE, - BUDGET, INTRA_DECILE_COLLECTION, REFORM_POVERTY_BY_AGE, REFORM_POVERTY_BY_GENDER, @@ -19,7 +18,6 @@ fake_analysis, ) from src.modal.simulation import _normalise_policy -from src.modal.simulation_macro_output_builder import SimulationMacroOutputBuilder from src.modal.simulation_macro_output import ( BudgetaryImpact, BudgetaryOutput, @@ -30,32 +28,147 @@ PovertyOutput, SingleYearMacroOutput, ) -from src.modal.simulation_output_adapter import ( - SingleYearMacroOutputBuilder, - adapt_analysis_to_legacy_macro_output, -) +from src.modal.simulation_output_builder import SimulationOutputBuilder -def _build_schema_output() -> SingleYearMacroOutput: - return SingleYearMacroOutputBuilder( - country="us", - model_version="1.700.0", - data_version="1.115.5", - budget=BUDGET, - analysis=fake_analysis(), - baseline_poverty_by_age=BASELINE_POVERTY_BY_AGE, - reform_poverty_by_age=REFORM_POVERTY_BY_AGE, - baseline_poverty_by_gender=BASELINE_POVERTY_BY_GENDER, - reform_poverty_by_gender=REFORM_POVERTY_BY_GENDER, - baseline_poverty_by_race=BASELINE_POVERTY_BY_RACE, - reform_poverty_by_race=REFORM_POVERTY_BY_RACE, - intra_decile=INTRA_DECILE_COLLECTION, - congressional_district_impact=[{"district_geoid": 101}], - ).build() - - -def test_builder_returns_schema_modules_before_legacy_dict_dump(): - output = _build_schema_output() +class _FakeOutputDataset: + def __init__(self, household): + self.data = SimpleNamespace(household=household) + + +class _FakeSimulation: + def __init__(self, household): + self.output_dataset = _FakeOutputDataset(household) + + def ensure(self): + raise AssertionError("test data is already materialized") + + +def _macro_baseline_reform(): + baseline = _FakeSimulation( + pd.DataFrame( + { + "household_weight": [1.0, 1.0], + "household_net_income": [400.0, 600.0], + "household_tax": [50.0, 50.0], + "household_benefits": [20.0, 30.0], + "household_state_income_tax": [5.0, 5.0], + } + ) + ) + reform = _FakeSimulation( + pd.DataFrame( + { + "household_weight": [1.0, 1.0], + "household_net_income": [410.0, 620.0], + "household_tax": [100.0, 100.0], + "household_benefits": [35.0, 45.0], + "household_state_income_tax": [15.0, 15.0], + } + ) + ) + return baseline, reform + + +def _simulation_output_builder( + country: str, + baseline, + reform, + analysis=None, +) -> SimulationOutputBuilder: + analysis = analysis or fake_analysis() + country_module = SimpleNamespace( + model=SimpleNamespace(version="1.700.0" if country == "us" else "2.88.20"), + economic_impact_analysis=lambda baseline_simulation, reform_simulation: analysis, + ) + return SimulationOutputBuilder( + country=country, + simulation_params={ + "country": country, + "data_version": "1.115.5" if country == "us" else "1.55.10", + }, + country_module=country_module, + dataset=SimpleNamespace(metadata={}), + baseline=baseline, + reform=reform, + ) + + +def _stub_policyengine_output_calls(monkeypatch, baseline, reform) -> None: + def fake_poverty_module_function(name): + def compute(simulation): + if "by_age" in name: + return ( + BASELINE_POVERTY_BY_AGE + if simulation is baseline + else REFORM_POVERTY_BY_AGE + ) + if "by_gender" in name: + return ( + BASELINE_POVERTY_BY_GENDER + if simulation is baseline + else REFORM_POVERTY_BY_GENDER + ) + if "by_race" in name: + return ( + BASELINE_POVERTY_BY_RACE + if simulation is baseline + else REFORM_POVERTY_BY_RACE + ) + raise AssertionError(f"Unexpected poverty output: {name}") + + return compute + + monkeypatch.setattr( + "src.modal.simulation_output_builder._poverty_module_function", + fake_poverty_module_function, + ) + monkeypatch.setattr( + SimulationOutputBuilder, + "_build_intra_decile_output", + lambda self: self._build_intra_decile_output_from_collection( + INTRA_DECILE_COLLECTION + ), + ) + monkeypatch.setattr( + SimulationOutputBuilder, + "_build_congressional_district_impact", + lambda self: ( + self._build_geographic_impact_output([{"district_geoid": 101}]) + if self.country == "us" + else None + ), + ) + monkeypatch.setattr( + SimulationOutputBuilder, + "_build_uk_constituency_impact", + lambda self: ( + self._build_geographic_impact_output([{"constituency_code": "E14000530"}]) + if self.country == "uk" + else None + ), + ) + monkeypatch.setattr( + SimulationOutputBuilder, + "_build_uk_local_authority_impact", + lambda self: ( + self._build_geographic_impact_output( + [{"local_authority_code": "E06000001"}] + ) + if self.country == "uk" + else None + ), + ) + + +def _build_schema_output(monkeypatch, *, country: str = "us") -> SingleYearMacroOutput: + baseline, reform = _macro_baseline_reform() + _stub_policyengine_output_calls(monkeypatch, baseline, reform) + return _simulation_output_builder(country, baseline, reform).build() + + +def test_builder_returns_schema_modules_before_legacy_dict_dump(monkeypatch): + output = _build_schema_output(monkeypatch) assert isinstance(output, SingleYearMacroOutput) assert isinstance(output.budget, BudgetaryOutput) @@ -68,31 +181,21 @@ def test_builder_returns_schema_modules_before_legacy_dict_dump(): assert output.wealth_decile is None assert output.congressional_district_impact.root == [{"district_geoid": 101}] - legacy_output = adapt_analysis_to_legacy_macro_output( - country="us", - model_version="1.700.0", - data_version="1.115.5", - budget=BUDGET, - analysis=fake_analysis(), - baseline_poverty_by_age=BASELINE_POVERTY_BY_AGE, - reform_poverty_by_age=REFORM_POVERTY_BY_AGE, - baseline_poverty_by_gender=BASELINE_POVERTY_BY_GENDER, - reform_poverty_by_gender=REFORM_POVERTY_BY_GENDER, - baseline_poverty_by_race=BASELINE_POVERTY_BY_RACE, - reform_poverty_by_race=REFORM_POVERTY_BY_RACE, - intra_decile=INTRA_DECILE_COLLECTION, - congressional_district_impact=[{"district_geoid": 101}], - ) - assert output.model_dump(mode="json") == legacy_output - -def test_adapter_returns_existing_single_year_macro_shape(): - output = _build_schema_output().model_dump(mode="json") +def test_builder_returns_existing_single_year_macro_shape(monkeypatch): + output = _build_schema_output(monkeypatch).model_dump(mode="json") assert set(output) == CURRENT_SINGLE_YEAR_MACRO_KEYS assert output["model_version"] == "1.700.0" assert output["data_version"] == "1.115.5" - assert output["budget"] == BUDGET + assert output["budget"] == { + "tax_revenue_impact": 100.0, + "state_tax_revenue_impact": 20.0, + "benefit_spending_impact": 30.0, + "budgetary_impact": 70.0, + "households": 2.0, + "baseline_net_income": 1000.0, + } assert output["detailed_budget"] == { "income_tax": {"baseline": 100.0, "reform": 125.0, "difference": 25.0} } @@ -123,17 +226,8 @@ def test_adapter_returns_existing_single_year_macro_shape(): assert output["congressional_district_impact"] == [{"district_geoid": 101}] -def test_adapter_maps_uk_wealth_outputs_and_omits_us_only_race(): - output = adapt_analysis_to_legacy_macro_output( - country="uk", - model_version="2.88.20", - data_version="1.55.10", - budget={**BUDGET, "state_tax_revenue_impact": 0.0}, - analysis=fake_analysis(), - intra_decile=INTRA_DECILE_COLLECTION, - constituency_impact=[{"constituency_code": "E14000530"}], - local_authority_impact=[{"local_authority_code": "E06000001"}], - ) +def test_builder_maps_uk_wealth_outputs_and_omits_us_only_race(monkeypatch): + output = _build_schema_output(monkeypatch, country="uk").model_dump(mode="json") assert output["poverty_by_race"] is None assert output["wealth_decile"] == { @@ -145,40 +239,34 @@ def test_adapter_maps_uk_wealth_outputs_and_omits_us_only_race(): assert output["local_authority_impact"] == [{"local_authority_code": "E06000001"}] -def test_normalise_policy_converts_legacy_period_range_keys(): - assert _normalise_policy({"gov.test.parameter": {"2026-01-01.2100-12-31": 1}}) == { - "gov.test.parameter": {"2026-01-01": 1} - } - - -def _simulation_output_builder( - country: str, - baseline, - reform, - analysis=None, -) -> SimulationMacroOutputBuilder: - return SimulationMacroOutputBuilder( - country=country, - simulation_params={"country": country}, - country_module=SimpleNamespace(model=SimpleNamespace(version="test")), +def test_builder_calls_policyengine_economic_impact_analysis(): + baseline, reform = _macro_baseline_reform() + analysis = fake_analysis() + calls = [] + country_module = SimpleNamespace( + model=SimpleNamespace(version="1.700.0"), + economic_impact_analysis=lambda baseline_simulation, reform_simulation: ( + calls.append((baseline_simulation, reform_simulation)) or analysis + ), + ) + builder = SimulationOutputBuilder( + country="us", + simulation_params={"country": "us", "data_version": "1.115.5"}, + country_module=country_module, dataset=SimpleNamespace(metadata={}), baseline=baseline, reform=reform, - analysis=analysis or fake_analysis(), ) + assert builder.analysis is analysis + assert builder.analysis is analysis + assert calls == [(baseline, reform)] -class _FakeOutputDataset: - def __init__(self, household): - self.data = SimpleNamespace(household=household) - -class _FakeSimulation: - def __init__(self, household): - self.output_dataset = _FakeOutputDataset(household) - - def ensure(self): - raise AssertionError("test data is already materialized") +def test_normalise_policy_converts_legacy_period_range_keys(): + assert _normalise_policy({"gov.test.parameter": {"2026-01-01.2100-12-31": 1}}) == { + "gov.test.parameter": {"2026-01-01": 1} + } def test_builder_budgetary_impact_uses_materialized_columns_and_uk_state_tax_zero(): @@ -241,7 +329,7 @@ def compute(baseline_simulation, reform_simulation): return compute monkeypatch.setattr( - "src.modal.simulation_macro_output_builder._output_module_function", + "src.modal.simulation_output_builder._output_module_function", fake_output_module_function, ) @@ -276,7 +364,7 @@ def compute(baseline_simulation, reform_simulation): return compute monkeypatch.setattr( - "src.modal.simulation_macro_output_builder._output_module_function", + "src.modal.simulation_output_builder._output_module_function", fake_output_module_function, ) From c84ea3b990376d8044391ff2068b481b2030844f Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 22 May 2026 16:17:05 +0200 Subject: [PATCH 14/23] fix: preserve simulation API contracts --- .../fixtures/test_simulation_api_contracts.py | 88 +++++++------ .../src/modal/app.py | 12 +- .../src/modal/dependency_pins.py | 25 ++++ .../src/modal/gateway/endpoints.py | 11 +- .../src/modal/release_bundle.py | 11 +- .../src/modal/simulation_output_builder.py | 10 +- .../compat_models.py | 32 +++++ .../policyengine_api_simulation/simulation.py | 14 +- .../tests/gateway/test_endpoints.py | 19 ++- .../test_policyengine_dependency_source.py | 44 +++++-- .../tests/test_release_bundle.py | 30 +++++ .../tests/test_simulation_api_contracts.py | 6 + .../tests/test_simulation_output_builder.py | 123 ++++++++++++------ .../test_standalone_simulation_contract.py | 41 ++++++ 14 files changed, 342 insertions(+), 124 deletions(-) create mode 100644 projects/policyengine-api-simulation/src/modal/dependency_pins.py create mode 100644 projects/policyengine-api-simulation/src/policyengine_api_simulation/compat_models.py create mode 100644 projects/policyengine-api-simulation/tests/test_standalone_simulation_contract.py diff --git a/projects/policyengine-api-simulation/fixtures/test_simulation_api_contracts.py b/projects/policyengine-api-simulation/fixtures/test_simulation_api_contracts.py index ba3df0675..eb6489de9 100644 --- a/projects/policyengine-api-simulation/fixtures/test_simulation_api_contracts.py +++ b/projects/policyengine-api-simulation/fixtures/test_simulation_api_contracts.py @@ -33,74 +33,75 @@ "model_version": "1.700.0", "data_version": "1.115.5", "budget": { - "budgetary_impact": 300.0, - "tax_revenue_impact": 500.0, - "state_tax_revenue_impact": 125.0, - "benefit_spending_impact": 200.0, + "budgetary_impact": 70.0, + "tax_revenue_impact": 100.0, + "state_tax_revenue_impact": 20.0, + "benefit_spending_impact": 30.0, "households": 2.0, "baseline_net_income": 1000.0, }, "detailed_budget": { "income_tax": { - "baseline": 1000.0, - "reform": 1100.0, - "difference": 100.0, + "baseline": 100.0, + "reform": 125.0, + "difference": 25.0, } }, "decile": { - "relative": {"1": 0.01}, - "average": {"1": 10.0}, + "relative": {"1": 0.01, "2": 0.02}, + "average": {"1": 10.0, "2": 20.0}, }, "inequality": { - "baseline": {"gini": 0.3}, - "reform": {"gini": 0.29}, + "gini": {"baseline": 0.4, "reform": 0.39}, + "top_10_pct_share": {"baseline": 0.3, "reform": 0.29}, + "top_1_pct_share": {"baseline": 0.1, "reform": 0.09}, }, "poverty": { "poverty": { - "adult": {"baseline": 0.1, "reform": 0.09}, + "adult": {"baseline": 0.0, "reform": 0.0}, "all": {"baseline": 0.1, "reform": 0.09}, - "child": {"baseline": 0.12, "reform": 0.1}, - "senior": {"baseline": 0.08, "reform": 0.07}, + "child": {"baseline": 0.12, "reform": 0.11}, + "senior": {"baseline": 0.0, "reform": 0.0}, }, "deep_poverty": { - "adult": {"baseline": 0.03, "reform": 0.02}, - "all": {"baseline": 0.03, "reform": 0.02}, + "adult": {"baseline": 0.0, "reform": 0.0}, + "all": {"baseline": 0.0, "reform": 0.0}, "child": {"baseline": 0.04, "reform": 0.03}, - "senior": {"baseline": 0.02, "reform": 0.01}, + "senior": {"baseline": 0.0, "reform": 0.0}, }, }, "poverty_by_gender": { "poverty": { - "male": {"baseline": 0.1, "reform": 0.09}, - "female": {"baseline": 0.11, "reform": 0.1}, + "male": {"baseline": 0.08, "reform": 0.07}, + "female": {"baseline": 0.0, "reform": 0.0}, }, "deep_poverty": { - "male": {"baseline": 0.03, "reform": 0.02}, - "female": {"baseline": 0.04, "reform": 0.03}, + "male": {"baseline": 0.0, "reform": 0.0}, + "female": {"baseline": 0.0, "reform": 0.0}, }, }, "poverty_by_race": { "poverty": { - "black": {"baseline": 0.12, "reform": 0.11}, - "hispanic": {"baseline": 0.13, "reform": 0.12}, - "other": {"baseline": 0.1, "reform": 0.09}, - "white": {"baseline": 0.08, "reform": 0.07}, + "black": {"baseline": 0.0, "reform": 0.0}, + "hispanic": {"baseline": 0.0, "reform": 0.0}, + "other": {"baseline": 0.0, "reform": 0.0}, + "white": {"baseline": 0.06, "reform": 0.05}, }, }, "intra_decile": { "all": { - "Gain less than 5%": 0.2, - "Gain more than 5%": 0.1, - "Lose less than 5%": 0.1, - "Lose more than 5%": 0.0, - "No change": 0.6, + "Gain less than 5%": 0.30000000000000004, + "Gain more than 5%": 0.3, + "Lose less than 5%": 0.15000000000000002, + "Lose more than 5%": 0.05, + "No change": 0.44999999999999996, }, "deciles": { - "Gain less than 5%": [0.2], - "Gain more than 5%": [0.1], - "Lose less than 5%": [0.1], - "Lose more than 5%": [0.0], - "No change": [0.6], + "Gain less than 5%": [0.4, 0.2], + "Gain more than 5%": [0.5, 0.1], + "Lose less than 5%": [0.2, 0.1], + "Lose more than 5%": [0.1, 0.0], + "No change": [0.3, 0.6], }, }, "wealth_decile": None, @@ -108,14 +109,23 @@ "labor_supply_response": { "substitution_lsr": 0.0, "income_lsr": 0.0, - "relative_lsr": {}, + "relative_lsr": {"income": 0.0, "substitution": 0.0}, "total_change": 0.0, "revenue_change": 0.0, - "decile": {}, - "hours": {"baseline": 0.0, "reform": 0.0, "change": 0.0}, + "decile": { + "average": {"income": {}, "substitution": {}}, + "relative": {"income": {}, "substitution": {}}, + }, + "hours": { + "baseline": 0.0, + "reform": 0.0, + "change": 0.0, + "income_effect": 0.0, + "substitution_effect": 0.0, + }, }, "constituency_impact": None, "local_authority_impact": None, - "congressional_district_impact": None, + "congressional_district_impact": [{"district_geoid": 101}], "cliff_impact": None, } diff --git a/projects/policyengine-api-simulation/src/modal/app.py b/projects/policyengine-api-simulation/src/modal/app.py index 2056f14b6..5831d6c86 100644 --- a/projects/policyengine-api-simulation/src/modal/app.py +++ b/projects/policyengine-api-simulation/src/modal/app.py @@ -11,9 +11,17 @@ import os from src.modal._image_setup import snapshot_models +from src.modal.dependency_pins import project_dependency_pin from src.modal.logging_redaction import redact_params_for_logging from src.modal.release_bundle import get_bundled_country_model_version +POLICYENGINE_VERSION = os.environ.get("POLICYENGINE_VERSION") or project_dependency_pin( + "policyengine" +) +POLICYENGINE_CORE_VERSION = os.environ.get( + "POLICYENGINE_CORE_VERSION" +) or project_dependency_pin("policyengine-core") + # Get versions from environment or the bundled policyengine.py release manifest. US_VERSION = os.environ.get( "POLICYENGINE_US_VERSION" @@ -54,8 +62,8 @@ def get_app_name(us_version: str, uk_version: str) -> str: .pip_install( f"policyengine-us=={US_VERSION}", f"policyengine-uk=={UK_VERSION}", - "policyengine==4.10.0", - "policyengine-core==3.26.1", + f"policyengine=={POLICYENGINE_VERSION}", + f"policyengine-core=={POLICYENGINE_CORE_VERSION}", "tables>=3.10.2", "logfire", ) diff --git a/projects/policyengine-api-simulation/src/modal/dependency_pins.py b/projects/policyengine-api-simulation/src/modal/dependency_pins.py new file mode 100644 index 000000000..9d99bd613 --- /dev/null +++ b/projects/policyengine-api-simulation/src/modal/dependency_pins.py @@ -0,0 +1,25 @@ +"""Helpers for reading pinned project dependencies.""" + +from __future__ import annotations + +import tomllib +from functools import lru_cache +from pathlib import Path + + +PROJECT_DIR = Path(__file__).resolve().parents[2] +PYPROJECT_PATH = PROJECT_DIR / "pyproject.toml" + + +@lru_cache +def _project_dependencies() -> tuple[str, ...]: + pyproject = tomllib.loads(PYPROJECT_PATH.read_text(encoding="utf-8")) + return tuple(pyproject["project"]["dependencies"]) + + +def project_dependency_pin(package: str) -> str: + prefix = f"{package}==" + for dependency in _project_dependencies(): + if dependency.startswith(prefix): + return dependency.removeprefix(prefix) + raise ValueError(f"Dependency {package!r} is not pinned in {PYPROJECT_PATH}") diff --git a/projects/policyengine-api-simulation/src/modal/gateway/endpoints.py b/projects/policyengine-api-simulation/src/modal/gateway/endpoints.py index 9db5a5442..35dc800a3 100644 --- a/projects/policyengine-api-simulation/src/modal/gateway/endpoints.py +++ b/projects/policyengine-api-simulation/src/modal/gateway/endpoints.py @@ -34,10 +34,7 @@ failed_job_response, running_job_response, ) -from src.modal.release_bundle import ( - get_country_release_bundle, - resolve_bundle_dataset_uri, -) +from src.modal.release_bundle import resolve_bundle_dataset_uri logger = logging.getLogger(__name__) @@ -70,17 +67,15 @@ def _is_modal_job_not_found(exc: BaseException) -> bool: def _build_policyengine_bundle( country: str, resolved_version: str, payload: dict ) -> PolicyEngineBundle: - bundle = get_country_release_bundle(country) dataset = payload.get("data") resolved_dataset = ( resolve_bundle_dataset_uri(country, dataset) - if dataset is None or isinstance(dataset, str) + if isinstance(dataset, str) else None ) return PolicyEngineBundle( model_version=resolved_version, - policyengine_version=bundle.policyengine_version, - data_version=payload.get("data_version") or bundle.data_version, + data_version=payload.get("data_version"), dataset=resolved_dataset, ) diff --git a/projects/policyengine-api-simulation/src/modal/release_bundle.py b/projects/policyengine-api-simulation/src/modal/release_bundle.py index f99dfadba..ef81a37f4 100644 --- a/projects/policyengine-api-simulation/src/modal/release_bundle.py +++ b/projects/policyengine-api-simulation/src/modal/release_bundle.py @@ -20,18 +20,18 @@ "us": { "enhanced_cps": "enhanced_cps_2024", "enhanced_cps_2024": "enhanced_cps_2024", - "gs://policyengine-us-data/enhanced_cps_2024.h5": "enhanced_cps_2024", - "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5": "enhanced_cps_2024", "cps_small": "cps_small_2024", "cps_small_2024": "cps_small_2024", + "cps": "hf://policyengine/policyengine-us-data/cps_2023.h5@1.110.12", + "cps_2023": "hf://policyengine/policyengine-us-data/cps_2023.h5@1.110.12", + "pooled_cps": "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.110.12", + "pooled_3_year_cps_2023": "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.110.12", }, "uk": { "enhanced_frs": "enhanced_frs_2023_24", "enhanced_frs_2023_24": "enhanced_frs_2023_24", - "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5": "enhanced_frs_2023_24", "frs": "frs_2023_24", "frs_2023_24": "frs_2023_24", - "hf://policyengine/policyengine-uk-data-private/frs_2023_24.h5": "frs_2023_24", }, } @@ -99,6 +99,9 @@ def resolve_bundle_dataset_name(country: str, requested_data: str | None) -> str if requested_data is None: return bundle.default_dataset + if "://" in requested_data or "@" in requested_data: + return requested_data + requested_without_revision = requested_data.split("@", maxsplit=1)[0] aliased = DATASET_ALIASES.get(bundle.country, {}).get( requested_without_revision, requested_data diff --git a/projects/policyengine-api-simulation/src/modal/simulation_output_builder.py b/projects/policyengine-api-simulation/src/modal/simulation_output_builder.py index e7ef45a0e..2b4c59aaa 100644 --- a/projects/policyengine-api-simulation/src/modal/simulation_output_builder.py +++ b/projects/policyengine-api-simulation/src/modal/simulation_output_builder.py @@ -327,14 +327,14 @@ def _build_inequality(self) -> InequalityOutput: ) def _build_budgetary_impact(self) -> BudgetaryImpact: - tax_revenue_impact = _try_change_output_variable( + tax_revenue_impact = _change_output_variable( self.baseline, self.reform, "household_tax", entity="household" ) - benefit_spending_impact = _try_change_output_variable( + benefit_spending_impact = _change_output_variable( self.baseline, self.reform, "household_benefits", entity="household" ) state_tax_revenue_impact = ( - _try_change_output_variable( + _change_output_variable( self.baseline, self.reform, "household_state_income_tax", @@ -349,10 +349,10 @@ def _build_budgetary_impact(self) -> BudgetaryImpact: state_tax_revenue_impact=state_tax_revenue_impact, benefit_spending_impact=benefit_spending_impact, budgetary_impact=tax_revenue_impact - benefit_spending_impact, - households=_try_sum_output_variable( + households=_sum_output_variable( self.baseline, "household_weight", entity="household" ), - baseline_net_income=_try_sum_output_variable( + baseline_net_income=_sum_output_variable( self.baseline, "household_net_income", entity="household" ), ) diff --git a/projects/policyengine-api-simulation/src/policyengine_api_simulation/compat_models.py b/projects/policyengine-api-simulation/src/policyengine_api_simulation/compat_models.py new file mode 100644 index 000000000..e69112943 --- /dev/null +++ b/projects/policyengine-api-simulation/src/policyengine_api_simulation/compat_models.py @@ -0,0 +1,32 @@ +"""Compatibility schemas for the live synchronous simulation surface.""" + +from __future__ import annotations + +from typing import Any, Optional + +from pydantic import BaseModel, ConfigDict + +from src.modal.simulation_macro_output import SingleYearMacroOutput + + +class SimulationOptions(BaseModel): + """Legacy request schema name kept for generated clients.""" + + country: str + scope: Optional[str] = None + data: Optional[str] = None + time_period: Optional[str | int] = None + reform: Optional[dict[str, Any]] = None + baseline: Optional[dict[str, Any]] = None + region: Optional[str] = None + subsample: Optional[int] = None + title: Optional[str] = None + include_cliffs: Optional[bool] = None + model_version: Optional[str] = None + data_version: Optional[str] = None + + model_config = ConfigDict(extra="forbid") + + +class EconomyComparison(SingleYearMacroOutput): + """Legacy response schema name kept for generated clients.""" diff --git a/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation.py b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation.py index 9b8567f66..5856c9460 100644 --- a/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation.py +++ b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation.py @@ -3,6 +3,10 @@ from fastapi import APIRouter from src.modal.simulation import run_simulation_impl +from policyengine_api_simulation.compat_models import ( + EconomyComparison, + SimulationOptions, +) logger = logging.getLogger(__file__) @@ -10,11 +14,13 @@ def create_router(): router = APIRouter() - @router.post("/simulate/economy/comparison", response_model=dict) - async def simulate(parameters: dict) -> dict: + @router.post("/simulate/economy/comparison", response_model=EconomyComparison) + async def simulate(parameters: SimulationOptions) -> EconomyComparison: logger.info("Calculating comparison") - result = run_simulation_impl(parameters) + result = run_simulation_impl( + parameters.model_dump(mode="json", exclude_none=True) + ) logger.info("Comparison complete") - return result + return EconomyComparison.model_validate(result) return router diff --git a/projects/policyengine-api-simulation/tests/gateway/test_endpoints.py b/projects/policyengine-api-simulation/tests/gateway/test_endpoints.py index 393fdfc5a..9faab2d6f 100644 --- a/projects/policyengine-api-simulation/tests/gateway/test_endpoints.py +++ b/projects/policyengine-api-simulation/tests/gateway/test_endpoints.py @@ -8,10 +8,7 @@ import pytest from fastapi.testclient import TestClient -from src.modal.release_bundle import ( - get_country_release_bundle, - resolve_bundle_dataset_uri, -) +from src.modal.release_bundle import resolve_bundle_dataset_uri def expected_bundle( @@ -20,14 +17,16 @@ def expected_bundle( *, dataset: str | None = None, data_version: str | None = None, -) -> dict[str, str]: - bundle = get_country_release_bundle(country) - return { +) -> dict[str, str | None]: + bundle: dict[str, str | None] = { "model_version": model_version, - "policyengine_version": bundle.policyengine_version, - "data_version": data_version or bundle.data_version, - "dataset": resolve_bundle_dataset_uri(country, dataset), + "dataset": resolve_bundle_dataset_uri(country, dataset) + if dataset is not None + else None, } + if data_version is not None: + bundle["data_version"] = data_version + return {key: value for key, value in bundle.items() if value is not None} class TestGetAppName: diff --git a/projects/policyengine-api-simulation/tests/test_policyengine_dependency_source.py b/projects/policyengine-api-simulation/tests/test_policyengine_dependency_source.py index ebb6e6fae..2927a7b92 100644 --- a/projects/policyengine-api-simulation/tests/test_policyengine_dependency_source.py +++ b/projects/policyengine-api-simulation/tests/test_policyengine_dependency_source.py @@ -1,6 +1,5 @@ """Regression tests for the policyengine dependency version configuration.""" -import re import tomllib from pathlib import Path @@ -8,6 +7,7 @@ PYPROJECT_PATH = REPO_ROOT / "pyproject.toml" MODAL_APP_PATH = REPO_ROOT / "src" / "modal" / "app.py" POLICYENGINE_DEPENDENCY_PREFIX = "policyengine==" +POLICYENGINE_CORE_DEPENDENCY_PREFIX = "policyengine-core==" COUNTRY_PACKAGES = { "us": "policyengine-us", "uk": "policyengine-uk", @@ -26,30 +26,50 @@ def _get_pyproject_policyengine_dependency(pyproject: dict) -> str: ) -def _get_dependency_pin(pyproject: dict, package: str) -> str: +def _get_pyproject_policyengine_core_dependency(pyproject: dict) -> str: dependencies = pyproject["project"]["dependencies"] - prefix = f"{package}==" return next( - dep.removeprefix(prefix) for dep in dependencies if dep.startswith(prefix) + dep + for dep in dependencies + if dep.startswith(POLICYENGINE_CORE_DEPENDENCY_PREFIX) ) -def _get_modal_policyengine_dependency(modal_source: str) -> str: - match = re.search( - r'"(policyengine==[^"]+)"', - modal_source, +def _get_dependency_pin(pyproject: dict, package: str) -> str: + dependencies = pyproject["project"]["dependencies"] + prefix = f"{package}==" + return next( + dep.removeprefix(prefix) for dep in dependencies if dep.startswith(prefix) ) - assert match is not None, "Modal app should install a pinned policyengine version" - return match.group(1) def test_policyengine_dependency_version_is_pinned_consistently(): + from src.modal.dependency_pins import project_dependency_pin + pyproject = _load_toml(PYPROJECT_PATH) pyproject_dependency = _get_pyproject_policyengine_dependency(pyproject) - modal_dependency = _get_modal_policyengine_dependency(MODAL_APP_PATH.read_text()) + pyproject_core_dependency = _get_pyproject_policyengine_core_dependency(pyproject) assert pyproject_dependency.startswith(POLICYENGINE_DEPENDENCY_PREFIX) - assert modal_dependency == pyproject_dependency + assert pyproject_core_dependency.startswith(POLICYENGINE_CORE_DEPENDENCY_PREFIX) + assert ( + f"policyengine=={project_dependency_pin('policyengine')}" + == pyproject_dependency + ) + assert ( + f"policyengine-core=={project_dependency_pin('policyengine-core')}" + == pyproject_core_dependency + ) + + +def test_modal_app_reads_policyengine_pins_from_pyproject(): + modal_source = MODAL_APP_PATH.read_text(encoding="utf-8") + + assert '"policyengine==4.10.0"' not in modal_source + assert '"policyengine-core==3.26.1"' not in modal_source + assert "project_dependency_pin" in modal_source + assert '"policyengine"' in modal_source + assert '"policyengine-core"' in modal_source def test_country_package_pins_match_policyengine_bundle(): diff --git a/projects/policyengine-api-simulation/tests/test_release_bundle.py b/projects/policyengine-api-simulation/tests/test_release_bundle.py index 3a390041e..4257c6293 100644 --- a/projects/policyengine-api-simulation/tests/test_release_bundle.py +++ b/projects/policyengine-api-simulation/tests/test_release_bundle.py @@ -43,6 +43,36 @@ def test_resolve_bundle_dataset_uri_maps_known_aliases_to_manifest_uris(): ) +def test_resolve_bundle_dataset_uri_preserves_explicit_dataset_uri_and_revision(): + uri = "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.110.12" + + assert resolve_bundle_dataset_name("us", uri) == uri + assert resolve_bundle_dataset_uri("us", uri) == uri + + +def test_resolve_bundle_dataset_uri_preserves_explicit_logical_revision(): + dataset = "enhanced_cps_2024@1.110.12" + + assert resolve_bundle_dataset_name("us", dataset) == dataset + assert resolve_bundle_dataset_uri("us", dataset) == dataset + + +def test_resolve_bundle_dataset_uri_preserves_explicit_gcs_uri(): + uri = "gs://policyengine-us-data/enhanced_cps_2024.h5" + + assert resolve_bundle_dataset_name("us", uri) == uri + assert resolve_bundle_dataset_uri("us", uri) == uri + + +def test_resolve_bundle_dataset_uri_supports_legacy_us_aliases(): + assert resolve_bundle_dataset_uri("us", "cps") == ( + "hf://policyengine/policyengine-us-data/cps_2023.h5@1.110.12" + ) + assert resolve_bundle_dataset_uri("us", "pooled_cps") == ( + "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.110.12" + ) + + def test_resolve_bundle_dataset_uri_preserves_unmanaged_unknown_values(): assert resolve_bundle_dataset_uri("us", "custom_dataset_label") == ( "custom_dataset_label" diff --git a/projects/policyengine-api-simulation/tests/test_simulation_api_contracts.py b/projects/policyengine-api-simulation/tests/test_simulation_api_contracts.py index 57465e610..d6418bc82 100644 --- a/projects/policyengine-api-simulation/tests/test_simulation_api_contracts.py +++ b/projects/policyengine-api-simulation/tests/test_simulation_api_contracts.py @@ -34,6 +34,12 @@ def test_internal_single_year_macro_schema_matches_current_public_keys(): assert set(SingleYearMacroOutput.model_fields) == CURRENT_SINGLE_YEAR_MACRO_KEYS +def test_internal_single_year_macro_schema_serializes_current_public_contract(): + output = SingleYearMacroOutput.model_validate(CURRENT_SINGLE_YEAR_MACRO_RESULT) + + assert output.model_dump(mode="json") == CURRENT_SINGLE_YEAR_MACRO_RESULT + + def test_openapi_keeps_job_status_result_as_unstructured_dict(): spec = create_openapi_app().openapi() schemas = spec["components"]["schemas"] diff --git a/projects/policyengine-api-simulation/tests/test_simulation_output_builder.py b/projects/policyengine-api-simulation/tests/test_simulation_output_builder.py index 9f0fb3ef6..c5558a396 100644 --- a/projects/policyengine-api-simulation/tests/test_simulation_output_builder.py +++ b/projects/policyengine-api-simulation/tests/test_simulation_output_builder.py @@ -5,8 +5,12 @@ from types import SimpleNamespace import pandas as pd +import pytest -from fixtures.test_simulation_api_contracts import CURRENT_SINGLE_YEAR_MACRO_KEYS +from fixtures.test_simulation_api_contracts import ( + CURRENT_SINGLE_YEAR_MACRO_KEYS, + CURRENT_SINGLE_YEAR_MACRO_RESULT, +) from fixtures.test_simulation_output_builder import ( BASELINE_POVERTY_BY_AGE, BASELINE_POVERTY_BY_GENDER, @@ -18,6 +22,7 @@ fake_analysis, ) from src.modal.simulation import _normalise_policy +from src.modal.simulation import _run_simulation_impl_core from src.modal.simulation_macro_output import ( BudgetaryImpact, BudgetaryOutput, @@ -79,7 +84,9 @@ def _simulation_output_builder( analysis = analysis or fake_analysis() country_module = SimpleNamespace( model=SimpleNamespace(version="1.700.0" if country == "us" else "2.88.20"), - economic_impact_analysis=lambda baseline_simulation, reform_simulation: analysis, + economic_impact_analysis=lambda baseline_simulation, reform_simulation: ( + analysis + ), ) return SimulationOutputBuilder( country=country, @@ -186,44 +193,7 @@ def test_builder_returns_existing_single_year_macro_shape(monkeypatch): output = _build_schema_output(monkeypatch).model_dump(mode="json") assert set(output) == CURRENT_SINGLE_YEAR_MACRO_KEYS - assert output["model_version"] == "1.700.0" - assert output["data_version"] == "1.115.5" - assert output["budget"] == { - "tax_revenue_impact": 100.0, - "state_tax_revenue_impact": 20.0, - "benefit_spending_impact": 30.0, - "budgetary_impact": 70.0, - "households": 2.0, - "baseline_net_income": 1000.0, - } - assert output["detailed_budget"] == { - "income_tax": {"baseline": 100.0, "reform": 125.0, "difference": 25.0} - } - assert output["decile"] == { - "average": {"1": 10.0, "2": 20.0}, - "relative": {"1": 0.01, "2": 0.02}, - } - assert output["intra_decile"]["deciles"]["Gain more than 5%"] == [0.5, 0.1] - assert output["intra_decile"]["all"]["Gain more than 5%"] == 0.3 - assert output["poverty"]["poverty"]["all"] == { - "baseline": 0.10, - "reform": 0.09, - } - assert output["poverty"]["poverty"]["child"] == { - "baseline": 0.12, - "reform": 0.11, - } - assert output["poverty_by_gender"]["poverty"]["male"] == { - "baseline": 0.08, - "reform": 0.07, - } - assert output["poverty_by_race"]["poverty"]["white"] == { - "baseline": 0.06, - "reform": 0.05, - } - assert output["wealth_decile"] is None - assert output["intra_wealth_decile"] is None - assert output["congressional_district_impact"] == [{"district_geoid": 101}] + assert output == CURRENT_SINGLE_YEAR_MACRO_RESULT def test_builder_maps_uk_wealth_outputs_and_omits_us_only_race(monkeypatch): @@ -269,6 +239,64 @@ def test_normalise_policy_converts_legacy_period_range_keys(): } +def test_run_simulation_impl_core_builds_and_serializes_macro_output(monkeypatch): + dataset = object() + country_module = SimpleNamespace(model=SimpleNamespace(version="1.700.0")) + baseline_simulation = object() + reform_simulation = object() + build_calls = [] + builder_calls = [] + + def fake_country_module(country): + assert country == "us" + return country_module + + def fake_build_simulation(params, *, dataset, policy): + build_calls.append((params, dataset, policy)) + return baseline_simulation if len(build_calls) == 1 else reform_simulation + + class FakeSimulationOutputBuilder: + def __init__(self, **kwargs): + builder_calls.append(kwargs) + + def serialize(self): + return CURRENT_SINGLE_YEAR_MACRO_RESULT + + monkeypatch.setattr("src.modal.simulation._country_module", fake_country_module) + monkeypatch.setattr("src.modal.simulation._load_dataset", lambda params: dataset) + monkeypatch.setattr("src.modal.simulation._build_simulation", fake_build_simulation) + monkeypatch.setattr( + "src.modal.simulation.SimulationOutputBuilder", + FakeSimulationOutputBuilder, + ) + + result = _run_simulation_impl_core( + { + "country": "us", + "baseline": {"gov.test.parameter": {"2026-01-01.2100-12-31": 1}}, + "reform": {"gov.test.parameter": {"2026-01-01.2100-12-31": 2}}, + } + ) + + assert result == CURRENT_SINGLE_YEAR_MACRO_RESULT + assert build_calls[0][2] == {"gov.test.parameter": {"2026-01-01": 1}} + assert build_calls[1][2] == {"gov.test.parameter": {"2026-01-01": 2}} + assert builder_calls == [ + { + "country": "us", + "simulation_params": { + "country": "us", + "baseline": {"gov.test.parameter": {"2026-01-01.2100-12-31": 1}}, + "reform": {"gov.test.parameter": {"2026-01-01.2100-12-31": 2}}, + }, + "country_module": country_module, + "dataset": dataset, + "baseline": baseline_simulation, + "reform": reform_simulation, + } + ] + + def test_builder_budgetary_impact_uses_materialized_columns_and_uk_state_tax_zero(): baseline = _FakeSimulation( pd.DataFrame( @@ -312,6 +340,21 @@ def test_builder_budgetary_impact_uses_materialized_columns_and_uk_state_tax_zer assert uk_budget.state_tax_revenue_impact == 0.0 +def test_builder_budgetary_impact_propagates_required_calculation_errors(monkeypatch): + baseline, reform = _macro_baseline_reform() + + def fail_change_output_variable(*args, **kwargs): + raise RuntimeError("household_tax missing") + + monkeypatch.setattr( + "src.modal.simulation_output_builder._change_output_variable", + fail_change_output_variable, + ) + + with pytest.raises(RuntimeError, match="household_tax missing"): + _simulation_output_builder("us", baseline, reform)._build_budgetary_impact() + + def test_uk_constituency_impact_uses_policyengine_output_function(monkeypatch): baseline = object() reform = object() diff --git a/projects/policyengine-api-simulation/tests/test_standalone_simulation_contract.py b/projects/policyengine-api-simulation/tests/test_standalone_simulation_contract.py new file mode 100644 index 000000000..cc17db53c --- /dev/null +++ b/projects/policyengine-api-simulation/tests/test_standalone_simulation_contract.py @@ -0,0 +1,41 @@ +"""Contract tests for the live synchronous simulation FastAPI app.""" + +from fastapi.testclient import TestClient + +from fixtures.test_simulation_api_contracts import CURRENT_SINGLE_YEAR_MACRO_RESULT +from policyengine_api_simulation.main import app + + +def test_standalone_simulation_openapi_keeps_legacy_schema_names(): + spec = app.openapi() + route = spec["paths"]["/simulate/economy/comparison"]["post"] + + assert route["requestBody"]["content"]["application/json"]["schema"] == { + "$ref": "#/components/schemas/SimulationOptions" + } + assert route["responses"]["200"]["content"]["application/json"]["schema"] == { + "$ref": "#/components/schemas/EconomyComparison" + } + assert ( + "telemetry" + not in spec["components"]["schemas"]["SimulationOptions"]["properties"] + ) + + +def test_standalone_simulation_route_returns_legacy_macro_contract(monkeypatch): + def fake_run_simulation_impl(params): + assert params == {"country": "us", "reform": {}} + return CURRENT_SINGLE_YEAR_MACRO_RESULT + + monkeypatch.setattr( + "policyengine_api_simulation.simulation.run_simulation_impl", + fake_run_simulation_impl, + ) + + response = TestClient(app).post( + "/simulate/economy/comparison", + json={"country": "us", "reform": {}}, + ) + + assert response.status_code == 200 + assert response.json() == CURRENT_SINGLE_YEAR_MACRO_RESULT From 10af71336b7d168b92b03bb73cc1df0560cffb27 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 28 May 2026 14:15:57 +0200 Subject: [PATCH 15/23] fix: package simulation runtime for integration --- .github/scripts/modal-deploy-app.sh | 11 +- .github/workflows/modal-deploy.reusable.yml | 7 +- .../fixtures/gateway/test_endpoints.py | 83 ++- .../pyproject.toml | 2 +- .../src/modal/app.py | 25 +- .../src/modal/gateway/app.py | 6 +- .../src/modal/gateway/endpoints.py | 142 +++- .../src/modal/gateway/models.py | 2 +- .../src/modal/hf_dataset.py | 3 + .../src/modal/release_bundle.py | 118 +--- .../src/modal/simulation.py | 304 +------- .../src/modal/simulation_macro_output.py | 135 +--- .../src/modal/simulation_output_builder.py | 659 +---------------- .../src/modal/telemetry.py | 48 +- .../modal/utils/extract_bundle_versions.py | 2 +- .../modal/utils/update_version_registry.py | 97 ++- .../compat_models.py | 2 +- .../policyengine_api_simulation/hf_dataset.py | 150 ++++ .../src/policyengine_api_simulation/main.py | 2 - .../release_bundle.py | 138 ++++ .../policyengine_api_simulation/simulation.py | 2 +- .../simulation_macro_output.py | 134 ++++ .../simulation_output_builder.py | 666 ++++++++++++++++++ .../simulation_runtime.py | 513 ++++++++++++++ .../policyengine_api_simulation/telemetry.py | 47 ++ .../tests/gateway/test_auth.py | 2 +- .../tests/gateway/test_budget_window_state.py | 16 +- .../tests/gateway/test_endpoints.py | 159 ++++- .../tests/gateway/test_models.py | 12 +- .../tests/gateway/test_package_imports.py | 20 + .../tests/test_budget_window_batch.py | 14 +- .../tests/test_budget_window_context.py | 4 +- .../tests/test_budget_window_scheduler.py | 2 +- .../tests/test_gcp_credentials.py | 4 +- .../tests/test_hf_dataset.py | 104 +++ .../tests/test_modal_telemetry.py | 2 +- .../test_policyengine_dependency_source.py | 10 +- .../tests/test_release_bundle.py | 31 +- .../tests/test_simulation_api_contracts.py | 2 +- .../tests/test_simulation_output_builder.py | 198 +++++- .../test_standalone_simulation_contract.py | 23 + .../tests/test_update_version_registry.py | 50 +- 42 files changed, 2557 insertions(+), 1394 deletions(-) create mode 100644 projects/policyengine-api-simulation/src/modal/hf_dataset.py create mode 100644 projects/policyengine-api-simulation/src/policyengine_api_simulation/hf_dataset.py create mode 100644 projects/policyengine-api-simulation/src/policyengine_api_simulation/release_bundle.py create mode 100644 projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_macro_output.py create mode 100644 projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_builder.py create mode 100644 projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_runtime.py create mode 100644 projects/policyengine-api-simulation/src/policyengine_api_simulation/telemetry.py create mode 100644 projects/policyengine-api-simulation/tests/test_hf_dataset.py diff --git a/.github/scripts/modal-deploy-app.sh b/.github/scripts/modal-deploy-app.sh index f8581752b..a6be04952 100755 --- a/.github/scripts/modal-deploy-app.sh +++ b/.github/scripts/modal-deploy-app.sh @@ -1,24 +1,24 @@ #!/bin/bash # Deploy simulation API to Modal # Usage: ./modal-deploy-app.sh -# Required env vars: POLICYENGINE_US_VERSION, POLICYENGINE_UK_VERSION +# Required env vars: POLICYENGINE_VERSION, POLICYENGINE_US_VERSION, POLICYENGINE_UK_VERSION # These should come from the bundled policyengine.py release manifest. # # Deploys two apps: # 1. policyengine-simulation-gateway - Stable gateway with fixed URL -# 2. policyengine-simulation-us{X}-uk{Y} - Versioned simulation app +# 2. policyengine-simulation-py{X} - Versioned simulation app set -euo pipefail MODAL_ENV="${1:?Modal environment required}" # Generate versioned simulation app name (dots replaced with dashes for URL safety) -US_VERSION_SAFE="${POLICYENGINE_US_VERSION//./-}" -UK_VERSION_SAFE="${POLICYENGINE_UK_VERSION//./-}" -SIMULATION_APP_NAME="policyengine-simulation-us${US_VERSION_SAFE}-uk${UK_VERSION_SAFE}" +POLICYENGINE_VERSION_SAFE="${POLICYENGINE_VERSION//./-}" +SIMULATION_APP_NAME="policyengine-simulation-py${POLICYENGINE_VERSION_SAFE}" echo "========================================" echo "Deploying to Modal environment: $MODAL_ENV" +echo " policyengine.py version: ${POLICYENGINE_VERSION}" echo " US version: ${POLICYENGINE_US_VERSION}" echo " UK version: ${POLICYENGINE_UK_VERSION}" echo "========================================" @@ -41,6 +41,7 @@ echo "" echo "Step 3: Updating version registries..." uv run python -m src.modal.utils.update_version_registry \ --app-name "$SIMULATION_APP_NAME" \ + --policyengine-version "${POLICYENGINE_VERSION}" \ --us-version "${POLICYENGINE_US_VERSION}" \ --uk-version "${POLICYENGINE_UK_VERSION}" \ --environment "$MODAL_ENV" diff --git a/.github/workflows/modal-deploy.reusable.yml b/.github/workflows/modal-deploy.reusable.yml index 1fa823c17..82f490367 100644 --- a/.github/workflows/modal-deploy.reusable.yml +++ b/.github/workflows/modal-deploy.reusable.yml @@ -15,6 +15,9 @@ on: simulation_api_url: description: 'The deployed simulation API URL' value: ${{ jobs.deploy.outputs.simulation_api_url }} + policyengine_version: + description: 'The deployed policyengine.py package version' + value: ${{ jobs.deploy.outputs.policyengine_version }} us_version: description: 'The deployed policyengine-us package version' value: ${{ jobs.deploy.outputs.us_version }} @@ -35,6 +38,7 @@ jobs: environment: ${{ inputs.environment }} outputs: simulation_api_url: ${{ steps.get-url.outputs.simulation_api_url }} + policyengine_version: ${{ steps.versions.outputs.policyengine_version }} us_version: ${{ steps.versions.outputs.us_version }} us_data_version: ${{ steps.versions.outputs.us_data_version }} uk_version: ${{ steps.versions.outputs.uk_version }} @@ -82,9 +86,10 @@ jobs: env: MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }} MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }} + POLICYENGINE_VERSION: ${{ steps.versions.outputs.policyengine_version }} POLICYENGINE_US_VERSION: ${{ steps.versions.outputs.us_version }} POLICYENGINE_UK_VERSION: ${{ steps.versions.outputs.uk_version }} - run: ../../.github/scripts/modal-deploy-app.sh "${{ inputs.modal_environment }}" src/modal/app.py + run: ../../.github/scripts/modal-deploy-app.sh "${{ inputs.modal_environment }}" - name: Get deployed URL id: get-url diff --git a/projects/policyengine-api-simulation/fixtures/gateway/test_endpoints.py b/projects/policyengine-api-simulation/fixtures/gateway/test_endpoints.py index 855762d2f..12fb044cc 100644 --- a/projects/policyengine-api-simulation/fixtures/gateway/test_endpoints.py +++ b/projects/policyengine-api-simulation/fixtures/gateway/test_endpoints.py @@ -2,6 +2,69 @@ import pytest +TEST_APP_RELEASE_BUNDLE = { + "app_name": "policyengine-simulation-py4-10-0", + "policyengine_version": "4.10.0", + "us": { + "model_version": "1.500.0", + "data_version": "1.110.12", + "default_dataset": "enhanced_cps_2024", + "default_dataset_uri": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.110.12", + "dataset_uris": { + "enhanced_cps_2024": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.110.12", + "cps_2023": "hf://policyengine/policyengine-us-data/cps_2023.h5@1.110.12", + "pooled_3_year_cps_2023": "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.110.12", + }, + "dataset_aliases": { + "enhanced_cps": "enhanced_cps_2024", + "enhanced_cps_2024": "enhanced_cps_2024", + "cps": "cps_2023", + "cps_2023": "cps_2023", + "pooled_cps": "pooled_3_year_cps_2023", + "pooled_3_year_cps_2023": "pooled_3_year_cps_2023", + }, + }, + "uk": { + "model_version": "2.66.0", + "data_version": "1.40.3", + "default_dataset": "enhanced_frs_2023_24", + "default_dataset_uri": "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.40.3", + "dataset_uris": { + "enhanced_frs_2023_24": "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.40.3", + "frs_2023_24": "hf://policyengine/policyengine-uk-data-private/frs_2023_24.h5@1.40.3", + }, + "dataset_aliases": { + "enhanced_frs": "enhanced_frs_2023_24", + "enhanced_frs_2023_24": "enhanced_frs_2023_24", + "frs": "frs_2023_24", + "frs_2023_24": "frs_2023_24", + }, + }, +} + +TEST_APP_NAMES = ( + "policyengine-simulation-py4-10-0", + "policyengine-simulation-py3-9-0", +) + + +def resolve_test_dataset_uri(country: str, dataset: str | None) -> str | None: + if dataset is None: + return None + if "://" in dataset: + return dataset + country_bundle = TEST_APP_RELEASE_BUNDLE[country] + dataset_name, revision = ( + dataset.rsplit("@", maxsplit=1) if "@" in dataset else (dataset, None) + ) + dataset_name = country_bundle["dataset_aliases"].get(dataset_name, dataset_name) + dataset_uri = country_bundle["dataset_uris"].get(dataset_name, dataset_name) + if revision is not None and dataset_uri == dataset_name: + return dataset + if revision is not None and dataset_uri.startswith("hf://"): + dataset_uri = f"{dataset_uri.rsplit('@', maxsplit=1)[0]}@{revision}" + return dataset_uri + class MockDict: """Mock for Modal.Dict to simulate version registry.""" @@ -107,7 +170,11 @@ def mock_modal(monkeypatch): from src.modal.gateway import endpoints mock_func = MockFunction() - mock_dicts = {} + mock_dicts = { + "simulation-api-app-release-bundles": { + app_name: TEST_APP_RELEASE_BUNDLE for app_name in TEST_APP_NAMES + } + } MockFunctionCall.registry = {} MockFunctionCall.from_id_errors = {} @@ -134,6 +201,20 @@ class MockModal: monkeypatch.setattr(endpoints, "modal", MockModal) monkeypatch.setattr(budget_window_state, "modal", MockModal) + monkeypatch.setattr( + endpoints, + "with_hf_revision", + lambda dataset_uri, revision: ( + f"{dataset_uri.rsplit('@', maxsplit=1)[0]}@{revision}" + if dataset_uri.startswith("hf://") + else dataset_uri + ), + ) + monkeypatch.setattr( + endpoints, + "validate_hf_dataset_uri", + lambda dataset_uri: dataset_uri, + ) return { "func": mock_func, diff --git a/projects/policyengine-api-simulation/pyproject.toml b/projects/policyengine-api-simulation/pyproject.toml index 2190556c8..6eb4aadde 100644 --- a/projects/policyengine-api-simulation/pyproject.toml +++ b/projects/policyengine-api-simulation/pyproject.toml @@ -9,7 +9,7 @@ readme = "README.md" authors = [ {name = "PolicyEngine", email = "hello@policyengine.org"}, ] -license = {file = "../../LICENSE"} +license = "AGPL-3.0-only" requires-python = ">=3.13,<3.14" dependencies = [ "opentelemetry-instrumentation-sqlalchemy (>=0.51b0,<0.52)", diff --git a/projects/policyengine-api-simulation/src/modal/app.py b/projects/policyengine-api-simulation/src/modal/app.py index 5831d6c86..51fa9901f 100644 --- a/projects/policyengine-api-simulation/src/modal/app.py +++ b/projects/policyengine-api-simulation/src/modal/app.py @@ -2,7 +2,7 @@ PolicyEngine Simulation - Versioned Modal App This app contains the heavy simulation workload with snapshotted models. -Each deployment creates a versioned app (e.g., policyengine-simulation-us1-459-0-uk2-65-9). +Each deployment creates a versioned app (e.g., policyengine-simulation-py4-10-0). The gateway app (policyengine-simulation-gateway) routes requests to these versioned apps. """ @@ -13,7 +13,7 @@ from src.modal._image_setup import snapshot_models from src.modal.dependency_pins import project_dependency_pin from src.modal.logging_redaction import redact_params_for_logging -from src.modal.release_bundle import get_bundled_country_model_version +from policyengine_api_simulation.release_bundle import get_bundled_country_model_version POLICYENGINE_VERSION = os.environ.get("POLICYENGINE_VERSION") or project_dependency_pin( "policyengine" @@ -31,20 +31,19 @@ ) or get_bundled_country_model_version("uk") -def get_app_name(us_version: str, uk_version: str) -> str: +def get_app_name(policyengine_version: str) -> str: """ - Generate versioned app name from package versions. + Generate versioned app name from the policyengine.py package version. Replaces dots with dashes for URL safety. - Example: us1.459.0, uk2.65.9 -> policyengine-simulation-us1-459-0-uk2-65-9 + Example: 4.10.0 -> policyengine-simulation-py4-10-0 """ - us_safe = us_version.replace(".", "-") - uk_safe = uk_version.replace(".", "-") - return f"policyengine-simulation-us{us_safe}-uk{uk_safe}" + policyengine_safe = policyengine_version.replace(".", "-") + return f"policyengine-simulation-py{policyengine_safe}" # App name can be overridden via environment variable, otherwise generated from versions -APP_NAME = os.environ.get("MODAL_APP_NAME", get_app_name(US_VERSION, UK_VERSION)) +APP_NAME = os.environ.get("MODAL_APP_NAME", get_app_name(POLICYENGINE_VERSION)) # App definition with versioned name app = modal.App(APP_NAME) @@ -67,7 +66,11 @@ def get_app_name(us_version: str, uk_version: str) -> str: "tables>=3.10.2", "logfire", ) - .add_local_python_source("src.modal", copy=True) + .add_local_python_source( + "src.modal", + "policyengine_api_simulation", + copy=True, + ) .run_function(snapshot_models) ) @@ -106,7 +109,7 @@ def run_simulation(params: dict) -> dict: """ import logfire - from src.modal.simulation import run_simulation_impl + from policyengine_api_simulation.simulation_runtime import run_simulation_impl configure_logfire() diff --git a/projects/policyengine-api-simulation/src/modal/gateway/app.py b/projects/policyengine-api-simulation/src/modal/gateway/app.py index a861cc02a..ec57b5dff 100644 --- a/projects/policyengine-api-simulation/src/modal/gateway/app.py +++ b/projects/policyengine-api-simulation/src/modal/gateway/app.py @@ -26,7 +26,11 @@ # the auth module at runtime here. "cryptography>=41.0.0", ) - .add_local_python_source("src.modal", copy=True) + .add_local_python_source( + "src.modal", + "policyengine_api_simulation", + copy=True, + ) .add_local_python_source("policyengine_fastapi", copy=True) ) diff --git a/projects/policyengine-api-simulation/src/modal/gateway/endpoints.py b/projects/policyengine-api-simulation/src/modal/gateway/endpoints.py index 35dc800a3..b972badbb 100644 --- a/projects/policyengine-api-simulation/src/modal/gateway/endpoints.py +++ b/projects/policyengine-api-simulation/src/modal/gateway/endpoints.py @@ -34,18 +34,121 @@ failed_job_response, running_job_response, ) -from src.modal.release_bundle import resolve_bundle_dataset_uri +from policyengine_api_simulation.hf_dataset import ( + HuggingFaceDatasetReferenceError, + validate_hf_dataset_uri, + with_hf_revision, +) logger = logging.getLogger(__name__) router = APIRouter() JOB_METADATA_DICT_NAME = "simulation-api-job-metadata" +APP_RELEASE_BUNDLES_DICT_NAME = "simulation-api-app-release-bundles" def _job_metadata_store(): return modal.Dict.from_name(JOB_METADATA_DICT_NAME, create_if_missing=True) +def _app_release_bundle_store(): + return modal.Dict.from_name(APP_RELEASE_BUNDLES_DICT_NAME, create_if_missing=True) + + +def _app_release_bundle(app_name: str) -> dict: + bundle = _app_release_bundle_store().get(app_name) + return bundle if isinstance(bundle, dict) else {} + + +def _split_requested_revision(requested_data: str) -> tuple[str, str | None]: + if "@" not in requested_data: + return requested_data, None + dataset_name, revision = requested_data.rsplit("@", maxsplit=1) + if not dataset_name or not revision: + raise ValueError(f"Invalid dataset revision reference: {requested_data}") + return dataset_name, revision + + +def _select_dataset_revision( + *, + requested_revision: str | None, + requested_data_version: str | None, +) -> str | None: + if ( + requested_revision is not None + and requested_data_version is not None + and requested_revision != requested_data_version + ): + raise ValueError( + "Conflicting dataset revisions: " + f"data requests {requested_revision!r} but data_version is " + f"{requested_data_version!r}" + ) + return requested_revision or requested_data_version + + +def _resolve_dataset_uri_from_app_bundle( + *, + app_bundle: dict, + country: str, + requested_data: str | None, + requested_data_version: str | None = None, +) -> str | None: + if requested_data is None: + if requested_data_version is None: + return None + country_bundle = app_bundle.get(country.lower()) + if not isinstance(country_bundle, dict): + return None + default_uri = country_bundle.get("default_dataset_uri") + if not isinstance(default_uri, str): + return None + return with_hf_revision(default_uri, requested_data_version) + + requested_without_revision, requested_revision = _split_requested_revision( + requested_data + ) + revision = _select_dataset_revision( + requested_revision=requested_revision, + requested_data_version=requested_data_version, + ) + + if "://" in requested_without_revision: + if requested_without_revision.startswith("hf://"): + return ( + with_hf_revision(requested_without_revision, revision) + if revision is not None + else validate_hf_dataset_uri(requested_data) + ) + return requested_data + + country_bundle = app_bundle.get(country.lower()) + if not isinstance(country_bundle, dict): + return requested_data + + aliases = country_bundle.get("dataset_aliases") + if not isinstance(aliases, dict): + aliases = {} + dataset_name = aliases.get(requested_without_revision, requested_without_revision) + + if "://" in dataset_name: + return ( + with_hf_revision(dataset_name, revision) + if revision is not None + else dataset_name + ) + + dataset_uris = country_bundle.get("dataset_uris") + if not isinstance(dataset_uris, dict): + return requested_data + dataset_uri = dataset_uris.get(dataset_name) + if not isinstance(dataset_uri, str): + return requested_data + return ( + with_hf_revision(dataset_uri, revision) if revision is not None else dataset_uri + ) + + def _modal_exception_class(name: str): exception_module = getattr(modal, "exception", None) if exception_module is None: @@ -65,17 +168,20 @@ def _is_modal_job_not_found(exc: BaseException) -> bool: def _build_policyengine_bundle( - country: str, resolved_version: str, payload: dict + country: str, resolved_version: str, app_name: str, payload: dict ) -> PolicyEngineBundle: + app_bundle = _app_release_bundle(app_name) dataset = payload.get("data") - resolved_dataset = ( - resolve_bundle_dataset_uri(country, dataset) - if isinstance(dataset, str) - else None + data_version = payload.get("data_version") + resolved_dataset = _resolve_dataset_uri_from_app_bundle( + app_bundle=app_bundle, + country=country, + requested_data=dataset if isinstance(dataset, str) else None, + requested_data_version=data_version if isinstance(data_version, str) else None, ) return PolicyEngineBundle( model_version=resolved_version, - data_version=payload.get("data_version"), + data_version=data_version, dataset=resolved_dataset, ) @@ -171,6 +277,13 @@ async def submit_simulation(request: SimulationRequest): if request.telemetry is not None: payload["_telemetry"] = request.telemetry.model_dump(mode="json") + try: + bundle = _build_policyengine_bundle( + request.country, resolved_version, app_name, payload + ) + except (ValueError, HuggingFaceDatasetReferenceError) as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + logger.info( "Routing %s:%s to app %s (run_id=%s)", request.country, @@ -185,7 +298,6 @@ async def submit_simulation(request: SimulationRequest): # Spawn the job (returns immediately) call = sim_func.spawn(payload) - bundle = _build_policyengine_bundle(request.country, resolved_version, payload) job_metadata = _serialize_job_metadata(app_name, bundle, run_id) _job_metadata_store()[call.object_id] = job_metadata @@ -216,11 +328,15 @@ async def submit_budget_window_batch(request: BudgetWindowBatchRequest): except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) - bundle = _build_policyengine_bundle( - request.country, - resolved_version, - request.model_dump(mode="json"), - ) + try: + bundle = _build_policyengine_bundle( + request.country, + resolved_version, + app_name, + request.model_dump(mode="json"), + ) + except (ValueError, HuggingFaceDatasetReferenceError) as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc payload = _build_budget_window_parent_payload( request, resolved_version=resolved_version, diff --git a/projects/policyengine-api-simulation/src/modal/gateway/models.py b/projects/policyengine-api-simulation/src/modal/gateway/models.py index a94e98e51..47a235369 100644 --- a/projects/policyengine-api-simulation/src/modal/gateway/models.py +++ b/projects/policyengine-api-simulation/src/modal/gateway/models.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from src.modal.telemetry import TelemetryEnvelope +from policyengine_api_simulation.telemetry import TelemetryEnvelope # Hard cap on request body size (bytes). SimulationOptions + telemetry + any diff --git a/projects/policyengine-api-simulation/src/modal/hf_dataset.py b/projects/policyengine-api-simulation/src/modal/hf_dataset.py new file mode 100644 index 000000000..836616831 --- /dev/null +++ b/projects/policyengine-api-simulation/src/modal/hf_dataset.py @@ -0,0 +1,3 @@ +"""Compatibility shim for packaged simulation helpers.""" + +from policyengine_api_simulation.hf_dataset import * # noqa: F403 diff --git a/projects/policyengine-api-simulation/src/modal/release_bundle.py b/projects/policyengine-api-simulation/src/modal/release_bundle.py index ef81a37f4..c3022a52b 100644 --- a/projects/policyengine-api-simulation/src/modal/release_bundle.py +++ b/projects/policyengine-api-simulation/src/modal/release_bundle.py @@ -1,117 +1,3 @@ -"""Helpers for using the bundled policyengine.py release manifests. +"""Compatibility shim for packaged simulation helpers.""" -The simulation API deploys separate versioned worker apps, but the country -package and data artifact versions must come from the policyengine.py bundle -manifest so model/data compatibility stays explicit. -""" - -from __future__ import annotations - -import os -from dataclasses import dataclass -from functools import lru_cache -from typing import Mapping - -os.environ.setdefault("POLICYENGINE_SKIP_COUNTRY_IMPORTS", "1") - -SUPPORTED_COUNTRIES = frozenset({"us", "uk"}) - -DATASET_ALIASES: dict[str, dict[str, str]] = { - "us": { - "enhanced_cps": "enhanced_cps_2024", - "enhanced_cps_2024": "enhanced_cps_2024", - "cps_small": "cps_small_2024", - "cps_small_2024": "cps_small_2024", - "cps": "hf://policyengine/policyengine-us-data/cps_2023.h5@1.110.12", - "cps_2023": "hf://policyengine/policyengine-us-data/cps_2023.h5@1.110.12", - "pooled_cps": "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.110.12", - "pooled_3_year_cps_2023": "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.110.12", - }, - "uk": { - "enhanced_frs": "enhanced_frs_2023_24", - "enhanced_frs_2023_24": "enhanced_frs_2023_24", - "frs": "frs_2023_24", - "frs_2023_24": "frs_2023_24", - }, -} - - -@dataclass(frozen=True) -class CountryReleaseBundle: - country: str - policyengine_version: str - model_package_name: str - model_version: str - data_package_name: str - data_version: str - default_dataset: str - default_dataset_uri: str - dataset_uris: Mapping[str, str] - - -def _normalise_country(country: str) -> str: - country = country.lower() - if country not in SUPPORTED_COUNTRIES: - raise ValueError(f"Unsupported country: {country}") - return country - - -def _artifact_revision(data_package) -> str: - return data_package.release_manifest_revision or data_package.version - - -@lru_cache -def get_country_release_bundle(country: str) -> CountryReleaseBundle: - """Return package and dataset versions from policyengine.py's manifest.""" - - country = _normalise_country(country) - from policyengine.provenance.manifest import build_hf_uri, get_release_manifest - - manifest = get_release_manifest(country) - dataset_uris = { - name: build_hf_uri( - repo_id=manifest.data_package.repo_id, - path_in_repo=reference.path, - revision=reference.revision or _artifact_revision(manifest.data_package), - ) - for name, reference in manifest.datasets.items() - } - - return CountryReleaseBundle( - country=country, - policyengine_version=manifest.policyengine_version, - model_package_name=manifest.model_package.name, - model_version=manifest.model_package.version, - data_package_name=manifest.data_package.name, - data_version=manifest.data_package.version, - default_dataset=manifest.default_dataset, - default_dataset_uri=manifest.default_dataset_uri, - dataset_uris=dataset_uris, - ) - - -def get_bundled_country_model_version(country: str) -> str: - return get_country_release_bundle(country).model_version - - -def resolve_bundle_dataset_name(country: str, requested_data: str | None) -> str: - bundle = get_country_release_bundle(country) - if requested_data is None: - return bundle.default_dataset - - if "://" in requested_data or "@" in requested_data: - return requested_data - - requested_without_revision = requested_data.split("@", maxsplit=1)[0] - aliased = DATASET_ALIASES.get(bundle.country, {}).get( - requested_without_revision, requested_data - ) - return aliased - - -def resolve_bundle_dataset_uri(country: str, requested_data: str | None) -> str: - bundle = get_country_release_bundle(country) - dataset_name = resolve_bundle_dataset_name(country, requested_data) - if "://" in dataset_name: - return dataset_name - return bundle.dataset_uris.get(dataset_name, dataset_name) +from policyengine_api_simulation.release_bundle import * # noqa: F403 diff --git a/projects/policyengine-api-simulation/src/modal/simulation.py b/projects/policyengine-api-simulation/src/modal/simulation.py index c0527cda6..f6197e79b 100644 --- a/projects/policyengine-api-simulation/src/modal/simulation.py +++ b/projects/policyengine-api-simulation/src/modal/simulation.py @@ -1,303 +1,3 @@ -""" -Simulation implementation - pure logic with snapshotted imports. +"""Compatibility shim for packaged simulation helpers.""" -This module avoids importing policyengine at module level so the worker can -load the requested country module without triggering cross-country imports. -No Modal dependencies here. -""" - -import contextlib -import json -import logging -import os -import tempfile -from importlib import import_module -from typing import Any, Iterator - -from src.modal.release_bundle import resolve_bundle_dataset_name -from src.modal.simulation_output_builder import SimulationOutputBuilder -from src.modal.telemetry import split_internal_payload - -logger = logging.getLogger(__name__) - -os.environ.setdefault("POLICYENGINE_SKIP_COUNTRY_IMPORTS", "1") - -DEFAULT_YEAR = 2026 - - -def _normalize_credentials_blob(creds_json: str) -> str: - """Return the raw JSON blob, decoding the outer escape if present. - - The upstream Modal secret sometimes stores the credentials payload - double-encoded (the entire JSON object is wrapped in quotes with - backslash-escaped interior quotes). Historically we always attempted - the unescape as a fallback which could accidentally parse an already - clean blob. Only unwrap when the payload looks wrapped.""" - - try: - json.loads(creds_json) - except json.JSONDecodeError: - looks_escaped = creds_json.lstrip().startswith('"') or '\\"' in creds_json - if looks_escaped: - return json.loads(f'"{creds_json}"') - raise - return creds_json - - -@contextlib.contextmanager -def setup_gcp_credentials() -> Iterator[None]: - """ - Set up GCP credentials from environment variable. - - Modal secrets are injected as environment variables. The GCP library - expects GOOGLE_APPLICATION_CREDENTIALS to point to a file path. If - credentials JSON is provided, write it to a temp file that's deleted - on exit. This runs as a context manager to guarantee cleanup even if - the caller raises mid-simulation; the previous fire-and-forget - ``tempfile.mkstemp`` path leaked credential material on disk every - time a container served a request. - """ - # Log available GCP-related env vars for debugging - gcp_vars = { - k: v[:50] + "..." if len(v) > 50 else v - for k, v in os.environ.items() - if "GOOGLE" in k or "GCP" in k or "CREDENTIAL" in k - } - logger.info(f"GCP-related env vars: {list(gcp_vars.keys())}") - - # Check if credentials are already set as a file path - if os.environ.get("GOOGLE_APPLICATION_CREDENTIALS"): - logger.info("GOOGLE_APPLICATION_CREDENTIALS already set") - yield - return - - # Check for credentials JSON in various env var names - creds_json = ( - os.environ.get("GOOGLE_APPLICATION_CREDENTIALS_JSON") - or os.environ.get("GCP_CREDENTIALS_JSON") - or os.environ.get("GOOGLE_CREDENTIALS") - or os.environ.get("SERVICE_ACCOUNT_JSON") - ) - - if not creds_json: - logger.warning("No GCP credentials found in environment variables") - yield - return - - normalized = _normalize_credentials_blob(creds_json) - - # ``NamedTemporaryFile(delete=True)`` removes the file when the context - # exits (either normally or via exception). We restore any prior value - # of ``GOOGLE_APPLICATION_CREDENTIALS`` so a retry in the same - # container doesn't silently pick up a path that no longer exists. - previous = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS") - with tempfile.NamedTemporaryFile( - mode="w", suffix=".json", delete=True - ) as creds_file: - creds_file.write(normalized) - creds_file.flush() - os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = creds_file.name - logger.info(f"GCP credentials written to {creds_file.name}") - try: - yield - finally: - if previous is None: - os.environ.pop("GOOGLE_APPLICATION_CREDENTIALS", None) - else: - os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = previous - - -def run_simulation_impl(params: dict) -> dict: - """ - Execute economic simulation. - - Pure implementation with no Modal dependencies. - Accepts the gateway simulation payload and returns the legacy macro result dict. - """ - # Set up GCP credentials if needed. The credentials temp file is - # cleaned up on exit so we never leave signed JSON material on disk. - with setup_gcp_credentials(): - return _run_simulation_impl_core(params) - - -def _parse_year(params: dict[str, Any]) -> int: - value = params.get("time_period") or params.get("year") or DEFAULT_YEAR - return int(value) - - -def _normalise_period_key(period_key: Any) -> str: - """Convert legacy ``start.stop`` period keys to v4 effective dates.""" - text = str(period_key) - parts = text.split(".") - if len(parts) > 1 and len(parts[0]) == 10: - return parts[0] - return text - - -def _normalise_policy(policy: dict[str, Any] | None) -> dict[str, Any] | None: - if not policy: - return None - - normalised: dict[str, Any] = {} - for parameter, value in policy.items(): - if isinstance(value, dict): - normalised[parameter] = { - _normalise_period_key(period): period_value - for period, period_value in value.items() - } - else: - normalised[parameter] = value - return normalised - - -def _resolve_dataset_name(country: str, requested_data: str | None) -> str: - return resolve_bundle_dataset_name(country, requested_data) - - -def _microframe_like(frame, weights: str): - from microdf import MicroDataFrame - - return MicroDataFrame(frame.copy(), weights=weights) - - -def _person_group_column(person, entity: str) -> str: - prefixed = f"person_{entity}_id" - if prefixed in person.columns: - return prefixed - return f"{entity}_id" - - -def _subsample_us_dataset(dataset, subsample: int | None): - if not subsample: - return dataset - - from policyengine.tax_benefit_models.us.datasets import ( - PolicyEngineUSDataset, - USYearData, - ) - - dataset.load() - data = dataset.data - household = data.household.head(int(subsample)).copy() - household_ids = set(household["household_id"]) - - person_household_col = _person_group_column(data.person, "household") - person = data.person[data.person[person_household_col].isin(household_ids)].copy() - - def group_subset(entity: str): - person_col = _person_group_column(person, entity) - entity_id_col = f"{entity}_id" - ids = set(person[person_col]) - frame = getattr(data, entity) - return frame[frame[entity_id_col].isin(ids)].copy() - - subset_data = USYearData( - person=_microframe_like(person, "person_weight"), - marital_unit=_microframe_like( - group_subset("marital_unit"), "marital_unit_weight" - ), - family=_microframe_like(group_subset("family"), "family_weight"), - spm_unit=_microframe_like(group_subset("spm_unit"), "spm_unit_weight"), - tax_unit=_microframe_like(group_subset("tax_unit"), "tax_unit_weight"), - household=_microframe_like(household, "household_weight"), - ) - subset_path = os.path.join( - os.environ.get("POLICYENGINE_DATA_FOLDER", "/tmp/policyengine-data"), - f"{dataset.id}_subsample_{subsample}.h5", - ) - return PolicyEngineUSDataset( - id=f"{dataset.id}_subsample_{subsample}", - name=f"{dataset.name} subsample {subsample}", - description=dataset.description, - filepath=subset_path, - year=dataset.year, - is_output_dataset=dataset.is_output_dataset, - metadata=getattr(dataset, "metadata", {}), - metadata_filepath=getattr(dataset, "metadata_filepath", None), - data=subset_data, - ) - - -def _country_module(country: str): - country = country.lower() - if country not in {"us", "uk"}: - raise ValueError(f"Unsupported country: {country}") - - return import_module(f"policyengine.tax_benefit_models.{country}") - - -def _load_dataset(params: dict[str, Any]): - country = params.get("country", "us").lower() - year = _parse_year(params) - country_module = _country_module(country) - dataset_name = _resolve_dataset_name(country, params.get("data")) - datasets = country_module.ensure_datasets( - datasets=[dataset_name], - years=[year], - data_folder=os.environ.get( - "POLICYENGINE_DATA_FOLDER", "/tmp/policyengine-data" - ), - ) - dataset = next(iter(datasets.values())) - if country == "us": - return _subsample_us_dataset(dataset, params.get("subsample")) - return dataset - - -def _build_simulation( - params: dict[str, Any], - *, - dataset, - policy: dict[str, Any] | None, -): - from policyengine.core import Simulation - - country_module = _country_module(params.get("country", "us")) - return Simulation( - dataset=dataset, - tax_benefit_model_version=country_module.model, - policy=policy, - ) - - -def _run_simulation_impl_core(params: dict) -> dict: - simulation_params, telemetry, metadata = split_internal_payload(params) - - logger.info( - "Starting simulation for country=%s run_id=%s process_id=%s", - simulation_params.get("country", "unknown"), - getattr(telemetry, "run_id", None), - getattr(telemetry, "process_id", None), - ) - if metadata: - logger.info("Received simulation metadata keys: %s", sorted(metadata)) - - country = simulation_params.get("country", "us").lower() - country_module = _country_module(country) - dataset = _load_dataset(simulation_params) - baseline_policy = _normalise_policy(simulation_params.get("baseline")) - reform_policy = _normalise_policy(simulation_params.get("reform")) - - logger.info("Initialising baseline and reform simulations") - baseline = _build_simulation( - simulation_params, - dataset=dataset, - policy=baseline_policy, - ) - reform = _build_simulation( - simulation_params, - dataset=dataset, - policy=reform_policy, - ) - - logger.info("Calculating economic impact") - output = SimulationOutputBuilder( - country=country, - simulation_params=simulation_params, - country_module=country_module, - dataset=dataset, - baseline=baseline, - reform=reform, - ).serialize() - logger.info("Comparison complete") - return output +from policyengine_api_simulation.simulation_runtime import * # noqa: F403 diff --git a/projects/policyengine-api-simulation/src/modal/simulation_macro_output.py b/projects/policyengine-api-simulation/src/modal/simulation_macro_output.py index 7ef198b71..e409982bb 100644 --- a/projects/policyengine-api-simulation/src/modal/simulation_macro_output.py +++ b/projects/policyengine-api-simulation/src/modal/simulation_macro_output.py @@ -1,134 +1,3 @@ -"""Internal schemas for the simulation API single-year macro output. +"""Compatibility shim for packaged simulation helpers.""" -These models define the legacy dictionary contract the simulation API returns -without exposing that schema through the gateway OpenAPI surface. The gateway -still treats job results as unstructured dictionaries for older callers. -""" - -from __future__ import annotations - -from typing import Any, Generic, TypeVar - -from pydantic import BaseModel, ConfigDict, RootModel - -T = TypeVar("T") - - -class MacroOutputModel(BaseModel): - """Base model for internal macro output schemas.""" - - model_config = ConfigDict(extra="forbid") - - -class MacroRootModel(RootModel[T], Generic[T]): - """Base model for internal root schemas that dump to dict/list values.""" - - -class BudgetaryImpact(MacroOutputModel): - tax_revenue_impact: float - state_tax_revenue_impact: float - benefit_spending_impact: float - budgetary_impact: float - households: float - baseline_net_income: float - - -BudgetaryOutput = BudgetaryImpact - - -class DetailedBudgetProgramOutput(MacroOutputModel): - baseline: float - reform: float - difference: float - - -class DetailedBudgetOutput(MacroRootModel[dict[str, DetailedBudgetProgramOutput]]): - pass - - -class DecileOutput(MacroOutputModel): - average: dict[str, float] - relative: dict[str, float] - - -class IntraDecileOutput(MacroOutputModel): - deciles: dict[str, list[float]] - all: dict[str, float] - - -class BaselineReformValue(MacroOutputModel): - baseline: float - reform: float - - -class AgePovertyOutput(MacroOutputModel): - child: BaselineReformValue - adult: BaselineReformValue - senior: BaselineReformValue - all: BaselineReformValue - - -class GenderPovertyOutput(MacroOutputModel): - male: BaselineReformValue - female: BaselineReformValue - - -class RacePovertyOutput(MacroOutputModel): - white: BaselineReformValue - black: BaselineReformValue - hispanic: BaselineReformValue - other: BaselineReformValue - - -class PovertyOutput(MacroOutputModel): - poverty: AgePovertyOutput - deep_poverty: AgePovertyOutput - - -class PovertyByGenderOutput(MacroOutputModel): - poverty: GenderPovertyOutput - deep_poverty: GenderPovertyOutput - - -class PovertyByRaceOutput(MacroOutputModel): - poverty: RacePovertyOutput - - -class PovertyModuleOutputs(MacroOutputModel): - poverty: PovertyOutput - poverty_by_gender: PovertyByGenderOutput - poverty_by_race: PovertyByRaceOutput | None - - -class InequalityOutput(MacroOutputModel): - gini: BaselineReformValue - top_10_pct_share: BaselineReformValue - top_1_pct_share: BaselineReformValue - - -class LaborSupplyResponseOutput(MacroRootModel[dict[str, Any]]): - pass - - -class GeographicImpactOutput(MacroRootModel[list[dict[str, Any]]]): - pass - - -class SingleYearMacroOutput(MacroOutputModel): - model_version: str - data_version: str - budget: BudgetaryImpact - detailed_budget: DetailedBudgetOutput - decile: DecileOutput - inequality: InequalityOutput - poverty: PovertyOutput - poverty_by_gender: PovertyByGenderOutput - poverty_by_race: PovertyByRaceOutput | None - intra_decile: IntraDecileOutput - wealth_decile: DecileOutput | None - intra_wealth_decile: IntraDecileOutput | None - labor_supply_response: LaborSupplyResponseOutput | None - constituency_impact: GeographicImpactOutput | None - local_authority_impact: GeographicImpactOutput | None - congressional_district_impact: GeographicImpactOutput | None - cliff_impact: None = None +from policyengine_api_simulation.simulation_macro_output import * # noqa: F403 diff --git a/projects/policyengine-api-simulation/src/modal/simulation_output_builder.py b/projects/policyengine-api-simulation/src/modal/simulation_output_builder.py index 2b4c59aaa..8334f2b38 100644 --- a/projects/policyengine-api-simulation/src/modal/simulation_output_builder.py +++ b/projects/policyengine-api-simulation/src/modal/simulation_output_builder.py @@ -1,658 +1,3 @@ -"""Build and serialize the runtime simulation macro output.""" +"""Compatibility shim for packaged simulation helpers.""" -from __future__ import annotations - -import logging -import math -from collections.abc import Iterable, Mapping -from dataclasses import dataclass, field -from importlib import import_module -from typing import Any - -from src.modal.release_bundle import get_country_release_bundle -from src.modal.simulation_macro_output import ( - AgePovertyOutput, - BaselineReformValue, - BudgetaryImpact, - DecileOutput, - DetailedBudgetOutput, - DetailedBudgetProgramOutput, - GeographicImpactOutput, - GenderPovertyOutput, - InequalityOutput, - IntraDecileOutput, - LaborSupplyResponseOutput, - PovertyModuleOutputs, - PovertyByGenderOutput, - PovertyByRaceOutput, - PovertyOutput, - RacePovertyOutput, - SingleYearMacroOutput, -) - -logger = logging.getLogger(__name__) - -INTRA_DECILE_COLUMNS = { - "Lose more than 5%": "lose_more_than_5pct", - "Lose less than 5%": "lose_less_than_5pct", - "No change": "no_change", - "Gain less than 5%": "gain_less_than_5pct", - "Gain more than 5%": "gain_more_than_5pct", -} - -US_POVERTY_TYPES = { - "spm": "poverty", - "spm_deep": "deep_poverty", -} - -UK_POVERTY_TYPES = { - "relative_bhc": "poverty", - "absolute_bhc": "deep_poverty", -} - - -def _number(value: Any, default: float = 0.0) -> float: - if value is None: - return default - try: - result = float(value) - except (TypeError, ValueError): - return default - if math.isnan(result) or math.isinf(result): - return default - return result - - -def _collection_records(collection: Any) -> list[dict[str, Any]]: - if collection is None: - return [] - dataframe = getattr(collection, "dataframe", None) - if dataframe is not None: - return list(dataframe.to_dict("records")) - if isinstance(collection, list): - return [dict(item) for item in collection if isinstance(item, Mapping)] - return [] - - -def _output_model_dump(value: Any) -> Any: - if value is None: - return None - if hasattr(value, "model_dump"): - return value.model_dump(mode="json") - if isinstance(value, Mapping): - return dict(value) - return None - - -def _empty_baseline_reform_value() -> dict[str, float]: - return {"baseline": 0.0, "reform": 0.0} - - -def _empty_age_poverty() -> dict[str, dict[str, float]]: - return { - "child": _empty_baseline_reform_value(), - "adult": _empty_baseline_reform_value(), - "senior": _empty_baseline_reform_value(), - "all": _empty_baseline_reform_value(), - } - - -def _empty_gender_poverty() -> dict[str, dict[str, float]]: - return { - "male": _empty_baseline_reform_value(), - "female": _empty_baseline_reform_value(), - } - - -def _poverty_type(country: str, row: Mapping[str, Any]) -> str | None: - poverty_type = str(row.get("poverty_type") or "").lower() - if country == "us": - return US_POVERTY_TYPES.get(poverty_type) - return UK_POVERTY_TYPES.get(poverty_type) - - -def _fill_poverty_block( - *, - country: str, - output: dict[str, dict[str, dict[str, float]]], - baseline_records: Iterable[Mapping[str, Any]], - reform_records: Iterable[Mapping[str, Any]], - default_group: str, -) -> None: - for side, records in (("baseline", baseline_records), ("reform", reform_records)): - for row in records: - poverty_type = _poverty_type(country, row) - if poverty_type is None: - continue - if poverty_type not in output: - continue - group = str(row.get("filter_group") or default_group).lower() - if group not in output[poverty_type]: - continue - output[poverty_type][group][side] = _number(row.get("rate")) - - -def _age_poverty_output(values: dict[str, dict[str, float]]) -> AgePovertyOutput: - return AgePovertyOutput( - child=BaselineReformValue(**values["child"]), - adult=BaselineReformValue(**values["adult"]), - senior=BaselineReformValue(**values["senior"]), - all=BaselineReformValue(**values["all"]), - ) - - -def _gender_poverty_output( - values: dict[str, dict[str, float]], -) -> GenderPovertyOutput: - return GenderPovertyOutput( - male=BaselineReformValue(**values["male"]), - female=BaselineReformValue(**values["female"]), - ) - - -def _race_poverty_output(values: dict[str, dict[str, float]]) -> RacePovertyOutput: - return RacePovertyOutput( - white=BaselineReformValue(**values["white"]), - black=BaselineReformValue(**values["black"]), - hispanic=BaselineReformValue(**values["hispanic"]), - other=BaselineReformValue(**values["other"]), - ) - - -def _entity_data(simulation, entity: str): - if simulation.output_dataset is None or simulation.output_dataset.data is None: - simulation.ensure() - return getattr(simulation.output_dataset.data, entity) - - -def _sum_output_variable(simulation, variable: str, entity: str) -> float: - data = _entity_data(simulation, entity) - if variable in data.columns: - return float(data[variable].sum()) - - from policyengine.outputs import Aggregate, AggregateType - - output = Aggregate( - simulation=simulation, - variable=variable, - entity=entity, - aggregate_type=AggregateType.SUM, - ) - output.run() - return float(output.result) - - -def _try_sum_output_variable(simulation, variable: str, entity: str) -> float: - try: - return _sum_output_variable(simulation, variable, entity) - except Exception: - logger.warning("Unable to calculate sum for %s", variable, exc_info=True) - return 0.0 - - -def _change_output_variable(baseline, reform, variable: str, entity: str) -> float: - baseline_data = _entity_data(baseline, entity) - reform_data = _entity_data(reform, entity) - if variable in baseline_data.columns and variable in reform_data.columns: - return float((reform_data[variable] - baseline_data[variable]).sum()) - - from policyengine.outputs import ChangeAggregate, ChangeAggregateType - - output = ChangeAggregate( - baseline_simulation=baseline, - reform_simulation=reform, - variable=variable, - entity=entity, - aggregate_type=ChangeAggregateType.SUM, - ) - output.run() - return float(output.result) - - -def _try_change_output_variable(baseline, reform, variable: str, entity: str) -> float: - try: - return _change_output_variable(baseline, reform, variable, entity) - except Exception: - logger.warning("Unable to calculate change for %s", variable, exc_info=True) - return 0.0 - - -def _output_module_function(module_name: str, name: str): - module = import_module(f"policyengine.outputs.{module_name}") - return getattr(module, name) - - -def _poverty_module_function(name: str): - return _output_module_function("poverty", name) - - -def _try_compute_output(label: str, fn, *args, **kwargs): - try: - return fn(*args, **kwargs) - except Exception: - logger.warning("Unable to calculate %s", label, exc_info=True) - return None - - -@dataclass -class SimulationOutputBuilder: - country: str - simulation_params: dict[str, Any] - country_module: Any - dataset: Any - baseline: Any - reform: Any - _analysis: Any = field(default=None, init=False) - - def __post_init__(self) -> None: - self.country = self.country.lower() - - @property - def analysis(self) -> Any: - if self._analysis is None: - self._analysis = self.country_module.economic_impact_analysis( - self.baseline, self.reform - ) - return self._analysis - - def build(self) -> SingleYearMacroOutput: - poverty_outputs = self._build_poverty_outputs() - wealth_decile = getattr(self.analysis, "wealth_decile_impacts", None) - intra_wealth_decile = getattr( - self.analysis, "intra_wealth_decile_impacts", None - ) - - return SingleYearMacroOutput( - model_version=self._model_version(), - data_version=self._data_version(), - budget=self._build_budgetary_impact(), - detailed_budget=self._build_detailed_budget(), - decile=self._build_decile(), - inequality=self._build_inequality(), - poverty=poverty_outputs.poverty, - poverty_by_gender=poverty_outputs.poverty_by_gender, - poverty_by_race=poverty_outputs.poverty_by_race, - intra_decile=self._build_intra_decile_output(), - wealth_decile=self._build_wealth_decile(wealth_decile), - intra_wealth_decile=self._build_intra_wealth_decile(intra_wealth_decile), - labor_supply_response=self._build_labor_supply_response(), - congressional_district_impact=(self._build_congressional_district_impact()), - constituency_impact=self._build_uk_constituency_impact(), - local_authority_impact=self._build_uk_local_authority_impact(), - cliff_impact=None, - ) - - def serialize(self) -> dict[str, Any]: - return self.build().model_dump(mode="json") - - def _build_detailed_budget(self) -> DetailedBudgetOutput: - collection = getattr(self.analysis, "program_statistics", None) - if isinstance(collection, DetailedBudgetOutput): - return collection - detailed_budget: dict[str, DetailedBudgetProgramOutput] = {} - for row in _collection_records(collection): - program_name = row.get("program_name") - if not program_name: - continue - baseline = _number(row.get("baseline_total")) - reform = _number(row.get("reform_total")) - detailed_budget[str(program_name)] = DetailedBudgetProgramOutput( - baseline=baseline, - reform=reform, - difference=_number(row.get("change"), reform - baseline), - ) - return DetailedBudgetOutput(detailed_budget) - - def _build_decile(self) -> DecileOutput: - return self._build_decile_output(getattr(self.analysis, "decile_impacts", None)) - - def _build_inequality(self) -> InequalityOutput: - baseline = getattr(self.analysis, "baseline_inequality", None) - reform = getattr(self.analysis, "reform_inequality", None) - if isinstance(baseline, InequalityOutput): - return baseline - return InequalityOutput( - gini=BaselineReformValue( - baseline=_number(getattr(baseline, "gini", None)), - reform=_number(getattr(reform, "gini", None)), - ), - top_10_pct_share=BaselineReformValue( - baseline=_number(getattr(baseline, "top_10_share", None)), - reform=_number(getattr(reform, "top_10_share", None)), - ), - top_1_pct_share=BaselineReformValue( - baseline=_number(getattr(baseline, "top_1_share", None)), - reform=_number(getattr(reform, "top_1_share", None)), - ), - ) - - def _build_budgetary_impact(self) -> BudgetaryImpact: - tax_revenue_impact = _change_output_variable( - self.baseline, self.reform, "household_tax", entity="household" - ) - benefit_spending_impact = _change_output_variable( - self.baseline, self.reform, "household_benefits", entity="household" - ) - state_tax_revenue_impact = ( - _change_output_variable( - self.baseline, - self.reform, - "household_state_income_tax", - entity="household", - ) - if self.country == "us" - else 0.0 - ) - - return BudgetaryImpact( - tax_revenue_impact=tax_revenue_impact, - state_tax_revenue_impact=state_tax_revenue_impact, - benefit_spending_impact=benefit_spending_impact, - budgetary_impact=tax_revenue_impact - benefit_spending_impact, - households=_sum_output_variable( - self.baseline, "household_weight", entity="household" - ), - baseline_net_income=_sum_output_variable( - self.baseline, "household_net_income", entity="household" - ), - ) - - def _build_poverty_outputs(self) -> PovertyModuleOutputs: - prefix = "us" if self.country == "us" else "uk" - baseline_poverty_by_age = _try_compute_output( - "baseline poverty by age", - _poverty_module_function(f"calculate_{prefix}_poverty_by_age"), - self.baseline, - ) - reform_poverty_by_age = _try_compute_output( - "reform poverty by age", - _poverty_module_function(f"calculate_{prefix}_poverty_by_age"), - self.reform, - ) - baseline_poverty_by_gender = _try_compute_output( - "baseline poverty by gender", - _poverty_module_function(f"calculate_{prefix}_poverty_by_gender"), - self.baseline, - ) - reform_poverty_by_gender = _try_compute_output( - "reform poverty by gender", - _poverty_module_function(f"calculate_{prefix}_poverty_by_gender"), - self.reform, - ) - baseline_poverty_by_race = None - reform_poverty_by_race = None - if self.country == "us": - baseline_poverty_by_race = _try_compute_output( - "baseline poverty by race", - _poverty_module_function("calculate_us_poverty_by_race"), - self.baseline, - ) - reform_poverty_by_race = _try_compute_output( - "reform poverty by race", - _poverty_module_function("calculate_us_poverty_by_race"), - self.reform, - ) - return PovertyModuleOutputs( - poverty=self._build_poverty_output( - baseline=getattr(self.analysis, "baseline_poverty", None), - reform=getattr(self.analysis, "reform_poverty", None), - baseline_by_age=baseline_poverty_by_age, - reform_by_age=reform_poverty_by_age, - ), - poverty_by_gender=self._build_poverty_by_gender_output( - baseline_by_gender=baseline_poverty_by_gender, - reform_by_gender=reform_poverty_by_gender, - ), - poverty_by_race=( - self._build_poverty_by_race_output( - baseline_by_race=baseline_poverty_by_race, - reform_by_race=reform_poverty_by_race, - ) - if self.country == "us" - else None - ), - ) - - def _build_intra_decile_output(self) -> IntraDecileOutput: - from policyengine.outputs.intra_decile_impact import ( - compute_intra_decile_impacts, - ) - - collection = _try_compute_output( - "intra-decile impacts", - compute_intra_decile_impacts, - self.baseline, - self.reform, - income_variable="household_net_income", - entity="household", - ) - return self._build_intra_decile_output_from_collection(collection) - - def _build_wealth_decile(self, wealth_decile) -> DecileOutput | None: - if self.country != "uk": - return None - return self._build_decile_output(wealth_decile) - - def _build_intra_wealth_decile( - self, intra_wealth_decile - ) -> IntraDecileOutput | None: - if self.country != "uk": - return None - return self._build_intra_decile_output_from_collection(intra_wealth_decile) - - def _build_labor_supply_response(self) -> LaborSupplyResponseOutput | None: - labor_supply_response = getattr(self.analysis, "labor_supply_response", None) - if isinstance(labor_supply_response, LaborSupplyResponseOutput): - return labor_supply_response - output = _output_model_dump(labor_supply_response) - return LaborSupplyResponseOutput(output) if isinstance(output, dict) else None - - def _build_geographic_impact_output( - self, value: Any - ) -> GeographicImpactOutput | None: - if isinstance(value, GeographicImpactOutput): - return value - records = _output_model_dump(value) - if isinstance(records, list): - return GeographicImpactOutput( - [dict(item) for item in records if isinstance(item, Mapping)] - ) - if isinstance(value, list): - return GeographicImpactOutput( - [dict(item) for item in value if isinstance(item, Mapping)] - ) - return None - - def _build_decile_output(self, collection: Any) -> DecileOutput: - if isinstance(collection, DecileOutput): - return collection - average: dict[str, float] = {} - relative: dict[str, float] = {} - for row in sorted( - _collection_records(collection), - key=lambda item: _number(item.get("decile")), - ): - decile = int(_number(row.get("decile"))) - if decile <= 0: - continue - key = str(decile) - average[key] = _number(row.get("absolute_change")) - relative[key] = _number(row.get("relative_change")) - return DecileOutput(average=average, relative=relative) - - def _build_intra_decile_output_from_collection( - self, collection: Any - ) -> IntraDecileOutput: - if isinstance(collection, IntraDecileOutput): - return collection - deciles: dict[str, list[float]] = {label: [] for label in INTRA_DECILE_COLUMNS} - all_values: dict[str, float] = {label: 0.0 for label in INTRA_DECILE_COLUMNS} - rows = [ - row - for row in sorted( - _collection_records(collection), - key=lambda item: _number(item.get("decile")), - ) - if int(_number(row.get("decile"))) > 0 - ] - - for label, column in INTRA_DECILE_COLUMNS.items(): - values = [_number(row.get(column)) for row in rows] - deciles[label] = values - all_values[label] = sum(values) / len(values) if values else 0.0 - return IntraDecileOutput(deciles=deciles, all=all_values) - - def _build_poverty_output( - self, - *, - baseline: Any, - reform: Any, - baseline_by_age: Any, - reform_by_age: Any, - ) -> PovertyOutput: - if isinstance(baseline, PovertyOutput): - return baseline - result = { - "poverty": _empty_age_poverty(), - "deep_poverty": _empty_age_poverty(), - } - _fill_poverty_block( - country=self.country, - output=result, - baseline_records=_collection_records(baseline), - reform_records=_collection_records(reform), - default_group="all", - ) - _fill_poverty_block( - country=self.country, - output=result, - baseline_records=_collection_records(baseline_by_age), - reform_records=_collection_records(reform_by_age), - default_group="all", - ) - return PovertyOutput( - poverty=_age_poverty_output(result["poverty"]), - deep_poverty=_age_poverty_output(result["deep_poverty"]), - ) - - def _build_poverty_by_gender_output( - self, - *, - baseline_by_gender: Any, - reform_by_gender: Any, - ) -> PovertyByGenderOutput: - if isinstance(baseline_by_gender, PovertyByGenderOutput): - return baseline_by_gender - result = { - "poverty": _empty_gender_poverty(), - "deep_poverty": _empty_gender_poverty(), - } - _fill_poverty_block( - country=self.country, - output=result, - baseline_records=_collection_records(baseline_by_gender), - reform_records=_collection_records(reform_by_gender), - default_group="all", - ) - return PovertyByGenderOutput( - poverty=_gender_poverty_output(result["poverty"]), - deep_poverty=_gender_poverty_output(result["deep_poverty"]), - ) - - def _build_poverty_by_race_output( - self, - *, - baseline_by_race: Any, - reform_by_race: Any, - ) -> PovertyByRaceOutput: - if isinstance(baseline_by_race, PovertyByRaceOutput): - return baseline_by_race - result = { - "poverty": { - "white": _empty_baseline_reform_value(), - "black": _empty_baseline_reform_value(), - "hispanic": _empty_baseline_reform_value(), - "other": _empty_baseline_reform_value(), - } - } - _fill_poverty_block( - country="us", - output=result, - baseline_records=_collection_records(baseline_by_race), - reform_records=_collection_records(reform_by_race), - default_group="all", - ) - return PovertyByRaceOutput(poverty=_race_poverty_output(result["poverty"])) - - def _build_congressional_district_impact( - self, - ) -> GeographicImpactOutput | None: - if self.country != "us": - return None - - from policyengine.outputs.congressional_district_impact import ( - compute_us_congressional_district_impacts, - ) - - impact = _try_compute_output( - "congressional district impacts", - compute_us_congressional_district_impacts, - self.baseline, - self.reform, - ) - return self._build_geographic_impact_output( - getattr(impact, "district_results", None) if impact is not None else None - ) - - def _build_uk_constituency_impact(self) -> GeographicImpactOutput | None: - if self.country != "uk": - return None - - impact = _try_compute_output( - "constituency impacts", - _output_module_function( - "constituency_impact", "compute_uk_constituency_impacts" - ), - self.baseline, - self.reform, - ) - if impact is None: - return None - return self._build_geographic_impact_output( - getattr(impact, "constituency_results", None) - ) - - def _build_uk_local_authority_impact(self) -> GeographicImpactOutput | None: - if self.country != "uk": - return None - - impact = _try_compute_output( - "local authority impacts", - _output_module_function( - "local_authority_impact", "compute_uk_local_authority_impacts" - ), - self.baseline, - self.reform, - ) - if impact is None: - return None - return self._build_geographic_impact_output( - getattr(impact, "local_authority_results", None) - ) - - def _model_version(self) -> str: - return str(getattr(self.country_module.model, "version", "")) - - def _data_version(self) -> str: - if self.simulation_params.get("data_version"): - return str(self.simulation_params["data_version"]) - try: - return get_country_release_bundle(self.country).data_version - except ValueError: - pass - metadata = getattr(self.dataset, "metadata", {}) or {} - for key in ("data_version", "version"): - value = metadata.get(key) - if value is not None: - return str(value) - return "" +from policyengine_api_simulation.simulation_output_builder import * # noqa: F403 diff --git a/projects/policyengine-api-simulation/src/modal/telemetry.py b/projects/policyengine-api-simulation/src/modal/telemetry.py index dccb418eb..ea120cb12 100644 --- a/projects/policyengine-api-simulation/src/modal/telemetry.py +++ b/projects/policyengine-api-simulation/src/modal/telemetry.py @@ -1,47 +1,3 @@ -""" -Internal telemetry helpers for Modal request passthrough. -""" +"""Compatibility shim for packaged simulation helpers.""" -from __future__ import annotations - -from datetime import datetime -from typing import Any, Literal - -from pydantic import BaseModel, ConfigDict - - -CaptureMode = Literal["disabled", "failures", "threshold", "sampled", "always"] - - -class TelemetryEnvelope(BaseModel): - """Minimal shared telemetry payload shape for gateway and worker code.""" - - run_id: str - process_id: str | None = None - request_id: str | None = None - traceparent: str | None = None - requested_at: datetime | None = None - simulation_kind: str | None = None - geography_code: str | None = None - geography_type: str | None = None - config_hash: str | None = None - capture_mode: CaptureMode = "disabled" - - model_config = ConfigDict(extra="forbid") - - -def split_internal_payload( - params: dict[str, Any], -) -> tuple[dict[str, Any], TelemetryEnvelope | None, dict[str, Any] | None]: - """Strip internal passthrough fields before SimulationOptions validation.""" - - simulation_params = dict(params) - raw_telemetry = simulation_params.pop("_telemetry", None) - raw_metadata = simulation_params.pop("_metadata", None) - - telemetry = None - if raw_telemetry is not None: - telemetry = TelemetryEnvelope.model_validate(raw_telemetry) - - metadata = raw_metadata if isinstance(raw_metadata, dict) else None - return simulation_params, telemetry, metadata +from policyengine_api_simulation.telemetry import * # noqa: F403 diff --git a/projects/policyengine-api-simulation/src/modal/utils/extract_bundle_versions.py b/projects/policyengine-api-simulation/src/modal/utils/extract_bundle_versions.py index 830eb0654..48f02eba3 100644 --- a/projects/policyengine-api-simulation/src/modal/utils/extract_bundle_versions.py +++ b/projects/policyengine-api-simulation/src/modal/utils/extract_bundle_versions.py @@ -6,7 +6,7 @@ import sys from pathlib import Path -from src.modal.release_bundle import get_country_release_bundle +from policyengine_api_simulation.release_bundle import get_country_release_bundle def _bundle_outputs() -> dict[str, str]: diff --git a/projects/policyengine-api-simulation/src/modal/utils/update_version_registry.py b/projects/policyengine-api-simulation/src/modal/utils/update_version_registry.py index 9797887bb..cc91e5ba7 100644 --- a/projects/policyengine-api-simulation/src/modal/utils/update_version_registry.py +++ b/projects/policyengine-api-simulation/src/modal/utils/update_version_registry.py @@ -1,8 +1,10 @@ """ Update Modal version registries after deployment. -Each deployment creates a versioned app (e.g., policyengine-simulation-us1-459-0-uk2-65-9). -This script updates the version dicts to map package versions to app names. +Each deployment creates a versioned policyengine.py app (e.g., +policyengine-simulation-py4-10-0). This script updates the version dicts to map +the policyengine.py version and the bundled country package versions to that +app name. The dicts allow the gateway to route requests for specific versions to the correct app. Multiple versions can coexist - old deployments remain accessible via their version numbers. @@ -12,7 +14,8 @@ Usage: uv run python -m src.modal.utils.update_version_registry \ - --app-name policyengine-simulation-us1-459-0-uk2-65-9 \ + --app-name policyengine-simulation-py4-10-0 \ + --policyengine-version 4.10.0 \ --us-version 1.459.0 \ --uk-version 2.65.9 \ --environment staging @@ -22,6 +25,11 @@ import modal from packaging.version import InvalidVersion, Version +POLICYENGINE_VERSION_DICT_NAME = "simulation-api-policyengine-versions" +US_VERSION_DICT_NAME = "simulation-api-us-versions" +UK_VERSION_DICT_NAME = "simulation-api-uk-versions" +APP_RELEASE_BUNDLES_DICT_NAME = "simulation-api-app-release-bundles" + def _is_newer_version(candidate: str, current: str | None) -> bool: """Return True when ``candidate`` should replace ``current`` as 'latest'. @@ -108,6 +116,59 @@ def update_version_dict( ) +def _country_bundle_metadata(country: str) -> dict: + from policyengine_api_simulation.release_bundle import ( + DATASET_ALIASES, + get_country_release_bundle, + ) + + bundle = get_country_release_bundle(country) + return { + "country": bundle.country, + "model_package_name": bundle.model_package_name, + "model_version": bundle.model_version, + "data_package_name": bundle.data_package_name, + "data_version": bundle.data_version, + "default_dataset": bundle.default_dataset, + "default_dataset_uri": bundle.default_dataset_uri, + "dataset_uris": dict(bundle.dataset_uris), + "dataset_aliases": dict(DATASET_ALIASES.get(bundle.country, {})), + } + + +def build_app_release_bundle_metadata( + *, + app_name: str, + policyengine_version: str, +) -> dict: + return { + "app_name": app_name, + "policyengine_version": policyengine_version, + "us": _country_bundle_metadata("us"), + "uk": _country_bundle_metadata("uk"), + } + + +def put_app_release_bundle_metadata( + *, + environment: str, + app_name: str, + policyengine_version: str, +) -> None: + bundle_store = modal.Dict.from_name( + APP_RELEASE_BUNDLES_DICT_NAME, + environment_name=environment, + create_if_missing=True, + ) + metadata = build_app_release_bundle_metadata( + app_name=app_name, + policyengine_version=policyengine_version, + ) + bundle_store[app_name] = metadata + bundle_store[policyengine_version] = metadata + print(f" {APP_RELEASE_BUNDLES_DICT_NAME}[{app_name}]: updated") + + def main(): parser = argparse.ArgumentParser( description="Update version registries after Modal deployment" @@ -115,7 +176,12 @@ def main(): parser.add_argument( "--app-name", required=True, - help="Versioned app name (e.g., policyengine-simulation-us1-459-0-uk2-65-9)", + help="Versioned app name (e.g., policyengine-simulation-py4-10-0)", + ) + parser.add_argument( + "--policyengine-version", + required=True, + help="policyengine.py package version (e.g., 4.10.0)", ) parser.add_argument( "--us-version", @@ -144,14 +210,25 @@ def main(): print(f"Updating version registries in Modal environment: {args.environment}") print(f" App name: {args.app_name}") + print(f" policyengine.py version: {args.policyengine_version}") print(f" US version: {args.us_version}") print(f" UK version: {args.uk_version}") print() + print("policyengine.py version registry:") + update_version_dict( + POLICYENGINE_VERSION_DICT_NAME, + args.environment, + args.policyengine_version, + args.app_name, + force_latest=args.force_latest, + ) + print() + # Update US registry print("US version registry:") update_version_dict( - "simulation-api-us-versions", + US_VERSION_DICT_NAME, args.environment, args.us_version, args.app_name, @@ -162,7 +239,7 @@ def main(): # Update UK registry print("UK version registry:") update_version_dict( - "simulation-api-uk-versions", + UK_VERSION_DICT_NAME, args.environment, args.uk_version, args.app_name, @@ -170,6 +247,14 @@ def main(): ) print() + print("App release bundle metadata:") + put_app_release_bundle_metadata( + environment=args.environment, + app_name=args.app_name, + policyengine_version=args.policyengine_version, + ) + print() + print("Version registries updated successfully.") diff --git a/projects/policyengine-api-simulation/src/policyengine_api_simulation/compat_models.py b/projects/policyengine-api-simulation/src/policyengine_api_simulation/compat_models.py index e69112943..4678466ad 100644 --- a/projects/policyengine-api-simulation/src/policyengine_api_simulation/compat_models.py +++ b/projects/policyengine-api-simulation/src/policyengine_api_simulation/compat_models.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict -from src.modal.simulation_macro_output import SingleYearMacroOutput +from policyengine_api_simulation.simulation_macro_output import SingleYearMacroOutput class SimulationOptions(BaseModel): diff --git a/projects/policyengine-api-simulation/src/policyengine_api_simulation/hf_dataset.py b/projects/policyengine-api-simulation/src/policyengine_api_simulation/hf_dataset.py new file mode 100644 index 000000000..407803240 --- /dev/null +++ b/projects/policyengine-api-simulation/src/policyengine_api_simulation/hf_dataset.py @@ -0,0 +1,150 @@ +"""Hugging Face dataset reference helpers. + +The gateway image is intentionally small and does not install +``huggingface_hub``. These helpers use the same REST endpoint as +``HfApi.dataset_info(repo_id, revision=...)`` so both gateway and worker +code can validate explicit dataset revisions without adding another runtime +dependency to the gateway. +""" + +from __future__ import annotations + +import json +import os +from dataclasses import dataclass +from functools import lru_cache +from typing import Any +from urllib.error import HTTPError, URLError +from urllib.parse import quote +from urllib.request import Request, urlopen + +HF_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co").rstrip("/") +HF_REQUEST_TIMEOUT_SECONDS = 30 +HF_TOKEN_ENV_VARS = ( + "HF_TOKEN", + "HUGGING_FACE_HUB_TOKEN", + "HUGGINGFACE_HUB_TOKEN", + "HUGGINGFACE_TOKEN", +) + + +class HuggingFaceDatasetReferenceError(ValueError): + """Raised when a Hugging Face dataset reference is invalid.""" + + +@dataclass(frozen=True) +class HFDatasetReference: + repo_id: str + path: str + revision: str | None + + +def _hf_token() -> str | None: + for env_name in HF_TOKEN_ENV_VARS: + value = os.environ.get(env_name) + if value: + return value + return None + + +def parse_hf_dataset_uri(dataset_uri: str) -> HFDatasetReference | None: + """Parse an ``hf://`` dataset artifact URI. + + PolicyEngine release manifests use ``hf://org/repo/path@revision``. The + Hub API needs ``org/repo`` and ``revision`` separately, while path + validation needs the artifact path within the repo. + """ + + if not dataset_uri.startswith("hf://"): + return None + + without_scheme = dataset_uri.removeprefix("hf://") + path_with_repo, revision = ( + without_scheme.rsplit("@", maxsplit=1) + if "@" in without_scheme + else (without_scheme, None) + ) + parts = path_with_repo.split("/", maxsplit=2) + if len(parts) != 3 or not all(parts): + raise HuggingFaceDatasetReferenceError( + f"Invalid Hugging Face dataset URI: {dataset_uri!r}" + ) + return HFDatasetReference( + repo_id=f"{parts[0]}/{parts[1]}", + path=parts[2], + revision=revision, + ) + + +@lru_cache +def _fetch_hf_dataset_revision( + repo_id: str, + revision: str, + token: str | None, +) -> dict[str, Any]: + url = ( + f"{HF_ENDPOINT}/api/datasets/" + f"{quote(repo_id, safe='/')}/revision/{quote(revision, safe='')}" + ) + headers = {"Accept": "application/json"} + if token is not None: + headers["Authorization"] = f"Bearer {token}" + + request = Request(url, headers=headers) + try: + with urlopen(request, timeout=HF_REQUEST_TIMEOUT_SECONDS) as response: + return json.loads(response.read().decode("utf-8")) + except HTTPError as exc: + detail = exc.reason or f"HTTP {exc.code}" + raise HuggingFaceDatasetReferenceError( + f"Hugging Face dataset revision {repo_id}@{revision} was not found: " + f"{detail}" + ) from exc + except (OSError, URLError, json.JSONDecodeError) as exc: + raise HuggingFaceDatasetReferenceError( + f"Unable to validate Hugging Face dataset revision " + f"{repo_id}@{revision}: {exc}" + ) from exc + + +def _siblings_contain_path(payload: dict[str, Any], path: str) -> bool | None: + siblings = payload.get("siblings") + if not isinstance(siblings, list): + return None + + seen_file_listing = False + for sibling in siblings: + if not isinstance(sibling, dict): + continue + name = sibling.get("rfilename") or sibling.get("path") + if isinstance(name, str): + seen_file_listing = True + if name == path: + return True + return False if seen_file_listing else None + + +def validate_hf_dataset_uri(dataset_uri: str) -> str: + """Validate an explicit ``hf://`` dataset URI if it pins a revision.""" + + parsed = parse_hf_dataset_uri(dataset_uri) + if parsed is None or parsed.revision is None: + return dataset_uri + + payload = _fetch_hf_dataset_revision(parsed.repo_id, parsed.revision, _hf_token()) + contains_path = _siblings_contain_path(payload, parsed.path) + if contains_path is False: + raise HuggingFaceDatasetReferenceError( + f"Hugging Face dataset revision {parsed.repo_id}@{parsed.revision} " + f"does not contain artifact {parsed.path!r}" + ) + return dataset_uri + + +def with_hf_revision(dataset_uri: str, revision: str) -> str: + """Return ``dataset_uri`` pinned to ``revision`` and validate it on the Hub.""" + + if not dataset_uri.startswith("hf://"): + return dataset_uri + without_revision = dataset_uri.rsplit("@", maxsplit=1)[0] + return validate_hf_dataset_uri(f"{without_revision}@{revision}") diff --git a/projects/policyengine-api-simulation/src/policyengine_api_simulation/main.py b/projects/policyengine-api-simulation/src/policyengine_api_simulation/main.py index 31a7eb82a..0759fe8cd 100644 --- a/projects/policyengine-api-simulation/src/policyengine_api_simulation/main.py +++ b/projects/policyengine-api-simulation/src/policyengine_api_simulation/main.py @@ -1,11 +1,9 @@ from contextlib import asynccontextmanager -from typing import Any from fastapi import FastAPI from .settings import get_settings, Environment from policyengine_fastapi.opentelemetry import ( GCPLoggingInstrumentor, FastAPIEnhancedInstrumenter, - export_ot_to_console, export_ot_to_gcp, ) from policyengine_fastapi.exit import exit diff --git a/projects/policyengine-api-simulation/src/policyengine_api_simulation/release_bundle.py b/projects/policyengine-api-simulation/src/policyengine_api_simulation/release_bundle.py new file mode 100644 index 000000000..af30fb066 --- /dev/null +++ b/projects/policyengine-api-simulation/src/policyengine_api_simulation/release_bundle.py @@ -0,0 +1,138 @@ +"""Helpers for using the bundled policyengine.py release manifests. + +The simulation API deploys separate versioned worker apps, but the country +package and data artifact versions must come from the policyengine.py bundle +manifest so model/data compatibility stays explicit. +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from functools import lru_cache +from typing import Mapping + +from policyengine_api_simulation.hf_dataset import with_hf_revision + +os.environ.setdefault("POLICYENGINE_SKIP_COUNTRY_IMPORTS", "1") + +SUPPORTED_COUNTRIES = frozenset({"us", "uk"}) + +DATASET_ALIASES: dict[str, dict[str, str]] = { + "us": { + "enhanced_cps": "enhanced_cps_2024", + "enhanced_cps_2024": "enhanced_cps_2024", + "cps_small": "cps_small_2024", + "cps_small_2024": "cps_small_2024", + "cps": "hf://policyengine/policyengine-us-data/cps_2023.h5@1.110.12", + "cps_2023": "hf://policyengine/policyengine-us-data/cps_2023.h5@1.110.12", + "pooled_cps": "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.110.12", + "pooled_3_year_cps_2023": "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.110.12", + }, + "uk": { + "enhanced_frs": "enhanced_frs_2023_24", + "enhanced_frs_2023_24": "enhanced_frs_2023_24", + "frs": "frs_2023_24", + "frs_2023_24": "frs_2023_24", + }, +} + + +@dataclass(frozen=True) +class CountryReleaseBundle: + country: str + policyengine_version: str + model_package_name: str + model_version: str + data_package_name: str + data_version: str + default_dataset: str + default_dataset_uri: str + dataset_uris: Mapping[str, str] + + +def _normalise_country(country: str) -> str: + country = country.lower() + if country not in SUPPORTED_COUNTRIES: + raise ValueError(f"Unsupported country: {country}") + return country + + +def _artifact_revision(data_package) -> str: + return data_package.release_manifest_revision or data_package.version + + +@lru_cache +def get_country_release_bundle(country: str) -> CountryReleaseBundle: + """Return package and dataset versions from policyengine.py's manifest.""" + + country = _normalise_country(country) + from policyengine.provenance.manifest import build_hf_uri, get_release_manifest + + manifest = get_release_manifest(country) + dataset_uris = { + name: build_hf_uri( + repo_id=manifest.data_package.repo_id, + path_in_repo=reference.path, + revision=reference.revision or _artifact_revision(manifest.data_package), + ) + for name, reference in manifest.datasets.items() + } + + return CountryReleaseBundle( + country=country, + policyengine_version=manifest.policyengine_version, + model_package_name=manifest.model_package.name, + model_version=manifest.model_package.version, + data_package_name=manifest.data_package.name, + data_version=manifest.data_package.version, + default_dataset=manifest.default_dataset, + default_dataset_uri=manifest.default_dataset_uri, + dataset_uris=dataset_uris, + ) + + +def get_bundled_country_model_version(country: str) -> str: + return get_country_release_bundle(country).model_version + + +def _split_requested_revision(requested_data: str) -> tuple[str, str | None]: + if "@" not in requested_data: + return requested_data, None + dataset_name, revision = requested_data.rsplit("@", maxsplit=1) + if not dataset_name or not revision: + raise ValueError(f"Invalid dataset revision reference: {requested_data}") + return dataset_name, revision + + +def resolve_bundle_dataset_name(country: str, requested_data: str | None) -> str: + bundle = get_country_release_bundle(country) + if requested_data is None: + return bundle.default_dataset + + if "://" in requested_data: + return requested_data + + requested_without_revision, revision = _split_requested_revision(requested_data) + aliased = DATASET_ALIASES.get(bundle.country, {}).get( + requested_without_revision, requested_data + ) + if revision is not None: + if "://" in aliased: + return with_hf_revision(aliased, revision) + uri = bundle.dataset_uris.get(aliased) + if uri is None: + raise ValueError( + "Unknown dataset revision reference " + f"{requested_data!r} for country {bundle.country!r}" + ) + return with_hf_revision(uri, revision) + return aliased + + +def resolve_bundle_dataset_uri(country: str, requested_data: str | None) -> str: + bundle = get_country_release_bundle(country) + dataset_name = resolve_bundle_dataset_name(country, requested_data) + if "://" in dataset_name: + return dataset_name + return bundle.dataset_uris.get(dataset_name, dataset_name) diff --git a/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation.py b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation.py index 5856c9460..722d207cf 100644 --- a/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation.py +++ b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation.py @@ -2,7 +2,7 @@ from fastapi import APIRouter -from src.modal.simulation import run_simulation_impl +from policyengine_api_simulation.simulation_runtime import run_simulation_impl from policyengine_api_simulation.compat_models import ( EconomyComparison, SimulationOptions, diff --git a/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_macro_output.py b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_macro_output.py new file mode 100644 index 000000000..7ef198b71 --- /dev/null +++ b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_macro_output.py @@ -0,0 +1,134 @@ +"""Internal schemas for the simulation API single-year macro output. + +These models define the legacy dictionary contract the simulation API returns +without exposing that schema through the gateway OpenAPI surface. The gateway +still treats job results as unstructured dictionaries for older callers. +""" + +from __future__ import annotations + +from typing import Any, Generic, TypeVar + +from pydantic import BaseModel, ConfigDict, RootModel + +T = TypeVar("T") + + +class MacroOutputModel(BaseModel): + """Base model for internal macro output schemas.""" + + model_config = ConfigDict(extra="forbid") + + +class MacroRootModel(RootModel[T], Generic[T]): + """Base model for internal root schemas that dump to dict/list values.""" + + +class BudgetaryImpact(MacroOutputModel): + tax_revenue_impact: float + state_tax_revenue_impact: float + benefit_spending_impact: float + budgetary_impact: float + households: float + baseline_net_income: float + + +BudgetaryOutput = BudgetaryImpact + + +class DetailedBudgetProgramOutput(MacroOutputModel): + baseline: float + reform: float + difference: float + + +class DetailedBudgetOutput(MacroRootModel[dict[str, DetailedBudgetProgramOutput]]): + pass + + +class DecileOutput(MacroOutputModel): + average: dict[str, float] + relative: dict[str, float] + + +class IntraDecileOutput(MacroOutputModel): + deciles: dict[str, list[float]] + all: dict[str, float] + + +class BaselineReformValue(MacroOutputModel): + baseline: float + reform: float + + +class AgePovertyOutput(MacroOutputModel): + child: BaselineReformValue + adult: BaselineReformValue + senior: BaselineReformValue + all: BaselineReformValue + + +class GenderPovertyOutput(MacroOutputModel): + male: BaselineReformValue + female: BaselineReformValue + + +class RacePovertyOutput(MacroOutputModel): + white: BaselineReformValue + black: BaselineReformValue + hispanic: BaselineReformValue + other: BaselineReformValue + + +class PovertyOutput(MacroOutputModel): + poverty: AgePovertyOutput + deep_poverty: AgePovertyOutput + + +class PovertyByGenderOutput(MacroOutputModel): + poverty: GenderPovertyOutput + deep_poverty: GenderPovertyOutput + + +class PovertyByRaceOutput(MacroOutputModel): + poverty: RacePovertyOutput + + +class PovertyModuleOutputs(MacroOutputModel): + poverty: PovertyOutput + poverty_by_gender: PovertyByGenderOutput + poverty_by_race: PovertyByRaceOutput | None + + +class InequalityOutput(MacroOutputModel): + gini: BaselineReformValue + top_10_pct_share: BaselineReformValue + top_1_pct_share: BaselineReformValue + + +class LaborSupplyResponseOutput(MacroRootModel[dict[str, Any]]): + pass + + +class GeographicImpactOutput(MacroRootModel[list[dict[str, Any]]]): + pass + + +class SingleYearMacroOutput(MacroOutputModel): + model_version: str + data_version: str + budget: BudgetaryImpact + detailed_budget: DetailedBudgetOutput + decile: DecileOutput + inequality: InequalityOutput + poverty: PovertyOutput + poverty_by_gender: PovertyByGenderOutput + poverty_by_race: PovertyByRaceOutput | None + intra_decile: IntraDecileOutput + wealth_decile: DecileOutput | None + intra_wealth_decile: IntraDecileOutput | None + labor_supply_response: LaborSupplyResponseOutput | None + constituency_impact: GeographicImpactOutput | None + local_authority_impact: GeographicImpactOutput | None + congressional_district_impact: GeographicImpactOutput | None + cliff_impact: None = None diff --git a/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_builder.py b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_builder.py new file mode 100644 index 000000000..1b60ab5e8 --- /dev/null +++ b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_builder.py @@ -0,0 +1,666 @@ +"""Build and serialize the runtime simulation macro output.""" + +from __future__ import annotations + +import logging +import math +from collections.abc import Iterable, Mapping +from dataclasses import dataclass, field +from importlib import import_module +from typing import Any + +from policyengine_api_simulation.release_bundle import get_country_release_bundle +from policyengine_api_simulation.simulation_macro_output import ( + AgePovertyOutput, + BaselineReformValue, + BudgetaryImpact, + DecileOutput, + DetailedBudgetOutput, + DetailedBudgetProgramOutput, + GeographicImpactOutput, + GenderPovertyOutput, + InequalityOutput, + IntraDecileOutput, + LaborSupplyResponseOutput, + PovertyModuleOutputs, + PovertyByGenderOutput, + PovertyByRaceOutput, + PovertyOutput, + RacePovertyOutput, + SingleYearMacroOutput, +) + +logger = logging.getLogger(__name__) + +INTRA_DECILE_COLUMNS = { + "Lose more than 5%": "lose_more_than_5pct", + "Lose less than 5%": "lose_less_than_5pct", + "No change": "no_change", + "Gain less than 5%": "gain_less_than_5pct", + "Gain more than 5%": "gain_more_than_5pct", +} + +US_POVERTY_TYPES = { + "spm": "poverty", + "spm_deep": "deep_poverty", +} + +UK_POVERTY_TYPES = { + "relative_bhc": "poverty", + "absolute_bhc": "deep_poverty", +} + + +def _number(value: Any, default: float = 0.0) -> float: + if value is None: + return default + try: + result = float(value) + except (TypeError, ValueError): + return default + if math.isnan(result) or math.isinf(result): + return default + return result + + +def _collection_records(collection: Any) -> list[dict[str, Any]]: + if collection is None: + return [] + dataframe = getattr(collection, "dataframe", None) + if dataframe is not None: + return list(dataframe.to_dict("records")) + if isinstance(collection, list): + return [dict(item) for item in collection if isinstance(item, Mapping)] + return [] + + +def _output_model_dump(value: Any) -> Any: + if value is None: + return None + if hasattr(value, "model_dump"): + return value.model_dump(mode="json") + if isinstance(value, Mapping): + return dict(value) + return None + + +def _empty_baseline_reform_value() -> dict[str, float]: + return {"baseline": 0.0, "reform": 0.0} + + +def _empty_age_poverty() -> dict[str, dict[str, float]]: + return { + "child": _empty_baseline_reform_value(), + "adult": _empty_baseline_reform_value(), + "senior": _empty_baseline_reform_value(), + "all": _empty_baseline_reform_value(), + } + + +def _empty_gender_poverty() -> dict[str, dict[str, float]]: + return { + "male": _empty_baseline_reform_value(), + "female": _empty_baseline_reform_value(), + } + + +def _poverty_type(country: str, row: Mapping[str, Any]) -> str | None: + poverty_type = str(row.get("poverty_type") or "").lower() + if country == "us": + return US_POVERTY_TYPES.get(poverty_type) + return UK_POVERTY_TYPES.get(poverty_type) + + +def _fill_poverty_block( + *, + country: str, + output: dict[str, dict[str, dict[str, float]]], + baseline_records: Iterable[Mapping[str, Any]], + reform_records: Iterable[Mapping[str, Any]], + default_group: str, +) -> None: + for side, records in (("baseline", baseline_records), ("reform", reform_records)): + for row in records: + poverty_type = _poverty_type(country, row) + if poverty_type is None: + continue + if poverty_type not in output: + continue + group = str(row.get("filter_group") or default_group).lower() + if group not in output[poverty_type]: + continue + output[poverty_type][group][side] = _number(row.get("rate")) + + +def _age_poverty_output(values: dict[str, dict[str, float]]) -> AgePovertyOutput: + return AgePovertyOutput( + child=BaselineReformValue(**values["child"]), + adult=BaselineReformValue(**values["adult"]), + senior=BaselineReformValue(**values["senior"]), + all=BaselineReformValue(**values["all"]), + ) + + +def _gender_poverty_output( + values: dict[str, dict[str, float]], +) -> GenderPovertyOutput: + return GenderPovertyOutput( + male=BaselineReformValue(**values["male"]), + female=BaselineReformValue(**values["female"]), + ) + + +def _race_poverty_output(values: dict[str, dict[str, float]]) -> RacePovertyOutput: + return RacePovertyOutput( + white=BaselineReformValue(**values["white"]), + black=BaselineReformValue(**values["black"]), + hispanic=BaselineReformValue(**values["hispanic"]), + other=BaselineReformValue(**values["other"]), + ) + + +def _entity_data(simulation, entity: str): + if simulation.output_dataset is None or simulation.output_dataset.data is None: + simulation.ensure() + return getattr(simulation.output_dataset.data, entity) + + +def _sum_output_variable(simulation, variable: str, entity: str) -> float: + data = _entity_data(simulation, entity) + if variable in data.columns: + return float(data[variable].sum()) + + from policyengine.outputs import Aggregate, AggregateType + + output = Aggregate( + simulation=simulation, + variable=variable, + entity=entity, + aggregate_type=AggregateType.SUM, + ) + output.run() + return float(output.result) + + +def _try_sum_output_variable(simulation, variable: str, entity: str) -> float: + try: + return _sum_output_variable(simulation, variable, entity) + except Exception: + logger.warning("Unable to calculate sum for %s", variable, exc_info=True) + return 0.0 + + +def _change_output_variable(baseline, reform, variable: str, entity: str) -> float: + baseline_data = _entity_data(baseline, entity) + reform_data = _entity_data(reform, entity) + if variable in baseline_data.columns and variable in reform_data.columns: + return float((reform_data[variable] - baseline_data[variable]).sum()) + + from policyengine.outputs import ChangeAggregate, ChangeAggregateType + + output = ChangeAggregate( + baseline_simulation=baseline, + reform_simulation=reform, + variable=variable, + entity=entity, + aggregate_type=ChangeAggregateType.SUM, + ) + output.run() + return float(output.result) + + +def _try_change_output_variable(baseline, reform, variable: str, entity: str) -> float: + try: + return _change_output_variable(baseline, reform, variable, entity) + except Exception: + logger.warning("Unable to calculate change for %s", variable, exc_info=True) + return 0.0 + + +def _output_module_function(module_name: str, name: str): + module = import_module(f"policyengine.outputs.{module_name}") + return getattr(module, name) + + +def _poverty_module_function(name: str): + return _output_module_function("poverty", name) + + +def _try_compute_output(label: str, fn, *args, **kwargs): + try: + return fn(*args, **kwargs) + except Exception: + logger.warning("Unable to calculate %s", label, exc_info=True) + return None + + +@dataclass +class SimulationOutputBuilder: + country: str + simulation_params: dict[str, Any] + country_module: Any + dataset: Any + baseline: Any + reform: Any + resolved_data_version: str | None = None + _analysis: Any = field(default=None, init=False) + + def __post_init__(self) -> None: + self.country = self.country.lower() + + @property + def analysis(self) -> Any: + if self._analysis is None: + self._analysis = self.country_module.economic_impact_analysis( + self.baseline, self.reform + ) + return self._analysis + + def build(self) -> SingleYearMacroOutput: + poverty_outputs = self._build_poverty_outputs() + wealth_decile = getattr(self.analysis, "wealth_decile_impacts", None) + intra_wealth_decile = getattr( + self.analysis, "intra_wealth_decile_impacts", None + ) + + return SingleYearMacroOutput( + model_version=self._model_version(), + data_version=self._data_version(), + budget=self._build_budgetary_impact(), + detailed_budget=self._build_detailed_budget(), + decile=self._build_decile(), + inequality=self._build_inequality(), + poverty=poverty_outputs.poverty, + poverty_by_gender=poverty_outputs.poverty_by_gender, + poverty_by_race=poverty_outputs.poverty_by_race, + intra_decile=self._build_intra_decile_output(), + wealth_decile=self._build_wealth_decile(wealth_decile), + intra_wealth_decile=self._build_intra_wealth_decile(intra_wealth_decile), + labor_supply_response=self._build_labor_supply_response(), + congressional_district_impact=(self._build_congressional_district_impact()), + constituency_impact=self._build_uk_constituency_impact(), + local_authority_impact=self._build_uk_local_authority_impact(), + cliff_impact=None, + ) + + def serialize(self) -> dict[str, Any]: + return self.build().model_dump(mode="json") + + def _build_detailed_budget(self) -> DetailedBudgetOutput: + collection = getattr(self.analysis, "program_statistics", None) + if isinstance(collection, DetailedBudgetOutput): + return collection + detailed_budget: dict[str, DetailedBudgetProgramOutput] = {} + for row in _collection_records(collection): + program_name = row.get("program_name") + if not program_name: + continue + baseline = _number(row.get("baseline_total")) + reform = _number(row.get("reform_total")) + detailed_budget[str(program_name)] = DetailedBudgetProgramOutput( + baseline=baseline, + reform=reform, + difference=_number(row.get("change"), reform - baseline), + ) + return DetailedBudgetOutput(detailed_budget) + + def _build_decile(self) -> DecileOutput: + return self._build_decile_output(getattr(self.analysis, "decile_impacts", None)) + + def _build_inequality(self) -> InequalityOutput: + baseline = getattr(self.analysis, "baseline_inequality", None) + reform = getattr(self.analysis, "reform_inequality", None) + if isinstance(baseline, InequalityOutput): + return baseline + return InequalityOutput( + gini=BaselineReformValue( + baseline=_number(getattr(baseline, "gini", None)), + reform=_number(getattr(reform, "gini", None)), + ), + top_10_pct_share=BaselineReformValue( + baseline=_number(getattr(baseline, "top_10_share", None)), + reform=_number(getattr(reform, "top_10_share", None)), + ), + top_1_pct_share=BaselineReformValue( + baseline=_number(getattr(baseline, "top_1_share", None)), + reform=_number(getattr(reform, "top_1_share", None)), + ), + ) + + def _build_budgetary_impact(self) -> BudgetaryImpact: + tax_revenue_impact = _change_output_variable( + self.baseline, self.reform, "household_tax", entity="household" + ) + benefit_spending_impact = _change_output_variable( + self.baseline, self.reform, "household_benefits", entity="household" + ) + state_tax_revenue_impact = ( + _change_output_variable( + self.baseline, + self.reform, + "household_state_income_tax", + entity="household", + ) + if self.country == "us" + else 0.0 + ) + + return BudgetaryImpact( + tax_revenue_impact=tax_revenue_impact, + state_tax_revenue_impact=state_tax_revenue_impact, + benefit_spending_impact=benefit_spending_impact, + budgetary_impact=tax_revenue_impact - benefit_spending_impact, + households=_sum_output_variable( + self.baseline, "household_weight", entity="household" + ), + baseline_net_income=_sum_output_variable( + self.baseline, "household_net_income", entity="household" + ), + ) + + def _build_poverty_outputs(self) -> PovertyModuleOutputs: + prefix = "us" if self.country == "us" else "uk" + baseline_poverty_by_age = _try_compute_output( + "baseline poverty by age", + _poverty_module_function(f"calculate_{prefix}_poverty_by_age"), + self.baseline, + ) + reform_poverty_by_age = _try_compute_output( + "reform poverty by age", + _poverty_module_function(f"calculate_{prefix}_poverty_by_age"), + self.reform, + ) + baseline_poverty_by_gender = _try_compute_output( + "baseline poverty by gender", + _poverty_module_function(f"calculate_{prefix}_poverty_by_gender"), + self.baseline, + ) + reform_poverty_by_gender = _try_compute_output( + "reform poverty by gender", + _poverty_module_function(f"calculate_{prefix}_poverty_by_gender"), + self.reform, + ) + baseline_poverty_by_race = None + reform_poverty_by_race = None + if self.country == "us": + baseline_poverty_by_race = _try_compute_output( + "baseline poverty by race", + _poverty_module_function("calculate_us_poverty_by_race"), + self.baseline, + ) + reform_poverty_by_race = _try_compute_output( + "reform poverty by race", + _poverty_module_function("calculate_us_poverty_by_race"), + self.reform, + ) + return PovertyModuleOutputs( + poverty=self._build_poverty_output( + baseline=getattr(self.analysis, "baseline_poverty", None), + reform=getattr(self.analysis, "reform_poverty", None), + baseline_by_age=baseline_poverty_by_age, + reform_by_age=reform_poverty_by_age, + ), + poverty_by_gender=self._build_poverty_by_gender_output( + baseline_by_gender=baseline_poverty_by_gender, + reform_by_gender=reform_poverty_by_gender, + ), + poverty_by_race=( + self._build_poverty_by_race_output( + baseline_by_race=baseline_poverty_by_race, + reform_by_race=reform_poverty_by_race, + ) + if self.country == "us" + else None + ), + ) + + def _build_intra_decile_output(self) -> IntraDecileOutput: + from policyengine.outputs.intra_decile_impact import ( + compute_intra_decile_impacts, + ) + + collection = _try_compute_output( + "intra-decile impacts", + compute_intra_decile_impacts, + self.baseline, + self.reform, + income_variable="household_net_income", + entity="household", + ) + return self._build_intra_decile_output_from_collection(collection) + + def _build_wealth_decile(self, wealth_decile) -> DecileOutput | None: + if self.country != "uk": + return None + return self._build_decile_output(wealth_decile) + + def _build_intra_wealth_decile( + self, intra_wealth_decile + ) -> IntraDecileOutput | None: + if self.country != "uk": + return None + return self._build_intra_decile_output_from_collection(intra_wealth_decile) + + def _build_labor_supply_response(self) -> LaborSupplyResponseOutput | None: + labor_supply_response = getattr(self.analysis, "labor_supply_response", None) + if isinstance(labor_supply_response, LaborSupplyResponseOutput): + return labor_supply_response + output = _output_model_dump(labor_supply_response) + return LaborSupplyResponseOutput(output) if isinstance(output, dict) else None + + def _build_geographic_impact_output( + self, value: Any + ) -> GeographicImpactOutput | None: + if isinstance(value, GeographicImpactOutput): + return value + records = _output_model_dump(value) + if isinstance(records, list): + return GeographicImpactOutput( + [dict(item) for item in records if isinstance(item, Mapping)] + ) + if isinstance(value, list): + return GeographicImpactOutput( + [dict(item) for item in value if isinstance(item, Mapping)] + ) + return None + + def _build_decile_output(self, collection: Any) -> DecileOutput: + if isinstance(collection, DecileOutput): + return collection + average: dict[str, float] = {} + relative: dict[str, float] = {} + for row in sorted( + _collection_records(collection), + key=lambda item: _number(item.get("decile")), + ): + decile = int(_number(row.get("decile"))) + if decile <= 0: + continue + key = str(decile) + average[key] = _number(row.get("absolute_change")) + relative[key] = _number(row.get("relative_change")) + return DecileOutput(average=average, relative=relative) + + def _build_intra_decile_output_from_collection( + self, collection: Any + ) -> IntraDecileOutput: + if isinstance(collection, IntraDecileOutput): + return collection + deciles: dict[str, list[float]] = {label: [] for label in INTRA_DECILE_COLUMNS} + all_values: dict[str, float] = {label: 0.0 for label in INTRA_DECILE_COLUMNS} + rows = [ + row + for row in sorted( + _collection_records(collection), + key=lambda item: _number(item.get("decile")), + ) + if int(_number(row.get("decile"))) > 0 + ] + + for label, column in INTRA_DECILE_COLUMNS.items(): + values = [_number(row.get(column)) for row in rows] + deciles[label] = values + all_values[label] = sum(values) / len(values) if values else 0.0 + return IntraDecileOutput(deciles=deciles, all=all_values) + + def _build_poverty_output( + self, + *, + baseline: Any, + reform: Any, + baseline_by_age: Any, + reform_by_age: Any, + ) -> PovertyOutput: + if isinstance(baseline, PovertyOutput): + return baseline + result = { + "poverty": _empty_age_poverty(), + "deep_poverty": _empty_age_poverty(), + } + _fill_poverty_block( + country=self.country, + output=result, + baseline_records=_collection_records(baseline), + reform_records=_collection_records(reform), + default_group="all", + ) + _fill_poverty_block( + country=self.country, + output=result, + baseline_records=_collection_records(baseline_by_age), + reform_records=_collection_records(reform_by_age), + default_group="all", + ) + return PovertyOutput( + poverty=_age_poverty_output(result["poverty"]), + deep_poverty=_age_poverty_output(result["deep_poverty"]), + ) + + def _build_poverty_by_gender_output( + self, + *, + baseline_by_gender: Any, + reform_by_gender: Any, + ) -> PovertyByGenderOutput: + if isinstance(baseline_by_gender, PovertyByGenderOutput): + return baseline_by_gender + result = { + "poverty": _empty_gender_poverty(), + "deep_poverty": _empty_gender_poverty(), + } + _fill_poverty_block( + country=self.country, + output=result, + baseline_records=_collection_records(baseline_by_gender), + reform_records=_collection_records(reform_by_gender), + default_group="all", + ) + return PovertyByGenderOutput( + poverty=_gender_poverty_output(result["poverty"]), + deep_poverty=_gender_poverty_output(result["deep_poverty"]), + ) + + def _build_poverty_by_race_output( + self, + *, + baseline_by_race: Any, + reform_by_race: Any, + ) -> PovertyByRaceOutput: + if isinstance(baseline_by_race, PovertyByRaceOutput): + return baseline_by_race + result = { + "poverty": { + "white": _empty_baseline_reform_value(), + "black": _empty_baseline_reform_value(), + "hispanic": _empty_baseline_reform_value(), + "other": _empty_baseline_reform_value(), + } + } + _fill_poverty_block( + country="us", + output=result, + baseline_records=_collection_records(baseline_by_race), + reform_records=_collection_records(reform_by_race), + default_group="all", + ) + return PovertyByRaceOutput(poverty=_race_poverty_output(result["poverty"])) + + def _build_congressional_district_impact( + self, + ) -> GeographicImpactOutput | None: + if self.country != "us": + return None + + from policyengine.outputs.congressional_district_impact import ( + compute_us_congressional_district_impacts, + ) + + impact = _try_compute_output( + "congressional district impacts", + compute_us_congressional_district_impacts, + self.baseline, + self.reform, + ) + return self._build_geographic_impact_output( + getattr(impact, "district_results", None) if impact is not None else None + ) + + def _build_uk_constituency_impact(self) -> GeographicImpactOutput | None: + if self.country != "uk": + return None + + impact = _try_compute_output( + "constituency impacts", + _output_module_function( + "constituency_impact", "compute_uk_constituency_impacts" + ), + self.baseline, + self.reform, + ) + if impact is None: + return None + return self._build_geographic_impact_output( + getattr(impact, "constituency_results", None) + ) + + def _build_uk_local_authority_impact(self) -> GeographicImpactOutput | None: + if self.country != "uk": + return None + + impact = _try_compute_output( + "local authority impacts", + _output_module_function( + "local_authority_impact", "compute_uk_local_authority_impacts" + ), + self.baseline, + self.reform, + ) + if impact is None: + return None + return self._build_geographic_impact_output( + getattr(impact, "local_authority_results", None) + ) + + def _model_version(self) -> str: + return str(getattr(self.country_module.model, "version", "")) + + def _data_version(self) -> str: + if self.resolved_data_version: + return str(self.resolved_data_version) + data = self.simulation_params.get("data") + if isinstance(data, str) and "@" in data: + revision = data.rsplit("@", maxsplit=1)[1] + if revision: + return revision + if self.simulation_params.get("data_version"): + return str(self.simulation_params["data_version"]) + metadata = getattr(self.dataset, "metadata", {}) or {} + for key in ("data_version", "version"): + value = metadata.get(key) + if value is not None: + return str(value) + try: + return get_country_release_bundle(self.country).data_version + except ValueError: + pass + return "" diff --git a/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_runtime.py b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_runtime.py new file mode 100644 index 000000000..e6088921b --- /dev/null +++ b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_runtime.py @@ -0,0 +1,513 @@ +""" +Simulation implementation - pure logic with snapshotted imports. + +This module avoids importing policyengine at module level so the worker can +load the requested country module without triggering cross-country imports. +No Modal dependencies here. +""" + +import contextlib +import json +import logging +import os +import tempfile +from dataclasses import dataclass +from importlib import import_module +from typing import Any, Iterator + +from policyengine_api_simulation.hf_dataset import with_hf_revision +from policyengine_api_simulation.release_bundle import ( + get_country_release_bundle, + resolve_bundle_dataset_name, + resolve_bundle_dataset_uri, +) +from policyengine_api_simulation.simulation_output_builder import ( + SimulationOutputBuilder, +) +from policyengine_api_simulation.telemetry import split_internal_payload + +logger = logging.getLogger(__name__) + +os.environ.setdefault("POLICYENGINE_SKIP_COUNTRY_IMPORTS", "1") + +DEFAULT_YEAR = 2026 + + +@dataclass(frozen=True) +class RegionResolution: + code: str + dataset_reference: str | None = None + scoping_strategy: Any | None = None + + +def _normalize_credentials_blob(creds_json: str) -> str: + """Return the raw JSON blob, decoding the outer escape if present. + + The upstream Modal secret sometimes stores the credentials payload + double-encoded (the entire JSON object is wrapped in quotes with + backslash-escaped interior quotes). Historically we always attempted + the unescape as a fallback which could accidentally parse an already + clean blob. Only unwrap when the payload looks wrapped.""" + + try: + json.loads(creds_json) + except json.JSONDecodeError: + looks_escaped = creds_json.lstrip().startswith('"') or '\\"' in creds_json + if looks_escaped: + return json.loads(f'"{creds_json}"') + raise + return creds_json + + +@contextlib.contextmanager +def setup_gcp_credentials() -> Iterator[None]: + """ + Set up GCP credentials from environment variable. + + Modal secrets are injected as environment variables. The GCP library + expects GOOGLE_APPLICATION_CREDENTIALS to point to a file path. If + credentials JSON is provided, write it to a temp file that's deleted + on exit. This runs as a context manager to guarantee cleanup even if + the caller raises mid-simulation; the previous fire-and-forget + ``tempfile.mkstemp`` path leaked credential material on disk every + time a container served a request. + """ + # Log available GCP-related env vars for debugging + gcp_vars = { + k: v[:50] + "..." if len(v) > 50 else v + for k, v in os.environ.items() + if "GOOGLE" in k or "GCP" in k or "CREDENTIAL" in k + } + logger.info(f"GCP-related env vars: {list(gcp_vars.keys())}") + + # Check if credentials are already set as a file path + if os.environ.get("GOOGLE_APPLICATION_CREDENTIALS"): + logger.info("GOOGLE_APPLICATION_CREDENTIALS already set") + yield + return + + # Check for credentials JSON in various env var names + creds_json = ( + os.environ.get("GOOGLE_APPLICATION_CREDENTIALS_JSON") + or os.environ.get("GCP_CREDENTIALS_JSON") + or os.environ.get("GOOGLE_CREDENTIALS") + or os.environ.get("SERVICE_ACCOUNT_JSON") + ) + + if not creds_json: + logger.warning("No GCP credentials found in environment variables") + yield + return + + normalized = _normalize_credentials_blob(creds_json) + + # ``NamedTemporaryFile(delete=True)`` removes the file when the context + # exits (either normally or via exception). We restore any prior value + # of ``GOOGLE_APPLICATION_CREDENTIALS`` so a retry in the same + # container doesn't silently pick up a path that no longer exists. + previous = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS") + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=True + ) as creds_file: + creds_file.write(normalized) + creds_file.flush() + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = creds_file.name + logger.info(f"GCP credentials written to {creds_file.name}") + try: + yield + finally: + if previous is None: + os.environ.pop("GOOGLE_APPLICATION_CREDENTIALS", None) + else: + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = previous + + +def run_simulation_impl(params: dict) -> dict: + """ + Execute economic simulation. + + Pure implementation with no Modal dependencies. + Accepts the gateway simulation payload and returns the legacy macro result dict. + """ + # Set up GCP credentials if needed. The credentials temp file is + # cleaned up on exit so we never leave signed JSON material on disk. + with setup_gcp_credentials(): + return _run_simulation_impl_core(params) + + +def _parse_year(params: dict[str, Any]) -> int: + value = params.get("time_period") or params.get("year") or DEFAULT_YEAR + return int(value) + + +def _normalise_period_key(period_key: Any) -> str: + """Convert legacy ``start.stop`` period keys to v4 effective dates.""" + text = str(period_key) + parts = text.split(".") + if len(parts) > 1 and len(parts[0]) == 10: + return parts[0] + return text + + +def _normalise_policy(policy: dict[str, Any] | None) -> dict[str, Any] | None: + if not policy: + return None + + normalised: dict[str, Any] = {} + for parameter, value in policy.items(): + if isinstance(value, dict): + normalised[parameter] = { + _normalise_period_key(period): period_value + for period, period_value in value.items() + } + else: + normalised[parameter] = value + return normalised + + +def _resolve_dataset_name(country: str, requested_data: str | None) -> str: + return resolve_bundle_dataset_name(country, requested_data) + + +def _split_requested_revision(requested_data: str) -> tuple[str, str | None]: + if "@" not in requested_data: + return requested_data, None + dataset_name, revision = requested_data.rsplit("@", maxsplit=1) + if not dataset_name or not revision: + raise ValueError(f"Invalid dataset revision reference: {requested_data}") + return dataset_name, revision + + +def _requested_data_version(params: dict[str, Any]) -> str | None: + data_version = params.get("data_version") + if data_version is not None: + return str(data_version) + + data = params.get("data") + if isinstance(data, str) and "@" in data: + _, revision = _split_requested_revision(data) + return revision + return None + + +def _resolve_dataset_reference(country: str, params: dict[str, Any]) -> str: + requested_data = params.get("data") + requested_data = requested_data if isinstance(requested_data, str) else None + requested_data_version = _requested_data_version(params) + + if requested_data_version is None: + return _resolve_dataset_name(country, requested_data) + + if requested_data is None: + dataset_uri = get_country_release_bundle(country).default_dataset_uri + else: + dataset_without_revision, data_revision = _split_requested_revision( + requested_data + ) + if data_revision is not None and data_revision != requested_data_version: + raise ValueError( + "Conflicting dataset revisions: " + f"data requests {data_revision!r} but data_version is " + f"{requested_data_version!r}" + ) + dataset_uri = resolve_bundle_dataset_uri(country, dataset_without_revision) + + return with_hf_revision(dataset_uri, requested_data_version) + + +def _normalise_region_code(country: str, region: Any) -> str: + if region is None or str(region).strip() == "": + return country + + raw = str(region).strip() + if raw.lower() in {"us", "uk"}: + return raw.lower() + + if "/" not in raw: + if country == "us" and len(raw) == 2: + return f"state/{raw.lower()}" + if country == "uk": + return f"country/{raw.lower().replace(' ', '_')}" + return raw + + prefix, value = raw.split("/", maxsplit=1) + prefix = prefix.lower() + value = value.strip() + if prefix == "state": + value = value.lower() + elif prefix == "country": + value = value.lower().replace(" ", "_") + elif prefix in {"congressional_district", "place"}: + value = value.upper() + elif prefix == "local_authority": + value = value.upper() + return f"{prefix}/{value}" + + +def _build_uk_weight_replacement_region(region_code: str): + if "/" not in region_code: + return None + + prefix, value = region_code.split("/", maxsplit=1) + if prefix not in {"constituency", "local_authority"}: + return None + + from policyengine.core.region import Region + from policyengine.core.scoping_strategy import WeightReplacementStrategy + from policyengine.data.uk_geography_assets import ( + CONSTITUENCY_ASSET_SPEC, + LOCAL_AUTHORITY_ASSET_SPEC, + ) + + asset_spec = ( + CONSTITUENCY_ASSET_SPEC + if prefix == "constituency" + else LOCAL_AUTHORITY_ASSET_SPEC + ) + return Region( + code=region_code, + label=value, + region_type=prefix, + parent_code="uk", + scoping_strategy=WeightReplacementStrategy( + weight_matrix_bucket=asset_spec.bucket, + weight_matrix_key=asset_spec.weight_matrix_filename, + lookup_csv_bucket=asset_spec.bucket, + lookup_csv_key=asset_spec.lookup_csv_filename, + region_code=value, + ), + ) + + +def _region_parent_dataset_reference( + country_module, + country: str, + region, + params: dict[str, Any], +) -> str: + parent_code = getattr(region, "parent_code", None) + if parent_code: + parent_region = country_module.model.get_region(parent_code) + parent_dataset_path = getattr(parent_region, "dataset_path", None) + if isinstance(parent_dataset_path, str): + requested_data_version = _requested_data_version(params) + return ( + with_hf_revision(parent_dataset_path, requested_data_version) + if requested_data_version is not None + else parent_dataset_path + ) + return _resolve_dataset_reference(country, params) + + +def _resolve_region( + *, + country_module, + country: str, + params: dict[str, Any], +) -> RegionResolution: + region_code = _normalise_region_code(country, params.get("region")) + if region_code == country: + return RegionResolution( + code=region_code, + dataset_reference=_resolve_dataset_reference(country, params), + ) + + region = country_module.model.get_region(region_code) + if region is None and country == "uk": + region = _build_uk_weight_replacement_region(region_code) + if region is None: + raise ValueError(f"Unsupported {country.upper()} region: {region_code}") + + dataset_path = getattr(region, "dataset_path", None) + requested_data_version = _requested_data_version(params) + if isinstance(dataset_path, str): + dataset_reference = ( + with_hf_revision(dataset_path, requested_data_version) + if requested_data_version is not None + else dataset_path + ) + else: + dataset_reference = _region_parent_dataset_reference( + country_module, + country, + region, + params, + ) + + return RegionResolution( + code=region_code, + dataset_reference=dataset_reference, + scoping_strategy=getattr(region, "scoping_strategy", None), + ) + + +def _microframe_like(frame, weights: str): + from microdf import MicroDataFrame + + return MicroDataFrame(frame.copy(), weights=weights) + + +def _person_group_column(person, entity: str) -> str: + prefixed = f"person_{entity}_id" + if prefixed in person.columns: + return prefixed + return f"{entity}_id" + + +def _subsample_us_dataset(dataset, subsample: int | None): + if not subsample: + return dataset + + from policyengine.tax_benefit_models.us.datasets import ( + PolicyEngineUSDataset, + USYearData, + ) + + dataset.load() + data = dataset.data + household = data.household.head(int(subsample)).copy() + household_ids = set(household["household_id"]) + + person_household_col = _person_group_column(data.person, "household") + person = data.person[data.person[person_household_col].isin(household_ids)].copy() + + def group_subset(entity: str): + person_col = _person_group_column(person, entity) + entity_id_col = f"{entity}_id" + ids = set(person[person_col]) + frame = getattr(data, entity) + return frame[frame[entity_id_col].isin(ids)].copy() + + subset_data = USYearData( + person=_microframe_like(person, "person_weight"), + marital_unit=_microframe_like( + group_subset("marital_unit"), "marital_unit_weight" + ), + family=_microframe_like(group_subset("family"), "family_weight"), + spm_unit=_microframe_like(group_subset("spm_unit"), "spm_unit_weight"), + tax_unit=_microframe_like(group_subset("tax_unit"), "tax_unit_weight"), + household=_microframe_like(household, "household_weight"), + ) + subset_path = os.path.join( + os.environ.get("POLICYENGINE_DATA_FOLDER", "/tmp/policyengine-data"), + f"{dataset.id}_subsample_{subsample}.h5", + ) + return PolicyEngineUSDataset( + id=f"{dataset.id}_subsample_{subsample}", + name=f"{dataset.name} subsample {subsample}", + description=dataset.description, + filepath=subset_path, + year=dataset.year, + is_output_dataset=dataset.is_output_dataset, + metadata=getattr(dataset, "metadata", {}), + metadata_filepath=getattr(dataset, "metadata_filepath", None), + data=subset_data, + ) + + +def _country_module(country: str): + country = country.lower() + if country not in {"us", "uk"}: + raise ValueError(f"Unsupported country: {country}") + + return import_module(f"policyengine.tax_benefit_models.{country}") + + +def _load_dataset( + params: dict[str, Any], + *, + country_module=None, + region_resolution: RegionResolution | None = None, +): + country = params.get("country", "us").lower() + year = _parse_year(params) + country_module = country_module or _country_module(country) + dataset_name = ( + region_resolution.dataset_reference + if region_resolution is not None and region_resolution.dataset_reference + else _resolve_dataset_reference(country, params) + ) + datasets = country_module.ensure_datasets( + datasets=[dataset_name], + years=[year], + data_folder=os.environ.get( + "POLICYENGINE_DATA_FOLDER", "/tmp/policyengine-data" + ), + ) + dataset = next(iter(datasets.values())) + if country == "us": + return _subsample_us_dataset(dataset, params.get("subsample")) + return dataset + + +def _build_simulation( + params: dict[str, Any], + *, + dataset, + policy: dict[str, Any] | None, + scoping_strategy=None, +): + from policyengine.core import Simulation + + country_module = _country_module(params.get("country", "us")) + return Simulation( + dataset=dataset, + tax_benefit_model_version=country_module.model, + policy=policy, + scoping_strategy=scoping_strategy, + ) + + +def _run_simulation_impl_core(params: dict) -> dict: + simulation_params, telemetry, metadata = split_internal_payload(params) + + logger.info( + "Starting simulation for country=%s run_id=%s process_id=%s", + simulation_params.get("country", "unknown"), + getattr(telemetry, "run_id", None), + getattr(telemetry, "process_id", None), + ) + if metadata: + logger.info("Received simulation metadata keys: %s", sorted(metadata)) + + country = simulation_params.get("country", "us").lower() + country_module = _country_module(country) + region_resolution = _resolve_region( + country_module=country_module, + country=country, + params=simulation_params, + ) + dataset = _load_dataset( + simulation_params, + country_module=country_module, + region_resolution=region_resolution, + ) + baseline_policy = _normalise_policy(simulation_params.get("baseline")) + reform_policy = _normalise_policy(simulation_params.get("reform")) + + logger.info("Initialising baseline and reform simulations") + baseline = _build_simulation( + simulation_params, + dataset=dataset, + policy=baseline_policy, + scoping_strategy=region_resolution.scoping_strategy, + ) + reform = _build_simulation( + simulation_params, + dataset=dataset, + policy=reform_policy, + scoping_strategy=region_resolution.scoping_strategy, + ) + + logger.info("Calculating economic impact") + output = SimulationOutputBuilder( + country=country, + simulation_params=simulation_params, + country_module=country_module, + dataset=dataset, + baseline=baseline, + reform=reform, + resolved_data_version=_requested_data_version(simulation_params), + ).serialize() + logger.info("Comparison complete") + return output diff --git a/projects/policyengine-api-simulation/src/policyengine_api_simulation/telemetry.py b/projects/policyengine-api-simulation/src/policyengine_api_simulation/telemetry.py new file mode 100644 index 000000000..dccb418eb --- /dev/null +++ b/projects/policyengine-api-simulation/src/policyengine_api_simulation/telemetry.py @@ -0,0 +1,47 @@ +""" +Internal telemetry helpers for Modal request passthrough. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any, Literal + +from pydantic import BaseModel, ConfigDict + + +CaptureMode = Literal["disabled", "failures", "threshold", "sampled", "always"] + + +class TelemetryEnvelope(BaseModel): + """Minimal shared telemetry payload shape for gateway and worker code.""" + + run_id: str + process_id: str | None = None + request_id: str | None = None + traceparent: str | None = None + requested_at: datetime | None = None + simulation_kind: str | None = None + geography_code: str | None = None + geography_type: str | None = None + config_hash: str | None = None + capture_mode: CaptureMode = "disabled" + + model_config = ConfigDict(extra="forbid") + + +def split_internal_payload( + params: dict[str, Any], +) -> tuple[dict[str, Any], TelemetryEnvelope | None, dict[str, Any] | None]: + """Strip internal passthrough fields before SimulationOptions validation.""" + + simulation_params = dict(params) + raw_telemetry = simulation_params.pop("_telemetry", None) + raw_metadata = simulation_params.pop("_metadata", None) + + telemetry = None + if raw_telemetry is not None: + telemetry = TelemetryEnvelope.model_validate(raw_telemetry) + + metadata = raw_metadata if isinstance(raw_metadata, dict) else None + return simulation_params, telemetry, metadata diff --git a/projects/policyengine-api-simulation/tests/gateway/test_auth.py b/projects/policyengine-api-simulation/tests/gateway/test_auth.py index 3dd038b87..0a6e4b3be 100644 --- a/projects/policyengine-api-simulation/tests/gateway/test_auth.py +++ b/projects/policyengine-api-simulation/tests/gateway/test_auth.py @@ -219,7 +219,7 @@ def test__given_dependency_override__then_gated_endpoint_returns_200( mock_modal["dicts"]["simulation-api-us-versions"] = { "latest": "1.500.0", - "1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0", + "1.500.0": "policyengine-simulation-py4-10-0", } response = client.post( diff --git a/projects/policyengine-api-simulation/tests/gateway/test_budget_window_state.py b/projects/policyengine-api-simulation/tests/gateway/test_budget_window_state.py index 7b9936f41..1058146c2 100644 --- a/projects/policyengine-api-simulation/tests/gateway/test_budget_window_state.py +++ b/projects/policyengine-api-simulation/tests/gateway/test_budget_window_state.py @@ -42,7 +42,7 @@ def test_create_initial_batch_state_builds_queued_years_and_run_id(): batch_job_id="fc-parent-123", request=request, resolved_version="1.500.0", - resolved_app_name="policyengine-simulation-us1-500-0-uk2-66-0", + resolved_app_name="policyengine-simulation-py4-10-0", bundle=PolicyEngineBundle(model_version="1.500.0"), ) @@ -70,7 +70,7 @@ def test_build_batch_status_response_computes_progress_from_completed_years(): batch_job_id="fc-parent-123", request=request, resolved_version="1.500.0", - resolved_app_name="policyengine-simulation-us1-500-0-uk2-66-0", + resolved_app_name="policyengine-simulation-py4-10-0", bundle=PolicyEngineBundle(model_version="1.500.0"), ) state.completed_years = ["2026", "2027"] @@ -100,7 +100,7 @@ def test_batch_state_round_trips_through_modal_dict(mock_modal): batch_job_id="fc-parent-123", request=request, resolved_version="1.500.0", - resolved_app_name="policyengine-simulation-us1-500-0-uk2-66-0", + resolved_app_name="policyengine-simulation-py4-10-0", bundle=PolicyEngineBundle(model_version="1.500.0"), ) put_batch_job_state(state) @@ -126,7 +126,7 @@ def test_state_transition_helpers_track_completion_path(): batch_job_id="fc-parent-123", request=request, resolved_version="1.500.0", - resolved_app_name="policyengine-simulation-us1-500-0-uk2-66-0", + resolved_app_name="policyengine-simulation-py4-10-0", bundle=PolicyEngineBundle(model_version="1.500.0"), ) @@ -188,7 +188,7 @@ def test_state_transition_helpers_track_failed_child(): batch_job_id="fc-parent-123", request=request, resolved_version="1.500.0", - resolved_app_name="policyengine-simulation-us1-500-0-uk2-66-0", + resolved_app_name="policyengine-simulation-py4-10-0", bundle=PolicyEngineBundle(model_version="1.500.0"), ) @@ -216,7 +216,7 @@ def test_mark_child_completed_handles_missing_child_jobs_entry(caplog): batch_job_id="fc-parent-123", request=request, resolved_version="1.500.0", - resolved_app_name="policyengine-simulation-us1-500-0-uk2-66-0", + resolved_app_name="policyengine-simulation-py4-10-0", bundle=PolicyEngineBundle(model_version="1.500.0"), ) # Simulate the crash-recovery path where the running/completed lists @@ -254,7 +254,7 @@ def test_mark_child_failed_handles_missing_child_jobs_entry(caplog): batch_job_id="fc-parent-123", request=request, resolved_version="1.500.0", - resolved_app_name="policyengine-simulation-us1-500-0-uk2-66-0", + resolved_app_name="policyengine-simulation-py4-10-0", bundle=PolicyEngineBundle(model_version="1.500.0"), ) state.running_years = ["2026"] @@ -279,7 +279,7 @@ def test_mark_batch_failed_cancels_any_remaining_running_children(): batch_job_id="fc-parent-123", request=request, resolved_version="1.500.0", - resolved_app_name="policyengine-simulation-us1-500-0-uk2-66-0", + resolved_app_name="policyengine-simulation-py4-10-0", bundle=PolicyEngineBundle(model_version="1.500.0"), ) diff --git a/projects/policyengine-api-simulation/tests/gateway/test_endpoints.py b/projects/policyengine-api-simulation/tests/gateway/test_endpoints.py index 9faab2d6f..4653112c1 100644 --- a/projects/policyengine-api-simulation/tests/gateway/test_endpoints.py +++ b/projects/policyengine-api-simulation/tests/gateway/test_endpoints.py @@ -8,7 +8,8 @@ import pytest from fastapi.testclient import TestClient -from src.modal.release_bundle import resolve_bundle_dataset_uri +from fixtures.gateway.test_endpoints import resolve_test_dataset_uri +from policyengine_api_simulation.hf_dataset import HuggingFaceDatasetReferenceError def expected_bundle( @@ -18,11 +19,18 @@ def expected_bundle( dataset: str | None = None, data_version: str | None = None, ) -> dict[str, str | None]: + resolved_dataset = resolve_test_dataset_uri(country, dataset) + if ( + data_version is not None + and resolved_dataset is not None + and resolved_dataset.startswith("hf://") + ): + resolved_dataset = ( + f"{resolved_dataset.rsplit('@', maxsplit=1)[0]}@{data_version}" + ) bundle: dict[str, str | None] = { "model_version": model_version, - "dataset": resolve_bundle_dataset_uri(country, dataset) - if dataset is not None - else None, + "dataset": resolved_dataset, } if data_version is not None: bundle["data_version"] = data_version @@ -43,7 +51,7 @@ def test__given_us_country_no_version__then_returns_latest_app(self, mock_modal) # Given mock_modal["dicts"]["simulation-api-us-versions"] = { "latest": "1.500.0", - "1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0", + "1.500.0": "policyengine-simulation-py4-10-0", } # When @@ -51,7 +59,7 @@ def test__given_us_country_no_version__then_returns_latest_app(self, mock_modal) # Then assert resolved_version == "1.500.0" - assert app_name == "policyengine-simulation-us1-500-0-uk2-66-0" + assert app_name == "policyengine-simulation-py4-10-0" def test__given_us_country_with_version__then_returns_specified_app( self, mock_modal @@ -65,7 +73,7 @@ def test__given_us_country_with_version__then_returns_specified_app( # Given mock_modal["dicts"]["simulation-api-us-versions"] = { - "1.459.0": "policyengine-simulation-us1-459-0-uk2-65-9" + "1.459.0": "policyengine-simulation-py3-9-0" } # When @@ -73,7 +81,7 @@ def test__given_us_country_with_version__then_returns_specified_app( # Then assert resolved_version == "1.459.0" - assert app_name == "policyengine-simulation-us1-459-0-uk2-65-9" + assert app_name == "policyengine-simulation-py3-9-0" def test__given_uk_country__then_uses_uk_version_dict(self, mock_modal): """ @@ -86,7 +94,7 @@ def test__given_uk_country__then_uses_uk_version_dict(self, mock_modal): # Given mock_modal["dicts"]["simulation-api-uk-versions"] = { "latest": "2.66.0", - "2.66.0": "policyengine-simulation-us1-500-0-uk2-66-0", + "2.66.0": "policyengine-simulation-py4-10-0", } # When @@ -94,7 +102,7 @@ def test__given_uk_country__then_uses_uk_version_dict(self, mock_modal): # Then assert resolved_version == "2.66.0" - assert app_name == "policyengine-simulation-us1-500-0-uk2-66-0" + assert app_name == "policyengine-simulation-py4-10-0" def test__given_invalid_country__then_raises_value_error(self): """ @@ -138,7 +146,7 @@ def test__given_regular_data_value__then_routes_to_run_simulation( # Given mock_modal["dicts"]["simulation-api-us-versions"] = { "latest": "1.500.0", - "1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0", + "1.500.0": "policyengine-simulation-py4-10-0", } request_body = { @@ -154,7 +162,7 @@ def test__given_regular_data_value__then_routes_to_run_simulation( # Then assert response.status_code == 200 assert mock_modal["func"].last_from_name_call == ( - "policyengine-simulation-us1-500-0-uk2-66-0", + "policyengine-simulation-py4-10-0", "run_simulation", ) @@ -169,7 +177,7 @@ def test__given_no_data_value__then_routes_to_run_simulation( # Given mock_modal["dicts"]["simulation-api-us-versions"] = { "latest": "1.500.0", - "1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0", + "1.500.0": "policyengine-simulation-py4-10-0", } request_body = { @@ -184,7 +192,7 @@ def test__given_no_data_value__then_routes_to_run_simulation( # Then assert response.status_code == 200 assert mock_modal["func"].last_from_name_call == ( - "policyengine-simulation-us1-500-0-uk2-66-0", + "policyengine-simulation-py4-10-0", "run_simulation", ) assert "time_period" not in mock_modal["func"].last_payload @@ -201,7 +209,7 @@ def test__given_submission__then_returns_job_id_and_poll_url( # Given mock_modal["dicts"]["simulation-api-us-versions"] = { "latest": "1.500.0", - "1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0", + "1.500.0": "policyengine-simulation-py4-10-0", } request_body = { @@ -231,7 +239,7 @@ def test__given_submission_with_telemetry__then_preserves_run_id( """ mock_modal["dicts"]["simulation-api-us-versions"] = { "latest": "1.500.0", - "1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0", + "1.500.0": "policyengine-simulation-py4-10-0", } request_body = { @@ -263,7 +271,7 @@ def test__given_submission_with_data__then_returns_resolved_bundle_metadata( # Given mock_modal["dicts"]["simulation-api-us-versions"] = { "latest": "1.500.0", - "1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0", + "1.500.0": "policyengine-simulation-py4-10-0", } request_body = { @@ -279,7 +287,7 @@ def test__given_submission_with_data__then_returns_resolved_bundle_metadata( # Then assert response.status_code == 200 data = response.json() - assert data["resolved_app_name"] == "policyengine-simulation-us1-500-0-uk2-66-0" + assert data["resolved_app_name"] == "policyengine-simulation-py4-10-0" assert data["policyengine_bundle"] == expected_bundle( "us", "1.500.0", @@ -291,7 +299,7 @@ def test__given_submission_with_alias_data__then_bundle_dataset_uses_manifest_ur ): mock_modal["dicts"]["simulation-api-us-versions"] = { "latest": "1.500.0", - "1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0", + "1.500.0": "policyengine-simulation-py4-10-0", } request_body = { @@ -305,16 +313,91 @@ def test__given_submission_with_alias_data__then_bundle_dataset_uses_manifest_ur assert response.status_code == 200 data = response.json() - assert data["policyengine_bundle"]["dataset"] == resolve_bundle_dataset_uri( + assert data["policyengine_bundle"]["dataset"] == resolve_test_dataset_uri( "us", "enhanced_cps_2024" ) + def test__given_submission_with_logical_revision__then_bundle_dataset_uses_revision( + self, mock_modal, client: TestClient + ): + mock_modal["dicts"]["simulation-api-us-versions"] = { + "latest": "1.500.0", + "1.500.0": "policyengine-simulation-py4-10-0", + } + + response = client.post( + "/simulate/economy/comparison", + json={ + "country": "us", + "scope": "macro", + "reform": {}, + "data": "enhanced_cps_2024@1.77.0", + }, + ) + + assert response.status_code == 200 + assert response.json()["policyengine_bundle"]["dataset"] == ( + "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0" + ) + + def test__given_submission_with_conflicting_data_versions__then_returns_400( + self, mock_modal, client: TestClient + ): + mock_modal["dicts"]["simulation-api-us-versions"] = { + "latest": "1.500.0", + "1.500.0": "policyengine-simulation-py4-10-0", + } + + response = client.post( + "/simulate/economy/comparison", + json={ + "country": "us", + "scope": "macro", + "reform": {}, + "data": "enhanced_cps_2024@1.77.0", + "data_version": "1.78.2", + }, + ) + + assert response.status_code == 400 + assert mock_modal["func"].last_payload is None + + def test__given_submission_with_invalid_hf_revision__then_returns_400_before_spawn( + self, mock_modal, client: TestClient, monkeypatch + ): + mock_modal["dicts"]["simulation-api-us-versions"] = { + "latest": "1.500.0", + "1.500.0": "policyengine-simulation-py4-10-0", + } + + def reject_revision(dataset_uri, revision): + raise HuggingFaceDatasetReferenceError("revision missing") + + monkeypatch.setattr( + "src.modal.gateway.endpoints.with_hf_revision", + reject_revision, + ) + + response = client.post( + "/simulate/economy/comparison", + json={ + "country": "us", + "scope": "macro", + "reform": {}, + "data": "enhanced_cps_2024@does-not-exist", + }, + ) + + assert response.status_code == 400 + assert response.json()["detail"] == "revision missing" + assert mock_modal["func"].last_payload is None + def test__given_submission_with_uk_alias_data__then_bundle_dataset_is_versioned_uri( self, mock_modal, client: TestClient ): mock_modal["dicts"]["simulation-api-uk-versions"] = { "latest": "2.66.0", - "2.66.0": "policyengine-simulation-us1-500-0-uk2-66-0", + "2.66.0": "policyengine-simulation-py4-10-0", } request_body = { @@ -328,7 +411,7 @@ def test__given_submission_with_uk_alias_data__then_bundle_dataset_is_versioned_ assert response.status_code == 200 data = response.json() - assert data["policyengine_bundle"]["dataset"] == resolve_bundle_dataset_uri( + assert data["policyengine_bundle"]["dataset"] == resolve_test_dataset_uri( "uk", "enhanced_frs" ) @@ -337,7 +420,7 @@ def test__given_submission_with_runtime_bundle__then_accepts_internal_provenance ): mock_modal["dicts"]["simulation-api-us-versions"] = { "latest": "1.500.0", - "1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0", + "1.500.0": "policyengine-simulation-py4-10-0", } request_body = { @@ -372,7 +455,7 @@ def test__given_submission_with_unknown_alias_data__then_bundle_dataset_is_prese ): mock_modal["dicts"]["simulation-api-us-versions"] = { "latest": "1.500.0", - "1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0", + "1.500.0": "policyengine-simulation-py4-10-0", } request_body = { @@ -399,7 +482,7 @@ def test__given_submitted_job__then_job_status_includes_bundle_metadata( # Given mock_modal["dicts"]["simulation-api-us-versions"] = { "latest": "1.500.0", - "1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0", + "1.500.0": "policyengine-simulation-py4-10-0", } submit_response = client.post( @@ -420,7 +503,7 @@ def test__given_submitted_job__then_job_status_includes_bundle_metadata( data = response.json() assert data["status"] == "complete" assert "run_id" not in data - assert data["resolved_app_name"] == "policyengine-simulation-us1-500-0-uk2-66-0" + assert data["resolved_app_name"] == "policyengine-simulation-py4-10-0" assert data["policyengine_bundle"] == expected_bundle( "us", "1.500.0", @@ -432,7 +515,7 @@ def test__given_submitted_job_with_telemetry__then_polling_echoes_run_id( ): mock_modal["dicts"]["simulation-api-us-versions"] = { "latest": "1.500.0", - "1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0", + "1.500.0": "policyengine-simulation-py4-10-0", } submit_response = client.post( @@ -490,7 +573,7 @@ def test__given_running_job__then_polling_returns_202( ): mock_modal["dicts"]["simulation-api-us-versions"] = { "latest": "1.500.0", - "1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0", + "1.500.0": "policyengine-simulation-py4-10-0", } submit_response = client.post( @@ -514,7 +597,7 @@ def test__given_expired_modal_output__then_polling_returns_404( ): mock_modal["dicts"]["simulation-api-us-versions"] = { "latest": "1.500.0", - "1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0", + "1.500.0": "policyengine-simulation-py4-10-0", } submit_response = client.post( @@ -542,7 +625,7 @@ def test__given_modal_call_not_found__then_polling_returns_404( ): mock_modal["dicts"]["simulation-api-us-versions"] = { "latest": "1.500.0", - "1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0", + "1.500.0": "policyengine-simulation-py4-10-0", } submit_response = client.post( @@ -568,7 +651,7 @@ def test__given_worker_error__then_polling_returns_redacted_500( ): mock_modal["dicts"]["simulation-api-us-versions"] = { "latest": "1.500.0", - "1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0", + "1.500.0": "policyengine-simulation-py4-10-0", } submit_response = client.post( @@ -600,7 +683,7 @@ def test__given_budget_window_submission__then_returns_parent_batch_job_id( ): mock_modal["dicts"]["simulation-api-us-versions"] = { "latest": "1.500.0", - "1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0", + "1.500.0": "policyengine-simulation-py4-10-0", } response = client.post( @@ -618,7 +701,7 @@ def test__given_budget_window_submission__then_returns_parent_batch_job_id( assert response.status_code == 200 assert mock_modal["func"].last_from_name_call == ( - "policyengine-simulation-us1-500-0-uk2-66-0", + "policyengine-simulation-py4-10-0", "run_budget_window_batch", ) assert response.json() == { @@ -627,7 +710,7 @@ def test__given_budget_window_submission__then_returns_parent_batch_job_id( "poll_url": "/budget-window-jobs/mock-batch-job-id-123", "country": "us", "version": "1.500.0", - "resolved_app_name": "policyengine-simulation-us1-500-0-uk2-66-0", + "resolved_app_name": "policyengine-simulation-py4-10-0", "policyengine_bundle": expected_bundle("us", "1.500.0"), } @@ -636,7 +719,7 @@ def test__given_budget_window_submission__then_initial_poll_returns_seed_state( ): mock_modal["dicts"]["simulation-api-us-versions"] = { "latest": "1.500.0", - "1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0", + "1.500.0": "policyengine-simulation-py4-10-0", } submit_response = client.post( @@ -672,7 +755,7 @@ def test__given_budget_window_submission__then_initial_poll_returns_seed_state( "child_jobs": {}, "result": None, "error": None, - "resolved_app_name": "policyengine-simulation-us1-500-0-uk2-66-0", + "resolved_app_name": "policyengine-simulation-py4-10-0", "policyengine_bundle": expected_bundle("us", "1.500.0"), "run_id": "batch-run-123", } @@ -697,7 +780,7 @@ def test__given_batch_state__then_poll_returns_completed_response( region="us", version="1.500.0", target="general", - resolved_app_name="policyengine-simulation-us1-500-0-uk2-66-0", + resolved_app_name="policyengine-simulation-py4-10-0", policyengine_bundle=PolicyEngineBundle(model_version="1.500.0"), start_year="2026", window_size=2, @@ -880,7 +963,7 @@ def test__given_parent_call_raises__then_failure_persists_across_polls( mock_modal["dicts"]["simulation-api-us-versions"] = { "latest": "1.500.0", - "1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0", + "1.500.0": "policyengine-simulation-py4-10-0", } submit_response = client.post( diff --git a/projects/policyengine-api-simulation/tests/gateway/test_models.py b/projects/policyengine-api-simulation/tests/gateway/test_models.py index 9bb56ecf9..d7747cb64 100644 --- a/projects/policyengine-api-simulation/tests/gateway/test_models.py +++ b/projects/policyengine-api-simulation/tests/gateway/test_models.py @@ -265,7 +265,7 @@ def test_job_submit_response_creates_with_all_fields(self): "poll_url": "/jobs/fc-abc123", "country": "us", "version": "1.459.0", - "resolved_app_name": "policyengine-simulation-us1-459-0-uk2-65-9", + "resolved_app_name": "policyengine-simulation-py3-9-0", "policyengine_bundle": { "model_version": "1.459.0", "policyengine_version": None, @@ -284,7 +284,7 @@ def test_job_submit_response_creates_with_all_fields(self): assert response.country == "us" assert response.version == "1.459.0" assert ( - response.resolved_app_name == "policyengine-simulation-us1-459-0-uk2-65-9" + response.resolved_app_name == "policyengine-simulation-py3-9-0" ) assert response.policyengine_bundle.model_version == "1.459.0" assert response.policyengine_bundle.policyengine_version is None @@ -348,7 +348,7 @@ def test_job_status_response_accepts_bundle_metadata(self): response = JobStatusResponse( status="complete", result={"budget": {"total": 1000000}}, - resolved_app_name="policyengine-simulation-us1-459-0-uk2-65-9", + resolved_app_name="policyengine-simulation-py3-9-0", policyengine_bundle={ "model_version": "1.459.0", "policyengine_version": None, @@ -358,7 +358,7 @@ def test_job_status_response_accepts_bundle_metadata(self): ) assert ( - response.resolved_app_name == "policyengine-simulation-us1-459-0-uk2-65-9" + response.resolved_app_name == "policyengine-simulation-py3-9-0" ) assert response.policyengine_bundle is not None assert response.policyengine_bundle.dataset == ( @@ -483,7 +483,7 @@ def test_budget_window_batch_submit_response_serializes_correctly(self): poll_url="/budget-window-jobs/bw-123", country="us", version="1.500.0", - resolved_app_name="policyengine-simulation-us1-500-0-uk2-66-0", + resolved_app_name="policyengine-simulation-py4-10-0", policyengine_bundle={ "model_version": "1.500.0", "dataset": "default", @@ -497,7 +497,7 @@ def test_budget_window_batch_submit_response_serializes_correctly(self): "poll_url": "/budget-window-jobs/bw-123", "country": "us", "version": "1.500.0", - "resolved_app_name": "policyengine-simulation-us1-500-0-uk2-66-0", + "resolved_app_name": "policyengine-simulation-py4-10-0", "policyengine_bundle": { "model_version": "1.500.0", "policyengine_version": None, diff --git a/projects/policyengine-api-simulation/tests/gateway/test_package_imports.py b/projects/policyengine-api-simulation/tests/gateway/test_package_imports.py index 3e40c4eec..efff842c7 100644 --- a/projects/policyengine-api-simulation/tests/gateway/test_package_imports.py +++ b/projects/policyengine-api-simulation/tests/gateway/test_package_imports.py @@ -1,4 +1,5 @@ import sys +import importlib def test_gateway_models_import_does_not_import_fastapi_endpoints( @@ -9,3 +10,22 @@ def test_gateway_models_import_does_not_import_fastapi_endpoints( assert gateway_import_module_names.endpoints not in sys.modules assert gateway_import_module_names.fastapi not in sys.modules + + +def test_gateway_endpoints_import_does_not_import_policyengine_bundle( + isolated_gateway_model_import_modules, + gateway_import_module_names, +): + release_bundle_module = sys.modules.pop( + "policyengine_api_simulation.release_bundle", None + ) + + try: + importlib.import_module(gateway_import_module_names.endpoints) + + assert "policyengine_api_simulation.release_bundle" not in sys.modules + finally: + if release_bundle_module is not None: + sys.modules["policyengine_api_simulation.release_bundle"] = ( + release_bundle_module + ) diff --git a/projects/policyengine-api-simulation/tests/test_budget_window_batch.py b/projects/policyengine-api-simulation/tests/test_budget_window_batch.py index 49549b4e7..37214426f 100644 --- a/projects/policyengine-api-simulation/tests/test_budget_window_batch.py +++ b/projects/policyengine-api-simulation/tests/test_budget_window_batch.py @@ -163,7 +163,7 @@ def _build_parent_payload(*, window_size: int = 3): payload["_telemetry"] = request.telemetry.model_dump(mode="json") payload["_metadata"] = { "resolved_version": "1.500.0", - "resolved_app_name": "policyengine-simulation-us1-500-0-uk2-66-0", + "resolved_app_name": "policyengine-simulation-py4-10-0", "policyengine_bundle": PolicyEngineBundle(model_version="1.500.0").model_dump( mode="json" ), @@ -176,7 +176,7 @@ def _seed_parent_batch(request: BudgetWindowBatchRequest, batch_job_id: str): batch_job_id=batch_job_id, request=request, resolved_version="1.500.0", - resolved_app_name="policyengine-simulation-us1-500-0-uk2-66-0", + resolved_app_name="policyengine-simulation-py4-10-0", bundle=PolicyEngineBundle(model_version="1.500.0"), ) state_module.put_batch_job_seed(seed) @@ -227,7 +227,7 @@ def test_run_budget_window_batch_impl_completes_and_respects_max_parallel( call_registry=mock_batch_modal["call_registry"], ) mock_batch_modal["functions"][ - ("policyengine-simulation-us1-500-0-uk2-66-0", "run_simulation") + ("policyengine-simulation-py4-10-0", "run_simulation") ] = run_simulation result = run_budget_window_batch_impl(payload) @@ -269,7 +269,7 @@ def test_run_budget_window_batch_impl_marks_failure(mock_batch_modal): call_registry=mock_batch_modal["call_registry"], ) mock_batch_modal["functions"][ - ("policyengine-simulation-us1-500-0-uk2-66-0", "run_simulation") + ("policyengine-simulation-py4-10-0", "run_simulation") ] = run_simulation result = run_budget_window_batch_impl(payload) @@ -324,7 +324,7 @@ def test_scheduler_sleep_exponentially_backs_off_then_resets_on_progress( call_registry=mock_batch_modal["call_registry"], ) mock_batch_modal["functions"][ - ("policyengine-simulation-us1-500-0-uk2-66-0", "run_simulation") + ("policyengine-simulation-py4-10-0", "run_simulation") ] = run_simulation sleeps: list[float] = [] @@ -335,7 +335,7 @@ def test_scheduler_sleep_exponentially_backs_off_then_resets_on_progress( batch_job_id=mock_batch_modal["parent_call_id"], request=request, resolved_version="1.500.0", - resolved_app_name="policyengine-simulation-us1-500-0-uk2-66-0", + resolved_app_name="policyengine-simulation-py4-10-0", bundle=PolicyEngineBundle(model_version="1.500.0"), raw_params=payload, ), @@ -373,7 +373,7 @@ def test_run_budget_window_batch_impl_fails_on_malformed_child_result( call_registry=mock_batch_modal["call_registry"], ) mock_batch_modal["functions"][ - ("policyengine-simulation-us1-500-0-uk2-66-0", "run_simulation") + ("policyengine-simulation-py4-10-0", "run_simulation") ] = run_simulation result = run_budget_window_batch_impl(payload) diff --git a/projects/policyengine-api-simulation/tests/test_budget_window_context.py b/projects/policyengine-api-simulation/tests/test_budget_window_context.py index fbb086a6e..9df03e1e7 100644 --- a/projects/policyengine-api-simulation/tests/test_budget_window_context.py +++ b/projects/policyengine-api-simulation/tests/test_budget_window_context.py @@ -27,7 +27,7 @@ def _build_parent_payload(): payload["_telemetry"] = request.telemetry.model_dump(mode="json") payload["_metadata"] = { "resolved_version": "1.500.0", - "resolved_app_name": "policyengine-simulation-us1-500-0-uk2-66-0", + "resolved_app_name": "policyengine-simulation-py4-10-0", "policyengine_bundle": PolicyEngineBundle(model_version="1.500.0").model_dump( mode="json" ), @@ -49,7 +49,7 @@ def test_build_batch_context_extracts_request_and_metadata(): assert context.request.telemetry is not None assert context.request.telemetry.run_id == "batch-run-123" assert context.resolved_version == "1.500.0" - assert context.resolved_app_name == "policyengine-simulation-us1-500-0-uk2-66-0" + assert context.resolved_app_name == "policyengine-simulation-py4-10-0" assert context.bundle == PolicyEngineBundle(model_version="1.500.0") diff --git a/projects/policyengine-api-simulation/tests/test_budget_window_scheduler.py b/projects/policyengine-api-simulation/tests/test_budget_window_scheduler.py index 7ac9422aa..49bcc43c2 100644 --- a/projects/policyengine-api-simulation/tests/test_budget_window_scheduler.py +++ b/projects/policyengine-api-simulation/tests/test_budget_window_scheduler.py @@ -137,7 +137,7 @@ def budget_window_semi_integration_client( runtime = SemiIntegrationRuntime() runtime.dicts["simulation-api-us-versions"] = { "latest": "1.500.0", - "1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0", + "1.500.0": "policyengine-simulation-py4-10-0", } class MockModalDict: diff --git a/projects/policyengine-api-simulation/tests/test_gcp_credentials.py b/projects/policyengine-api-simulation/tests/test_gcp_credentials.py index 371c2fe2f..074cf0abf 100644 --- a/projects/policyengine-api-simulation/tests/test_gcp_credentials.py +++ b/projects/policyengine-api-simulation/tests/test_gcp_credentials.py @@ -1,4 +1,4 @@ -"""Tests for GCP credentials setup in ``src.modal.simulation``.""" +"""Tests for GCP credentials setup in ``policyengine_api_simulation.simulation_runtime``.""" from __future__ import annotations @@ -7,7 +7,7 @@ import pytest -from src.modal.simulation import ( +from policyengine_api_simulation.simulation_runtime import ( _normalize_credentials_blob, setup_gcp_credentials, ) diff --git a/projects/policyengine-api-simulation/tests/test_hf_dataset.py b/projects/policyengine-api-simulation/tests/test_hf_dataset.py new file mode 100644 index 000000000..a343654e7 --- /dev/null +++ b/projects/policyengine-api-simulation/tests/test_hf_dataset.py @@ -0,0 +1,104 @@ +"""Tests for Hugging Face dataset revision validation helpers.""" + +from __future__ import annotations + +import json + +import pytest + +import policyengine_api_simulation.hf_dataset as hf_dataset +from policyengine_api_simulation.hf_dataset import ( + HuggingFaceDatasetReferenceError, + parse_hf_dataset_uri, + validate_hf_dataset_uri, + with_hf_revision, +) + + +class _FakeResponse: + def __init__(self, payload: dict): + self.payload = payload + + def __enter__(self): + return self + + def __exit__(self, *args): + return None + + def read(self): + return json.dumps(self.payload).encode("utf-8") + + +def test_parse_hf_dataset_uri_extracts_repo_path_and_revision(): + parsed = parse_hf_dataset_uri( + "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0" + ) + + assert parsed is not None + assert parsed.repo_id == "policyengine/policyengine-us-data" + assert parsed.path == "enhanced_cps_2024.h5" + assert parsed.revision == "1.77.0" + + +def test_fetch_hf_dataset_revision_uses_dataset_revision_api(monkeypatch): + hf_dataset._fetch_hf_dataset_revision.cache_clear() + seen = {} + + def fake_urlopen(request, timeout): + seen["url"] = request.full_url + seen["headers"] = dict(request.header_items()) + seen["timeout"] = timeout + return _FakeResponse({"sha": "abc123", "siblings": []}) + + monkeypatch.setattr(hf_dataset, "urlopen", fake_urlopen) + + payload = hf_dataset._fetch_hf_dataset_revision( + "policyengine/policyengine-us-data", + "1.77.0", + "hf-token", + ) + + assert payload == {"sha": "abc123", "siblings": []} + assert seen["url"] == ( + "https://huggingface.co/api/datasets/" + "policyengine/policyengine-us-data/revision/1.77.0" + ) + assert seen["headers"]["Authorization"] == "Bearer hf-token" + assert seen["timeout"] == hf_dataset.HF_REQUEST_TIMEOUT_SECONDS + + +def test_validate_hf_dataset_uri_rejects_revision_missing_artifact(monkeypatch): + monkeypatch.setattr( + hf_dataset, + "_fetch_hf_dataset_revision", + lambda repo_id, revision, token: {"siblings": [{"rfilename": "other_file.h5"}]}, + ) + + with pytest.raises( + HuggingFaceDatasetReferenceError, + match="does not contain artifact", + ): + validate_hf_dataset_uri( + "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0" + ) + + +def test_with_hf_revision_validates_and_preserves_requested_revision(monkeypatch): + calls = [] + + def fake_validate(dataset_uri): + calls.append(dataset_uri) + return dataset_uri + + monkeypatch.setattr(hf_dataset, "validate_hf_dataset_uri", fake_validate) + + assert ( + with_hf_revision( + "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.110.12", + "1.77.0", + ) + == "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0" + ) + assert calls == [ + "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0" + ] diff --git a/projects/policyengine-api-simulation/tests/test_modal_telemetry.py b/projects/policyengine-api-simulation/tests/test_modal_telemetry.py index e65831727..59155537d 100644 --- a/projects/policyengine-api-simulation/tests/test_modal_telemetry.py +++ b/projects/policyengine-api-simulation/tests/test_modal_telemetry.py @@ -1,4 +1,4 @@ -from src.modal.telemetry import TelemetryEnvelope, split_internal_payload +from policyengine_api_simulation.telemetry import TelemetryEnvelope, split_internal_payload def test_split_internal_payload__removes_internal_fields(): diff --git a/projects/policyengine-api-simulation/tests/test_policyengine_dependency_source.py b/projects/policyengine-api-simulation/tests/test_policyengine_dependency_source.py index 2927a7b92..d6c69c480 100644 --- a/projects/policyengine-api-simulation/tests/test_policyengine_dependency_source.py +++ b/projects/policyengine-api-simulation/tests/test_policyengine_dependency_source.py @@ -72,8 +72,16 @@ def test_modal_app_reads_policyengine_pins_from_pyproject(): assert '"policyengine-core"' in modal_source +def test_modal_app_name_is_keyed_to_policyengine_py_version(): + modal_source = MODAL_APP_PATH.read_text(encoding="utf-8") + + assert "def get_app_name(policyengine_version: str)" in modal_source + assert "policyengine-simulation-py" in modal_source + assert "policyengine-simulation-us" not in modal_source + + def test_country_package_pins_match_policyengine_bundle(): - from src.modal.release_bundle import get_country_release_bundle + from policyengine_api_simulation.release_bundle import get_country_release_bundle pyproject = _load_toml(PYPROJECT_PATH) diff --git a/projects/policyengine-api-simulation/tests/test_release_bundle.py b/projects/policyengine-api-simulation/tests/test_release_bundle.py index 4257c6293..0fed12ae0 100644 --- a/projects/policyengine-api-simulation/tests/test_release_bundle.py +++ b/projects/policyengine-api-simulation/tests/test_release_bundle.py @@ -1,12 +1,26 @@ """Tests for policyengine.py release bundle helpers.""" -from src.modal.release_bundle import ( +import pytest + +from policyengine_api_simulation.release_bundle import ( get_country_release_bundle, resolve_bundle_dataset_name, resolve_bundle_dataset_uri, ) +@pytest.fixture(autouse=True) +def stub_hf_revision_validation(monkeypatch): + monkeypatch.setattr( + "policyengine_api_simulation.release_bundle.with_hf_revision", + lambda dataset_uri, revision: ( + f"{dataset_uri.rsplit('@', maxsplit=1)[0]}@{revision}" + if dataset_uri.startswith("hf://") + else dataset_uri + ), + ) + + def test_country_release_bundle_exposes_model_and_data_versions(): us_bundle = get_country_release_bundle("us") uk_bundle = get_country_release_bundle("uk") @@ -50,11 +64,15 @@ def test_resolve_bundle_dataset_uri_preserves_explicit_dataset_uri_and_revision( assert resolve_bundle_dataset_uri("us", uri) == uri -def test_resolve_bundle_dataset_uri_preserves_explicit_logical_revision(): +def test_resolve_bundle_dataset_uri_maps_explicit_logical_revision_to_hf_uri(): dataset = "enhanced_cps_2024@1.110.12" - assert resolve_bundle_dataset_name("us", dataset) == dataset - assert resolve_bundle_dataset_uri("us", dataset) == dataset + assert resolve_bundle_dataset_name("us", dataset).startswith( + "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.110.12" + ) + assert resolve_bundle_dataset_uri("us", dataset).startswith( + "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.110.12" + ) def test_resolve_bundle_dataset_uri_preserves_explicit_gcs_uri(): @@ -77,3 +95,8 @@ def test_resolve_bundle_dataset_uri_preserves_unmanaged_unknown_values(): assert resolve_bundle_dataset_uri("us", "custom_dataset_label") == ( "custom_dataset_label" ) + + +def test_resolve_bundle_dataset_uri_rejects_unknown_logical_revision(): + with pytest.raises(ValueError, match="Unknown dataset revision reference"): + resolve_bundle_dataset_uri("us", "custom_dataset_label@1.0.0") diff --git a/projects/policyengine-api-simulation/tests/test_simulation_api_contracts.py b/projects/policyengine-api-simulation/tests/test_simulation_api_contracts.py index d6418bc82..777bdd86c 100644 --- a/projects/policyengine-api-simulation/tests/test_simulation_api_contracts.py +++ b/projects/policyengine-api-simulation/tests/test_simulation_api_contracts.py @@ -7,7 +7,7 @@ BudgetWindowTotals, JobStatusResponse, ) -from src.modal.simulation_macro_output import SingleYearMacroOutput +from policyengine_api_simulation.simulation_macro_output import SingleYearMacroOutput from fixtures.test_simulation_api_contracts import ( CURRENT_REQUIRED_BUDGET_KEYS, diff --git a/projects/policyengine-api-simulation/tests/test_simulation_output_builder.py b/projects/policyengine-api-simulation/tests/test_simulation_output_builder.py index c5558a396..31066f354 100644 --- a/projects/policyengine-api-simulation/tests/test_simulation_output_builder.py +++ b/projects/policyengine-api-simulation/tests/test_simulation_output_builder.py @@ -21,9 +21,12 @@ REFORM_POVERTY_BY_RACE, fake_analysis, ) -from src.modal.simulation import _normalise_policy -from src.modal.simulation import _run_simulation_impl_core -from src.modal.simulation_macro_output import ( +from policyengine_api_simulation.simulation_runtime import RegionResolution +from policyengine_api_simulation.simulation_runtime import _normalise_policy +from policyengine_api_simulation.simulation_runtime import _resolve_dataset_reference +from policyengine_api_simulation.simulation_runtime import _resolve_region +from policyengine_api_simulation.simulation_runtime import _run_simulation_impl_core +from policyengine_api_simulation.simulation_macro_output import ( BudgetaryImpact, BudgetaryOutput, DecileOutput, @@ -33,7 +36,9 @@ PovertyOutput, SingleYearMacroOutput, ) -from src.modal.simulation_output_builder import SimulationOutputBuilder +from policyengine_api_simulation.simulation_output_builder import ( + SimulationOutputBuilder, +) class _FakeOutputDataset: @@ -127,7 +132,7 @@ def compute(simulation): return compute monkeypatch.setattr( - "src.modal.simulation_output_builder._poverty_module_function", + "policyengine_api_simulation.simulation_output_builder._poverty_module_function", fake_poverty_module_function, ) monkeypatch.setattr( @@ -251,8 +256,8 @@ def fake_country_module(country): assert country == "us" return country_module - def fake_build_simulation(params, *, dataset, policy): - build_calls.append((params, dataset, policy)) + def fake_build_simulation(params, *, dataset, policy, scoping_strategy=None): + build_calls.append((params, dataset, policy, scoping_strategy)) return baseline_simulation if len(build_calls) == 1 else reform_simulation class FakeSimulationOutputBuilder: @@ -262,11 +267,24 @@ def __init__(self, **kwargs): def serialize(self): return CURRENT_SINGLE_YEAR_MACRO_RESULT - monkeypatch.setattr("src.modal.simulation._country_module", fake_country_module) - monkeypatch.setattr("src.modal.simulation._load_dataset", lambda params: dataset) - monkeypatch.setattr("src.modal.simulation._build_simulation", fake_build_simulation) monkeypatch.setattr( - "src.modal.simulation.SimulationOutputBuilder", + "policyengine_api_simulation.simulation_runtime._country_module", + fake_country_module, + ) + monkeypatch.setattr( + "policyengine_api_simulation.simulation_runtime._resolve_region", + lambda **kwargs: RegionResolution(code="us", dataset_reference="dataset"), + ) + monkeypatch.setattr( + "policyengine_api_simulation.simulation_runtime._load_dataset", + lambda params, country_module, region_resolution: dataset, + ) + monkeypatch.setattr( + "policyengine_api_simulation.simulation_runtime._build_simulation", + fake_build_simulation, + ) + monkeypatch.setattr( + "policyengine_api_simulation.simulation_runtime.SimulationOutputBuilder", FakeSimulationOutputBuilder, ) @@ -281,6 +299,8 @@ def serialize(self): assert result == CURRENT_SINGLE_YEAR_MACRO_RESULT assert build_calls[0][2] == {"gov.test.parameter": {"2026-01-01": 1}} assert build_calls[1][2] == {"gov.test.parameter": {"2026-01-01": 2}} + assert build_calls[0][3] is None + assert build_calls[1][3] is None assert builder_calls == [ { "country": "us", @@ -293,10 +313,160 @@ def serialize(self): "dataset": dataset, "baseline": baseline_simulation, "reform": reform_simulation, + "resolved_data_version": None, } ] +def test_resolve_region_uses_dedicated_region_dataset_with_requested_version( + monkeypatch, +): + monkeypatch.setattr( + "policyengine_api_simulation.simulation_runtime.with_hf_revision", + lambda dataset_uri, revision: ( + f"{dataset_uri.rsplit('@', maxsplit=1)[0]}@{revision}" + ), + ) + state = SimpleNamespace( + dataset_path="hf://policyengine/policyengine-us-data/states/CA.h5@1.110.12", + scoping_strategy=None, + parent_code="us", + ) + country_module = SimpleNamespace( + model=SimpleNamespace( + get_region=lambda code: state if code == "state/ca" else None + ) + ) + + resolution = _resolve_region( + country_module=country_module, + country="us", + params={"region": "state/ca", "data_version": "1.77.0"}, + ) + + assert resolution.code == "state/ca" + assert resolution.dataset_reference == ( + "hf://policyengine/policyengine-us-data/states/CA.h5@1.77.0" + ) + assert resolution.scoping_strategy is None + + +def test_resolve_dataset_reference_applies_data_version_to_logical_dataset( + monkeypatch, +): + monkeypatch.setattr( + "policyengine_api_simulation.simulation_runtime.resolve_bundle_dataset_uri", + lambda country, requested_data: ( + "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.110.12" + ), + ) + monkeypatch.setattr( + "policyengine_api_simulation.simulation_runtime.with_hf_revision", + lambda dataset_uri, revision: ( + f"{dataset_uri.rsplit('@', maxsplit=1)[0]}@{revision}" + ), + ) + + assert ( + _resolve_dataset_reference( + "us", + {"data": "enhanced_cps_2024", "data_version": "1.77.0"}, + ) + == "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0" + ) + + +def test_resolve_region_scopes_us_place_from_parent_state_dataset(monkeypatch): + scoping_strategy = object() + place = SimpleNamespace( + dataset_path=None, + scoping_strategy=scoping_strategy, + parent_code="state/ca", + ) + state = SimpleNamespace( + dataset_path="hf://policyengine/policyengine-us-data/states/CA.h5@1.110.12", + scoping_strategy=None, + parent_code="us", + ) + regions = {"place/CA-57000": place, "state/ca": state} + country_module = SimpleNamespace( + model=SimpleNamespace(get_region=lambda code: regions.get(code)) + ) + + resolution = _resolve_region( + country_module=country_module, + country="us", + params={"region": "place/ca-57000"}, + ) + + assert resolution.code == "place/CA-57000" + assert resolution.dataset_reference == ( + "hf://policyengine/policyengine-us-data/states/CA.h5@1.110.12" + ) + assert resolution.scoping_strategy is scoping_strategy + + +def test_resolve_region_scopes_uk_country_from_national_dataset(): + scoping_strategy = object() + england = SimpleNamespace( + dataset_path=None, + scoping_strategy=scoping_strategy, + parent_code="uk", + ) + uk = SimpleNamespace( + dataset_path="hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.40.3", + scoping_strategy=None, + parent_code=None, + ) + regions = {"country/england": england, "uk": uk} + country_module = SimpleNamespace( + model=SimpleNamespace(get_region=lambda code: regions.get(code)) + ) + + resolution = _resolve_region( + country_module=country_module, + country="uk", + params={"region": "England"}, + ) + + assert resolution.code == "country/england" + assert resolution.dataset_reference == ( + "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.40.3" + ) + assert resolution.scoping_strategy is scoping_strategy + + +def test_builder_data_version_prefers_resolved_revision_then_dataset_metadata(): + baseline, reform = _macro_baseline_reform() + country_module = SimpleNamespace( + model=SimpleNamespace(version="1.700.0"), + economic_impact_analysis=lambda baseline_simulation, reform_simulation: ( + fake_analysis() + ), + ) + + resolved_builder = SimulationOutputBuilder( + country="us", + simulation_params={"country": "us"}, + country_module=country_module, + dataset=SimpleNamespace(metadata={"version": "metadata-version"}), + baseline=baseline, + reform=reform, + resolved_data_version="1.77.0", + ) + metadata_builder = SimulationOutputBuilder( + country="us", + simulation_params={"country": "us"}, + country_module=country_module, + dataset=SimpleNamespace(metadata={"version": "metadata-version"}), + baseline=baseline, + reform=reform, + ) + + assert resolved_builder._data_version() == "1.77.0" + assert metadata_builder._data_version() == "metadata-version" + + def test_builder_budgetary_impact_uses_materialized_columns_and_uk_state_tax_zero(): baseline = _FakeSimulation( pd.DataFrame( @@ -347,7 +517,7 @@ def fail_change_output_variable(*args, **kwargs): raise RuntimeError("household_tax missing") monkeypatch.setattr( - "src.modal.simulation_output_builder._change_output_variable", + "policyengine_api_simulation.simulation_output_builder._change_output_variable", fail_change_output_variable, ) @@ -372,7 +542,7 @@ def compute(baseline_simulation, reform_simulation): return compute monkeypatch.setattr( - "src.modal.simulation_output_builder._output_module_function", + "policyengine_api_simulation.simulation_output_builder._output_module_function", fake_output_module_function, ) @@ -407,7 +577,7 @@ def compute(baseline_simulation, reform_simulation): return compute monkeypatch.setattr( - "src.modal.simulation_output_builder._output_module_function", + "policyengine_api_simulation.simulation_output_builder._output_module_function", fake_output_module_function, ) diff --git a/projects/policyengine-api-simulation/tests/test_standalone_simulation_contract.py b/projects/policyengine-api-simulation/tests/test_standalone_simulation_contract.py index cc17db53c..ca182f704 100644 --- a/projects/policyengine-api-simulation/tests/test_standalone_simulation_contract.py +++ b/projects/policyengine-api-simulation/tests/test_standalone_simulation_contract.py @@ -1,11 +1,34 @@ """Contract tests for the live synchronous simulation FastAPI app.""" +from importlib import import_module +from pathlib import Path + from fastapi.testclient import TestClient from fixtures.test_simulation_api_contracts import CURRENT_SINGLE_YEAR_MACRO_RESULT from policyengine_api_simulation.main import app +PACKAGED_RUNTIME_MODULES = ( + "policyengine_api_simulation.compat_models", + "policyengine_api_simulation.hf_dataset", + "policyengine_api_simulation.release_bundle", + "policyengine_api_simulation.simulation", + "policyengine_api_simulation.simulation_macro_output", + "policyengine_api_simulation.simulation_output_builder", + "policyengine_api_simulation.simulation_runtime", + "policyengine_api_simulation.telemetry", +) + + +def test_standalone_package_runtime_does_not_import_unpackaged_modal_source(): + for module_name in PACKAGED_RUNTIME_MODULES: + module = import_module(module_name) + source = Path(module.__file__).read_text(encoding="utf-8") + + assert "src.modal" not in source + + def test_standalone_simulation_openapi_keeps_legacy_schema_names(): spec = app.openapi() route = spec["paths"]["/simulate/economy/comparison"]["post"] diff --git a/projects/policyengine-api-simulation/tests/test_update_version_registry.py b/projects/policyengine-api-simulation/tests/test_update_version_registry.py index 2cc588602..c08714a1e 100644 --- a/projects/policyengine-api-simulation/tests/test_update_version_registry.py +++ b/projects/policyengine-api-simulation/tests/test_update_version_registry.py @@ -64,7 +64,7 @@ def test_update_version_dict__keeps_latest_when_incoming_older(patched_modal): stores["main/simulation-api-us-versions"] = FakeDict( { "latest": "1.500.0", - "1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0", + "1.500.0": "policyengine-simulation-py4-10-0", } ) @@ -72,12 +72,12 @@ def test_update_version_dict__keeps_latest_when_incoming_older(patched_modal): "simulation-api-us-versions", "main", "1.400.0", - "policyengine-simulation-us1-400-0-uk2-66-0", + "policyengine-simulation-py3-8-0", ) snapshot = stores["main/simulation-api-us-versions"].snapshot() assert snapshot["latest"] == "1.500.0" - assert snapshot["1.400.0"] == "policyengine-simulation-us1-400-0-uk2-66-0" + assert snapshot["1.400.0"] == "policyengine-simulation-py3-8-0" def test_update_version_dict__advances_latest_when_incoming_newer(patched_modal): @@ -85,7 +85,7 @@ def test_update_version_dict__advances_latest_when_incoming_newer(patched_modal) stores["main/simulation-api-us-versions"] = FakeDict( { "latest": "1.500.0", - "1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0", + "1.500.0": "policyengine-simulation-py4-10-0", } ) @@ -93,7 +93,7 @@ def test_update_version_dict__advances_latest_when_incoming_newer(patched_modal) "simulation-api-us-versions", "main", "1.601.2", - "policyengine-simulation-us1-601-2-uk2-70-0", + "policyengine-simulation-py4-11-0", ) snapshot = stores["main/simulation-api-us-versions"].snapshot() @@ -105,7 +105,7 @@ def test_update_version_dict__force_latest_allows_downgrade(patched_modal): stores["main/simulation-api-us-versions"] = FakeDict( { "latest": "1.500.0", - "1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0", + "1.500.0": "policyengine-simulation-py4-10-0", } ) @@ -113,7 +113,7 @@ def test_update_version_dict__force_latest_allows_downgrade(patched_modal): "simulation-api-us-versions", "main", "1.400.0", - "policyengine-simulation-us1-400-0-uk2-66-0", + "policyengine-simulation-py3-8-0", force_latest=True, ) @@ -128,9 +128,41 @@ def test_update_version_dict__new_registry_sets_latest_even_without_force( "simulation-api-uk-versions", "staging", "2.66.0", - "policyengine-simulation-us1-500-0-uk2-66-0", + "policyengine-simulation-py4-10-0", ) snapshot = patched_modal["staging/simulation-api-uk-versions"].snapshot() assert snapshot["latest"] == "2.66.0" - assert snapshot["2.66.0"] == "policyengine-simulation-us1-500-0-uk2-66-0" + assert snapshot["2.66.0"] == "policyengine-simulation-py4-10-0" + + +def test_put_app_release_bundle_metadata_records_app_and_py_version_aliases( + patched_modal, + monkeypatch, +): + def fake_country_bundle_metadata(country: str) -> dict: + return { + "country": country, + "model_version": "1.0.0" if country == "us" else "2.0.0", + "data_version": "3.0.0" if country == "us" else "4.0.0", + "dataset_uris": {"default": f"hf://datasets/policyengine/{country}"}, + "dataset_aliases": {"alias": "default"}, + } + + monkeypatch.setattr( + registry, "_country_bundle_metadata", fake_country_bundle_metadata + ) + + registry.put_app_release_bundle_metadata( + environment="main", + app_name="policyengine-simulation-py4-10-0", + policyengine_version="4.10.0", + ) + + snapshot = patched_modal[ + "main/simulation-api-app-release-bundles" + ].snapshot() + metadata = snapshot["policyengine-simulation-py4-10-0"] + assert snapshot["4.10.0"] == metadata + assert metadata["policyengine_version"] == "4.10.0" + assert metadata["us"]["dataset_aliases"] == {"alias": "default"} From e665b625c1de37d44004157639b7b8cf0faacf33 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 28 May 2026 14:45:36 +0200 Subject: [PATCH 16/23] feat: enable sim API cliff impacts --- .../pyproject.toml | 2 +- .../src/modal/gateway/models.py | 5 + .../simulation_macro_output.py | 12 +- .../simulation_output_builder.py | 23 +++- .../tests/gateway/test_endpoints.py | 46 +++++++ .../tests/gateway/test_models.py | 10 ++ .../tests/test_simulation_output_builder.py | 117 +++++++++++++++--- .../test_standalone_simulation_contract.py | 22 ++++ projects/policyengine-api-simulation/uv.lock | 8 +- 9 files changed, 223 insertions(+), 22 deletions(-) diff --git a/projects/policyengine-api-simulation/pyproject.toml b/projects/policyengine-api-simulation/pyproject.toml index 6eb4aadde..2f8df2fce 100644 --- a/projects/policyengine-api-simulation/pyproject.toml +++ b/projects/policyengine-api-simulation/pyproject.toml @@ -16,7 +16,7 @@ dependencies = [ "pydantic-settings (>=2.7.1,<3.0.0)", "opentelemetry-instrumentation-fastapi (>=0.51b0,<0.52)", "policyengine-fastapi", - "policyengine==4.10.0", + "policyengine==4.12.0", "policyengine-core==3.26.1", "policyengine-uk==2.88.20", "policyengine-us==1.700.0", diff --git a/projects/policyengine-api-simulation/src/modal/gateway/models.py b/projects/policyengine-api-simulation/src/modal/gateway/models.py index 47a235369..40f268fd2 100644 --- a/projects/policyengine-api-simulation/src/modal/gateway/models.py +++ b/projects/policyengine-api-simulation/src/modal/gateway/models.py @@ -184,6 +184,11 @@ def validate_end_year(self) -> "BudgetWindowBatchRequest": raise ValueError( f"budget-window end_year must be {self.MAX_END_YEAR} or earlier" ) + if self.include_cliffs is True: + raise ValueError( + "budget-window cliff impacts are not supported; use the single-year " + "simulation endpoint with include_cliffs=true" + ) return self diff --git a/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_macro_output.py b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_macro_output.py index 7ef198b71..b0afe98a7 100644 --- a/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_macro_output.py +++ b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_macro_output.py @@ -110,6 +110,16 @@ class LaborSupplyResponseOutput(MacroRootModel[dict[str, Any]]): pass +class CliffImpactInSimulation(MacroOutputModel): + cliff_gap: float + cliff_share: float + + +class CliffImpactOutput(MacroOutputModel): + baseline: CliffImpactInSimulation + reform: CliffImpactInSimulation + + class GeographicImpactOutput(MacroRootModel[list[dict[str, Any]]]): pass @@ -131,4 +141,4 @@ class SingleYearMacroOutput(MacroOutputModel): constituency_impact: GeographicImpactOutput | None local_authority_impact: GeographicImpactOutput | None congressional_district_impact: GeographicImpactOutput | None - cliff_impact: None = None + cliff_impact: CliffImpactOutput | None = None diff --git a/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_builder.py b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_builder.py index 1b60ab5e8..8a2c2d7d9 100644 --- a/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_builder.py +++ b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_builder.py @@ -14,6 +14,8 @@ AgePovertyOutput, BaselineReformValue, BudgetaryImpact, + CliffImpactInSimulation, + CliffImpactOutput, DecileOutput, DetailedBudgetOutput, DetailedBudgetProgramOutput, @@ -252,10 +254,15 @@ def __post_init__(self) -> None: def analysis(self) -> Any: if self._analysis is None: self._analysis = self.country_module.economic_impact_analysis( - self.baseline, self.reform + self.baseline, + self.reform, + include_cliff_impacts=self._include_cliff_impacts(), ) return self._analysis + def _include_cliff_impacts(self) -> bool: + return self.simulation_params.get("include_cliffs") is True + def build(self) -> SingleYearMacroOutput: poverty_outputs = self._build_poverty_outputs() wealth_decile = getattr(self.analysis, "wealth_decile_impacts", None) @@ -280,7 +287,7 @@ def build(self) -> SingleYearMacroOutput: congressional_district_impact=(self._build_congressional_district_impact()), constituency_impact=self._build_uk_constituency_impact(), local_authority_impact=self._build_uk_local_authority_impact(), - cliff_impact=None, + cliff_impact=self._build_cliff_impact(), ) def serialize(self) -> dict[str, Any]: @@ -448,6 +455,18 @@ def _build_labor_supply_response(self) -> LaborSupplyResponseOutput | None: output = _output_model_dump(labor_supply_response) return LaborSupplyResponseOutput(output) if isinstance(output, dict) else None + def _build_cliff_impact(self) -> CliffImpactOutput | None: + cliff_impact = getattr(self.analysis, "cliff_impact", None) + if isinstance(cliff_impact, CliffImpactOutput): + return cliff_impact + output = _output_model_dump(cliff_impact) + if not isinstance(output, Mapping): + return None + return CliffImpactOutput( + baseline=CliffImpactInSimulation(**output["baseline"]), + reform=CliffImpactInSimulation(**output["reform"]), + ) + def _build_geographic_impact_output( self, value: Any ) -> GeographicImpactOutput | None: diff --git a/projects/policyengine-api-simulation/tests/gateway/test_endpoints.py b/projects/policyengine-api-simulation/tests/gateway/test_endpoints.py index 4653112c1..566c6257d 100644 --- a/projects/policyengine-api-simulation/tests/gateway/test_endpoints.py +++ b/projects/policyengine-api-simulation/tests/gateway/test_endpoints.py @@ -229,6 +229,27 @@ def test__given_submission__then_returns_job_id_and_poll_url( assert data["poll_url"] == "/jobs/mock-job-id-123" assert data["status"] == "submitted" + def test__given_submission_with_include_cliffs__then_forwards_worker_flag( + self, mock_modal, client: TestClient + ): + mock_modal["dicts"]["simulation-api-us-versions"] = { + "latest": "1.500.0", + "1.500.0": "policyengine-simulation-py4-10-0", + } + + response = client.post( + "/simulate/economy/comparison", + json={ + "country": "us", + "scope": "macro", + "reform": {}, + "include_cliffs": True, + }, + ) + + assert response.status_code == 200 + assert mock_modal["func"].last_payload["include_cliffs"] is True + def test__given_submission_with_telemetry__then_preserves_run_id( self, mock_modal, client: TestClient ): @@ -714,6 +735,31 @@ def test__given_budget_window_submission__then_returns_parent_batch_job_id( "policyengine_bundle": expected_bundle("us", "1.500.0"), } + def test__given_budget_window_include_cliffs__then_returns_422( + self, mock_modal, client: TestClient + ): + mock_modal["dicts"]["simulation-api-us-versions"] = { + "latest": "1.500.0", + "1.500.0": "policyengine-simulation-py4-10-0", + } + + response = client.post( + "/simulate/economy/budget-window", + json={ + "country": "us", + "region": "us", + "scope": "macro", + "reform": {}, + "start_year": "2026", + "window_size": 3, + "include_cliffs": True, + }, + ) + + assert response.status_code == 422 + assert "cliff impacts are not supported" in response.text + assert mock_modal["func"].last_payload is None + def test__given_budget_window_submission__then_initial_poll_returns_seed_state( self, mock_modal, client: TestClient ): diff --git a/projects/policyengine-api-simulation/tests/gateway/test_models.py b/projects/policyengine-api-simulation/tests/gateway/test_models.py index d7747cb64..a8d6717fb 100644 --- a/projects/policyengine-api-simulation/tests/gateway/test_models.py +++ b/projects/policyengine-api-simulation/tests/gateway/test_models.py @@ -429,6 +429,16 @@ def test_budget_window_batch_request_rejects_non_general_target(self): target="cliff", ) + def test_budget_window_batch_request_rejects_include_cliffs(self): + with pytest.raises(ValidationError, match="cliff impacts are not supported"): + BudgetWindowBatchRequest( + country="us", + region="us", + start_year="2026", + window_size=3, + include_cliffs=True, + ) + def test_budget_window_batch_request_rejects_max_parallel_above_active_limit(self): with pytest.raises(ValidationError): BudgetWindowBatchRequest( diff --git a/projects/policyengine-api-simulation/tests/test_simulation_output_builder.py b/projects/policyengine-api-simulation/tests/test_simulation_output_builder.py index 31066f354..550032a7b 100644 --- a/projects/policyengine-api-simulation/tests/test_simulation_output_builder.py +++ b/projects/policyengine-api-simulation/tests/test_simulation_output_builder.py @@ -19,6 +19,7 @@ REFORM_POVERTY_BY_AGE, REFORM_POVERTY_BY_GENDER, REFORM_POVERTY_BY_RACE, + FakeModelOutput, fake_analysis, ) from policyengine_api_simulation.simulation_runtime import RegionResolution @@ -85,20 +86,32 @@ def _simulation_output_builder( baseline, reform, analysis=None, + include_cliffs: bool | None = None, ) -> SimulationOutputBuilder: analysis = analysis or fake_analysis() + + def economic_impact_analysis( + baseline_simulation, + reform_simulation, + *, + include_cliff_impacts=False, + ): + return analysis + + simulation_params = { + "country": country, + "data_version": "1.115.5" if country == "us" else "1.55.10", + } + if include_cliffs is not None: + simulation_params["include_cliffs"] = include_cliffs + country_module = SimpleNamespace( model=SimpleNamespace(version="1.700.0" if country == "us" else "2.88.20"), - economic_impact_analysis=lambda baseline_simulation, reform_simulation: ( - analysis - ), + economic_impact_analysis=economic_impact_analysis, ) return SimulationOutputBuilder( country=country, - simulation_params={ - "country": country, - "data_version": "1.115.5" if country == "us" else "1.55.10", - }, + simulation_params=simulation_params, country_module=country_module, dataset=SimpleNamespace(metadata={}), baseline=baseline, @@ -218,11 +231,19 @@ def test_builder_calls_policyengine_economic_impact_analysis(): baseline, reform = _macro_baseline_reform() analysis = fake_analysis() calls = [] + + def economic_impact_analysis( + baseline_simulation, + reform_simulation, + *, + include_cliff_impacts=False, + ): + calls.append((baseline_simulation, reform_simulation, include_cliff_impacts)) + return analysis + country_module = SimpleNamespace( model=SimpleNamespace(version="1.700.0"), - economic_impact_analysis=lambda baseline_simulation, reform_simulation: ( - calls.append((baseline_simulation, reform_simulation)) or analysis - ), + economic_impact_analysis=economic_impact_analysis, ) builder = SimulationOutputBuilder( country="us", @@ -235,7 +256,68 @@ def test_builder_calls_policyengine_economic_impact_analysis(): assert builder.analysis is analysis assert builder.analysis is analysis - assert calls == [(baseline, reform)] + assert calls == [(baseline, reform, False)] + + +def test_builder_passes_include_cliffs_to_policyengine_economic_impact_analysis(): + baseline, reform = _macro_baseline_reform() + analysis = fake_analysis() + calls = [] + + def economic_impact_analysis( + baseline_simulation, + reform_simulation, + *, + include_cliff_impacts=False, + ): + calls.append(include_cliff_impacts) + return analysis + + country_module = SimpleNamespace( + model=SimpleNamespace(version="1.700.0"), + economic_impact_analysis=economic_impact_analysis, + ) + builder = SimulationOutputBuilder( + country="us", + simulation_params={ + "country": "us", + "data_version": "1.115.5", + "include_cliffs": True, + }, + country_module=country_module, + dataset=SimpleNamespace(metadata={}), + baseline=baseline, + reform=reform, + ) + + assert builder.analysis is analysis + assert calls == [True] + + +def test_builder_serializes_cliff_impact_when_requested(monkeypatch): + baseline, reform = _macro_baseline_reform() + _stub_policyengine_output_calls(monkeypatch, baseline, reform) + analysis = fake_analysis() + analysis.cliff_impact = FakeModelOutput( + { + "baseline": {"cliff_gap": 10.0, "cliff_share": 0.25}, + "reform": {"cliff_gap": 20.0, "cliff_share": 0.5}, + } + ) + + output = _simulation_output_builder( + "us", + baseline, + reform, + analysis=analysis, + include_cliffs=True, + ).build() + + assert output.cliff_impact is not None + assert output.model_dump(mode="json")["cliff_impact"] == { + "baseline": {"cliff_gap": 10.0, "cliff_share": 0.25}, + "reform": {"cliff_gap": 20.0, "cliff_share": 0.5}, + } def test_normalise_policy_converts_legacy_period_range_keys(): @@ -438,11 +520,18 @@ def test_resolve_region_scopes_uk_country_from_national_dataset(): def test_builder_data_version_prefers_resolved_revision_then_dataset_metadata(): baseline, reform = _macro_baseline_reform() + + def economic_impact_analysis( + baseline_simulation, + reform_simulation, + *, + include_cliff_impacts=False, + ): + return fake_analysis() + country_module = SimpleNamespace( model=SimpleNamespace(version="1.700.0"), - economic_impact_analysis=lambda baseline_simulation, reform_simulation: ( - fake_analysis() - ), + economic_impact_analysis=economic_impact_analysis, ) resolved_builder = SimulationOutputBuilder( diff --git a/projects/policyengine-api-simulation/tests/test_standalone_simulation_contract.py b/projects/policyengine-api-simulation/tests/test_standalone_simulation_contract.py index ca182f704..f764c0a97 100644 --- a/projects/policyengine-api-simulation/tests/test_standalone_simulation_contract.py +++ b/projects/policyengine-api-simulation/tests/test_standalone_simulation_contract.py @@ -62,3 +62,25 @@ def fake_run_simulation_impl(params): assert response.status_code == 200 assert response.json() == CURRENT_SINGLE_YEAR_MACRO_RESULT + + +def test_standalone_simulation_route_forwards_include_cliffs(monkeypatch): + def fake_run_simulation_impl(params): + assert params == { + "country": "us", + "reform": {}, + "include_cliffs": True, + } + return CURRENT_SINGLE_YEAR_MACRO_RESULT + + monkeypatch.setattr( + "policyengine_api_simulation.simulation.run_simulation_impl", + fake_run_simulation_impl, + ) + + response = TestClient(app).post( + "/simulate/economy/comparison", + json={"country": "us", "reform": {}, "include_cliffs": True}, + ) + + assert response.status_code == 200 diff --git a/projects/policyengine-api-simulation/uv.lock b/projects/policyengine-api-simulation/uv.lock index 82207be8d..e227a79b9 100644 --- a/projects/policyengine-api-simulation/uv.lock +++ b/projects/policyengine-api-simulation/uv.lock @@ -1611,7 +1611,7 @@ wheels = [ [[package]] name = "policyengine" -version = "4.10.0" +version = "4.12.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "h5py" }, @@ -1623,9 +1623,9 @@ dependencies = [ { name = "pydantic" }, { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/10/27/59ca969ab71647d526f6a7553f93cb61a1853d3f4fcc88552f08292be8a2/policyengine-4.10.0.tar.gz", hash = "sha256:68f634d107bd3ac81427364b03203a7d80407599cae9d13ff44231001436daa6", size = 571499, upload-time = "2026-05-21T19:04:09.051Z" } +sdist = { url = "https://files.pythonhosted.org/packages/29/0a/9a9e2262ee6b152ece050e59efdfd1654a7d76716b3f02c9e9ddc5c7de29/policyengine-4.12.0.tar.gz", hash = "sha256:dbf5505fb1739883be7475300eabd6a85eefaf04faed882fcfd878729a796717", size = 635887, upload-time = "2026-05-28T12:37:43.195Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ee/d2/b4d11bd59e87da0376255779da7c0fc8afc1141cb33e8891a1fa5ce2c7e5/policyengine-4.10.0-py3-none-any.whl", hash = "sha256:db7454a3bf9cbc791ed8b8ccb7d9d5dcb8f4a08f93bb6c2fb3cf920d1dcce7c2", size = 189098, upload-time = "2026-05-21T19:04:07.296Z" }, + { url = "https://files.pythonhosted.org/packages/1c/54/45550717d035b98efb491c969fc33fa6554c39ef9ae7a518f2dd18c485cd/policyengine-4.12.0-py3-none-any.whl", hash = "sha256:92a28865929b895df1dc900e2a92337bbe81767b1939c9d7bfee892c2dbdf526", size = 190733, upload-time = "2026-05-28T12:37:41.716Z" }, ] [[package]] @@ -1732,7 +1732,7 @@ requires-dist = [ { name = "openapi-python-client", marker = "extra == 'build'", specifier = ">=0.21.6" }, { name = "opentelemetry-instrumentation-fastapi", specifier = ">=0.51b0,<0.52" }, { name = "opentelemetry-instrumentation-sqlalchemy", specifier = ">=0.51b0,<0.52" }, - { name = "policyengine", specifier = "==4.10.0" }, + { name = "policyengine", specifier = "==4.12.0" }, { name = "policyengine-core", specifier = "==3.26.1" }, { name = "policyengine-fastapi", editable = "../../libs/policyengine-fastapi" }, { name = "policyengine-uk", specifier = "==2.88.20" }, From a2c997ba79c414c3c90f8f481c250da76c3e656d Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 28 May 2026 15:41:18 +0200 Subject: [PATCH 17/23] chore: type version registry metadata --- .../modal/utils/update_version_registry.py | 24 +++++++++++++++++-- .../tests/test_update_version_registry.py | 14 ++++++++++- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/projects/policyengine-api-simulation/src/modal/utils/update_version_registry.py b/projects/policyengine-api-simulation/src/modal/utils/update_version_registry.py index cc91e5ba7..4687bd2d0 100644 --- a/projects/policyengine-api-simulation/src/modal/utils/update_version_registry.py +++ b/projects/policyengine-api-simulation/src/modal/utils/update_version_registry.py @@ -24,6 +24,7 @@ import argparse import modal from packaging.version import InvalidVersion, Version +from typing import TypedDict POLICYENGINE_VERSION_DICT_NAME = "simulation-api-policyengine-versions" US_VERSION_DICT_NAME = "simulation-api-us-versions" @@ -31,6 +32,25 @@ APP_RELEASE_BUNDLES_DICT_NAME = "simulation-api-app-release-bundles" +class CountryBundleMetadata(TypedDict): + country: str + model_package_name: str + model_version: str + data_package_name: str + data_version: str + default_dataset: str + default_dataset_uri: str + dataset_uris: dict[str, str] + dataset_aliases: dict[str, str] + + +class AppReleaseBundleMetadata(TypedDict): + app_name: str + policyengine_version: str + us: CountryBundleMetadata + uk: CountryBundleMetadata + + def _is_newer_version(candidate: str, current: str | None) -> bool: """Return True when ``candidate`` should replace ``current`` as 'latest'. @@ -116,7 +136,7 @@ def update_version_dict( ) -def _country_bundle_metadata(country: str) -> dict: +def _country_bundle_metadata(country: str) -> CountryBundleMetadata: from policyengine_api_simulation.release_bundle import ( DATASET_ALIASES, get_country_release_bundle, @@ -140,7 +160,7 @@ def build_app_release_bundle_metadata( *, app_name: str, policyengine_version: str, -) -> dict: +) -> AppReleaseBundleMetadata: return { "app_name": app_name, "policyengine_version": policyengine_version, diff --git a/projects/policyengine-api-simulation/tests/test_update_version_registry.py b/projects/policyengine-api-simulation/tests/test_update_version_registry.py index c08714a1e..1b4a3acf7 100644 --- a/projects/policyengine-api-simulation/tests/test_update_version_registry.py +++ b/projects/policyengine-api-simulation/tests/test_update_version_registry.py @@ -140,11 +140,23 @@ def test_put_app_release_bundle_metadata_records_app_and_py_version_aliases( patched_modal, monkeypatch, ): - def fake_country_bundle_metadata(country: str) -> dict: + def fake_country_bundle_metadata( + country: str, + ) -> registry.CountryBundleMetadata: return { "country": country, + "model_package_name": ( + "policyengine-us" if country == "us" else "policyengine-uk" + ), "model_version": "1.0.0" if country == "us" else "2.0.0", + "data_package_name": ( + "policyengine-us-data" + if country == "us" + else "policyengine-uk-data" + ), "data_version": "3.0.0" if country == "us" else "4.0.0", + "default_dataset": "default", + "default_dataset_uri": f"hf://datasets/policyengine/{country}/default", "dataset_uris": {"default": f"hf://datasets/policyengine/{country}"}, "dataset_aliases": {"alias": "default"}, } From cd6cf1a9376804557769a7808e67f34d4b8ac88d Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 28 May 2026 15:59:26 +0200 Subject: [PATCH 18/23] chore: remove simulation compatibility shims --- projects/policyengine-api-simulation/src/modal/hf_dataset.py | 3 --- .../policyengine-api-simulation/src/modal/release_bundle.py | 3 --- projects/policyengine-api-simulation/src/modal/simulation.py | 3 --- .../src/modal/simulation_macro_output.py | 3 --- .../src/modal/simulation_output_builder.py | 3 --- projects/policyengine-api-simulation/src/modal/telemetry.py | 3 --- 6 files changed, 18 deletions(-) delete mode 100644 projects/policyengine-api-simulation/src/modal/hf_dataset.py delete mode 100644 projects/policyengine-api-simulation/src/modal/release_bundle.py delete mode 100644 projects/policyengine-api-simulation/src/modal/simulation.py delete mode 100644 projects/policyengine-api-simulation/src/modal/simulation_macro_output.py delete mode 100644 projects/policyengine-api-simulation/src/modal/simulation_output_builder.py delete mode 100644 projects/policyengine-api-simulation/src/modal/telemetry.py diff --git a/projects/policyengine-api-simulation/src/modal/hf_dataset.py b/projects/policyengine-api-simulation/src/modal/hf_dataset.py deleted file mode 100644 index 836616831..000000000 --- a/projects/policyengine-api-simulation/src/modal/hf_dataset.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Compatibility shim for packaged simulation helpers.""" - -from policyengine_api_simulation.hf_dataset import * # noqa: F403 diff --git a/projects/policyengine-api-simulation/src/modal/release_bundle.py b/projects/policyengine-api-simulation/src/modal/release_bundle.py deleted file mode 100644 index c3022a52b..000000000 --- a/projects/policyengine-api-simulation/src/modal/release_bundle.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Compatibility shim for packaged simulation helpers.""" - -from policyengine_api_simulation.release_bundle import * # noqa: F403 diff --git a/projects/policyengine-api-simulation/src/modal/simulation.py b/projects/policyengine-api-simulation/src/modal/simulation.py deleted file mode 100644 index f6197e79b..000000000 --- a/projects/policyengine-api-simulation/src/modal/simulation.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Compatibility shim for packaged simulation helpers.""" - -from policyengine_api_simulation.simulation_runtime import * # noqa: F403 diff --git a/projects/policyengine-api-simulation/src/modal/simulation_macro_output.py b/projects/policyengine-api-simulation/src/modal/simulation_macro_output.py deleted file mode 100644 index e409982bb..000000000 --- a/projects/policyengine-api-simulation/src/modal/simulation_macro_output.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Compatibility shim for packaged simulation helpers.""" - -from policyengine_api_simulation.simulation_macro_output import * # noqa: F403 diff --git a/projects/policyengine-api-simulation/src/modal/simulation_output_builder.py b/projects/policyengine-api-simulation/src/modal/simulation_output_builder.py deleted file mode 100644 index 8334f2b38..000000000 --- a/projects/policyengine-api-simulation/src/modal/simulation_output_builder.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Compatibility shim for packaged simulation helpers.""" - -from policyengine_api_simulation.simulation_output_builder import * # noqa: F403 diff --git a/projects/policyengine-api-simulation/src/modal/telemetry.py b/projects/policyengine-api-simulation/src/modal/telemetry.py deleted file mode 100644 index ea120cb12..000000000 --- a/projects/policyengine-api-simulation/src/modal/telemetry.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Compatibility shim for packaged simulation helpers.""" - -from policyengine_api_simulation.telemetry import * # noqa: F403 From ee0e8a5bfab6405b9abb0eb557a988ee5f5dbef6 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 28 May 2026 16:26:16 +0200 Subject: [PATCH 19/23] chore: remove stale console exporter comment --- .../src/policyengine_api_simulation/main.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/projects/policyengine-api-simulation/src/policyengine_api_simulation/main.py b/projects/policyengine-api-simulation/src/policyengine_api_simulation/main.py index 0759fe8cd..757a166d2 100644 --- a/projects/policyengine-api-simulation/src/policyengine_api_simulation/main.py +++ b/projects/policyengine-api-simulation/src/policyengine_api_simulation/main.py @@ -63,8 +63,7 @@ async def lifespan(app: FastAPI): match get_settings().environment: case Environment.DESKTOP: - pass # Don't print opentelemetry to console- this makes it impossible to read the logs. Alternatively, do by uncommenting this line. - # export_ot_to_console(resource) + pass # Don't print OpenTelemetry to console; it makes logs unreadable. case Environment.PRODUCTION: export_ot_to_gcp(resource) case value: From 24bc2c3ba8114bb204d33dbc3bde59096abe2041 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 28 May 2026 16:31:40 +0200 Subject: [PATCH 20/23] chore: restore console exporter comment --- .../src/policyengine_api_simulation/main.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/projects/policyengine-api-simulation/src/policyengine_api_simulation/main.py b/projects/policyengine-api-simulation/src/policyengine_api_simulation/main.py index 757a166d2..8606f67b2 100644 --- a/projects/policyengine-api-simulation/src/policyengine_api_simulation/main.py +++ b/projects/policyengine-api-simulation/src/policyengine_api_simulation/main.py @@ -4,6 +4,7 @@ from policyengine_fastapi.opentelemetry import ( GCPLoggingInstrumentor, FastAPIEnhancedInstrumenter, + export_ot_to_console, export_ot_to_gcp, ) from policyengine_fastapi.exit import exit @@ -63,7 +64,8 @@ async def lifespan(app: FastAPI): match get_settings().environment: case Environment.DESKTOP: - pass # Don't print OpenTelemetry to console; it makes logs unreadable. + pass # Don't print opentelemetry to console- this makes it impossible to read the logs. Alternatively, do by uncommenting this line. + # export_ot_to_console(resource) case Environment.PRODUCTION: export_ot_to_gcp(resource) case value: From 915b3e355dabfef9d4f9754aad897edd09f777ea Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 28 May 2026 16:48:26 +0200 Subject: [PATCH 21/23] refactor: split simulation output segments --- .../simulation_output_budget.py | 68 +++ .../simulation_output_builder.py | 556 ++---------------- .../simulation_output_cliff.py | 25 + .../simulation_output_common.py | 103 ++++ .../simulation_output_distribution.py | 96 +++ .../simulation_output_geographic.py | 87 +++ .../simulation_output_inequality.py | 32 + .../simulation_output_labor.py | 18 + .../simulation_output_poverty.py | 251 ++++++++ .../tests/test_simulation_output_builder.py | 8 +- .../test_standalone_simulation_contract.py | 8 + 11 files changed, 740 insertions(+), 512 deletions(-) create mode 100644 projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_budget.py create mode 100644 projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_cliff.py create mode 100644 projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_common.py create mode 100644 projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_distribution.py create mode 100644 projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_geographic.py create mode 100644 projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_inequality.py create mode 100644 projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_labor.py create mode 100644 projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_poverty.py diff --git a/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_budget.py b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_budget.py new file mode 100644 index 000000000..e50aa201c --- /dev/null +++ b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_budget.py @@ -0,0 +1,68 @@ +"""Budget output segment builders.""" + +from __future__ import annotations + +from typing import Any + +from policyengine_api_simulation.simulation_macro_output import ( + BudgetaryImpact, + DetailedBudgetOutput, + DetailedBudgetProgramOutput, +) +from policyengine_api_simulation.simulation_output_common import ( + _change_output_variable, + _collection_records, + _number, + _sum_output_variable, +) + + +def build_detailed_budget(analysis: Any) -> DetailedBudgetOutput: + collection = getattr(analysis, "program_statistics", None) + if isinstance(collection, DetailedBudgetOutput): + return collection + detailed_budget: dict[str, DetailedBudgetProgramOutput] = {} + for row in _collection_records(collection): + program_name = row.get("program_name") + if not program_name: + continue + baseline = _number(row.get("baseline_total")) + reform = _number(row.get("reform_total")) + detailed_budget[str(program_name)] = DetailedBudgetProgramOutput( + baseline=baseline, + reform=reform, + difference=_number(row.get("change"), reform - baseline), + ) + return DetailedBudgetOutput(detailed_budget) + + +def build_budgetary_impact(country: str, baseline, reform) -> BudgetaryImpact: + tax_revenue_impact = _change_output_variable( + baseline, reform, "household_tax", entity="household" + ) + benefit_spending_impact = _change_output_variable( + baseline, reform, "household_benefits", entity="household" + ) + state_tax_revenue_impact = ( + _change_output_variable( + baseline, + reform, + "household_state_income_tax", + entity="household", + ) + if country == "us" + else 0.0 + ) + + return BudgetaryImpact( + tax_revenue_impact=tax_revenue_impact, + state_tax_revenue_impact=state_tax_revenue_impact, + benefit_spending_impact=benefit_spending_impact, + budgetary_impact=tax_revenue_impact - benefit_spending_impact, + households=_sum_output_variable( + baseline, "household_weight", entity="household" + ), + baseline_net_income=_sum_output_variable( + baseline, "household_net_income", entity="household" + ), + ) diff --git a/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_builder.py b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_builder.py index 8a2c2d7d9..f68b4ac78 100644 --- a/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_builder.py +++ b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_builder.py @@ -2,239 +2,33 @@ from __future__ import annotations -import logging -import math -from collections.abc import Iterable, Mapping from dataclasses import dataclass, field -from importlib import import_module from typing import Any +from policyengine_api_simulation import simulation_output_budget +from policyengine_api_simulation import simulation_output_cliff +from policyengine_api_simulation import simulation_output_distribution +from policyengine_api_simulation import simulation_output_geographic +from policyengine_api_simulation import simulation_output_inequality +from policyengine_api_simulation import simulation_output_labor +from policyengine_api_simulation import simulation_output_poverty from policyengine_api_simulation.release_bundle import get_country_release_bundle from policyengine_api_simulation.simulation_macro_output import ( - AgePovertyOutput, - BaselineReformValue, BudgetaryImpact, - CliffImpactInSimulation, CliffImpactOutput, DecileOutput, DetailedBudgetOutput, - DetailedBudgetProgramOutput, GeographicImpactOutput, - GenderPovertyOutput, InequalityOutput, IntraDecileOutput, LaborSupplyResponseOutput, - PovertyModuleOutputs, PovertyByGenderOutput, PovertyByRaceOutput, + PovertyModuleOutputs, PovertyOutput, - RacePovertyOutput, SingleYearMacroOutput, ) -logger = logging.getLogger(__name__) - -INTRA_DECILE_COLUMNS = { - "Lose more than 5%": "lose_more_than_5pct", - "Lose less than 5%": "lose_less_than_5pct", - "No change": "no_change", - "Gain less than 5%": "gain_less_than_5pct", - "Gain more than 5%": "gain_more_than_5pct", -} - -US_POVERTY_TYPES = { - "spm": "poverty", - "spm_deep": "deep_poverty", -} - -UK_POVERTY_TYPES = { - "relative_bhc": "poverty", - "absolute_bhc": "deep_poverty", -} - - -def _number(value: Any, default: float = 0.0) -> float: - if value is None: - return default - try: - result = float(value) - except (TypeError, ValueError): - return default - if math.isnan(result) or math.isinf(result): - return default - return result - - -def _collection_records(collection: Any) -> list[dict[str, Any]]: - if collection is None: - return [] - dataframe = getattr(collection, "dataframe", None) - if dataframe is not None: - return list(dataframe.to_dict("records")) - if isinstance(collection, list): - return [dict(item) for item in collection if isinstance(item, Mapping)] - return [] - - -def _output_model_dump(value: Any) -> Any: - if value is None: - return None - if hasattr(value, "model_dump"): - return value.model_dump(mode="json") - if isinstance(value, Mapping): - return dict(value) - return None - - -def _empty_baseline_reform_value() -> dict[str, float]: - return {"baseline": 0.0, "reform": 0.0} - - -def _empty_age_poverty() -> dict[str, dict[str, float]]: - return { - "child": _empty_baseline_reform_value(), - "adult": _empty_baseline_reform_value(), - "senior": _empty_baseline_reform_value(), - "all": _empty_baseline_reform_value(), - } - - -def _empty_gender_poverty() -> dict[str, dict[str, float]]: - return { - "male": _empty_baseline_reform_value(), - "female": _empty_baseline_reform_value(), - } - - -def _poverty_type(country: str, row: Mapping[str, Any]) -> str | None: - poverty_type = str(row.get("poverty_type") or "").lower() - if country == "us": - return US_POVERTY_TYPES.get(poverty_type) - return UK_POVERTY_TYPES.get(poverty_type) - - -def _fill_poverty_block( - *, - country: str, - output: dict[str, dict[str, dict[str, float]]], - baseline_records: Iterable[Mapping[str, Any]], - reform_records: Iterable[Mapping[str, Any]], - default_group: str, -) -> None: - for side, records in (("baseline", baseline_records), ("reform", reform_records)): - for row in records: - poverty_type = _poverty_type(country, row) - if poverty_type is None: - continue - if poverty_type not in output: - continue - group = str(row.get("filter_group") or default_group).lower() - if group not in output[poverty_type]: - continue - output[poverty_type][group][side] = _number(row.get("rate")) - - -def _age_poverty_output(values: dict[str, dict[str, float]]) -> AgePovertyOutput: - return AgePovertyOutput( - child=BaselineReformValue(**values["child"]), - adult=BaselineReformValue(**values["adult"]), - senior=BaselineReformValue(**values["senior"]), - all=BaselineReformValue(**values["all"]), - ) - - -def _gender_poverty_output( - values: dict[str, dict[str, float]], -) -> GenderPovertyOutput: - return GenderPovertyOutput( - male=BaselineReformValue(**values["male"]), - female=BaselineReformValue(**values["female"]), - ) - - -def _race_poverty_output(values: dict[str, dict[str, float]]) -> RacePovertyOutput: - return RacePovertyOutput( - white=BaselineReformValue(**values["white"]), - black=BaselineReformValue(**values["black"]), - hispanic=BaselineReformValue(**values["hispanic"]), - other=BaselineReformValue(**values["other"]), - ) - - -def _entity_data(simulation, entity: str): - if simulation.output_dataset is None or simulation.output_dataset.data is None: - simulation.ensure() - return getattr(simulation.output_dataset.data, entity) - - -def _sum_output_variable(simulation, variable: str, entity: str) -> float: - data = _entity_data(simulation, entity) - if variable in data.columns: - return float(data[variable].sum()) - - from policyengine.outputs import Aggregate, AggregateType - - output = Aggregate( - simulation=simulation, - variable=variable, - entity=entity, - aggregate_type=AggregateType.SUM, - ) - output.run() - return float(output.result) - - -def _try_sum_output_variable(simulation, variable: str, entity: str) -> float: - try: - return _sum_output_variable(simulation, variable, entity) - except Exception: - logger.warning("Unable to calculate sum for %s", variable, exc_info=True) - return 0.0 - - -def _change_output_variable(baseline, reform, variable: str, entity: str) -> float: - baseline_data = _entity_data(baseline, entity) - reform_data = _entity_data(reform, entity) - if variable in baseline_data.columns and variable in reform_data.columns: - return float((reform_data[variable] - baseline_data[variable]).sum()) - - from policyengine.outputs import ChangeAggregate, ChangeAggregateType - - output = ChangeAggregate( - baseline_simulation=baseline, - reform_simulation=reform, - variable=variable, - entity=entity, - aggregate_type=ChangeAggregateType.SUM, - ) - output.run() - return float(output.result) - - -def _try_change_output_variable(baseline, reform, variable: str, entity: str) -> float: - try: - return _change_output_variable(baseline, reform, variable, entity) - except Exception: - logger.warning("Unable to calculate change for %s", variable, exc_info=True) - return 0.0 - - -def _output_module_function(module_name: str, name: str): - module = import_module(f"policyengine.outputs.{module_name}") - return getattr(module, name) - - -def _poverty_module_function(name: str): - return _output_module_function("poverty", name) - - -def _try_compute_output(label: str, fn, *args, **kwargs): - try: - return fn(*args, **kwargs) - except Exception: - logger.warning("Unable to calculate %s", label, exc_info=True) - return None - @dataclass class SimulationOutputBuilder: @@ -294,233 +88,61 @@ def serialize(self) -> dict[str, Any]: return self.build().model_dump(mode="json") def _build_detailed_budget(self) -> DetailedBudgetOutput: - collection = getattr(self.analysis, "program_statistics", None) - if isinstance(collection, DetailedBudgetOutput): - return collection - detailed_budget: dict[str, DetailedBudgetProgramOutput] = {} - for row in _collection_records(collection): - program_name = row.get("program_name") - if not program_name: - continue - baseline = _number(row.get("baseline_total")) - reform = _number(row.get("reform_total")) - detailed_budget[str(program_name)] = DetailedBudgetProgramOutput( - baseline=baseline, - reform=reform, - difference=_number(row.get("change"), reform - baseline), - ) - return DetailedBudgetOutput(detailed_budget) + return simulation_output_budget.build_detailed_budget(self.analysis) def _build_decile(self) -> DecileOutput: - return self._build_decile_output(getattr(self.analysis, "decile_impacts", None)) + return simulation_output_distribution.build_decile(self.analysis) def _build_inequality(self) -> InequalityOutput: - baseline = getattr(self.analysis, "baseline_inequality", None) - reform = getattr(self.analysis, "reform_inequality", None) - if isinstance(baseline, InequalityOutput): - return baseline - return InequalityOutput( - gini=BaselineReformValue( - baseline=_number(getattr(baseline, "gini", None)), - reform=_number(getattr(reform, "gini", None)), - ), - top_10_pct_share=BaselineReformValue( - baseline=_number(getattr(baseline, "top_10_share", None)), - reform=_number(getattr(reform, "top_10_share", None)), - ), - top_1_pct_share=BaselineReformValue( - baseline=_number(getattr(baseline, "top_1_share", None)), - reform=_number(getattr(reform, "top_1_share", None)), - ), - ) + return simulation_output_inequality.build_inequality(self.analysis) def _build_budgetary_impact(self) -> BudgetaryImpact: - tax_revenue_impact = _change_output_variable( - self.baseline, self.reform, "household_tax", entity="household" - ) - benefit_spending_impact = _change_output_variable( - self.baseline, self.reform, "household_benefits", entity="household" - ) - state_tax_revenue_impact = ( - _change_output_variable( - self.baseline, - self.reform, - "household_state_income_tax", - entity="household", - ) - if self.country == "us" - else 0.0 - ) - - return BudgetaryImpact( - tax_revenue_impact=tax_revenue_impact, - state_tax_revenue_impact=state_tax_revenue_impact, - benefit_spending_impact=benefit_spending_impact, - budgetary_impact=tax_revenue_impact - benefit_spending_impact, - households=_sum_output_variable( - self.baseline, "household_weight", entity="household" - ), - baseline_net_income=_sum_output_variable( - self.baseline, "household_net_income", entity="household" - ), + return simulation_output_budget.build_budgetary_impact( + self.country, self.baseline, self.reform ) def _build_poverty_outputs(self) -> PovertyModuleOutputs: - prefix = "us" if self.country == "us" else "uk" - baseline_poverty_by_age = _try_compute_output( - "baseline poverty by age", - _poverty_module_function(f"calculate_{prefix}_poverty_by_age"), - self.baseline, - ) - reform_poverty_by_age = _try_compute_output( - "reform poverty by age", - _poverty_module_function(f"calculate_{prefix}_poverty_by_age"), - self.reform, - ) - baseline_poverty_by_gender = _try_compute_output( - "baseline poverty by gender", - _poverty_module_function(f"calculate_{prefix}_poverty_by_gender"), - self.baseline, - ) - reform_poverty_by_gender = _try_compute_output( - "reform poverty by gender", - _poverty_module_function(f"calculate_{prefix}_poverty_by_gender"), - self.reform, - ) - baseline_poverty_by_race = None - reform_poverty_by_race = None - if self.country == "us": - baseline_poverty_by_race = _try_compute_output( - "baseline poverty by race", - _poverty_module_function("calculate_us_poverty_by_race"), - self.baseline, - ) - reform_poverty_by_race = _try_compute_output( - "reform poverty by race", - _poverty_module_function("calculate_us_poverty_by_race"), - self.reform, - ) - return PovertyModuleOutputs( - poverty=self._build_poverty_output( - baseline=getattr(self.analysis, "baseline_poverty", None), - reform=getattr(self.analysis, "reform_poverty", None), - baseline_by_age=baseline_poverty_by_age, - reform_by_age=reform_poverty_by_age, - ), - poverty_by_gender=self._build_poverty_by_gender_output( - baseline_by_gender=baseline_poverty_by_gender, - reform_by_gender=reform_poverty_by_gender, - ), - poverty_by_race=( - self._build_poverty_by_race_output( - baseline_by_race=baseline_poverty_by_race, - reform_by_race=reform_poverty_by_race, - ) - if self.country == "us" - else None - ), + return simulation_output_poverty.build_poverty_outputs( + self.country, self.baseline, self.reform, self.analysis ) def _build_intra_decile_output(self) -> IntraDecileOutput: - from policyengine.outputs.intra_decile_impact import ( - compute_intra_decile_impacts, + return simulation_output_distribution.build_intra_decile_output( + self.baseline, self.reform ) - collection = _try_compute_output( - "intra-decile impacts", - compute_intra_decile_impacts, - self.baseline, - self.reform, - income_variable="household_net_income", - entity="household", + def _build_wealth_decile(self, wealth_decile: Any) -> DecileOutput | None: + return simulation_output_distribution.build_wealth_decile( + self.country, wealth_decile ) - return self._build_intra_decile_output_from_collection(collection) - - def _build_wealth_decile(self, wealth_decile) -> DecileOutput | None: - if self.country != "uk": - return None - return self._build_decile_output(wealth_decile) def _build_intra_wealth_decile( - self, intra_wealth_decile + self, intra_wealth_decile: Any ) -> IntraDecileOutput | None: - if self.country != "uk": - return None - return self._build_intra_decile_output_from_collection(intra_wealth_decile) + return simulation_output_distribution.build_intra_wealth_decile( + self.country, intra_wealth_decile + ) def _build_labor_supply_response(self) -> LaborSupplyResponseOutput | None: - labor_supply_response = getattr(self.analysis, "labor_supply_response", None) - if isinstance(labor_supply_response, LaborSupplyResponseOutput): - return labor_supply_response - output = _output_model_dump(labor_supply_response) - return LaborSupplyResponseOutput(output) if isinstance(output, dict) else None + return simulation_output_labor.build_labor_supply_response(self.analysis) def _build_cliff_impact(self) -> CliffImpactOutput | None: - cliff_impact = getattr(self.analysis, "cliff_impact", None) - if isinstance(cliff_impact, CliffImpactOutput): - return cliff_impact - output = _output_model_dump(cliff_impact) - if not isinstance(output, Mapping): - return None - return CliffImpactOutput( - baseline=CliffImpactInSimulation(**output["baseline"]), - reform=CliffImpactInSimulation(**output["reform"]), - ) + return simulation_output_cliff.build_cliff_impact(self.analysis) def _build_geographic_impact_output( self, value: Any ) -> GeographicImpactOutput | None: - if isinstance(value, GeographicImpactOutput): - return value - records = _output_model_dump(value) - if isinstance(records, list): - return GeographicImpactOutput( - [dict(item) for item in records if isinstance(item, Mapping)] - ) - if isinstance(value, list): - return GeographicImpactOutput( - [dict(item) for item in value if isinstance(item, Mapping)] - ) - return None + return simulation_output_geographic.build_geographic_impact_output(value) def _build_decile_output(self, collection: Any) -> DecileOutput: - if isinstance(collection, DecileOutput): - return collection - average: dict[str, float] = {} - relative: dict[str, float] = {} - for row in sorted( - _collection_records(collection), - key=lambda item: _number(item.get("decile")), - ): - decile = int(_number(row.get("decile"))) - if decile <= 0: - continue - key = str(decile) - average[key] = _number(row.get("absolute_change")) - relative[key] = _number(row.get("relative_change")) - return DecileOutput(average=average, relative=relative) + return simulation_output_distribution.build_decile_output(collection) def _build_intra_decile_output_from_collection( self, collection: Any ) -> IntraDecileOutput: - if isinstance(collection, IntraDecileOutput): - return collection - deciles: dict[str, list[float]] = {label: [] for label in INTRA_DECILE_COLUMNS} - all_values: dict[str, float] = {label: 0.0 for label in INTRA_DECILE_COLUMNS} - rows = [ - row - for row in sorted( - _collection_records(collection), - key=lambda item: _number(item.get("decile")), - ) - if int(_number(row.get("decile"))) > 0 - ] - - for label, column in INTRA_DECILE_COLUMNS.items(): - values = [_number(row.get(column)) for row in rows] - deciles[label] = values - all_values[label] = sum(values) / len(values) if values else 0.0 - return IntraDecileOutput(deciles=deciles, all=all_values) + return simulation_output_distribution.build_intra_decile_output_from_collection( + collection + ) def _build_poverty_output( self, @@ -530,29 +152,12 @@ def _build_poverty_output( baseline_by_age: Any, reform_by_age: Any, ) -> PovertyOutput: - if isinstance(baseline, PovertyOutput): - return baseline - result = { - "poverty": _empty_age_poverty(), - "deep_poverty": _empty_age_poverty(), - } - _fill_poverty_block( - country=self.country, - output=result, - baseline_records=_collection_records(baseline), - reform_records=_collection_records(reform), - default_group="all", - ) - _fill_poverty_block( + return simulation_output_poverty.build_poverty_output( country=self.country, - output=result, - baseline_records=_collection_records(baseline_by_age), - reform_records=_collection_records(reform_by_age), - default_group="all", - ) - return PovertyOutput( - poverty=_age_poverty_output(result["poverty"]), - deep_poverty=_age_poverty_output(result["deep_poverty"]), + baseline=baseline, + reform=reform, + baseline_by_age=baseline_by_age, + reform_by_age=reform_by_age, ) def _build_poverty_by_gender_output( @@ -561,22 +166,10 @@ def _build_poverty_by_gender_output( baseline_by_gender: Any, reform_by_gender: Any, ) -> PovertyByGenderOutput: - if isinstance(baseline_by_gender, PovertyByGenderOutput): - return baseline_by_gender - result = { - "poverty": _empty_gender_poverty(), - "deep_poverty": _empty_gender_poverty(), - } - _fill_poverty_block( + return simulation_output_poverty.build_poverty_by_gender_output( country=self.country, - output=result, - baseline_records=_collection_records(baseline_by_gender), - reform_records=_collection_records(reform_by_gender), - default_group="all", - ) - return PovertyByGenderOutput( - poverty=_gender_poverty_output(result["poverty"]), - deep_poverty=_gender_poverty_output(result["deep_poverty"]), + baseline_by_gender=baseline_by_gender, + reform_by_gender=reform_by_gender, ) def _build_poverty_by_race_output( @@ -585,79 +178,26 @@ def _build_poverty_by_race_output( baseline_by_race: Any, reform_by_race: Any, ) -> PovertyByRaceOutput: - if isinstance(baseline_by_race, PovertyByRaceOutput): - return baseline_by_race - result = { - "poverty": { - "white": _empty_baseline_reform_value(), - "black": _empty_baseline_reform_value(), - "hispanic": _empty_baseline_reform_value(), - "other": _empty_baseline_reform_value(), - } - } - _fill_poverty_block( - country="us", - output=result, - baseline_records=_collection_records(baseline_by_race), - reform_records=_collection_records(reform_by_race), - default_group="all", + return simulation_output_poverty.build_poverty_by_race_output( + baseline_by_race=baseline_by_race, + reform_by_race=reform_by_race, ) - return PovertyByRaceOutput(poverty=_race_poverty_output(result["poverty"])) def _build_congressional_district_impact( self, ) -> GeographicImpactOutput | None: - if self.country != "us": - return None - - from policyengine.outputs.congressional_district_impact import ( - compute_us_congressional_district_impacts, - ) - - impact = _try_compute_output( - "congressional district impacts", - compute_us_congressional_district_impacts, - self.baseline, - self.reform, - ) - return self._build_geographic_impact_output( - getattr(impact, "district_results", None) if impact is not None else None + return simulation_output_geographic.build_congressional_district_impact( + self.country, self.baseline, self.reform ) def _build_uk_constituency_impact(self) -> GeographicImpactOutput | None: - if self.country != "uk": - return None - - impact = _try_compute_output( - "constituency impacts", - _output_module_function( - "constituency_impact", "compute_uk_constituency_impacts" - ), - self.baseline, - self.reform, - ) - if impact is None: - return None - return self._build_geographic_impact_output( - getattr(impact, "constituency_results", None) + return simulation_output_geographic.build_uk_constituency_impact( + self.country, self.baseline, self.reform ) def _build_uk_local_authority_impact(self) -> GeographicImpactOutput | None: - if self.country != "uk": - return None - - impact = _try_compute_output( - "local authority impacts", - _output_module_function( - "local_authority_impact", "compute_uk_local_authority_impacts" - ), - self.baseline, - self.reform, - ) - if impact is None: - return None - return self._build_geographic_impact_output( - getattr(impact, "local_authority_results", None) + return simulation_output_geographic.build_uk_local_authority_impact( + self.country, self.baseline, self.reform ) def _model_version(self) -> str: diff --git a/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_cliff.py b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_cliff.py new file mode 100644 index 000000000..a5a38a7d1 --- /dev/null +++ b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_cliff.py @@ -0,0 +1,25 @@ +"""Cliff impact output segment builders.""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from policyengine_api_simulation.simulation_macro_output import ( + CliffImpactInSimulation, + CliffImpactOutput, +) +from policyengine_api_simulation.simulation_output_common import _output_model_dump + + +def build_cliff_impact(analysis: Any) -> CliffImpactOutput | None: + cliff_impact = getattr(analysis, "cliff_impact", None) + if isinstance(cliff_impact, CliffImpactOutput): + return cliff_impact + output = _output_model_dump(cliff_impact) + if not isinstance(output, Mapping): + return None + return CliffImpactOutput( + baseline=CliffImpactInSimulation(**output["baseline"]), + reform=CliffImpactInSimulation(**output["reform"]), + ) diff --git a/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_common.py b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_common.py new file mode 100644 index 000000000..fb804d9f2 --- /dev/null +++ b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_common.py @@ -0,0 +1,103 @@ +"""Shared helpers for simulation macro output serialization.""" + +from __future__ import annotations + +import logging +import math +from collections.abc import Mapping +from importlib import import_module +from typing import Any + +logger = logging.getLogger(__name__) + + +def _number(value: Any, default: float = 0.0) -> float: + if value is None: + return default + try: + result = float(value) + except (TypeError, ValueError): + return default + if math.isnan(result) or math.isinf(result): + return default + return result + + +def _collection_records(collection: Any) -> list[dict[str, Any]]: + if collection is None: + return [] + dataframe = getattr(collection, "dataframe", None) + if dataframe is not None: + return list(dataframe.to_dict("records")) + if isinstance(collection, list): + return [dict(item) for item in collection if isinstance(item, Mapping)] + return [] + + +def _output_model_dump(value: Any) -> Any: + if value is None: + return None + if hasattr(value, "model_dump"): + return value.model_dump(mode="json") + if isinstance(value, Mapping): + return dict(value) + return None + + +def _entity_data(simulation, entity: str): + if simulation.output_dataset is None or simulation.output_dataset.data is None: + simulation.ensure() + return getattr(simulation.output_dataset.data, entity) + + +def _sum_output_variable(simulation, variable: str, entity: str) -> float: + data = _entity_data(simulation, entity) + if variable in data.columns: + return float(data[variable].sum()) + + from policyengine.outputs import Aggregate, AggregateType + + output = Aggregate( + simulation=simulation, + variable=variable, + entity=entity, + aggregate_type=AggregateType.SUM, + ) + output.run() + return float(output.result) + + +def _change_output_variable(baseline, reform, variable: str, entity: str) -> float: + baseline_data = _entity_data(baseline, entity) + reform_data = _entity_data(reform, entity) + if variable in baseline_data.columns and variable in reform_data.columns: + return float((reform_data[variable] - baseline_data[variable]).sum()) + + from policyengine.outputs import ChangeAggregate, ChangeAggregateType + + output = ChangeAggregate( + baseline_simulation=baseline, + reform_simulation=reform, + variable=variable, + entity=entity, + aggregate_type=ChangeAggregateType.SUM, + ) + output.run() + return float(output.result) + + +def _output_module_function(module_name: str, name: str): + module = import_module(f"policyengine.outputs.{module_name}") + return getattr(module, name) + + +def _poverty_module_function(name: str): + return _output_module_function("poverty", name) + + +def _try_compute_output(label: str, fn, *args, **kwargs): + try: + return fn(*args, **kwargs) + except Exception: + logger.warning("Unable to calculate %s", label, exc_info=True) + return None diff --git a/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_distribution.py b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_distribution.py new file mode 100644 index 000000000..58aef26a3 --- /dev/null +++ b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_distribution.py @@ -0,0 +1,96 @@ +"""Distributional output segment builders.""" + +from __future__ import annotations + +from typing import Any + +from policyengine_api_simulation.simulation_macro_output import ( + DecileOutput, + IntraDecileOutput, +) +from policyengine_api_simulation.simulation_output_common import ( + _collection_records, + _number, + _try_compute_output, +) + +INTRA_DECILE_COLUMNS = { + "Lose more than 5%": "lose_more_than_5pct", + "Lose less than 5%": "lose_less_than_5pct", + "No change": "no_change", + "Gain less than 5%": "gain_less_than_5pct", + "Gain more than 5%": "gain_more_than_5pct", +} + + +def build_decile(analysis: Any) -> DecileOutput: + return build_decile_output(getattr(analysis, "decile_impacts", None)) + + +def build_decile_output(collection: Any) -> DecileOutput: + if isinstance(collection, DecileOutput): + return collection + average: dict[str, float] = {} + relative: dict[str, float] = {} + for row in sorted( + _collection_records(collection), + key=lambda item: _number(item.get("decile")), + ): + decile = int(_number(row.get("decile"))) + if decile <= 0: + continue + key = str(decile) + average[key] = _number(row.get("absolute_change")) + relative[key] = _number(row.get("relative_change")) + return DecileOutput(average=average, relative=relative) + + +def build_intra_decile_output(baseline, reform) -> IntraDecileOutput: + from policyengine.outputs.intra_decile_impact import ( + compute_intra_decile_impacts, + ) + + collection = _try_compute_output( + "intra-decile impacts", + compute_intra_decile_impacts, + baseline, + reform, + income_variable="household_net_income", + entity="household", + ) + return build_intra_decile_output_from_collection(collection) + + +def build_wealth_decile(country: str, wealth_decile: Any) -> DecileOutput | None: + if country != "uk": + return None + return build_decile_output(wealth_decile) + + +def build_intra_wealth_decile( + country: str, intra_wealth_decile: Any +) -> IntraDecileOutput | None: + if country != "uk": + return None + return build_intra_decile_output_from_collection(intra_wealth_decile) + + +def build_intra_decile_output_from_collection(collection: Any) -> IntraDecileOutput: + if isinstance(collection, IntraDecileOutput): + return collection + deciles: dict[str, list[float]] = {label: [] for label in INTRA_DECILE_COLUMNS} + all_values: dict[str, float] = {label: 0.0 for label in INTRA_DECILE_COLUMNS} + rows = [ + row + for row in sorted( + _collection_records(collection), + key=lambda item: _number(item.get("decile")), + ) + if int(_number(row.get("decile"))) > 0 + ] + + for label, column in INTRA_DECILE_COLUMNS.items(): + values = [_number(row.get(column)) for row in rows] + deciles[label] = values + all_values[label] = sum(values) / len(values) if values else 0.0 + return IntraDecileOutput(deciles=deciles, all=all_values) diff --git a/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_geographic.py b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_geographic.py new file mode 100644 index 000000000..463d9c323 --- /dev/null +++ b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_geographic.py @@ -0,0 +1,87 @@ +"""Geographic output segment builders.""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from policyengine_api_simulation.simulation_macro_output import GeographicImpactOutput +from policyengine_api_simulation.simulation_output_common import ( + _output_model_dump, + _output_module_function, + _try_compute_output, +) + + +def build_geographic_impact_output(value: Any) -> GeographicImpactOutput | None: + if isinstance(value, GeographicImpactOutput): + return value + records = _output_model_dump(value) + if isinstance(records, list): + return GeographicImpactOutput( + [dict(item) for item in records if isinstance(item, Mapping)] + ) + if isinstance(value, list): + return GeographicImpactOutput( + [dict(item) for item in value if isinstance(item, Mapping)] + ) + return None + + +def build_congressional_district_impact( + country: str, baseline, reform +) -> GeographicImpactOutput | None: + if country != "us": + return None + + from policyengine.outputs.congressional_district_impact import ( + compute_us_congressional_district_impacts, + ) + + impact = _try_compute_output( + "congressional district impacts", + compute_us_congressional_district_impacts, + baseline, + reform, + ) + return build_geographic_impact_output( + getattr(impact, "district_results", None) if impact is not None else None + ) + + +def build_uk_constituency_impact( + country: str, baseline, reform +) -> GeographicImpactOutput | None: + if country != "uk": + return None + + impact = _try_compute_output( + "constituency impacts", + _output_module_function("constituency_impact", "compute_uk_constituency_impacts"), + baseline, + reform, + ) + if impact is None: + return None + return build_geographic_impact_output(getattr(impact, "constituency_results", None)) + + +def build_uk_local_authority_impact( + country: str, baseline, reform +) -> GeographicImpactOutput | None: + if country != "uk": + return None + + impact = _try_compute_output( + "local authority impacts", + _output_module_function( + "local_authority_impact", "compute_uk_local_authority_impacts" + ), + baseline, + reform, + ) + if impact is None: + return None + return build_geographic_impact_output( + getattr(impact, "local_authority_results", None) + ) diff --git a/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_inequality.py b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_inequality.py new file mode 100644 index 000000000..e0c8d6a9b --- /dev/null +++ b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_inequality.py @@ -0,0 +1,32 @@ +"""Inequality output segment builders.""" + +from __future__ import annotations + +from typing import Any + +from policyengine_api_simulation.simulation_macro_output import ( + BaselineReformValue, + InequalityOutput, +) +from policyengine_api_simulation.simulation_output_common import _number + + +def build_inequality(analysis: Any) -> InequalityOutput: + baseline = getattr(analysis, "baseline_inequality", None) + reform = getattr(analysis, "reform_inequality", None) + if isinstance(baseline, InequalityOutput): + return baseline + return InequalityOutput( + gini=BaselineReformValue( + baseline=_number(getattr(baseline, "gini", None)), + reform=_number(getattr(reform, "gini", None)), + ), + top_10_pct_share=BaselineReformValue( + baseline=_number(getattr(baseline, "top_10_share", None)), + reform=_number(getattr(reform, "top_10_share", None)), + ), + top_1_pct_share=BaselineReformValue( + baseline=_number(getattr(baseline, "top_1_share", None)), + reform=_number(getattr(reform, "top_1_share", None)), + ), + ) diff --git a/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_labor.py b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_labor.py new file mode 100644 index 000000000..596abfdac --- /dev/null +++ b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_labor.py @@ -0,0 +1,18 @@ +"""Labor supply response output segment builders.""" + +from __future__ import annotations + +from typing import Any + +from policyengine_api_simulation.simulation_macro_output import ( + LaborSupplyResponseOutput, +) +from policyengine_api_simulation.simulation_output_common import _output_model_dump + + +def build_labor_supply_response(analysis: Any) -> LaborSupplyResponseOutput | None: + labor_supply_response = getattr(analysis, "labor_supply_response", None) + if isinstance(labor_supply_response, LaborSupplyResponseOutput): + return labor_supply_response + output = _output_model_dump(labor_supply_response) + return LaborSupplyResponseOutput(output) if isinstance(output, dict) else None diff --git a/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_poverty.py b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_poverty.py new file mode 100644 index 000000000..3ae54e86c --- /dev/null +++ b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_poverty.py @@ -0,0 +1,251 @@ +"""Poverty output segment builders.""" + +from __future__ import annotations + +from collections.abc import Iterable, Mapping +from typing import Any + +from policyengine_api_simulation.simulation_macro_output import ( + AgePovertyOutput, + BaselineReformValue, + GenderPovertyOutput, + PovertyByGenderOutput, + PovertyByRaceOutput, + PovertyModuleOutputs, + PovertyOutput, + RacePovertyOutput, +) +from policyengine_api_simulation.simulation_output_common import ( + _collection_records, + _number, + _poverty_module_function, + _try_compute_output, +) + +US_POVERTY_TYPES = { + "spm": "poverty", + "spm_deep": "deep_poverty", +} + +UK_POVERTY_TYPES = { + "relative_bhc": "poverty", + "absolute_bhc": "deep_poverty", +} + + +def _empty_baseline_reform_value() -> dict[str, float]: + return {"baseline": 0.0, "reform": 0.0} + + +def _empty_age_poverty() -> dict[str, dict[str, float]]: + return { + "child": _empty_baseline_reform_value(), + "adult": _empty_baseline_reform_value(), + "senior": _empty_baseline_reform_value(), + "all": _empty_baseline_reform_value(), + } + + +def _empty_gender_poverty() -> dict[str, dict[str, float]]: + return { + "male": _empty_baseline_reform_value(), + "female": _empty_baseline_reform_value(), + } + + +def _poverty_type(country: str, row: Mapping[str, Any]) -> str | None: + poverty_type = str(row.get("poverty_type") or "").lower() + if country == "us": + return US_POVERTY_TYPES.get(poverty_type) + return UK_POVERTY_TYPES.get(poverty_type) + + +def _fill_poverty_block( + *, + country: str, + output: dict[str, dict[str, dict[str, float]]], + baseline_records: Iterable[Mapping[str, Any]], + reform_records: Iterable[Mapping[str, Any]], + default_group: str, +) -> None: + for side, records in (("baseline", baseline_records), ("reform", reform_records)): + for row in records: + poverty_type = _poverty_type(country, row) + if poverty_type is None: + continue + if poverty_type not in output: + continue + group = str(row.get("filter_group") or default_group).lower() + if group not in output[poverty_type]: + continue + output[poverty_type][group][side] = _number(row.get("rate")) + + +def _age_poverty_output(values: dict[str, dict[str, float]]) -> AgePovertyOutput: + return AgePovertyOutput( + child=BaselineReformValue(**values["child"]), + adult=BaselineReformValue(**values["adult"]), + senior=BaselineReformValue(**values["senior"]), + all=BaselineReformValue(**values["all"]), + ) + + +def _gender_poverty_output( + values: dict[str, dict[str, float]], +) -> GenderPovertyOutput: + return GenderPovertyOutput( + male=BaselineReformValue(**values["male"]), + female=BaselineReformValue(**values["female"]), + ) + + +def _race_poverty_output(values: dict[str, dict[str, float]]) -> RacePovertyOutput: + return RacePovertyOutput( + white=BaselineReformValue(**values["white"]), + black=BaselineReformValue(**values["black"]), + hispanic=BaselineReformValue(**values["hispanic"]), + other=BaselineReformValue(**values["other"]), + ) + + +def build_poverty_outputs(country: str, baseline, reform, analysis: Any): + prefix = "us" if country == "us" else "uk" + baseline_poverty_by_age = _try_compute_output( + "baseline poverty by age", + _poverty_module_function(f"calculate_{prefix}_poverty_by_age"), + baseline, + ) + reform_poverty_by_age = _try_compute_output( + "reform poverty by age", + _poverty_module_function(f"calculate_{prefix}_poverty_by_age"), + reform, + ) + baseline_poverty_by_gender = _try_compute_output( + "baseline poverty by gender", + _poverty_module_function(f"calculate_{prefix}_poverty_by_gender"), + baseline, + ) + reform_poverty_by_gender = _try_compute_output( + "reform poverty by gender", + _poverty_module_function(f"calculate_{prefix}_poverty_by_gender"), + reform, + ) + baseline_poverty_by_race = None + reform_poverty_by_race = None + if country == "us": + baseline_poverty_by_race = _try_compute_output( + "baseline poverty by race", + _poverty_module_function("calculate_us_poverty_by_race"), + baseline, + ) + reform_poverty_by_race = _try_compute_output( + "reform poverty by race", + _poverty_module_function("calculate_us_poverty_by_race"), + reform, + ) + return PovertyModuleOutputs( + poverty=build_poverty_output( + country=country, + baseline=getattr(analysis, "baseline_poverty", None), + reform=getattr(analysis, "reform_poverty", None), + baseline_by_age=baseline_poverty_by_age, + reform_by_age=reform_poverty_by_age, + ), + poverty_by_gender=build_poverty_by_gender_output( + country=country, + baseline_by_gender=baseline_poverty_by_gender, + reform_by_gender=reform_poverty_by_gender, + ), + poverty_by_race=( + build_poverty_by_race_output( + baseline_by_race=baseline_poverty_by_race, + reform_by_race=reform_poverty_by_race, + ) + if country == "us" + else None + ), + ) + + +def build_poverty_output( + *, + country: str, + baseline: Any, + reform: Any, + baseline_by_age: Any, + reform_by_age: Any, +) -> PovertyOutput: + if isinstance(baseline, PovertyOutput): + return baseline + result = { + "poverty": _empty_age_poverty(), + "deep_poverty": _empty_age_poverty(), + } + _fill_poverty_block( + country=country, + output=result, + baseline_records=_collection_records(baseline), + reform_records=_collection_records(reform), + default_group="all", + ) + _fill_poverty_block( + country=country, + output=result, + baseline_records=_collection_records(baseline_by_age), + reform_records=_collection_records(reform_by_age), + default_group="all", + ) + return PovertyOutput( + poverty=_age_poverty_output(result["poverty"]), + deep_poverty=_age_poverty_output(result["deep_poverty"]), + ) + + +def build_poverty_by_gender_output( + *, + country: str, + baseline_by_gender: Any, + reform_by_gender: Any, +) -> PovertyByGenderOutput: + if isinstance(baseline_by_gender, PovertyByGenderOutput): + return baseline_by_gender + result = { + "poverty": _empty_gender_poverty(), + "deep_poverty": _empty_gender_poverty(), + } + _fill_poverty_block( + country=country, + output=result, + baseline_records=_collection_records(baseline_by_gender), + reform_records=_collection_records(reform_by_gender), + default_group="all", + ) + return PovertyByGenderOutput( + poverty=_gender_poverty_output(result["poverty"]), + deep_poverty=_gender_poverty_output(result["deep_poverty"]), + ) + + +def build_poverty_by_race_output( + *, + baseline_by_race: Any, + reform_by_race: Any, +) -> PovertyByRaceOutput: + if isinstance(baseline_by_race, PovertyByRaceOutput): + return baseline_by_race + result = { + "poverty": { + "white": _empty_baseline_reform_value(), + "black": _empty_baseline_reform_value(), + "hispanic": _empty_baseline_reform_value(), + "other": _empty_baseline_reform_value(), + } + } + _fill_poverty_block( + country="us", + output=result, + baseline_records=_collection_records(baseline_by_race), + reform_records=_collection_records(reform_by_race), + default_group="all", + ) + return PovertyByRaceOutput(poverty=_race_poverty_output(result["poverty"])) diff --git a/projects/policyengine-api-simulation/tests/test_simulation_output_builder.py b/projects/policyengine-api-simulation/tests/test_simulation_output_builder.py index 550032a7b..364b8871b 100644 --- a/projects/policyengine-api-simulation/tests/test_simulation_output_builder.py +++ b/projects/policyengine-api-simulation/tests/test_simulation_output_builder.py @@ -145,7 +145,7 @@ def compute(simulation): return compute monkeypatch.setattr( - "policyengine_api_simulation.simulation_output_builder._poverty_module_function", + "policyengine_api_simulation.simulation_output_poverty._poverty_module_function", fake_poverty_module_function, ) monkeypatch.setattr( @@ -606,7 +606,7 @@ def fail_change_output_variable(*args, **kwargs): raise RuntimeError("household_tax missing") monkeypatch.setattr( - "policyengine_api_simulation.simulation_output_builder._change_output_variable", + "policyengine_api_simulation.simulation_output_budget._change_output_variable", fail_change_output_variable, ) @@ -631,7 +631,7 @@ def compute(baseline_simulation, reform_simulation): return compute monkeypatch.setattr( - "policyengine_api_simulation.simulation_output_builder._output_module_function", + "policyengine_api_simulation.simulation_output_geographic._output_module_function", fake_output_module_function, ) @@ -666,7 +666,7 @@ def compute(baseline_simulation, reform_simulation): return compute monkeypatch.setattr( - "policyengine_api_simulation.simulation_output_builder._output_module_function", + "policyengine_api_simulation.simulation_output_geographic._output_module_function", fake_output_module_function, ) diff --git a/projects/policyengine-api-simulation/tests/test_standalone_simulation_contract.py b/projects/policyengine-api-simulation/tests/test_standalone_simulation_contract.py index f764c0a97..75e8b2ece 100644 --- a/projects/policyengine-api-simulation/tests/test_standalone_simulation_contract.py +++ b/projects/policyengine-api-simulation/tests/test_standalone_simulation_contract.py @@ -15,7 +15,15 @@ "policyengine_api_simulation.release_bundle", "policyengine_api_simulation.simulation", "policyengine_api_simulation.simulation_macro_output", + "policyengine_api_simulation.simulation_output_budget", "policyengine_api_simulation.simulation_output_builder", + "policyengine_api_simulation.simulation_output_cliff", + "policyengine_api_simulation.simulation_output_common", + "policyengine_api_simulation.simulation_output_distribution", + "policyengine_api_simulation.simulation_output_geographic", + "policyengine_api_simulation.simulation_output_inequality", + "policyengine_api_simulation.simulation_output_labor", + "policyengine_api_simulation.simulation_output_poverty", "policyengine_api_simulation.simulation_runtime", "policyengine_api_simulation.telemetry", ) From c8fef3179db0d20882656f1e5e3d9bacc9c86f0d Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 28 May 2026 19:36:12 +0200 Subject: [PATCH 22/23] chore: restore simulation package license file reference --- projects/policyengine-api-simulation/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/policyengine-api-simulation/pyproject.toml b/projects/policyengine-api-simulation/pyproject.toml index 2f8df2fce..13ce121b2 100644 --- a/projects/policyengine-api-simulation/pyproject.toml +++ b/projects/policyengine-api-simulation/pyproject.toml @@ -9,7 +9,7 @@ readme = "README.md" authors = [ {name = "PolicyEngine", email = "hello@policyengine.org"}, ] -license = "AGPL-3.0-only" +license = {file = "../../LICENSE"} requires-python = ">=3.13,<3.14" dependencies = [ "opentelemetry-instrumentation-sqlalchemy (>=0.51b0,<0.52)", From 993365e5ed70147cce7de70bed19294bce26096a Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 28 May 2026 20:03:18 +0200 Subject: [PATCH 23/23] style: format simulation API files --- .../simulation_output_geographic.py | 4 +++- .../tests/gateway/test_auth.py | 16 ++++++++++++---- .../tests/gateway/test_models.py | 8 ++------ .../tests/test_modal_scripts.py | 6 +++--- .../tests/test_modal_telemetry.py | 5 ++++- .../tests/test_update_version_registry.py | 8 ++------ 6 files changed, 26 insertions(+), 21 deletions(-) diff --git a/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_geographic.py b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_geographic.py index 463d9c323..c85e72fcd 100644 --- a/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_geographic.py +++ b/projects/policyengine-api-simulation/src/policyengine_api_simulation/simulation_output_geographic.py @@ -57,7 +57,9 @@ def build_uk_constituency_impact( impact = _try_compute_output( "constituency impacts", - _output_module_function("constituency_impact", "compute_uk_constituency_impacts"), + _output_module_function( + "constituency_impact", "compute_uk_constituency_impacts" + ), baseline, reform, ) diff --git a/projects/policyengine-api-simulation/tests/gateway/test_auth.py b/projects/policyengine-api-simulation/tests/gateway/test_auth.py index 0a6e4b3be..ca037b00a 100644 --- a/projects/policyengine-api-simulation/tests/gateway/test_auth.py +++ b/projects/policyengine-api-simulation/tests/gateway/test_auth.py @@ -332,7 +332,9 @@ def test__given_auth_optional_and_unset__then_guard_noops(self, monkeypatch): def test__given_partial_auth_config__then_guard_raises(self, monkeypatch): monkeypatch.delenv(auth_module.GATEWAY_AUTH_DISABLED_ENV, raising=False) monkeypatch.delenv(auth_module.GATEWAY_AUTH_REQUIRED_ENV, raising=False) - monkeypatch.setenv(auth_module.GATEWAY_AUTH_ISSUER_ENV, "https://issuer.example/") + monkeypatch.setenv( + auth_module.GATEWAY_AUTH_ISSUER_ENV, "https://issuer.example/" + ) monkeypatch.delenv(auth_module.GATEWAY_AUTH_AUDIENCE_ENV, raising=False) with pytest.raises(auth_module.AuthMisconfiguredError): @@ -350,7 +352,9 @@ def test__given_required_and_missing__then_guard_raises(self, monkeypatch): def test__given_required_and_configured__then_guard_noops(self, monkeypatch): monkeypatch.delenv(auth_module.GATEWAY_AUTH_DISABLED_ENV, raising=False) monkeypatch.setenv(auth_module.GATEWAY_AUTH_REQUIRED_ENV, "1") - monkeypatch.setenv(auth_module.GATEWAY_AUTH_ISSUER_ENV, "https://issuer.example/") + monkeypatch.setenv( + auth_module.GATEWAY_AUTH_ISSUER_ENV, "https://issuer.example/" + ) monkeypatch.setenv(auth_module.GATEWAY_AUTH_AUDIENCE_ENV, "aud") auth_module.enforce_auth_configured_guard() @@ -362,7 +366,9 @@ class TestIssuerNormalization: def test__given_issuer_without_slash__then_decoder_receives_normalized_value( self, monkeypatch ): - monkeypatch.setenv(auth_module.GATEWAY_AUTH_ISSUER_ENV, "https://issuer.example") + monkeypatch.setenv( + auth_module.GATEWAY_AUTH_ISSUER_ENV, "https://issuer.example" + ) monkeypatch.setenv(auth_module.GATEWAY_AUTH_AUDIENCE_ENV, "aud") auth_module.reset_decoder_cache() @@ -385,7 +391,9 @@ def fake_builder(issuer, audience): def test__given_issuer_with_slash__then_decoder_receives_unchanged_value( self, monkeypatch ): - monkeypatch.setenv(auth_module.GATEWAY_AUTH_ISSUER_ENV, "https://issuer.example/") + monkeypatch.setenv( + auth_module.GATEWAY_AUTH_ISSUER_ENV, "https://issuer.example/" + ) monkeypatch.setenv(auth_module.GATEWAY_AUTH_AUDIENCE_ENV, "aud") auth_module.reset_decoder_cache() diff --git a/projects/policyengine-api-simulation/tests/gateway/test_models.py b/projects/policyengine-api-simulation/tests/gateway/test_models.py index a8d6717fb..526555c25 100644 --- a/projects/policyengine-api-simulation/tests/gateway/test_models.py +++ b/projects/policyengine-api-simulation/tests/gateway/test_models.py @@ -283,9 +283,7 @@ def test_job_submit_response_creates_with_all_fields(self): assert response.poll_url == "/jobs/fc-abc123" assert response.country == "us" assert response.version == "1.459.0" - assert ( - response.resolved_app_name == "policyengine-simulation-py3-9-0" - ) + assert response.resolved_app_name == "policyengine-simulation-py3-9-0" assert response.policyengine_bundle.model_version == "1.459.0" assert response.policyengine_bundle.policyengine_version is None assert response.policyengine_bundle.dataset == ( @@ -357,9 +355,7 @@ def test_job_status_response_accepts_bundle_metadata(self): }, ) - assert ( - response.resolved_app_name == "policyengine-simulation-py3-9-0" - ) + assert response.resolved_app_name == "policyengine-simulation-py3-9-0" assert response.policyengine_bundle is not None assert response.policyengine_bundle.dataset == ( "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.115.5" diff --git a/projects/policyengine-api-simulation/tests/test_modal_scripts.py b/projects/policyengine-api-simulation/tests/test_modal_scripts.py index 4911ef4ab..476c234b2 100644 --- a/projects/policyengine-api-simulation/tests/test_modal_scripts.py +++ b/projects/policyengine-api-simulation/tests/test_modal_scripts.py @@ -516,6 +516,6 @@ def test_all_scripts_have_valid_syntax(self, all_modal_scripts): capture_output=True, text=True, ) - assert ( - result.returncode == 0 - ), f"{script.name} has syntax errors: {result.stderr}" + assert result.returncode == 0, ( + f"{script.name} has syntax errors: {result.stderr}" + ) diff --git a/projects/policyengine-api-simulation/tests/test_modal_telemetry.py b/projects/policyengine-api-simulation/tests/test_modal_telemetry.py index 59155537d..243a48a5f 100644 --- a/projects/policyengine-api-simulation/tests/test_modal_telemetry.py +++ b/projects/policyengine-api-simulation/tests/test_modal_telemetry.py @@ -1,4 +1,7 @@ -from policyengine_api_simulation.telemetry import TelemetryEnvelope, split_internal_payload +from policyengine_api_simulation.telemetry import ( + TelemetryEnvelope, + split_internal_payload, +) def test_split_internal_payload__removes_internal_fields(): diff --git a/projects/policyengine-api-simulation/tests/test_update_version_registry.py b/projects/policyengine-api-simulation/tests/test_update_version_registry.py index 1b4a3acf7..280f01f83 100644 --- a/projects/policyengine-api-simulation/tests/test_update_version_registry.py +++ b/projects/policyengine-api-simulation/tests/test_update_version_registry.py @@ -150,9 +150,7 @@ def fake_country_bundle_metadata( ), "model_version": "1.0.0" if country == "us" else "2.0.0", "data_package_name": ( - "policyengine-us-data" - if country == "us" - else "policyengine-uk-data" + "policyengine-us-data" if country == "us" else "policyengine-uk-data" ), "data_version": "3.0.0" if country == "us" else "4.0.0", "default_dataset": "default", @@ -171,9 +169,7 @@ def fake_country_bundle_metadata( policyengine_version="4.10.0", ) - snapshot = patched_modal[ - "main/simulation-api-app-release-bundles" - ].snapshot() + snapshot = patched_modal["main/simulation-api-app-release-bundles"].snapshot() metadata = snapshot["policyengine-simulation-py4-10-0"] assert snapshot["4.10.0"] == metadata assert metadata["policyengine_version"] == "4.10.0"