Skip to content

Commit 5a8e95b

Browse files
authored
Merge pull request #445 from PolicyEngine/feat/multi-year-sim-api-orchestration
Implement budget-window batch orchestration
2 parents 2ef0ea2 + 53ef9b1 commit 5a8e95b

15 files changed

Lines changed: 1732 additions & 50 deletions

projects/policyengine-api-simulation/fixtures/gateway/test_endpoints.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,29 @@ def __init__(self):
5959
self.last_payload = None
6060
self.last_from_name_call = None
6161
self.last_call = None
62+
self.calls = []
6263

63-
def spawn(self, payload: dict) -> MockFunctionCall:
64-
self.last_payload = payload
65-
self.last_call = MockFunctionCall()
66-
return self.last_call
64+
def bind(self, app_name: str, func_name: str) -> "BoundMockFunction":
65+
return BoundMockFunction(self, app_name, func_name)
6766

68-
@classmethod
69-
def from_name(cls, app_name: str, func_name: str):
70-
"""Mock from_name that returns a MockFunction."""
71-
raise NotImplementedError("Mock not configured")
67+
68+
class BoundMockFunction:
69+
"""Function handle returned by Modal.Function.from_name."""
70+
71+
def __init__(self, recorder: MockFunction, app_name: str, func_name: str):
72+
self.recorder = recorder
73+
self.app_name = app_name
74+
self.func_name = func_name
75+
76+
def spawn(self, payload: dict) -> MockFunctionCall:
77+
self.recorder.last_payload = payload
78+
is_batch = self.func_name == "run_budget_window_batch"
79+
object_id = "mock-batch-job-id-123" if is_batch else "mock-job-id-123"
80+
self.recorder.last_call = MockFunctionCall(object_id=object_id)
81+
if is_batch:
82+
self.recorder.last_call.running = True
83+
self.recorder.calls.append((self.app_name, self.func_name, payload, object_id))
84+
return self.recorder.last_call
7285

7386

7487
@pytest.fixture
@@ -94,7 +107,7 @@ class MockModalFunction:
94107
@staticmethod
95108
def from_name(app_name: str, func_name: str):
96109
mock_func.last_from_name_call = (app_name, func_name)
97-
return mock_func
110+
return mock_func.bind(app_name, func_name)
98111

99112
class MockModal:
100113
Dict = MockModalDict

