Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 181 additions & 18 deletions reflex/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,30 @@
)
from reflex.vars.object import ObjectVar

_global_event_handlers: dict[str, EventHandler] = {}


def register_event_handler(name: str, handler: EventHandler) -> None:
"""Register a decentralized event handler.

Args:
name: The name of the event handler.
handler: The event handler.
"""
_global_event_handlers[name] = handler


def get_event_handler(name: str) -> EventHandler | None:
"""Get a decentralized event handler by name.

Args:
name: The name of the event handler.

Returns:
The event handler, or None if not found.
"""
return _global_event_handlers.get(name)


@dataclasses.dataclass(
init=True,
Expand Down Expand Up @@ -178,7 +202,7 @@ class EventHandler(EventActionsMixin):

# The full name of the state class this event handler is attached to.
# Empty string means this event handler is a server side event.
state_full_name: str = dataclasses.field(default="")
state_full_name: str | None = dataclasses.field(default="")

@classmethod
def __class_getitem__(cls, args_spec: str) -> Annotated:
Expand Down Expand Up @@ -261,6 +285,13 @@ def __call__(self, *args: Any, **kwargs: Any) -> EventSpec:
) from e
payload = tuple(zip(fn_args, values, strict=False))

# Check if this is a decentralized event handler
if self.state_full_name is None:
from reflex.utils import format

name = format.to_snake_case(self.fn.__qualname__)
register_event_handler(name, self)

