Skip to content

Commit 525c6d1

Browse files
authored
Fix bug when narrowing union containing custom eq against custom eq (#20754)
I rewrote the logic for custom equality in #20643 . This is a soundness bug in that rewrite. Fixes #20750
1 parent 62575ba commit 525c6d1

2 files changed

Lines changed: 49 additions & 1 deletion

File tree

mypy/checker.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6736,13 +6736,17 @@ def narrow_type_by_identity_equality(
67366736
or_else_maps: list[TypeMap] = []
67376737
for expr_type in union_expr_type.items:
67386738
if has_custom_eq_checks(expr_type):
6739-
# Always include union items with custom __eq__ in the type
6739+
# Always include the union items with custom __eq__ in the type
67406740
# we narrow to in the if_map
67416741
or_if_maps.append({operands[i]: expr_type})
67426742

67436743
expr_type = coerce_to_literal(try_expanding_sum_type_to_union(expr_type, None))
67446744
for j in expr_indices:
67456745
if j in custom_eq_indices:
6746+
if i == j:
6747+
continue
6748+
# If we compare to a target with custom __eq__, we cannot narrow at all
6749+
or_if_maps.append({})
67466750
continue
67476751
target_type = operand_types[j]
67486752
if should_coerce_literals:

test-data/unit/check-narrowing.test

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,6 +1005,50 @@ def f(x: Custom | None, y: int | None):
10051005
reveal_type(y) # N: Revealed type is "builtins.int | None"
10061006
[builtins fixtures/primitives.pyi]
10071007

1008+
[case testNarrowingCustomEqualityUnion4]
1009+
# flags: --strict-equality --warn-unreachable
1010+
from __future__ import annotations
1011+
from typing import Any
1012+
1013+
class Custom1:
1014+
def __eq__(self, other: object) -> bool:
1015+
raise
1016+
1017+
class Custom2:
1018+
def __eq__(self, other: object) -> bool:
1019+
raise
1020+
1021+
def f(x: Custom1 | int, y: Custom2 | int):
1022+
if x == y:
1023+
reveal_type(x) # N: Revealed type is "__main__.Custom1 | builtins.int"
1024+
reveal_type(y) # N: Revealed type is "__main__.Custom2 | builtins.int"
1025+
else:
1026+
reveal_type(x) # N: Revealed type is "__main__.Custom1 | builtins.int"
1027+
reveal_type(y) # N: Revealed type is "__main__.Custom2 | builtins.int"
1028+
[builtins fixtures/primitives.pyi]
1029+
1030+
[case testNarrowingCustomEqualitySubclass]
1031+
# flags: --strict-equality --warn-unreachable
1032+
from __future__ import annotations
1033+
from typing import Any
1034+
1035+
class Custom:
1036+
def __eq__(self, other: object) -> bool:
1037+
raise
1038+
1039+
class CustomSub(Custom):
1040+
def __eq__(self, other: object) -> bool:
1041+
raise
1042+
1043+
def f(x: Custom, y: CustomSub):
1044+
if x == y:
1045+
reveal_type(x) # N: Revealed type is "__main__.Custom"
1046+
reveal_type(y) # N: Revealed type is "__main__.CustomSub"
1047+
else:
1048+
reveal_type(x) # N: Revealed type is "__main__.Custom"
1049+
reveal_type(y) # N: Revealed type is "__main__.CustomSub"
1050+
[builtins fixtures/tuple.pyi]
1051+
10081052
[case testNarrowingUnreachableCases]
10091053
# flags: --strict-equality --warn-unreachable
10101054
from typing import Literal, Union

0 commit comments

Comments
 (0)