Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 78 additions & 8 deletions src/stac_auth_proxy/middleware/Cql2RewriteLinksFilterMiddleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
from dataclasses import dataclass
from logging import getLogger
from typing import Optional
from typing import Any, Optional
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse

from cql2 import Expr
Expand All @@ -12,6 +12,8 @@

logger = getLogger(__name__)

_UNSET: Any = object()


@dataclass(frozen=True)
class Cql2RewriteLinksFilterMiddleware:
Expand All @@ -32,6 +34,54 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
# No filter set, just pass through
return await self.app(scope, receive, send)

# When the client sends the filter in the request body (POST /search etc.),
# query_params won't expose it. Capture it here so we can put it back on
# paginated next-link bodies. We use a sentinel to distinguish "client sent
# no filter" (drop the field) from "client sent some filter value" (echo it
# back verbatim). Mirroring the query-string read above, we attempt this
# for any method that can carry a body and let the JSON-decode no-op when
# the body is absent or unparseable.
original_body_filter: Any = _UNSET
original_body_filter_lang: Any = _UNSET

if request.method in ("POST", "PUT", "PATCH"):
buffered_body = b""
more_body = True
while more_body:
message = await receive()
if message["type"] == "http.request":
buffered_body += message.get("body", b"")
more_body = message.get("more_body", False)
else:
# Disconnect or unexpected message; bail out without capture.
break

try:
body_json = json.loads(buffered_body) if buffered_body else None
except json.JSONDecodeError:
body_json = None

if isinstance(body_json, dict):
if "filter" in body_json:
original_body_filter = body_json["filter"]
if "filter-lang" in body_json:
original_body_filter_lang = body_json["filter-lang"]

replayed = False

async def replay_receive() -> Message:
nonlocal replayed
if not replayed:
replayed = True
return {
"type": "http.request",
"body": buffered_body,
"more_body": False,
}
return await receive()

receive = replay_receive

