Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import io
import json
import sys
import time
import traceback
import urllib.parse
from collections.abc import (
Expand Down Expand Up @@ -1582,7 +1583,10 @@ async def _coro():
sid=state.router.session.session_id,
)

task = asyncio.create_task(_coro())
task = asyncio.create_task(
_coro(),
name=f"reflex_background_task|{event.name}|{time.time()}|{event.token}",
)
self._background_tasks.add(task)
# Clean up task from background_tasks set when complete.
task.add_done_callback(self._background_tasks.discard)
Expand Down Expand Up @@ -1727,7 +1731,8 @@ async def process(
"reload",
data=event,
to=sid,
)
),
name=f"reflex_emit_reload|{event.name}|{time.time()}|{event.token}",
)
return
# re-assign only when the value is different
Expand Down Expand Up @@ -2028,7 +2033,8 @@ def on_disconnect(self, sid: str):
if disconnect_token:
# Use async cleanup through token manager
task = asyncio.create_task(
self._token_manager.disconnect_token(disconnect_token, sid)
self._token_manager.disconnect_token(disconnect_token, sid),
name=f"reflex_disconnect_token|{disconnect_token}|{time.time()}",
)
# Don't await to avoid blocking disconnect, but handle potential errors
task.add_done_callback(
Expand All @@ -2047,12 +2053,14 @@ async def emit_update(self, update: StateUpdate, sid: str) -> None:
# If the sid is None, we are not connected to a client. Prevent sending
# updates to all clients.
return
if sid not in self.sid_to_token:
token = self.sid_to_token.get(sid)
if token is None:
console.warn(f"Attempting to send delta to disconnected websocket {sid}")
return
# Creating a task prevents the update from being blocked behind other coroutines.
await asyncio.create_task(
self.emit(str(constants.SocketEvent.EVENT), update, to=sid)
self.emit(str(constants.SocketEvent.EVENT), update, to=sid),
name=f"reflex_emit_event|{token}|{sid}|{time.time()}",
)

async def on_event(self, sid: str, data: Any):
Expand Down
10 changes: 8 additions & 2 deletions reflex/app_mixins/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import dataclasses
import functools
import inspect
import time
from collections.abc import Callable, Coroutine

from starlette.applications import Starlette
Expand Down Expand Up @@ -36,6 +37,7 @@ async def _run_lifespan_tasks(self, app: Starlette):
if isinstance(task, asyncio.Task):
running_tasks.append(task)
else:
task_name = task.__name__
signature = inspect.signature(task)
if "app" in signature.parameters:
task = functools.partial(task, app=app)
Expand All @@ -44,7 +46,10 @@ async def _run_lifespan_tasks(self, app: Starlette):
await stack.enter_async_context(_t)
console.debug(run_msg.format(type="asynccontextmanager"))
elif isinstance(_t, Coroutine):
task_ = asyncio.create_task(_t)
task_ = asyncio.create_task(
_t,
name=f"reflex_lifespan_task|{task_name}|{time.time()}",
)
task_.add_done_callback(lambda t: t.result())
running_tasks.append(task_)
console.debug(run_msg.format(type="coroutine"))
Expand All @@ -70,9 +75,10 @@ def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):
msg = f"Task {task.__name__} of type generator must be decorated with contextlib.asynccontextmanager."
raise InvalidLifespanTaskTypeError(msg)

task_name = task.__name__ # pyright: ignore [reportAttributeAccessIssue]
if task_kwargs:
original_task = task
task = functools.partial(task, **task_kwargs) # pyright: ignore [reportArgumentType]
functools.update_wrapper(task, original_task) # pyright: ignore [reportArgumentType]
self.lifespan_tasks.add(task)
console.debug(f"Registered lifespan task: {task.__name__}") # pyright: ignore [reportAttributeAccessIssue]
console.debug(f"Registered lifespan task: {task_name}")
3 changes: 2 additions & 1 deletion reflex/istate/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,8 @@ async def set_state(
_substate_key(client_token, substate),
substate,
lock_id,
)
),
name=f"reflex_set_state|{client_token}|{substate.get_full_name()}",
)
for substate in state.substates.values()
]
Expand Down
6 changes: 5 additions & 1 deletion reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import inspect
import pickle
import sys
import time
import typing
import warnings
from collections.abc import AsyncIterator, Callable, Sequence
Expand Down Expand Up @@ -284,7 +285,10 @@ async def _resolve_delta(delta: Delta) -> Delta:
for state_name, state_delta in delta.items():
for var_name, value in state_delta.items():
if asyncio.iscoroutine(value):
tasks[state_name, var_name] = asyncio.create_task(value)
tasks[state_name, var_name] = asyncio.create_task(
value,
name=f"reflex_resolve_delta|{state_name}|{var_name}|{time.time()}",
)
for (state_name, var_name), task in tasks.items():
delta[state_name][var_name] = await task
return delta
Expand Down
5 changes: 4 additions & 1 deletion reflex/utils/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,10 @@ async def async_send(event: str, telemetry_enabled: bool | None, **kwargs):

try:
# Within an event loop context, send the event asynchronously.
task = asyncio.create_task(async_send(event, telemetry_enabled, **kwargs))
task = asyncio.create_task(
async_send(event, telemetry_enabled, **kwargs),
name=f"reflex_send_telemetry_event|{event}",
)
background_tasks.add(task)
task.add_done_callback(background_tasks.discard)
except RuntimeError:
Expand Down
Loading