Skip to content

Commit 14cb35f

Browse files
nicolaraccoclaude
andcommitted
fix(cql2-rewrite-links): preserve client filter from POST body
Cql2RewriteLinksFilterMiddleware read the client's original filter only from request.query_params, which is empty for POST /search where the filter lives in the JSON body. With no original filter to restore, the middleware fell into the strip-from-body branch and popped both `filter` and `filter-lang` from every paginated next-link body. Clients following `next` silently lost their filter, broadening every page after the first. Capture the client's `filter` and `filter-lang` from the POST/PUT/PATCH request body for endpoints matching ^/search$ (kept in lock-step with Cql2ApplyFilterBodyMiddleware.search_body_endpoints), echo them back verbatim into next-link bodies, and replay the buffered bytes via a wrapped receive so inner middlewares still see the unmodified body. Existing GET-with-query-string behavior is unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent c9dc496 commit 14cb35f

2 files changed

Lines changed: 285 additions & 8 deletions

File tree

src/stac_auth_proxy/middleware/Cql2RewriteLinksFilterMiddleware.py

Lines changed: 83 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""Middleware to rewrite 'filter' in .links of the JSON response, removing the filter from the request state."""
22

33
import json
4+
import re
45
from dataclasses import dataclass
56
from logging import getLogger
6-
from typing import Optional
7+
from typing import Any, Optional
78
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
89

910
from cql2 import Expr
@@ -12,6 +13,12 @@
1213

1314
logger = getLogger(__name__)
1415

16+
_UNSET: Any = object()
17+
18+
# Endpoints whose POST body carries the client's CQL2 filter. Kept in sync with
19+
# Cql2ApplyFilterBodyMiddleware.search_body_endpoints.
20+
_SEARCH_BODY_ENDPOINTS = (re.compile(r"^/search$"),)
21+
1522

1623
@dataclass(frozen=True)
1724
class Cql2RewriteLinksFilterMiddleware:
@@ -32,6 +39,54 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
3239
# No filter set, just pass through
3340
return await self.app(scope, receive, send)
3441

42+
# When the client sends the filter in the request body (POST /search etc.),
43+
# query_params won't expose it. Capture it here so we can put it back on
44+
# paginated next-link bodies. We use a sentinel to distinguish "client sent
45+
# no filter" (drop the field) from "client sent some filter value" (echo it
46+
# back verbatim).
47+
original_body_filter: Any = _UNSET
48+
original_body_filter_lang: Any = _UNSET
49+
50+
if request.method in ("POST", "PUT", "PATCH") and any(
51+
pattern.match(request.url.path) for pattern in _SEARCH_BODY_ENDPOINTS
52+
):
53+
buffered_body = b""
54+
more_body = True
55+
while more_body:
56+
message = await receive()
57+
if message["type"] == "http.request":
58+
buffered_body += message.get("body", b"")
59+
more_body = message.get("more_body", False)
60+
else:
61+
# Disconnect or unexpected message; bail out without capture.
62+
break
63+
64+
try:
65+
body_json = json.loads(buffered_body) if buffered_body else None
66+
except json.JSONDecodeError:
67+
body_json = None
68+
69+
if isinstance(body_json, dict):
70+
if "filter" in body_json:
71+
original_body_filter = body_json["filter"]
72+
if "filter-lang" in body_json:
73+
original_body_filter_lang = body_json["filter-lang"]
74+
75+
replayed = False
76+
77+
async def replay_receive() -> Message:
78+
nonlocal replayed
79+
if not replayed:
80+
replayed = True
81+
return {
82+
"type": "http.request",
83+
"body": buffered_body,
84+
"more_body": False,
85+
}
86+
return await receive()
87+
88+
receive = replay_receive
89+
3590
# Intercept the response
3691
response_start = None
3792
body_chunks = []
@@ -46,7 +101,12 @@ async def send_wrapper(message: Message):
46101
more_body = message.get("more_body", False)
47102
if not more_body:
48103
await self._process_and_send_response(
49-
response_start, body_chunks, send, original_filter
104+
response_start,
105+
body_chunks,
106+
send,
107+
original_filter,
108+
original_body_filter,
109+
original_body_filter_lang,
50110
)
51111
else:
52112
await send(message)
@@ -59,6 +119,8 @@ async def _process_and_send_response(
59119
body_chunks: list[bytes],
60120
send: Send,
61121
original_filter: Optional[str],
122+
original_body_filter: Any = _UNSET,
123+
original_body_filter_lang: Any = _UNSET,
62124
):
63125
body = b"".join(body_chunks)
64126
try:
@@ -87,12 +149,25 @@ async def _process_and_send_response(
87149

88150
# Handle filter in body (for POST links)
89151
if "body" in link and isinstance(link["body"], dict):
90-
if "filter" in link["body"]:
91-
if cql2_filter:
92-
link["body"]["filter"] = cql2_filter.to_json()
93-
else:
94-
link["body"].pop("filter", None)
95-
link["body"].pop("filter-lang", None)
152+
had_filter = "filter" in link["body"]
153+
154+
if original_body_filter is not _UNSET:
155+
# Client originally sent a CQL2 filter in the request
156+
# body (POST /search). Echo it back verbatim so
157+
# paginated requests carry the same filter shape and
158+
# serialization.
159+
link["body"]["filter"] = original_body_filter
160+
elif had_filter and cql2_filter:
161+
# Filter came from the query string; emit it in the
162+
# body as JSON so the next-link POST is self-contained.
163+
link["body"]["filter"] = cql2_filter.to_json()
164+
elif had_filter:
165+
link["body"].pop("filter", None)
166+
167+
if original_body_filter_lang is not _UNSET:
168+
link["body"]["filter-lang"] = original_body_filter_lang
169+
elif had_filter and not cql2_filter:
170+
link["body"].pop("filter-lang", None)
96171

97172
# Send the modified response
98173
new_body = json.dumps(data).encode("utf-8")

tests/test_cql2_rewrite_links_filter_middleware.py

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Test Cql2RewriteLinksFilterMiddleware."""
22

