Skip to content

Commit 6656b04

Browse files
committed
Support FunctionVar handlers in EventChain rendering
Allow EventChain to accept frontend FunctionVar handlers alongside EventSpec, EventVar, and EventCallback values. When a chain contains FunctionVars, keep backend events grouped through addEvents(...) and invoke frontend functions inline with the trigger arguments so mixed chains preserve execution order and DOM event actions like preventDefault and stopPropagation. Wrap inline arrow functions before emitting VarOperationCall JS so direct invocation renders valid JavaScript, add unit coverage for pure/mixed event-chain formatting and creation, and move upload exception docs to the helper that actually raises them to satisfy darglint.
1 parent fe34ae8 commit 6656b04

5 files changed

Lines changed: 210 additions & 17 deletions

File tree

reflex/app.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1931,11 +1931,6 @@ async def upload_file(request: Request):
19311931
Returns:
19321932
StreamingResponse yielding newline-delimited JSON of StateUpdate
19331933
emitted by the upload handler.
1934-
1935-
Raises:
1936-
UploadValueError: if there are no args with supported annotation.
1937-
UploadTypeError: if a background task is used as the handler.
1938-
HTTPException: when the request does not include token / handler headers.
19391934
"""
19401935
from reflex.utils.exceptions import UploadTypeError, UploadValueError
19411936

@@ -1960,6 +1955,11 @@ async def _create_upload_event() -> Event:
19601955
19611956
Returns:
19621957
The upload event backed by the original temp files.
1958+
1959+
Raises:
1960+
UploadValueError: If there are no uploaded files or supported annotations.
1961+
UploadTypeError: If a background task is used as the handler.
1962+
HTTPException: If the request is missing token or handler headers.
19631963
"""
19641964
files = form_data.getlist("files")
19651965
if not files:

reflex/event.py

Lines changed: 76 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -449,8 +449,8 @@ def __call__(self, *args, **kwargs) -> EventSpec:
449449
class EventChain(EventActionsMixin):
450450
"""Container for a chain of events that will be executed in order."""
451451

