Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 154 additions & 93 deletions reflex/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,6 +1262,118 @@ def get_hydrate_event(state: BaseState) -> str:
return get_event(state, constants.CompileVars.HYDRATE)


def _values_returned_from_event(
event_spec: ArgsSpec | Sequence[ArgsSpec],
) -> list[Any]:
return [
event_spec_return_type
for arg_spec in (
[event_spec] if not isinstance(event_spec, Sequence) else list(event_spec)
)
if (event_spec_return_type := get_type_hints(arg_spec).get("return", None))
is not None
and get_origin(event_spec_return_type) is tuple
]


def _check_event_args_subclass_of_callback(
callback_params_names: list[str],
provided_event_types: list[Any],
callback_param_name_to_type: dict[str, Any],
callback_name: str = "",
key: str = "",
):
"""Check if the event handler arguments are subclass of the callback.

Args:
callback_params_names: The names of the callback parameters.
provided_event_types: The event types.
callback_param_name_to_type: The callback parameter name to type mapping.
callback_name: The name of the callback.
key: The key.

Raises:
TypeError: If the event handler arguments are invalid.
EventHandlerArgTypeMismatchError: If the event handler arguments do not match the callback.

# noqa: DAR401 delayed_exceptions[]
# noqa: DAR402 EventHandlerArgTypeMismatchError
"""
type_match_found: dict[str, bool] = {}
delayed_exceptions: list[EventHandlerArgTypeMismatchError] = []

for event_spec_index, event_spec_return_type in enumerate(provided_event_types):
args = get_args(event_spec_return_type)

args_types_without_vars = [
arg if get_origin(arg) is not Var else get_args(arg)[0] for arg in args
]

# check that args of event handler are matching the spec if type hints are provided
for i, arg in enumerate(callback_params_names[: len(args_types_without_vars)]):
if arg not in callback_param_name_to_type:
continue

type_match_found.setdefault(arg, False)

try:
compare_result = typehint_issubclass(
args_types_without_vars[i], callback_param_name_to_type[arg]
)
except TypeError as te:
callback_name_context = f" of {callback_name}" if callback_name else ""
key_context = f" for {key}" if key else ""
raise TypeError(
f"Could not compare types {args_types_without_vars[i]} and {callback_param_name_to_type[arg]} for argument {arg}{callback_name_context}{key_context}."
) from te

if compare_result:
type_match_found[arg] = True
continue
else:
type_match_found[arg] = False
as_annotated_in = (
f" as annotated in {callback_name}" if callback_name else ""
)
delayed_exceptions.append(
EventHandlerArgTypeMismatchError(
f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {callback_param_name_to_type[arg]}{as_annotated_in} instead."
)
)

if all(type_match_found.values()):
delayed_exceptions.clear()
if event_spec_index:
args = get_args(provided_event_types[0])

args_types_without_vars = [
arg if get_origin(arg) is not Var else get_args(arg)[0]
for arg in args
]

expect_string = ", ".join(
repr(arg) for arg in args_types_without_vars
).replace("[", "\\[")

given_string = ", ".join(
repr(callback_param_name_to_type.get(arg, Any))
for arg in callback_params_names
).replace("[", "\\[")

as_annotated_in = (
f" as annotated in {callback_name}" if callback_name else ""
)

console.warn(
f"Event handler {key} expects ({expect_string}) -> () but got ({given_string}) -> (){as_annotated_in} instead. "
f"This may lead to unexpected behavior but is intentionally ignored for {key}."
)
break

if delayed_exceptions:
raise delayed_exceptions[0]


