Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 6 additions & 110 deletions aws_lambda_powertools/event_handler/http_resolver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import base64
import inspect
import warnings
from typing import TYPE_CHECKING, Any, Callable
from urllib.parse import parse_qs
Expand All @@ -10,10 +9,7 @@
ApiGatewayResolver,
BaseRouter,
ProxyEventType,
Response,
Route,
)
from aws_lambda_powertools.event_handler.middlewares.async_utils import wrap_middleware_async
from aws_lambda_powertools.shared.headers_serializer import BaseHeadersSerializer
from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent

Expand Down Expand Up @@ -240,113 +236,13 @@ def _get_base_path(self) -> str:
return ""

async def _resolve_async(self) -> dict: # type: ignore[override]
"""Async version of resolve that supports async handlers."""
method = self.current_event.http_method.upper()
path = self._remove_prefix(self.current_event.path)

registered_routes = self._static_routes + self._dynamic_routes

for route in registered_routes:
if method != route.method:
continue
match_results = route.rule.match(path)
if match_results:
self.append_context(_route=route, _path=path)
route_keys = self._convert_matches_into_route_keys(match_results)
return await self._call_route_async(route, route_keys)

# Handle not found
return await self._handle_not_found_async()

async def _call_route_async(self, route: Route, route_arguments: dict[str, str]) -> dict: # type: ignore[override]
"""Call route handler, supporting both sync and async handlers."""
from aws_lambda_powertools.event_handler.api_gateway import ResponseBuilder

try:
self._reset_processed_stack()

# Get the route args (may be modified by validation middleware)
self.append_context(_route_args=route_arguments)

# Run middleware chain (sync for now, handlers can be async)
response = await self._run_middleware_chain_async(route)

response_builder: ResponseBuilder = ResponseBuilder(
response=response,
serializer=self._serializer,
route=route,
)

return response_builder.build(self.current_event, self._cors)

except Exception as exc:
exc_response_builder = self._call_exception_handler(exc, route)
if exc_response_builder:
return exc_response_builder.build(self.current_event, self._cors)
raise

async def _run_middleware_chain_async(self, route: Route) -> Response:
"""Run the middleware chain, awaiting async handlers."""
# Build middleware list
all_middlewares: list[Callable[..., Any]] = []

# Determine if validation should be enabled for this route
# If route has explicit enable_validation setting, use it; otherwise, use resolver's global setting
route_validation_enabled = (
route.enable_validation if route.enable_validation is not None else self._enable_validation
)

if route_validation_enabled and hasattr(self, "_request_validation_middleware"):
all_middlewares.append(self._request_validation_middleware)

all_middlewares.extend(self._router_middlewares + route.middlewares)

if route_validation_enabled and hasattr(self, "_response_validation_middleware"):
all_middlewares.append(self._response_validation_middleware)

# Create the final handler that calls the route function
async def final_handler(app):
route_args = app.context.get("_route_args", {})
result = route.func(**route_args)

# Await if coroutine
if inspect.iscoroutine(result):
result = await result

return self._to_response(result)

# Build middleware chain from end to start
next_handler = final_handler

for middleware in reversed(all_middlewares):
next_handler = wrap_middleware_async(middleware, next_handler)

return await next_handler(self)

async def _handle_not_found_async(self, method: str = "", path: str = "") -> dict: # type: ignore[override]
"""Handle 404 responses, using custom not_found handler if registered."""
from http import HTTPStatus

from aws_lambda_powertools.event_handler.api_gateway import ResponseBuilder
from aws_lambda_powertools.event_handler.exceptions import NotFoundError

# Check for custom not_found handler
custom_not_found_handler = self.exception_handler_manager.lookup_exception_handler(NotFoundError)
if custom_not_found_handler:
response = custom_not_found_handler(NotFoundError())
else:
response = Response(
status_code=HTTPStatus.NOT_FOUND.value,
content_type="application/json",
body={"statusCode": HTTPStatus.NOT_FOUND.value, "message": "Not found"},
)

response_builder: ResponseBuilder = ResponseBuilder(
response=response,
serializer=self._serializer,
route=None,
)
"""Thin async resolver: delegates entirely to the parent and serializes to dict.

The parent's _resolve_async handles route matching, CORS preflight, not-found
logic, and exception handling. The only adaptation needed here is converting
the returned ResponseBuilder into the dict format that asgi_handler expects.
"""
response_builder = await super()._resolve_async()
return response_builder.build(self.current_event, self._cors)

async def asgi_handler(self, scope: dict, receive: Callable, send: Callable) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1242,3 +1242,219 @@ def hello():

# THEN it returns 404 (method mismatch is treated as not found)
assert captured["status_code"] == 404


# =============================================================================
# CORS Tests (issue #8267)
# =============================================================================


@pytest.mark.asyncio
async def test_cors_options_preflight_returns_204():
# GIVEN an app with CORSConfig and a POST route
from aws_lambda_powertools.event_handler.api_gateway import CORSConfig

app = HttpResolverLocal(cors=CORSConfig(allow_origin="*"))

@app.post("/items")
def create_item():
return {"ok": True}

# WHEN a browser sends a CORS preflight OPTIONS request
scope = {
"type": "http",
"method": "OPTIONS",
"path": "/items",
"query_string": b"",
"headers": [
(b"origin", b"http://localhost:3000"),
(b"access-control-request-method", b"POST"),
],
}

