Skip to content

Commit 95e3c9c

Browse files
committed
Tighten budget-window batch state contract
1 parent 71f7952 commit 95e3c9c

5 files changed

Lines changed: 77 additions & 2 deletions

File tree

projects/policyengine-api-simulation/src/modal/budget_window_state.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,15 @@ def create_initial_batch_state(
4444
batch_job_id=batch_job_id,
4545
status="submitted",
4646
country=request.country,
47+
region=request.region,
4748
version=resolved_version,
49+
target=request.target,
4850
resolved_app_name=resolved_app_name,
4951
policyengine_bundle=bundle,
5052
start_year=request.start_year,
5153
window_size=request.window_size,
5254
max_parallel=request.max_parallel,
55+
request_payload=request.model_dump(exclude={"telemetry"}, mode="json"),
5356
years=years,
5457
queued_years=list(years),
5558
running_years=[],

projects/policyengine-api-simulation/src/modal/gateway/models.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Pydantic models for the Gateway API.
33
"""
44

5-
from typing import ClassVar, Literal, Optional
5+
from typing import Any, ClassVar, Literal, Optional
66

77
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
88

@@ -80,11 +80,12 @@ class BudgetWindowBatchRequest(GatewayRequestBase):
8080

8181
MAX_YEARS: ClassVar[int] = 75
8282
MAX_END_YEAR: ClassVar[int] = 2099
83+
MAX_PARALLEL: ClassVar[int] = 3
8384

8485
region: str
8586
start_year: str
8687
window_size: int = Field(ge=1, le=MAX_YEARS)
87-
max_parallel: int = Field(default=3, ge=1)
88+
max_parallel: int = Field(default=MAX_PARALLEL, ge=1, le=MAX_PARALLEL)
8889
target: Literal["general"] = "general"
8990

9091
@field_validator("start_year")
@@ -182,12 +183,15 @@ class BudgetWindowBatchState(BaseModel):
182183
batch_job_id: str
183184
status: str
184185
country: str
186+
region: str
185187
version: str
188+
target: Literal["general"] = "general"
186189
resolved_app_name: str
187190
policyengine_bundle: PolicyEngineBundle
188191
start_year: str
189192
window_size: int
190193
max_parallel: int
194+
request_payload: dict[str, Any] = Field(default_factory=dict)
191195
years: list[str] = Field(default_factory=list)
192196
queued_years: list[str] = Field(default_factory=list)
193197
running_years: list[str] = Field(default_factory=list)

projects/policyengine-api-simulation/tests/gateway/test_budget_window_state.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from src.modal.budget_window_state import (
44
build_batch_status_response,
55
create_initial_batch_state,
6+
get_batch_job_state,
7+
put_batch_job_state,
68
)
79
from src.modal.gateway.models import BudgetWindowBatchRequest, PolicyEngineBundle
810

@@ -14,6 +16,9 @@ def test_create_initial_batch_state_builds_queued_years_and_run_id():
1416
start_year="2026",
1517
window_size=3,
1618
max_parallel=2,
19+
dataset="enhanced_cps_2024",
20+
scope="macro",
21+
reform={},
1722
_telemetry={
1823
"run_id": "batch-run-123",
1924
"process_id": "proc-123",
@@ -31,8 +36,13 @@ def test_create_initial_batch_state_builds_queued_years_and_run_id():
3136

3237
assert state.batch_job_id == "fc-parent-123"
3338
assert state.status == "submitted"
39+
assert state.region == "us"
40+
assert state.target == "general"
3441
assert state.years == ["2026", "2027", "2028"]
3542
assert state.queued_years == ["2026", "2027", "2028"]
43+
assert state.request_payload["dataset"] == "enhanced_cps_2024"
44+
assert state.request_payload["scope"] == "macro"
45+
assert state.request_payload["reform"] == {}
3646
assert state.run_id == "batch-run-123"
3747

3848

@@ -61,3 +71,32 @@ def test_build_batch_status_response_computes_progress_from_completed_years():
6171
assert response.completed_years == ["2026", "2027"]
6272
assert response.running_years == ["2028"]
6373
assert response.queued_years == ["2029"]
74+
75+
76+
def test_batch_state_round_trips_through_modal_dict(mock_modal):
77+
request = BudgetWindowBatchRequest(
78+
country="us",
79+
region="state/ca",
80+
start_year="2026",
81+
window_size=2,
82+
max_parallel=2,
83+
scope="macro",
84+
reform={"foo": True},
85+
)
86+
87+
state = create_initial_batch_state(
88+
batch_job_id="fc-parent-123",
89+
request=request,
90+
resolved_version="1.500.0",
91+
resolved_app_name="policyengine-simulation-us1-500-0-uk2-66-0",
92+
bundle=PolicyEngineBundle(model_version="1.500.0"),
93+
)
94+
put_batch_job_state(state)
95+
96+
restored = get_batch_job_state("fc-parent-123")
97+
98+
assert restored is not None
99+
assert restored.region == "state/ca"
100+
assert restored.target == "general"
101+
assert restored.request_payload["scope"] == "macro"
102+
assert restored.request_payload["reform"] == {"foo": True}

projects/policyengine-api-simulation/tests/gateway/test_endpoints.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,3 +527,22 @@ def test__given_non_general_target__then_budget_window_submit_returns_422(
527527

528528
assert response.status_code == 422
529529
assert response.json()["detail"][0]["loc"] == ["body", "target"]
530+
531+
def test__given_max_parallel_above_active_limit__then_budget_window_submit_returns_422(
532+
self, mock_modal, client: TestClient
533+
):
534+
response = client.post(
535+
"/simulate/economy/budget-window",
536+
json={
537+
"country": "us",
538+
"region": "us",
539+
"scope": "macro",
540+
"reform": {},
541+
"start_year": "2026",
542+
"window_size": 3,
543+
"max_parallel": 4,
544+
},
545+
)
546+
547+
assert response.status_code == 422
548+
assert response.json()["detail"][0]["loc"] == ["body", "max_parallel"]

projects/policyengine-api-simulation/tests/gateway/test_models.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,16 @@ def test_budget_window_batch_request_rejects_non_general_target(self):
379379
target="cliff",
380380
)
381381

382+
def test_budget_window_batch_request_rejects_max_parallel_above_active_limit(self):
383+
with pytest.raises(ValidationError):
384+
BudgetWindowBatchRequest(
385+
country="us",
386+
region="us",
387+
start_year="2026",
388+
window_size=3,
389+
max_parallel=4,
390+
)
391+
382392
def test_budget_window_batch_request_accepts_extra_simulation_fields(self):
383393
request = BudgetWindowBatchRequest(
384394
country="us",

0 commit comments

Comments
 (0)