# Return the event spec.
return EventSpec(
handler=self, args=payload, event_actions=self.event_actions.copy()
Expand Down Expand Up @@ -496,13 +527,19 @@ def create(

# If the input is a callable, create an event chain.
elif isinstance(value, Callable):
result = call_event_fn(value, args_spec, key=key)
if isinstance(result, Var):
# Recursively call this function if the lambda returned an EventChain Var.
return cls.create(
value=result, args_spec=args_spec, key=key, **event_chain_kwargs
)
events = [*result]
# Check if this is a decentralized event handler
if is_decentralized_event_handler(value):
wrapped_fn = wrap_decentralized_handler(value)
# Create an event spec directly
events = [wrapped_fn()]
else:
result = call_event_fn(value, args_spec, key=key)
if isinstance(result, Var):
# Recursively call this function if the lambda returned an EventChain Var.
return cls.create(
value=result, args_spec=args_spec, key=key, **event_chain_kwargs
)
events = [*result]

# Otherwise, raise an error.
else:
Expand Down Expand Up @@ -1299,13 +1336,14 @@ def call_event_handler(
# Handle partial application of EventSpec args
return event_callback.add_args(*event_spec_args)

check_fn_match_arg_spec(
event_callback.fn,
event_spec,
key,
bool(event_callback.state_full_name),
event_callback.fn.__qualname__,
)
if not is_decentralized_event_handler(event_callback.fn):
check_fn_match_arg_spec(
event_callback.fn,
event_spec,
key,
bool(event_callback.state_full_name),
event_callback.fn.__qualname__,
)

all_acceptable_specs = (
[event_spec] if not isinstance(event_spec, Sequence) else event_spec
Expand Down Expand Up @@ -1492,6 +1530,9 @@ def check_fn_match_arg_spec(
Raises:
EventFnArgMismatchError: Raised if the number of mandatory arguments do not match
"""
if is_decentralized_event_handler(user_func):
return

user_args = list(inspect.signature(user_func).parameters)
# Drop the first argument if it's a bound method
if inspect.ismethod(user_func) and user_func.__self__ is not None:
Expand All @@ -1518,6 +1559,102 @@ def check_fn_match_arg_spec(
)


DECENTRALIZED_EVENT_MARKER = "_rx_decentralized_event"


def is_decentralized_event_handler(fn: Callable) -> bool:
"""Check if a function is a decentralized event handler.

Args:
fn: The function to check.

Returns:
Whether the function is a decentralized event handler.
"""
# Check if the function has been decorated with @rx.event
if not hasattr(fn, "__qualname__"):
return False

# Check if the function has the decentralized event marker
return hasattr(fn, DECENTRALIZED_EVENT_MARKER)


def wrap_decentralized_handler(fn: Callable) -> Callable:
"""Wrap a decentralized event handler to be used with component events.

This creates a wrapper that doesn't require the state parameter when called
from a component event, but will pass the state when the event is processed.

Args:
fn: The decentralized event handler to wrap.

Returns:
A wrapped function that can be used with component events.

Raises:
ValueError: If the event handler doesn't have at least one parameter (state).
"""
# Get the signature of the function to determine parameter names
sig = inspect.signature(fn)
param_names = list(sig.parameters.keys())

# The first parameter should be the state parameter
if not param_names:
raise ValueError(
f"Event handler {fn.__name__} must have at least one parameter (state)"
)

# Create a wrapper function that doesn't require the state parameter
def wrapper(*args, **kwargs):
# Get or create the event handler
from reflex.utils import format

name = format.to_snake_case(fn.__qualname__)
handler = _global_event_handlers.get(name)
if handler is None:
handler = EventHandler(fn=fn, state_full_name=None)
register_event_handler(name, handler)

# Create an event spec with the provided arguments
arg_specs = []

# Skip the first parameter (state) when creating arg specs
param_offset = 1 # Skip the state parameter

for i, arg in enumerate(args):
# Create a var for the arg
var_arg = Var.create(arg)

# Get the parameter name if available, otherwise use a generic name
param_name = (
param_names[i + param_offset]
if i + param_offset < len(param_names)
else f"arg{i}"
)

# Add the arg to the arg specs
arg_specs.append((Var.create_safe(param_name), var_arg))

for name, arg in kwargs.items():
# Create a var for the arg
var_arg = Var.create(arg)

# Add the arg to the arg specs
arg_specs.append((Var.create_safe(name), var_arg))

return EventSpec(handler=handler, args=tuple(arg_specs))

wrapper.__name__ = fn.__name__
wrapper.__qualname__ = fn.__qualname__
wrapper.__doc__ = fn.__doc__
wrapper.__module__ = fn.__module__

# Preserve the decentralized event marker
setattr(wrapper, DECENTRALIZED_EVENT_MARKER, True)

return wrapper


def call_event_fn(
fn: Callable,
arg_spec: ArgsSpec | Sequence[ArgsSpec],
Expand All @@ -1543,8 +1680,18 @@ def call_event_fn(
from reflex.event import EventHandler, EventSpec
from reflex.utils.exceptions import EventHandlerValueError

# Check that fn signature matches arg_spec
check_fn_match_arg_spec(fn, arg_spec, key=key)
# Check if this is a decentralized event handler
if is_decentralized_event_handler(fn):
wrapped_fn = wrap_decentralized_handler(fn)

parsed_args = parse_args_spec(arg_spec)

# Call the wrapped function with the parsed arguments
return [wrapped_fn(*parsed_args)]

# Check that fn signature matches arg_spec (skip for decentralized event handlers)
if not is_decentralized_event_handler(fn):
check_fn_match_arg_spec(fn, arg_spec, key=key)

parsed_args = parse_args_spec(arg_spec)

Expand Down Expand Up @@ -2066,7 +2213,23 @@ def wrapper(
setattr(func, BACKGROUND_TASK_MARKER, True)
if getattr(func, "__name__", "").startswith("_"):
raise ValueError("Event handlers cannot be private.")
return func # pyright: ignore [reportReturnType]

# Check if this is a method (defined in a class) or a standalone function
if hasattr(func, "__qualname__") and "." in func.__qualname__:
return func # pyright: ignore [reportReturnType]
else:
# This is a decentralized event handler
handler = EventHandler(fn=func, state_full_name=None)
if background:
setattr(handler, BACKGROUND_TASK_MARKER, True)
# Mark the function as a decentralized event handler
setattr(func, DECENTRALIZED_EVENT_MARKER, True)

# Create a wrapped version that can handle parameters
wrapped = wrap_decentralized_handler(func)

# Return the wrapped function instead of the original
return wrapped # pyright: ignore [reportReturnType]

if func is not None:
return wrapper(func)
Expand Down
29 changes: 27 additions & 2 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1586,7 +1586,6 @@ def _get_event_handler(
Args:
event: The event to get the handler for.


Returns:
The event handler.

Expand All @@ -1595,6 +1594,29 @@ def _get_event_handler(
"""
# Get the event handler.
path = event.name.split(".")

if "." not in event.name:
# Check if the event handler exists in the class's event_handlers
cls = type(self)
if event.name in cls.event_handlers:
handler = cls.event_handlers[event.name]

# For background tasks, proxy the state
if handler.is_background:
return StateProxy(self), handler

return self, handler

from reflex.event import get_event_handler

handler = get_event_handler(event.name)
if handler is not None:
# For background tasks, proxy the state
if handler.is_background:
return StateProxy(self), handler

return self, handler

path, name = path[:-1], path[-1]
substate = self.get_substate(path)
if not substate:
Expand Down Expand Up @@ -1753,7 +1775,10 @@ async def _process_event(
from reflex.utils import telemetry

# Get the function to process the event.
fn = functools.partial(handler.fn, state)
if handler.state_full_name is None:
fn = handler.fn
else:
fn = functools.partial(handler.fn, state)

try:
type_hints = typing.get_type_hints(handler.fn)
Expand Down
5 changes: 5 additions & 0 deletions reflex/utils/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,11 @@ def format_event_handler(handler: EventHandler) -> str:
Returns:
The formatted function.
"""
if handler.state_full_name is None:
from reflex.utils import format

return format.to_snake_case(handler.fn.__qualname__)

state, name = get_event_handler_parts(handler)
if state == "":
return name
Expand Down
Loading
Loading