receive = make_asgi_receive()
captured: dict[str, Any] = {"status_code": None, "headers": []}

async def send(message: dict[str, Any]) -> None:
await asyncio.sleep(0)
if message["type"] == "http.response.start":
captured["status_code"] = message["status"]
captured["headers"].extend(message.get("headers", []))

await app(scope, receive, send)

# THEN it returns 204 with CORS headers (not 500 or 404)
assert captured["status_code"] == 204

header_names = [name.lower() for name, _ in captured["headers"]]
assert b"access-control-allow-origin" in header_names
assert b"access-control-allow-methods" in header_names


@pytest.mark.asyncio
async def test_cors_options_preflight_with_exception_handler_does_not_return_500():
# GIVEN an app with CORSConfig and a generic exception handler that returns 500
import json

from aws_lambda_powertools.event_handler.api_gateway import CORSConfig

app = HttpResolverLocal(cors=CORSConfig(allow_origin="*"))

@app.post("/items")
def create_item():
return {"ok": True}

@app.exception_handler(Exception)
def handle_server_error(ex: Exception):
return Response(
status_code=500,
content_type="application/json",
body=json.dumps({"error": "internal"}),
)

# WHEN a browser sends a CORS preflight OPTIONS request
scope = {
"type": "http",
"method": "OPTIONS",
"path": "/items",
"query_string": b"",
"headers": [
(b"origin", b"http://localhost:3000"),
(b"access-control-request-method", b"POST"),
],
}

receive = make_asgi_receive()
captured: dict[str, Any] = {"status_code": None, "headers": []}

async def send(message: dict[str, Any]) -> None:
await asyncio.sleep(0)
if message["type"] == "http.response.start":
captured["status_code"] = message["status"]
captured["headers"].extend(message.get("headers", []))

await app(scope, receive, send)

# THEN the OPTIONS request returns 204, not 500
assert captured["status_code"] == 204
header_names = [name.lower() for name, _ in captured["headers"]]
assert b"access-control-allow-origin" in header_names


@pytest.mark.asyncio
async def test_no_cors_options_returns_404():
# GIVEN an app WITHOUT CORSConfig
app = HttpResolverLocal()

@app.post("/items")
def create_item():
return {"ok": True}

# WHEN a browser sends an OPTIONS request (no CORS configured)
scope = {
"type": "http",
"method": "OPTIONS",
"path": "/items",
"query_string": b"",
"headers": [],
}

receive = make_asgi_receive()
send, captured = make_asgi_send()

await app(scope, receive, send)

# THEN it returns 404 (no CORS config, no special handling)
assert captured["status_code"] == 404


@pytest.mark.asyncio
async def test_cors_options_includes_allowed_methods_header():
# GIVEN an app with CORSConfig and multiple routes
from aws_lambda_powertools.event_handler.api_gateway import CORSConfig

app = HttpResolverLocal(cors=CORSConfig(allow_origin="https://example.com"))

@app.get("/resource")
def get_resource():
return {"method": "GET"}

@app.post("/resource")
def post_resource():
return {"method": "POST"}

# WHEN an OPTIONS preflight is sent
scope = {
"type": "http",
"method": "OPTIONS",
"path": "/resource",
"query_string": b"",
"headers": [
(b"origin", b"https://example.com"),
(b"access-control-request-method", b"GET"),
],
}

receive = make_asgi_receive()
captured: dict[str, Any] = {"status_code": None, "headers": []}

async def send(message: dict[str, Any]) -> None:
await asyncio.sleep(0)
if message["type"] == "http.response.start":
captured["status_code"] = message["status"]
captured["headers"].extend(message.get("headers", []))

await app(scope, receive, send)

# THEN 204 is returned with Access-Control-Allow-Methods header
assert captured["status_code"] == 204
allow_methods_headers = [v for name, v in captured["headers"] if name.lower() == b"access-control-allow-methods"]
assert len(allow_methods_headers) == 1


@pytest.mark.asyncio
async def test_cors_disallowed_header_not_in_allow_headers():
# GIVEN an app with CORSConfig that only allows specific headers
from aws_lambda_powertools.event_handler.api_gateway import CORSConfig

app = HttpResolverLocal(cors=CORSConfig(allow_origin="*", allow_headers=["X-Custom-Allowed"]))

@app.post("/items")
def create_item():
return {"ok": True}

# WHEN a preflight requests an unlisted header
scope = {
"type": "http",
"method": "OPTIONS",
"path": "/items",
"query_string": b"",
"headers": [
(b"origin", b"http://localhost:3000"),
(b"access-control-request-method", b"POST"),
(b"access-control-request-headers", b"X-Not-Allowed"),
],
}

receive = make_asgi_receive()
captured: dict[str, Any] = {"status_code": None, "headers": []}

async def send(message: dict[str, Any]) -> None:
await asyncio.sleep(0)
if message["type"] == "http.response.start":
captured["status_code"] = message["status"]
captured["headers"].extend(message.get("headers", []))

await app(scope, receive, send)

# THEN the server still returns 204 (browser enforces the rejection, not the server)
assert captured["status_code"] == 204

# AND the unlisted header is absent from Access-Control-Allow-Headers
allow_headers_value = next(
(v.decode() for name, v in captured["headers"] if name.lower() == b"access-control-allow-headers"),
"",
)
assert "X-Not-Allowed" not in allow_headers_value
# AND the explicitly allowed header IS present
assert "X-Custom-Allowed" in allow_headers_value