@@ -6494,7 +6494,7 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa
64946494 and not self .is_literal_enum (expr )
64956495 # CallableType type objects are usually already maximally specific
64966496 and not (
6497- isinstance (p_expr := get_proper_type (expr_type ), CallableType )
6497+ isinstance (p_expr := get_proper_type (expr_type ), FunctionLike )
64986498 and p_expr .is_type_obj ()
64996499 )
65006500 # This is a little ad hoc, in the absence of intersection types
@@ -6679,7 +6679,8 @@ def narrow_type_by_identity_equality(
66796679 else :
66806680 raise AssertionError
66816681
6682- partial_type_maps = []
6682+ all_if_maps : list [TypeMap ] = []
6683+ all_else_maps : list [TypeMap ] = []
66836684
66846685 # For each narrowable index, we see what we can narrow based on each relevant target
66856686 for i in expr_indices :
@@ -6690,10 +6691,8 @@ def narrow_type_by_identity_equality(
66906691 continue
66916692
66926693 expr_type = operand_types [i ]
6693- expanded_expr_type = try_expanding_sum_type_to_union (
6694- coerce_to_literal (expr_type ), None
6695- )
66966694 expr_enum_keys = ambiguous_enum_equality_keys (expr_type )
6695+ expr_type = try_expanding_sum_type_to_union (coerce_to_literal (expr_type ), None )
66976696 for j in expr_indices :
66986697 if i == j :
66996698 continue
@@ -6703,11 +6702,6 @@ def narrow_type_by_identity_equality(
67036702 continue
67046703 target_type = operand_types [j ]
67056704 if should_coerce_literals :
6706- # TODO: doing this prevents narrowing a single-member Enum to literal
6707- # of its member, because we expand it here and then refuse to add equal
6708- # types to typemaps. As a result, `x: Foo; x == Foo.A` does not narrow
6709- # `x` to `Literal[Foo.A]` iff `Foo` has exactly one member.
6710- # See testMatchEnumSingleChoice
67116705 target_type = coerce_to_literal (target_type )
67126706
67136707 if (
@@ -6718,24 +6712,21 @@ def narrow_type_by_identity_equality(
67186712 continue
67196713
67206714 target = TypeRange (target_type , is_upper_bound = False )
6721- is_value_target = is_target_for_value_narrowing (get_proper_type (target_type ))
67226715
6723- if is_value_target :
6724- if_map , else_map = conditional_types_to_typemaps (
6725- operands [i ], * conditional_types (expanded_expr_type , [target ])
6726- )
6727- partial_type_maps .append ((if_map , else_map ))
6716+ if_map , else_map = conditional_types_to_typemaps (
6717+ operands [i ], * conditional_types (expr_type , [target ])
6718+ )
6719+ if is_target_for_value_narrowing (get_proper_type (target_type )):
6720+ all_if_maps .append (if_map )
6721+ all_else_maps .append (else_map )
67286722 else :
6729- if_map , else_map = conditional_types_to_typemaps (
6730- operands [i ], * conditional_types (expr_type , [target ])
6731- )
67326723 # For value targets, it is safe to narrow in the negative case.
67336724 # e.g. if (x: Literal[5] | None) != (y: Literal[5]), we can narrow x to None
67346725 # However, for non-value targets, we cannot do this narrowing,
67356726 # and so we ignore else_map
67366727 # e.g. if (x: str | None) != (y: str), we cannot narrow x to None
6737- if if_map :
6738- partial_type_maps .append (( if_map , {}) )
6728+ if if_map is not None : # TODO: this gate is incorrect and should be removed
6729+ all_if_maps .append (if_map )
67396730
67406731 # Handle narrowing for operands with custom __eq__ methods specially
67416732 # In most cases, we won't be able to do any narrowing
@@ -6757,14 +6748,12 @@ def narrow_type_by_identity_equality(
67576748 if should_coerce_literals :
67586749 target_type = coerce_to_literal (target_type )
67596750 target = TypeRange (target_type , is_upper_bound = False )
6760- is_value_target = is_target_for_value_narrowing (get_proper_type (target_type ))
6761-
6762- if is_value_target :
6751+ if is_target_for_value_narrowing (get_proper_type (target_type )):
67636752 if_map , else_map = conditional_types_to_typemaps (
67646753 operands [i ], * conditional_types (expr_type , [target ])
67656754 )
67666755 if else_map :
6767- partial_type_maps .append (({}, else_map ) )
6756+ all_else_maps .append (else_map )
67686757 continue
67696758
67706759 # If our operand with custom __eq__ is a union, where only some members of the union
@@ -6778,37 +6767,24 @@ def narrow_type_by_identity_equality(
67786767 # we narrow to in the if_map
67796768 or_if_maps .append ({operands [i ]: expr_type })
67806769
6770+ expr_type = coerce_to_literal (try_expanding_sum_type_to_union (expr_type , None ))
67816771 for j in expr_indices :
67826772 if j in custom_eq_indices :
67836773 continue
67846774 target_type = operand_types [j ]
67856775 if should_coerce_literals :
67866776 target_type = coerce_to_literal (target_type )
67876777 target = TypeRange (target_type , is_upper_bound = False )
6788- is_value_target = is_target_for_value_narrowing (get_proper_type (target_type ))
67896778
6790- if is_value_target :
6791- expr_type = coerce_to_literal (expr_type )
6792- expr_type = try_expanding_sum_type_to_union (expr_type , None )
67936779 if_map , else_map = conditional_types_to_typemaps (
67946780 operands [i ], * conditional_types (expr_type , [target ], default = expr_type )
67956781 )
67966782 or_if_maps .append (if_map )
6797- if is_value_target :
6783+ if is_target_for_value_narrowing ( get_proper_type ( target_type )) :
67986784 or_else_maps .append (else_map )
67996785
6800- final_if_map : TypeMap = {}
6801- final_else_map : TypeMap = {}
6802- if or_if_maps :
6803- final_if_map = or_if_maps [0 ]
6804- for if_map in or_if_maps [1 :]:
6805- final_if_map = or_conditional_maps (final_if_map , if_map )
6806- if or_else_maps :
6807- final_else_map = or_else_maps [0 ]
6808- for else_map in or_else_maps [1 :]:
6809- final_else_map = or_conditional_maps (final_else_map , else_map )
6810-
6811- partial_type_maps .append ((final_if_map , final_else_map ))
6786+ all_if_maps .append (reduce_or_conditional_type_maps (or_if_maps ))
6787+ all_else_maps .append (reduce_or_conditional_type_maps (or_else_maps ))
68126788
68136789 # Handle narrowing for comparisons that produce additional narrowing, like
68146790 # `type(x) == T` or `x.__class__ is T`
@@ -6851,13 +6827,16 @@ def narrow_type_by_identity_equality(
68516827 if isinstance (expr , RefExpr ) and isinstance (expr .node , TypeInfo )
68526828 else False
68536829 )
6854- if not is_final :
6855- else_map = {}
6856- partial_type_maps .append ((if_map , else_map ))
6830+ all_if_maps .append (if_map )
6831+ if is_final :
6832+ # We can only narrow `type(x) == T` in the negative case if T is final
6833+ all_else_maps .append (else_map )
68576834
68586835 # We will not have duplicate entries in our type maps if we only have two operands,
68596836 # so we can skip running meets on the intersections
6860- return reduce_conditional_maps (partial_type_maps , use_meet = len (operands ) > 2 )
6837+ if_map = reduce_and_conditional_type_maps (all_if_maps , use_meet = len (operands ) > 2 )
6838+ else_map = reduce_or_conditional_type_maps (all_else_maps )
6839+ return if_map , else_map
68616840
68626841 def propagate_up_typemap_info (self , new_types : TypeMap ) -> TypeMap :
68636842 """Attempts refining parent expressions of any MemberExpr or IndexExprs in new_types.
@@ -8529,7 +8508,7 @@ def builtin_item_type(tp: Type) -> Type | None:
85298508 return None
85308509
85318510
8532- def and_conditional_maps (m1 : TypeMap , m2 : TypeMap , use_meet : bool = False ) -> TypeMap :
8511+ def and_conditional_maps (m1 : TypeMap , m2 : TypeMap , * , use_meet : bool = False ) -> TypeMap :
85338512 """Calculate what information we can learn from the truth of (e1 and e2)
85348513 in terms of the information that we can learn from the truth of e1 and
85358514 the truth of e2.
@@ -8562,7 +8541,7 @@ def and_conditional_maps(m1: TypeMap, m2: TypeMap, use_meet: bool = False) -> Ty
85628541 return result
85638542
85648543
8565- def or_conditional_maps (m1 : TypeMap , m2 : TypeMap , coalesce_any : bool = False ) -> TypeMap :
8544+ def or_conditional_maps (m1 : TypeMap , m2 : TypeMap , * , coalesce_any : bool = False ) -> TypeMap :
85668545 """Calculate what information we can learn from the truth of (e1 or e2)
85678546 in terms of the information that we can learn from the truth of e1 and
85688547 the truth of e2. If coalesce_any is True, consider Any a supertype when
@@ -8627,6 +8606,30 @@ def reduce_conditional_maps(
86278606 return final_if_map , final_else_map
86288607
86298608
8609+ def reduce_or_conditional_type_maps (ms : list [TypeMap ]) -> TypeMap :
8610+ """Reduces a list of TypeMaps into a single TypeMap by "or"-ing them together."""
8611+ if len (ms ) == 0 :
8612+ return {}
8613+ if len (ms ) == 1 :
8614+ return ms [0 ]
8615+ result = ms [0 ]
8616+ for m in ms [1 :]:
8617+ result = or_conditional_maps (result , m )
8618+ return result
8619+
8620+
8621+ def reduce_and_conditional_type_maps (ms : list [TypeMap ], * , use_meet : bool ) -> TypeMap :
8622+ """Reduces a list of TypeMaps into a single TypeMap by "and"-ing them together."""
8623+ if len (ms ) == 0 :
8624+ return {}
8625+ if len (ms ) == 1 :
8626+ return ms [0 ]
8627+ result = ms [0 ]
8628+ for m in ms [1 :]:
8629+ result = and_conditional_maps (result , m , use_meet = use_meet )
8630+ return result
8631+
8632+
86308633BUILTINS_CUSTOM_EQ_CHECKS : Final = {
86318634 "builtins.bytes" ,
86328635 "builtins.bytearray" ,
@@ -8681,12 +8684,13 @@ def flatten(t: Expression) -> list[Expression]:
86818684def flatten_types_if_tuple (t : Type ) -> list [Type ]:
86828685 """Flatten a nested sequence of tuples into one list of nodes."""
86838686 t = get_proper_type (t )
8687+ if isinstance (t , UnionType ):
8688+ return [b for a in t .items for b in flatten_types (a )]
86848689 if isinstance (t , TupleType ):
86858690 return [b for a in t .items for b in flatten_types_if_tuple (a )]
86868691 elif is_named_instance (t , "builtins.tuple" ):
86878692 return [t .args [0 ]]
8688- else :
8689- return [t ]
8693+ return [t ]
86908694
86918695
86928696def expand_func (defn : FuncItem , map : dict [TypeVarId , Type ]) -> FuncItem :
0 commit comments