452-
events: Sequence["EventSpec | EventVar | EventCallback"] = dataclasses.field(
453-
default_factory=list
452+
events: Sequence["EventSpec | EventVar | FunctionVar | EventCallback"] = (
453+
dataclasses.field(default_factory=list)
454454
)
455455

456456
args_spec: Callable | Sequence[Callable] | None = dataclasses.field(default=None)
@@ -483,6 +483,8 @@ def create(
483483
if isinstance(value, Var):
484484
if isinstance(value, EventChainVar):
485485
return value
486+
if isinstance(value, FunctionVar):
487+
return value
486488
if isinstance(value, EventVar):
487489
value = [value]
488490
elif safe_issubclass(value._var_type, (EventChain, EventSpec)):
@@ -505,23 +507,26 @@ def create(
505507

506508
# If the input is a list of event handlers, create an event chain.
507509
if isinstance(value, list):
508-
events: list[EventSpec | EventVar] = []
510+
events: list[EventSpec | EventVar | FunctionVar] = []
509511
for v in value:
510512
if isinstance(v, (EventHandler, EventSpec)):
511513
# Call the event handler to get the event.
512514
events.append(call_event_handler(v, args_spec, key=key))
515+
elif isinstance(v, (EventVar, FunctionVar)):
516+
events.append(v)
513517
elif isinstance(v, Callable):
514518
# Call the lambda to get the event chain.
515519
result = call_event_fn(v, args_spec, key=key)
516520
if isinstance(result, Var):
521+
if isinstance(result, (EventVar, FunctionVar)):
522+
events.append(result)
523+
continue
517524
msg = (
518525
f"Invalid event chain: {v}. Cannot use a Var-returning "
519526
"lambda inside an EventChain list."
520527
)
521528
raise ValueError(msg)
522529
events.extend(result)
523-
elif isinstance(v, EventVar):
524-
events.append(v)
525530
else:
526531
msg = f"Invalid event: {v}"
527532
raise ValueError(msg)
@@ -2077,12 +2082,15 @@ def create(
20772082
sig = inspect.signature(arg_spec) # pyright: ignore [reportArgumentType]
20782083
if sig.parameters:
20792084
arg_def = tuple(f"_{p}" for p in sig.parameters)
2080-
arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def])
2085+
arg_vars = tuple(Var(_js_expr=arg) for arg in arg_def)
2086+
arg_def_expr = LiteralVar.create(list(arg_vars))
2087+
call_args = arg_vars
20812088
else:
20822089
# add a default argument for addEvents if none were specified in value.args_spec
20832090
# used to trigger the preventDefault() on the event.
20842091
arg_def = ("...args",)
20852092
arg_def_expr = Var(_js_expr="args")
2093+
call_args = (Var(_js_expr="...args"),)
20862094

20872095
if value.invocation is None:
20882096
invocation = FunctionStringVar.create(
@@ -2099,16 +2107,73 @@ def create(
20992107
msg = f"EventChain invocation must be a FunctionVar, got {invocation!s} of type {invocation._var_type!s}."
21002108
raise ValueError(msg)
21012109

2110+
has_function_var = any(isinstance(e, FunctionVar) for e in value.events)
2111+
2112+
if not has_function_var:
2113+
return_expr = invocation.call(
2114+
LiteralVar.create([LiteralVar.create(event) for event in value.events]),
2115+
arg_def_expr,
2116+
value.event_actions,
2117+
)
2118+
else:
2119+
statement_js: list[str] = []
2120+
statement_var_data: list[VarData | None] = []
2121+
queueable_group: list[EventSpec | EventVar | EventCallback] = []
2122+
2123+
if value.event_actions.get("preventDefault") or value.event_actions.get(
2124+
"stopPropagation"
2125+
):
2126+
statement_js.append(
2127+
"const _reflex_dom_event = "
2128+
f"{arg_def_expr}.filter((o) => o?.preventDefault !== undefined)[0];"
2129+
)
2130+
if value.event_actions.get("preventDefault"):
2131+
statement_js.append(
2132+
"if (_reflex_dom_event?.preventDefault) "
2133+
"{_reflex_dom_event.preventDefault();}"
2134+
)
2135+
if value.event_actions.get("stopPropagation"):
2136+
statement_js.append(
2137+
"if (_reflex_dom_event?.stopPropagation) "
2138+
"{_reflex_dom_event.stopPropagation();}"
2139+
)
2140+
2141+
def flush_queueable_group() -> None:
2142+
if not queueable_group:
2143+
return
2144+
queue_call = invocation.call(
2145+
LiteralVar.create([
2146+
LiteralVar.create(event) for event in queueable_group
2147+
]),
2148+
arg_def_expr,
2149+
{},
2150+
)
2151+
statement_js.append(f"{queue_call!s};")
2152+
statement_var_data.append(queue_call._get_all_var_data())
2153+
queueable_group.clear()
2154+
2155+
for event in value.events:
2156+
if isinstance(event, FunctionVar):
2157+
flush_queueable_group()
2158+
function_call = event.call(*call_args)
2159+
statement_js.append(f"{function_call!s};")
2160+
statement_var_data.append(function_call._get_all_var_data())
2161+
else:
2162+
queueable_group.append(event)
2163+
2164+
flush_queueable_group()
2165+
2166+
return_expr = Var(
2167+
_js_expr=f"{{{''.join(statement_js)}}}",
2168+
_var_data=VarData.merge(*statement_var_data),
2169+
)
2170+
21022171
return cls(
21032172
_js_expr="",
21042173
_var_type=EventChain,
21052174
_var_data=_var_data,
21062175
_args=FunctionArgs(arg_def),
2107-
_return_expr=invocation.call(
2108-
LiteralVar.create([LiteralVar.create(event) for event in value.events]),
2109-
arg_def_expr,
2110-
value.event_actions,
2111-
),
2176+
_return_expr=return_expr,
21122177
_var_value=value,
21132178
)
21142179

reflex/vars/function.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,11 @@ def _cached_var_name(self) -> str:
239239
Returns:
240240
The name of the var.
241241
"""
242-
return f"({self._func!s}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))"
242+
func_expr = str(self._func)
243+
if "=>" in func_expr and not format.is_wrapped(func_expr, "("):
244+
func_expr = format.wrap(func_expr, "(")
245+
246+
return f"({func_expr}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))"
243247

244248
@cached_property_no_lock
245249
def _cached_get_all_var_data(self) -> VarData | None:

tests/units/test_event.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ def make_var(value) -> Var:
3232
return Var(_js_expr=value)
3333

3434

35+
def make_timeout_logger():
36+
return rx.vars.FunctionStringVar.create(
37+
"(...args) => { setTimeout(() => console.log('Timeout reached!', args), 1000); }"
38+
).to(EventChain)
39+
40+
3541
def test_create_event():
3642
"""Test creating an event."""
3743
event = Event(token="token", name="state.do_thing", payload={"arg": "value"})
@@ -668,6 +674,41 @@ def _args_spec() -> tuple:
668674
assert "to bool" in str(err.value)
669675

670676

677+
def test_event_chain_create_allows_plain_function_var():
678+
"""Plain FunctionVars should be usable as frontend event handlers."""
679+
frontend_handler = make_timeout_logger()
680+
681+
assert EventChain.create(frontend_handler, args_spec=lambda: ()) is frontend_handler
682+
683+
684+
def test_event_chain_create_allows_function_var_in_list():
685+
"""FunctionVars should be allowed inside EventChain lists."""
686+
frontend_handler = make_timeout_logger()
687+
688+
chain = EventChain.create([frontend_handler], args_spec=lambda: ())
689+
690+
assert isinstance(chain, EventChain)
691+
assert chain.events == [frontend_handler]
692+
693+
694+
def test_button_accepts_mixed_event_handler_and_function_var():
695+
"""Components should accept mixed backend/frontend event chains."""
696+
697+
class MixedState(BaseState):
698+
@event
699+
def do_a_thing(self):
700+
pass
701+
702+
log_after_timeout = make_timeout_logger()
703+
704+
button = rx.button(
705+
"Do both",
706+
on_click=[MixedState.do_a_thing, log_after_timeout],
707+
)
708+
709+
assert isinstance(button.event_triggers["on_click"], EventChain)
710+
711+
671712
def test_decentralized_event_with_args():
672713
"""Test the decentralized event."""
673714

tests/units/utils/test_format.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from reflex.utils import format
2121
from reflex.utils.serializers import serialize_figure
2222
from reflex.vars.base import LiteralVar, Var
23+
from reflex.vars.function import FunctionStringVar
2324
from reflex.vars.object import ObjectVar
2425

2526
pytest.importorskip("pydantic")
@@ -41,6 +42,88 @@ def mock_event(arg):
4142
pass
4243

4344

45+
def mock_event_two(arg):
46+
pass
47+
48+
49+
def make_timeout_logger():
50+
return FunctionStringVar.create(
51+
"(...args) => { setTimeout(() => console.log('Timeout reached!', args), 1000); }"
52+
).to(EventChain)
53+
54+
55+
def test_format_prop_event_chain_pure_eventspec_grouped():
56+
"""Pure EventSpec chains should still collapse to one addEvents call."""
57+
chain = EventChain(
58+
events=[
59+
EventSpec(handler=EventHandler(fn=mock_event)),
60+
EventSpec(handler=EventHandler(fn=mock_event_two)),
61+
],
62+
args_spec=lambda e: [e],
63+
)
64+
65+
assert format.format_prop(LiteralVar.create(chain)) == (
66+
'((_e) => (addEvents([(ReflexEvent("mock_event", ({ }), ({ }))), '
67+
'(ReflexEvent("mock_event_two", ({ }), ({ })))], [_e], ({ }))))'
68+
)
69+
70+
71+
def test_format_prop_event_chain_pure_function_var():
72+
"""Pure FunctionVar chains should render as direct frontend calls."""
73+
log_after_timeout = make_timeout_logger()
74+
chain = EventChain(
75+
events=[log_after_timeout],
76+
args_spec=lambda e: [e],
77+
)
78+
79+
assert format.format_prop(LiteralVar.create(chain)) == (
80+
"((_e) => {(((...args) => { setTimeout(() => console.log('Timeout reached!', "
81+
"args), 1000); })(_e));})"
82+
)
83+
84+
85+
def test_format_prop_event_chain_mixed_queue_and_function():
86+
"""Mixed chains should alternate addEvents and direct calls in order."""
87+
log_after_timeout = make_timeout_logger()
88+
chain = EventChain(
89+
events=[
90+
EventSpec(handler=EventHandler(fn=mock_event)),
91+
log_after_timeout,
92+
EventSpec(handler=EventHandler(fn=mock_event_two)),
93+
],
94+
args_spec=lambda e: [e],
95+
)
96+
97+
assert format.format_prop(LiteralVar.create(chain)) == (
98+
'((_e) => {(addEvents([(ReflexEvent("mock_event", ({ }), ({ })))], '
99+
"[_e], ({ })));(((...args) => { setTimeout(() => console.log('Timeout reached!', "
100+
'args), 1000); })(_e));(addEvents([(ReflexEvent("mock_event_two", '
101+
"({ }), ({ })))], [_e], ({ })));})"
102+
)
103+
104+
105+
def test_format_prop_event_chain_mixed_with_event_actions():
106+
"""Mixed chains should preserve DOM event actions on the wrapper callback."""
107+
log_after_timeout = make_timeout_logger()
108+
chain = EventChain(
109+
events=[
110+
EventSpec(handler=EventHandler(fn=mock_event)),
111+
log_after_timeout,
112+
],
113+
args_spec=lambda e: [e],
114+
event_actions={"preventDefault": True, "stopPropagation": True},
115+
)
116+
117+
assert format.format_prop(LiteralVar.create(chain)) == (
118+
"((_e) => {const _reflex_dom_event = [_e].filter((o) => "
119+
"o?.preventDefault !== undefined)[0];if (_reflex_dom_event?.preventDefault) "
120+
"{_reflex_dom_event.preventDefault();}if (_reflex_dom_event?.stopPropagation) "
121+
'{_reflex_dom_event.stopPropagation();}(addEvents([(ReflexEvent("mock_event", '
122+
"({ }), ({ })))], [_e], ({ })));(((...args) => { setTimeout(() => "
123+
"console.log('Timeout reached!', args), 1000); })(_e));})"
124+
)
125+
126+
44127
@pytest.mark.parametrize(
45128
("input", "output"),
46129
[

0 commit comments

Comments
 (0)