Skip to content

Commit 507042a

Browse files
Implement decentralized event handlers
Co-Authored-By: khaleel@reflex.dev <khaleel.aladhami@gmail.com>
1 parent 26f1a7b commit 507042a

4 files changed

Lines changed: 243 additions & 4 deletions

File tree

reflex/event.py

Lines changed: 108 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,30 @@
5555
)
5656
from reflex.vars.object import ObjectVar
5757

58+
_global_event_handlers: dict[str, EventHandler] = {}
59+
60+
61+
def register_event_handler(name: str, handler: EventHandler) -> None:
62+
"""Register a decentralized event handler.
63+
64+
Args:
65+
name: The name of the event handler.
66+
handler: The event handler.
67+
"""
68+
_global_event_handlers[name] = handler
69+
70+
71+
def get_event_handler(name: str) -> EventHandler | None:
72+
"""Get a decentralized event handler by name.
73+
74+
Args:
75+
name: The name of the event handler.
76+
77+
Returns:
78+
The event handler, or None if not found.
79+
"""
80+
return _global_event_handlers.get(name)
81+
5882

5983
@dataclasses.dataclass(
6084
init=True,
@@ -178,7 +202,7 @@ class EventHandler(EventActionsMixin):
178202

179203
# The full name of the state class this event handler is attached to.
180204
# Empty string means this event handler is a server side event.
181-
state_full_name: str = dataclasses.field(default="")
205+
state_full_name: str | None = dataclasses.field(default="")
182206

183207
@classmethod
184208
def __class_getitem__(cls, args_spec: str) -> Annotated:
@@ -261,6 +285,13 @@ def __call__(self, *args: Any, **kwargs: Any) -> EventSpec:
261285
) from e
262286
payload = tuple(zip(fn_args, values, strict=False))
263287

