@@ -144,6 +144,39 @@ def visit_Name(self, node):
144144 self .local_variables .add (node .id )
145145 return node
146146
147+ def is_expression_context (self , node ):
148+ if not hasattr (node , "_parent" ) or node ._parent is None :
149+ return False
150+ parent = node ._parent
151+ # Common expression contexts:
152+ if isinstance (
153+ parent ,
154+ (
155+ ast .BinOp ,
156+ ast .UnaryOp ,
157+ ast .BoolOp ,
158+ ast .Call ,
159+ ast .Subscript ,
160+ ast .Compare ,
161+ ast .Return ,
162+ ast .Expr ,
163+ ast .If ,
164+ ast .While ,
165+ ),
166+ ):
167+ return True
168+ # RHS of assignment: parent is Assign and node is in value
169+ if isinstance (parent , ast .Assign ) and node in ast .walk (parent .value ):
170+ return True
171+ return False
172+
173+ def _attach_parents (self , node , parent = None ):
174+ node ._parent = parent
175+ if parent and not hasattr (node , "lineno" ):
176+ node .lineno = parent .lineno
177+ for child in ast .iter_child_nodes (node ):
178+ self ._attach_parents (child , node )
179+
147180 def visit_FunctionDef (self , node ):
148181 # Capture argument names for branch functions
149182 old_args = self .current_func_args
@@ -152,8 +185,11 @@ def visit_FunctionDef(self, node):
152185 for n in node .body :
153186 visited = self .visit (n )
154187 if isinstance (visited , list ):
188+ for n in visited :
189+ self ._attach_parents (n , node )
155190 new_body .extend (visited )
156191 else :
192+ self ._attach_parents (visited , node )
157193 new_body .append (visited )
158194 node .body = new_body
159195 self .current_func_args = old_args
@@ -455,10 +491,10 @@ def visit_If(self, node):
455491 f"Inconsistencies between n_returned_values={ n_returned_values } , "
456492 f"dropped={ dropped } , tgt.elts={ tgt .elts } , tgt_elts={ tgt_elts } "
457493 )
458- tgt = ast .Tuple (tgt_elts , ctx = ast .Store ())
494+ tgt = ast .Tuple (list ( tgt_elts ) , ctx = ast .Store ())
459495
460496 added = {tgt .id } if isinstance (tgt , ast .Name ) else set (t .id for t in tgt .elts )
461- assign = ast .Assign (targets = [tgt ], value = call )
497+ assign = ast .Assign (targets = [tgt ], value = call , ctx = ast . Store () )
462498 ast .copy_location (assign , node )
463499 ast .fix_missing_locations (assign )
464500 self .local_variables = known_local_variables | added
@@ -631,7 +667,7 @@ def visit_For(self, node):
631667 ),
632668 ],
633669 decorator_list = [],
634- ctx = ast .Store (),
670+ ctx = ast .Load (),
635671 )
636672
637673 # final rewriting
@@ -654,7 +690,7 @@ def visit_For(self, node):
654690 args = [
655691 ast .Name (id = func_name , ctx = ast .Load ()),
656692 ast .List (
657- elts = [ast .Name (id = v , ctx = ast .Load ()) for v in init_vars ], ctx = ast .Store ()
693+ elts = [ast .Name (id = v , ctx = ast .Load ()) for v in init_vars ], ctx = ast .Load ()
658694 ),
659695 ast .List (
660696 elts = [
@@ -700,13 +736,13 @@ def visit_For(self, node):
700736 ],
701737 * [ast .Name (id = v , ctx = ast .Load ()) for v in scan_vars ],
702738 ],
703- ctx = ast .Store (),
739+ ctx = ast .Load (),
704740 ),
705741 ast .List (
706742 elts = [
707743 ast .Name (id = v , ctx = ast .Load ()) for v in [* scan_shape_vars , * input_vars ]
708744 ],
709- ctx = ast .Store (),
745+ ctx = ast .Load (),
710746 ),
711747 ],
712748 keywords = [],
@@ -923,23 +959,28 @@ def forward(self, x, y):
923959 )
924960 try :
925961 mod = compile (new_tree , filename = "<ast>" , mode = "exec" )
926- except TypeError as e :
927- if 'required field "lineno" missing from stmt' in str (e ):
962+ except (TypeError , ValueError ) as e :
963+ se = str (e )
964+ if (
965+ 'required field "lineno" missing from' in se
966+ or "expression must have Load context but has Store instead" in se
967+ ):
928968 # Could not find a way to avoid compiling a string.
929969 # The error message still pops up without indicating which node is not
930970 # properly set.
931971 code = ast .unparse (new_tree )
932972 try :
933973 mod = compile (code , filename = "<source>" , mode = "exec" )
934974 except IndentationError as ee :
935- raise RuntimeError (f"Unable to compile\n { code } " ) from ee
975+ raise RuntimeError (f"Unable to compile due to { ee } (and { e } ) \n { code } " ) from ee
936976 else :
937977 kws = dict (include_attributes = True , annotate_fields = True , indent = 4 )
938978 raise RuntimeError (
939- f"Unable to compile code\n --CODE--\n "
979+ f"Unable to compile code due to { e } \n --CODE--\n "
940980 f"{ ast .unparse (new_tree )} \n --TREE--\n "
941981 f"{ ast .dump (new_tree , ** kws )} "
942982 ) from e
983+
943984 namespace : Dict [str , type ] = {}
944985 globs = func .__globals__ .copy ()
945986 exec (mod , globs , namespace )
0 commit comments