Skip to content

Commit ad9cdb5

Browse files
authored
Merge pull request #2 from feldroy/fix-middleware-return-await
Pass inner app results through the middleware
2 parents 191ba51 + 5bd404a commit ad9cdb5

2 files changed

Lines changed: 145 additions & 93 deletions

File tree

src/staticware/middleware.py

Lines changed: 39 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@
1919
import hashlib
2020
import mimetypes
2121
import re
22-
from pathlib import Path
2322
from collections.abc import Awaitable, Callable
23+
from pathlib import Path
2424
from typing import Any
2525

2626
# ASGI protocol types — inlined so we depend on nothing.
2727
type Scope = dict[str, Any]
2828
type Receive = Callable[[], Awaitable[dict[str, Any]]]
2929
type Send = Callable[[dict[str, Any]], Awaitable[None]]
30-
type ASGIApp = Callable[[Scope, Receive, Send], Awaitable[None]]
30+
type ASGIApp = Callable[[Scope, Receive, Send], Awaitable[Any]]
3131

3232

3333
class HashedStatic:
@@ -102,7 +102,7 @@ def _hash_files(self) -> None:
102102

103103
self.file_map[relative] = hashed
104104
self._reverse[hashed] = relative
105-
self._etags[relative] = f'"{hash_val}"'.encode('latin-1')
105+
self._etags[relative] = f'"{hash_val}"'.encode("latin-1")
106106