3+
import json
34
import re
45

56
import pytest
@@ -335,3 +336,204 @@ async def test_endpoint(request: Request):
335336

336337
# Other data should be preserved
337338
assert body["other_data"] == "preserved"
339+
340+
341+
class TestPostBodyClientFilterPreservation:
342+
"""Regression: client filters sent in a POST search body must be preserved
343+
in the next-link body. The middleware previously read the original filter
344+
only from the query string, which silently dropped POST-body filters.
345+
"""
346+
347+
@pytest.mark.parametrize(
348+
"system_filter,client_filter,client_filter_lang,expected_filter,expected_filter_lang",
349+
[
350+
# CQL2-JSON client filter must be echoed back unchanged
351+
(
352+
"private = false",
353+
{"op": "<", "args": [{"property": "cloud_coverage"}, 50]},
354+
"cql2-json",
355+
{"op": "<", "args": [{"property": "cloud_coverage"}, 50]},
356+
"cql2-json",
357+
),
358+
# Different client filter
359+
(
360+
"collection = 'landsat'",
361+
{"op": ">", "args": [{"property": "datetime"}, "2023-01-01"]},
362+
"cql2-json",
363+
{"op": ">", "args": [{"property": "datetime"}, "2023-01-01"]},
364+
"cql2-json",
365+
),
366+
# CQL2-text client filter must also be preserved verbatim
367+
(
368+
"private = false",
369+
"cloud_coverage < 30",
370+
"cql2-text",
371+
"cloud_coverage < 30",
372+
"cql2-text",
373+
),
374+
# No client filter in body — filter/filter-lang stay stripped from next.body
375+
(
376+
"private = false",
377+
None,
378+
None,
379+
None,
380+
None,
381+
),
382+
],
383+
)
384+
def test_preserves_client_filter_from_post_body(
385+
self,
386+
system_filter,
387+
client_filter,
388+
client_filter_lang,
389+
expected_filter,
390+
expected_filter_lang,
391+
):
392+
"""POST /search with filter in body keeps that filter in the next link body."""
393+
app = FastAPI()
394+
395+
class MockBuildFilterMiddleware:
396+
def __init__(self, app, state_key="cql2_filter"):
397+
self.app = app
398+
self.state_key = state_key
399+
400+
async def __call__(self, scope, receive, send):
401+
if scope["type"] == "http":
402+
request = Request(scope)
403+
setattr(request.state, self.state_key, Expr(system_filter))
404+
await self.app(scope, receive, send)
405+
406+
app.add_middleware(Cql2RewriteLinksFilterMiddleware)
407+
app.add_middleware(MockBuildFilterMiddleware)
408+
409+
@app.post("/search")
410+
async def search_endpoint(request: Request):
411+
body_json = await request.json()
412+
system_expr = getattr(request.state, "cql2_filter", None)
413+
user_filter = body_json.get("filter")
414+
user_filter_lang = body_json.get("filter-lang")
415+
416+
combined = None
417+
if system_expr is not None and user_filter is not None:
418+
combined = system_expr + Expr(user_filter)
419+
elif system_expr is not None:
420+
combined = system_expr
421+
elif user_filter is not None:
422+
combined = Expr(user_filter)
423+
424+
next_body = {
425+
"collections": body_json.get("collections", []),
426+
"limit": body_json.get("limit", 10),
427+
"token": "next-token",
428+
}
429+
if combined is not None:
430+
lang = user_filter_lang or "cql2-json"
431+
next_body["filter-lang"] = lang
432+
next_body["filter"] = (
433+
combined.to_text() if lang == "cql2-text" else combined.to_json()
434+
)
435+
436+
return {
437+
"type": "FeatureCollection",
438+
"links": [
439+
{
440+
"rel": "next",
441+
"method": "POST",
442+
"href": "http://example.com/search",
443+
"body": next_body,
444+
}
445+
],
446+
}
447+
448+
request_body = {"collections": ["col1"], "limit": 10}
449+
if client_filter is not None:
450+
request_body["filter"] = client_filter
451+
request_body["filter-lang"] = client_filter_lang
452+
453+
client = TestClient(app)
454+
response = client.post("/search", json=request_body)
455+
assert response.status_code == 200, response.text
456+
data = response.json()
457+
458+
next_link = next(link for link in data["links"] if link.get("rel") == "next")
459+
body = next_link["body"]
460+
461+
# Pagination metadata is always carried through
462+
assert body["token"] == "next-token"
463+
assert body["collections"] == ["col1"]
464+
assert body["limit"] == 10
465+
466+
if expected_filter is None:
467+
assert "filter" not in body
468+
assert "filter-lang" not in body
469+
else:
470+
assert body["filter"] == expected_filter
471+
assert body["filter-lang"] == expected_filter_lang
472+
473+
def test_request_body_is_intact_for_inner_app(self):
474+
"""Body capture must replay the exact original bytes to the inner app."""
475+
app = FastAPI()
476+
477+
class MockBuildFilterMiddleware:
478+
def __init__(self, app, state_key="cql2_filter"):
479+
self.app = app
480+
self.state_key = state_key
481+
482+
async def __call__(self, scope, receive, send):
483+
if scope["type"] == "http":
484+
request = Request(scope)
485+
setattr(request.state, self.state_key, Expr("private = false"))
486+
await self.app(scope, receive, send)
487+
488+
app.add_middleware(Cql2RewriteLinksFilterMiddleware)
489+
app.add_middleware(MockBuildFilterMiddleware)
490+
491+
@app.post("/search")
492+
async def search_endpoint(request: Request):
493+
received = await request.body()
494+
return {"echo": json.loads(received)}
495+
496+
request_body = {
497+
"collections": ["a", "b"],
498+
"filter": {"op": "=", "args": [{"property": "x"}, 1]},
499+
"filter-lang": "cql2-json",
500+
}
501+
client = TestClient(app)
502+
response = client.post("/search", json=request_body)
503+
assert response.status_code == 200, response.text
504+
assert response.json()["echo"] == request_body
505+
506+
def test_malformed_json_body_does_not_break_middleware(self):
507+
"""An unparseable body must pass through without the middleware crashing."""
508+
app = FastAPI()
509+
510+
class MockBuildFilterMiddleware:
511+
def __init__(self, app, state_key="cql2_filter"):
512+
self.app = app
513+
self.state_key = state_key
514+
515+
async def __call__(self, scope, receive, send):
516+
if scope["type"] == "http":
517+
request = Request(scope)
518+
setattr(request.state, self.state_key, Expr("private = false"))
519+
await self.app(scope, receive, send)
520+
521+
app.add_middleware(Cql2RewriteLinksFilterMiddleware)
522+
app.add_middleware(MockBuildFilterMiddleware)
523+
524+
@app.post("/search")
525+
async def search_endpoint(request: Request):
526+
raw = await request.body()
527+
return Response(
528+
content=raw,
529+
media_type="application/octet-stream",
530+
)
531+
532+
client = TestClient(app)
533+
response = client.post(
534+
"/search",
535+
content=b"not json",
536+
headers={"content-type": "application/json"},
537+
)
538+
assert response.status_code == 200
539+
assert response.content == b"not json"

0 commit comments

Comments
 (0)