# Intercept the response
response_start = None
body_chunks = []
Expand All @@ -46,7 +96,12 @@ async def send_wrapper(message: Message):
more_body = message.get("more_body", False)
if not more_body:
await self._process_and_send_response(
response_start, body_chunks, send, original_filter
response_start,
body_chunks,
send,
original_filter,
original_body_filter,
original_body_filter_lang,
)
else:
await send(message)
Expand All @@ -59,6 +114,8 @@ async def _process_and_send_response(
body_chunks: list[bytes],
send: Send,
original_filter: Optional[str],
original_body_filter: Any = _UNSET,
original_body_filter_lang: Any = _UNSET,
):
body = b"".join(body_chunks)
try:
Expand Down Expand Up @@ -87,12 +144,25 @@ async def _process_and_send_response(

# Handle filter in body (for POST links)
if "body" in link and isinstance(link["body"], dict):
if "filter" in link["body"]:
if cql2_filter:
link["body"]["filter"] = cql2_filter.to_json()
else:
link["body"].pop("filter", None)
link["body"].pop("filter-lang", None)
had_filter = "filter" in link["body"]

if original_body_filter is not _UNSET:
# Client originally sent a CQL2 filter in the request
# body (POST /search). Echo it back verbatim so
# paginated requests carry the same filter shape and
# serialization.
link["body"]["filter"] = original_body_filter
elif had_filter and cql2_filter:
# Filter came from the query string; emit it in the
# body as JSON so the next-link POST is self-contained.
link["body"]["filter"] = cql2_filter.to_json()
elif had_filter:
link["body"].pop("filter", None)

if original_body_filter_lang is not _UNSET:
link["body"]["filter-lang"] = original_body_filter_lang
elif had_filter and not cql2_filter:
link["body"].pop("filter-lang", None)

# Send the modified response
new_body = json.dumps(data).encode("utf-8")
Expand Down
202 changes: 202 additions & 0 deletions tests/test_cql2_rewrite_links_filter_middleware.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Test Cql2RewriteLinksFilterMiddleware."""

import json
import re

import pytest
Expand Down Expand Up @@ -335,3 +336,204 @@ async def test_endpoint(request: Request):

# Other data should be preserved
assert body["other_data"] == "preserved"


class TestPostBodyClientFilterPreservation:
"""Regression: client filters sent in a POST search body must be preserved
in the next-link body. The middleware previously read the original filter
only from the query string, which silently dropped POST-body filters.
"""

@pytest.mark.parametrize(
"system_filter,client_filter,client_filter_lang,expected_filter,expected_filter_lang",
[
# CQL2-JSON client filter must be echoed back unchanged
(
"private = false",
{"op": "<", "args": [{"property": "cloud_coverage"}, 50]},
"cql2-json",
{"op": "<", "args": [{"property": "cloud_coverage"}, 50]},
"cql2-json",
),
# Different client filter
(
"collection = 'landsat'",
{"op": ">", "args": [{"property": "datetime"}, "2023-01-01"]},
"cql2-json",
{"op": ">", "args": [{"property": "datetime"}, "2023-01-01"]},
"cql2-json",
),
# CQL2-text client filter must also be preserved verbatim
(
"private = false",
"cloud_coverage < 30",
"cql2-text",
"cloud_coverage < 30",
"cql2-text",
),
# No client filter in body — filter/filter-lang stay stripped from next.body
(
"private = false",
None,
None,
None,
None,
),
],
)
def test_preserves_client_filter_from_post_body(
self,
system_filter,
client_filter,
client_filter_lang,
expected_filter,
expected_filter_lang,
):
"""POST /search with filter in body keeps that filter in the next link body."""
app = FastAPI()

class MockBuildFilterMiddleware:
def __init__(self, app, state_key="cql2_filter"):
self.app = app
self.state_key = state_key

async def __call__(self, scope, receive, send):
if scope["type"] == "http":
request = Request(scope)
setattr(request.state, self.state_key, Expr(system_filter))
await self.app(scope, receive, send)

app.add_middleware(Cql2RewriteLinksFilterMiddleware)
app.add_middleware(MockBuildFilterMiddleware)

@app.post("/search")
async def search_endpoint(request: Request):
body_json = await request.json()
system_expr = getattr(request.state, "cql2_filter", None)
user_filter = body_json.get("filter")
user_filter_lang = body_json.get("filter-lang")

combined = None
if system_expr is not None and user_filter is not None:
combined = system_expr + Expr(user_filter)
elif system_expr is not None:
combined = system_expr
elif user_filter is not None:
combined = Expr(user_filter)

next_body = {
"collections": body_json.get("collections", []),
"limit": body_json.get("limit", 10),
"token": "next-token",
}
if combined is not None:
lang = user_filter_lang or "cql2-json"
next_body["filter-lang"] = lang
next_body["filter"] = (
combined.to_text() if lang == "cql2-text" else combined.to_json()
)

return {
"type": "FeatureCollection",
"links": [
{
"rel": "next",
"method": "POST",
"href": "http://example.com/search",
"body": next_body,
}
],
}

request_body = {"collections": ["col1"], "limit": 10}
if client_filter is not None:
request_body["filter"] = client_filter
request_body["filter-lang"] = client_filter_lang

client = TestClient(app)
response = client.post("/search", json=request_body)
assert response.status_code == 200, response.text
data = response.json()

next_link = next(link for link in data["links"] if link.get("rel") == "next")
body = next_link["body"]

# Pagination metadata is always carried through
assert body["token"] == "next-token"
assert body["collections"] == ["col1"]
assert body["limit"] == 10

if expected_filter is None:
assert "filter" not in body
assert "filter-lang" not in body
else:
assert body["filter"] == expected_filter
assert body["filter-lang"] == expected_filter_lang

def test_request_body_is_intact_for_inner_app(self):
"""Body capture must replay the exact original bytes to the inner app."""
app = FastAPI()

class MockBuildFilterMiddleware:
def __init__(self, app, state_key="cql2_filter"):
self.app = app
self.state_key = state_key

async def __call__(self, scope, receive, send):
if scope["type"] == "http":
request = Request(scope)
setattr(request.state, self.state_key, Expr("private = false"))
await self.app(scope, receive, send)

app.add_middleware(Cql2RewriteLinksFilterMiddleware)
app.add_middleware(MockBuildFilterMiddleware)

@app.post("/search")
async def search_endpoint(request: Request):
received = await request.body()
return {"echo": json.loads(received)}

request_body = {
"collections": ["a", "b"],
"filter": {"op": "=", "args": [{"property": "x"}, 1]},
"filter-lang": "cql2-json",
}
client = TestClient(app)
response = client.post("/search", json=request_body)
assert response.status_code == 200, response.text
assert response.json()["echo"] == request_body

def test_malformed_json_body_does_not_break_middleware(self):
"""An unparseable body must pass through without the middleware crashing."""
app = FastAPI()

class MockBuildFilterMiddleware:
def __init__(self, app, state_key="cql2_filter"):
self.app = app
self.state_key = state_key

async def __call__(self, scope, receive, send):
if scope["type"] == "http":
request = Request(scope)
setattr(request.state, self.state_key, Expr("private = false"))
await self.app(scope, receive, send)

app.add_middleware(Cql2RewriteLinksFilterMiddleware)
app.add_middleware(MockBuildFilterMiddleware)

@app.post("/search")
async def search_endpoint(request: Request):
raw = await request.body()
return Response(
content=raw,
media_type="application/octet-stream",
)

client = TestClient(app)
response = client.post(
"/search",
content=b"not json",
headers={"content-type": "application/json"},
)
assert response.status_code == 200
assert response.content == b"not json"