Skip to content

Commit e900923

Browse files
authored
stubtest: attempt to resolve decorators from their type (#20867)
Fixes #19689
1 parent e413220 commit e900923

File tree

2 files changed

+111
-3
lines changed

2 files changed

+111
-3
lines changed

mypy/stubtest.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import enum
1313
import functools
1414
import importlib
15-
import importlib.machinery
1615
import inspect
1716
import os
1817
import pkgutil
@@ -35,11 +34,11 @@
3534

3635
import mypy.build
3736
import mypy.checkexpr
38-
import mypy.checkmember
3937
import mypy.erasetype
4038
import mypy.modulefinder
4139
import mypy.nodes
4240
import mypy.state
41+
import mypy.subtypes
4342
import mypy.types
4443
import mypy.version
4544
from mypy import nodes
@@ -1529,7 +1528,7 @@ def apply_decorator_to_funcitem(
15291528
):
15301529
return func
15311530
if decorator.fullname == "builtins.classmethod":
1532-
if func.arguments[0].variable.name not in ("cls", "mcs", "metacls"):
1531+
if func.arguments[0].variable.name not in ("_cls", "cls", "mcs", "metacls"):
15331532
raise StubtestFailure(
15341533
f"unexpected class parameter name {func.arguments[0].variable.name!r} "
15351534
f"in {dec.fullname}"
@@ -1547,11 +1546,70 @@ def apply_decorator_to_funcitem(
15471546
for decorator in dec.original_decorators:
15481547
resulting_func = apply_decorator_to_funcitem(decorator, func)
15491548
if resulting_func is None:
1549+
# We couldn't figure out how to apply the decorator by transforming nodes, so try to
1550+
# reconstitute a FuncDef from the resulting type of the decorator
1551+
# This is worse because e.g. we lose the values of defaults
1552+
dec_type = mypy.types.get_proper_type(dec.type)
1553+
callable_type = None
1554+
if isinstance(dec_type, mypy.types.Instance):
1555+
callable_type = mypy.subtypes.find_member(
1556+
"__call__", dec_type, dec_type, is_operator=True
1557+
)
1558+
elif isinstance(dec_type, mypy.types.CallableType):
1559+
callable_type = dec_type
1560+
1561+
callable_type = mypy.types.get_proper_type(callable_type)
1562+
if isinstance(callable_type, mypy.types.CallableType):
1563+
return _resolve_funcitem_from_callable_type(dec, callable_type)
15501564
return None
1565+
15511566
func = resulting_func
15521567
return func
15531568

15541569

1570+
def _resolve_funcitem_from_callable_type(
1571+
dec: nodes.Decorator, typ: mypy.types.CallableType
1572+
) -> nodes.FuncDef | None:
1573+
if (
1574+
typ.arg_kinds == [nodes.ARG_STAR, nodes.ARG_STAR2]
1575+
and (var_arg := typ.var_arg()) is not None
1576+
and isinstance(mypy.types.get_proper_type(var_arg.typ), mypy.types.AnyType)
1577+
and (var_kwarg := typ.kw_arg()) is not None
1578+
and isinstance(mypy.types.get_proper_type(var_kwarg.typ), mypy.types.AnyType)
1579+
):
1580+
# There isn't a FuncDef we can invent corresponding to a Callable[..., T]
1581+
return None
1582+
1583+
args: list[nodes.Argument] = []
1584+
for i, (arg_type, arg_kind, arg_name) in enumerate(
1585+
zip(typ.arg_types, typ.arg_kinds, typ.arg_names, strict=True)
1586+
):
1587+
var_name = arg_name if arg_name is not None else f"__arg{i}"
1588+
var = nodes.Var(var_name, arg_type)
1589+
pos_only = arg_name is None and arg_kind == nodes.ARG_POS
1590+
args.append(
1591+
nodes.Argument(
1592+
variable=var,
1593+
type_annotation=arg_type,
1594+
initializer=None, # CallableType doesn't store the values of defaults
1595+
kind=arg_kind,
1596+
pos_only=pos_only,
1597+
)
1598+
)
1599+
1600+
if dec.func.is_class:
1601+
if not args:
1602+
return None
1603+
# Munge classmethods, similar to logic in _resolve_funcitem_from_decorator
1604+
if args[0].variable.name not in ("_cls", "cls", "mcs", "metacls"):
1605+
return None
1606+
args.pop(0)
1607+
1608+
ret = nodes.FuncDef(name=typ.name or "", arguments=args, body=nodes.Block([]), typ=typ)
1609+
ret.is_class = dec.func.is_class
1610+
return ret
1611+
1612+
15551613
@verify.register(nodes.Decorator)
15561614
def verify_decorator(
15571615
stub: nodes.Decorator, runtime: MaybeMissing[Any], object_path: list[str]

mypy/test/teststubtest.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,56 @@ def f(a, *args): ...
900900
error=None,
901901
)
902902

903+
@collect_cases
904+
def test_decorated_overload(self) -> Iterator[Case]:
905+
yield Case(
906+
stub="""
907+
from typing import overload
908+
909+
class _dec1:
910+
def __init__(self, func: object) -> None: ...
911+
def __call__(self, x: str) -> str: ...
912+
913+
@overload
914+
def good1(x: int) -> int: ...
915+
@overload
916+
@_dec1
917+
def good1(unrelated: int, whatever: str) -> str: ...
918+
""",
919+
runtime="def good1(x): ...",
920+
error=None,
921+
)
922+
yield Case(
923+
stub="""
924+
class _dec2:
925+
def __init__(self, func: object) -> None: ...
926+
def __call__(self, x: str, y: int) -> str: ...
927+
928+
@overload
929+
def good2(x: int) -> str: ...
930+
@overload
931+
@_dec2
932+
def good2(unrelated: int, whatever: str) -> str: ...
933+
""",
934+
runtime="def good2(x, y=...): ...",
935+
error=None,
936+
)
937+
yield Case(
938+
stub="""
939+
class _dec3:
940+
def __init__(self, func: object) -> None: ...
941+
def __call__(self, x: str, y: int) -> str: ...
942+
943+
@overload
944+
def bad(x: int) -> str: ...
945+
@overload
946+
@_dec3
947+
def bad(unrelated: int, whatever: str) -> str: ...
948+
""",
949+
runtime="def bad(x): ...",
950+
error="bad",
951+
)
952+
903953
@collect_cases
904954
def test_property(self) -> Iterator[Case]:
905955
yield Case(

0 commit comments

Comments
 (0)