Skip to content

Commit af85fa2

Browse files
committed
fixed buffer upload by shortcircuiting
1 parent 3dc14f1 commit af85fa2

4 files changed

Lines changed: 182 additions & 3 deletions

File tree

packages/reflex-base/src/reflex_base/event/processor/event_processor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ async def enqueue_stream_delta(
380380
self,
381381
token: str,
382382
event: Event,
383+
on_task_future: Callable[[EventFuture], None] | None = None,
383384
) -> AsyncGenerator[Mapping[str, Any]]:
384385
"""Enqueue an event to be processed and yield deltas emitted by the event handler.
385386
@@ -393,6 +394,8 @@ async def enqueue_stream_delta(
393394
Args:
394395
token: The client token associated with the event.
395396
event: The event to be enqueued.
397+
on_task_future: Optional callback invoked with the EventFuture for the
398+
enqueued handler as soon as it is created.
396399
397400
Yields:
398401
Deltas emitted by the event handler for the specified token.
@@ -425,6 +428,8 @@ async def _emit_delta_impl(
425428
emit_delta_impl=_emit_delta_impl,
426429
),
427430
)
431+
if on_task_future is not None:
432+
on_task_future(task_future)
428433
all_task_futures = asyncio.create_task(task_future.wait_all())
429434
waiting_for = {all_task_futures, asyncio.create_task(deltas.get())}
430435
try:

packages/reflex-components-core/src/reflex_components_core/core/_upload.py

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
from typing_extensions import Self
2323

2424
if TYPE_CHECKING:
25-
from reflex_base.utils.types import Receive, Scope, Send
25+
from reflex_base.event.processor import EventFuture
26+
from reflex_base.utils.types import Message, Receive, Scope, Send
2627

2728
from reflex.app import App
2829

@@ -403,20 +404,70 @@ class _UploadStreamingResponse(StreamingResponse):
403404
"""Streaming response that always releases upload form resources."""
404405

405406
_on_finish: Callable[[], Awaitable[None]]
407+
_on_disconnect: Callable[[], None] | None
408+
_disconnect_handled: bool
406409

407410
def __init__(
408411
self,
409412
*args: Any,
410413
on_finish: Callable[[], Awaitable[None]],
414+
on_disconnect: Callable[[], None] | None = None,
411415
**kwargs: Any,
412416
) -> None:
413417
super().__init__(*args, **kwargs)
414418
self._on_finish = on_finish
419+
self._on_disconnect = on_disconnect
420+
self._disconnect_handled = False
421+
422+
def _handle_disconnect(self) -> None:
423+
"""Run disconnect cleanup exactly once."""
424+
if self._disconnect_handled or self._on_disconnect is None:
425+
return
426+
self._disconnect_handled = True
427+
self._on_disconnect()
428+
429+
async def _watch_disconnect(self, receive: Receive) -> None:
430+
"""Wait for the client connection to close."""
431+
while True:
432+
message = await receive()
433+
if message["type"] == "http.disconnect":
434+
self._handle_disconnect()
435+
return
415436

416437
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
438+
spec_version = tuple(
439+
map(int, scope.get("asgi", {}).get("spec_version", "2.0").split("."))
440+
)
441+
disconnect_task: asyncio.Task[None] | None = None
442+
use_watcher = spec_version >= (2, 4) and self._on_disconnect is not None
443+
444+
async def wrapped_receive() -> Message:
445+
message = await receive()
446+
if message.get("type") == "http.disconnect":
447+
self._handle_disconnect()
448+
return message
449+
417450
try:
418-
await super().__call__(scope, receive, send)
451+
if use_watcher:
452+
# ASGI >= 2.4: use a dedicated task to watch for disconnect
453+
# concurrently. Pass raw `receive` to Starlette — the watcher
454+
# owns disconnect detection; using wrapped_receive here would
455+
# race on the same receive callable.
456+
disconnect_task = asyncio.create_task(self._watch_disconnect(receive))
457+
try:
458+
await super().__call__(
459+
scope,
460+
wrapped_receive if not use_watcher else receive,
461+
send,
462+
)
463+
except ClientDisconnect:
464+
self._handle_disconnect()
465+
raise
419466
finally:
467+
if disconnect_task is not None:
468+
disconnect_task.cancel()
469+
with contextlib.suppress(asyncio.CancelledError):
470+
await disconnect_task
420471
await self._on_finish()
421472

422473

@@ -515,20 +566,46 @@ def _create_upload_event() -> Event:
515566
msg = "Upload event was not created."
516567
raise RuntimeError(msg)
517568

569+
task_future: EventFuture | None = None
570+
disconnect_seen = False
571+
572+
def _try_cancel() -> None:
573+
"""Cancel the task future if it exists and is still running."""
574+
if task_future is not None and not task_future.done():
575+
task_future.cancel()
576+
577+
def _remember_task_future(future: EventFuture) -> None:
578+
"""Keep a handle to the upload task for disconnect cancellation."""
579+
nonlocal task_future
580+
task_future = future
581+
if disconnect_seen:
582+
_try_cancel()
583+
584+
def _cancel_upload_task() -> None:
585+
"""Cancel the queued upload handler when the client disconnects."""
586+
nonlocal disconnect_seen
587+
disconnect_seen = True
588+
_try_cancel()
589+
518590
async def _ndjson_updates():
519591
"""Process the upload event, generating ndjson updates.
520592
521593
Yields:
522594
Each state update as newline-delimited JSON.
523595
"""
524596
# Enqueue the task on the main event loop, but emit deltas to the local queue.
525-
async for delta in app.event_processor.enqueue_stream_delta(token, event):
597+
async for delta in app.event_processor.enqueue_stream_delta(
598+
token,
599+
event,
600+
on_task_future=_remember_task_future,
601+
):
526602
yield json_dumps(StateUpdate(delta=delta)) + "\n"
527603

528604
return _UploadStreamingResponse(
529605
_ndjson_updates(),
530606
media_type="application/x-ndjson",
531607
on_finish=_close_form_data,
608+
on_disconnect=_cancel_upload_task,
532609
)
533610

534611

tests/units/reflex_base/event/processor/test_event_processor.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,30 @@ async def test_stream_delta_not_configured_raises():
518518
pass
519519

520520

521+
async def test_stream_delta_calls_on_task_future(token: str):
522+
"""enqueue_stream_delta exposes the tracked EventFuture immediately.
523+
524+
Args:
525+
token: The client token.
526+
"""
527+
ep = EventProcessor(graceful_shutdown_timeout=2)
528+
ep.configure()
529+
captured = []
530+
async with ep:
531+
event = Event.from_event_type(noop_event())[0]
532+
collected = [
533+
d
534+
async for d in ep.enqueue_stream_delta(
535+
token,
536+
event,
537+
on_task_future=captured.append,
538+
)
539+
]
540+
assert collected == []
541+
assert len(captured) == 1
542+
assert captured[0].done()
543+
544+
521545
async def test_sequential_chained_events_run_in_order(token: str):
522546
"""Chained events enqueued by a handler run in the order they were enqueued.
523547

