@@ -4928,6 +4928,8 @@ def infer_context_dependent(
49284928 return typ
49294929
49304930 # If there are errors with the original type context, try re-inferring in empty context.
4931+ # However, skip this fallback if the expression contains assignment expressions (walrus
4932+ # operator), as they can cause incorrect type inference when the context is removed.
49314933 original_messages = msg .filtered_errors ()
49324934 original_type_map = type_map
49334935 with self .msg .filter_errors (
@@ -4937,10 +4939,11 @@ def infer_context_dependent(
49374939 alt_typ = get_proper_type (
49384940 self .expr_checker .accept (expr , None , allow_none_return = allow_none_func_call )
49394941 )
4940- if not msg .has_new_errors () and is_subtype (alt_typ , type_ctx ):
4942+
4943+ if not msg .has_new_errors () and is_subtype (alt_typ , type_ctx ) and not self .contains_assignment_expr (expr ):
49414944 self .store_types (type_map )
49424945 return alt_typ
4943-
4946+
49444947 # If empty fallback didn't work, use results from the original type context.
49454948 self .msg .add_errors (original_messages )
49464949 self .store_types (original_type_map )
@@ -4979,7 +4982,7 @@ def check_return_stmt(self, s: ReturnStmt) -> None:
49794982
49804983 # Return with a value.
49814984 if (
4982- isinstance (s .expr , (CallExpr , ListExpr , TupleExpr , DictExpr , SetExpr , OpExpr ))
4985+ isinstance (s .expr , (CallExpr , ListExpr , TupleExpr , DictExpr , SetExpr , OpExpr , AssignmentExpr ))
49834986 or isinstance (s .expr , AwaitExpr )
49844987 and isinstance (s .expr .expr , CallExpr )
49854988 ):
@@ -5056,6 +5059,80 @@ def check_return_stmt(self, s: ReturnStmt) -> None:
50565059
50575060 if self .in_checked_function ():
50585061 self .fail (message_registry .RETURN_VALUE_EXPECTED , s )
5062+
5063+ def contains_assignment_expr (self , expr : Expression ) -> bool :
5064+ """Check if expression contains any AssignmentExpr (walrus operator)."""
5065+ # Base case: found an assignment expression
5066+ if isinstance (expr , AssignmentExpr ):
5067+ return True
5068+
5069+ # Recursively check nested expressions in various expression types
5070+
5071+ # Container expressions
5072+ if isinstance (expr , (TupleExpr , ListExpr , SetExpr )):
5073+ return any (self .contains_assignment_expr (item ) for item in expr .items )
5074+
5075+ if isinstance (expr , DictExpr ):
5076+ # Check both keys and values
5077+ for k , v in zip (expr .items , expr .values ):
5078+ if self .contains_assignment_expr (k ) or self .contains_assignment_expr (v ):
5079+ return True
5080+ return False
5081+
5082+ # Binary operations (left and right operands)
5083+ if isinstance (expr , OpExpr ):
5084+ return (
5085+ self .contains_assignment_expr (expr .left )
5086+ or self .contains_assignment_expr (expr .right )
5087+ )
5088+
5089+ # Unary operations
5090+ if isinstance (expr , UnaryExpr ):
5091+ return self .contains_assignment_expr (expr .expr )
5092+
5093+ # Comparison expressions (multiple operands)
5094+ if isinstance (expr , ComparisonExpr ):
5095+ return any (self .contains_assignment_expr (operand ) for operand in expr .operands )
5096+
5097+ # Function calls (check arguments)
5098+ if isinstance (expr , CallExpr ):
5099+ # Check callee and all arguments
5100+ if self .contains_assignment_expr (expr .callee ):
5101+ return True
5102+ return any (self .contains_assignment_expr (arg ) for arg in expr .args )
5103+
5104+ # Index expressions (subscripts)
5105+ if isinstance (expr , IndexExpr ):
5106+ if self .contains_assignment_expr (expr .base ):
5107+ return True
5108+ if expr .index is not None :
5109+ return self .contains_assignment_expr (expr .index )
5110+ return False
5111+
5112+ # Member access
5113+ if isinstance (expr , MemberExpr ):
5114+ return self .contains_assignment_expr (expr .expr )
5115+
5116+ # Starred expressions (unpacking)
5117+ if isinstance (expr , StarExpr ):
5118+ return self .contains_assignment_expr (expr .expr )
5119+
5120+ # Await expressions
5121+ if isinstance (expr , AwaitExpr ):
5122+ return self .contains_assignment_expr (expr .expr )
5123+
5124+ # Yield expressions
5125+ if isinstance (expr , YieldExpr ):
5126+ if expr .expr is not None :
5127+ return self .contains_assignment_expr (expr .expr )
5128+ return False
5129+
5130+ # Conditional expressions (ternary operator)
5131+ # Note: ConditionalExpr might not be in imports, but if it exists, handle it
5132+ # For now, we'll skip it if it's not imported
5133+
5134+ # All other expression types (NameExpr, IntExpr, StrExpr, etc.) don't contain nested expressions
5135+ return False
50595136
50605137 def visit_if_stmt (self , s : IfStmt ) -> None :
50615138 """Type check an if statement."""
0 commit comments