Skip to content

Commit db5c0c0

Browse files
committed
Fix how type parameters are collected from Protocol bases
1 parent ccbfd9b commit db5c0c0

2 files changed

Lines changed: 171 additions & 6 deletions

File tree

src/test_typing_extensions.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3580,12 +3580,14 @@ class C: pass
35803580

35813581
def test_defining_generic_protocols(self):
35823582
T = TypeVar('T')
3583+
T2 = TypeVar('T2')
35833584
S = TypeVar('S')
35843585
@runtime_checkable
35853586
class PR(Protocol[T, S]):
35863587
def meth(self): pass
35873588
class P(PR[int, T], Protocol[T]):
35883589
y = 1
3590+
self.assertEqual(P.__parameters__, (T,))
35893591
with self.assertRaises(TypeError):
35903592
issubclass(PR[int, T], PR)
35913593
with self.assertRaises(TypeError):
@@ -3594,16 +3596,23 @@ class P(PR[int, T], Protocol[T]):
35943596
PR[int]
35953597
with self.assertRaises(TypeError):
35963598
P[int, str]
3599+
with self.assertRaisesRegex(
3600+
TypeError,
3601+
re.escape('Some type variables (~S) are not listed in Protocol[~T, ~T2]'),
3602+
):
3603+
class ExtraTypeVars(P[S], Protocol[T, T2]): ...
35973604
if not TYPING_3_10_0:
35983605
with self.assertRaises(TypeError):
35993606
PR[int, 1]
36003607
with self.assertRaises(TypeError):
36013608
PR[int, ClassVar]
36023609
class C(PR[int, T]): pass
3610+
self.assertEqual(C.__parameters__, (T,))
36033611
self.assertIsInstance(C[str](), C)
36043612

36053613
def test_defining_generic_protocols_old_style(self):
36063614
T = TypeVar('T')
3615+
T2 = TypeVar('T2')
36073616
S = TypeVar('S')
36083617
@runtime_checkable
36093618
class PR(Protocol, Generic[T, S]):
@@ -3620,8 +3629,15 @@ class P(PR[int, str], Protocol):
36203629
PR[int, 1]
36213630
class P1(Protocol, Generic[T]):
36223631
def bar(self, x: T) -> str: ...
3632+
self.assertEqual(P1.__parameters__, (T,))
36233633
class P2(Generic[T], Protocol):
36243634
def bar(self, x: T) -> str: ...
3635+
self.assertEqual(P2.__parameters__, (T,))
3636+
msg = re.escape('Some type variables (~S) are not listed in Protocol[~T, ~T2]')
3637+
with self.assertRaisesRegex(TypeError, msg):
3638+
class ExtraTypeVars(P1[S], Protocol[T, T2]): ...
3639+
with self.assertRaisesRegex(TypeError, msg):
3640+
class ExtraTypeVars(P2[S], Protocol[T, T2]): ...
36253641
@runtime_checkable
36263642
class PSub(P1[str], Protocol):
36273643
x = 1
@@ -3634,9 +3650,33 @@ def bar(self, x: str) -> str:
36343650
with self.assertRaises(TypeError):
36353651
PR[int, ClassVar]
36363652

