Skip to content

Commit 5e61a8b

Browse files
authored
Merge pull request #1166 from codeflash-ai/skyvern-grace
feat: improve dependency tracking and base class extraction
2 parents 53fbe57 + 214d891 commit 5e61a8b

12 files changed

Lines changed: 2592 additions & 502 deletions

.github/workflows/mypy.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,10 @@ jobs:
2222

2323
- name: Install uv
2424
uses: astral-sh/setup-uv@v6
25-
with:
26-
version: "0.5.30"
2725

2826
- name: sync uv
2927
run: |
28+
uv venv --seed
3029
uv sync
3130
3231

codeflash/code_utils/code_extractor.py

Lines changed: 268 additions & 76 deletions
Large diffs are not rendered by default.

codeflash/context/code_context_extractor.py

Lines changed: 416 additions & 292 deletions
Large diffs are not rendered by default.

codeflash/context/unused_definition_remover.py

Lines changed: 55 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -295,11 +295,18 @@ def visit_Name(self, node: cst.Name) -> None:
295295
return
296296

297297
if name in self.definitions and name != self.current_top_level_name:
298-
# skip if we are refrencing a class attribute and not a top-level definition
298+
# Skip if this Name is the .attr part of an Attribute (e.g., 'x' in 'self.x')
299+
# We only want to track the base/value of attribute access, not the attribute name itself
299300
if self.class_depth > 0:
300301
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
301302
if parent is not None and isinstance(parent, cst.Attribute):
302-
return
303+
# Check if this Name is the .attr (property name), not the .value (base)
304+
# If it's the .attr, skip it - attribute names aren't references to definitions
305+
if parent.attr is node:
306+
return
307+
# If it's the .value (base), only skip if it's self/cls
308+
if name in ("self", "cls"):
309+
return
303310
self.definitions[self.current_top_level_name].dependencies.add(name)
304311

305312

