Skip to content

Commit 7b33e8b

Browse files
committed
refactor: smarter placement of global assignments based on dependencies
Assignments that don't reference module-level definitions are now placed right after imports. Only assignments that reference classes/functions are placed after those definitions to prevent NameError.
1 parent 257c5f2 commit 7b33e8b

5 files changed

Lines changed: 74 additions & 46 deletions

File tree

codeflash/code_utils/code_extractor.py

Lines changed: 64 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,21 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
115115
return updated_node.with_changes(body=new_statements)
116116

117117

118+
def collect_referenced_names(node: cst.CSTNode) -> set[str]:
119+
"""Collect all names referenced in a CST node using recursive traversal."""
120+
names: set[str] = set()
121+
122+
def _collect(n: cst.CSTNode) -> None:
123+
if isinstance(n, cst.Name):
124+
names.add(n.value)
125+
# Recursively process all children
126+
for child in n.children:
127+
_collect(child)
128+
129+
_collect(node)
130+
return names
131+
132+
118133
class GlobalAssignmentCollector(cst.CSTVisitor):
119134
"""Collects all global assignment statements."""
120135

@@ -274,37 +289,69 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
274289

275290
# Find assignments to append
276291
assignments_to_append = [
277-
self.new_assignments[name]
292+
(name, self.new_assignments[name])
278293
for name in self.new_assignment_order
279294
if name not in self.processed_assignments and name in self.new_assignments
280295
]
281296

282-
if assignments_to_append:
283-
# Start after imports, then advance past class/function definitions
284-
# to ensure assignments can reference any classes defined in the module
297+
if not assignments_to_append:
298+
return updated_node.with_changes(body=new_statements)
299+
300+
# Collect all class and function names defined in the module
301+
# These are the names that assignments might reference
302+
module_defined_names: set[str] = set()
303+
for stmt in new_statements:
304+
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
305+
module_defined_names.add(stmt.name.value)
306+
307+
# Partition assignments: those that reference module definitions go at the end,
308+
# those that don't can go right after imports
309+
assignments_after_imports: list[tuple[str, cst.Assign | cst.AnnAssign]] = []
310+
assignments_after_definitions: list[tuple[str, cst.Assign | cst.AnnAssign]] = []
311+
312+
for name, assignment in assignments_to_append:
313+
# Get the value being assigned
314+
if isinstance(assignment, (cst.Assign, cst.AnnAssign)) and assignment.value is not None:
315+
value_node = assignment.value
316+
else:
317+
# No value to analyze, safe to place after imports
318+
assignments_after_imports.append((name, assignment))
319+
continue
320+
321+
# Collect names referenced in the assignment value
322+
referenced_names = collect_referenced_names(value_node)
323+
324+
# Check if any referenced names are module-level definitions
325+
if referenced_names & module_defined_names:
326+
# This assignment references a class/function, place it after definitions
327+
assignments_after_definitions.append((name, assignment))
328+
else:
329+
# Safe to place right after imports
330+
assignments_after_imports.append((name, assignment))
331+
332+
# Insert assignments that don't depend on module definitions right after imports
333+
if assignments_after_imports:
285334
insert_index = find_insertion_index_after_imports(updated_node)
335+
assignment_lines = [
336+
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
337+
for _, assignment in assignments_after_imports
338+
]
339+
new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:]))
340+
341+
# Insert assignments that depend on module definitions after all class/function definitions
342+
if assignments_after_definitions:
343+
# Find the position after the last function or class definition
344+
insert_index = find_insertion_index_after_imports(cst.Module(body=new_statements))
286345
for i, stmt in enumerate(new_statements):
287346
if isinstance(stmt, (cst.FunctionDef, cst.ClassDef)):
288347
insert_index = i + 1
289348

290349
assignment_lines = [
291350
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
292-
for assignment in assignments_to_append
351+
for _, assignment in assignments_after_definitions
293352
]
294-
295353
new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:]))
296354

297-
# Add a blank line after the last assignment if needed
298-
after_index = insert_index + len(assignment_lines)
299-
if after_index < len(new_statements):
300-
next_stmt = new_statements[after_index]
301-
# If there's no empty line, add one
302-
has_empty = any(isinstance(line, cst.EmptyLine) for line in next_stmt.leading_lines)
303-
if not has_empty:
304-
new_statements[after_index] = next_stmt.with_changes(
305-
leading_lines=[cst.EmptyLine(), *next_stmt.leading_lines]
306-
)
307-
308355
return updated_node.with_changes(body=new_statements)
309356

310357

