@@ -6718,6 +6718,7 @@ def narrow_type_by_equality(
67186718 is_target_for_value_narrowing = is_singleton_identity_type
67196719 should_coerce_literals = True
67206720 should_narrow_by_identity_equality = True
6721+ enum_comparison_is_ambiguous = False
67216722
67226723 elif operator in {"==" , "!=" }:
67236724 is_target_for_value_narrowing = is_singleton_equality_type
@@ -6730,9 +6731,8 @@ def narrow_type_by_equality(
67306731 break
67316732
67326733 expr_types = [operand_types [i ] for i in expr_indices ]
6733- should_narrow_by_identity_equality = not any (
6734- map (has_custom_eq_checks , expr_types )
6735- ) and not is_ambiguous_mix_of_enums (expr_types )
6734+ should_narrow_by_identity_equality = not any (map (has_custom_eq_checks , expr_types ))
6735+ enum_comparison_is_ambiguous = True
67366736 else :
67376737 raise AssertionError
67386738
@@ -6765,11 +6765,18 @@ def narrow_type_by_equality(
67656765 for i in expr_indices :
67666766 if i not in narrowable_indices :
67676767 continue
6768+ expr_type = coerce_to_literal (operand_types [i ])
6769+ expr_type = try_expanding_sum_type_to_union (expr_type , None )
6770+ expr_enum_keys = ambiguous_enum_equality_keys (expr_type )
67686771 for j , target in value_targets :
67696772 if i == j :
67706773 continue
6771- expr_type = coerce_to_literal (operand_types [i ])
6772- expr_type = try_expanding_sum_type_to_union (expr_type , None )
6774+ if (
6775+ # See comments in ambiguous_enum_equality_keys
6776+ enum_comparison_is_ambiguous
6777+ and len (expr_enum_keys | ambiguous_enum_equality_keys (target .item )) > 1
6778+ ):
6779+ continue
67736780 if_map , else_map = conditional_types_to_typemaps (
67746781 operands [i ], * conditional_types (expr_type , [target ])
67756782 )
@@ -6779,10 +6786,10 @@ def narrow_type_by_equality(
67796786 for i in expr_indices :
67806787 if i not in narrowable_indices :
67816788 continue
6789+ expr_type = operand_types [i ]
67826790 for j , target in type_targets :
67836791 if i == j :
67846792 continue
6785- expr_type = operand_types [i ]
67866793 if_map , else_map = conditional_types_to_typemaps (
67876794 operands [i ], * conditional_types (expr_type , [target ])
67886795 )
@@ -9371,47 +9378,44 @@ def visit_starred_pattern(self, p: StarredPattern) -> None:
93719378 self .lvalue = False
93729379
93739380
9374- def is_ambiguous_mix_of_enums (types : list [Type ]) -> bool :
9375- """Do types have IntEnum/StrEnum types that are potentially overlapping with other types?
9381+ def ambiguous_enum_equality_keys (t : Type ) -> set [str ]:
9382+ """
9383+ Used when narrowing types based on equality.
93769384
9377- If True, we shouldn't attempt type narrowing based on enum values, as it gets
9378- too ambiguous.
9385+ Certain kinds of enums can compare equal to values of other types, so doing type math
9386+ the way `conditional_types` does will be misleading if you expect it to correspond to
9387+ conditions based on equality comparisons.
93799388
9380- For example, return True if there's an 'int' type together with an IntEnum literal.
9381- However, IntEnum together with a literal of the same IntEnum type is not ambiguous.
9389+ For example, StrEnum classes can compare equal to str values. So if we see
9390+ `val: StrEnum; if val == "foo": ...` we currently avoid narrowing.
9391+ Note that we do wish to continue narrowing for `if val == StrEnum.MEMBER: ...`
93829392 """
93839393 # We need these things for this to be ambiguous:
9384- # (1) an IntEnum or StrEnum type
9394+ # (1) an IntEnum or StrEnum type or enum subclass of int or str
93859395 # (2) either a different IntEnum/StrEnum type or a non-enum type ("<other>")
9386- #
9387- # It would be slightly more correct to calculate this separately for IntEnum and
9388- # StrEnum related types, as an IntEnum can't be confused with a StrEnum.
9389- return len (_ambiguous_enum_variants (types )) > 1
9390-
9391-
9392- def _ambiguous_enum_variants (types : list [Type ]) -> set [str ]:
93939396 result = set ()
9394- for t in types :
9395- t = get_proper_type (t )
9396- if isinstance (t , UnionType ):
9397- result .update (_ambiguous_enum_variants (t .items ))
9398- elif isinstance (t , Instance ):
9399- if t .last_known_value :
9400- result .update (_ambiguous_enum_variants ([t .last_known_value ]))
9401- elif t .type .is_enum and any (
9402- base .fullname in ("enum.IntEnum" , "enum.StrEnum" ) for base in t .type .mro
9403- ):
9404- result .add (t .type .fullname )
9405- elif not t .type .is_enum :
9406- # These might compare equal to IntEnum/StrEnum types (e.g. Decimal), so
9407- # let's be conservative
9408- result .add ("<other>" )
9409- elif isinstance (t , LiteralType ):
9410- result .update (_ambiguous_enum_variants ([t .fallback ]))
9411- elif isinstance (t , NoneType ):
9412- pass
9413- else :
9397+ t = get_proper_type (t )
9398+ if isinstance (t , UnionType ):
9399+ for item in t .items :
9400+ result .update (ambiguous_enum_equality_keys (item ))
9401+ elif isinstance (t , Instance ):
9402+ if t .last_known_value :
9403+ result .update (ambiguous_enum_equality_keys (t .last_known_value ))
9404+ elif t .type .is_enum and any (
9405+ base .fullname in ("enum.IntEnum" , "enum.StrEnum" , "builtins.str" , "builtins.int" )
9406+ for base in t .type .mro
9407+ ):
9408+ result .add (t .type .fullname )
9409+ elif not t .type .is_enum :
9410+ # These might compare equal to IntEnum/StrEnum types (e.g. Decimal), so
9411+ # let's be conservative
94149412 result .add ("<other>" )
9413+ elif isinstance (t , LiteralType ):
9414+ result .update (ambiguous_enum_equality_keys (t .fallback ))
9415+ elif isinstance (t , NoneType ):
9416+ pass
9417+ else :
9418+ result .add ("<other>" )
94159419 return result
94169420
94179421
0 commit comments