diff --git a/discord/commands/core.py b/discord/commands/core.py index d2f8f52f69..1cc5ee902a 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -43,6 +43,7 @@ Generic, TypeVar, Union, + get_type_hints, ) from ..channel import PartialMessageable, _threaded_guild_channel_factory @@ -800,9 +801,14 @@ def _parse_options(self, params, *, check_params: bool = True) -> list[Option]: else: params = iter(params.items()) + try: + hints = get_type_hints(self.callback, include_extras=True) + except Exception: + hints = {} + final_options = [] for p_name, p_obj in params: - option = p_obj.annotation + option = hints.get(p_name, p_obj.annotation) if option == inspect.Parameter.empty: option = str @@ -884,6 +890,11 @@ def _match_option_param_names(self, params, options): options = list(options) params = self._check_required_params(params) + try: + hints = get_type_hints(self.callback, include_extras=True) + except Exception: + hints = {} + check_annotations: list[Callable[[Option, type], bool]] = [ lambda o, a: ( o.input_type == SlashCommandOptionType.string @@ -909,7 +920,7 @@ def _match_option_param_names(self, params, options): p_name, p_obj = next(params) except StopIteration: # not enough params for all the options raise ClientException("Too many arguments passed to the options kwarg.") - p_obj = p_obj.annotation + p_obj = hints.get(p_name, p_obj.annotation) if not any(check(o, p_obj) for check in check_annotations): raise TypeError( @@ -1088,7 +1099,7 @@ async def _invoke(self, ctx: ApplicationContext) -> None: ): pass - elif issubclass(op._raw_type, Enum): + elif isinstance(op._raw_type, type) and issubclass(op._raw_type, Enum): if isinstance(arg, str) and arg.isdigit(): try: arg = op._raw_type(int(arg)) diff --git a/tests/commands/__init__.py b/tests/commands/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/commands/test_pep563.py b/tests/commands/test_pep563.py new file mode 100644 index 0000000000..716b0cab73 --- /dev/null +++ b/tests/commands/test_pep563.py @@ -0,0 +1,163 @@ +""" +The MIT License (MIT) + +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +# PEP 563: all annotations in this module are stored as strings at runtime. +# This file intentionally uses `from __future__ import annotations` so that +# every callback defined here exercises the exact scenario reported in #513. +from __future__ import annotations + +import discord +from discord.commands import SlashCommand +from discord.commands.options import Option +from discord.enums import SlashCommandOptionType + +# --------------------------------------------------------------------------- +# Callbacks — defined here so their __annotations__ are PEP-563 strings. +# --------------------------------------------------------------------------- + + +async def _ann_member(ctx, user: Option(discord.Member, "A member")): + pass + + +async def _ann_str(ctx, name: Option(str, "A name")): + pass + + +async def _ann_int(ctx, count: Option(int, "A count")): + pass + + +async def _ann_role(ctx, role: Option(discord.Role, "A role")): + pass + + +async def _ann_not_required( + ctx, user: Option(discord.Member, "optional", required=False) +): + pass + + +async def _plain_str(ctx, name: str): + pass + + +async def _plain_int(ctx, count: int): + pass + + +async def _default_member( + ctx, user: discord.Member = Option(discord.Member, "A member") +): + pass + + +async def _default_not_required( + ctx, + user: discord.Member = Option( + discord.Member, "optional", required=False, default=None + ), +): + pass + + +# --------------------------------------------------------------------------- +# Option(...) as annotation under PEP 563 +# --------------------------------------------------------------------------- + + +class TestOptionAsAnnotation: + def test_member_input_type(self): + cmd = SlashCommand(_ann_member, name="test") + assert cmd.options[0].input_type == SlashCommandOptionType.user + + def test_str_input_type(self): + cmd = SlashCommand(_ann_str, name="test") + assert cmd.options[0].input_type == SlashCommandOptionType.string + + def test_int_input_type(self): + cmd = SlashCommand(_ann_int, name="test") + assert cmd.options[0].input_type == SlashCommandOptionType.integer + + def test_role_input_type(self): + cmd = SlashCommand(_ann_role, name="test") + assert cmd.options[0].input_type == SlashCommandOptionType.role + + def test_option_name_matches_param(self): + cmd = SlashCommand(_ann_member, name="test") + assert cmd.options[0].name == "user" + + def test_not_required_flag(self): + cmd = SlashCommand(_ann_not_required, name="test") + assert not cmd.options[0].required + + def test_raw_type_is_never_string(self): + # Before the fix, _raw_type would be a str like "discord.Member"; after it + # must always be an actual type or SlashCommandOptionType enum value. + for func in (_ann_member, _ann_str, _ann_int, _ann_role): + cmd = SlashCommand(func, name="test") + raw = cmd.options[0]._raw_type + assert isinstance( + raw, (type, SlashCommandOptionType) + ), f"{func.__name__}: _raw_type={raw!r} should be a class, not a string" + + +# --------------------------------------------------------------------------- +# Plain type annotations under PEP 563 — regression +# --------------------------------------------------------------------------- + + +class TestPlainAnnotationRegression: + def test_str_annotation(self): + cmd = SlashCommand(_plain_str, name="test") + assert cmd.options[0].input_type == SlashCommandOptionType.string + + def test_int_annotation(self): + cmd = SlashCommand(_plain_int, name="test") + assert cmd.options[0].input_type == SlashCommandOptionType.integer + + def test_raw_type_is_never_string(self): + for func in (_plain_str, _plain_int): + cmd = SlashCommand(func, name="test") + raw = cmd.options[0]._raw_type + assert isinstance( + raw, (type, SlashCommandOptionType) + ), f"{func.__name__}: _raw_type={raw!r} should be a class, not a string" + + +# --------------------------------------------------------------------------- +# Option(...) as default value — pre-existing workaround must keep working +# --------------------------------------------------------------------------- + + +class TestOptionAsDefaultRegression: + def test_member_default(self): + cmd = SlashCommand(_default_member, name="test") + assert cmd.options[0].input_type == SlashCommandOptionType.user + + def test_not_required_default(self): + cmd = SlashCommand(_default_not_required, name="test") + opt = cmd.options[0] + assert opt.input_type == SlashCommandOptionType.user + assert not opt.required