Skip to content

Commit 53dd1da

Browse files
authored
Merge pull request #3643 from PolicyEngine/migration-pr2-fastapi-shell
Stage 2: Add FastAPI shell with Flask fallback
2 parents 890ae22 + d36050d commit 53dd1da

13 files changed

Lines changed: 586 additions & 10 deletions

File tree

Makefile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ setup-env:
77
debug:
88
FLASK_APP=policyengine_api.api FLASK_DEBUG=1 flask run --without-threads
99

10+
debug-asgi:
11+
FLASK_DEBUG=1 uvicorn policyengine_api.asgi:app --reload --port 8000
12+
1013
test-env-vars:
1114
pytest tests/env_variables
1215

changelog.d/fastapi-shell.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added a FastAPI ASGI compatibility shell that can serve the existing Flask API through WSGI fallback.

docs/engineering/skills/testing.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,24 @@ python scripts/export_migration_contracts.py
2222
python -m pytest tests/contract tests/unit/test_migration_flags.py tests/unit/test_migration_contract_artifacts.py tests/unit/test_capture_migration_baseline.py tests/unit/routes/test_migration_context_logging.py -q
2323
```
2424

25+
For PR 2 FastAPI shell or Flask fallback changes, verify the ASGI entrypoint and
26+
the v1 route contracts together:
27+
28+
```bash
29+
FLASK_DEBUG=1 python -m pytest tests/unit/test_asgi_factory.py tests/contract/test_v1_route_contracts.py tests/unit/routes/test_migration_context_logging.py -q
30+
```
31+
32+
If the change touches service compatibility behavior used by migrated or
33+
candidate endpoints, add the relevant focused service tests. For budget-window
34+
simulation compatibility, run:
35+
36+
```bash
37+
FLASK_DEBUG=1 python -m pytest tests/unit/services/test_economy_service.py::TestEconomyService::TestGetBudgetWindowEconomicImpact -q
38+
```
39+
40+
Regenerate and review `docs/engineering/generated/migration_contracts.md` when
41+
route inventory, migration registry flags, or v1 contract expectations change.
42+
FastAPI shell-only fallback changes should not change the route catalog.
43+
2544
Run `ruff format --check` and `ruff check` on changed Python files before
2645
handoff.
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# PR 2 FastAPI Shell Runbook
2+
3+
PR 2 adds an ASGI FastAPI shell around the existing Flask API. It is a
4+
compatibility step only.
5+
6+
## Included
7+
8+
- Native FastAPI `GET /health`.
9+
- Flask fallback for all existing API v1 routes through WSGI middleware.
10+
- ASGI parity tests for current app-v2 contract routes.
11+
- Local Uvicorn run command.
12+
13+
## Not Included
14+
15+
- No production traffic shift.
16+
- No Cloud Run deployment.
17+
- No native FastAPI route migration beyond `GET /health`.
18+
- No Supabase, Alembic, SQLAlchemy, or Modal compute changes.
19+
20+
## Local Smoke
21+
22+
Run:
23+
24+
```bash
25+
FLASK_DEBUG=1 uvicorn policyengine_api.asgi:app --port 8000
26+
```
27+
28+
Smoke-check:
29+
30+
```bash
31+
curl -i http://localhost:8000/health
32+
curl -i http://localhost:8000/readiness-check
33+
curl -i http://localhost:8000/liveness-check
34+
curl -i http://localhost:8000/zz/metadata
35+
```
36+
37+
Expected behavior:
38+
39+
- `/health` returns FastAPI JSON: `{"status":"healthy"}`.
40+
- `/readiness-check` and `/liveness-check` return existing Flask text `OK`.
41+
- Existing v1 routes continue to use Flask fallback behavior.

policyengine_api/asgi.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""ASGI entrypoint for the Stage 2 FastAPI compatibility shell."""
2+
3+
from __future__ import annotations
4+
5+
from policyengine_api.api import app as flask_app
6+
from policyengine_api.asgi_factory import create_asgi_app
7+
8+
9+
app = application = create_asgi_app(flask_app)

policyengine_api/asgi_factory.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""FastAPI shell for serving the existing Flask API through ASGI."""
2+
3+
from __future__ import annotations
4+
5+
from typing import Literal
6+
7+
from a2wsgi import WSGIMiddleware
8+
from fastapi import FastAPI
9+
from pydantic import BaseModel
10+
11+
from policyengine_api.constants import VERSION
12+
13+
14+
class HealthResponse(BaseModel):
15+
status: Literal["healthy"]
16+
17+
18+
def _add_vary_origin(response) -> None:
19+
vary = response.headers.get("Vary")
20+
if vary is None:
21+
response.headers["Vary"] = "Origin"
22+
return
23+
if "origin" not in {value.strip().lower() for value in vary.split(",")}:
24+
response.headers["Vary"] = f"{vary}, Origin"
25+
26+
27+
def create_asgi_app(wsgi_app) -> FastAPI:
28+
"""Create the Stage 2 FastAPI shell around the existing Flask app."""
29+
30+
app = FastAPI(
31+
title="PolicyEngine API",
32+
version=VERSION,
33+
docs_url=None,
34+
redoc_url=None,
35+
openapi_url=None,
36+
)
37+
38+
@app.middleware("http")
39+
async def add_cors_for_native_routes(request, call_next):
40+
response = await call_next(request)
41+
origin = request.headers.get("origin")
42+
if origin and "access-control-allow-origin" not in response.headers:
43+
response.headers["Access-Control-Allow-Origin"] = origin
44+
_add_vary_origin(response)
45+
return response
46+
47+
@app.get("/health", response_model=HealthResponse)
48+
def health() -> HealthResponse:
49+
return HealthResponse(status="healthy")
50+
51+
app.mount("/", WSGIMiddleware(wsgi_app))
52+
return app

policyengine_api/services/economy_service.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from policyengine_api.utils import budget_window as budget_window_utils
2525
from policyengine.simulation import SimulationOptions
2626
from policyengine.utils.data.datasets import get_default_dataset
27+
import httpx
2728
import json
2829
import datetime
2930
import hashlib
@@ -77,6 +78,7 @@ class ImpactStatus(Enum):
7778
BUDGET_WINDOW_MAX_ACTIVE_YEARS = budget_window_utils.BUDGET_WINDOW_MAX_ACTIVE_YEARS
7879
BUDGET_WINDOW_MAX_YEARS = budget_window_utils.BUDGET_WINDOW_MAX_YEARS
7980
BUDGET_WINDOW_MAX_END_YEAR = budget_window_utils.BUDGET_WINDOW_MAX_END_YEAR
81+
BUDGET_WINDOW_SUBMISSION_VALIDATION_ERROR_STATUS_CODES = {400, 422}
8082

8183

8284
class EconomicImpactSetupOptions(BaseModel):
@@ -348,6 +350,18 @@ def get_budget_window_economic_impact(
348350
budget_window_cache.store_batch_job_id(
349351
cache_key, batch_execution.batch_job_id
350352
)
353+
except httpx.HTTPStatusError as error:
354+
budget_window_cache.clear_starting_claim(cache_key, claim_token)
355+
if (
356+
error.response.status_code
357+
in BUDGET_WINDOW_SUBMISSION_VALIDATION_ERROR_STATUS_CODES
358+
):
359+
return BudgetWindowEconomicImpactResult.failed(
360+
self._build_budget_window_submission_error_message(error),
361+
queued_years=years,
362+
cache_status=cache_status,
363+
)
364+
raise
351365
except Exception:
352366
budget_window_cache.clear_starting_claim(cache_key, claim_token)
353367
raise
@@ -443,6 +457,26 @@ def _start_budget_window_batch(
443457

444458
return simulation_api.run_budget_window_batch(sim_params)
445459

460+
def _build_budget_window_submission_error_message(
461+
self, error: httpx.HTTPStatusError
462+
) -> str:
463+
try:
464+
response_json = error.response.json()
465+
except ValueError:
466+
response_json = None
467+
468+
if isinstance(response_json, dict):
469+
for key in ("detail", "message", "error"):
470+
value = response_json.get(key)
471+
if value:
472+
return str(value)
473+
474+
response_text = error.response.text.strip()
475+
if response_text:
476+
return response_text
477+
478+
return str(error)
479+
446480
def _get_budget_window_result_from_batch_job_id(
447481
self,
448482
*,

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@ classifiers = [
2121
"License :: OSI Approved :: GNU Affero General Public License v3",
2222
]
2323
dependencies = [
24+
"a2wsgi>=1.10,<2",
2425
"anthropic",
2526
"assertpy",
2627
"click>=8,<9",
2728
"cloud-sql-python-connector",
2829
"faiss-cpu",
30+
"fastapi>=0.115,<1",
2931
"flask>=3,<4",
3032
"flask-cors>=5,<6",
3133
"Flask-Caching>=2,<3",
@@ -50,6 +52,7 @@ dependencies = [
5052
"rq",
5153
"sqlalchemy>=2,<3",
5254
"streamlit",
55+
"uvicorn[standard]>=0.32,<1",
5356
"werkzeug",
5457
]
5558

tests/contract/clients.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import Mapping, Protocol
5+
6+
from fastapi.testclient import TestClient
7+
from flask import Flask
8+
9+
from policyengine_api.asgi_factory import create_asgi_app
10+
11+
12+
@dataclass(frozen=True)
13+
class ContractResponse:
14+
status_code: int
15+
body: bytes
16+
headers: Mapping[str, str]
17+
content_type: str | None
18+
19+
@property
20+
def data(self) -> bytes:
21+
return self.body
22+
23+
24+
class ContractClient(Protocol):
25+
def open(
26+
self,
27+
path: str,
28+
*,
29+
method: str,
30+
json: dict | None = None,
31+
headers: dict | None = None,
32+
) -> ContractResponse: ...
33+
34+
35+
class FlaskContractClient:
36+
def __init__(self, app: Flask):
37+
self._client = app.test_client()
38+
39+
def open(
40+
self,
41+
path: str,
42+
*,
43+
method: str,
44+
json: dict | None = None,
45+
headers: dict | None = None,
46+
) -> ContractResponse:
47+
response = self._client.open(
48+
path,
49+
method=method,
50+
json=json,
51+
headers=headers,
52+
)
53+
return ContractResponse(
54+
status_code=response.status_code,
55+
body=response.data,
56+
headers=dict(response.headers),
57+
content_type=response.content_type,
58+
)
59+
60+
61+
class ASGIContractClient:
62+
def __init__(self, app: Flask):
63+
self._client = TestClient(create_asgi_app(app))
64+
65+
def open(
66+
self,
67+
path: str,
68+
*,
69+
method: str,
70+
json: dict | None = None,
71+
headers: dict | None = None,
72+
) -> ContractResponse:
73+
response = self._client.request(
74+
method,
75+
path,
76+
json=json,
77+
headers=headers,
78+
)
79+
return ContractResponse(
80+
status_code=response.status_code,
81+
body=response.content,
82+
headers=dict(response.headers),
83+
content_type=response.headers.get("content-type"),
84+
)

tests/contract/test_v1_route_contracts.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
from policyengine_api.routes.policy_routes import policy_bp
1414
from policyengine_api.routes.report_output_routes import report_output_bp
1515
from policyengine_api.routes.simulation_routes import simulation_bp
16+
from tests.contract.clients import (
17+
ASGIContractClient,
18+
ContractClient,
19+
FlaskContractClient,
20+
)
1621
from tests.contract.helpers import (
1722
assert_field_path_exists,
1823
assert_subset,
@@ -121,7 +126,7 @@ def _load_contract_economy_blueprint():
121126
)
122127

123128

124-
def _client():
129+
def create_contract_flask_app() -> Flask:
125130
app = Flask(__name__)
126131
app.config["TESTING"] = True
127132
app.register_blueprint(_load_contract_metadata_blueprint())
@@ -141,7 +146,17 @@ def liveness_check():
141146
def readiness_check():
142147
return Response("OK", status=200, mimetype="text/plain")
143148

144-
return app.test_client()
149+
return app
150+
151+
152+
@pytest.fixture(params=("flask_direct", "fastapi_fallback"))
153+
def contract_client(request) -> ContractClient:
154+
app = create_contract_flask_app()
155+
if request.param == "flask_direct":
156+
return FlaskContractClient(app)
157+
if request.param == "fastapi_fallback":
158+
return ASGIContractClient(app)
159+
raise AssertionError(f"Unknown contract client: {request.param}")
145160

146161

147162
def _resolved_path(path: str) -> str:
@@ -375,9 +390,12 @@ def _expected_subset(contract: ContractRequest) -> dict:
375390
APP_V2_ROUTE_CONTRACTS,
376391
ids=lambda contract: f"{contract.method} {contract.path}",
377392
)
378-
def test_app_v2_api_v1_route_contract(contract):
393+
def test_app_v2_api_v1_route_contract(
394+
contract: ContractRequest,
395+
contract_client: ContractClient,
396+
):
379397
with _patched_route_dependencies():
380-
response = _client().open(
398+
response = contract_client.open(
381399
_resolved_path(contract.path),
382400
method=contract.method,
383401
json=_json_payload(contract),
@@ -390,10 +408,9 @@ def test_app_v2_api_v1_route_contract(contract):
390408
assert_field_path_exists(payload, field_path)
391409

392410

393-
def test_health_routes_contract():
394-
client = _client()
395-
liveness = client.get("/liveness-check")
396-
readiness = client.get("/readiness-check")
411+
def test_health_routes_contract(contract_client: ContractClient):
412+
liveness = contract_client.open("/liveness-check", method="GET")
413+
readiness = contract_client.open("/readiness-check", method="GET")
397414

398415
assert liveness.status_code == 200
399416
assert liveness.data == b"OK"
@@ -403,8 +420,8 @@ def test_health_routes_contract():
403420
assert "text/plain" in readiness.content_type
404421

405422

406-
def test_invalid_country_contract():
407-
response = _client().get("/zz/metadata")
423+
def test_invalid_country_contract(contract_client: ContractClient):
424+
response = contract_client.open("/zz/metadata", method="GET")
408425

409426
assert response.status_code == 400
410427
assert_subset(

0 commit comments

Comments
 (0)