tests/units/test_app.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,6 +1306,79 @@ async def send(_message):
13061306
assert bio.closed
13071307

13081308

1309+
@pytest.mark.asyncio
1310+
async def test_upload_file_cancels_buffered_handler_on_disconnect_before_future_capture(
1311+
token: str,
1312+
):
1313+
"""Buffered uploads cancel the handler even if disconnect wins the race.
1314+
1315+
This exercises the ASGI 2.4 path where the response must watch
1316+
``receive()`` directly because Starlette does not listen for disconnects
1317+
while streaming the response body.
1318+
1319+
Args:
1320+
token: A token.
1321+
"""
1322+
request_mock = unittest.mock.Mock()
1323+
request_mock.headers = {
1324+
"reflex-client-token": token,
1325+
"reflex-event-handler": f"{FileUploadState.get_full_name()}.multi_handle_upload",
1326+
}
1327+
1328+
bio = io.BytesIO(b"contents of image one")
1329+
file1 = UploadFile(filename="image1.jpg", file=bio)
1330+
form_data = FormData([("files", file1)])
1331+
original_close = form_data.close
1332+
form_close = AsyncMock(side_effect=original_close)
1333+
form_data.close = form_close
1334+
1335+
async def form(): # noqa: RUF029
1336+
return form_data
1337+
1338+
request_mock.form = form
1339+
1340+
cancelled = asyncio.Event()
1341+
task_future = Mock()
1342+
task_future.done = Mock(side_effect=cancelled.is_set)
1343+
task_future.cancel = Mock(side_effect=cancelled.set)
1344+
1345+
async def enqueue_stream_delta(_token, _event, on_task_future=None):
1346+
assert on_task_future is not None
1347+
on_task_future(task_future)
1348+
await cancelled.wait()
1349+
if False: # pragma: no cover
1350+
yield {}
1351+
1352+
app = Mock(
1353+
event_processor=Mock(enqueue_stream_delta=enqueue_stream_delta),
1354+
)
1355+
1356+
upload_fn = upload(app)
1357+
streaming_response = await upload_fn(request_mock)
1358+
1359+
assert isinstance(streaming_response, StreamingResponse)
1360+
1361+
async def receive():
1362+
await asyncio.sleep(0)
1363+
return {"type": "http.disconnect"}
1364+
1365+
async def send(_message): # noqa: RUF029
1366+
return None
1367+
1368+
await asyncio.wait_for(
1369+
streaming_response(
1370+
{"type": "http", "asgi": {"spec_version": "2.4"}},
1371+
receive,
1372+
send,
1373+
),
1374+
timeout=1,
1375+
)
1376+
1377+
assert task_future.cancel.call_count == 1
1378+
assert form_close.await_count == 1
1379+
assert bio.closed
1380+
1381+
13091382
@pytest.mark.asyncio
13101383
@pytest.mark.parametrize(
13111384
"state",

0 commit comments

Comments
 (0)