|
15 | 15 |
|
16 | 16 | from abc import abstractmethod |
17 | 17 | from collections.abc import Iterable, Sequence |
18 | | -from typing import Any, Final, Generic, TypeVar, cast |
| 18 | +from typing import TYPE_CHECKING, Any, Final, Generic, TypeVar, cast |
| 19 | + |
| 20 | +if TYPE_CHECKING: |
| 21 | + from mypy.nodes import TypeAlias |
19 | 22 |
|
20 | 23 | from mypy_extensions import mypyc_attr, trait |
21 | 24 |
|
@@ -356,9 +359,12 @@ class TypeQuery(SyntheticTypeVisitor[T]): |
356 | 359 | """ |
357 | 360 |
|
358 | 361 | def __init__(self) -> None: |
359 | | - # Keep track of the type aliases already visited. This is needed to avoid |
360 | | - # infinite recursion on types like A = Union[int, List[A]]. |
361 | | - self.seen_aliases: set[TypeAliasType] | None = None |
| 362 | + # Keep track of the type alias definitions already visited. This is needed |
| 363 | + # to avoid infinite recursion on recursive type aliases. We track by the |
| 364 | + # underlying TypeAlias node (not TypeAliasType) so that recursive aliases |
| 365 | + # with varying type arguments (e.g. A[P] -> A[Concatenate[int, P]]) are |
| 366 | + # still caught. |
| 367 | + self.seen_aliases: set[TypeAlias] | None = None |
362 | 368 | # By default, we eagerly expand type aliases, and query also types in the |
363 | 369 | # alias target. In most cases this is a desired behavior, but we may want |
364 | 370 | # to skip targets in some cases (e.g. when collecting type variables). |
@@ -447,13 +453,14 @@ def visit_placeholder_type(self, t: PlaceholderType, /) -> T: |
447 | 453 | def visit_type_alias_type(self, t: TypeAliasType, /) -> T: |
448 | 454 | if self.skip_alias_target: |
449 | 455 | return self.query_types(t.args) |
450 | | - # Skip type aliases already visited types to avoid infinite recursion |
451 | | - # (also use this as a simple-minded cache). |
| 456 | + # Skip type aliases already visited to avoid infinite recursion. |
| 457 | + # We track by the TypeAlias node so that recursive aliases with varying |
| 458 | + # type arguments are still caught. |
452 | 459 | if self.seen_aliases is None: |
453 | 460 | self.seen_aliases = set() |
454 | | - elif t in self.seen_aliases: |
455 | | - return self.strategy([]) |
456 | | - self.seen_aliases.add(t) |
| 461 | + elif t.alias in self.seen_aliases: |
| 462 | + return self.query_types(t.args) |
| 463 | + self.seen_aliases.add(t.alias) |
457 | 464 | return get_proper_type(t).accept(self) |
458 | 465 |
|
459 | 466 | def query_types(self, types: Iterable[Type]) -> T: |
@@ -487,10 +494,12 @@ def __init__(self, strategy: int) -> None: |
487 | 494 | else: |
488 | 495 | assert strategy == ALL_STRATEGY |
489 | 496 | self.default = True |
490 | | - # Keep track of the type aliases already visited. This is needed to avoid |
491 | | - # infinite recursion on types like A = Union[int, List[A]]. An empty set is |
492 | | - # represented as None as a micro-optimization. |
493 | | - self.seen_aliases: set[TypeAliasType] | None = None |
| 497 | + # Keep track of the type alias definitions already visited. This is needed |
| 498 | + # to avoid infinite recursion on recursive type aliases. We track by the |
| 499 | + # underlying TypeAlias node (not TypeAliasType) so that recursive aliases |
| 500 | + # with varying type arguments (e.g. A[P] -> A[Concatenate[int, P]]) are |
| 501 | + # still caught. An empty set is represented as None as a micro-optimization. |
| 502 | + self.seen_aliases: set[TypeAlias] | None = None |
494 | 503 | # By default, we eagerly expand type aliases, and query also types in the |
495 | 504 | # alias target. In most cases this is a desired behavior, but we may want |
496 | 505 | # to skip targets in some cases (e.g. when collecting type variables). |
@@ -588,13 +597,14 @@ def visit_placeholder_type(self, t: PlaceholderType, /) -> bool: |
588 | 597 | def visit_type_alias_type(self, t: TypeAliasType, /) -> bool: |
589 | 598 | if self.skip_alias_target: |
590 | 599 | return self.query_types(t.args) |
591 | | - # Skip type aliases already visited types to avoid infinite recursion |
592 | | - # (also use this as a simple-minded cache). |
| 600 | + # Skip type aliases already visited to avoid infinite recursion. |
| 601 | + # We track by the TypeAlias node so that recursive aliases with varying |
| 602 | + # type arguments are still caught. |
593 | 603 | if self.seen_aliases is None: |
594 | 604 | self.seen_aliases = set() |
595 | | - elif t in self.seen_aliases: |
596 | | - return self.default |
597 | | - self.seen_aliases.add(t) |
| 605 | + elif t.alias in self.seen_aliases: |
| 606 | + return self.query_types(t.args) |
| 607 | + self.seen_aliases.add(t.alias) |
598 | 608 | return get_proper_type(t).accept(self) |
599 | 609 |
|
600 | 610 | def query_types(self, types: list[Type] | tuple[Type, ...]) -> bool: |
|
0 commit comments