Skip to content

Commit 798eae0

Browse files
committed
Better handling of generics when narrowing
Notably we preserve behaviour on the testNarrowingCollections test I added
1 parent 87e9425 commit 798eae0

3 files changed

Lines changed: 135 additions & 22 deletions

File tree

mypy/checker.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
)
3333
from mypy.checkpattern import PatternChecker
3434
from mypy.constraints import SUPERTYPE_OF
35-
from mypy.erasetype import erase_type, erase_typevars, remove_instance_last_known_values
35+
from mypy.erasetype import erase_type, erase_typevars, shallow_erase_type_for_equality, remove_instance_last_known_values
3636
from mypy.errorcodes import TYPE_VAR, UNUSED_AWAITABLE, UNUSED_COROUTINE, ErrorCode
3737
from mypy.errors import (
3838
ErrorInfo,
@@ -6540,6 +6540,9 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa
65406540
narrowable_indices={0},
65416541
)
65426542

6543+
# TODO: This remove_optional code should no longer be needed. The only
6544+
# thing it does is paper over a pre-existing deficiency in equality
6545+
# narrowing w.r.t to enums.
65436546
# We only try and narrow away 'None' for now
65446547
if (
65456548
not is_unreachable_map(if_map)
@@ -6688,7 +6691,7 @@ def narrow_type_by_identity_equality(
66886691

66896692
if_map, else_map = conditional_types_to_typemaps(
66906693
operands[i],
6691-
*conditional_types(expr_type, [target], consider_promotion_overlap=True),
6694+
*conditional_types(expr_type, [target], from_equality=True),
66926695
)
66936696
if is_target_for_value_narrowing(get_proper_type(target_type)):
66946697
all_if_maps.append(if_map)
@@ -6727,7 +6730,7 @@ def narrow_type_by_identity_equality(
67276730
if_map, else_map = conditional_types_to_typemaps(
67286731
operands[i],
67296732
*conditional_types(
6730-
expr_type, [target], consider_promotion_overlap=True
6733+
expr_type, [target], from_equality=True
67316734
),
67326735
)
67336736
all_else_maps.append(else_map)
@@ -6767,7 +6770,7 @@ def narrow_type_by_identity_equality(
67676770
if_map, else_map = conditional_types_to_typemaps(
67686771
operands[i],
67696772
*conditional_types(
6770-
expr_type, [target], default=expr_type, consider_promotion_overlap=True
6773+
expr_type, [target], default=expr_type, from_equality=True
67716774
),
67726775
)
67736776
or_if_maps.append(if_map)
@@ -8271,7 +8274,7 @@ def conditional_types(
82718274
default: None = None,
82728275
*,
82738276
consider_runtime_isinstance: bool = True,
8274-
consider_promotion_overlap: bool = False,
8277+
from_equality: bool = False,
82758278
) -> tuple[Type | None, Type | None]: ...
82768279

82778280

@@ -8282,7 +8285,7 @@ def conditional_types(
82828285
default: Type,
82838286
*,
82848287
consider_runtime_isinstance: bool = True,
8285-
consider_promotion_overlap: bool = False,
8288+
from_equality: bool = False,
82868289
) -> tuple[Type, Type]: ...
82878290

82888291

@@ -8292,7 +8295,7 @@ def conditional_types(
82928295
default: Type | None = None,
82938296
*,
82948297
consider_runtime_isinstance: bool = True,
8295-
consider_promotion_overlap: bool = False,
8298+
from_equality: bool = False,
82968299
) -> tuple[Type | None, Type | None]:
82978300
"""Takes in the current type and a proposed type of an expression.
82988301
@@ -8337,7 +8340,7 @@ def conditional_types(
83378340
proposed_type_ranges,
83388341
default=union_item,
83398342
consider_runtime_isinstance=consider_runtime_isinstance,
8340-
consider_promotion_overlap=consider_promotion_overlap,
8343+
from_equality=from_equality,
83418344
)
83428345
yes_items.append(yes_type)
83438346
no_items.append(no_type)
@@ -8382,17 +8385,29 @@ def conditional_types(
83828385
consider_runtime_isinstance=consider_runtime_isinstance,
83838386
)
83848387
return default, remainder
8385-
if not is_overlapping_types(
8386-
current_type, proposed_type, ignore_promotions=not consider_promotion_overlap
8387-
):
8388-
# Expression is never of any type in proposed_type_ranges
8389-
return UninhabitedType(), default
8390-
if consider_promotion_overlap and not is_overlapping_types(
8391-
current_type, proposed_type, ignore_promotions=True
8392-
):
8393-
# We set consider_promotion_overlap when comparing equality. This is one of the places
8394-
# at runtime where subtyping with promotion does happen to match runtime semantics
8395-
return default, default
8388+
8389+
if from_equality:
8390+
# We erase generic args because values with different generic types can compare equal
8391+
# For instance, cast(list[str], []) and cast(list[int], [])
8392+
proposed_type = shallow_erase_type_for_equality(proposed_type)
8393+
if not is_overlapping_types(
8394+
current_type, proposed_type, ignore_promotions=False
8395+
):
8396+
# Equality narrowing is one of the places at runtime where subtyping with promotion
8397+
# does happen to match runtime semantics
8398+
# Expression is never of any type in proposed_type_ranges
8399+
return UninhabitedType(), default
8400+
if not is_overlapping_types(
8401+
current_type, proposed_type, ignore_promotions=True
8402+
):
8403+
return default, default
8404+
else:
8405+
if not is_overlapping_types(
8406+
current_type, proposed_type, ignore_promotions=True
8407+
):
8408+
# Expression is never of any type in proposed_type_ranges
8409+
return UninhabitedType(), default
8410+
83968411
# we can only restrict when the type is precise, not bounded
83978412
proposed_precise_type = UnionType.make_union(
83988413
[type_range.item for type_range in proposed_type_ranges if not type_range.is_upper_bound]
@@ -8641,9 +8656,6 @@ def reduce_and_conditional_type_maps(ms: list[TypeMap], *, use_meet: bool) -> Ty
86418656
BUILTINS_CUSTOM_EQ_CHECKS: Final = {
86428657
"builtins.bytearray",
86438658
"builtins.memoryview",
8644-
"builtins.list",
8645-
"builtins.dict",
8646-
"builtins.set",
86478659
}
86488660

86498661

mypy/erasetype.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,3 +285,18 @@ def visit_union_type(self, t: UnionType) -> Type:
285285
merged.append(orig_item)
286286
return UnionType.make_union(merged)
287287
return new
288+
289+
290+
291+
def shallow_erase_type_for_equality(typ: Type) -> ProperType:
292+
"""Erase type variables from Instance's"""
293+
p_typ = get_proper_type(typ)
294+
if isinstance(p_typ, Instance):
295+
if not p_typ.args:
296+
return p_typ
297+
args = erased_vars(p_typ.type.defn.type_vars, TypeOfAny.special_form)
298+
return Instance(p_typ.type, args, p_typ.line)
299+
if isinstance(p_typ, UnionType):
300+
items = [shallow_erase_type_for_equality(item) for item in p_typ.items]
301+
return UnionType.make_union(items)
302+
return p_typ

test-data/unit/check-narrowing.test

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,6 +1065,92 @@ def f(x: Custom, y: CustomSub):
10651065
reveal_type(y) # N: Revealed type is "__main__.CustomSub"
10661066
[builtins fixtures/tuple.pyi]
10671067

1068+
[case testNarrowingCustomEqualityGeneric]
1069+
# flags: --strict-equality --warn-unreachable
1070+
from __future__ import annotations
1071+
from typing import Union
1072+
1073+
class Custom:
1074+
def __eq__(self, other: object) -> bool:
1075+
raise
1076+
1077+
class Default: ...
1078+
1079+
def f1(x: list[Custom] | Default, y: list[int]):
1080+
if x == y: # E: Non-overlapping equality check (left operand type: "list[Custom] | Default", right operand type: "list[int]")
1081+
reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom]"
1082+
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int]"
1083+
else:
1084+
reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom] | __main__.Default"
1085+
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int]"
1086+
1087+
f1([], [])
1088+
1089+
def f2(x: list[Custom] | Default, y: list[int] | list[Default]):
1090+
if x == y: # E: Non-overlapping equality check (left operand type: "list[Custom] | Default", right operand type: "list[int] | list[Default]")
1091+
reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom]"
1092+
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int] | builtins.list[__main__.Default]"
1093+
else:
1094+
reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom] | __main__.Default"
1095+
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int] | builtins.list[__main__.Default]"
1096+
1097+
listcustom_or_default = Union[list[Custom], Default]
1098+
listint_or_default = Union[list[int], list[Default]]
1099+
1100+
def f2_with_alias(x: listcustom_or_default, y: listint_or_default):
1101+
if x == y: # E: Non-overlapping equality check (left operand type: "list[Custom] | Default", right operand type: "list[int] | list[Default]")
1102+
reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom]"
1103+
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int] | builtins.list[__main__.Default]"
1104+
else:
1105+
reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom] | __main__.Default"
1106+
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int] | builtins.list[__main__.Default]"
1107+
1108+
def f3(x: Custom | dict[str, str], y: dict[int, int]):
1109+
if x == y:
1110+
reveal_type(x) # N: Revealed type is "__main__.Custom | builtins.dict[builtins.str, builtins.str]"
1111+
reveal_type(y) # N: Revealed type is "builtins.dict[builtins.int, builtins.int]"
1112+
else:
1113+
reveal_type(x) # N: Revealed type is "__main__.Custom | builtins.dict[builtins.str, builtins.str]"
1114+
reveal_type(y) # N: Revealed type is "builtins.dict[builtins.int, builtins.int]"
1115+
[builtins fixtures/primitives.pyi]
1116+
1117+
[case testNarrowingRecursiveCallable]
1118+
# flags: --strict-equality --warn-unreachable
1119+
from __future__ import annotations
1120+
from typing import Callable
1121+
1122+
class A: ...
1123+
class B: ...
1124+
1125+
T = Callable[[A], "S"]
1126+
S = Callable[[B], "T"]
1127+
1128+
def f(x: S, y: T):
1129+
if x == y: # E: Unsupported left operand type for == ("Callable[[B], T]")
1130+
reveal_type(x) # N: Revealed type is "def (__main__.B) -> def (__main__.A) -> ..."
1131+
reveal_type(y) # N: Revealed type is "def (__main__.A) -> def (__main__.B) -> ..."
1132+
else:
1133+
reveal_type(x) # N: Revealed type is "def (__main__.B) -> def (__main__.A) -> ..."
1134+
reveal_type(y) # N: Revealed type is "def (__main__.A) -> def (__main__.B) -> ..."
1135+
[builtins fixtures/tuple.pyi]
1136+
1137+
[case testNarrowingRecursiveUnion]
1138+
# flags: --strict-equality --warn-unreachable
1139+
from __future__ import annotations
1140+
from typing import Union
1141+
1142+
class A: ...
1143+
class B: ...
1144+
1145+
T = Union[A, "S"]
1146+
S = Union[B, "T"] # E: Invalid recursive alias: a union item of itself
1147+
1148+
def f(x: S, y: T):
1149+
if x == y:
1150+
reveal_type(x) # N: Revealed type is "Any"
1151+
reveal_type(y) # N: Revealed type is "__main__.A | Any"
1152+
[builtins fixtures/tuple.pyi]
1153+
10681154
[case testNarrowingUnreachableCases]
10691155
# flags: --strict-equality --warn-unreachable
10701156
from typing import Literal, Union

0 commit comments

Comments
 (0)