Skip to content

Commit d29a6df

Browse files
committed
Improve narrowing logic for Enum int and str subclasses
1 parent 6424d0b commit d29a6df

2 files changed

Lines changed: 67 additions & 42 deletions

File tree

mypy/checker.py

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6718,6 +6718,7 @@ def narrow_type_by_equality(
67186718
is_target_for_value_narrowing = is_singleton_identity_type
67196719
should_coerce_literals = True
67206720
should_narrow_by_identity_equality = True
6721+
enum_comparison_is_ambiguous = False
67216722

67226723
elif operator in {"==", "!="}:
67236724
is_target_for_value_narrowing = is_singleton_equality_type
@@ -6730,9 +6731,8 @@ def narrow_type_by_equality(
67306731
break
67316732

67326733
expr_types = [operand_types[i] for i in expr_indices]
6733-
should_narrow_by_identity_equality = not any(
6734-
map(has_custom_eq_checks, expr_types)
6735-
) and not is_ambiguous_mix_of_enums(expr_types)
6734+
should_narrow_by_identity_equality = not any(map(has_custom_eq_checks, expr_types))
6735+
enum_comparison_is_ambiguous = True
67366736
else:
67376737
raise AssertionError
67386738

@@ -6765,11 +6765,18 @@ def narrow_type_by_equality(
67656765
for i in expr_indices:
67666766
if i not in narrowable_indices:
67676767
continue
6768+
expr_type = coerce_to_literal(operand_types[i])
6769+
expr_type = try_expanding_sum_type_to_union(expr_type, None)
6770+
expr_enum_keys = ambiguous_enum_equality_keys(expr_type)
67686771
for j, target in value_targets:
67696772
if i == j:
67706773
continue
6771-
expr_type = coerce_to_literal(operand_types[i])
6772-
expr_type = try_expanding_sum_type_to_union(expr_type, None)
6774+
if (
6775+
# See comments in ambiguous_enum_equality_keys
6776+
enum_comparison_is_ambiguous
6777+
and len(expr_enum_keys | ambiguous_enum_equality_keys(target.item)) > 1
6778+
):
6779+
continue
67736780
if_map, else_map = conditional_types_to_typemaps(
67746781
operands[i], *conditional_types(expr_type, [target])
67756782
)
@@ -6779,10 +6786,10 @@ def narrow_type_by_equality(
67796786
for i in expr_indices:
67806787
if i not in narrowable_indices:
67816788
continue
6789+
expr_type = operand_types[i]
67826790
for j, target in type_targets:
67836791
if i == j:
67846792
continue
6785-
expr_type = operand_types[i]
67866793
if_map, else_map = conditional_types_to_typemaps(
67876794
operands[i], *conditional_types(expr_type, [target])
67886795
)
@@ -9371,47 +9378,44 @@ def visit_starred_pattern(self, p: StarredPattern) -> None:
93719378
self.lvalue = False
93729379

93739380

9374-
def is_ambiguous_mix_of_enums(types: list[Type]) -> bool:
9375-
"""Do types have IntEnum/StrEnum types that are potentially overlapping with other types?
9381+
def ambiguous_enum_equality_keys(t: Type) -> set[str]:
9382+
"""
9383+
Used when narrowing types based on equality.
93769384
9377-
If True, we shouldn't attempt type narrowing based on enum values, as it gets
9378-
too ambiguous.
9385+
Certain kinds of enums can compare equal to values of other types, so doing type math
9386+
the way `conditional_types` does will be misleading if you expect it to correspond to
9387+
conditions based on equality comparisons.
93799388
9380-
For example, return True if there's an 'int' type together with an IntEnum literal.
9381-
However, IntEnum together with a literal of the same IntEnum type is not ambiguous.
9389+
For example, StrEnum classes can compare equal to str values. So if we see
9390+
`val: StrEnum; if val == "foo": ...` we currently avoid narrowing.
9391+
Note that we do wish to continue narrowing for `if val == StrEnum.MEMBER: ...`
93829392
"""
93839393
# We need these things for this to be ambiguous:
9384-
# (1) an IntEnum or StrEnum type
9394+
# (1) an IntEnum or StrEnum type or enum subclass of int or str
93859395
# (2) either a different IntEnum/StrEnum type or a non-enum type ("<other>")
9386-
#
9387-
# It would be slightly more correct to calculate this separately for IntEnum and
9388-
# StrEnum related types, as an IntEnum can't be confused with a StrEnum.
9389-
return len(_ambiguous_enum_variants(types)) > 1
9390-
9391-
9392-
def _ambiguous_enum_variants(types: list[Type]) -> set[str]:
93939396
result = set()
9394-
for t in types:
9395-
t = get_proper_type(t)
9396-
if isinstance(t, UnionType):
9397-
result.update(_ambiguous_enum_variants(t.items))
9398-
elif isinstance(t, Instance):
9399-
if t.last_known_value:
9400-
result.update(_ambiguous_enum_variants([t.last_known_value]))
9401-
elif t.type.is_enum and any(
9402-
base.fullname in ("enum.IntEnum", "enum.StrEnum") for base in t.type.mro
9403-
):
9404-
result.add(t.type.fullname)
9405-
elif not t.type.is_enum:
9406-
# These might compare equal to IntEnum/StrEnum types (e.g. Decimal), so
9407-
# let's be conservative
9408-
result.add("<other>")
9409-
elif isinstance(t, LiteralType):
9410-
result.update(_ambiguous_enum_variants([t.fallback]))
9411-
elif isinstance(t, NoneType):
9412-
pass
9413-
else:
9397+
t = get_proper_type(t)
9398+
if isinstance(t, UnionType):
9399+
for item in t.items:
9400+
result.update(ambiguous_enum_equality_keys(item))
9401+
elif isinstance(t, Instance):
9402+
if t.last_known_value:
9403+
result.update(ambiguous_enum_equality_keys(t.last_known_value))
9404+
elif t.type.is_enum and any(
9405+
base.fullname in ("enum.IntEnum", "enum.StrEnum", "builtins.str", "builtins.int")
9406+
for base in t.type.mro
9407+
):
9408+
result.add(t.type.fullname)
9409+
elif not t.type.is_enum:
9410+
# These might compare equal to IntEnum/StrEnum types (e.g. Decimal), so
9411+
# let's be conservative
94149412
result.add("<other>")
9413+
elif isinstance(t, LiteralType):
9414+
result.update(ambiguous_enum_equality_keys(t.fallback))
9415+
elif isinstance(t, NoneType):
9416+
pass
9417+
else:
9418+
result.add("<other>")
94159419
return result
94169420

94179421

test-data/unit/check-narrowing.test

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2124,7 +2124,7 @@ else:
21242124
[builtins fixtures/ops.pyi]
21252125

21262126
[case testNarrowingWithIntEnum]
2127-
# mypy: strict-equality
2127+
# flags: --strict-equality --warn-unreachable
21282128
from __future__ import annotations
21292129
from typing import Any
21302130
from enum import IntEnum
@@ -2179,7 +2179,7 @@ def f6(x: IE) -> None:
21792179
[builtins fixtures/primitives.pyi]
21802180

21812181
[case testNarrowingWithIntEnum2]
2182-
# mypy: strict-equality
2182+
# flags: --strict-equality --warn-unreachable
21832183
from __future__ import annotations
21842184
from typing import Any
21852185
from enum import IntEnum, Enum
@@ -2284,6 +2284,27 @@ def f4(x: SE) -> None:
22842284
reveal_type(x) # N: Revealed type is "Literal[__main__.SE.B]"
22852285
[builtins fixtures/primitives.pyi]
22862286

2287+
[case testNarrowingWithEnumStrSubclass]
2288+
# flags: --strict-equality --warn-unreachable
2289+
from enum import Enum
2290+
2291+
class ParameterLocation(str, Enum):
2292+
QUERY = "query"
2293+
HEADER = "header"
2294+
PATH = "path"
2295+
2296+
def foo(location: ParameterLocation):
2297+
if location == "path":
2298+
reveal_type(location) # N: Revealed type is "__main__.ParameterLocation"
2299+
else:
2300+
reveal_type(location) # N: Revealed type is "__main__.ParameterLocation"
2301+
2302+
if location == ParameterLocation.PATH:
2303+
reveal_type(location) # N: Revealed type is "Literal[__main__.ParameterLocation.PATH]"
2304+
else:
2305+
reveal_type(location) # N: Revealed type is "Literal[__main__.ParameterLocation.QUERY] | Literal[__main__.ParameterLocation.HEADER]"
2306+
[builtins fixtures/primitives.pyi]
2307+
22872308
[case testConsistentNarrowingEqAndIn]
22882309
# flags: --python-version 3.10
22892310

0 commit comments

Comments
 (0)