|
5 | 5 |
|
6 | 6 | import mypy.errorcodes as codes |
7 | 7 | from mypy import message_registry |
8 | | -from mypy.nodes import DictExpr, IntExpr, StrExpr, UnaryExpr |
| 8 | +from mypy.nodes import DictExpr, Expression, IntExpr, StrExpr, UnaryExpr |
9 | 9 | from mypy.plugin import ( |
10 | 10 | AttributeContext, |
11 | 11 | ClassDefContext, |
@@ -263,30 +263,40 @@ def typed_dict_get_callback(ctx: MethodContext) -> Type: |
263 | 263 | if keys is None: |
264 | 264 | return ctx.default_return_type |
265 | 265 |
|
| 266 | + default_type: Type |
| 267 | + default_arg: Expression | None |
| 268 | + if len(ctx.arg_types) <= 1 or not ctx.arg_types[1]: |
| 269 | + default_arg = None |
| 270 | + default_type = NoneType() |
| 271 | + elif len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1: |
| 272 | + default_arg = ctx.args[1][0] |
| 273 | + default_type = ctx.arg_types[1][0] |
| 274 | + else: |
| 275 | + return ctx.default_return_type |
| 276 | + |
266 | 277 | output_types: list[Type] = [] |
267 | 278 | for key in keys: |
268 | | - value_type = get_proper_type(ctx.type.items.get(key)) |
| 279 | + value_type: Type | None = ctx.type.items.get(key) |
269 | 280 | if value_type is None: |
270 | 281 | return ctx.default_return_type |
271 | 282 |
|
272 | | - if len(ctx.arg_types) == 1: |
| 283 | + if key in ctx.type.required_keys: |
273 | 284 | output_types.append(value_type) |
274 | | - elif len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1: |
275 | | - default_arg = ctx.args[1][0] |
| 285 | + else: |
| 286 | + # HACK to deal with get(key, {}) |
276 | 287 | if ( |
277 | 288 | isinstance(default_arg, DictExpr) |
278 | 289 | and len(default_arg.items) == 0 |
279 | | - and isinstance(value_type, TypedDictType) |
| 290 | + and isinstance(vt := get_proper_type(value_type), TypedDictType) |
280 | 291 | ): |
281 | | - # Special case '{}' as the default for a typed dict type. |
282 | | - output_types.append(value_type.copy_modified(required_keys=set())) |
| 292 | + output_types.append(vt.copy_modified(required_keys=set())) |
283 | 293 | else: |
284 | 294 | output_types.append(value_type) |
285 | | - output_types.append(ctx.arg_types[1][0]) |
286 | | - |
287 | | - if len(ctx.arg_types) == 1: |
288 | | - output_types.append(NoneType()) |
| 295 | + output_types.append(default_type) |
289 | 296 |
|
| 297 | + # for nicer reveal_type, put default at the end, if it is present |
| 298 | + if default_type in output_types: |
| 299 | + output_types = [t for t in output_types if t != default_type] + [default_type] |
290 | 300 | return make_simplified_union(output_types) |
291 | 301 | return ctx.default_return_type |
292 | 302 |
|
|
0 commit comments