Skip to content

Commit 16b415d

Browse files
committed
Fix gateway run ID polling and pytest plugin loading
1 parent eef995c commit 16b415d

5 files changed

Lines changed: 43 additions & 6 deletions

File tree

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import sys
22
from pathlib import Path
33

4+
pytest_plugins = ("fixtures.ping.shared",)
5+
46
library_root = Path(__file__).parent.parent
57
if str(library_root) not in sys.path:
68
sys.path.insert(0, str(library_root))

libs/policyengine-fastapi/tests/ping/conftest.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

projects/policyengine-api-simulation/src/modal/gateway/endpoints.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,15 @@ def _build_policyengine_bundle(
6060
)
6161

6262

63-
def _serialize_job_metadata(resolved_app_name: str, bundle: PolicyEngineBundle) -> dict:
63+
def _serialize_job_metadata(
64+
resolved_app_name: str,
65+
bundle: PolicyEngineBundle,
66+
run_id: str | None = None,
67+
) -> dict:
6468
return {
6569
"resolved_app_name": resolved_app_name,
6670
"policyengine_bundle": bundle.model_dump(),
71+
"run_id": run_id,
6772
}
6873

6974

@@ -136,7 +141,7 @@ async def submit_simulation(request: SimulationRequest):
136141
call = sim_func.spawn(payload)
137142

138143
bundle = _build_policyengine_bundle(request.country, resolved_version, payload)
139-
job_metadata = _serialize_job_metadata(app_name, bundle)
144+
job_metadata = _serialize_job_metadata(app_name, bundle, run_id)
140145
_job_metadata_store()[call.object_id] = job_metadata
141146

142147
return JobSubmitResponse(
@@ -151,7 +156,11 @@ async def submit_simulation(request: SimulationRequest):
151156
)
152157

153158

154-
@router.get("/jobs/{job_id}")
159+
@router.get(
160+
"/jobs/{job_id}",
161+
response_model=JobStatusResponse,
162+
response_model_exclude_none=True,
163+
)
155164
async def get_job_status(job_id: str):
156165
"""
157166
Poll for job status.

projects/policyengine-api-simulation/src/modal/gateway/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class JobStatusResponse(BaseModel):
6363
error: Optional[str] = None
6464
resolved_app_name: Optional[str] = None
6565
policyengine_bundle: Optional[PolicyEngineBundle] = None
66+
run_id: Optional[str] = None
6667

6768

6869
class PingRequest(BaseModel):

projects/policyengine-api-simulation/tests/gateway/test_endpoints.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,10 +363,36 @@ def test__given_submitted_job__then_job_status_includes_bundle_metadata(
363363
assert response.status_code == 200
364364
data = response.json()
365365
assert data["status"] == "complete"
366+
assert "run_id" not in data
366367
assert data["resolved_app_name"] == "policyengine-simulation-us1-500-0-uk2-66-0"
367368
assert data["policyengine_bundle"] == {
368369
"model_version": "1.500.0",
369-
"policyengine_version": None,
370-
"data_version": None,
371370
"dataset": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0",
372371
}
372+
373+
def test__given_submitted_job_with_telemetry__then_polling_echoes_run_id(
374+
self, mock_modal, client: TestClient
375+
):
376+
mock_modal["dicts"]["simulation-api-us-versions"] = {
377+
"latest": "1.500.0",
378+
"1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0",
379+
}
380+
381+
submit_response = client.post(
382+
"/simulate/economy/comparison",
383+
json={
384+
"country": "us",
385+
"scope": "macro",
386+
"reform": {},
387+
"_telemetry": {
388+
"run_id": "run-123",
389+
"process_id": "proc-123",
390+
"capture_mode": "disabled",
391+
},
392+
},
393+
)
394+
395+
response = client.get(f"/jobs/{submit_response.json()['job_id']}")
396+
397+
assert response.status_code == 200
398+
assert response.json()["run_id"] == "run-123"

0 commit comments

Comments
 (0)