Skip to content

Commit 5ef4ed4

Browse files
authored
Merge pull request #1014 from codeflash-ai/feat/extract-imported-class-definitions
feat: extract imported class definitions for testgen context
2 parents bfc183f + fdb1d61 commit 5ef4ed4

5 files changed

Lines changed: 509 additions & 5 deletions

File tree

codeflash/context/code_context_extractor.py

Lines changed: 170 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,20 +127,48 @@ def get_code_optimization_context(
127127
remove_docstrings=False,
128128
code_context_type=CodeContextType.TESTGEN,
129129
)
130+
131+
# Extract class definitions for imported types from project modules
132+
# This helps the LLM understand class constructors and structure
133+
imported_class_context = get_imported_class_definitions(testgen_context, project_root_path)
134+
if imported_class_context.code_strings:
135+
# Merge imported class definitions into testgen context
136+
testgen_context = CodeStringsMarkdown(
137+
code_strings=testgen_context.code_strings + imported_class_context.code_strings
138+
)
139+
130140
testgen_markdown_code = testgen_context.markdown
131141
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
132142
if testgen_code_token_length > testgen_token_limit:
143+
# First try removing docstrings
133144
testgen_context = extract_code_markdown_context_from_files(
134145
helpers_of_fto_dict,
135146
helpers_of_helpers_dict,
136147
project_root_path,
137148
remove_docstrings=True,
138149
code_context_type=CodeContextType.TESTGEN,
139150
)
151+
# Re-extract imported classes (they may still fit)
152+
imported_class_context = get_imported_class_definitions(testgen_context, project_root_path)
153+
if imported_class_context.code_strings:
154+
testgen_context = CodeStringsMarkdown(
155+
code_strings=testgen_context.code_strings + imported_class_context.code_strings
156+
)
140157
testgen_markdown_code = testgen_context.markdown
141158
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
142159
if testgen_code_token_length > testgen_token_limit:
143-
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
160+
# If still over limit, try without imported class definitions
161+
testgen_context = extract_code_markdown_context_from_files(
162+
helpers_of_fto_dict,
163+
helpers_of_helpers_dict,
164+
project_root_path,
165+
remove_docstrings=True,
166+
code_context_type=CodeContextType.TESTGEN,
167+
)
168+
testgen_markdown_code = testgen_context.markdown
169+
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
170+
if testgen_code_token_length > testgen_token_limit:
171+
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
144172
code_hash_context = hashing_code_context.markdown
145173
code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest()
146174

