Skip to content

Commit 55191ba

Browse files
Fix buffered upload handler not cancelling on client disconnect (#6307)
* Fix buffered upload handler not cancelling on client disconnect When a client disconnects during a buffered file upload, the enqueued event handler was not being cancelled, leaving orphaned tasks. Add disconnect detection to _UploadStreamingResponse (via ASGI 2.4 receive watcher and ClientDisconnect handling) and propagate an on_task_future callback through enqueue_stream_delta so the upload path can cancel the EventFuture when the connection drops. * Update packages/reflex-components-core/src/reflex_components_core/core/_upload.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * Update packages/reflex-components-core/src/reflex_components_core/core/_upload.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * Generalize disconnect-cancel for streaming event responses Move disconnect detection out of upload-specific code into a reusable DisconnectAwareStreamingResponse in reflex-base. Drop the on_task_future callback from enqueue_stream_delta — the generator"s finally block already cancels the task future when the body iterator is closed on disconnect. * Add pre-2.4 streaming response disconnect test, remove bogus event processor tests Add test coverage for the pre-2.4 ASGI disconnect path in DisconnectAwareStreamingResponse. Remove test_stream_delta_aclose_cancels_in_flight_event and test_stream_delta_consumer_task_cancellation_cancels_in_flight_event — they passed with or without the task_future.cancel() fix because _on_future_done already cascades cancellation from the EventFuture to the handler task. --------- Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
1 parent 9680808 commit 55191ba

File tree

5 files changed

+459
-27
lines changed

5 files changed

+459
-27
lines changed

packages/reflex-base/src/reflex_base/event/processor/event_processor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,10 @@ async def enqueue_stream_delta(
390390
Any frontend events or chained events are handled normally and deltas from chained events
391391
will not be yielded by this method.
392392
393+
If the consumer stops iterating early, the in-flight event future is
394+
cancelled so the handler chain does not continue running in the
395+
background.
396+
393397
Args:
394398
token: The client token associated with the event.
395399
event: The event to be enqueued.
@@ -442,6 +446,9 @@ async def _emit_delta_impl(
442446
finally:
443447
for future in waiting_for:
444448
future.cancel()
449+
# Cancel the event chain if the streaming consumer exits early.
450+
if not task_future.done():
451+
task_future.cancel()
445452
# Raise any exceptions for the caller, waiting for all chained events.
446453
await task_future.wait_all()
447454

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
"""Internal helpers for streaming responses."""
2+
3+
from __future__ import annotations
4+
5+
import asyncio
6+
import builtins
7+
import contextlib
8+
import sys
9+
from collections.abc import Awaitable, Callable, Generator
10+
from functools import partial
11+
from typing import Any
12+
13+
import anyio
14+
from starlette.requests import ClientDisconnect
15+
from starlette.responses import StreamingResponse
16+
17+
from reflex_base.utils.types import Receive, Scope, Send
18+
19+
_BASE_EXCEPTION_GROUP = getattr(builtins, "BaseExceptionGroup", None)
20+
21+
22+
def _parse_asgi_spec_version(scope: Scope) -> tuple[int, ...]:
23+
"""Parse the ASGI spec version from a scope.
24+
25+
Args:
26+
scope: The ASGI scope.
27+
28+
Returns:
29+
The parsed ASGI spec version, or ``(2, 0)`` if parsing fails.
30+
"""
31+
raw_spec = scope.get("asgi", {}).get("spec_version", "2.0")
32+
try:
33+
return tuple(int(part) for part in str(raw_spec).split("."))
34+
except (TypeError, ValueError):
35+
return (2, 0)
36+
37+
38+
@contextlib.contextmanager
39+
def _collapse_excgroups() -> Generator[None, None, None]:
40+
"""Collapse single-item exception groups to their underlying exception."""
41+
collapsed_exc: BaseException | None = None
42+
try:
43+
yield
44+
except BaseException as exc:
45+
collapsed_exc = exc
46+
if sys.version_info >= (3, 11) and _BASE_EXCEPTION_GROUP is not None:
47+
while isinstance(collapsed_exc, _BASE_EXCEPTION_GROUP):
48+
nested_exceptions = getattr(collapsed_exc, "exceptions", None)
49+
if (
50+
not isinstance(nested_exceptions, tuple)
51+
or len(nested_exceptions) != 1
52+
or not isinstance(nested_exceptions[0], BaseException)
53+
):
54+
break
55+
collapsed_exc = nested_exceptions[0]
56+
if collapsed_exc is exc:
57+
raise
58+
if collapsed_exc is not None:
59+
raise collapsed_exc
60+
61+
62+
class DisconnectAwareStreamingResponse(StreamingResponse):
63+
"""Streaming response that cancels its body task on disconnect."""
64+
65+
_on_finish: Callable[[], Awaitable[None]]
66+
67+
def __init__(
68+
self,
69+
*args: Any,
70+
on_finish: Callable[[], Awaitable[None]],
71+
**kwargs: Any,
72+
) -> None:
73+
"""Initialize the response.
74+
75+
Args:
76+
args: Positional args forwarded to ``StreamingResponse``.
77+
on_finish: Cleanup callback to run exactly once when the response ends.
78+
kwargs: Keyword args forwarded to ``StreamingResponse``.
79+
"""
80+
super().__init__(*args, **kwargs)
81+
self._on_finish = on_finish
82+
83+
async def _watch_disconnect(self, receive: Receive) -> None:
84+
"""Wait for the client connection to close."""
85+
while True:
86+
message = await receive()
87+
if message["type"] == "http.disconnect":
88+
return
89+
90+
async def _close_body_iterator(self) -> None:
91+
"""Close the body iterator if it supports ``aclose``."""
92+
aclose = getattr(self.body_iterator, "aclose", None)
93+
if aclose is not None:
94+
await aclose()
95+
96+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
97+
"""Serve the response and cancel the body task on disconnect."""
98+
spec_version = _parse_asgi_spec_version(scope)
99+
100+
try:
101+
if spec_version < (2, 4):
102+
with _collapse_excgroups():
103+
async with anyio.create_task_group() as task_group:
104+
105+
async def wrap(func: Callable[[], Awaitable[None]]) -> None:
106+
await func()
107+
task_group.cancel_scope.cancel()
108+
109+
task_group.start_soon(wrap, partial(self.stream_response, send))
110+
await wrap(partial(self.listen_for_disconnect, receive))
111+
else:
112+
# Verified against Starlette 0.52.1: the ASGI >= 2.4 path in
113+
# StreamingResponse.__call__ delegates straight to
114+
# stream_response(send) and does not read from receive().
115+
# Keep calling stream_response(send) directly here so the
116+
# disconnect watcher remains the only receive() consumer; if
117+
# Starlette changes that contract, re-check this logic.
118+
stream_task = asyncio.create_task(self.stream_response(send))
119+
disconnect_task = asyncio.create_task(self._watch_disconnect(receive))
120+
should_close_body_iterator = False
121+
122+
try:
123+
done, _ = await asyncio.wait(
124+
{stream_task, disconnect_task},
125+
return_when=asyncio.FIRST_COMPLETED,
126+
)
127+
if disconnect_task in done and not stream_task.done():
128+
should_close_body_iterator = True
129+
stream_task.cancel()
130+
with contextlib.suppress(asyncio.CancelledError):
131+
await stream_task
132+
else:
133+
try:
134+
await stream_task
135+
except OSError as err:
136+
should_close_body_iterator = True
137+
raise ClientDisconnect from err
138+
finally:
139+
if not disconnect_task.done():
140+
disconnect_task.cancel()
141+
with contextlib.suppress(asyncio.CancelledError):
142+
await disconnect_task
143+
if not stream_task.done():
144+
should_close_body_iterator = True
145+
stream_task.cancel()
146+
with contextlib.suppress(asyncio.CancelledError):
147+
await stream_task
148+
if should_close_body_iterator:
149+
await self._close_body_iterator()
150+
finally:
151+
await self._on_finish()
152+
153+
if self.background is not None:
154+
await self.background()
155+
156+
157+
__all__ = ["DisconnectAwareStreamingResponse"]

packages/reflex-components-core/src/reflex_components_core/core/_upload.py

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
import contextlib
77
import dataclasses
88
from collections import deque
9-
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable
9+
from collections.abc import AsyncGenerator, AsyncIterator
1010
from pathlib import Path
1111
from typing import TYPE_CHECKING, Any, BinaryIO, cast
1212

1313
from python_multipart.multipart import MultipartParser, parse_options_header
1414
from reflex_base.utils import exceptions
1515
from reflex_base.utils.format import json_dumps
16+
from reflex_base.utils.streaming_response import DisconnectAwareStreamingResponse
1617
from starlette.datastructures import Headers
1718
from starlette.datastructures import UploadFile as StarletteUploadFile
1819
from starlette.exceptions import HTTPException
@@ -22,8 +23,6 @@
2223
from typing_extensions import Self
2324

2425
if TYPE_CHECKING:
25-
from reflex_base.utils.types import Receive, Scope, Send
26-
2726
from reflex.app import App
2827

2928

@@ -399,27 +398,6 @@ async def parse(self) -> None:
399398
await self._flush_emitted_chunks()
400399

401400

402-
class _UploadStreamingResponse(StreamingResponse):
403-
"""Streaming response that always releases upload form resources."""
404-
405-
_on_finish: Callable[[], Awaitable[None]]
406-
407-
def __init__(
408-
self,
409-
*args: Any,
410-
on_finish: Callable[[], Awaitable[None]],
411-
**kwargs: Any,
412-
) -> None:
413-
super().__init__(*args, **kwargs)
414-
self._on_finish = on_finish
415-
416-
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
417-
try:
418-
await super().__call__(scope, receive, send)
419-
finally:
420-
await self._on_finish()
421-
422-
423401
def _require_upload_headers(request: Request) -> tuple[str, str]:
424402
"""Extract the required upload headers from a request.
425403
@@ -525,7 +503,7 @@ async def _ndjson_updates():
525503
async for delta in app.event_processor.enqueue_stream_delta(token, event):
526504
yield json_dumps(StateUpdate(delta=delta)) + "\n"
527505

528-
return _UploadStreamingResponse(
506+
return DisconnectAwareStreamingResponse(
529507
_ndjson_updates(),
530508
media_type="application/x-ndjson",
531509
on_finish=_close_form_data,

0 commit comments

Comments
 (0)