107107
def url(self, path: str) -> str:
108108
"""Return the cache-busted URL for a static file path.
@@ -165,9 +165,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
165165
if hdr_name == b"if-none-match" and hdr_value == etag:
166166
await _send_text(send, 304, b"")
167167
return
168-
await _send_file(
169-
send, file_path, extra_headers=[(b"etag", etag)]
170-
)
168+
await _send_file(send, file_path, extra_headers=[(b"etag", etag)])
171169
else:
172170
await _send_file(send, file_path)
173171
return
@@ -201,10 +199,9 @@ def _replace(self, match: re.Match[str]) -> str:
201199
return f"{self.static.prefix}/{self.static.file_map[path]}"
202200
return match.group(0)
203201

204-
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
202+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> Any:
205203
if scope["type"] != "http":
206-
await self.app(scope, receive, send)
207-
return
204+
return await self.app(scope, receive, send)
208205

209206
response_start: dict[str, Any] | None = None
210207
body_parts: list[bytes] = []
@@ -225,9 +222,7 @@ async def send_wrapper(message: dict[str, Any]) -> None:
225222

226223
if message["type"] == "http.response.body":
227224
if response_start is None:
228-
raise RuntimeError(
229-
"http.response.body received before http.response.start"
230-
)
225+
raise RuntimeError("http.response.body received before http.response.start")
231226
if not is_html:
232227
await send(message)
233228
return
@@ -246,21 +241,17 @@ async def send_wrapper(message: dict[str, Any]) -> None:
246241
pass
247242

248243
if response_start is None:
249-
raise RuntimeError(
250-
"http.response.body received before http.response.start"
251-
)
244+
raise RuntimeError("http.response.body received before http.response.start")
252245
new_headers = [
253-
(k, str(len(full_body)).encode("latin-1"))
254-
if k == b"content-length"
255-
else (k, v)
246+
(k, str(len(full_body)).encode("latin-1")) if k == b"content-length" else (k, v)
256247
for k, v in response_start.get("headers", [])
257248
]
258249
response_start["headers"] = new_headers
259250
await send(response_start)
260251
await send({"type": "http.response.body", "body": full_body})
261252
return
262253

263-
await self.app(scope, receive, send_wrapper)
254+
return await self.app(scope, receive, send_wrapper)
264255

265256

266257
# ── Raw ASGI helpers ────────────────────────────────────────────────────
@@ -283,28 +274,36 @@ async def _send_file(
283274
if extra_headers:
284275
headers.extend(extra_headers)
285276

286-
await send({
287-
"type": "http.response.start",
288-
"status": 200,
289-
"headers": headers,
290-
})
291-
await send({
292-
"type": "http.response.body",
293-
"body": content,
294-
})
277+
await send(
278+
{
279+
"type": "http.response.start",
280+
"status": 200,
281+
"headers": headers,
282+
}
283+
)
284+
await send(
285+
{
286+
"type": "http.response.body",
287+
"body": content,
288+
}
289+
)
295290

296291

297292
async def _send_text(send: Send, status: int, body: bytes) -> None:
298293
"""Send a plain-text ASGI response."""
299-
await send({
300-
"type": "http.response.start",
301-
"status": status,
302-
"headers": [
303-
(b"content-type", b"text/plain"),
304-
(b"content-length", str(len(body)).encode("latin-1")),
305-
],
306-
})
307-
await send({
308-
"type": "http.response.body",
309-
"body": body,
310-
})
294+
await send(
295+
{
296+
"type": "http.response.start",
297+
"status": status,
298+
"headers": [
299+
(b"content-type", b"text/plain"),
300+
(b"content-length", str(len(body)).encode("latin-1")),
301+
],
302+
}
303+
)
304+
await send(
305+
{
306+
"type": "http.response.body",
307+
"body": body,
308+
}
309+
)

tests/test_staticware.py

Lines changed: 106 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
from staticware import HashedStatic, StaticRewriteMiddleware
2424

25-
2625
# ── Helpers ──────────────────────────────────────────────────────────────
2726

2827

@@ -61,7 +60,6 @@ def expected_hash(content: bytes, length: int = 8) -> str:
6160
# ── HashedStatic: hashing and url() ──────────────────────────────────────
6261

6362

64-
6563
def test_file_map_contains_all_files(static: HashedStatic, static_dir: Path) -> None:
6664
assert "styles.css" in static.file_map
6765
assert "images/logo.png" in static.file_map
@@ -216,14 +214,16 @@ def make_html_app(html: str):
216214
body = html.encode("utf-8")
217215

218216
async def app(scope: dict, receive: Any, send: Any) -> None:
219-
await send({
220-
"type": "http.response.start",
221-
"status": 200,
222-
"headers": [
223-
(b"content-type", b"text/html; charset=utf-8"),
224-
(b"content-length", str(len(body)).encode("latin-1")),
225-
],
226-
})
217+
await send(
218+
{
219+
"type": "http.response.start",
220+
"status": 200,
221+
"headers": [
222+
(b"content-type", b"text/html; charset=utf-8"),
223+
(b"content-length", str(len(body)).encode("latin-1")),
224+
],
225+
}
226+
)
227227
await send({"type": "http.response.body", "body": body})
228228

229229
return app
@@ -233,14 +233,16 @@ def make_json_app(data: bytes):
233233
"""Create a dummy ASGI app that returns JSON."""
234234

235235
async def app(scope: dict, receive: Any, send: Any) -> None:
236-
await send({
237-
"type": "http.response.start",
238-
"status": 200,
239-
"headers": [
240-
(b"content-type", b"application/json"),
241-
(b"content-length", str(len(data)).encode("latin-1")),
242-
],
243-
})
236+
await send(
237+
{
238+
"type": "http.response.start",
239+
"status": 200,
240+
"headers": [
241+
(b"content-type", b"application/json"),
242+
(b"content-length", str(len(data)).encode("latin-1")),
243+
],
244+
}
245+
)
244246
await send({"type": "http.response.body", "body": data})
245247

246248
return app
@@ -318,10 +320,12 @@ async def test_rewrite_raises_runtime_error_on_body_before_start(
318320

319321
async def broken_app(scope: dict, receive: Any, send: Any) -> None:
320322
# Skip http.response.start entirely — straight to body.
321-
await send({
322-
"type": "http.response.body",
323-
"body": b"<html>oops</html>",
324-
})
323+
await send(
324+
{
325+
"type": "http.response.body",
326+
"body": b"<html>oops</html>",
327+
}
328+
)
325329

326330
app = StaticRewriteMiddleware(broken_app, static=static)
327331
with pytest.raises(RuntimeError):
@@ -335,14 +339,16 @@ async def test_rewrite_streaming_html_response(static: HashedStatic) -> None:
335339

336340
async def streaming_app(scope: dict, receive: Any, send: Any) -> None:
337341
total = len(chunk1) + len(chunk2)
338-
await send({
339-
"type": "http.response.start",
340-
"status": 200,
341-
"headers": [
342-
(b"content-type", b"text/html; charset=utf-8"),
343-
(b"content-length", str(total).encode("latin-1")),
344-
],
345-
})
342+
await send(
343+
{
344+
"type": "http.response.start",
345+
"status": 200,
346+
"headers": [
347+
(b"content-type", b"text/html; charset=utf-8"),
348+
(b"content-length", str(total).encode("latin-1")),
349+
],
350+
}
351+
)
346352
await send({"type": "http.response.body", "body": chunk1, "more_body": True})
347353
await send({"type": "http.response.body", "body": chunk2, "more_body": False})
348354

@@ -373,14 +379,16 @@ async def test_rewrite_non_utf8_html_passes_through(static: HashedStatic) -> Non
373379
raw_body = b"<html>\x80\x81\x82 not valid utf-8</html>"
374380

375381
async def bad_encoding_app(scope: dict, receive: Any, send: Any) -> None:
376-
await send({
377-
"type": "http.response.start",
378-
"status": 200,
379-
"headers": [
380-
(b"content-type", b"text/html; charset=utf-8"),
381-
(b"content-length", str(len(raw_body)).encode("latin-1")),
382-
],
383-
})
382+
await send(
383+
{
384+
"type": "http.response.start",
385+
"status": 200,
386+
"headers": [
387+
(b"content-type", b"text/html; charset=utf-8"),
388+
(b"content-length", str(len(raw_body)).encode("latin-1")),
389+
],
390+
}
391+
)
384392
await send({"type": "http.response.body", "body": raw_body})
385393

386394
app = StaticRewriteMiddleware(bad_encoding_app, static=static)
@@ -397,9 +405,7 @@ def make_mount_scope(path: str, *, root_path: str = "") -> dict[str, Any]:
397405
return {"type": "http", "path": path, "root_path": root_path, "method": "GET"}
398406

399407

400-
async def test_serve_with_root_path_scope(
401-
static: HashedStatic, static_dir: Path
402-
) -> None:
408+
async def test_serve_with_root_path_scope(static: HashedStatic, static_dir: Path) -> None:
403409
"""Starlette-style mount: root_path set, path still includes the prefix.
404410
405411
Starlette sets scope["root_path"] = "/static" and leaves
@@ -414,9 +420,7 @@ async def test_serve_with_root_path_scope(
414420
assert resp.text == "body { color: red; }"
415421

416422

417-
async def test_serve_with_stripped_path(
418-
static: HashedStatic, static_dir: Path
419-
) -> None:
423+
async def test_serve_with_stripped_path(static: HashedStatic, static_dir: Path) -> None:
420424
"""Litestar-style mount: framework strips the prefix from scope["path"].
421425
422426
The sub-app sees scope["root_path"] = "/static" and
@@ -466,18 +470,14 @@ async def test_serve_with_mismatched_mount_and_prefix(static_dir: Path) -> None:
466470
# ── HashedStatic: ETag and conditional requests ──────────────────────
467471

468472

469-
def make_scope_with_headers(
470-
path: str, headers: list[tuple[bytes, bytes]] | None = None
471-
) -> dict[str, Any]:
473+
def make_scope_with_headers(path: str, headers: list[tuple[bytes, bytes]] | None = None) -> dict[str, Any]:
472474
scope: dict[str, Any] = {"type": "http", "path": path, "method": "GET"}
473475
if headers:
474476
scope["headers"] = headers
475477
return scope
476478

477479

478-
async def test_etag_on_unhashed_response(
479-
static: HashedStatic, static_dir: Path
480-
) -> None:
480+
async def test_etag_on_unhashed_response(static: HashedStatic, static_dir: Path) -> None:
481481
"""Original filename response includes an ETag header with the content hash."""
482482
resp = ResponseCollector()
483483
await static(make_scope("/static/styles.css"), receive, resp)
@@ -489,9 +489,7 @@ async def test_etag_on_unhashed_response(
489489
assert resp.headers[b"etag"] == f'"{h}"'.encode("latin-1")
490490

491491

492-
async def test_conditional_request_returns_304(
493-
static: HashedStatic, static_dir: Path
494-
) -> None:
492+
async def test_conditional_request_returns_304(static: HashedStatic, static_dir: Path) -> None:
495493
"""If-None-Match with matching ETag returns 304 and empty body."""
496494
css_content = (static_dir / "styles.css").read_bytes()
497495
h = expected_hash(css_content)
@@ -528,3 +526,58 @@ async def test_hashed_url_no_etag(static: HashedStatic) -> None:
528526
await static(make_scope(f"/static/{hashed_name}"), receive, resp)
529527
assert resp.status == 200
530528
assert b"etag" not in resp.headers, "Hashed URL should not include an etag header"
529+
530+
531+
# ── StaticRewriteMiddleware: return value propagation ──────────────
532+
533+
534+
async def test_rewrite_middleware_returns_inner_app_result(
535+
static: HashedStatic,
536+
) -> None:
537+
"""Middleware should propagate the inner app's return value on the HTTP path.
538+
539+
ASGI apps normally return None, but the spec does not forbid return values.
540+
Frameworks like Starlette rely on ``return await self.app(...)`` so that
541+
return values propagate through the middleware chain. A bare ``await``
542+
without ``return`` silently discards the result.
543+
"""
544+
sentinel = "app_result"
545+
546+
async def inner_app(scope: dict, receive: Any, send: Any) -> str:
547+
body = b"<html>hello</html>"
548+
await send(
549+
{
550+
"type": "http.response.start",
551+
"status": 200,
552+
"headers": [
553+
(b"content-type", b"text/html; charset=utf-8"),
554+
(b"content-length", str(len(body)).encode("latin-1")),
555+
],
556+
}
557+
)
558+
await send({"type": "http.response.body", "body": body})
559+
return sentinel
560+
561+
app = StaticRewriteMiddleware(inner_app, static=static)
562+
resp = ResponseCollector()
563+
result = await app(make_scope("/"), receive, resp)
564+
assert result == sentinel
565+
566+
567+
async def test_rewrite_middleware_returns_inner_app_result_non_http(
568+
static: HashedStatic,
569+
) -> None:
570+
"""Middleware should propagate the inner app's return value for non-HTTP scopes.
571+
572+
When the scope type is not "http", the middleware forwards directly to the
573+
inner app. It must ``return await self.app(...)`` so the return value is
574+
not silently discarded.
575+
"""
576+
sentinel = "ws_result"
577+
578+
async def inner_app(scope: dict, receive: Any, send: Any) -> str:
579+
return sentinel
580+
581+
app = StaticRewriteMiddleware(inner_app, static=static)
582+
result = await app({"type": "websocket", "path": "/"}, receive, ResponseCollector())
583+
assert result == sentinel

0 commit comments

Comments
 (0)