def call_event_handler(
event_callback: EventHandler | EventSpec,
event_spec: ArgsSpec | Sequence[ArgsSpec],
Expand All @@ -1278,17 +1390,13 @@ def call_event_handler(
event_spec: The lambda that define the argument(s) to pass to the event handler.
key: The key to pass to the event handler.

Raises:
EventHandlerArgTypeMismatchError: If the event handler arguments do not match the event spec. #noqa: DAR402
TypeError: If the event handler arguments are invalid.

Returns:
The event spec from calling the event handler.

#noqa: DAR401
"""
event_spec_args = parse_args_spec(event_spec)

event_spec_return_types = _values_returned_from_event(event_spec)

if isinstance(event_callback, EventSpec):
check_fn_match_arg_spec(
event_callback.handler.fn,
Expand All @@ -1297,6 +1405,32 @@ def call_event_handler(
bool(event_callback.handler.state_full_name) + len(event_callback.args),
event_callback.handler.fn.__qualname__,
)

event_callback_spec_args = list(
inspect.signature(event_callback.handler.fn).parameters.keys()
)

try:
type_hints_of_provided_callback = get_type_hints(event_callback.handler.fn)
except NameError:
type_hints_of_provided_callback = {}

argument_names = [str(arg) for arg, value in event_callback.args]

_check_event_args_subclass_of_callback(
[
arg
for arg in event_callback_spec_args[
bool(event_callback.handler.state_full_name) :
]
if arg not in argument_names
],
event_spec_return_types,
type_hints_of_provided_callback,
event_callback.handler.fn.__qualname__,
key or "",
)

# Handle partial application of EventSpec args
return event_callback.add_args(*event_spec_args)

Expand All @@ -1308,98 +1442,23 @@ def call_event_handler(
event_callback.fn.__qualname__,
)

all_acceptable_specs = (
[event_spec] if not isinstance(event_spec, Sequence) else event_spec
)

event_spec_return_types = list(
filter(
lambda event_spec_return_type: event_spec_return_type is not None
and get_origin(event_spec_return_type) is tuple,
(
get_type_hints(arg_spec).get("return", None)
for arg_spec in all_acceptable_specs
),
)
)
type_match_found: dict[str, bool] = {}
delayed_exceptions: list[EventHandlerArgTypeMismatchError] = []

try:
type_hints_of_provided_callback = get_type_hints(event_callback.fn)
except NameError:
type_hints_of_provided_callback = {}

if event_spec_return_types:
event_callback_spec_args = list(
inspect.signature(event_callback.fn).parameters.keys()
)

for event_spec_index, event_spec_return_type in enumerate(
event_spec_return_types
):
args = get_args(event_spec_return_type)

args_types_without_vars = [
arg if get_origin(arg) is not Var else get_args(arg)[0] for arg in args
]

# check that args of event handler are matching the spec if type hints are provided
for i, arg in enumerate(
event_callback_spec_args[1 : len(args_types_without_vars) + 1]
):
if arg not in type_hints_of_provided_callback:
continue

type_match_found.setdefault(arg, False)

try:
compare_result = typehint_issubclass(
args_types_without_vars[i], type_hints_of_provided_callback[arg]
)
except TypeError as te:
raise TypeError(
f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_callback.fn.__qualname__} provided for {key}."
) from te

if compare_result:
type_match_found[arg] = True
continue
else:
type_match_found[arg] = False
delayed_exceptions.append(
EventHandlerArgTypeMismatchError(
f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_callback.fn.__qualname__} instead."
)
)

if all(type_match_found.values()):
delayed_exceptions.clear()
if event_spec_index:
args = get_args(event_spec_return_types[0])

args_types_without_vars = [
arg if get_origin(arg) is not Var else get_args(arg)[0]
for arg in args
]

expect_string = ", ".join(
repr(arg) for arg in args_types_without_vars
).replace("[", "\\[")

given_string = ", ".join(
repr(type_hints_of_provided_callback.get(arg, Any))
for arg in event_callback_spec_args[1:]
).replace("[", "\\[")

console.warn(
f"Event handler {key} expects ({expect_string}) -> () but got ({given_string}) -> () as annotated in {event_callback.fn.__qualname__} instead. "
f"This may lead to unexpected behavior but is intentionally ignored for {key}."
)
break

if delayed_exceptions:
raise delayed_exceptions[0]
try:
type_hints_of_provided_callback = get_type_hints(event_callback.fn)
except NameError:
type_hints_of_provided_callback = {}

_check_event_args_subclass_of_callback(
event_callback_spec_args[1:],
event_spec_return_types,
type_hints_of_provided_callback,
event_callback.fn.__qualname__,
key or "",
)

return event_callback(*event_spec_args)

Expand Down Expand Up @@ -1958,6 +2017,8 @@ def __get__(self, instance: Any, owner: Any) -> Callable:
class LambdaEventCallback(Protocol[Unpack[P]]):
"""A protocol for a lambda event callback."""

__code__: types.CodeType

@overload
def __call__(self: LambdaEventCallback[()]) -> Any: ...

Expand Down
16 changes: 16 additions & 0 deletions tests/units/components/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,22 @@ def test_invalid_event_handler_args(component2, test_state):
component2.create(on_user_list_changed=test_state.do_something_with_int)
with pytest.raises(EventHandlerArgTypeMismatchError):
component2.create(on_user_list_changed=test_state.do_something_with_list_int)
with pytest.raises(EventHandlerArgTypeMismatchError):
component2.create(
on_user_visited_count_changed=test_state.do_something_with_bool()
)
with pytest.raises(EventHandlerArgTypeMismatchError):
component2.create(on_user_list_changed=test_state.do_something_with_int())
with pytest.raises(EventHandlerArgTypeMismatchError):
component2.create(on_user_list_changed=test_state.do_something_with_list_int())

component2.create(
on_user_visited_count_changed=test_state.do_something_with_bool(False)
)
component2.create(on_user_list_changed=test_state.do_something_with_int(23))
component2.create(
on_user_list_changed=test_state.do_something_with_list_int([2321, 321])
)

component2.create(on_open=test_state.do_something_with_int)
component2.create(on_open=test_state.do_something_with_bool)
Expand Down