diff --git a/src/stac_auth_proxy/middleware/Cql2RewriteLinksFilterMiddleware.py b/src/stac_auth_proxy/middleware/Cql2RewriteLinksFilterMiddleware.py index 1909bfd..2bf92a2 100644 --- a/src/stac_auth_proxy/middleware/Cql2RewriteLinksFilterMiddleware.py +++ b/src/stac_auth_proxy/middleware/Cql2RewriteLinksFilterMiddleware.py @@ -3,7 +3,7 @@ import json from dataclasses import dataclass from logging import getLogger -from typing import Optional +from typing import Any, Optional from urllib.parse import parse_qs, urlencode, urlparse, urlunparse from cql2 import Expr @@ -12,6 +12,8 @@ logger = getLogger(__name__) +_UNSET: Any = object() + @dataclass(frozen=True) class Cql2RewriteLinksFilterMiddleware: @@ -32,6 +34,54 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # No filter set, just pass through return await self.app(scope, receive, send) + # When the client sends the filter in the request body (POST /search etc.), + # query_params won't expose it. Capture it here so we can put it back on + # paginated next-link bodies. We use a sentinel to distinguish "client sent + # no filter" (drop the field) from "client sent some filter value" (echo it + # back verbatim). Mirroring the query-string read above, we attempt this + # for any method that can carry a body and let the JSON-decode no-op when + # the body is absent or unparseable. + original_body_filter: Any = _UNSET + original_body_filter_lang: Any = _UNSET + + if request.method in ("POST", "PUT", "PATCH"): + buffered_body = b"" + more_body = True + while more_body: + message = await receive() + if message["type"] == "http.request": + buffered_body += message.get("body", b"") + more_body = message.get("more_body", False) + else: + # Disconnect or unexpected message; bail out without capture. + break + + try: + body_json = json.loads(buffered_body) if buffered_body else None + except json.JSONDecodeError: + body_json = None + + if isinstance(body_json, dict): + if "filter" in body_json: + original_body_filter = body_json["filter"] + if "filter-lang" in body_json: + original_body_filter_lang = body_json["filter-lang"] + + replayed = False + + async def replay_receive() -> Message: + nonlocal replayed + if not replayed: + replayed = True + return { + "type": "http.request", + "body": buffered_body, + "more_body": False, + } + return await receive() + + receive = replay_receive + # Intercept the response response_start = None body_chunks = [] @@ -46,7 +96,12 @@ async def send_wrapper(message: Message): more_body = message.get("more_body", False) if not more_body: await self._process_and_send_response( - response_start, body_chunks, send, original_filter + response_start, + body_chunks, + send, + original_filter, + original_body_filter, + original_body_filter_lang, ) else: await send(message) @@ -59,6 +114,8 @@ async def _process_and_send_response( body_chunks: list[bytes], send: Send, original_filter: Optional[str], + original_body_filter: Any = _UNSET, + original_body_filter_lang: Any = _UNSET, ): body = b"".join(body_chunks) try: @@ -87,12 +144,25 @@ async def _process_and_send_response( # Handle filter in body (for POST links) if "body" in link and isinstance(link["body"], dict): - if "filter" in link["body"]: - if cql2_filter: - link["body"]["filter"] = cql2_filter.to_json() - else: - link["body"].pop("filter", None) - link["body"].pop("filter-lang", None) + had_filter = "filter" in link["body"] + + if original_body_filter is not _UNSET: + # Client originally sent a CQL2 filter in the request + # body (POST /search). Echo it back verbatim so + # paginated requests carry the same filter shape and + # serialization. + link["body"]["filter"] = original_body_filter + elif had_filter and cql2_filter: + # Filter came from the query string; emit it in the + # body as JSON so the next-link POST is self-contained. + link["body"]["filter"] = cql2_filter.to_json() + elif had_filter: + link["body"].pop("filter", None) + + if original_body_filter_lang is not _UNSET: + link["body"]["filter-lang"] = original_body_filter_lang + elif had_filter and not cql2_filter: + link["body"].pop("filter-lang", None) # Send the modified response new_body = json.dumps(data).encode("utf-8") diff --git a/tests/test_cql2_rewrite_links_filter_middleware.py b/tests/test_cql2_rewrite_links_filter_middleware.py index d5b07cd..16036fa 100644 --- a/tests/test_cql2_rewrite_links_filter_middleware.py +++ b/tests/test_cql2_rewrite_links_filter_middleware.py @@ -1,5 +1,6 @@ """Test Cql2RewriteLinksFilterMiddleware.""" +import json import re import pytest @@ -335,3 +336,204 @@ async def test_endpoint(request: Request): # Other data should be preserved assert body["other_data"] == "preserved" + + +class TestPostBodyClientFilterPreservation: + """Regression: client filters sent in a POST search body must be preserved + in the next-link body. The middleware previously read the original filter + only from the query string, which silently dropped POST-body filters. + """ + + @pytest.mark.parametrize( + "system_filter,client_filter,client_filter_lang,expected_filter,expected_filter_lang", + [ + # CQL2-JSON client filter must be echoed back unchanged + ( + "private = false", + {"op": "<", "args": [{"property": "cloud_coverage"}, 50]}, + "cql2-json", + {"op": "<", "args": [{"property": "cloud_coverage"}, 50]}, + "cql2-json", + ), + # Different client filter + ( + "collection = 'landsat'", + {"op": ">", "args": [{"property": "datetime"}, "2023-01-01"]}, + "cql2-json", + {"op": ">", "args": [{"property": "datetime"}, "2023-01-01"]}, + "cql2-json", + ), + # CQL2-text client filter must also be preserved verbatim + ( + "private = false", + "cloud_coverage < 30", + "cql2-text", + "cloud_coverage < 30", + "cql2-text", + ), + # No client filter in body — filter/filter-lang stay stripped from next.body + ( + "private = false", + None, + None, + None, + None, + ), + ], + ) + def test_preserves_client_filter_from_post_body( + self, + system_filter, + client_filter, + client_filter_lang, + expected_filter, + expected_filter_lang, + ): + """POST /search with filter in body keeps that filter in the next link body.""" + app = FastAPI() + + class MockBuildFilterMiddleware: + def __init__(self, app, state_key="cql2_filter"): + self.app = app + self.state_key = state_key + + async def __call__(self, scope, receive, send): + if scope["type"] == "http": + request = Request(scope) + setattr(request.state, self.state_key, Expr(system_filter)) + await self.app(scope, receive, send) + + app.add_middleware(Cql2RewriteLinksFilterMiddleware) + app.add_middleware(MockBuildFilterMiddleware) + + @app.post("/search") + async def search_endpoint(request: Request): + body_json = await request.json() + system_expr = getattr(request.state, "cql2_filter", None) + user_filter = body_json.get("filter") + user_filter_lang = body_json.get("filter-lang") + + combined = None + if system_expr is not None and user_filter is not None: + combined = system_expr + Expr(user_filter) + elif system_expr is not None: + combined = system_expr + elif user_filter is not None: + combined = Expr(user_filter) + + next_body = { + "collections": body_json.get("collections", []), + "limit": body_json.get("limit", 10), + "token": "next-token", + } + if combined is not None: + lang = user_filter_lang or "cql2-json" + next_body["filter-lang"] = lang + next_body["filter"] = ( + combined.to_text() if lang == "cql2-text" else combined.to_json() + ) + + return { + "type": "FeatureCollection", + "links": [ + { + "rel": "next", + "method": "POST", + "href": "http://example.com/search", + "body": next_body, + } + ], + } + + request_body = {"collections": ["col1"], "limit": 10} + if client_filter is not None: + request_body["filter"] = client_filter + request_body["filter-lang"] = client_filter_lang + + client = TestClient(app) + response = client.post("/search", json=request_body) + assert response.status_code == 200, response.text + data = response.json() + + next_link = next(link for link in data["links"] if link.get("rel") == "next") + body = next_link["body"] + + # Pagination metadata is always carried through + assert body["token"] == "next-token" + assert body["collections"] == ["col1"] + assert body["limit"] == 10 + + if expected_filter is None: + assert "filter" not in body + assert "filter-lang" not in body + else: + assert body["filter"] == expected_filter + assert body["filter-lang"] == expected_filter_lang + + def test_request_body_is_intact_for_inner_app(self): + """Body capture must replay the exact original bytes to the inner app.""" + app = FastAPI() + + class MockBuildFilterMiddleware: + def __init__(self, app, state_key="cql2_filter"): + self.app = app + self.state_key = state_key + + async def __call__(self, scope, receive, send): + if scope["type"] == "http": + request = Request(scope) + setattr(request.state, self.state_key, Expr("private = false")) + await self.app(scope, receive, send) + + app.add_middleware(Cql2RewriteLinksFilterMiddleware) + app.add_middleware(MockBuildFilterMiddleware) + + @app.post("/search") + async def search_endpoint(request: Request): + received = await request.body() + return {"echo": json.loads(received)} + + request_body = { + "collections": ["a", "b"], + "filter": {"op": "=", "args": [{"property": "x"}, 1]}, + "filter-lang": "cql2-json", + } + client = TestClient(app) + response = client.post("/search", json=request_body) + assert response.status_code == 200, response.text + assert response.json()["echo"] == request_body + + def test_malformed_json_body_does_not_break_middleware(self): + """An unparseable body must pass through without the middleware crashing.""" + app = FastAPI() + + class MockBuildFilterMiddleware: + def __init__(self, app, state_key="cql2_filter"): + self.app = app + self.state_key = state_key + + async def __call__(self, scope, receive, send): + if scope["type"] == "http": + request = Request(scope) + setattr(request.state, self.state_key, Expr("private = false")) + await self.app(scope, receive, send) + + app.add_middleware(Cql2RewriteLinksFilterMiddleware) + app.add_middleware(MockBuildFilterMiddleware) + + @app.post("/search") + async def search_endpoint(request: Request): + raw = await request.body() + return Response( + content=raw, + media_type="application/octet-stream", + ) + + client = TestClient(app) + response = client.post( + "/search", + content=b"not json", + headers={"content-type": "application/json"}, + ) + assert response.status_code == 200 + assert response.content == b"not json"