@@ -958,38 +958,44 @@ def get_dict_base_type(self, expr: Expression) -> list[Instance]:
958958
959959 This is useful for dict subclasses like SymbolTable.
960960 """
961- target_type = get_proper_type (self .types [expr ])
961+ return self .get_dict_base_type_from_type (self .types [expr ])
962+
963+ def get_dict_base_type_from_type (self , target_type : Type ) -> list [Instance ]:
964+ target_type = get_proper_type (target_type )
962965 if isinstance (target_type , UnionType ):
963- types = [get_proper_type (item ) for item in target_type .items ]
966+ return [
967+ inner
968+ for item in target_type .items
969+ for inner in self .get_dict_base_type_from_type (item )
970+ ]
971+ if isinstance (target_type , TypeVarLikeType ):
972+ # Match behaviour of self.node_type
973+ # We can only reach this point if `target_type` was a TypeVar(bound=dict[...])
974+ # or a ParamSpec.
975+ return self .get_dict_base_type_from_type (target_type .upper_bound )
976+
977+ if isinstance (target_type , TypedDictType ):
978+ target_type = target_type .fallback
979+ dict_base = next (
980+ base for base in target_type .type .mro if base .fullname == "typing.Mapping"
981+ )
982+ elif isinstance (target_type , Instance ):
983+ dict_base = next (
984+ base for base in target_type .type .mro if base .fullname == "builtins.dict"
985+ )
964986 else :
965- types = [target_type ]
966-
967- dict_types = []
968- for t in types :
969- if isinstance (t , TypedDictType ):
970- t = t .fallback
971- dict_base = next (base for base in t .type .mro if base .fullname == "typing.Mapping" )
972- else :
973- assert isinstance (t , Instance ), t
974- dict_base = next (base for base in t .type .mro if base .fullname == "builtins.dict" )
975- dict_types .append (map_instance_to_supertype (t , dict_base ))
976- return dict_types
987+ assert False , f"Failed to extract dict base from { target_type } "
988+ return [map_instance_to_supertype (target_type , dict_base )]
977989
978990 def get_dict_key_type (self , expr : Expression ) -> RType :
979991 dict_base_types = self .get_dict_base_type (expr )
980- if len (dict_base_types ) == 1 :
981- return self .type_to_rtype (dict_base_types [0 ].args [0 ])
982- else :
983- rtypes = [self .type_to_rtype (t .args [0 ]) for t in dict_base_types ]
984- return RUnion .make_simplified_union (rtypes )
992+ rtypes = [self .type_to_rtype (t .args [0 ]) for t in dict_base_types ]
993+ return RUnion .make_simplified_union (rtypes )
985994
986995 def get_dict_value_type (self , expr : Expression ) -> RType :
987996 dict_base_types = self .get_dict_base_type (expr )
988- if len (dict_base_types ) == 1 :
989- return self .type_to_rtype (dict_base_types [0 ].args [1 ])
990- else :
991- rtypes = [self .type_to_rtype (t .args [1 ]) for t in dict_base_types ]
992- return RUnion .make_simplified_union (rtypes )
997+ rtypes = [self .type_to_rtype (t .args [1 ]) for t in dict_base_types ]
998+ return RUnion .make_simplified_union (rtypes )
993999
9941000 def get_dict_item_type (self , expr : Expression ) -> RType :
9951001 key_type = self .get_dict_key_type (expr )
0 commit comments