diff --git a/reflex/event.py b/reflex/event.py index afb03367cb5..bb8898907d8 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -55,6 +55,30 @@ ) from reflex.vars.object import ObjectVar +_global_event_handlers: dict[str, EventHandler] = {} + + +def register_event_handler(name: str, handler: EventHandler) -> None: + """Register a decentralized event handler. + + Args: + name: The name of the event handler. + handler: The event handler. + """ + _global_event_handlers[name] = handler + + +def get_event_handler(name: str) -> EventHandler | None: + """Get a decentralized event handler by name. + + Args: + name: The name of the event handler. + + Returns: + The event handler, or None if not found. + """ + return _global_event_handlers.get(name) + @dataclasses.dataclass( init=True, @@ -178,7 +202,7 @@ class EventHandler(EventActionsMixin): # The full name of the state class this event handler is attached to. # Empty string means this event handler is a server side event. - state_full_name: str = dataclasses.field(default="") + state_full_name: str | None = dataclasses.field(default="") @classmethod def __class_getitem__(cls, args_spec: str) -> Annotated: @@ -261,6 +285,13 @@ def __call__(self, *args: Any, **kwargs: Any) -> EventSpec: ) from e payload = tuple(zip(fn_args, values, strict=False)) + # Check if this is a decentralized event handler + if self.state_full_name is None: + from reflex.utils import format + + name = format.to_snake_case(self.fn.__qualname__) + register_event_handler(name, self) + # Return the event spec. return EventSpec( handler=self, args=payload, event_actions=self.event_actions.copy() @@ -496,13 +527,19 @@ def create( # If the input is a callable, create an event chain. elif isinstance(value, Callable): - result = call_event_fn(value, args_spec, key=key) - if isinstance(result, Var): - # Recursively call this function if the lambda returned an EventChain Var. - return cls.create( - value=result, args_spec=args_spec, key=key, **event_chain_kwargs - ) - events = [*result] + # Check if this is a decentralized event handler + if is_decentralized_event_handler(value): + wrapped_fn = wrap_decentralized_handler(value) + # Create an event spec directly + events = [wrapped_fn()] + else: + result = call_event_fn(value, args_spec, key=key) + if isinstance(result, Var): + # Recursively call this function if the lambda returned an EventChain Var. + return cls.create( + value=result, args_spec=args_spec, key=key, **event_chain_kwargs + ) + events = [*result] # Otherwise, raise an error. else: @@ -1299,13 +1336,14 @@ def call_event_handler( # Handle partial application of EventSpec args return event_callback.add_args(*event_spec_args) - check_fn_match_arg_spec( - event_callback.fn, - event_spec, - key, - bool(event_callback.state_full_name), - event_callback.fn.__qualname__, - ) + if not is_decentralized_event_handler(event_callback.fn): + check_fn_match_arg_spec( + event_callback.fn, + event_spec, + key, + bool(event_callback.state_full_name), + event_callback.fn.__qualname__, + ) all_acceptable_specs = ( [event_spec] if not isinstance(event_spec, Sequence) else event_spec @@ -1492,6 +1530,9 @@ def check_fn_match_arg_spec( Raises: EventFnArgMismatchError: Raised if the number of mandatory arguments do not match """ + if is_decentralized_event_handler(user_func): + return + user_args = list(inspect.signature(user_func).parameters) # Drop the first argument if it's a bound method if inspect.ismethod(user_func) and user_func.__self__ is not None: @@ -1518,6 +1559,102 @@ def check_fn_match_arg_spec( ) +DECENTRALIZED_EVENT_MARKER = "_rx_decentralized_event" + + +def is_decentralized_event_handler(fn: Callable) -> bool: + """Check if a function is a decentralized event handler. + + Args: + fn: The function to check. + + Returns: + Whether the function is a decentralized event handler. + """ + # Check if the function has been decorated with @rx.event + if not hasattr(fn, "__qualname__"): + return False + + # Check if the function has the decentralized event marker + return hasattr(fn, DECENTRALIZED_EVENT_MARKER) + + +def wrap_decentralized_handler(fn: Callable) -> Callable: + """Wrap a decentralized event handler to be used with component events. + + This creates a wrapper that doesn't require the state parameter when called + from a component event, but will pass the state when the event is processed. + + Args: + fn: The decentralized event handler to wrap. + + Returns: + A wrapped function that can be used with component events. + + Raises: + ValueError: If the event handler doesn't have at least one parameter (state). + """ + # Get the signature of the function to determine parameter names + sig = inspect.signature(fn) + param_names = list(sig.parameters.keys()) + + # The first parameter should be the state parameter + if not param_names: + raise ValueError( + f"Event handler {fn.__name__} must have at least one parameter (state)" + ) + + # Create a wrapper function that doesn't require the state parameter + def wrapper(*args, **kwargs): + # Get or create the event handler + from reflex.utils import format + + name = format.to_snake_case(fn.__qualname__) + handler = _global_event_handlers.get(name) + if handler is None: + handler = EventHandler(fn=fn, state_full_name=None) + register_event_handler(name, handler) + + # Create an event spec with the provided arguments + arg_specs = [] + + # Skip the first parameter (state) when creating arg specs + param_offset = 1 # Skip the state parameter + + for i, arg in enumerate(args): + # Create a var for the arg + var_arg = Var.create(arg) + + # Get the parameter name if available, otherwise use a generic name + param_name = ( + param_names[i + param_offset] + if i + param_offset < len(param_names) + else f"arg{i}" + ) + + # Add the arg to the arg specs + arg_specs.append((Var.create_safe(param_name), var_arg)) + + for name, arg in kwargs.items(): + # Create a var for the arg + var_arg = Var.create(arg) + + # Add the arg to the arg specs + arg_specs.append((Var.create_safe(name), var_arg)) + + return EventSpec(handler=handler, args=tuple(arg_specs)) + + wrapper.__name__ = fn.__name__ + wrapper.__qualname__ = fn.__qualname__ + wrapper.__doc__ = fn.__doc__ + wrapper.__module__ = fn.__module__ + + # Preserve the decentralized event marker + setattr(wrapper, DECENTRALIZED_EVENT_MARKER, True) + + return wrapper + + def call_event_fn( fn: Callable, arg_spec: ArgsSpec | Sequence[ArgsSpec], @@ -1543,8 +1680,18 @@ def call_event_fn( from reflex.event import EventHandler, EventSpec from reflex.utils.exceptions import EventHandlerValueError - # Check that fn signature matches arg_spec - check_fn_match_arg_spec(fn, arg_spec, key=key) + # Check if this is a decentralized event handler + if is_decentralized_event_handler(fn): + wrapped_fn = wrap_decentralized_handler(fn) + + parsed_args = parse_args_spec(arg_spec) + + # Call the wrapped function with the parsed arguments + return [wrapped_fn(*parsed_args)] + + # Check that fn signature matches arg_spec (skip for decentralized event handlers) + if not is_decentralized_event_handler(fn): + check_fn_match_arg_spec(fn, arg_spec, key=key) parsed_args = parse_args_spec(arg_spec) @@ -2066,7 +2213,23 @@ def wrapper( setattr(func, BACKGROUND_TASK_MARKER, True) if getattr(func, "__name__", "").startswith("_"): raise ValueError("Event handlers cannot be private.") - return func # pyright: ignore [reportReturnType] + + # Check if this is a method (defined in a class) or a standalone function + if hasattr(func, "__qualname__") and "." in func.__qualname__: + return func # pyright: ignore [reportReturnType] + else: + # This is a decentralized event handler + handler = EventHandler(fn=func, state_full_name=None) + if background: + setattr(handler, BACKGROUND_TASK_MARKER, True) + # Mark the function as a decentralized event handler + setattr(func, DECENTRALIZED_EVENT_MARKER, True) + + # Create a wrapped version that can handle parameters + wrapped = wrap_decentralized_handler(func) + + # Return the wrapped function instead of the original + return wrapped # pyright: ignore [reportReturnType] if func is not None: return wrapper(func) diff --git a/reflex/state.py b/reflex/state.py index 588383518b9..46b5ba08acc 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1586,7 +1586,6 @@ def _get_event_handler( Args: event: The event to get the handler for. - Returns: The event handler. @@ -1595,6 +1594,29 @@ def _get_event_handler( """ # Get the event handler. path = event.name.split(".") + + if "." not in event.name: + # Check if the event handler exists in the class's event_handlers + cls = type(self) + if event.name in cls.event_handlers: + handler = cls.event_handlers[event.name] + + # For background tasks, proxy the state + if handler.is_background: + return StateProxy(self), handler + + return self, handler + + from reflex.event import get_event_handler + + handler = get_event_handler(event.name) + if handler is not None: + # For background tasks, proxy the state + if handler.is_background: + return StateProxy(self), handler + + return self, handler + path, name = path[:-1], path[-1] substate = self.get_substate(path) if not substate: @@ -1753,7 +1775,10 @@ async def _process_event( from reflex.utils import telemetry # Get the function to process the event. - fn = functools.partial(handler.fn, state) + if handler.state_full_name is None: + fn = handler.fn + else: + fn = functools.partial(handler.fn, state) try: type_hints = typing.get_type_hints(handler.fn) diff --git a/reflex/utils/format.py b/reflex/utils/format.py index cba91b4ddb9..12619e90115 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -483,6 +483,11 @@ def format_event_handler(handler: EventHandler) -> str: Returns: The formatted function. """ + if handler.state_full_name is None: + from reflex.utils import format + + return format.to_snake_case(handler.fn.__qualname__) + state, name = get_event_handler_parts(handler) if state == "": return name diff --git a/tests/integration/test_decentralized_event_handlers.py b/tests/integration/test_decentralized_event_handlers.py new file mode 100644 index 00000000000..42f32a9a6a3 --- /dev/null +++ b/tests/integration/test_decentralized_event_handlers.py @@ -0,0 +1,129 @@ +"""Test decentralized event handlers functionality.""" + +import pytest +from selenium.webdriver.common.by import By + +from reflex.testing import AppHarness, WebDriver + + +def DecentralizedEventHandlers(): + """Test that decentralized event handlers work as expected.""" + import reflex as rx + + class TestState(rx.State): + count: int = 0 + value: int = 0 + + @rx.event + def increment(self): + """Increment the counter.""" + self.count += 1 + + @rx.event + def on_load(state: TestState): + """Event handler for loading the state. + + Args: + state: The state to modify. + """ + state.count = 10 + + @rx.event + def reset_count(state: TestState): + """Event handler for resetting the count. + + Args: + state: The state to modify. + """ + state.count = 0 + + @rx.event + def set_value(state: TestState, value: str): + """Set the value with a parameter. + + Args: + state: The state to modify. + value: The value to set. + """ + state.value = int(value) + + def index(): + return rx.vstack( + rx.heading(TestState.count, id="counter"), + rx.heading(TestState.value, id="value"), + rx.button("Increment", on_click=TestState.increment, id="increment"), + rx.button("Reset", on_click=reset_count, id="reset"), + rx.button("Set Value", on_click=set_value("42"), id="set-value"), + rx.text("Loaded", on_mount=on_load, id="loaded"), + ) + + app = rx.App() + app.add_page(index) + + +@pytest.fixture(scope="module") +def decentralized_handlers( + tmp_path_factory, +): + """Start DecentralizedEventHandlers app at tmp_path via AppHarness. + + Args: + tmp_path_factory: pytest tmp_path_factory fixture + + Yields: + running AppHarness instance + """ + with AppHarness.create( + root=tmp_path_factory.mktemp("decentralized_handlers"), + app_source=DecentralizedEventHandlers, + ) as harness: + yield harness + + +@pytest.fixture +def driver(decentralized_handlers: AppHarness): + """Get an instance of the browser open to the app. + + Args: + decentralized_handlers: harness for DecentralizedEventHandlers app + + Yields: + WebDriver instance. + """ + assert decentralized_handlers.app_instance is not None, "app is not running" + driver = decentralized_handlers.frontend() + try: + yield driver + finally: + driver.quit() + + +def test_decentralized_event_handlers( + decentralized_handlers: AppHarness, + driver: WebDriver, +): + """Test that decentralized event handlers work as expected. + + Args: + decentralized_handlers: harness for DecentralizedEventHandlers app + driver: WebDriver instance + """ + assert decentralized_handlers.app_instance is not None + + counter = driver.find_element(By.ID, "counter") + value = driver.find_element(By.ID, "value") + increment_button = driver.find_element(By.ID, "increment") + reset_button = driver.find_element(By.ID, "reset") + set_value_button = driver.find_element(By.ID, "set-value") + + assert decentralized_handlers._poll_for(lambda: counter.text == "10", timeout=5) + assert value.text == "0" + + increment_button.click() + assert decentralized_handlers._poll_for(lambda: counter.text == "11", timeout=5) + + reset_button.click() + assert decentralized_handlers._poll_for(lambda: counter.text == "0", timeout=5) + + set_value_button.click() + assert decentralized_handlers._poll_for(lambda: value.text == "42", timeout=5) diff --git a/tests/units/test_decentralized_handlers_simple.py b/tests/units/test_decentralized_handlers_simple.py new file mode 100644 index 00000000000..2d8bed41a33 --- /dev/null +++ b/tests/units/test_decentralized_handlers_simple.py @@ -0,0 +1,48 @@ +"""Simple test for decentralized event handlers.""" + +import reflex as rx + + +class TestState(rx.State): + """Test state class for decentralized event handlers.""" + + count: int = 0 + + +@rx.event +def reset_count(state: TestState): + """Reset the count to zero. + + Args: + state: The test state to modify. + """ + state.count = 0 + + +@rx.event +def set_count(state: TestState, value: str): + """Set the count to a specific value. + + Args: + state: The test state to modify. + value: The value to set as count. + """ + state.count = int(value) + + +def test_is_decentralized(): + """Test if functions are correctly identified as decentralized event handlers.""" + from reflex.event import is_decentralized_event_handler, wrap_decentralized_handler + + assert is_decentralized_event_handler(reset_count) + + wrapped = wrap_decentralized_handler(reset_count) + assert is_decentralized_event_handler(wrapped) + + assert is_decentralized_event_handler(set_count) + wrapped_with_params = wrap_decentralized_handler(set_count) + assert is_decentralized_event_handler(wrapped_with_params) + + +if __name__ == "__main__": + test_is_decentralized()