Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 68 additions & 7 deletions reflex/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def stop_propagation(self) -> Self:
"""
return dataclasses.replace(
self,
event_actions={"stopPropagation": True, **self.event_actions},
event_actions={**self.event_actions, "stopPropagation": True},
)

@property
Expand All @@ -121,7 +121,7 @@ def prevent_default(self) -> Self:
"""
return dataclasses.replace(
self,
event_actions={"preventDefault": True, **self.event_actions},
event_actions={**self.event_actions, "preventDefault": True},
)

def throttle(self, limit_ms: int) -> Self:
Expand All @@ -135,7 +135,7 @@ def throttle(self, limit_ms: int) -> Self:
"""
return dataclasses.replace(
self,
event_actions={"throttle": limit_ms, **self.event_actions},
event_actions={**self.event_actions, "throttle": limit_ms},
)

def debounce(self, delay_ms: int) -> Self:
Expand All @@ -149,7 +149,7 @@ def debounce(self, delay_ms: int) -> Self:
"""
return dataclasses.replace(
self,
event_actions={"debounce": delay_ms, **self.event_actions},
event_actions={**self.event_actions, "debounce": delay_ms},
)

@property
Expand All @@ -161,7 +161,7 @@ def temporal(self) -> Self:
"""
return dataclasses.replace(
self,
event_actions={"temporal": True, **self.event_actions},
event_actions={**self.event_actions, "temporal": True},
)


Expand Down Expand Up @@ -2211,7 +2211,15 @@ class EventNamespace:

@overload
def __new__(
cls, func: None = None, *, background: bool | None = None
cls,
func: None = None,
*,
background: bool | None = None,
stop_propagation: bool | None = None,
prevent_default: bool | None = None,
throttle: int | None = None,
debounce: int | None = None,
temporal: bool | None = None,
) -> Callable[
[Callable[[BASE_STATE, Unpack[P]], Any]], EventCallback[Unpack[P]] # pyright: ignore [reportInvalidTypeVarUse]
]: ...
Expand All @@ -2222,13 +2230,23 @@ def __new__(
func: Callable[[BASE_STATE, Unpack[P]], Any],
*,
background: bool | None = None,
stop_propagation: bool | None = None,
prevent_default: bool | None = None,
throttle: int | None = None,
debounce: int | None = None,
temporal: bool | None = None,
) -> EventCallback[Unpack[P]]: ...

def __new__(
cls,
func: Callable[[BASE_STATE, Unpack[P]], Any] | None = None,
*,
background: bool | None = None,
stop_propagation: bool | None = None,
prevent_default: bool | None = None,
throttle: int | None = None,
debounce: int | None = None,
temporal: bool | None = None,
) -> (
EventCallback[Unpack[P]]
| Callable[[Callable[[BASE_STATE, Unpack[P]], Any]], EventCallback[Unpack[P]]]
Expand All @@ -2238,6 +2256,11 @@ def __new__(
Args:
func: The function to wrap.
background: Whether the event should be run in the background. Defaults to False.
stop_propagation: Whether to stop the event from bubbling up the DOM tree.
prevent_default: Whether to prevent the default behavior of the event.
throttle: Throttle the event handler to limit calls (in milliseconds).
debounce: Debounce the event handler to delay calls (in milliseconds).
temporal: Whether the event should be temporal.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
temporal: Whether the event should be temporal.
temporal: Whether the event should be dropped when the backend is down.


Raises:
TypeError: If background is True and the function is not a coroutine or async generator. # noqa: DAR402
Expand All @@ -2246,6 +2269,30 @@ def __new__(
The wrapped function.
"""

def _build_event_actions():
"""Build event_actions dict from decorator parameters.

Returns:
Dict of event actions to apply, or empty dict if none specified.
"""
if not any(
[stop_propagation, prevent_default, throttle, debounce, temporal]
):
return {}

event_actions = {}
if stop_propagation is not None:
event_actions["stopPropagation"] = stop_propagation
if prevent_default is not None:
event_actions["preventDefault"] = prevent_default
if throttle is not None:
event_actions["throttle"] = throttle
if debounce is not None:
event_actions["debounce"] = debounce
if temporal is not None:
event_actions["temporal"] = temporal
return event_actions

