From 2c3ca72c205ee21c9aa111b20da65cb235e0d455 Mon Sep 17 00:00:00 2001 From: Lendemor Date: Tue, 15 Jul 2025 19:00:24 +0200 Subject: [PATCH] add event_action flags to rx.event decorator --- reflex/event.py | 75 +++++++++++++++++-- reflex/state.py | 7 +- tests/units/test_event.py | 152 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 226 insertions(+), 8 deletions(-) diff --git a/reflex/event.py b/reflex/event.py index f0dc035c282..07da9f90bf9 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -109,7 +109,7 @@ def stop_propagation(self) -> Self: """ return dataclasses.replace( self, - event_actions={"stopPropagation": True, **self.event_actions}, + event_actions={**self.event_actions, "stopPropagation": True}, ) @property @@ -121,7 +121,7 @@ def prevent_default(self) -> Self: """ return dataclasses.replace( self, - event_actions={"preventDefault": True, **self.event_actions}, + event_actions={**self.event_actions, "preventDefault": True}, ) def throttle(self, limit_ms: int) -> Self: @@ -135,7 +135,7 @@ def throttle(self, limit_ms: int) -> Self: """ return dataclasses.replace( self, - event_actions={"throttle": limit_ms, **self.event_actions}, + event_actions={**self.event_actions, "throttle": limit_ms}, ) def debounce(self, delay_ms: int) -> Self: @@ -149,7 +149,7 @@ def debounce(self, delay_ms: int) -> Self: """ return dataclasses.replace( self, - event_actions={"debounce": delay_ms, **self.event_actions}, + event_actions={**self.event_actions, "debounce": delay_ms}, ) @property @@ -161,7 +161,7 @@ def temporal(self) -> Self: """ return dataclasses.replace( self, - event_actions={"temporal": True, **self.event_actions}, + event_actions={**self.event_actions, "temporal": True}, ) @@ -2211,7 +2211,15 @@ class EventNamespace: @overload def __new__( - cls, func: None = None, *, background: bool | None = None + cls, + func: None = None, + *, + background: bool | None = None, + stop_propagation: bool | None = None, + prevent_default: bool | None = None, + throttle: int | None = None, + debounce: int | None = None, + temporal: bool | None = None, ) -> Callable[ [Callable[[BASE_STATE, Unpack[P]], Any]], EventCallback[Unpack[P]] # pyright: ignore [reportInvalidTypeVarUse] ]: ... @@ -2222,6 +2230,11 @@ def __new__( func: Callable[[BASE_STATE, Unpack[P]], Any], *, background: bool | None = None, + stop_propagation: bool | None = None, + prevent_default: bool | None = None, + throttle: int | None = None, + debounce: int | None = None, + temporal: bool | None = None, ) -> EventCallback[Unpack[P]]: ... def __new__( @@ -2229,6 +2242,11 @@ def __new__( func: Callable[[BASE_STATE, Unpack[P]], Any] | None = None, *, background: bool | None = None, + stop_propagation: bool | None = None, + prevent_default: bool | None = None, + throttle: int | None = None, + debounce: int | None = None, + temporal: bool | None = None, ) -> ( EventCallback[Unpack[P]] | Callable[[Callable[[BASE_STATE, Unpack[P]], Any]], EventCallback[Unpack[P]]] @@ -2238,6 +2256,11 @@ def __new__( Args: func: The function to wrap. background: Whether the event should be run in the background. Defaults to False. + stop_propagation: Whether to stop the event from bubbling up the DOM tree. + prevent_default: Whether to prevent the default behavior of the event. + throttle: Throttle the event handler to limit calls (in milliseconds). + debounce: Debounce the event handler to delay calls (in milliseconds). + temporal: Whether the event should be temporal. Raises: TypeError: If background is True and the function is not a coroutine or async generator. # noqa: DAR402 @@ -2246,6 +2269,30 @@ def __new__( The wrapped function. """ + def _build_event_actions(): + """Build event_actions dict from decorator parameters. + + Returns: + Dict of event actions to apply, or empty dict if none specified. + """ + if not any( + [stop_propagation, prevent_default, throttle, debounce, temporal] + ): + return {} + + event_actions = {} + if stop_propagation is not None: + event_actions["stopPropagation"] = stop_propagation + if prevent_default is not None: + event_actions["preventDefault"] = prevent_default + if throttle is not None: + event_actions["throttle"] = throttle + if debounce is not None: + event_actions["debounce"] = debounce + if temporal is not None: + event_actions["temporal"] = temporal + return event_actions + def wrapper( func: Callable[[BASE_STATE, Unpack[P]], T], ) -> EventCallback[Unpack[P]]: @@ -2281,8 +2328,22 @@ def wrapper( object.__setattr__(func, "__name__", name) object.__setattr__(func, "__qualname__", name) state_cls._add_event_handler(name, func) - return getattr(state_cls, name) + event_callback = getattr(state_cls, name) + + # Apply decorator event actions + event_actions = _build_event_actions() + if event_actions: + # Create new EventCallback with updated event_actions + event_callback = dataclasses.replace( + event_callback, event_actions=event_actions + ) + + return event_callback + # Store decorator event actions on the function for later processing + event_actions = _build_event_actions() + if event_actions: + func._rx_event_actions = event_actions # pyright: ignore [reportFunctionMemberAccess] return func # pyright: ignore [reportReturnType] if func is not None: diff --git a/reflex/state.py b/reflex/state.py index 92f6744138c..aa269581d95 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1100,7 +1100,12 @@ def _create_event_handler(cls, fn: Any): Returns: The event handler. """ - return EventHandler(fn=fn, state_full_name=cls.get_full_name()) + # Check if function has stored event_actions from decorator + event_actions = getattr(fn, "_rx_event_actions", {}) + + return EventHandler( + fn=fn, state_full_name=cls.get_full_name(), event_actions=event_actions + ) @classmethod def _create_setvar(cls): diff --git a/tests/units/test_event.py b/tests/units/test_event.py index 3f101a4d51f..7731e49e18b 100644 --- a/tests/units/test_event.py +++ b/tests/units/test_event.py @@ -5,6 +5,7 @@ import reflex as rx from reflex.constants.compiler import Hooks, Imports from reflex.event import ( + BACKGROUND_TASK_MARKER, Event, EventChain, EventHandler, @@ -493,6 +494,157 @@ def get_handler(self, arg: Var[str]): _ = rx.input(on_change=w.get_handler) +def test_event_decorator_with_event_actions(): + """Test that @rx.event decorator can accept event action parameters.""" + + class MyTestState(BaseState): + # Test individual event actions + @event(stop_propagation=True) + def handle_stop_prop(self): + pass + + @event(prevent_default=True) + def handle_prevent_default(self): + pass + + @event(throttle=500) + def handle_throttle(self): + pass + + @event(debounce=300) + def handle_debounce(self): + pass + + @event(temporal=True) + def handle_temporal(self): + pass + + # Test multiple event actions combined + @event(stop_propagation=True, prevent_default=True, throttle=1000) + def handle_multiple(self): + pass + + # Test with background parameter (existing functionality) + @event(background=True, temporal=True) + async def handle_background_temporal(self): + pass + + # Test no event actions (existing behavior) + @event + def handle_no_actions(self): + pass + + # Test individual event actions are applied + stop_prop_handler = MyTestState.handle_stop_prop + assert isinstance(stop_prop_handler, EventHandler) + assert stop_prop_handler.event_actions == {"stopPropagation": True} + + prevent_default_handler = MyTestState.handle_prevent_default + assert prevent_default_handler.event_actions == {"preventDefault": True} + + throttle_handler = MyTestState.handle_throttle + assert throttle_handler.event_actions == {"throttle": 500} + + debounce_handler = MyTestState.handle_debounce + assert debounce_handler.event_actions == {"debounce": 300} + + temporal_handler = MyTestState.handle_temporal + assert temporal_handler.event_actions == {"temporal": True} + + # Test multiple event actions are combined correctly + multiple_handler = MyTestState.handle_multiple + assert multiple_handler.event_actions == { + "stopPropagation": True, + "preventDefault": True, + "throttle": 1000, + } + + # Test background + event actions work together + bg_temporal_handler = MyTestState.handle_background_temporal + assert bg_temporal_handler.event_actions == {"temporal": True} + assert hasattr(bg_temporal_handler.fn, BACKGROUND_TASK_MARKER) # pyright: ignore [reportAttributeAccessIssue] + + # Test no event actions (existing behavior preserved) + no_actions_handler = MyTestState.handle_no_actions + assert no_actions_handler.event_actions == {} + + +def test_event_decorator_actions_can_be_overridden(): + """Test that decorator event actions can still be overridden by chaining.""" + + class MyTestState(BaseState): + @event(throttle=500, stop_propagation=True) + def handle_with_defaults(self): + pass + + # Get the handler with default actions + handler = MyTestState.handle_with_defaults + assert handler.event_actions == {"throttle": 500, "stopPropagation": True} + + # Chain additional actions - should combine + handler_with_prevent_default = handler.prevent_default + assert handler_with_prevent_default.event_actions == { + "throttle": 500, + "stopPropagation": True, + "preventDefault": True, + } + + # Chain throttle with different value - should override + handler_with_new_throttle = handler.throttle(1000) + assert handler_with_new_throttle.event_actions == { + "throttle": 1000, # New value overrides default + "stopPropagation": True, + } + + # Original handler should be unchanged + assert handler.event_actions == {"throttle": 500, "stopPropagation": True} + + +def test_event_decorator_with_none_values(): + """Test that None values in decorator don't create event actions.""" + + class MyTestState(BaseState): + @event(stop_propagation=None, prevent_default=None, throttle=None) + def handle_all_none(self): + pass + + @event(stop_propagation=True, prevent_default=None, throttle=500, debounce=None) + def handle_mixed(self): + pass + + # All None should result in no event actions + all_none_handler = MyTestState.handle_all_none + assert all_none_handler.event_actions == {} + + # Only non-None values should be included + mixed_handler = MyTestState.handle_mixed + assert mixed_handler.event_actions == {"stopPropagation": True, "throttle": 500} + + +def test_event_decorator_backward_compatibility(): + """Test that existing code without event action parameters continues to work.""" + + class MyTestState(BaseState): + @event + def handle_old_style(self): + pass + + @event(background=True) + async def handle_old_background(self): + pass + + # Old style without parameters should work unchanged + old_handler = MyTestState.handle_old_style + assert isinstance(old_handler, EventHandler) + assert old_handler.event_actions == {} + assert not hasattr(old_handler.fn, BACKGROUND_TASK_MARKER) # pyright: ignore [reportAttributeAccessIssue] + + # Old background parameter should work unchanged + bg_handler = MyTestState.handle_old_background + assert bg_handler.event_actions == {} + assert hasattr(bg_handler.fn, BACKGROUND_TASK_MARKER) # pyright: ignore [reportAttributeAccessIssue] + + def test_event_var_in_rx_cond(): """Test that EventVar and EventChainVar cannot be used in rx.cond().""" from reflex.components.core.cond import cond as rx_cond