@@ -489,6 +517,147 @@ def get_function_sources_from_jedi(
489517
return file_path_to_function_source, function_source_list
490518

491519

520+
def get_imported_class_definitions(code_context: CodeStringsMarkdown, project_root_path: Path) -> CodeStringsMarkdown:
521+
"""Extract class definitions for imported types from project modules.
522+
523+
This function analyzes the imports in the extracted code context and fetches
524+
class definitions for any classes imported from project modules. This helps
525+
the LLM understand the actual class structure (constructors, methods, inheritance)
526+
rather than just seeing import statements.
527+
528+
Args:
529+
code_context: The already extracted code context containing imports
530+
project_root_path: Root path of the project
531+
532+
Returns:
533+
CodeStringsMarkdown containing class definitions from imported project modules
534+
535+
"""
536+
import jedi
537+
538+
# Collect all code from the context
539+
all_code = "\n".join(cs.code for cs in code_context.code_strings)
540+
541+
# Parse to find import statements
542+
try:
543+
tree = ast.parse(all_code)
544+
except SyntaxError:
545+
return CodeStringsMarkdown(code_strings=[])
546+
547+
# Collect imported names and their source modules
548+
imported_names: dict[str, str] = {} # name -> module_path
549+
for node in ast.walk(tree):
550+
if isinstance(node, ast.ImportFrom) and node.module:
551+
for alias in node.names:
552+
if alias.name != "*":
553+
imported_name = alias.asname if alias.asname else alias.name
554+
imported_names[imported_name] = node.module
555+
556+
if not imported_names:
557+
return CodeStringsMarkdown(code_strings=[])
558+
559+
# Track which classes we've already extracted to avoid duplicates
560+
extracted_classes: set[tuple[Path, str]] = set() # (file_path, class_name)
561+
562+
# Also track what's already defined in the context
563+
existing_definitions: set[str] = set()
564+
for node in ast.walk(tree):
565+
if isinstance(node, ast.ClassDef):
566+
existing_definitions.add(node.name)
567+
568+
class_code_strings: list[CodeString] = []
569+
570+
for name, module_name in imported_names.items():
571+
# Skip if already defined in context
572+
if name in existing_definitions:
573+
continue
574+
575+
# Try to find the module file using Jedi
576+
try:
577+
# Create a script that imports the module to resolve it
578+
test_code = f"import {module_name}"
579+
script = jedi.Script(test_code, project=jedi.Project(path=project_root_path))
580+
completions = script.goto(1, len(test_code))
581+
582+
if not completions:
583+
continue
584+
585+
module_path = completions[0].module_path
586+
if not module_path:
587+
continue
588+
589+
# Check if this is a project module (not stdlib/third-party)
590+
if not str(module_path).startswith(str(project_root_path) + os.sep):
591+
continue
592+
if path_belongs_to_site_packages(module_path):
593+
continue
594+
595+
# Skip if we've already extracted this class
596+
if (module_path, name) in extracted_classes:
597+
continue
598+
599+
# Parse the module to find the class definition
600+
module_source = module_path.read_text(encoding="utf-8")
601+
module_tree = ast.parse(module_source)
602+
603+
for node in ast.walk(module_tree):
604+
if isinstance(node, ast.ClassDef) and node.name == name:
605+
# Extract the class source code
606+
lines = module_source.split("\n")
607+
class_source = "\n".join(lines[node.lineno - 1 : node.end_lineno])
608+
609+
# Also extract any necessary imports for the class (base classes, type hints)
610+
class_imports = _extract_imports_for_class(module_tree, node, module_source)
611+
612+
full_source = class_imports + "\n\n" + class_source if class_imports else class_source
613+
614+
class_code_strings.append(CodeString(code=full_source, file_path=module_path))
615+
extracted_classes.add((module_path, name))
616+
break
617+
618+
except Exception:
619+
logger.debug(f"Error extracting class definition for {name} from {module_name}")
620+
continue
621+
622+
return CodeStringsMarkdown(code_strings=class_code_strings)
623+
624+
625+
def _extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef, module_source: str) -> str:
626+
"""Extract import statements needed for a class definition.
627+
628+
This extracts imports for base classes and commonly used type annotations.
629+
"""
630+
needed_names: set[str] = set()
631+
632+
# Get base class names
633+
for base in class_node.bases:
634+
if isinstance(base, ast.Name):
635+
needed_names.add(base.id)
636+
elif isinstance(base, ast.Attribute) and isinstance(base.value, ast.Name):
637+
# For things like abc.ABC, we need the module name
638+
needed_names.add(base.value.id)
639+
640+
# Find imports that provide these names
641+
import_lines: list[str] = []
642+
source_lines = module_source.split("\n")
643+
644+
for node in module_tree.body:
645+
if isinstance(node, ast.Import):
646+
for alias in node.names:
647+
name = alias.asname if alias.asname else alias.name.split(".")[0]
648+
if name in needed_names:
649+
import_lines.append(source_lines[node.lineno - 1])
650+
break
651+
elif isinstance(node, ast.ImportFrom):
652+
for alias in node.names:
653+
name = alias.asname if alias.asname else alias.name
654+
if name in needed_names:
655+
import_lines.append(source_lines[node.lineno - 1])
656+
break
657+
658+
return "\n".join(import_lines)
659+
660+
492661
def is_dunder_method(name: str) -> bool:
493662
return len(name) > 4 and name.isascii() and name.startswith("__") and name.endswith("__")
494663

codeflash/models/models.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -655,16 +655,18 @@ def find_func_in_class(self, class_node: cst.ClassDef, func_name: str) -> Option
655655
def get_src_code(self, test_path: Path) -> Optional[str]:
656656
if not test_path.exists():
657657
return None
658-
test_src = test_path.read_text(encoding="utf-8")
659-
module_node = cst.parse_module(test_src)
658+
try:
659+
test_src = test_path.read_text(encoding="utf-8")
660+
module_node = cst.parse_module(test_src)
661+
except Exception:
662+
return None
660663

661664
if self.test_class_name:
662665
for stmt in module_node.body:
663666
if isinstance(stmt, cst.ClassDef) and stmt.name.value == self.test_class_name:
664667
func_node = self.find_func_in_class(stmt, self.test_function_name)
665668
if func_node:
666669
return module_node.code_for_node(func_node).strip()
667-
# class not found
668670
return None
669671

670672
# Otherwise, look for a top level function

codeflash/verification/codeflash_capture.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,17 @@ def get_test_info_from_stack(tests_root: str) -> tuple[str, str | None, str, str
8383
# Go to the previous frame
8484
frame = frame.f_back
8585

86+
# If stack walking didn't find test info, fall back to environment variables
87+
if not test_name:
88+
env_test_function = os.environ.get("CODEFLASH_TEST_FUNCTION")
89+
if env_test_function:
90+
test_name = env_test_function
91+
if not test_module_name:
92+
test_module_name = os.environ.get("CODEFLASH_TEST_MODULE", "")
93+
if not test_class_name:
94+
env_class = os.environ.get("CODEFLASH_TEST_CLASS")
95+
test_class_name = env_class if env_class else None
96+
8697
return test_module_name, test_class_name, test_name, line_id
8798

8899

0 commit comments

Comments
 (0)