@@ -6247,43 +6247,41 @@ def find_type_equals_check(
62476247 expr_indices: The list of indices of expressions in ``node`` that are being
62486248 compared
62496249 """
6250-
6251- def is_type_call (expr : CallExpr ) -> bool :
6252- """Is expr a call to type with one argument?"""
6253- return refers_to_fullname (expr .callee , "builtins.type" ) and len (expr .args ) == 1
6254-
62556250 # exprs that are being passed into type
62566251 exprs_in_type_calls : list [Expression ] = []
6257- # type that is being compared to type(expr)
6258- type_being_compared : list [TypeRange ] | None = None
6259- # whether the type being compared to is final
6260- is_final = False
62616252
62626253 for index in expr_indices :
62636254 expr = node .operands [index ]
6264-
62656255 if isinstance (expr , CallExpr ) and is_type_call (expr ):
62666256 exprs_in_type_calls .append (expr .args [0 ])
6267- else :
6268- current_type = self .get_isinstance_type (expr )
6269- if current_type is None :
6270- continue
6271- if type_being_compared is not None :
6272- # It doesn't really make sense to have several types being
6273- # compared to the output of type (like type(x) == int == str)
6274- # because whether that's true is solely dependent on what the
6275- # types being compared are, so we don't try to narrow types any
6276- # further because we can't really get any information about the
6277- # type of x from that check
6278- return {}, {}
6279- else :
6280- if isinstance (expr , RefExpr ) and isinstance (expr .node , TypeInfo ):
6281- is_final = expr .node .is_final
6282- type_being_compared = current_type
62836257
62846258 if not exprs_in_type_calls :
62856259 return {}, {}
62866260
6261+ # type that is being compared to type(expr)
6262+ type_being_compared : list [TypeRange ] | None = None
6263+ # whether the type being compared to is final
6264+ is_final = False
6265+
6266+ for index in expr_indices :
6267+ expr = node .operands [index ]
6268+ if isinstance (expr , CallExpr ) and is_type_call (expr ):
6269+ continue
6270+ current_type = self .get_isinstance_type (expr )
6271+ if current_type is None :
6272+ continue
6273+ if type_being_compared is not None :
6274+ # It doesn't really make sense to have several types being
6275+ # compared to the output of type (like type(x) == int == str)
6276+ # because whether that's true is solely dependent on what the
6277+ # types being compared are, so we don't try to narrow types any
6278+ # further because we can't really get any information about the
6279+ # type of x from that check
6280+ return {}, {}
6281+ if isinstance (expr , RefExpr ) and isinstance (expr .node , TypeInfo ):
6282+ is_final = expr .node .is_final
6283+ type_being_compared = current_type
6284+
62876285 if_maps : list [TypeMap ] = []
62886286 else_maps : list [TypeMap ] = []
62896287 for expr in exprs_in_type_calls :
@@ -6663,8 +6661,10 @@ def equality_type_narrowing_helper(
66636661 expr_indices ,
66646662 narrowable_operand_index_to_hash .keys (),
66656663 )
6666- if if_map == {} and else_map == {} and node is not None :
6667- if_map , else_map = self .find_type_equals_check (node , expr_indices )
6664+ if node is not None :
6665+ type_if_map , type_else_map = self .find_type_equals_check (node , expr_indices )
6666+ if_map = and_conditional_maps (if_map , type_if_map )
6667+ else_map = and_conditional_maps (else_map , type_else_map )
66686668 return if_map , else_map
66696669
66706670 def narrow_type_by_equality (
@@ -6696,28 +6696,19 @@ def narrow_type_by_equality(
66966696 is_valid_target : Callable [[Type ], bool ] = is_singleton_type
66976697 coerce_only_in_literal_context = False
66986698 should_narrow_by_identity = True
6699- else :
6700-
6701- def is_exactly_literal_type (t : Type ) -> bool :
6702- return isinstance (get_proper_type (t ), LiteralType )
6703-
6704- def has_no_custom_eq_checks (t : Type ) -> bool :
6705- return not custom_special_method (
6706- t , "__eq__" , check_all = False
6707- ) and not custom_special_method (t , "__ne__" , check_all = False )
6708-
6709- is_valid_target = is_exactly_literal_type
6699+ elif operator in {"==" , "!=" }:
6700+ is_valid_target = is_singleton_value
67106701 coerce_only_in_literal_context = True
67116702
67126703 expr_types = [operand_types [i ] for i in expr_indices ]
6713- should_narrow_by_identity = all (
6714- map (has_no_custom_eq_checks , expr_types )
6704+ should_narrow_by_identity = not any (
6705+ map (has_custom_eq_checks , expr_types )
67156706 ) and not is_ambiguous_mix_of_enums (expr_types )
6707+ else :
6708+ raise AssertionError
67166709
6717- if_map : TypeMap = {}
6718- else_map : TypeMap = {}
67196710 if should_narrow_by_identity :
6720- if_map , else_map = self .refine_identity_comparison_expression (
6711+ return self .refine_identity_comparison_expression (
67216712 operands ,
67226713 operand_types ,
67236714 expr_indices ,
@@ -6726,11 +6717,9 @@ def has_no_custom_eq_checks(t: Type) -> bool:
67266717 coerce_only_in_literal_context ,
67276718 )
67286719
6729- if if_map == {} and else_map == {}:
6730- if_map , else_map = self .refine_away_none_in_comparison (
6731- operands , operand_types , expr_indices , narrowable_indices
6732- )
6733- return if_map , else_map
6720+ return self .refine_away_none_in_comparison (
6721+ operands , operand_types , expr_indices , narrowable_indices
6722+ )
67346723
67356724 def propagate_up_typemap_info (self , new_types : TypeMap ) -> TypeMap :
67366725 """Attempts refining parent expressions of any MemberExpr or IndexExprs in new_types.
@@ -6948,113 +6937,73 @@ def refine_identity_comparison_expression(
69486937 expressions in the chain to a Literal type. Performing this coercion is sometimes
69496938 too aggressive of a narrowing, depending on context.
69506939 """
6951- should_coerce = True
6952- if coerce_only_in_literal_context :
69536940
6954- def should_coerce_inner (typ : Type ) -> bool :
6955- typ = get_proper_type (typ )
6956- return is_literal_type_like (typ ) or (
6957- isinstance (typ , Instance ) and typ .type .is_enum
6958- )
6959-
6960- should_coerce = any (should_coerce_inner (operand_types [i ]) for i in chain_indices )
6941+ if coerce_only_in_literal_context :
6942+ should_coerce = False
6943+ for i in chain_indices :
6944+ typ = get_proper_type (operand_types [i ])
6945+ if is_literal_type_like (typ ) or (isinstance (typ , Instance ) and typ .type .is_enum ):
6946+ should_coerce = True
6947+ break
6948+ else :
6949+ should_coerce = True
69616950
6962- target : Type | None = None
6963- possible_target_indices = []
6951+ operator_specific_targets = []
6952+ type_targets = []
69646953 for i in chain_indices :
69656954 expr_type = operand_types [i ]
69666955 if should_coerce :
6967- # TODO: doing this prevents narrowing a single-member Enum to literal
6968- # of its member, because we expand it here and then refuse to add equal
6969- # types to typemaps. As a result, `x: Foo; x == Foo.A` does not narrow
6970- # `x` to `Literal[Foo.A]` iff `Foo` has exactly one member.
6971- # See testMatchEnumSingleChoice
69726956 expr_type = coerce_to_literal (expr_type )
6973- if not is_valid_target (get_proper_type (expr_type )):
6974- continue
6975- if target and not is_same_type (target , expr_type ):
6976- # We have multiple disjoint target types. So the 'if' branch
6977- # must be unreachable.
6978- return None , {}
6979- target = expr_type
6980- possible_target_indices .append (i )
6981-
6982- # There's nothing we can currently infer if none of the operands are valid targets,
6983- # so we end early and infer nothing.
6984- if target is None :
6985- return {}, {}
6986-
6987- # If possible, use an unassignable expression as the target.
6988- # We skip refining the type of the target below, so ideally we'd
6989- # want to pick an expression we were going to skip anyways.
6990- singleton_index = - 1
6991- for i in possible_target_indices :
6992- if i not in narrowable_operand_indices :
6993- singleton_index = i
6994-
6995- # But if none of the possible singletons are unassignable ones, we give up
6996- # and arbitrarily pick the last item, mostly because other parts of the
6997- # type narrowing logic bias towards picking the rightmost item and it'd be
6998- # nice to stay consistent.
6999- #
7000- # That said, it shouldn't matter which index we pick. For example, suppose we
7001- # have this if statement, where 'x' and 'y' both have singleton types:
7002- #
7003- # if x is y:
7004- # reveal_type(x)
7005- # reveal_type(y)
7006- # else:
7007- # reveal_type(x)
7008- # reveal_type(y)
7009- #
7010- # At this point, 'x' and 'y' *must* have the same singleton type: we would have
7011- # ended early in the first for-loop in this function if they weren't.
7012- #
7013- # So, we should always get the same result in the 'if' case no matter which
7014- # index we pick. And while we do end up getting different results in the 'else'
7015- # case depending on the index (e.g. if we pick 'y', then its type stays the same
7016- # while 'x' is narrowed to '<uninhabited>'), this distinction is also moot: mypy
7017- # currently will just mark the whole branch as unreachable if either operand is
7018- # narrowed to <uninhabited>.
7019- if singleton_index == - 1 :
7020- singleton_index = possible_target_indices [- 1 ]
7021-
7022- sum_type_name = None
7023- target = get_proper_type (target )
7024- if isinstance (target , LiteralType ) and (
7025- target .is_enum_literal () or isinstance (target .value , bool )
7026- ):
7027- sum_type_name = target .fallback .type .fullname
6957+ if is_valid_target (get_proper_type (expr_type )):
6958+ operator_specific_targets .append ((i , TypeRange (expr_type , is_upper_bound = False )))
6959+ else :
6960+ type_targets .append ((i , TypeRange (expr_type , is_upper_bound = False )))
70286961
7029- target_type = [TypeRange (target , is_upper_bound = False )]
6962+ # print = lambda *a: None
6963+ print ()
6964+ print ("operands" , operands )
6965+ print ("operand_types" , operand_types )
6966+ print ("operator_specific_targets" , operator_specific_targets )
6967+ print ("type_targets" , type_targets )
70306968
70316969 partial_type_maps = []
7032- for i in chain_indices :
7033- # If we try refining a type against itself, conditional_type_map
7034- # will end up assuming that the 'else' branch is unreachable. This is
7035- # typically not what we want: generally the user will intend for the
7036- # target type to be some fixed 'sentinel' value and will want to refine
7037- # the other exprs against this one instead.
7038- if i == singleton_index :
7039- continue
7040-
7041- # Naturally, we can't refine operands which are not permitted to be refined.
7042- if i not in narrowable_operand_indices :
7043- continue
7044-
7045- expr = operands [i ]
7046- expr_type = coerce_to_literal (operand_types [i ])
7047-
7048- if sum_type_name is not None :
7049- expr_type = try_expanding_sum_type_to_union (expr_type , sum_type_name )
70506970
7051- # We intentionally use 'conditional_types' directly here instead of
7052- # 'self.conditional_types_with_intersection': we only compute ad-hoc
7053- # intersections when working with pure instances.
7054- types = conditional_types (expr_type , target_type )
7055- partial_type_maps .append (conditional_types_to_typemaps (expr , * types ))
6971+ if operator_specific_targets :
6972+ for i in chain_indices :
6973+ if i not in narrowable_operand_indices :
6974+ continue
6975+ targets = [t for j , t in operator_specific_targets if j != i ]
6976+ if targets :
6977+ expr_type = coerce_to_literal (operand_types [i ])
6978+ expr_type = try_expanding_sum_type_to_union (expr_type , None )
6979+ if_map , else_map = conditional_types_to_typemaps (
6980+ operands [i ], * conditional_types (expr_type , targets )
6981+ )
6982+ print ("ooo if_map" , if_map )
6983+ print ("ooo else_map" , else_map )
6984+ partial_type_maps .append ((if_map , else_map ))
70566985
7057- return reduce_conditional_maps (partial_type_maps )
6986+ if type_targets :
6987+ for i in chain_indices :
6988+ if i not in narrowable_operand_indices :
6989+ continue
6990+ targets = [t for j , t in type_targets if j != i ]
6991+ if targets :
6992+ expr_type = operand_types [i ]
6993+ if_map , else_map = conditional_types_to_typemaps (
6994+ operands [i ], * conditional_types (expr_type , targets )
6995+ )
6996+ if if_map :
6997+ else_map = {}
6998+ print ("ttt targets" , targets )
6999+ print ("ttt if_map" , if_map )
7000+ print ("ttt else_map" , else_map )
7001+ partial_type_maps .append ((if_map , else_map ))
7002+
7003+ final_if_map , final_else_map = reduce_conditional_maps (partial_type_maps )
7004+ print ("final_if_map" , final_if_map )
7005+ print ("final_else_map" , final_else_map )
7006+ return final_if_map , final_else_map
70587007
70597008 def refine_away_none_in_comparison (
70607009 self ,
@@ -8648,6 +8597,40 @@ def reduce_conditional_maps(
86488597 return final_if_map , final_else_map
86498598
86508599
8600+ def is_singleton_value (t : Type ) -> bool :
8601+ t = get_proper_type (t )
8602+ # TODO: check the type object thing
8603+ ret = isinstance (t , LiteralType ) or t .is_singleton_type () or (isinstance (t , CallableType ) and t .is_type_obj ())
8604+ print ("!!!" , t , type (t ), ret )
8605+ return ret
8606+
8607+
8608+ BUILTINS_CUSTOM_EQ_CHECKS : Final = {
8609+ "builtins.bytes" ,
8610+ "builtins.bytearray" ,
8611+ "builtins.memoryview" ,
8612+ "builtins.tuple" ,
8613+ "builtins.list" ,
8614+ "builtins.dict" ,
8615+ "builtins.set" ,
8616+ }
8617+
8618+
8619+ def has_custom_eq_checks (t : Type ) -> bool :
8620+ return (
8621+ custom_special_method (t , "__eq__" , check_all = False )
8622+ or custom_special_method (t , "__ne__" , check_all = False )
8623+ # TODO: make the hack more principled. explain what and why we're doing this though
8624+ # custom_special_method has special casing for builtins
8625+ or (isinstance (t , Instance ) and t .type .fullname in BUILTINS_CUSTOM_EQ_CHECKS )
8626+ )
8627+
8628+
8629+ def is_type_call (expr : CallExpr ) -> bool :
8630+ """Is expr a call to type with one argument?"""
8631+ return refers_to_fullname (expr .callee , "builtins.type" ) and len (expr .args ) == 1
8632+
8633+
86518634def convert_to_typetype (type_map : TypeMap ) -> TypeMap :
86528635 converted_type_map : dict [Expression , Type ] = {}
86538636 if type_map is None :
0 commit comments