Skip to content

Commit 620c84d

Browse files
committed
Skip buffered upload handler when probe chunk detects client disconnect
Insert a b"\n" probe chunk before the real response body so that an early client disconnect is detected before the upload handler is enqueued. Add asyncio.sleep(0) yield point in _upload_buffered_file to let the disconnect watcher fire first, and return early if disconnect was seen. Includes a test covering the probe-based disconnect detection path.
1 parent af85fa2 commit 620c84d

2 files changed

Lines changed: 90 additions & 5 deletions

File tree

packages/reflex-components-core/src/reflex_components_core/core/_upload.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -441,18 +441,30 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
441441
disconnect_task: asyncio.Task[None] | None = None
442442
use_watcher = spec_version >= (2, 4) and self._on_disconnect is not None
443443

444+
if use_watcher:
445+
body_iterator = self.body_iterator
446+
447+
async def body_with_probe() -> AsyncGenerator[
448+
str | bytes | memoryview[int], None
449+
]:
450+
"""Yield a tiny probe chunk before the real response body."""
451+
yield b"\n"
452+
async for chunk in body_iterator:
453+
yield chunk
454+
455+
self.body_iterator = body_with_probe()
456+
444457
async def wrapped_receive() -> Message:
445458
message = await receive()
446-
if message.get("type") == "http.disconnect":
459+
if message["type"] == "http.disconnect":
447460
self._handle_disconnect()
448461
return message
449462

450463
try:
451464
if use_watcher:
452-
# ASGI >= 2.4: use a dedicated task to watch for disconnect
453-
# concurrently. Pass raw `receive` to Starlette — the watcher
454-
# owns disconnect detection; using wrapped_receive here would
455-
# race on the same receive callable.
465+
# ASGI >= 2.4: Starlette does not call receive() while
466+
# streaming. Use a dedicated task so disconnect fires the
467+
# callback; pass raw receive to avoid racing wrapped_receive.
456468
disconnect_task = asyncio.create_task(self._watch_disconnect(receive))
457469
try:
458470
await super().__call__(
@@ -593,6 +605,10 @@ async def _ndjson_updates():
593605
Yields:
594606
Each state update as newline-delimited JSON.
595607
"""
608+
# Let the disconnect watcher run before we enqueue the upload handler.
609+
await asyncio.sleep(0)
610+
if disconnect_seen:
611+
return
596612
# Enqueue the task on the main event loop, but emit deltas to the local queue.
597613
async for delta in app.event_processor.enqueue_stream_delta(
598614
token,

tests/units/test_app.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from reflex_components_radix.themes.typography.text import Text
3131
from starlette.applications import Starlette
3232
from starlette.datastructures import FormData, Headers, UploadFile
33+
from starlette.requests import ClientDisconnect
3334
from starlette.responses import StreamingResponse
3435
from starlette_admin.auth import AuthProvider
3536

@@ -1379,6 +1380,74 @@ async def send(_message): # noqa: RUF029
13791380
assert bio.closed
13801381

13811382

1383+
@pytest.mark.asyncio
1384+
async def test_upload_file_skips_buffered_handler_when_disconnect_detected_on_probe(
1385+
token: str,
1386+
):
1387+
"""Buffered uploads skip handler dispatch when the probe send disconnects.
1388+
1389+
This models ASGI 2.4+ behavior where the upload request can finish parsing,
1390+
but the client disconnect is only surfaced on the first response-body send.
1391+
1392+
Args:
1393+
token: A token.
1394+
"""
1395+
request_mock = unittest.mock.Mock()
1396+
request_mock.headers = {
1397+
"reflex-client-token": token,
1398+
"reflex-event-handler": f"{FileUploadState.get_full_name()}.multi_handle_upload",
1399+
}
1400+
1401+
bio = io.BytesIO(b"contents of image one")
1402+
file1 = UploadFile(filename="image1.jpg", file=bio)
1403+
form_data = FormData([("files", file1)])
1404+
original_close = form_data.close
1405+
form_close = AsyncMock(side_effect=original_close)
1406+
form_data.close = form_close
1407+
1408+
async def form(): # noqa: RUF029
1409+
return form_data
1410+
1411+
request_mock.form = form
1412+
1413+
msg = "upload handler should not be enqueued"
1414+
probe_chunk = b"\n"
1415+
asgi_24_scope = {"type": "http", "asgi": {"spec_version": "2.4"}}
1416+
enqueue_stream_delta = Mock(side_effect=AssertionError(msg))
1417+
app = Mock(
1418+
event_processor=Mock(enqueue_stream_delta=enqueue_stream_delta),
1419+
)
1420+
1421+
upload_fn = upload(app)
1422+
streaming_response = await upload_fn(request_mock)
1423+
1424+
assert isinstance(streaming_response, StreamingResponse)
1425+
1426+
async def receive():
1427+
await asyncio.sleep(0)
1428+
return {"type": "http.disconnect"}
1429+
1430+
async def send(message):
1431+
await asyncio.sleep(0)
1432+
if (
1433+
message.get("type") == "http.response.body"
1434+
and message.get("body") == probe_chunk
1435+
):
1436+
err = "client disconnected"
1437+
raise OSError(err)
1438+
1439+
with pytest.raises(ClientDisconnect):
1440+
await streaming_response(
1441+
asgi_24_scope,
1442+
receive,
1443+
send,
1444+
)
1445+
1446+
assert enqueue_stream_delta.call_count == 0
1447+
assert form_close.await_count == 1
1448+
assert bio.closed
1449+
1450+
13821451
@pytest.mark.asyncio
13831452
@pytest.mark.parametrize(
13841453
"state",

0 commit comments

Comments
 (0)