Skip to content

Commit 0c8bc05

Browse files
committed
Generalize Params unwrapping in _callable_type_to_signature
Extract _unwrap_params helper that handles Params[...], list, and tuple[...] uniformly for all three callable kinds (Callable, classmethod, staticmethod). Standard callables produce simple positional-only signatures; Param types without Params wrapper are rejected.
1 parent 9518ad0 commit 0c8bc05

File tree

3 files changed

+41
-139
lines changed

3 files changed

+41
-139
lines changed

tests/test_type_eval.py

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1828,58 +1828,6 @@ def test_callable_to_signature_01():
18281828
)
18291829

18301830

1831-
def test_callable_to_signature_02():
1832-
from typemap.type_eval._eval_operators import _callable_type_to_signature
1833-
1834-
class C:
1835-
pass
1836-
1837-
callable_type = classmethod[
1838-
C,
1839-
tuple[
1840-
Param[None, int],
1841-
Param[Literal["b"], int],
1842-
Param[Literal["c"], int, Literal["default"]],
1843-
Param[None, int, Literal["*"]],
1844-
Param[Literal["d"], int, Literal["keyword"]],
1845-
Param[Literal["e"], int, Literal["default", "keyword"]],
1846-
Param[None, int, Literal["**"]],
1847-
],
1848-
int,
1849-
]
1850-
sig = _callable_type_to_signature(callable_type)
1851-
assert str(sig) == (
1852-
'(cls: tests.test_type_eval.test_callable_to_signature_02.<locals>.C, '
1853-
'_arg1: int, /, b: int, c: int = ..., *args: int, '
1854-
'd: int, e: int = ..., **kwargs: int) -> int'
1855-
)
1856-
1857-
1858-
def test_callable_to_signature_03():
1859-
from typemap.type_eval._eval_operators import _callable_type_to_signature
1860-
1861-
class C:
1862-
pass
1863-
1864-
callable_type = staticmethod[
1865-
tuple[
1866-
Param[None, int],
1867-
Param[Literal["b"], int],
1868-
Param[Literal["c"], int, Literal["default"]],
1869-
Param[None, int, Literal["*"]],
1870-
Param[Literal["d"], int, Literal["keyword"]],
1871-
Param[Literal["e"], int, Literal["default", "keyword"]],
1872-
Param[None, int, Literal["**"]],
1873-
],
1874-
int,
1875-
]
1876-
sig = _callable_type_to_signature(callable_type)
1877-
assert str(sig) == (
1878-
'(_arg0: int, /, b: int, c: int = ..., *args: int, '
1879-
'd: int, e: int = ..., **kwargs: int) -> int'
1880-
)
1881-
1882-
18831831
def test_new_protocol_with_methods_01():
18841832
class C:
18851833
def member_method(self, x: int) -> int: ...

typemap/type_eval/_apply_generic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ def substitute(ty, args):
8686
ty, (typing_GenericAlias, types.GenericAlias, types.UnionType)
8787
):
8888
return ty.__origin__[*[substitute(t, args) for t in ty.__args__]]
89+
elif isinstance(ty, list):
90+
return [substitute(t, args) for t in ty]
8991
else:
9092
return ty
9193

typemap/type_eval/_eval_operators.py

Lines changed: 39 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -492,90 +492,44 @@ def __repr__(self):
492492
_DUMMY_DEFAULT = _DummyDefault()
493493

494494

