Skip to content

Commit 2f059ab

Browse files
authored
Merge pull request #486 from PolicyEngine/rollback-policyengine-443
Roll back PolicyEngine 4.4 schema change
2 parents 6985143 + be5af17 commit 2f059ab

10 files changed

Lines changed: 192 additions & 391 deletions

File tree

.github/scripts/modal-sync-secrets.sh

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,6 @@ if [ -n "${GCP_CREDENTIALS_JSON:-}" ]; then
6262
--force || true
6363
fi
6464

65-
uv run modal secret create policyengine-data-credentials \
66-
"HUGGING_FACE_TOKEN=${HUGGING_FACE_TOKEN:-}" \
67-
--env="$MODAL_ENV" \
68-
--force || true
69-
7065
# Sync gateway auth config. The gateway runtime only needs issuer/audience and
7166
# the explicit requirement flag; client credentials stay on the GitHub side and
7267
# are only used to mint integration-test tokens.

.github/workflows/modal-deploy.reusable.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ jobs:
6262
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
6363
LOGFIRE_TOKEN: ${{ secrets.LOGFIRE_TOKEN }}
6464
GCP_CREDENTIALS_JSON: ${{ secrets.GCP_CREDENTIALS_JSON }}
65-
HUGGING_FACE_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }}
6665
GATEWAY_AUTH_ISSUER: ${{ secrets.GATEWAY_AUTH_ISSUER }}
6766
GATEWAY_AUTH_AUDIENCE: ${{ secrets.GATEWAY_AUTH_AUDIENCE }}
6867
GATEWAY_AUTH_CLIENT_ID: ${{ secrets.GATEWAY_AUTH_CLIENT_ID }}