projects/policyengine-api-simulation/src/modal/app.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,32 @@ def run_simulation(params: dict) -> dict:
104104
return result
105105
finally:
106106
logfire.force_flush()
107+
108+
109+
@app.function(
110+
image=simulation_image,
111+
cpu=1.0,
112+
memory=4096,
113+
timeout=3600,
114+
retries=0,
115+
max_containers=100,
116+
secrets=[gcp_secret, logfire_secret],
117+
)
118+
def run_budget_window_batch(params: dict) -> dict:
119+
"""Execute a multi-year budget-window batch orchestration."""
120+
import logfire
121+
122+
from src.modal.budget_window_batch import run_budget_window_batch_impl
123+
124+
configure_logfire()
125+
126+
try:
127+
with logfire.span(
128+
"run_budget_window_batch",
129+
input_params=params,
130+
) as span:
131+
result = run_budget_window_batch_impl(params)
132+
span.set_attribute("output_result", result)
133+
return result
134+
finally:
135+
logfire.force_flush()
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""Thin entrypoint for budget-window batch execution."""
2+
3+
from __future__ import annotations
4+
5+
from typing import Any
6+
7+
import modal
8+
9+
from src.modal.budget_window_context import build_batch_context
10+
from src.modal.budget_window_scheduler import BudgetWindowBatchRunner
11+
12+
13+
def run_budget_window_batch_impl(params: dict[str, Any]) -> dict[str, Any]:
14+
context = build_batch_context(
15+
params,
16+
batch_job_id=modal.current_function_call_id(),
17+
)
18+
runner = BudgetWindowBatchRunner(context)
19+
return runner.run()
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
"""Typed context and child request helpers for budget-window batches."""
2+
3+
from __future__ import annotations
4+
5+
from dataclasses import dataclass
6+
from typing import Any
7+
8+
from src.modal.gateway.models import BudgetWindowBatchRequest, PolicyEngineBundle
9+
10+
BATCH_ONLY_FIELDS = {
11+
"version",
12+
"start_year",
13+
"window_size",
14+
"max_parallel",
15+
"target",
16+
"_metadata",
17+
"_telemetry",
18+
}
19+
20+
21+
@dataclass(frozen=True)
22+
class BudgetWindowBatchContext:
23+
"""Resolved parent batch execution context."""
24+
25+
batch_job_id: str
26+
request: BudgetWindowBatchRequest
27+
resolved_version: str
28+
resolved_app_name: str
29+
bundle: PolicyEngineBundle
30+
raw_params: dict[str, Any]
31+
32+
33+
@dataclass(frozen=True)
34+
class ChildSimulationRequest:
35+
"""Expanded single-year child simulation request."""
36+
37+
simulation_year: str
38+
payload: dict[str, Any]
39+
40+
41+
@dataclass
42+
class ChildSimulationHandle:
43+
"""Tracked child job handle for a single simulation year."""
44+
45+
simulation_year: str
46+
job_id: str
47+
call: Any | None = None
48+
49+
50+
def build_batch_context(
51+
params: dict[str, Any],
52+
*,
53+
batch_job_id: str,
54+
) -> BudgetWindowBatchContext:
55+
request = BudgetWindowBatchRequest.model_validate(params)
56+
metadata = params.get("_metadata")
57+
if not isinstance(metadata, dict):
58+
raise ValueError("Missing internal batch metadata")
59+
60+
resolved_app_name = metadata.get("resolved_app_name")
61+
resolved_version = metadata.get("resolved_version")
62+
bundle_payload = metadata.get("policyengine_bundle")
63+
64+
if not isinstance(resolved_app_name, str) or not resolved_app_name:
65+
raise ValueError("Missing resolved_app_name in batch metadata")
66+
if not isinstance(resolved_version, str) or not resolved_version:
67+
raise ValueError("Missing resolved_version in batch metadata")
68+
if not isinstance(bundle_payload, dict):
69+
raise ValueError("Missing policyengine_bundle in batch metadata")
70+
71+
return BudgetWindowBatchContext(
72+
batch_job_id=batch_job_id,
73+
request=request,
74+
resolved_version=resolved_version,
75+
resolved_app_name=resolved_app_name,
76+
bundle=PolicyEngineBundle.model_validate(bundle_payload),
77+
raw_params=params,
78+
)
79+
80+
81+
def build_child_simulation_request(
82+
context: BudgetWindowBatchContext,
83+
*,
84+
simulation_year: str,
85+
) -> ChildSimulationRequest:
86+
payload = {
87+
key: value
88+
for key, value in context.raw_params.items()
89+
if key not in BATCH_ONLY_FIELDS
90+
}
91+
payload["time_period"] = simulation_year
92+
93+
telemetry = context.raw_params.get("_telemetry")
94+
if isinstance(telemetry, dict):
95+
payload["_telemetry"] = telemetry
96+
97+
return ChildSimulationRequest(
98+
simulation_year=simulation_year,
99+
payload=payload,
100+
)
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
"""Budget-window annual result extraction and aggregation helpers."""
2+
3+
from __future__ import annotations
4+
5+
from typing import Any
6+
7+
from src.modal.gateway.models import (
8+
BudgetWindowAnnualImpact,
9+
BudgetWindowResult,
10+
BudgetWindowTotals,
11+
)
12+
13+
REQUIRED_BUDGET_KEYS = (
14+
"tax_revenue_impact",
15+
"state_tax_revenue_impact",
16+
"benefit_spending_impact",
17+
"budgetary_impact",
18+
)
19+
20+
21+
def extract_annual_impact(
22+
*,
23+
simulation_year: str,
24+
child_result: dict[str, Any],
25+
) -> BudgetWindowAnnualImpact:
26+
budget = child_result.get("budget", {})
27+
if not isinstance(budget, dict):
28+
raise ValueError("Malformed budget-window child result: missing budget object")
29+
30+
missing_keys = [
31+
key
32+
for key in REQUIRED_BUDGET_KEYS
33+
if not isinstance(budget.get(key), int | float)
34+
]
35+
if missing_keys:
36+
missing = ", ".join(f"budget.{key}" for key in missing_keys)
37+
raise ValueError(
38+
f"Malformed budget-window child result: missing numeric {missing}"
39+
)
40+
41+
state_tax_revenue_impact = budget["state_tax_revenue_impact"]
42+
tax_revenue_impact = budget["tax_revenue_impact"]
43+
44+
return BudgetWindowAnnualImpact(
45+
year=simulation_year,
46+
taxRevenueImpact=tax_revenue_impact,
47+
federalTaxRevenueImpact=tax_revenue_impact - state_tax_revenue_impact,
48+
stateTaxRevenueImpact=state_tax_revenue_impact,
49+
benefitSpendingImpact=budget["benefit_spending_impact"],
50+
budgetaryImpact=budget["budgetary_impact"],
51+
)
52+
53+
54+
def sum_annual_impacts(
55+
annual_impacts: list[BudgetWindowAnnualImpact],
56+
) -> BudgetWindowTotals:
57+
totals = {
58+
"taxRevenueImpact": 0,
59+
"federalTaxRevenueImpact": 0,
60+
"stateTaxRevenueImpact": 0,
61+
"benefitSpendingImpact": 0,
62+
"budgetaryImpact": 0,
63+
}
64+
65+
for annual_impact in annual_impacts:
66+
totals["taxRevenueImpact"] += annual_impact.taxRevenueImpact
67+
totals["federalTaxRevenueImpact"] += annual_impact.federalTaxRevenueImpact
68+
totals["stateTaxRevenueImpact"] += annual_impact.stateTaxRevenueImpact
69+
totals["benefitSpendingImpact"] += annual_impact.benefitSpendingImpact
70+
totals["budgetaryImpact"] += annual_impact.budgetaryImpact
71+
72+
return BudgetWindowTotals(**totals)
73+
74+
75+
def build_budget_window_result(
76+
*,
77+
start_year: str,
78+
window_size: int,
79+
annual_impacts: list[BudgetWindowAnnualImpact],
80+
) -> BudgetWindowResult:
81+
return BudgetWindowResult(
82+
startYear=start_year,
83+
endYear=str(int(start_year) + window_size - 1),
84+
windowSize=window_size,
85+
annualImpacts=annual_impacts,
86+
totals=sum_annual_impacts(annual_impacts),
87+
)

0 commit comments

Comments
 (0)