Skip to content

Commit bfe2098

Browse files
committed
.
1 parent 06a5f5d commit bfe2098

5 files changed

Lines changed: 118 additions & 38 deletions

File tree

mypy/checker.py

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,12 @@
3232
)
3333
from mypy.checkpattern import PatternChecker
3434
from mypy.constraints import SUPERTYPE_OF
35-
from mypy.erasetype import erase_type, erase_typevars, remove_instance_last_known_values
35+
from mypy.erasetype import (
36+
erase_type,
37+
erase_typevars,
38+
remove_instance_last_known_values,
39+
shallow_erase_type_for_equality,
40+
)
3641
from mypy.errorcodes import TYPE_VAR, UNUSED_AWAITABLE, UNUSED_COROUTINE, ErrorCode
3742
from mypy.errors import (
3843
ErrorInfo,
@@ -45,7 +50,7 @@
4550
from mypy.expandtype import expand_type
4651
from mypy.literals import Key, extract_var_from_literal_hash, literal, literal_hash
4752
from mypy.maptype import map_instance_to_supertype
48-
from mypy.meet import is_overlapping_erased_types, is_overlapping_types, meet_types
53+
from mypy.meet import is_overlapping_types, meet_types
4954
from mypy.message_registry import ErrorMessage
5055
from mypy.messages import (
5156
SUGGESTED_TEST_FIXTURES,
@@ -6540,19 +6545,6 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa
65406545
narrowable_indices={0},
65416546
)
65426547

6543-
# We only try and narrow away 'None' for now
6544-
if (
6545-
not is_unreachable_map(if_map)
6546-
and is_overlapping_none(item_type)
6547-
and not is_overlapping_none(collection_item_type)
6548-
and not (
6549-
isinstance(collection_item_type, Instance)
6550-
and collection_item_type.type.fullname == "builtins.object"
6551-
)
6552-
and is_overlapping_erased_types(item_type, collection_item_type)
6553-
):
6554-
if_map[operands[left_index]] = remove_optional(item_type)
6555-
65566548
if right_index in narrowable_operand_index_to_hash:
65576549
if_type, else_type = self.conditional_types_for_iterable(
65586550
item_type, iterable_type
@@ -6676,6 +6668,9 @@ def narrow_type_by_identity_equality(
66766668
target_type = operand_types[j]
66776669
if should_coerce_literals:
66786670
target_type = coerce_to_literal(target_type)
6671+
# Type A[T1] could compare equal to A[T2] even if T1 is disjoint from T2
6672+
# e.g. cast(list[int], []) == cast(list[str], [])
6673+
target_type = shallow_erase_type_for_equality(target_type)
66796674

66806675
if (
66816676
# See comments in ambiguous_enum_equality_keys
@@ -6689,7 +6684,7 @@ def narrow_type_by_identity_equality(
66896684
if_map, else_map = conditional_types_to_typemaps(
66906685
operands[i], *conditional_types(expr_type, [target])
66916686
)
6692-
if is_target_for_value_narrowing(get_proper_type(target_type)):
6687+
if is_target_for_value_narrowing(target_type):
66936688
all_if_maps.append(if_map)
66946689
all_else_maps.append(else_map)
66956690
else:
@@ -6758,13 +6753,15 @@ def narrow_type_by_identity_equality(
67586753
target_type = operand_types[j]
67596754
if should_coerce_literals:
67606755
target_type = coerce_to_literal(target_type)
6756+
target_type = shallow_erase_type_for_equality(target_type)
6757+
67616758
target = TypeRange(target_type, is_upper_bound=False)
67626759

67636760
if_map, else_map = conditional_types_to_typemaps(
67646761
operands[i], *conditional_types(expr_type, [target], default=expr_type)
67656762
)
67666763
or_if_maps.append(if_map)
6767-
if is_target_for_value_narrowing(get_proper_type(target_type)):
6764+
if is_target_for_value_narrowing(target_type):
67686765
or_else_maps.append(else_map)
67696766

67706767
all_if_maps.append(reduce_or_conditional_type_maps(or_if_maps))
@@ -8609,13 +8606,7 @@ def reduce_and_conditional_type_maps(ms: list[TypeMap], *, use_meet: bool) -> Ty
86098606
return result
86108607

86118608

8612-
BUILTINS_CUSTOM_EQ_CHECKS: Final = {
8613-
"builtins.bytearray",
8614-
"builtins.memoryview",
8615-
"builtins.list",
8616-
"builtins.dict",
8617-
"builtins.set",
8618-
}
8609+
BUILTINS_CUSTOM_EQ_CHECKS: Final = {"builtins.bytearray", "builtins.memoryview"}
86198610

86208611

86218612
def has_custom_eq_checks(t: Type) -> bool:

mypy/erasetype.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,3 +285,17 @@ def visit_union_type(self, t: UnionType) -> Type:
285285
merged.append(orig_item)
286286
return UnionType.make_union(merged)
287287
return new
288+
289+
290+
def shallow_erase_type_for_equality(typ: Type) -> ProperType:
291+
"""Erase type variables from Instance's inside a type."""
292+
p_typ = get_proper_type(typ)
293+
if isinstance(p_typ, Instance):
294+
if not p_typ.args:
295+
return p_typ
296+
args = erased_vars(p_typ.type.defn.type_vars, TypeOfAny.special_form)
297+
return Instance(p_typ.type, args, p_typ.line)
298+
if isinstance(p_typ, UnionType):
299+
items = [shallow_erase_type_for_equality(item) for item in p_typ.items]
300+
return UnionType.make_union(items)
301+
return p_typ

mypy/meet.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from collections.abc import Callable
44

55
from mypy import join
6-
from mypy.erasetype import erase_type
76
from mypy.maptype import map_instance_to_supertype
87
from mypy.state import state
98
from mypy.subtypes import (
@@ -657,18 +656,6 @@ def _type_object_overlap(left: Type, right: Type) -> bool:
657656
return False
658657

659658

660-
def is_overlapping_erased_types(
661-
left: Type, right: Type, *, ignore_promotions: bool = False
662-
) -> bool:
663-
"""The same as 'is_overlapping_erased_types', except the types are erased first."""
664-
return is_overlapping_types(
665-
erase_type(left),
666-
erase_type(right),
667-
ignore_promotions=ignore_promotions,
668-
prohibit_none_typevar_overlap=True,
669-
)
670-
671-
672659
def are_typed_dicts_overlapping(
673660
left: TypedDictType, right: TypedDictType, is_overlapping: Callable[[Type, Type], bool]
674661
) -> bool:

test-data/unit/check-narrowing.test

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,6 +1065,92 @@ def f(x: Custom, y: CustomSub):
10651065
reveal_type(y) # N: Revealed type is "__main__.CustomSub"
10661066
[builtins fixtures/tuple.pyi]
10671067

1068+
[case testNarrowingCustomEqualityGeneric]
1069+
# flags: --strict-equality --warn-unreachable
1070+
from __future__ import annotations
1071+
from typing import Union
1072+
1073+
class Custom:
1074+
def __eq__(self, other: object) -> bool:
1075+
raise
1076+
1077+
class Default: ...
1078+
1079+
def f1(x: list[Custom] | Default, y: list[int]):
1080+
if x == y: # E: Non-overlapping equality check (left operand type: "list[Custom] | Default", right operand type: "list[int]")
1081+
reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom]"
1082+
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int]"
1083+
else:
1084+
reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom] | __main__.Default"
1085+
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int]"
1086+
1087+
f1([], [])
1088+
1089+
def f2(x: list[Custom] | Default, y: list[int] | list[Default]):
1090+
if x == y: # E: Non-overlapping equality check (left operand type: "list[Custom] | Default", right operand type: "list[int] | list[Default]")
1091+
reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom]"
1092+
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int] | builtins.list[__main__.Default]"
1093+
else:
1094+
reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom] | __main__.Default"
1095+
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int] | builtins.list[__main__.Default]"
1096+
1097+
listcustom_or_default = Union[list[Custom], Default]
1098+
listint_or_default = Union[list[int], list[Default]]
1099+
1100+
def f2_with_alias(x: listcustom_or_default, y: listint_or_default):
1101+
if x == y: # E: Non-overlapping equality check (left operand type: "list[Custom] | Default", right operand type: "list[int] | list[Default]")
1102+
reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom]"
1103+
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int] | builtins.list[__main__.Default]"
1104+
else:
1105+
reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom] | __main__.Default"
1106+
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int] | builtins.list[__main__.Default]"
1107+
1108+
def f3(x: Custom | dict[str, str], y: dict[int, int]):
1109+
if x == y:
1110+
reveal_type(x) # N: Revealed type is "__main__.Custom | builtins.dict[builtins.str, builtins.str]"
1111+
reveal_type(y) # N: Revealed type is "builtins.dict[builtins.int, builtins.int]"
1112+
else:
1113+
reveal_type(x) # N: Revealed type is "__main__.Custom | builtins.dict[builtins.str, builtins.str]"
1114+
reveal_type(y) # N: Revealed type is "builtins.dict[builtins.int, builtins.int]"
1115+
[builtins fixtures/primitives.pyi]
1116+
1117+
[case testNarrowingRecursiveCallable]
1118+
# flags: --strict-equality --warn-unreachable
1119+
from __future__ import annotations
1120+
from typing import Callable
1121+
1122+
class A: ...
1123+
class B: ...
1124+
1125+
T = Callable[[A], "S"]
1126+
S = Callable[[B], "T"]
1127+
1128+
def f(x: S, y: T):
1129+
if x == y: # E: Unsupported left operand type for == ("Callable[[B], T]")
1130+
reveal_type(x) # N: Revealed type is "def (__main__.B) -> def (__main__.A) -> ..."
1131+
reveal_type(y) # N: Revealed type is "def (__main__.A) -> def (__main__.B) -> ..."
1132+
else:
1133+
reveal_type(x) # N: Revealed type is "def (__main__.B) -> def (__main__.A) -> ..."
1134+
reveal_type(y) # N: Revealed type is "def (__main__.A) -> def (__main__.B) -> ..."
1135+
[builtins fixtures/tuple.pyi]
1136+
1137+
[case testNarrowingRecursiveUnion]
1138+
# flags: --strict-equality --warn-unreachable
1139+
from __future__ import annotations
1140+
from typing import Union
1141+
1142+
class A: ...
1143+
class B: ...
1144+
1145+
T = Union[A, "S"]
1146+
S = Union[B, "T"] # E: Invalid recursive alias: a union item of itself
1147+
1148+
def f(x: S, y: T):
1149+
if x == y:
1150+
reveal_type(x) # N: Revealed type is "Any"
1151+
reveal_type(y) # N: Revealed type is "__main__.A | Any"
1152+
[builtins fixtures/tuple.pyi]
1153+
10681154
[case testNarrowingUnreachableCases]
10691155
# flags: --strict-equality --warn-unreachable
10701156
from typing import Literal, Union

test-data/unit/check-tuples.test

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1540,7 +1540,9 @@ class B: pass
15401540

15411541
def f1(possibles: Tuple[int, Tuple[A]], x: Optional[Tuple[B]]):
15421542
if x in possibles:
1543-
reveal_type(x) # N: Revealed type is "tuple[__main__.B]"
1543+
# TODO: this branch is actually unreachable
1544+
# This is an easy fix: https://github.com/python/mypy/pull/20660
1545+
reveal_type(x) # N: Revealed type is "tuple[__main__.B] | None"
15441546
else:
15451547
reveal_type(x) # N: Revealed type is "tuple[__main__.B] | None"
15461548

0 commit comments

Comments
 (0)