@@ -553,16 +560,6 @@ def remove_unused_definitions_by_function_names(code: str, qualified_function_na
553560
return code
554561

555562

556-
def print_definitions(definitions: dict[str, UsageInfo]) -> None:
557-
"""Print information about each definition without the complex node object, used for debugging."""
558-
print(f"Found {len(definitions)} definitions:")
559-
for name, info in sorted(definitions.items()):
560-
print(f" - Name: {name}")
561-
print(f" Used by qualified function: {info.used_by_qualified_function}")
562-
print(f" Dependencies: {', '.join(sorted(info.dependencies)) if info.dependencies else 'None'}")
563-
print()
564-
565-
566563
def revert_unused_helper_functions(
567564
project_root: Path, unused_helpers: list[FunctionSource], original_helper_code: dict[Path, str]
568565
) -> None:
@@ -637,43 +634,40 @@ def _analyze_imports_in_optimized_code(
637634
func_name = helper.only_function_name
638635
module_name = helper.file_path.stem
639636
# Cache function lookup for this (module, func)
640-
file_entry = helpers_by_file_and_func[module_name]
641-
if func_name in file_entry:
642-
file_entry[func_name].append(helper)
643-
else:
644-
file_entry[func_name] = [helper]
637+
helpers_by_file_and_func[module_name].setdefault(func_name, []).append(helper)
645638
helpers_by_file[module_name].append(helper)
646639

647-
# Optimize attribute lookups and method binding outside the loop
648-
helpers_by_file_and_func_get = helpers_by_file_and_func.get
649-
helpers_by_file_get = helpers_by_file.get
650-
651640
for node in ast.walk(optimized_ast):
652641
if isinstance(node, ast.ImportFrom):
653642
# Handle "from module import function" statements
654643
module_name = node.module
655644
if module_name:
656-
file_entry = helpers_by_file_and_func_get(module_name, None)
645+
file_entry = helpers_by_file_and_func.get(module_name)
657646
if file_entry:
658647
for alias in node.names:
659648
imported_name = alias.asname if alias.asname else alias.name
660649
original_name = alias.name
661-
helpers = file_entry.get(original_name, None)
650+
helpers = file_entry.get(original_name)
662651
if helpers:
652+
imported_set = imported_names_map[imported_name]
663653
for helper in helpers:
664-
imported_names_map[imported_name].add(helper.qualified_name)
665-
imported_names_map[imported_name].add(helper.fully_qualified_name)
654+
imported_set.add(helper.qualified_name)
655+
imported_set.add(helper.fully_qualified_name)
666656

667657
elif isinstance(node, ast.Import):
668658
# Handle "import module" statements
669659
for alias in node.names:
670660
imported_name = alias.asname if alias.asname else alias.name
671661
module_name = alias.name
672-
for helper in helpers_by_file_get(module_name, []):
673-
# For "import module" statements, functions would be called as module.function
674-
full_call = f"{imported_name}.{helper.only_function_name}"
675-
imported_names_map[full_call].add(helper.qualified_name)
676-
imported_names_map[full_call].add(helper.fully_qualified_name)
662+
helpers = helpers_by_file.get(module_name)
663+
if helpers:
664+
imported_set = imported_names_map[f"{imported_name}.{{func}}"]
665+
for helper in helpers:
666+
# For "import module" statements, functions would be called as module.function
667+
full_call = f"{imported_name}.{helper.only_function_name}"
668+
full_call_set = imported_names_map[full_call]
669+
full_call_set.add(helper.qualified_name)
670+
full_call_set.add(helper.fully_qualified_name)
677671

678672
return dict(imported_names_map)
679673

@@ -753,27 +747,31 @@ def detect_unused_helper_functions(
753747
called_name = node.func.id
754748
called_function_names.add(called_name)
755749
# Also add the qualified name if this is an imported function
756-
if called_name in imported_names_map:
757-
called_function_names.update(imported_names_map[called_name])
750+
mapped_names = imported_names_map.get(called_name)
751+
if mapped_names:
752+
called_function_names.update(mapped_names)
758753
elif isinstance(node.func, ast.Attribute):
759754
# Method call: obj.method() or self.method() or module.function()
760755
if isinstance(node.func.value, ast.Name):
761-
if node.func.value.id == "self":
756+
attr_name = node.func.attr
757+
value_id = node.func.value.id
758+
if value_id == "self":
762759
# self.method_name() -> add both method_name and ClassName.method_name
763-
called_function_names.add(node.func.attr)
760+
called_function_names.add(attr_name)
761+
# For class methods, also add the qualified name
764762
# For class methods, also add the qualified name
765763
if hasattr(function_to_optimize, "parents") and function_to_optimize.parents:
766764
class_name = function_to_optimize.parents[0].name
767-
called_function_names.add(f"{class_name}.{node.func.attr}")
765+
called_function_names.add(f"{class_name}.{attr_name}")
768766
else:
769-
# obj.method() or module.function()
770-
attr_name = node.func.attr
771767
called_function_names.add(attr_name)
772-
called_function_names.add(f"{node.func.value.id}.{attr_name}")
768+
full_call = f"{value_id}.{attr_name}"
769+
called_function_names.add(full_call)
773770
# Check if this is a module.function call that maps to a helper
774-
full_call = f"{node.func.value.id}.{attr_name}"
775-
if full_call in imported_names_map:
776-
called_function_names.update(imported_names_map[full_call])
771+
mapped_names = imported_names_map.get(full_call)
772+
if mapped_names:
773+
called_function_names.update(mapped_names)
774+
# Handle nested attribute access like obj.attr.method()
777775
# Handle nested attribute access like obj.attr.method()
778776
else:
779777
called_function_names.add(node.func.attr)
@@ -783,36 +781,38 @@ def detect_unused_helper_functions(
783781

784782
# Find helper functions that are no longer called
785783
unused_helpers = []
784+
entrypoint_file_path = function_to_optimize.file_path
786785
for helper_function in code_context.helper_functions:
787786
if helper_function.jedi_definition.type != "class":
788787
# Check if the helper function is called using multiple name variants
789788
helper_qualified_name = helper_function.qualified_name
790789
helper_simple_name = helper_function.only_function_name
791790
helper_fully_qualified_name = helper_function.fully_qualified_name
792791

793-
# Create a set of all possible names this helper might be called by
794-
possible_call_names = {helper_qualified_name, helper_simple_name, helper_fully_qualified_name}
795-
792+
# Check membership efficiently - exit early on first match
793+
if (
794+
helper_qualified_name in called_function_names
795+
or helper_simple_name in called_function_names
796+
or helper_fully_qualified_name in called_function_names
797+
):
798+
is_called = True
796799
# For cross-file helpers, also consider module-based calls
797-
if helper_function.file_path != function_to_optimize.file_path:
800+
elif helper_function.file_path != entrypoint_file_path:
798801
# Add potential module.function combinations
799802
module_name = helper_function.file_path.stem
800-
possible_call_names.add(f"{module_name}.{helper_simple_name}")
801-
802-
# Check if any of the possible names are in the called functions
803-
is_called = bool(possible_call_names.intersection(called_function_names))
803+
module_call = f"{module_name}.{helper_simple_name}"
804+
is_called = module_call in called_function_names
805+
else:
806+
is_called = False
804807

805808
if not is_called:
806809
unused_helpers.append(helper_function)
807810
logger.debug(f"Helper function {helper_qualified_name} is not called in optimized code")
808-
logger.debug(f" Checked names: {possible_call_names}")
809811
else:
810812
logger.debug(f"Helper function {helper_qualified_name} is still called in optimized code")
811-
logger.debug(f" Called via: {possible_call_names.intersection(called_function_names)}")
812-
813-
ret_val = unused_helpers
814813

815814
except Exception as e:
816815
logger.debug(f"Error detecting unused helper functions: {e}")
817-
ret_val = []
818-
return ret_val
816+
return []
817+
else:
818+
return unused_helpers

codeflash/verification/codeflash_capture.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import dill as pickle
1616
from dill import PicklingWarning
1717

18+
from codeflash.picklepatch.pickle_patcher import PicklePatcher
19+
1820
warnings.filterwarnings("ignore", category=PicklingWarning)
1921

2022

@@ -148,18 +150,29 @@ def wrapper(*args, **kwargs) -> None: # noqa: ANN002, ANN003
148150
print(f"!######{test_stdout_tag}######!")
149151

150152
# Capture instance state after initialization
151-
if hasattr(args[0], "__dict__"):
152-
instance_state = args[
153-
0
154-
].__dict__ # self is always the first argument, this is ensured during instrumentation
153+
# self is always the first argument, this is ensured during instrumentation
154+
instance = args[0]
155+
if hasattr(instance, "__dict__"):
156+
instance_state = instance.__dict__
157+
elif hasattr(instance, "__slots__"):
158+
# For classes using __slots__, capture slot values
159+
instance_state = {
160+
slot: getattr(instance, slot, None) for slot in instance.__slots__ if hasattr(instance, slot)
161+
}
155162
else:
156-
raise ValueError("Instance state could not be captured.")
163+
# For C extension types or other special classes (e.g., Playwright's Page),
164+
# capture all non-private, non-callable attributes
165+
instance_state = {
166+
attr: getattr(instance, attr)
167+
for attr in dir(instance)
168+
if not attr.startswith("_") and not callable(getattr(instance, attr, None))
169+
}
157170
codeflash_cur.execute(
158171
"CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)"
159172
)
160173

161174
# Write to sqlite
162-
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(instance_state)
175+
pickled_return_value = pickle.dumps(exception) if exception else PicklePatcher.dumps(instance_state)
163176
codeflash_cur.execute(
164177
"INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
165178
(

codeflash/verification/equivalence.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@
1919
test_diff_repr = reprlib_repr.repr
2020

2121

22+
def safe_repr(obj: object) -> str:
23+
"""Safely get repr of an object, handling Mock objects with corrupted state."""
24+
try:
25+
return repr(obj)
26+
except (AttributeError, TypeError, RecursionError) as e:
27+
return f"<repr failed: {type(e).__name__}: {e}>"
28+
29+
2230
def compare_test_results(original_results: TestResults, candidate_results: TestResults) -> tuple[bool, list[TestDiff]]:
2331
# This is meant to be only called with test results for the first loop index
2432
if len(original_results) == 0 or len(candidate_results) == 0:
@@ -77,8 +85,8 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
7785
test_diffs.append(
7886
TestDiff(
7987
scope=TestDiffScope.RETURN_VALUE,
80-
original_value=test_diff_repr(repr(original_test_result.return_value)),
81-
candidate_value=test_diff_repr(repr(cdd_test_result.return_value)),
88+
original_value=test_diff_repr(safe_repr(original_test_result.return_value)),
89+
candidate_value=test_diff_repr(safe_repr(cdd_test_result.return_value)),
8290
test_src_code=original_test_result.id.get_src_code(original_test_result.file_name),
8391
candidate_pytest_error=cdd_pytest_error,
8492
original_pass=original_test_result.did_pass,

0 commit comments

Comments
 (0)