def wrapper(
func: Callable[[BASE_STATE, Unpack[P]], T],
) -> EventCallback[Unpack[P]]:
Expand Down Expand Up @@ -2281,8 +2328,22 @@ def wrapper(
object.__setattr__(func, "__name__", name)
object.__setattr__(func, "__qualname__", name)
state_cls._add_event_handler(name, func)
return getattr(state_cls, name)
event_callback = getattr(state_cls, name)

# Apply decorator event actions
event_actions = _build_event_actions()
if event_actions:
# Create new EventCallback with updated event_actions
event_callback = dataclasses.replace(
event_callback, event_actions=event_actions
)

return event_callback

# Store decorator event actions on the function for later processing
event_actions = _build_event_actions()
if event_actions:
func._rx_event_actions = event_actions # pyright: ignore [reportFunctionMemberAccess]
return func # pyright: ignore [reportReturnType]

if func is not None:
Expand Down
7 changes: 6 additions & 1 deletion reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,7 +1100,12 @@ def _create_event_handler(cls, fn: Any):
Returns:
The event handler.
"""
return EventHandler(fn=fn, state_full_name=cls.get_full_name())
# Check if function has stored event_actions from decorator
event_actions = getattr(fn, "_rx_event_actions", {})

return EventHandler(
fn=fn, state_full_name=cls.get_full_name(), event_actions=event_actions
)

@classmethod
def _create_setvar(cls):
Expand Down
152 changes: 152 additions & 0 deletions tests/units/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import reflex as rx
from reflex.constants.compiler import Hooks, Imports
from reflex.event import (
BACKGROUND_TASK_MARKER,
Event,
EventChain,
EventHandler,
Expand Down Expand Up @@ -493,6 +494,157 @@ def get_handler(self, arg: Var[str]):
_ = rx.input(on_change=w.get_handler)


def test_event_decorator_with_event_actions():
"""Test that @rx.event decorator can accept event action parameters."""

class MyTestState(BaseState):
# Test individual event actions
@event(stop_propagation=True)
def handle_stop_prop(self):
pass

@event(prevent_default=True)
def handle_prevent_default(self):
pass

@event(throttle=500)
def handle_throttle(self):
pass

@event(debounce=300)
def handle_debounce(self):
pass

@event(temporal=True)
def handle_temporal(self):
pass

# Test multiple event actions combined
@event(stop_propagation=True, prevent_default=True, throttle=1000)
def handle_multiple(self):
pass

# Test with background parameter (existing functionality)
@event(background=True, temporal=True)
async def handle_background_temporal(self):
pass

# Test no event actions (existing behavior)
@event
def handle_no_actions(self):
pass

# Test individual event actions are applied
stop_prop_handler = MyTestState.handle_stop_prop
assert isinstance(stop_prop_handler, EventHandler)
assert stop_prop_handler.event_actions == {"stopPropagation": True}

prevent_default_handler = MyTestState.handle_prevent_default
assert prevent_default_handler.event_actions == {"preventDefault": True}

throttle_handler = MyTestState.handle_throttle
assert throttle_handler.event_actions == {"throttle": 500}

debounce_handler = MyTestState.handle_debounce
assert debounce_handler.event_actions == {"debounce": 300}

temporal_handler = MyTestState.handle_temporal
assert temporal_handler.event_actions == {"temporal": True}

# Test multiple event actions are combined correctly
multiple_handler = MyTestState.handle_multiple
assert multiple_handler.event_actions == {
"stopPropagation": True,
"preventDefault": True,
"throttle": 1000,
}

# Test background + event actions work together
bg_temporal_handler = MyTestState.handle_background_temporal
assert bg_temporal_handler.event_actions == {"temporal": True}
assert hasattr(bg_temporal_handler.fn, BACKGROUND_TASK_MARKER) # pyright: ignore [reportAttributeAccessIssue]

# Test no event actions (existing behavior preserved)
no_actions_handler = MyTestState.handle_no_actions
assert no_actions_handler.event_actions == {}


def test_event_decorator_actions_can_be_overridden():
"""Test that decorator event actions can still be overridden by chaining."""

class MyTestState(BaseState):
@event(throttle=500, stop_propagation=True)
def handle_with_defaults(self):
pass

# Get the handler with default actions
handler = MyTestState.handle_with_defaults
assert handler.event_actions == {"throttle": 500, "stopPropagation": True}

# Chain additional actions - should combine
handler_with_prevent_default = handler.prevent_default
assert handler_with_prevent_default.event_actions == {
"throttle": 500,
"stopPropagation": True,
"preventDefault": True,
}

# Chain throttle with different value - should override
handler_with_new_throttle = handler.throttle(1000)
assert handler_with_new_throttle.event_actions == {
"throttle": 1000, # New value overrides default
"stopPropagation": True,
}

# Original handler should be unchanged
assert handler.event_actions == {"throttle": 500, "stopPropagation": True}


def test_event_decorator_with_none_values():
"""Test that None values in decorator don't create event actions."""

class MyTestState(BaseState):
@event(stop_propagation=None, prevent_default=None, throttle=None)
def handle_all_none(self):
pass

@event(stop_propagation=True, prevent_default=None, throttle=500, debounce=None)
def handle_mixed(self):
pass

# All None should result in no event actions
all_none_handler = MyTestState.handle_all_none
assert all_none_handler.event_actions == {}

# Only non-None values should be included
mixed_handler = MyTestState.handle_mixed
assert mixed_handler.event_actions == {"stopPropagation": True, "throttle": 500}


def test_event_decorator_backward_compatibility():
"""Test that existing code without event action parameters continues to work."""

class MyTestState(BaseState):
@event
def handle_old_style(self):
pass

@event(background=True)
async def handle_old_background(self):
pass

# Old style without parameters should work unchanged
old_handler = MyTestState.handle_old_style
assert isinstance(old_handler, EventHandler)
assert old_handler.event_actions == {}
assert not hasattr(old_handler.fn, BACKGROUND_TASK_MARKER) # pyright: ignore [reportAttributeAccessIssue]

# Old background parameter should work unchanged
bg_handler = MyTestState.handle_old_background
assert bg_handler.event_actions == {}
assert hasattr(bg_handler.fn, BACKGROUND_TASK_MARKER) # pyright: ignore [reportAttributeAccessIssue]


def test_event_var_in_rx_cond():
"""Test that EventVar and EventChainVar cannot be used in rx.cond()."""
from reflex.components.core.cond import cond as rx_cond
Expand Down