Skip to content

Commit 1efb385

Browse files
Kunal SaliKunal Sali
authored andcommitted
Fix false negative with walrus operator in tuples
1 parent a547c51 commit 1efb385

1 file changed

Lines changed: 80 additions & 3 deletions

File tree

mypy/checker.py

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4928,6 +4928,8 @@ def infer_context_dependent(
49284928
return typ
49294929

49304930
# If there are errors with the original type context, try re-inferring in empty context.
4931+
# However, skip this fallback if the expression contains assignment expressions (walrus
4932+
# operator), as they can cause incorrect type inference when the context is removed.
49314933
original_messages = msg.filtered_errors()
49324934
original_type_map = type_map
49334935
with self.msg.filter_errors(
@@ -4937,10 +4939,11 @@ def infer_context_dependent(
49374939
alt_typ = get_proper_type(
49384940
self.expr_checker.accept(expr, None, allow_none_return=allow_none_func_call)
49394941
)
4940-
if not msg.has_new_errors() and is_subtype(alt_typ, type_ctx):
4942+
4943+
if not msg.has_new_errors() and is_subtype(alt_typ, type_ctx) and not self.contains_assignment_expr(expr):
49414944
self.store_types(type_map)
49424945
return alt_typ
4943-
4946+
49444947
# If empty fallback didn't work, use results from the original type context.
49454948
self.msg.add_errors(original_messages)
49464949
self.store_types(original_type_map)
@@ -4979,7 +4982,7 @@ def check_return_stmt(self, s: ReturnStmt) -> None:
49794982

49804983
# Return with a value.
49814984
if (
4982-
isinstance(s.expr, (CallExpr, ListExpr, TupleExpr, DictExpr, SetExpr, OpExpr))
4985+
isinstance(s.expr, (CallExpr, ListExpr, TupleExpr, DictExpr, SetExpr, OpExpr, AssignmentExpr))
49834986
or isinstance(s.expr, AwaitExpr)
49844987
and isinstance(s.expr.expr, CallExpr)
49854988
):
@@ -5056,6 +5059,80 @@ def check_return_stmt(self, s: ReturnStmt) -> None:
50565059

50575060
if self.in_checked_function():
50585061
self.fail(message_registry.RETURN_VALUE_EXPECTED, s)
5062+
5063+
def contains_assignment_expr(self, expr: Expression) -> bool:
5064+
"""Check if expression contains any AssignmentExpr (walrus operator)."""
5065+
# Base case: found an assignment expression
5066+
if isinstance(expr, AssignmentExpr):
5067+
return True
5068+
5069+
# Recursively check nested expressions in various expression types
5070+
5071+
# Container expressions
5072+
if isinstance(expr, (TupleExpr, ListExpr, SetExpr)):
5073+
return any(self.contains_assignment_expr(item) for item in expr.items)
5074+
5075+
if isinstance(expr, DictExpr):
5076+
# Check both keys and values
5077+
for k, v in zip(expr.items, expr.values):
5078+
if self.contains_assignment_expr(k) or self.contains_assignment_expr(v):
5079+
return True
5080+
return False
5081+
5082+
# Binary operations (left and right operands)
5083+
if isinstance(expr, OpExpr):
5084+
return (
5085+
self.contains_assignment_expr(expr.left)
5086+
or self.contains_assignment_expr(expr.right)
5087+
)
5088+
5089+
# Unary operations
5090+
if isinstance(expr, UnaryExpr):
5091+
return self.contains_assignment_expr(expr.expr)
5092+
5093+
# Comparison expressions (multiple operands)
5094+
if isinstance(expr, ComparisonExpr):
5095+
return any(self.contains_assignment_expr(operand) for operand in expr.operands)
5096+
5097+
# Function calls (check arguments)
5098+
if isinstance(expr, CallExpr):
5099+
# Check callee and all arguments
5100+
if self.contains_assignment_expr(expr.callee):
5101+
return True
5102+
return any(self.contains_assignment_expr(arg) for arg in expr.args)
5103+
5104+
# Index expressions (subscripts)
5105+
if isinstance(expr, IndexExpr):
5106+
if self.contains_assignment_expr(expr.base):
5107+
return True
5108+
if expr.index is not None:
5109+
return self.contains_assignment_expr(expr.index)
5110+
return False
5111+
5112+
# Member access
5113+
if isinstance(expr, MemberExpr):
5114+
return self.contains_assignment_expr(expr.expr)
5115+
5116+
# Starred expressions (unpacking)
5117+
if isinstance(expr, StarExpr):
5118+
return self.contains_assignment_expr(expr.expr)
5119+
5120+
# Await expressions
5121+
if isinstance(expr, AwaitExpr):
5122+
return self.contains_assignment_expr(expr.expr)
5123+
5124+
# Yield expressions
5125+
if isinstance(expr, YieldExpr):
5126+
if expr.expr is not None:
5127+
return self.contains_assignment_expr(expr.expr)
5128+
return False
5129+
5130+
# Conditional expressions (ternary operator)
5131+
# Note: ConditionalExpr might not be in imports, but if it exists, handle it
5132+
# For now, we'll skip it if it's not imported
5133+
5134+
# All other expression types (NameExpr, IntExpr, StrExpr, etc.) don't contain nested expressions
5135+
return False
50595136

50605137
def visit_if_stmt(self, s: IfStmt) -> None:
50615138
"""Type check an if statement."""

0 commit comments

Comments
 (0)