Skip to content

Commit f0afbc0

Browse files
authored
Update simulation API to PolicyEngine 4.4.3
Update simulation worker dependencies to PolicyEngine 4.4.3, policyengine-us 1.690.7, and matching runtime flow.
1 parent 1313a58 commit f0afbc0

8 files changed

Lines changed: 411 additions & 180 deletions

File tree

projects/policyengine-api-simulation/pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ dependencies = [
1616
"pydantic-settings (>=2.7.1,<3.0.0)",
1717
"opentelemetry-instrumentation-fastapi (>=0.51b0,<0.52)",
1818
"policyengine-fastapi",
19-
"policyengine==0.13.0",
20-
"policyengine-core>=3.23.5",
21-
"policyengine-uk==2.88.0",
22-
"policyengine-us==1.653.3",
19+
"policyengine==4.4.3",
20+
"policyengine-core>=3.26.1",
21+
"policyengine-uk==2.88.14",
22+
"policyengine-us==1.690.7",
2323
"tables>=3.10.2",
2424
"modal>=0.73.0",
2525
"logfire>=3.0.0",

projects/policyengine-api-simulation/src/modal/app.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from src.modal.logging_redaction import redact_params_for_logging
1515

1616
# Get versions from environment or use defaults
17-
US_VERSION = os.environ.get("POLICYENGINE_US_VERSION", "1.562.3")
18-
UK_VERSION = os.environ.get("POLICYENGINE_UK_VERSION", "2.65.9")
17+
US_VERSION = os.environ.get("POLICYENGINE_US_VERSION", "1.690.7")
18+
UK_VERSION = os.environ.get("POLICYENGINE_UK_VERSION", "2.88.14")
1919

2020

2121
def get_app_name(us_version: str, uk_version: str) -> str:
@@ -48,7 +48,7 @@ def get_app_name(us_version: str, uk_version: str) -> str:
4848
.pip_install(
4949
f"policyengine-us=={US_VERSION}",
5050
f"policyengine-uk=={UK_VERSION}",
51-
"policyengine==0.13.0",
51+
"policyengine==4.4.3",
5252
"tables>=3.10.2",
5353
"logfire",
5454
)

projects/policyengine-api-simulation/src/modal/gateway/endpoints.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@
4141
JOB_METADATA_DICT_NAME = "simulation-api-job-metadata"
4242
DATASET_URIS = {
4343
"us": {
44-
"enhanced_cps": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0",
45-
"enhanced_cps_2024": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0",
44+
"enhanced_cps": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.110.12",
45+
"enhanced_cps_2024": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.110.12",
4646
"cps": "hf://policyengine/policyengine-us-data/cps_2023.h5@1.77.0",
4747
"cps_2023": "hf://policyengine/policyengine-us-data/cps_2023.h5@1.77.0",
4848
"pooled_cps": "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.77.0",

projects/policyengine-api-simulation/src/modal/simulation.py

Lines changed: 284 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,35 @@
1010
import logging
1111
import os
1212
import tempfile
13-
from typing import Iterator
13+
from typing import Any, Iterator
1414

15-
# Module-level imports - these are SNAPSHOTTED at image build time
16-
from policyengine.simulation import Simulation, SimulationOptions
17-
18-
from src.modal.telemetry import split_internal_payload
15+
try:
16+
from src.modal.telemetry import split_internal_payload
17+
except ModuleNotFoundError:
18+
from modal.telemetry import split_internal_payload
1919

2020
logger = logging.getLogger(__name__)
2121

2222

23+
DEFAULT_YEAR = 2026
24+
DATASET_ALIASES = {
25+
"us": {
26+
"enhanced_cps": "enhanced_cps_2024",
27+
"enhanced_cps_2024": "enhanced_cps_2024",
28+
"gs://policyengine-us-data/enhanced_cps_2024.h5": "enhanced_cps_2024",
29+
"hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5": "enhanced_cps_2024",
30+
"cps_small": "cps_small_2024",
31+
"cps_small_2024": "cps_small_2024",
32+
},
33+
"uk": {
34+
"enhanced_frs": "enhanced_frs_2023_24",
35+
"enhanced_frs_2023_24": "enhanced_frs_2023_24",
36+
"frs": "frs_2023_24",
37+
"frs_2023_24": "frs_2023_24",
38+
},
39+
}
40+
41+
2342
def _normalize_credentials_blob(creds_json: str) -> str:
2443
"""Return the raw JSON blob, decoding the outer escape if present.
2544
@@ -115,6 +134,251 @@ def run_simulation_impl(params: dict) -> dict:
115134
return _run_simulation_impl_core(params)
116135

117136

137+
def _parse_year(params: dict[str, Any]) -> int:
138+
value = params.get("time_period") or params.get("year") or DEFAULT_YEAR
139+
return int(value)
140+
141+
142+
def _normalise_period_key(period_key: Any) -> str:
143+
"""Convert legacy ``start.stop`` period keys to v4 effective dates."""
144+
text = str(period_key)
145+
parts = text.split(".")
146+
if len(parts) > 1 and len(parts[0]) == 10:
147+
return parts[0]
148+
return text
149+
150+
151+
def _normalise_reform(reform: dict[str, Any] | None) -> dict[str, Any] | None:
152+
if not reform:
153+
return None
154+
normalised: dict[str, Any] = {}
155+
for parameter, value in reform.items():
156+
if isinstance(value, dict):
157+
normalised[parameter] = {
158+
_normalise_period_key(period): period_value
159+
for period, period_value in value.items()
160+
}
161+
else:
162+
normalised[parameter] = value
163+
return normalised
164+
165+
166+
def _resolve_dataset_name(
167+
country: str, requested_data: str | None, subsample: int | None
168+
) -> str:
169+
if requested_data is None:
170+
return "enhanced_cps_2024" if country == "us" else "enhanced_frs_2023_24"
171+
172+
requested = requested_data.split("@", maxsplit=1)[0]
173+
return DATASET_ALIASES.get(country, {}).get(requested, requested_data)
174+
175+
176+
def _microframe_like(frame, weights: str):
177+
from microdf import MicroDataFrame
178+
179+
return MicroDataFrame(frame.copy(), weights=weights)
180+
181+
182+
def _person_group_column(person, entity: str) -> str:
183+
prefixed = f"person_{entity}_id"
184+
if prefixed in person.columns:
185+
return prefixed
186+
return f"{entity}_id"
187+
188+
189+
def _subsample_us_dataset(dataset, subsample: int | None):
190+
if not subsample:
191+
return dataset
192+
193+
from policyengine.tax_benefit_models.us.datasets import (
194+
PolicyEngineUSDataset,
195+
USYearData,
196+
)
197+
198+
dataset.load()
199+
data = dataset.data
200+
household = data.household.head(int(subsample)).copy()
201+
household_ids = set(household["household_id"])
202+
203+
person_household_col = _person_group_column(data.person, "household")
204+
person = data.person[data.person[person_household_col].isin(household_ids)].copy()
205+
206+
def group_subset(entity: str):
207+
person_col = _person_group_column(person, entity)
208+
entity_id_col = f"{entity}_id"
209+
ids = set(person[person_col])
210+
frame = getattr(data, entity)
211+
return frame[frame[entity_id_col].isin(ids)].copy()
212+
213+
subset_data = USYearData(
214+
person=_microframe_like(person, "person_weight"),
215+
marital_unit=_microframe_like(
216+
group_subset("marital_unit"), "marital_unit_weight"
217+
),
218+
family=_microframe_like(group_subset("family"), "family_weight"),
219+
spm_unit=_microframe_like(group_subset("spm_unit"), "spm_unit_weight"),
220+
tax_unit=_microframe_like(group_subset("tax_unit"), "tax_unit_weight"),
221+
household=_microframe_like(household, "household_weight"),
222+
)
223+
subset_path = os.path.join(
224+
os.environ.get("POLICYENGINE_DATA_FOLDER", "/tmp/policyengine-data"),
225+
f"{dataset.id}_subsample_{subsample}.h5",
226+
)
227+
return PolicyEngineUSDataset(
228+
id=f"{dataset.id}_subsample_{subsample}",
229+
name=f"{dataset.name} subsample {subsample}",
230+
description=dataset.description,
231+
filepath=subset_path,
232+
year=dataset.year,
233+
is_output_dataset=dataset.is_output_dataset,
234+
data=subset_data,
235+
)
236+
237+
238+
def _country_module(country: str):
239+
import policyengine as pe
240+
241+
country = country.lower()
242+
if country == "us":
243+
return pe.us
244+
if country == "uk":
245+
return pe.uk
246+
raise ValueError(f"Unsupported country: {country}")
247+
248+
249+
def _load_dataset(params: dict[str, Any]):
250+
country = params.get("country", "us").lower()
251+
year = _parse_year(params)
252+
country_module = _country_module(country)
253+
dataset_name = _resolve_dataset_name(
254+
country, params.get("data"), params.get("subsample")
255+
)
256+
datasets = country_module.ensure_datasets(
257+
datasets=[dataset_name],
258+
years=[year],
259+
data_folder=os.environ.get(
260+
"POLICYENGINE_DATA_FOLDER", "/tmp/policyengine-data"
261+
),
262+
)
263+
dataset = next(iter(datasets.values()))
264+
if country == "us":
265+
return _subsample_us_dataset(dataset, params.get("subsample"))
266+
return dataset
267+
268+
269+
def _build_simulation(params: dict[str, Any], policy: dict[str, Any] | None):
270+
from policyengine.core import Simulation
271+
272+
country = params.get("country", "us").lower()
273+
country_module = _country_module(country)
274+
dataset = _load_dataset(params)
275+
return Simulation(
276+
dataset=dataset,
277+
tax_benefit_model_version=country_module.model,
278+
policy=policy,
279+
)
280+
281+
282+
def _change_sum(baseline, reform, variable: str, entity: str | None = None) -> float:
283+
from policyengine.outputs import ChangeAggregate, ChangeAggregateType
284+
285+
output = ChangeAggregate(
286+
baseline_simulation=baseline,
287+
reform_simulation=reform,
288+
variable=variable,
289+
entity=entity,
290+
aggregate_type=ChangeAggregateType.SUM,
291+
)
292+
output.run()
293+
return float(output.result)
294+
295+
296+
def _try_change_sum(
297+
baseline, reform, variable: str, entity: str | None = None
298+
) -> float:
299+
try:
300+
return _change_sum(baseline, reform, variable, entity)
301+
except Exception:
302+
logger.warning("Unable to calculate change for %s", variable, exc_info=True)
303+
return 0.0
304+
305+
306+
def _budget_result(country: str, baseline, reform) -> dict[str, float]:
307+
tax_revenue_impact = _try_change_sum(
308+
baseline, reform, "household_tax", entity="household"
309+
)
310+
benefit_spending_impact = _try_change_sum(
311+
baseline, reform, "household_benefits", entity="household"
312+
)
313+
budgetary_impact = tax_revenue_impact - benefit_spending_impact
314+
result = {
315+
"tax_revenue_impact": tax_revenue_impact,
316+
"benefit_spending_impact": benefit_spending_impact,
317+
"budgetary_impact": budgetary_impact,
318+
}
319+
if country == "us":
320+
result["state_tax_revenue_impact"] = _try_change_sum(
321+
baseline,
322+
reform,
323+
"household_state_income_tax",
324+
entity="tax_unit",
325+
)
326+
return result
327+
328+
329+
def _poverty_result(country: str, baseline, reform) -> dict[str, list[dict[str, Any]]]:
330+
import policyengine as pe
331+
332+
if country == "us":
333+
baseline_poverty = pe.us.economic_impact_analysis(
334+
baseline, reform
335+
).baseline_poverty
336+
reform_poverty = pe.us.economic_impact_analysis(baseline, reform).reform_poverty
337+
else:
338+
baseline_poverty = pe.uk.economic_impact_analysis(
339+
baseline, reform
340+
).baseline_poverty
341+
reform_poverty = pe.uk.economic_impact_analysis(baseline, reform).reform_poverty
342+
343+
return {
344+
"baseline": baseline_poverty.dataframe.to_dict("records"),
345+
"reform": reform_poverty.dataframe.to_dict("records"),
346+
}
347+
348+
349+
def _analysis_result(country: str, baseline, reform) -> dict[str, Any]:
350+
import policyengine as pe
351+
352+
if country == "us":
353+
analysis = pe.us.economic_impact_analysis(baseline, reform)
354+
else:
355+
analysis = pe.uk.economic_impact_analysis(baseline, reform)
356+
357+
return {
358+
"decile_impacts": analysis.decile_impacts.dataframe.to_dict("records"),
359+
"program_statistics": analysis.program_statistics.dataframe.to_dict("records"),
360+
"poverty": {
361+
"baseline": analysis.baseline_poverty.dataframe.to_dict("records"),
362+
"reform": analysis.reform_poverty.dataframe.to_dict("records"),
363+
},
364+
"inequality": {
365+
"baseline": _inequality_summary(analysis.baseline_inequality),
366+
"reform": _inequality_summary(analysis.reform_inequality),
367+
},
368+
}
369+
370+
371+
def _inequality_summary(inequality) -> dict[str, Any]:
372+
return {
373+
"income_variable": inequality.income_variable,
374+
"entity": inequality.entity,
375+
"gini": inequality.gini,
376+
"top_10_share": inequality.top_10_share,
377+
"top_1_share": inequality.top_1_share,
378+
"bottom_50_share": inequality.bottom_50_share,
379+
}
380+
381+
118382
def _run_simulation_impl_core(params: dict) -> dict:
119383
simulation_params, telemetry, metadata = split_internal_payload(params)
120384

@@ -127,17 +391,21 @@ def _run_simulation_impl_core(params: dict) -> dict:
127391
if metadata:
128392
logger.info("Received simulation metadata keys: %s", sorted(metadata))
129393

130-
# Validate and create simulation options
131-
options = SimulationOptions.model_validate(simulation_params)
132-
logger.info("Initialising simulation from input")
394+
country = simulation_params.get("country", "us").lower()
395+
baseline_policy = _normalise_reform(simulation_params.get("baseline"))
396+
reform_policy = _normalise_reform(simulation_params.get("reform"))
133397

134-
# Create simulation instance
135-
simulation = Simulation(**options.model_dump())
136-
logger.info("Calculating comparison")
398+
logger.info("Initialising baseline and reform simulations")
399+
baseline = _build_simulation(simulation_params, baseline_policy)
400+
reform = _build_simulation(simulation_params, reform_policy)
137401

138-
# Run the economy comparison calculation
139-
result = simulation.calculate_economy_comparison()
402+
logger.info("Calculating economic impact")
403+
analysis = _analysis_result(country, baseline, reform)
404+
analysis["budget"] = _budget_result(country, baseline, reform)
405+
analysis["metadata"] = {
406+
"country": country,
407+
"year": _parse_year(simulation_params),
408+
"dataset": getattr(baseline.dataset, "filepath", None),
409+
}
140410
logger.info("Comparison complete")
141-
142-
# Use mode='json' to ensure numpy arrays are converted to lists
143-
return result.model_dump(mode="json")
411+
return analysis

0 commit comments

Comments
 (0)