288+
# Check if this is a decentralized event handler
289+
if self.state_full_name is None:
290+
from reflex.utils import format
291+
292+
name = format.to_snake_case(self.fn.__qualname__)
293+
register_event_handler(name, self)
294+
264295
# Return the event spec.
265296
return EventSpec(
266297
handler=self, args=payload, event_actions=self.event_actions.copy()
@@ -1492,6 +1523,9 @@ def check_fn_match_arg_spec(
14921523
Raises:
14931524
EventFnArgMismatchError: Raised if the number of mandatory arguments do not match
14941525
"""
1526+
if is_decentralized_event_handler(user_func):
1527+
return
1528+
14951529
user_args = list(inspect.signature(user_func).parameters)
14961530
# Drop the first argument if it's a bound method
14971531
if inspect.ismethod(user_func) and user_func.__self__ is not None:
@@ -1518,6 +1552,61 @@ def check_fn_match_arg_spec(
15181552
)
15191553

15201554

1555+
DECENTRALIZED_EVENT_MARKER = "_rx_decentralized_event"
1556+
1557+
1558+
def is_decentralized_event_handler(fn: Callable) -> bool:
1559+
"""Check if a function is a decentralized event handler.
1560+
1561+
Args:
1562+
fn: The function to check.
1563+
1564+
Returns:
1565+
Whether the function is a decentralized event handler.
1566+
"""
1567+
# Check if the function has been decorated with @rx.event
1568+
if not hasattr(fn, "__qualname__"):
1569+
return False
1570+
1571+
# Check if the function has the decentralized event marker
1572+
return hasattr(fn, DECENTRALIZED_EVENT_MARKER)
1573+
1574+
1575+
def wrap_decentralized_handler(fn: Callable) -> Callable:
1576+
"""Wrap a decentralized event handler to be used with component events.
1577+
1578+
This creates a wrapper that doesn't require the state parameter when called
1579+
from a component event, but will pass the state when the event is processed.
1580+
1581+
Args:
1582+
fn: The decentralized event handler to wrap.
1583+
1584+
Returns:
1585+
A wrapped function that can be used with component events.
1586+
"""
1587+
1588+
# Create a wrapper function that doesn't require the state parameter
1589+
def wrapper(*args, **kwargs):
1590+
# Get or create the event handler
1591+
from reflex.utils import format
1592+
1593+
name = format.to_snake_case(fn.__qualname__)
1594+
handler = _global_event_handlers.get(name)
1595+
if handler is None:
1596+
handler = EventHandler(fn=fn, state_full_name=None)
1597+
register_event_handler(name, handler)
1598+
1599+
# Create an event spec with no arguments - the state will be provided
1600+
return EventSpec(handler=handler, args=())
1601+
1602+
wrapper.__name__ = fn.__name__
1603+
wrapper.__qualname__ = fn.__qualname__
1604+
wrapper.__doc__ = fn.__doc__
1605+
wrapper.__module__ = fn.__module__
1606+
1607+
return wrapper
1608+
1609+
15211610
def call_event_fn(
15221611
fn: Callable,
15231612
arg_spec: ArgsSpec | Sequence[ArgsSpec],
@@ -1543,6 +1632,11 @@ def call_event_fn(
15431632
from reflex.event import EventHandler, EventSpec
15441633
from reflex.utils.exceptions import EventHandlerValueError
15451634

1635+
# Check if this is a decentralized event handler
1636+
if is_decentralized_event_handler(fn):
1637+
wrapped_fn = wrap_decentralized_handler(fn)
1638+
return call_event_fn(wrapped_fn, arg_spec, key=key)
1639+
15461640
# Check that fn signature matches arg_spec
15471641
check_fn_match_arg_spec(fn, arg_spec, key=key)
15481642

@@ -2066,7 +2160,19 @@ def wrapper(
20662160
setattr(func, BACKGROUND_TASK_MARKER, True)
20672161
if getattr(func, "__name__", "").startswith("_"):
20682162
raise ValueError("Event handlers cannot be private.")
2069-
return func # pyright: ignore [reportReturnType]
2163+
2164+
# Check if this is a method (defined in a class) or a standalone function
2165+
if hasattr(func, "__qualname__") and "." in func.__qualname__:
2166+
return func # pyright: ignore [reportReturnType]
2167+
else:
2168+
# This is a decentralized event handler
2169+
handler = EventHandler(fn=func, state_full_name=None)
2170+
if background:
2171+
setattr(handler, BACKGROUND_TASK_MARKER, True)
2172+
# Mark the function as a decentralized event handler
2173+
setattr(func, DECENTRALIZED_EVENT_MARKER, True)
2174+
# Return the original function so it can be called normally
2175+
return func # pyright: ignore [reportReturnType]
20702176

20712177
if func is not None:
20722178
return wrapper(func)

reflex/state.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1586,15 +1586,30 @@ def _get_event_handler(
15861586
Args:
15871587
event: The event to get the handler for.
15881588
1589-
15901589
Returns:
15911590
The event handler.
15921591
15931592
Raises:
15941593
ValueError: If the event handler or substate is not found.
1594+
EventHandlerValueError: If the event handler is not found.
15951595
"""
15961596
# Get the event handler.
15971597
path = event.name.split(".")
1598+
1599+
if "." not in event.name:
1600+
from reflex.event import get_event_handler
1601+
from reflex.utils.exceptions import EventHandlerValueError
1602+
1603+
handler = get_event_handler(event.name)
1604+
if handler is None:
1605+
raise EventHandlerValueError(f"Event handler {event.name} not found.")
1606+
1607+
# For background tasks, proxy the state
1608+
if handler.is_background:
1609+
return StateProxy(self), handler
1610+
1611+
return self, handler
1612+
15981613
path, name = path[:-1], path[-1]
15991614
substate = self.get_substate(path)
16001615
if not substate:
@@ -1753,7 +1768,10 @@ async def _process_event(
17531768
from reflex.utils import telemetry
17541769

17551770
# Get the function to process the event.
1756-
fn = functools.partial(handler.fn, state)
1771+
if handler.state_full_name is None:
1772+
fn = handler.fn
1773+
else:
1774+
fn = functools.partial(handler.fn, state)
17571775

17581776
try:
17591777
type_hints = typing.get_type_hints(handler.fn)

reflex/utils/format.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,11 @@ def format_event_handler(handler: EventHandler) -> str:
483483
Returns:
484484
The formatted function.
485485
"""
486+
if handler.state_full_name is None:
487+
from reflex.utils import format
488+
489+
return format.to_snake_case(handler.fn.__qualname__)
490+
486491
state, name = get_event_handler_parts(handler)
487492
if state == "":
488493
return name
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""Test decentralized event handlers functionality."""
2+
3+
import pytest
4+
from selenium.webdriver.common.by import By
5+
6+
from reflex.testing import AppHarness, WebDriver
7+
8+
9+
def DecentralizedEventHandlers():
10+
"""Test that decentralized event handlers work as expected."""
11+
import reflex as rx
12+
13+
class TestState(rx.State):
14+
count: int = 0
15+
16+
@rx.event
17+
def increment(self):
18+
"""Increment the counter."""
19+
self.count += 1
20+
21+
@rx.event
22+
def on_load(state: TestState):
23+
"""Event handler for loading the state.
24+
25+
Args:
26+
state: The state to modify.
27+
"""
28+
state.count = 10
29+
30+
@rx.event
31+
def reset_count(state: TestState):
32+
"""Event handler for resetting the count.
33+
34+
Args:
35+
state: The state to modify.
36+
"""
37+
state.count = 0
38+
39+
def index():
40+
return rx.vstack(
41+
rx.heading(TestState.count, id="counter"),
42+
rx.button("Increment", on_click=TestState.increment, id="increment"),
43+
rx.button("Reset", on_click=reset_count, id="reset"),
44+
rx.text("Loaded", on_mount=on_load, id="loaded"),
45+
)
46+
47+
app = rx.App()
48+
app.add_page(index)
49+
50+
51+
@pytest.fixture(scope="module")
52+
def decentralized_handlers(
53+
tmp_path_factory,
54+
):
55+
"""Start DecentralizedEventHandlers app at tmp_path via AppHarness.
56+
57+
Args:
58+
tmp_path_factory: pytest tmp_path_factory fixture
59+
60+
Yields:
61+
running AppHarness instance
62+
"""
63+
with AppHarness.create(
64+
root=tmp_path_factory.mktemp("decentralized_handlers"),
65+
app_source=DecentralizedEventHandlers,
66+
) as harness:
67+
yield harness
68+
69+
70+
@pytest.fixture
71+
def driver(decentralized_handlers: AppHarness):
72+
"""Get an instance of the browser open to the app.
73+
74+
Args:
75+
decentralized_handlers: harness for DecentralizedEventHandlers app
76+
77+
Yields:
78+
WebDriver instance.
79+
"""
80+
assert decentralized_handlers.app_instance is not None, "app is not running"
81+
driver = decentralized_handlers.frontend()
82+
try:
83+
yield driver
84+
finally:
85+
driver.quit()
86+
87+
88+
def test_decentralized_event_handlers(
89+
decentralized_handlers: AppHarness,
90+
driver: WebDriver,
91+
):
92+
"""Test that decentralized event handlers work as expected.
93+
94+
Args:
95+
decentralized_handlers: harness for DecentralizedEventHandlers app
96+
driver: WebDriver instance
97+
"""
98+
assert decentralized_handlers.app_instance is not None
99+
100+
counter = driver.find_element(By.ID, "counter")
101+
increment_button = driver.find_element(By.ID, "increment")
102+
reset_button = driver.find_element(By.ID, "reset")
103+
104+
assert decentralized_handlers._poll_for(lambda: counter.text == "10", timeout=5)
105+
106+
increment_button.click()
107+
assert decentralized_handlers._poll_for(lambda: counter.text == "11", timeout=5)
108+
109+
reset_button.click()
110+
assert decentralized_handlers._poll_for(lambda: counter.text == "0", timeout=5)

0 commit comments

Comments
 (0)