@@ -115,6 +115,21 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
115115 return updated_node .with_changes (body = new_statements )
116116
117117
118+ def collect_referenced_names (node : cst .CSTNode ) -> set [str ]:
119+ """Collect all names referenced in a CST node using recursive traversal."""
120+ names : set [str ] = set ()
121+
122+ def _collect (n : cst .CSTNode ) -> None :
123+ if isinstance (n , cst .Name ):
124+ names .add (n .value )
125+ # Recursively process all children
126+ for child in n .children :
127+ _collect (child )
128+
129+ _collect (node )
130+ return names
131+
132+
118133class GlobalAssignmentCollector (cst .CSTVisitor ):
119134 """Collects all global assignment statements."""
120135
@@ -274,37 +289,69 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
274289
275290 # Find assignments to append
276291 assignments_to_append = [
277- self .new_assignments [name ]
292+ ( name , self .new_assignments [name ])
278293 for name in self .new_assignment_order
279294 if name not in self .processed_assignments and name in self .new_assignments
280295 ]
281296
282- if assignments_to_append :
283- # Start after imports, then advance past class/function definitions
284- # to ensure assignments can reference any classes defined in the module
297+ if not assignments_to_append :
298+ return updated_node .with_changes (body = new_statements )
299+
300+ # Collect all class and function names defined in the module
301+ # These are the names that assignments might reference
302+ module_defined_names : set [str ] = set ()
303+ for stmt in new_statements :
304+ if isinstance (stmt , (cst .ClassDef , cst .FunctionDef )):
305+ module_defined_names .add (stmt .name .value )
306+
307+ # Partition assignments: those that reference module definitions go at the end,
308+ # those that don't can go right after imports
309+ assignments_after_imports : list [tuple [str , cst .Assign | cst .AnnAssign ]] = []
310+ assignments_after_definitions : list [tuple [str , cst .Assign | cst .AnnAssign ]] = []
311+
312+ for name , assignment in assignments_to_append :
313+ # Get the value being assigned
314+ if isinstance (assignment , (cst .Assign , cst .AnnAssign )) and assignment .value is not None :
315+ value_node = assignment .value
316+ else :
317+ # No value to analyze, safe to place after imports
318+ assignments_after_imports .append ((name , assignment ))
319+ continue
320+
321+ # Collect names referenced in the assignment value
322+ referenced_names = collect_referenced_names (value_node )
323+
324+ # Check if any referenced names are module-level definitions
325+ if referenced_names & module_defined_names :
326+ # This assignment references a class/function, place it after definitions
327+ assignments_after_definitions .append ((name , assignment ))
328+ else :
329+ # Safe to place right after imports
330+ assignments_after_imports .append ((name , assignment ))
331+
332+ # Insert assignments that don't depend on module definitions right after imports
333+ if assignments_after_imports :
285334 insert_index = find_insertion_index_after_imports (updated_node )
335+ assignment_lines = [
336+ cst .SimpleStatementLine ([assignment ], leading_lines = [cst .EmptyLine ()])
337+ for _ , assignment in assignments_after_imports
338+ ]
339+ new_statements = list (chain (new_statements [:insert_index ], assignment_lines , new_statements [insert_index :]))
340+
341+ # Insert assignments that depend on module definitions after all class/function definitions
342+ if assignments_after_definitions :
343+ # Find the position after the last function or class definition
344+ insert_index = find_insertion_index_after_imports (cst .Module (body = new_statements ))
286345 for i , stmt in enumerate (new_statements ):
287346 if isinstance (stmt , (cst .FunctionDef , cst .ClassDef )):
288347 insert_index = i + 1
289348
290349 assignment_lines = [
291350 cst .SimpleStatementLine ([assignment ], leading_lines = [cst .EmptyLine ()])
292- for assignment in assignments_to_append
351+ for _ , assignment in assignments_after_definitions
293352 ]
294-
295353 new_statements = list (chain (new_statements [:insert_index ], assignment_lines , new_statements [insert_index :]))
296354
297- # Add a blank line after the last assignment if needed
298- after_index = insert_index + len (assignment_lines )
299- if after_index < len (new_statements ):
300- next_stmt = new_statements [after_index ]
301- # If there's no empty line, add one
302- has_empty = any (isinstance (line , cst .EmptyLine ) for line in next_stmt .leading_lines )
303- if not has_empty :
304- new_statements [after_index ] = next_stmt .with_changes (
305- leading_lines = [cst .EmptyLine (), * next_stmt .leading_lines ]
306- )
307-
308355 return updated_node .with_changes (body = new_statements )
309356
310357
0 commit comments