Skip to content

Commit ac86a83

Browse files
authored
Merge pull request #1005 from codeflash-ai/unstructured-fixes
2 parents fae2c6a + 9c468cd commit ac86a83

5 files changed

Lines changed: 282 additions & 22 deletions

File tree

codeflash/context/code_context_extractor.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -446,31 +446,45 @@ def get_function_sources_from_jedi(
446446
definition_path = definition.module_path
447447

448448
# The definition is part of this project and not defined within the original function
449-
if (
449+
is_valid_definition = (
450450
str(definition_path).startswith(str(project_root_path) + os.sep)
451451
and not path_belongs_to_site_packages(definition_path)
452452
and definition.full_name
453-
and definition.type == "function"
454453
and not belongs_to_function_qualified(definition, qualified_function_name)
455454
and definition.full_name.startswith(definition.module_name)
455+
)
456+
if is_valid_definition and definition.type == "function":
457+
qualified_name = get_qualified_name(definition.module_name, definition.full_name)
456458
# Avoid nested functions or classes. Only class.function is allowed
457-
and len(
458-
(qualified_name := get_qualified_name(definition.module_name, definition.full_name)).split(
459-
"."
459+
if len(qualified_name.split(".")) <= 2:
460+
function_source = FunctionSource(
461+
file_path=definition_path,
462+
qualified_name=qualified_name,
463+
fully_qualified_name=definition.full_name,
464+
only_function_name=definition.name,
465+
source_code=definition.get_line_code(),
466+
jedi_definition=definition,
460467
)
468+
file_path_to_function_source[definition_path].add(function_source)
469+
function_source_list.append(function_source)
470+
# When a class is instantiated (e.g., MyClass()), track its __init__ as a helper
471+
# This ensures the class definition with constructor is included in testgen context
472+
elif is_valid_definition and definition.type == "class":
473+
init_qualified_name = get_qualified_name(
474+
definition.module_name, f"{definition.full_name}.__init__"
461475
)
462-
<= 2
463-
):
464-
function_source = FunctionSource(
465-
file_path=definition_path,
466-
qualified_name=qualified_name,
467-
fully_qualified_name=definition.full_name,
468-
only_function_name=definition.name,
469-
source_code=definition.get_line_code(),
470-
jedi_definition=definition,
471-
)
472-
file_path_to_function_source[definition_path].add(function_source)
473-
function_source_list.append(function_source)
476+
# Only include if it's a top-level class (not nested)
477+
if len(init_qualified_name.split(".")) <= 2:
478+
function_source = FunctionSource(
479+
file_path=definition_path,
480+
qualified_name=init_qualified_name,
481+
fully_qualified_name=f"{definition.full_name}.__init__",
482+
only_function_name="__init__",
483+
source_code=definition.get_line_code(),
484+
jedi_definition=definition,
485+
)
486+
file_path_to_function_source[definition_path].add(function_source)
487+
function_source_list.append(function_source)
474488

475489
return file_path_to_function_source, function_source_list
476490

@@ -647,7 +661,10 @@ def prune_cst_for_code_hashing( # noqa: PLR0911
647661

648662
if isinstance(node, cst.FunctionDef):
649663
qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value
650-
if qualified_name in target_functions:
664+
# For hashing, exclude __init__ methods even if in target_functions
665+
# because they don't affect the semantic behavior being hashed
666+
# But include other dunder methods like __call__ which do affect behavior
667+
if qualified_name in target_functions and node.name.value != "__init__":
651668
new_body = remove_docstring_from_body(node.body) if isinstance(node.body, cst.IndentedBlock) else node.body
652669
return node.with_changes(body=new_body), True
653670
return None, False
@@ -666,7 +683,9 @@ def prune_cst_for_code_hashing( # noqa: PLR0911
666683
for stmt in node.body.body:
667684
if isinstance(stmt, cst.FunctionDef):
668685
qualified_name = f"{class_prefix}.{stmt.name.value}"
669-
if qualified_name in target_functions:
686+
# For hashing, exclude __init__ methods even if in target_functions
687+
# but include other methods like __call__ which affect behavior
688+
if qualified_name in target_functions and stmt.name.value != "__init__":
670689
stmt_with_changes = stmt.with_changes(
671690
body=remove_docstring_from_body(cast("cst.IndentedBlock", stmt.body))
672691
)

codeflash/context/unused_definition_remover.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,18 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None:
223223
self.current_class = class_name
224224
self.current_top_level_name = class_name
225225

226+
# Track base classes as dependencies
227+
for base in node.bases:
228+
if isinstance(base.value, cst.Name):
229+
base_name = base.value.value
230+
if base_name in self.definitions and class_name in self.definitions:
231+
self.definitions[class_name].dependencies.add(base_name)
232+
elif isinstance(base.value, cst.Attribute):
233+
# Handle cases like module.ClassName
234+
attr_name = base.value.attr.value
235+
if attr_name in self.definitions and class_name in self.definitions:
236+
self.definitions[class_name].dependencies.add(attr_name)
237+
226238
self.class_depth += 1
227239

228240
def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002

tests/test_code_context_extractor.py

Lines changed: 168 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ def test_code_replacement10() -> None:
8484

8585
code_ctx = get_code_optimization_context(function_to_optimize=func_top_optimize, project_root_path=file_path.parent)
8686
qualified_names = {func.qualified_name for func in code_ctx.helper_functions}
87-
assert qualified_names == {"HelperClass.helper_method"} # Nested method should not be in here
87+
# HelperClass.__init__ is now tracked because HelperClass(self.name) instantiates the class
88+
assert qualified_names == {"HelperClass.helper_method", "HelperClass.__init__"} # Nested method should not be in here
8889
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
8990
hashing_context = code_ctx.hashing_code_context
9091

@@ -570,6 +571,8 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
570571
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
571572
"""Interface for cache backends used by the persistent cache decorator."""
572573
574+
def __init__(self) -> None: ...
575+
573576
def hash_key(
574577
self,
575578
*,
@@ -1296,6 +1299,8 @@ def __repr__(self) -> str:
12961299
```
12971300
```python:{path_to_transform_utils.relative_to(project_root)}
12981301
class DataTransformer:
1302+
def __init__(self):
1303+
self.data = None
12991304
13001305
def transform(self, data):
13011306
self.data = data
@@ -1599,7 +1604,11 @@ def __repr__(self) -> str:
15991604
\"\"\"Return a string representation of the DataProcessor.\"\"\"
16001605
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
16011606
```
1602-
1607+
```python:{path_to_transform_utils.relative_to(project_root)}
1608+
class DataTransformer:
1609+
def __init__(self):
1610+
self.data = None
1611+
```
16031612
"""
16041613
expected_hashing_context = f"""
16051614
```python:utils.py
@@ -1705,13 +1714,19 @@ def test_direct_module_import() -> None:
17051714

17061715
expected_read_only_context = """
17071716
```python:utils.py
1717+
import math
17081718
from transform_utils import DataTransformer
17091719
17101720
class DataProcessor:
17111721
\"\"\"A class for processing data.\"\"\"
17121722
17131723
number = 1
17141724
1725+
def __init__(self, default_prefix: str = "PREFIX_"):
1726+
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
1727+
self.default_prefix = default_prefix
1728+
self.number += math.log(self.number)
1729+
17151730
def __repr__(self) -> str:
17161731
\"\"\"Return a string representation of the DataProcessor.\"\"\"
17171732
return f"DataProcessor(default_prefix={self.default_prefix!r})"
@@ -2727,3 +2742,154 @@ async def async_function():
27272742
# Verify correct order
27282743
expected_order = ["GLOBAL_CONSTANT", "ANOTHER_CONSTANT", "FINAL_ASSIGNMENT"]
27292744
assert collector.assignment_order == expected_order
2745+
2746+
2747+
def test_class_instantiation_includes_init_as_helper(tmp_path: Path) -> None:
2748+
"""Test that when a class is instantiated, its __init__ method is tracked as a helper.
2749+
2750+
This test verifies the fix for the bug where class constructors were not
2751+
included in the context when only the class instantiation was called
2752+
(not any other methods). This caused LLMs to not know the constructor
2753+
signatures when generating tests.
2754+
"""
2755+
code = '''
2756+
class DataDumper:
2757+
"""A class that dumps data."""
2758+
2759+
def __init__(self, data):
2760+
"""Initialize with data."""
2761+
self.data = data
2762+
2763+
def dump(self):
2764+
"""Dump the data."""
2765+
return self.data
2766+
2767+
2768+
def target_function():
2769+
# Only instantiates DataDumper, doesn't call any other methods
2770+
dumper = DataDumper({"key": "value"})
2771+
return dumper
2772+
'''
2773+
file_path = tmp_path / "test_code.py"
2774+
file_path.write_text(code, encoding="utf-8")
2775+
opt = Optimizer(
2776+
Namespace(
2777+
project_root=file_path.parent.resolve(),
2778+
disable_telemetry=True,
2779+
tests_root="tests",
2780+
test_framework="pytest",
2781+
pytest_cmd="pytest",
2782+
experiment_id=None,
2783+
test_project_root=Path().resolve(),
2784+
)
2785+
)
2786+
function_to_optimize = FunctionToOptimize(
2787+
function_name="target_function",
2788+
file_path=file_path,
2789+
parents=[],
2790+
starting_line=None,
2791+
ending_line=None,
2792+
)
2793+
2794+
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
2795+
2796+
# The __init__ method should be tracked as a helper since DataDumper() instantiates the class
2797+
qualified_names = {func.qualified_name for func in code_ctx.helper_functions}
2798+
assert "DataDumper.__init__" in qualified_names, (
2799+
"DataDumper.__init__ should be tracked as a helper when the class is instantiated"
2800+
)
2801+
2802+
# The testgen context should contain the class with __init__ (critical for LLM to know constructor)
2803+
testgen_context = code_ctx.testgen_context.markdown
2804+
assert "class DataDumper:" in testgen_context, "DataDumper class should be in testgen context"
2805+
assert "def __init__(self, data):" in testgen_context, (
2806+
"__init__ method should be included in testgen context"
2807+
)
2808+
2809+
# The hashing context should NOT contain __init__ (excluded for stability)
2810+
hashing_context = code_ctx.hashing_code_context
2811+
assert "__init__" not in hashing_context, (
2812+
"__init__ should NOT be in hashing context (excluded for hash stability)"
2813+
)
2814+
2815+
2816+
def test_class_instantiation_preserves_full_class_in_testgen(tmp_path: Path) -> None:
2817+
"""Test that instantiated classes are fully preserved in testgen context.
2818+
2819+
This is specifically for the unstructured LayoutDumper bug where helper classes
2820+
that were instantiated but had no other methods called were being excluded
2821+
from the testgen context.
2822+
"""
2823+
code = '''
2824+
class LayoutDumper:
2825+
"""Base class for layout dumpers."""
2826+
layout_source: str = "unknown"
2827+
2828+
def __init__(self, layout):
2829+
self._layout = layout
2830+
2831+
def dump(self) -> dict:
2832+
raise NotImplementedError()
2833+
2834+
2835+
class ObjectDetectionLayoutDumper(LayoutDumper):
2836+
"""Specific dumper for object detection layouts."""
2837+
2838+
def __init__(self, layout):
2839+
super().__init__(layout)
2840+
2841+
def dump(self) -> dict:
2842+
return {"type": "object_detection", "layout": self._layout}
2843+
2844+
2845+
def dump_layout(layout_type, layout):
2846+
"""Dump a layout based on its type."""
2847+
if layout_type == "object_detection":
2848+
dumper = ObjectDetectionLayoutDumper(layout)
2849+
else:
2850+
dumper = LayoutDumper(layout)
2851+
return dumper.dump()
2852+
'''
2853+
file_path = tmp_path / "test_code.py"
2854+
file_path.write_text(code, encoding="utf-8")
2855+
opt = Optimizer(
2856+
Namespace(
2857+
project_root=file_path.parent.resolve(),
2858+
disable_telemetry=True,
2859+
tests_root="tests",
2860+
test_framework="pytest",
2861+
pytest_cmd="pytest",
2862+
experiment_id=None,
2863+
test_project_root=Path().resolve(),
2864+
)
2865+
)
2866+
function_to_optimize = FunctionToOptimize(
2867+
function_name="dump_layout",
2868+
file_path=file_path,
2869+
parents=[],
2870+
starting_line=None,
2871+
ending_line=None,
2872+
)
2873+
2874+
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
2875+
qualified_names = {func.qualified_name for func in code_ctx.helper_functions}
2876+
2877+
# Both class __init__ methods should be tracked as helpers
2878+
assert "ObjectDetectionLayoutDumper.__init__" in qualified_names, (
2879+
"ObjectDetectionLayoutDumper.__init__ should be tracked"
2880+
)
2881+
assert "LayoutDumper.__init__" in qualified_names, (
2882+
"LayoutDumper.__init__ should be tracked"
2883+
)
2884+
2885+
# The testgen context should include both classes with their __init__ methods
2886+
testgen_context = code_ctx.testgen_context.markdown
2887+
assert "class LayoutDumper:" in testgen_context, "LayoutDumper should be in testgen context"
2888+
assert "class ObjectDetectionLayoutDumper" in testgen_context, (
2889+
"ObjectDetectionLayoutDumper should be in testgen context"
2890+
)
2891+
2892+
# Both __init__ methods should be in the testgen context (so LLM knows constructor signatures)
2893+
assert testgen_context.count("def __init__") >= 2, (
2894+
"Both __init__ methods should be in testgen context"
2895+
)

tests/test_instrument_line_profiler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def hi():
5555
5656
5757
class BubbleSortClass:
58+
@codeflash_line_profile
5859
def __init__(self):
5960
pass
6061
@@ -117,7 +118,9 @@ def sort_classmethod(x):
117118
return y.sorter(x)
118119
"""
119120
assert code_path.read_text("utf-8") == expected_code_main
120-
assert code_context.helper_functions.__len__() == 0
121+
# WrapperClass.__init__ is now detected as a helper since WrapperClass.BubbleSortClass() instantiates it
122+
assert len(code_context.helper_functions) == 1
123+
assert code_context.helper_functions[0].qualified_name == "WrapperClass.__init__"
121124
finally:
122125
func_optimizer.write_code_and_helpers(
123126
func_optimizer.function_to_optimize_source_code, original_helper_code, func_optimizer.function_to_optimize.file_path
@@ -283,6 +286,7 @@ def sorter(arr):
283286
ans = helper(arr)
284287
return ans
285288
class helper:
289+
@codeflash_line_profile
286290
def __init__(self, arr):
287291
return arr.sort()
288292
"""

0 commit comments

Comments
 (0)