Skip to content

Commit 2534663

Browse files
Refactor: Use DictExpr attribute to validate string keys, including bytes and nested unpacking
1 parent 172578d commit 2534663

4 files changed

Lines changed: 17 additions & 3 deletions

File tree

mypy/checkexpr.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5436,6 +5436,15 @@ def visit_dict_expr(self, e: DictExpr) -> Type:
54365436
expected_types.append(
54375437
self.chk.named_generic_type("_typeshed.SupportsKeysAndGetItem", [kt, vt])
54385438
)
5439+
# If this DictExpr came from a dict() call translation, validate that
5440+
# any unpacked dict has string keys (keywords must be strings)
5441+
if e.from_dict_call:
5442+
value_type = self.accept(value)
5443+
if not self.is_valid_keyword_var_arg(value_type):
5444+
is_mapping = is_subtype(
5445+
value_type, self.chk.named_type("_typeshed.SupportsKeysAndGetItem")
5446+
)
5447+
self.msg.invalid_keyword_var_arg(value_type, is_mapping, value)
54395448
else:
54405449
tup = TupleExpr([key, value])
54415450
if key.line >= 0:

mypy/nodes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2672,15 +2672,17 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T:
26722672
class DictExpr(Expression):
26732673
"""Dictionary literal expression {key: value, ...}."""
26742674

2675-
__slots__ = ("items",)
2675+
__slots__ = ("items", "from_dict_call")
26762676

26772677
__match_args__ = ("items",)
26782678

26792679
items: list[tuple[Expression | None, Expression]]
2680+
from_dict_call: bool # True if this came from a dict(...) call translation
26802681

26812682
def __init__(self, items: list[tuple[Expression | None, Expression]]) -> None:
26822683
super().__init__()
26832684
self.items = items
2685+
self.from_dict_call = False
26842686

26852687
def accept(self, visitor: ExpressionVisitor[T]) -> T:
26862688
return visitor.visit_dict_expr(self)

mypy/semanal.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6056,9 +6056,9 @@ def translate_dict_call(self, call: CallExpr) -> DictExpr | None:
60566056
# will catch the "keywords must be strings" error.
60576057
for kind, arg in zip(call.arg_kinds, call.args):
60586058
if kind == ARG_STAR2 and isinstance(arg, DictExpr):
6059-
# Check if all keys in the dict literal are strings
6059+
# Check if all keys in the dict literal are strings (not bytes!)
60606060
for key, _ in arg.items:
6061-
if key is not None and not isinstance(key, (StrExpr, BytesExpr)):
6061+
if key is not None and not isinstance(key, StrExpr):
60626062
# Non-string key found, don't translate
60636063
for a in call.args:
60646064
a.accept(self)
@@ -6070,6 +6070,7 @@ def translate_dict_call(self, call: CallExpr) -> DictExpr | None:
60706070
]
60716071
)
60726072
expr.set_line(call)
6073+
expr.from_dict_call = True
60736074
expr.accept(self)
60746075
return expr
60756076

test-data/unit/check-expressions.test

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2587,4 +2587,6 @@ def last_known_value() -> None:
25872587
[case testDictUnpackNonStringKey]
25882588
def f() -> None:
25892589
dict(**{10: 20}) # E: Argument after ** must have string keys
2590+
dict(**{**{1: 1}}) # E: Argument after ** must have string keys
2591+
dict(**{b'a': 1}) # E: Argument after ** must have string keys
25902592
[builtins fixtures/dict.pyi]

0 commit comments

Comments
 (0)