projects/policyengine-api-simulation/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ dependencies = [
1616
"pydantic-settings (>=2.7.1,<3.0.0)",
1717
"opentelemetry-instrumentation-fastapi (>=0.51b0,<0.52)",
1818
"policyengine-fastapi",
19-
"policyengine==4.4.3",
20-
"policyengine-core>=3.26.1",
19+
"policyengine==0.13.0",
20+
"policyengine-core>=3.23.5",
2121
"policyengine-uk==2.88.14",
2222
"policyengine-us==1.690.7",
2323
"tables>=3.10.2",

projects/policyengine-api-simulation/src/modal/app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
# Get versions from environment or use defaults
1717
US_VERSION = os.environ.get("POLICYENGINE_US_VERSION", "1.690.7")
18-
UK_VERSION = os.environ.get("POLICYENGINE_UK_VERSION", "2.88.0")
18+
UK_VERSION = os.environ.get("POLICYENGINE_UK_VERSION", "2.88.14")
1919

2020

2121
def get_app_name(us_version: str, uk_version: str) -> str:
@@ -49,7 +49,7 @@ def get_app_name(us_version: str, uk_version: str) -> str:
4949
.pip_install(
5050
f"policyengine-us=={US_VERSION}",
5151
f"policyengine-uk=={UK_VERSION}",
52-
"policyengine==4.4.3",
52+
"policyengine==0.13.0",
5353
"tables>=3.10.2",
5454
"logfire",
5555
)

projects/policyengine-api-simulation/src/modal/gateway/endpoints.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@
4343
"us": {
4444
"enhanced_cps": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.110.12",
4545
"enhanced_cps_2024": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.110.12",
46-
"cps": "hf://policyengine/policyengine-us-data/cps_2023.h5@1.77.0",
47-
"cps_2023": "hf://policyengine/policyengine-us-data/cps_2023.h5@1.77.0",
48-
"pooled_cps": "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.77.0",
49-
"pooled_3_year_cps_2023": "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.77.0",
46+
"cps": "hf://policyengine/policyengine-us-data/cps_2023.h5@1.110.12",
47+
"cps_2023": "hf://policyengine/policyengine-us-data/cps_2023.h5@1.110.12",
48+
"pooled_cps": "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.110.12",
49+
"pooled_3_year_cps_2023": "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.110.12",
5050
},
5151
"uk": {
5252
"enhanced_frs": "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.40.3",

projects/policyengine-api-simulation/src/modal/simulation.py

Lines changed: 15 additions & 275 deletions
Original file line numberDiff line numberDiff line change
@@ -10,41 +10,16 @@
1010
import logging
1111
import os
1212
import tempfile
13-
import importlib
14-
from typing import Any, Iterator
13+
from typing import Iterator
1514

16-
# policyengine.core is imported for every simulation. Without this guard,
17-
# importing the package pulls both country modules into the process; a US run
18-
# can then fail before it starts if UK private-data credentials are absent.
19-
os.environ.setdefault("POLICYENGINE_SKIP_COUNTRY_IMPORTS", "1")
15+
# Module-level imports - these are SNAPSHOTTED at image build time
16+
from policyengine.simulation import Simulation, SimulationOptions
2017

21-
try:
22-
from src.modal.telemetry import split_internal_payload
23-
except ModuleNotFoundError:
24-
from modal.telemetry import split_internal_payload
18+
from src.modal.telemetry import split_internal_payload
2519

2620
logger = logging.getLogger(__name__)
2721

2822

29-
DEFAULT_YEAR = 2026
30-
DATASET_ALIASES = {
31-
"us": {
32-
"enhanced_cps": "enhanced_cps_2024",
33-
"enhanced_cps_2024": "enhanced_cps_2024",
34-
"gs://policyengine-us-data/enhanced_cps_2024.h5": "enhanced_cps_2024",
35-
"hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5": "enhanced_cps_2024",
36-
"cps_small": "cps_small_2024",
37-
"cps_small_2024": "cps_small_2024",
38-
},
39-
"uk": {
40-
"enhanced_frs": "enhanced_frs_2023_24",
41-
"enhanced_frs_2023_24": "enhanced_frs_2023_24",
42-
"frs": "frs_2023_24",
43-
"frs_2023_24": "frs_2023_24",
44-
},
45-
}
46-
47-
4823
def _normalize_credentials_blob(creds_json: str) -> str:
4924
"""Return the raw JSON blob, decoding the outer escape if present.
5025
@@ -140,237 +115,6 @@ def run_simulation_impl(params: dict) -> dict:
140115
return _run_simulation_impl_core(params)
141116

142117

143-
def _parse_year(params: dict[str, Any]) -> int:
144-
value = params.get("time_period") or params.get("year") or DEFAULT_YEAR
145-
return int(value)
146-
147-
148-
def _normalise_period_key(period_key: Any) -> str:
149-
"""Convert legacy ``start.stop`` period keys to v4 effective dates."""
150-
text = str(period_key)
151-
parts = text.split(".")
152-
if len(parts) > 1 and len(parts[0]) == 10:
153-
return parts[0]
154-
return text
155-
156-
157-
def _normalise_reform(reform: dict[str, Any] | None) -> dict[str, Any] | None:
158-
if not reform:
159-
return None
160-
normalised: dict[str, Any] = {}
161-
for parameter, value in reform.items():
162-
if isinstance(value, dict):
163-
normalised[parameter] = {
164-
_normalise_period_key(period): period_value
165-
for period, period_value in value.items()
166-
}
167-
else:
168-
normalised[parameter] = value
169-
return normalised
170-
171-
172-
def _resolve_dataset_name(
173-
country: str, requested_data: str | None, subsample: int | None
174-
) -> str:
175-
if requested_data is None:
176-
return "enhanced_cps_2024" if country == "us" else "enhanced_frs_2023_24"
177-
178-
requested = requested_data.split("@", maxsplit=1)[0]
179-
return DATASET_ALIASES.get(country, {}).get(requested, requested_data)
180-
181-
182-
def _microframe_like(frame, weights: str):
183-
from microdf import MicroDataFrame
184-
185-
return MicroDataFrame(frame.copy(), weights=weights)
186-
187-
188-
def _person_group_column(person, entity: str) -> str:
189-
prefixed = f"person_{entity}_id"
190-
if prefixed in person.columns:
191-
return prefixed
192-
return f"{entity}_id"
193-
194-
195-
def _subsample_us_dataset(dataset, subsample: int | None):
196-
if not subsample:
197-
return dataset
198-
199-
from policyengine.tax_benefit_models.us.datasets import (
200-
PolicyEngineUSDataset,
201-
USYearData,
202-
)
203-
204-
dataset.load()
205-
data = dataset.data
206-
household = data.household.head(int(subsample)).copy()
207-
household_ids = set(household["household_id"])
208-
209-
person_household_col = _person_group_column(data.person, "household")
210-
person = data.person[data.person[person_household_col].isin(household_ids)].copy()
211-
212-
def group_subset(entity: str):
213-
person_col = _person_group_column(person, entity)
214-
entity_id_col = f"{entity}_id"
215-
ids = set(person[person_col])
216-
frame = getattr(data, entity)
217-
return frame[frame[entity_id_col].isin(ids)].copy()
218-
219-
subset_data = USYearData(
220-
person=_microframe_like(person, "person_weight"),
221-
marital_unit=_microframe_like(
222-
group_subset("marital_unit"), "marital_unit_weight"
223-
),
224-
family=_microframe_like(group_subset("family"), "family_weight"),
225-
spm_unit=_microframe_like(group_subset("spm_unit"), "spm_unit_weight"),
226-
tax_unit=_microframe_like(group_subset("tax_unit"), "tax_unit_weight"),
227-
household=_microframe_like(household, "household_weight"),
228-
)
229-
subset_path = os.path.join(
230-
os.environ.get("POLICYENGINE_DATA_FOLDER", "/tmp/policyengine-data"),
231-
f"{dataset.id}_subsample_{subsample}.h5",
232-
)
233-
return PolicyEngineUSDataset(
234-
id=f"{dataset.id}_subsample_{subsample}",
235-
name=f"{dataset.name} subsample {subsample}",
236-
description=dataset.description,
237-
filepath=subset_path,
238-
year=dataset.year,
239-
is_output_dataset=dataset.is_output_dataset,
240-
data=subset_data,
241-
)
242-
243-
244-
def _country_module(country: str):
245-
country = country.lower()
246-
if country == "us":
247-
return importlib.import_module("policyengine.tax_benefit_models.us")
248-
if country == "uk":
249-
return importlib.import_module("policyengine.tax_benefit_models.uk")
250-
raise ValueError(f"Unsupported country: {country}")
251-
252-
253-
def _load_dataset(params: dict[str, Any]):
254-
country = params.get("country", "us").lower()
255-
year = _parse_year(params)
256-
country_module = _country_module(country)
257-
dataset_name = _resolve_dataset_name(
258-
country, params.get("data"), params.get("subsample")
259-
)
260-
datasets = country_module.ensure_datasets(
261-
datasets=[dataset_name],
262-
years=[year],
263-
data_folder=os.environ.get(
264-
"POLICYENGINE_DATA_FOLDER", "/tmp/policyengine-data"
265-
),
266-
)
267-
dataset = next(iter(datasets.values()))
268-
if country == "us":
269-
return _subsample_us_dataset(dataset, params.get("subsample"))
270-
return dataset
271-
272-
273-
def _build_simulation(params: dict[str, Any], policy: dict[str, Any] | None):
274-
from policyengine.core import Simulation
275-
276-
country = params.get("country", "us").lower()
277-
country_module = _country_module(country)
278-
dataset = _load_dataset(params)
279-
return Simulation(
280-
dataset=dataset,
281-
tax_benefit_model_version=country_module.model,
282-
policy=policy,
283-
)
284-
285-
286-
def _change_sum(baseline, reform, variable: str, entity: str | None = None) -> float:
287-
from policyengine.outputs import ChangeAggregate, ChangeAggregateType
288-
289-
output = ChangeAggregate(
290-
baseline_simulation=baseline,
291-
reform_simulation=reform,
292-
variable=variable,
293-
entity=entity,
294-
aggregate_type=ChangeAggregateType.SUM,
295-
)
296-
output.run()
297-
return float(output.result)
298-
299-
300-
def _try_change_sum(
301-
baseline, reform, variable: str, entity: str | None = None
302-
) -> float:
303-
try:
304-
return _change_sum(baseline, reform, variable, entity)
305-
except Exception:
306-
logger.warning("Unable to calculate change for %s", variable, exc_info=True)
307-
return 0.0
308-
309-
310-
def _budget_result(country: str, baseline, reform) -> dict[str, float]:
311-
tax_revenue_impact = _try_change_sum(
312-
baseline, reform, "household_tax", entity="household"
313-
)
314-
benefit_spending_impact = _try_change_sum(
315-
baseline, reform, "household_benefits", entity="household"
316-
)
317-
budgetary_impact = tax_revenue_impact - benefit_spending_impact
318-
result = {
319-
"tax_revenue_impact": tax_revenue_impact,
320-
"benefit_spending_impact": benefit_spending_impact,
321-
"budgetary_impact": budgetary_impact,
322-
}
323-
if country == "us":
324-
result["state_tax_revenue_impact"] = _try_change_sum(
325-
baseline,
326-
reform,
327-
"household_state_income_tax",
328-
entity="tax_unit",
329-
)
330-
return result
331-
332-
333-
def _poverty_result(country: str, baseline, reform) -> dict[str, list[dict[str, Any]]]:
334-
country_module = _country_module(country)
335-
impact = country_module.economic_impact_analysis(baseline, reform)
336-
baseline_poverty = impact.baseline_poverty
337-
reform_poverty = impact.reform_poverty
338-
339-
return {
340-
"baseline": baseline_poverty.dataframe.to_dict("records"),
341-
"reform": reform_poverty.dataframe.to_dict("records"),
342-
}
343-
344-
345-
def _analysis_result(country: str, baseline, reform) -> dict[str, Any]:
346-
country_module = _country_module(country)
347-
analysis = country_module.economic_impact_analysis(baseline, reform)
348-
349-
return {
350-
"decile_impacts": analysis.decile_impacts.dataframe.to_dict("records"),
351-
"program_statistics": analysis.program_statistics.dataframe.to_dict("records"),
352-
"poverty": {
353-
"baseline": analysis.baseline_poverty.dataframe.to_dict("records"),
354-
"reform": analysis.reform_poverty.dataframe.to_dict("records"),
355-
},
356-
"inequality": {
357-
"baseline": _inequality_summary(analysis.baseline_inequality),
358-
"reform": _inequality_summary(analysis.reform_inequality),
359-
},
360-
}
361-
362-
363-
def _inequality_summary(inequality) -> dict[str, Any]:
364-
return {
365-
"income_variable": inequality.income_variable,
366-
"entity": inequality.entity,
367-
"gini": inequality.gini,
368-
"top_10_share": inequality.top_10_share,
369-
"top_1_share": inequality.top_1_share,
370-
"bottom_50_share": inequality.bottom_50_share,
371-
}
372-
373-
374118
def _run_simulation_impl_core(params: dict) -> dict:
375119
simulation_params, telemetry, metadata = split_internal_payload(params)
376120

@@ -383,21 +127,17 @@ def _run_simulation_impl_core(params: dict) -> dict:
383127
if metadata:
384128
logger.info("Received simulation metadata keys: %s", sorted(metadata))
385129

386-
country = simulation_params.get("country", "us").lower()
387-
baseline_policy = _normalise_reform(simulation_params.get("baseline"))
388-
reform_policy = _normalise_reform(simulation_params.get("reform"))
130+
# Validate and create simulation options
131+
options = SimulationOptions.model_validate(simulation_params)
132+
logger.info("Initialising simulation from input")
389133

390-
logger.info("Initialising baseline and reform simulations")
391-
baseline = _build_simulation(simulation_params, baseline_policy)
392-
reform = _build_simulation(simulation_params, reform_policy)
134+
# Create simulation instance
135+
simulation = Simulation(**options.model_dump())
136+
logger.info("Calculating comparison")
393137

394-
logger.info("Calculating economic impact")
395-
analysis = _analysis_result(country, baseline, reform)
396-
analysis["budget"] = _budget_result(country, baseline, reform)
397-
analysis["metadata"] = {
398-
"country": country,
399-
"year": _parse_year(simulation_params),
400-
"dataset": getattr(baseline.dataset, "filepath", None),
401-
}
138+
# Run the economy comparison calculation
139+
result = simulation.calculate_economy_comparison()
402140
logger.info("Comparison complete")
403-
return analysis
141+
142+
# Use mode='json' to ensure numpy arrays are converted to lists
143+
return result.model_dump(mode="json")

0 commit comments

Comments
 (0)