Skip to content

Commit 1313a58

Browse files
authored
Merge pull request #477 from PolicyEngine/fix/budget-window-batch-model-import
Fix budget-window batch worker gateway imports
2 parents 614f233 + 8ae055d commit 1313a58

6 files changed

Lines changed: 290 additions & 10 deletions

File tree

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""Fixtures for gateway package import regression tests."""
2+
3+
from __future__ import annotations
4+
5+
import importlib
6+
import sys
7+
from collections.abc import Iterator
8+
from dataclasses import dataclass
9+
10+
import pytest
11+
12+
13+
GATEWAY_MODEL_MODULE = "src.modal.gateway.models"
14+
GATEWAY_ENDPOINTS_MODULE = "src.modal.gateway.endpoints"
15+
GATEWAY_PACKAGE_MODULE = "src.modal.gateway"
16+
FASTAPI_MODULE = "fastapi"
17+
18+
GATEWAY_MODEL_IMPORT_MODULES = (
19+
FASTAPI_MODULE,
20+
GATEWAY_PACKAGE_MODULE,
21+
GATEWAY_ENDPOINTS_MODULE,
22+
GATEWAY_MODEL_MODULE,
23+
)
24+
25+
26+
@dataclass(frozen=True)
27+
class GatewayImportModuleNames:
28+
"""Module names involved in the gateway model import boundary."""
29+
30+
endpoints: str = GATEWAY_ENDPOINTS_MODULE
31+
fastapi: str = FASTAPI_MODULE
32+
33+
34+
@pytest.fixture()
35+
def gateway_import_module_names() -> GatewayImportModuleNames:
36+
return GatewayImportModuleNames()
37+
38+
39+
@pytest.fixture()
40+
def isolated_gateway_model_import_modules() -> Iterator[None]:
41+
"""Temporarily clear modules that would mask import side effects."""
42+
previous_modules = {
43+
module_name: sys.modules.pop(module_name, None)
44+
for module_name in GATEWAY_MODEL_IMPORT_MODULES
45+
}
46+
47+
try:
48+
yield
49+
finally:
50+
for module_name in GATEWAY_MODEL_IMPORT_MODULES:
51+
sys.modules.pop(module_name, None)
52+
sys.modules.update(
53+
{
54+
module_name: module
55+
for module_name, module in previous_modules.items()
56+
if module is not None
57+
}
58+
)
59+
60+
61+
@pytest.fixture()
62+
def import_gateway_models(isolated_gateway_model_import_modules):
63+
"""Import gateway models from a clean module state."""
64+
65+
def import_models():
66+
return importlib.import_module(GATEWAY_MODEL_MODULE)
67+
68+
return import_models
Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,3 @@
11
"""
22
Gateway package for PolicyEngine Simulation API.
33
"""
4-
5-
from .endpoints import router
6-
from .models import JobStatusResponse, JobSubmitResponse, SimulationRequest
7-
8-
__all__ = [
9-
"router",
10-
"SimulationRequest",
11-
"JobSubmitResponse",
12-
"JobStatusResponse",
13-
]

projects/policyengine-api-simulation/tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
pytest_plugins = (
88
"fixtures.gateway.shared",
99
"fixtures.gateway.test_endpoints",
10+
"fixtures.gateway.package_imports",
1011
)
1112

1213
project_root = Path(__file__).parent.parent
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import sys
2+
3+
4+
def test_gateway_models_import_does_not_import_fastapi_endpoints(
5+
import_gateway_models,
6+
gateway_import_module_names,
7+
):
8+
import_gateway_models()
9+
10+
assert gateway_import_module_names.endpoints not in sys.modules
11+
assert gateway_import_module_names.fastapi not in sys.modules

projects/policyengine-apis-integ/tests/simulation/conftest.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,30 @@
1+
import json
2+
import time
3+
from http import HTTPStatus
4+
15
import httpx
26
import pytest
37
from pydantic_settings import BaseSettings, SettingsConfigDict
48

59
from policyengine_api_simulation_client import AuthenticatedClient, Client
10+
from policyengine_api_simulation_client.api.default import (
11+
get_budget_window_job_status_budget_window_jobs_batch_job_id_get,
12+
submit_budget_window_batch_simulate_economy_budget_window_post,
13+
)
14+
from policyengine_api_simulation_client.models import (
15+
BudgetWindowBatchRequest,
16+
BudgetWindowBatchStatusResponse,
17+
)
18+
19+
20+
BUDGET_WINDOW_YEARS = ["2026", "2027"]
21+
BUDGET_WINDOW_REFORM = {
22+
"gov.irs.credits.ctc.refundable.fully_refundable": {"2023-01-01.2100-12-31": True}
23+
}
24+
BUDGET_WINDOW_DATASET = "gs://policyengine-us-data/enhanced_cps_2024.h5"
25+
BUDGET_WINDOW_REGION = "us"
26+
BUDGET_WINDOW_SUBSAMPLE = 200
27+
BUDGET_WINDOW_MAX_PARALLEL = 2
628

729

830
class Settings(BaseSettings):
@@ -49,3 +71,122 @@ def poll_interval() -> float:
4971
def max_wait_seconds() -> float:
5072
"""Return max wait time in seconds."""
5173
return settings.timeout_in_millis / 1000
74+
75+
76+
def _decode_response_content(content: bytes) -> str:
77+
try:
78+
return json.dumps(json.loads(content), sort_keys=True)
79+
except (json.JSONDecodeError, UnicodeDecodeError):
80+
return content.decode("utf-8", errors="replace")
81+
82+
83+
def _poll_budget_window_batch(
84+
*,
85+
client: Client | AuthenticatedClient,
86+
batch_job_id: str,
87+
max_wait_seconds: float,
88+
poll_interval: float,
89+
) -> BudgetWindowBatchStatusResponse:
90+
deadline = time.monotonic() + max_wait_seconds
91+
last_status_code: HTTPStatus | None = None
92+
last_content = b""
93+
94+
while time.monotonic() < deadline:
95+
response = get_budget_window_job_status_budget_window_jobs_batch_job_id_get.sync_detailed(
96+
batch_job_id=batch_job_id, client=client
97+
)
98+
last_status_code = response.status_code
99+
last_content = response.content
100+
101+
if response.status_code == HTTPStatus.ACCEPTED:
102+
time.sleep(poll_interval)
103+
continue
104+
105+
if response.status_code == HTTPStatus.OK:
106+
assert isinstance(response.parsed, BudgetWindowBatchStatusResponse), (
107+
f"Unexpected response type: {type(response.parsed)}"
108+
)
109+
assert response.parsed.status == "complete", (
110+
f"Unexpected budget-window status: {response.parsed}"
111+
)
112+
return response.parsed
113+
114+
if response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR:
115+
raise AssertionError(
116+
"Budget-window batch failed: "
117+
f"{_decode_response_content(response.content)}"
118+
)
119+
120+
raise AssertionError(
121+
"Unexpected budget-window poll status "
122+
f"{response.status_code}: {_decode_response_content(response.content)}"
123+
)
124+
125+
raise TimeoutError(
126+
f"Budget-window batch {batch_job_id} did not complete within "
127+
f"{max_wait_seconds}s; last response was "
128+
f"{last_status_code}: {_decode_response_content(last_content)}"
129+
)
130+
131+
132+
@pytest.fixture()
133+
def budget_window_years() -> list[str]:
134+
"""Return the annual rows expected from the staging budget-window smoke run."""
135+
return list(BUDGET_WINDOW_YEARS)
136+
137+
138+
@pytest.fixture()
139+
def budget_window_request(us_model_version: str) -> BudgetWindowBatchRequest:
140+
"""Build the standard staging budget-window smoke request."""
141+
return BudgetWindowBatchRequest.from_dict(
142+
{
143+
"country": "us",
144+
"version": us_model_version,
145+
"region": BUDGET_WINDOW_REGION,
146+
"scope": "macro",
147+
"reform": BUDGET_WINDOW_REFORM,
148+
"subsample": BUDGET_WINDOW_SUBSAMPLE,
149+
"data": BUDGET_WINDOW_DATASET,
150+
"start_year": BUDGET_WINDOW_YEARS[0],
151+
"window_size": len(BUDGET_WINDOW_YEARS),
152+
"max_parallel": BUDGET_WINDOW_MAX_PARALLEL,
153+
}
154+
)
155+
156+
157+
@pytest.fixture()
158+
def decode_response_content():
159+
"""Return a compact formatter for non-OK HTTP response payloads."""
160+
return _decode_response_content
161+
162+
163+
@pytest.fixture()
164+
def submit_budget_window_batch(client: Client | AuthenticatedClient):
165+
"""Submit a budget-window batch through the generated client."""
166+
167+
def submit(request: BudgetWindowBatchRequest):
168+
return submit_budget_window_batch_simulate_economy_budget_window_post.sync_detailed(
169+
client=client,
170+
body=request,
171+
)
172+
173+
return submit
174+
175+
176+
@pytest.fixture()
177+
def poll_budget_window_batch(
178+
client: Client | AuthenticatedClient,
179+
max_wait_seconds: float,
180+
poll_interval: float,
181+
):
182+
"""Poll a budget-window batch through the generated client."""
183+
184+
def poll(batch_job_id: str) -> BudgetWindowBatchStatusResponse:
185+
return _poll_budget_window_batch(
186+
client=client,
187+
batch_job_id=batch_job_id,
188+
max_wait_seconds=max_wait_seconds,
189+
poll_interval=poll_interval,
190+
)
191+
192+
return poll
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""
2+
Integration tests for Modal-based budget-window batches.
3+
4+
These tests run against the staging Modal deployment and verify that the
5+
gateway can spawn the parent budget-window worker, the parent can spawn child
6+
simulation workers, and the completed batch result has the public response
7+
shape expected by API consumers.
8+
"""
9+
10+
from http import HTTPStatus
11+
12+
import pytest
13+
14+
from policyengine_api_simulation_client.models import (
15+
BudgetWindowBatchSubmitResponse,
16+
BudgetWindowResult,
17+
)
18+
from policyengine_api_simulation_client.types import Unset
19+
20+
21+
@pytest.mark.beta_only
22+
def test_budget_window_multi_year_batch_completes(
23+
budget_window_request,
24+
budget_window_years,
25+
decode_response_content,
26+
submit_budget_window_batch,
27+
poll_budget_window_batch,
28+
us_model_version: str,
29+
):
30+
"""
31+
Given a two-year US budget-window request
32+
When the batch is submitted and polled to completion
33+
Then the response contains 2026 and 2027 annual impacts plus totals.
34+
"""
35+
submit_response = submit_budget_window_batch(budget_window_request)
36+
37+
assert submit_response.status_code == HTTPStatus.OK, (
38+
"Unexpected submit status "
39+
f"{submit_response.status_code}: "
40+
f"{decode_response_content(submit_response.content)}"
41+
)
42+
assert isinstance(submit_response.parsed, BudgetWindowBatchSubmitResponse), (
43+
f"Unexpected response type: {type(submit_response.parsed)}"
44+
)
45+
assert submit_response.parsed.status == "submitted"
46+
assert submit_response.parsed.version == us_model_version
47+
48+
batch_job_id = submit_response.parsed.batch_job_id
49+
assert submit_response.parsed.poll_url == f"/budget-window-jobs/{batch_job_id}"
50+
51+
completed = poll_budget_window_batch(batch_job_id)
52+
53+
assert completed.status == "complete"
54+
assert completed.progress == 100
55+
assert completed.error is None or isinstance(completed.error, Unset)
56+
assert isinstance(completed.result, BudgetWindowResult)
57+
58+
result = completed.result
59+
assert result.kind == "budgetWindow"
60+
assert result.start_year == budget_window_years[0]
61+
assert result.end_year == budget_window_years[-1]
62+
assert result.window_size == len(budget_window_years)
63+
annual_impacts = result.annual_impacts
64+
assert not isinstance(annual_impacts, Unset)
65+
assert [impact.year for impact in annual_impacts] == budget_window_years
66+
assert result.totals.year == "Total"
67+
assert all(
68+
isinstance(impact.budgetary_impact, int | float) for impact in annual_impacts
69+
)

0 commit comments

Comments
 (0)