3653+
def test_protocol_parameter_order(self):
3654+
# https://github.com/python/cpython/issues/137191
3655+
T1 = TypeVar("T1")
3656+
T2 = TypeVar("T2", default=object)
3657+
3658+
class A(Protocol[T1]): ...
3659+
3660+
class B0(A[T2], Generic[T1, T2]): ...
3661+
self.assertEqual(B0.__parameters__, (T1, T2))
3662+
3663+
class B1(A[T2], Protocol, Generic[T1, T2]): ...
3664+
self.assertEqual(B1.__parameters__, (T1, T2))
3665+
3666+
class B2(A[T2], Protocol[T1, T2]): ...
3667+
self.assertEqual(B2.__parameters__, (T1, T2))
3668+
36373669
if hasattr(typing, "TypeAliasType"):
36383670
exec(textwrap.dedent(
36393671
"""
3672+
def test_pep695_protocol_parameter_order(self):
3673+
class A[T1](Protocol): ...
3674+
class B3[T1, T2](A[T2], Protocol):
3675+
@staticmethod
3676+
def get_typeparams():
3677+
return (T1, T2)
3678+
self.assertEqual(B3.__parameters__, B3.get_typeparams())
3679+
36403680
def test_pep695_generic_protocol_callable_members(self):
36413681
@runtime_checkable
36423682
class Foo[T](Protocol):

src/typing_extensions.py

Lines changed: 131 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3203,7 +3203,13 @@ def _is_unpacked_typevartuple(x) -> bool:
32033203
)
32043204

32053205

3206-
# Python 3.11+ _collect_type_vars was renamed to _collect_parameters
3206+
# - Python 3.11+ _collect_type_vars was renamed to _collect_parameters.
3207+
# Breakpoint: https://github.com/python/cpython/pull/31143
3208+
# - Python 3.13+ _collect_parameters was renamed to _collect_type_parameters.
3209+
# Breakpoint: https://github.com/python/cpython/pull/118900
3210+
# - Monkey patch Generic.__init_subclass__ on <3.15 to fix type parameter
3211+
# collection from Protocol bases with listed parameters.
3212+
# Breakpoint: https://github.com/python/cpython/pull/137281
32073213
if hasattr(typing, '_collect_type_vars'):
32083214
def _collect_type_vars(types, typevar_types=None):
32093215
"""Collect all type variable contained in types in order of
@@ -3253,21 +3259,82 @@ def _collect_type_vars(types, typevar_types=None):
32533259
tvars.append(collected)
32543260
return tuple(tvars)
32553261

3262+
def _generic_init_subclass(cls, *args, **kwargs):
3263+
super(Generic, cls).__init_subclass__(*args, **kwargs)
3264+
tvars = []
3265+
if '__orig_bases__' in cls.__dict__:
3266+
error = Generic in cls.__orig_bases__
3267+
else:
3268+
error = (Generic in cls.__bases__ and
3269+
cls.__name__ != 'Protocol' and
3270+
type(cls) not in (_TypedDictMeta, typing._TypedDictMeta))
3271+
if error:
3272+
raise TypeError("Cannot inherit from plain Generic")
3273+
if '__orig_bases__' in cls.__dict__:
3274+
typevar_types = (TypeVar, typing.TypeVar, ParamSpec)
3275+
if hasattr(typing, "ParamSpec"): # Python 3.10+
3276+
typevar_types += (typing.ParamSpec,)
3277+
tvars = _collect_type_vars(cls.__orig_bases__, typevar_types)
3278+
# Look for Generic[T1, ..., Tn].
3279+
# If found, tvars must be a subset of it.
3280+
# If not found, tvars is it.
3281+
# Also check for and reject plain Generic,
3282+
# and reject multiple Generic[...].
3283+
gvars = None
3284+
basename = None
3285+
for base in cls.__orig_bases__:
3286+
if (isinstance(base, typing._GenericAlias) and
3287+
base.__origin__ in (Generic, typing.Protocol, Protocol)):
3288+
if gvars is not None:
3289+
raise TypeError(
3290+
"Cannot inherit from Generic[...] multiple times."
3291+
)
3292+
gvars = base.__parameters__
3293+
basename = base.__origin__.__name__
3294+
if gvars is not None:
3295+
tvarset = set(tvars)
3296+
gvarset = set(gvars)
3297+
if not tvarset <= gvarset:
3298+
s_vars = ', '.join(str(t) for t in tvars if t not in gvarset)
3299+
s_args = ', '.join(str(g) for g in gvars)
3300+
raise TypeError(
3301+
f"Some type variables ({s_vars}) are"
3302+
f" not listed in {basename}[{s_args}]"
3303+
)
3304+
tvars = gvars
3305+
cls.__parameters__ = tuple(tvars)
3306+
32563307
typing._collect_type_vars = _collect_type_vars
3257-
else:
3258-
def _collect_parameters(args):
3308+
typing.Generic.__init_subclass__ = classmethod(_generic_init_subclass)
3309+
elif sys.version_info < (3, 15):
3310+
def _collect_parameters(
3311+
args,
3312+
*,
3313+
enforce_default_ordering=_marker,
3314+
validate_all=False,
3315+
):
32593316
"""Collect all type variables and parameter specifications in args
32603317
in order of first appearance (lexicographic order).
32613318
3319+
Having an explicit `Generic` or `Protocol` base class determines
3320+
the exact parameter order.
3321+
32623322
For example::
32633323
3264-
assert _collect_parameters((T, Callable[P, T])) == (T, P)
3324+
>>> P = ParamSpec('P')
3325+
>>> T = TypeVar('T')
3326+
>>> _collect_parameters((T, Callable[P, T]))
3327+
(~T, ~P)
3328+
>>> _collect_parameters((list[T], Generic[P, T]))
3329+
(~P, ~T)
32653330
"""
32663331
parameters = []
32673332

32683333
# A required TypeVarLike cannot appear after a TypeVarLike with default
32693334
# if it was a direct call to `Generic[]` or `Protocol[]`
3270-
enforce_default_ordering = _has_generic_or_protocol_as_origin()
3335+
if enforce_default_ordering is _marker:
3336+
enforce_default_ordering = _has_generic_or_protocol_as_origin()
3337+
32713338
default_encountered = False
32723339

32733340
# Also, a TypeVarLike with a default cannot appear after a TypeVarTuple
@@ -3302,6 +3369,17 @@ def _collect_parameters(args):
33023369
' follows type parameter with a default')
33033370

33043371
parameters.append(t)
3372+
elif (
3373+
not validate_all
3374+
and isinstance(t, typing._GenericAlias)
3375+
and t.__origin__ in (Generic, typing.Protocol, Protocol)
3376+
):
3377+
# If we see explicit `Generic[...]` or `Protocol[...]` base classes,
3378+
# we need to just copy them as-is.
3379+
# Unless `validate_all` is passed, in this case it means that
3380+
# we are doing a validation of `Generic` subclasses,
3381+
# then we collect all unique parameters to be able to inspect them.
3382+
parameters = t.__parameters__
33053383
else:
33063384
if _is_unpacked_typevartuple(t):
33073385
type_var_tuple_encountered = True
@@ -3311,8 +3389,55 @@ def _collect_parameters(args):
33113389

33123390
return tuple(parameters)
33133391

3314-
if not _PEP_696_IMPLEMENTED:
3392+
def _generic_init_subclass(cls, *args, **kwargs):
3393+
super(Generic, cls).__init_subclass__(*args, **kwargs)
3394+
tvars = []
3395+
if '__orig_bases__' in cls.__dict__:
3396+
error = Generic in cls.__orig_bases__
3397+
else:
3398+
error = (Generic in cls.__bases__ and
3399+
cls.__name__ != 'Protocol' and
3400+
type(cls) not in (_TypedDictMeta, typing._TypedDictMeta))
3401+
if error:
3402+
raise TypeError("Cannot inherit from plain Generic")
3403+
if '__orig_bases__' in cls.__dict__:
3404+
tvars = _collect_parameters(cls.__orig_bases__, validate_all=True)
3405+
# Look for Generic[T1, ..., Tn].
3406+
# If found, tvars must be a subset of it.
3407+
# If not found, tvars is it.
3408+
# Also check for and reject plain Generic,
3409+
# and reject multiple Generic[...].
3410+
gvars = None
3411+
basename = None
3412+
for base in cls.__orig_bases__:
3413+
if (isinstance(base, typing._GenericAlias) and
3414+
base.__origin__ in (Generic, typing.Protocol, Protocol)):
3415+
if gvars is not None:
3416+
raise TypeError(
3417+
"Cannot inherit from Generic[...] multiple times."
3418+
)
3419+
gvars = base.__parameters__
3420+
basename = base.__origin__.__name__
3421+
if gvars is not None:
3422+
tvarset = set(tvars)
3423+
gvarset = set(gvars)
3424+
if not tvarset <= gvarset:
3425+
s_vars = ', '.join(str(t) for t in tvars if t not in gvarset)
3426+
s_args = ', '.join(str(g) for g in gvars)
3427+
raise TypeError(
3428+
f"Some type variables ({s_vars}) are"
3429+
f" not listed in {basename}[{s_args}]"
3430+
)
3431+
tvars = gvars
3432+
cls.__parameters__ = tuple(tvars)
3433+
3434+
if _PEP_696_IMPLEMENTED:
3435+
typing._collect_type_parameters = _collect_parameters
3436+
typing._generic_init_subclass = _generic_init_subclass
3437+
else:
33153438
typing._collect_parameters = _collect_parameters
3439+
typing.Generic.__init_subclass__ = classmethod(_generic_init_subclass)
3440+
33163441

33173442
# Backport typing.NamedTuple as it exists in Python 3.13.
33183443
# In 3.11, the ability to define generic `NamedTuple`s was supported.

0 commit comments

Comments
 (0)