495-
def _callable_type_to_signature(callable_type: object) -> inspect.Signature:
496-
"""Convert a Callable type to an inspect.Signature.
497-
498-
Extended callables use the form:
499-
Callable[Params[Param[name, type, quals], ...], return_type]
495+
def _unwrap_params(param_types) -> list:
496+
"""Unwrap params into a list of Param types.
500497
501-
Standard callables use the form:
502-
Callable[[type, ...], return_type]
498+
Accepts Params[...] (extended format), or a list/tuple of plain
499+
types (standard format, converted to positional-only Params).
503500
"""
504-
args = typing.get_args(callable_type)
505-
if (
506-
isinstance(callable_type, types.GenericAlias)
507-
and callable_type.__origin__ is classmethod
508-
):
509-
if len(args) != 3:
510-
raise TypeError(
511-
f"Expected classmethod[cls, [...], ret], got {callable_type}"
512-
)
513-
514-
receiver, param_types, return_type = typing.get_args(callable_type)
515-
param_types = [
516-
Param[
517-
typing.Literal["cls"],
518-
receiver, # type: ignore[valid-type]
519-
typing.Literal["positional"],
520-
],
521-
*typing.get_args(param_types),
522-
]
523-
524-
elif (
525-
isinstance(callable_type, types.GenericAlias)
526-
and callable_type.__origin__ is staticmethod
527-
):
528-
if len(args) != 2:
529-
raise TypeError(
530-
f"Expected staticmethod[...], ret], got {callable_type}"
531-
)
532-
533-
param_types, return_type = typing.get_args(callable_type)
534-
param_types = list(typing.get_args(param_types))
501+
if typing.get_origin(param_types) is Params:
502+
return list(typing.get_args(param_types))
535503

504+
if isinstance(param_types, (list, tuple)):
505+
items = list(param_types)
536506
else:
537-
if len(args) != 2:
507+
raise TypeError(
508+
f"Expected Params[...] or list of types, got {param_types}"
509+
)
510+
# Error if someone passes Param types without Params wrapper
511+
for t in items:
512+
if typing.get_origin(t) is Param:
538513
raise TypeError(
539-
f"Expected Callable[[...], ret], got {callable_type}"
514+
f"Param types must be wrapped in Params[...], got [{t}, ...]"
540515
)
516+
# Convert standard types to positional-only Params
517+
return [
518+
Param[typing.Literal[None], t] # type: ignore[valid-type]
519+
for t in items
520+
]
541521

542-
param_types, return_type = args
543-
# Unwrap Params wrapper
544-
if typing.get_origin(param_types) is Params:
545-
param_types = list(typing.get_args(param_types))
546-
else:
547-
# Standard callable (no Params wrapping) — build simple
548-
# positional parameters from the type list
549-
if isinstance(param_types, (list, tuple)):
550-
# Error if someone passes Param types without Params wrapper
551-
for t in param_types:
552-
if typing.get_origin(t) is Param:
553-
raise TypeError(
554-
f"Param types must be wrapped in Params[...], "
555-
f"got Callable[[{t}, ...], ...]"
556-
)
557-
params = []
558-
for i, t in enumerate(param_types):
559-
params.append(
560-
inspect.Parameter(
561-
f"_arg{i}",
562-
kind=inspect.Parameter.POSITIONAL_ONLY,
563-
annotation=t,
564-
)
565-
)
566-
if return_type is type(None):
567-
return_type = None
568-
return inspect.Signature(
569-
parameters=params,
570-
return_annotation=return_type,
571-
)
572-
raise TypeError(
573-
f"Expected Params[...] or list of types, got {param_types}"
574-
)
575522

576-
# Handle the case where param_types is a list of Param types
577-
if not isinstance(param_types, (list, tuple)):
578-
raise TypeError(f"Expected list of Param types, got {param_types}")
523+
def _callable_type_to_signature(callable_type: object) -> inspect.Signature:
524+
"""Convert a Callable[Params[...], ret] type to an inspect.Signature."""
525+
args = typing.get_args(callable_type)
526+
if len(args) != 2:
527+
raise TypeError(
528+
f"Expected Callable[Params[...], ret], got {callable_type}"
529+
)
530+
531+
raw_params, return_type = args
532+
param_types = _unwrap_params(raw_params)
579533

580534
parameters: list[inspect.Parameter] = []
581535
saw_keyword_only = False
@@ -706,10 +660,11 @@ def _callable_type_to_method(name, typ, ctx):
706660

707661
if head is classmethod:
708662
# XXX: handle other amounts
709-
cls, params, ret = typing.get_args(typ)
663+
cls, raw_params, ret = typing.get_args(typ)
664+
param_list = _unwrap_params(raw_params)
710665
# We have to make class positional only if there is some other
711666
# positional only argument. Annoying!
712-
has_pos_only = any(_is_pos_only(p) for p in typing.get_args(params))
667+
has_pos_only = any(_is_pos_only(p) for p in param_list)
713668
quals = typing.Literal["positional"] if has_pos_only else typing.Never
714669
# Override the receiver type with type[Self].
715670
if name == "__init_subclass__" and isinstance(cls, typing.TypeVar):
@@ -718,17 +673,14 @@ def _callable_type_to_method(name, typ, ctx):
718673
else:
719674
cls_typ = type[typing.Self] # type: ignore[name-defined]
720675
cls_param = Param[typing.Literal["cls"], cls_typ, quals]
721-
typ = typing.Callable[Params[cls_param, *typing.get_args(params)], ret]
676+
typ = typing.Callable[Params[cls_param, *param_list], ret]
722677
elif head is staticmethod:
723-
params, ret = typing.get_args(typ)
724-
typ = typing.Callable[Params[*typing.get_args(params)], ret]
678+
raw_params, ret = typing.get_args(typ)
679+
param_list = _unwrap_params(raw_params)
680+
typ = typing.Callable[Params[*param_list], ret]
725681
else:
726-
params, ret = typing.get_args(typ)
727-
# Unwrap Params wrapper if present
728-
if typing.get_origin(params) is Params:
729-
param_list = list(typing.get_args(params))
730-
else:
731-
param_list = list(params)
682+
raw_params, ret = typing.get_args(typ)
683+
param_list = _unwrap_params(raw_params)
732684
# Override the annotations for methods
733685
# - use Self for the "self" param, otherwise the fully qualified cls
734686
# name gets used. This ends up being long and annoying to handle.

0 commit comments

Comments
 (0)