tests/test_code_context_extractor.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2975,11 +2975,11 @@ def method(self):
29752975
return cached_helper(5)
29762976
"""
29772977

2978-
# Global assignments are now inserted AFTER class/function definitions
2979-
# to ensure they can reference classes defined in the module
29802978
expected = """\
29812979
from typing import Any
29822980
2981+
_LOCAL_CACHE: dict[str, int] = {}
2982+
29832983
class MyClass:
29842984
def method(self):
29852985
return cached_helper(5)
@@ -2992,8 +2992,6 @@ def cached_helper(x: int) -> int:
29922992
29932993
def regular_helper():
29942994
return "regular"
2995-
2996-
_LOCAL_CACHE: dict[str, int] = {}
29972995
"""
29982996

29992997
result = add_global_assignments(source_code, destination_code)
@@ -3111,11 +3109,11 @@ def handle_message(kind):
31113109
return "reply"
31123110
"""
31133111

3114-
# Global statements (function calls) should be inserted AFTER all class/function
3115-
# definitions to ensure they can reference any function defined in the module
31163112
expected = """\
31173113
import enum
31183114
3115+
_factories = {}
3116+
31193117
class MessageKind(enum.StrEnum):
31203118
ASK = "ask"
31213119
REPLY = "reply"
@@ -3129,8 +3127,6 @@ def handle_message(kind):
31293127
def _register(kind, factory):
31303128
_factories[kind] = factory
31313129
3132-
_factories = {}
3133-
31343130
31353131
_register(MessageKind.ASK, lambda: "ask handler")
31363132

tests/test_code_replacement.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2116,12 +2116,9 @@ def new_function2(value):
21162116
print("Hello world")
21172117
```
21182118
"""
2119-
# Global assignments are now inserted AFTER class/function definitions
2120-
# to ensure they can reference any classes defined in the module.
2121-
# This prevents NameError when LLM-generated optimizations like
2122-
# `_HANDLERS = {MessageKind.XXX: ...}` reference classes.
21232119
expected_code = """import numpy as np
21242120
2121+
a = 6
21252122
if 2<3:
21262123
a=4
21272124
else:
@@ -2143,8 +2140,6 @@ def __call__(self, value):
21432140
return "I am still old"
21442141
def new_function2(value):
21452142
return cst.ensure_type(value, str)
2146-
2147-
a = 6
21482143
"""
21492144
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
21502145
code_path.write_text(original_code, encoding="utf-8")
@@ -3371,9 +3366,6 @@ def hydrate_input_text_actions_with_field_names(
33713366
return updated_actions_by_task
33723367
```
33733368
'''
3374-
# Global assignments are now inserted AFTER class/function definitions
3375-
# to ensure they can reference any classes defined in the module.
3376-
# This prevents NameError when LLM-generated optimizations reference classes.
33773369
expected = '''"""
33783370
Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions.
33793371
"""
@@ -3388,6 +3380,8 @@ def hydrate_input_text_actions_with_field_names(
33883380
from skyvern.webeye.actions.actions import ActionType
33893381
import re
33903382
3383+
_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+")
3384+
33913385
LOG = structlog.get_logger(__name__)
33923386
33933387
# Initialize prompt engine
@@ -3441,8 +3435,6 @@ def hydrate_input_text_actions_with_field_names(
34413435
updated_actions_by_task[task_id] = updated_actions
34423436
34433437
return updated_actions_by_task
3444-
3445-
_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+")
34463438
'''
34473439

34483440
func = FunctionToOptimize(function_name="hydrate_input_text_actions_with_field_names", parents=[], file_path=main_file)

tests/test_get_read_writable_code.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,6 @@ class Inner:
218218
def target(self):
219219
pass
220220
"""
221-
# Nested class methods (MyClass.Inner.target) aren't directly targetable,
222-
# but the outer class is kept when the qualified name starts with it.
223-
# This is because the dependency tracking marks "MyClass" as used when it
224-
# sees "MyClass.Inner.target" as a target function.
225221
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"})
226222
expected = dedent("""
227223
class MyClass:

tests/test_multi_file_code_replacement.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,15 +124,14 @@ def _get_string_usage(text: str) -> Usage:
124124

125125
helper_file.unlink(missing_ok=True)
126126
main_file.unlink(missing_ok=True)
127-
128-
# Global assignments are now inserted AFTER class/function definitions
129-
# to prevent NameError when they reference classes or functions.
130-
# See commit 50fba096 for details.
127+
131128
expected_helper = """import re
132129
from collections.abc import Sequence
133130
134131
from pydantic_ai_slim.pydantic_ai.messages import BinaryContent, UserContent
135132
133+
_translate_table = {ord(c): ord(' ') for c in ' \\t\\n\\r\\x0b\\x0c",.:'}
134+
136135
_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+')
137136
138137
def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
@@ -159,8 +158,6 @@ def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
159158
tokens += len(part.data)
160159
161160
return tokens
162-
163-
_translate_table = {ord(c): ord(' ') for c in ' \\t\\n\\r\\x0b\\x0c",.:'}
164161
"""
165162

166163
assert new_code.rstrip() == original_main.rstrip() # No Change

0 commit comments

Comments
 (0)