Skip to content

Commit 42a1150

Browse files
authored
Merge pull request #1481 from codeflash-ai/include-external-class-inits-in-testgen
feat: include external class __init__ signatures with transitive type deps in testgen context
2 parents 1d9824c + 4f44286 commit 42a1150

8 files changed

Lines changed: 809 additions & 271 deletions

File tree

CLAUDE.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ uv run mypy codeflash/ # Type check
2424
uv run ruff check codeflash/ # Lint
2525
uv run ruff format codeflash/ # Format
2626

27-
# Linting (run before committing)
27+
# Linting (run before committing, checks staged files)
28+
uv run prek run
29+
30+
# Linting in CI (checks all files changed since main)
2831
uv run prek run --from-ref origin/main
2932

3033
# Mypy type checking (run on changed files before committing)

codeflash/context/code_context_extractor.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ def build_testgen_context(
7070
code_strings=testgen_context.code_strings + external_base_inits.code_strings
7171
)
7272

73+
external_class_inits = get_external_class_inits(testgen_context, project_root_path)
74+
if external_class_inits.code_strings:
75+
testgen_context = CodeStringsMarkdown(
76+
code_strings=testgen_context.code_strings + external_class_inits.code_strings
77+
)
78+
7379
return testgen_context
7480

7581

@@ -821,6 +827,210 @@ def get_external_base_class_inits(code_context: CodeStringsMarkdown, project_roo
821827
return CodeStringsMarkdown(code_strings=code_strings)
822828

823829

830+
MAX_TRANSITIVE_DEPTH = 2
831+
832+
833+
def extract_classes_from_type_hint(hint: object) -> list[type]:
834+
"""Recursively extract concrete class objects from a type annotation.
835+
836+
Unwraps Optional, Union, List, Dict, Callable, Annotated, etc.
837+
Filters out builtins and typing module types.
838+
"""
839+
import typing
840+
841+
classes: list[type] = []
842+
origin = getattr(hint, "__origin__", None)
843+
args = getattr(hint, "__args__", None)
844+
845+
if origin is not None and args:
846+
for arg in args:
847+
classes.extend(extract_classes_from_type_hint(arg))
848+
elif isinstance(hint, type):
849+
module = getattr(hint, "__module__", "")
850+
if module not in ("builtins", "typing", "typing_extensions", "types"):
851+
classes.append(hint)
852+
# Handle typing.Annotated on older Pythons where __origin__ may not be set
853+
if hasattr(typing, "get_args") and origin is None and args is None:
854+
try:
855+
inner_args = typing.get_args(hint)
856+
if inner_args:
857+
for arg in inner_args:
858+
classes.extend(extract_classes_from_type_hint(arg))
859+
except Exception:
860+
pass
861+
862+
return classes
863+
864+
865+
def resolve_transitive_type_deps(cls: type) -> list[type]:
866+
"""Find external classes referenced in cls.__init__ type annotations.
867+
868+
Returns classes from site-packages that have a custom __init__.
869+
"""
870+
import inspect
871+
import typing
872+
873+
try:
874+
init_method = getattr(cls, "__init__")
875+
hints = typing.get_type_hints(init_method)
876+
except Exception:
877+
return []
878+
879+
deps: list[type] = []
880+
for param_name, hint in hints.items():
881+
if param_name == "return":
882+
continue
883+
for dep_cls in extract_classes_from_type_hint(hint):
884+
if dep_cls is cls:
885+
continue
886+
init_method = getattr(dep_cls, "__init__", None)
887+
if init_method is None or init_method is object.__init__:
888+
continue
889+
try:
890+
class_file = Path(inspect.getfile(dep_cls))
891+
except (OSError, TypeError):
892+
continue
893+
if not path_belongs_to_site_packages(class_file):
894+
continue
895+
deps.append(dep_cls)
896+
897+
return deps
898+
899+
900+
def extract_init_stub_for_class(cls: type, class_name: str) -> CodeString | None:
901+
"""Extract a stub containing the class definition with only its __init__ method."""
902+
import inspect
903+
import textwrap
904+
905+
init_method = getattr(cls, "__init__", None)
906+
if init_method is None or init_method is object.__init__:
907+
return None
908+
909+
try:
910+
class_file = Path(inspect.getfile(cls))
911+
except (OSError, TypeError):
912+
return None
913+
914+
if not path_belongs_to_site_packages(class_file):
915+
return None
916+
917+
try:
918+
init_source = inspect.getsource(init_method)
919+
init_source = textwrap.dedent(init_source)
920+
except (OSError, TypeError):
921+
return None
922+
923+
parts = class_file.parts
924+
if "site-packages" in parts:
925+
idx = parts.index("site-packages")
926+
class_file = Path(*parts[idx + 1 :])
927+
928+
class_source = f"class {class_name}:\n" + textwrap.indent(init_source, " ")
929+
return CodeString(code=class_source, file_path=class_file)
930+
931+
932+
def get_external_class_inits(code_context: CodeStringsMarkdown, project_root_path: Path) -> CodeStringsMarkdown:
933+
"""Extract __init__ methods from directly imported external library classes.
934+
935+
Scans the code context for classes imported from external packages (site-packages) and extracts
936+
their __init__ methods, including transitive type dependencies found in __init__ annotations.
937+
This helps the LLM understand constructor signatures for instantiation in generated tests.
938+
"""
939+
import importlib
940+
import inspect
941+
942+
all_code = "\n".join(cs.code for cs in code_context.code_strings)
943+
944+
try:
945+
tree = ast.parse(all_code)
946+
except SyntaxError:
947+
return CodeStringsMarkdown(code_strings=[])
948+
949+
# Collect all from X import Y statements
950+
imported_names: dict[str, str] = {}
951+
is_project_cache: dict[str, bool] = {}
952+
953+
# Track classes already defined in the context to avoid duplicates
954+
existing_classes: set[str] = set()
955+
956+
for node in ast.walk(tree):
957+
if isinstance(node, ast.ImportFrom) and node.module:
958+
for alias in node.names:
959+
if alias.name != "*":
960+
imported_name = alias.asname if alias.asname else alias.name
961+
imported_names[imported_name] = node.module
962+
elif isinstance(node, ast.ClassDef):
963+
existing_classes.add(node.name)
964+
965+
if not imported_names:
966+
return CodeStringsMarkdown(code_strings=[])
967+
968+
# Filter to external-only imports
969+
external_imports: set[tuple[str, str]] = set()
970+
for name, module_name in imported_names.items():
971+
if name in existing_classes:
972+
continue
973+
cached = is_project_cache.get(module_name)
974+
if cached is None:
975+
is_project = _is_project_module(module_name, project_root_path)
976+
is_project_cache[module_name] = is_project
977+
else:
978+
is_project = cached
979+
if not is_project:
980+
external_imports.add((name, module_name))
981+
982+
if not external_imports:
983+
return CodeStringsMarkdown(code_strings=[])
984+
985+
code_strings: list[CodeString] = []
986+
imported_module_cache: dict[str, object] = {}
987+
processed_classes: set[type] = set()
988+
emitted_names: set[str] = set()
989+
990+
# BFS worklist: (class_object, class_name, depth)
991+
worklist: list[tuple[type, str, int]] = []
992+
993+
# Seed the worklist with directly imported classes
994+
for class_name, module_name in external_imports:
995+
try:
996+
module = imported_module_cache.get(module_name)
997+
if module is None:
998+
module = importlib.import_module(module_name)
999+
imported_module_cache[module_name] = module
1000+
1001+
cls = getattr(module, class_name, None)
1002+
if cls is None or not inspect.isclass(cls):
1003+
continue
1004+
1005+
worklist.append((cls, class_name, 0))
1006+
except (ImportError, ModuleNotFoundError, AttributeError):
1007+
logger.debug(f"Failed to import {module_name}.{class_name}")
1008+
continue
1009+
1010+
while worklist:
1011+
cls, class_name, depth = worklist.pop(0)
1012+
1013+
if cls in processed_classes:
1014+
continue
1015+
processed_classes.add(cls)
1016+
1017+
stub = extract_init_stub_for_class(cls, class_name)
1018+
if stub is None:
1019+
continue
1020+
1021+
if class_name not in emitted_names:
1022+
code_strings.append(stub)
1023+
emitted_names.add(class_name)
1024+
1025+
# Resolve transitive type dependencies up to MAX_TRANSITIVE_DEPTH
1026+
if depth < MAX_TRANSITIVE_DEPTH:
1027+
for dep_cls in resolve_transitive_type_deps(cls):
1028+
if dep_cls not in processed_classes:
1029+
worklist.append((dep_cls, dep_cls.__name__, depth + 1))
1030+
1031+
return CodeStringsMarkdown(code_strings=code_strings)
1032+
1033+
8241034
def _is_project_module(module_name: str, project_root_path: Path) -> bool:
8251035
"""Check if a module is part of the project (not external/stdlib)."""
8261036
import importlib.util

codeflash/languages/javascript/instrument.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1354,12 +1354,10 @@ def fix_mock_path(match: re.Match[str]) -> str:
13541354
or source_relative_resolved.with_suffix(".jsx").exists()
13551355
):
13561356
# Calculate the correct relative path from test_dir to source_relative_resolved
1357-
new_rel_path = os.path.relpath(str(source_relative_resolved), str(test_dir))
1357+
new_rel_path = Path(os.path.relpath(source_relative_resolved, test_dir)).as_posix()
13581358
# Ensure it starts with ./ or ../
13591359
if not new_rel_path.startswith("../") and not new_rel_path.startswith("./"):
13601360
new_rel_path = f"./{new_rel_path}"
1361-
# Use forward slashes
1362-
new_rel_path = new_rel_path.replace("\\", "/")
13631361

13641362
logger.debug(f"Fixed jest.mock path: {rel_path} -> {new_rel_path}")
13651363
return f"{prefix}{new_rel_path}{suffix}"

codeflash/verification/parse_test_output.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import contextlib
43
import os
54
import re
65
import sqlite3
@@ -22,6 +21,9 @@
2221
)
2322
from codeflash.discovery.discover_unit_tests import discover_parameters_unittest
2423
from codeflash.languages import is_javascript
24+
25+
# Import Jest-specific parsing from the JavaScript language module
26+
from codeflash.languages.javascript.parse import parse_jest_test_xml as _parse_jest_test_xml
2527
from codeflash.models.models import (
2628
ConcurrencyMetrics,
2729
FunctionTestInvocation,
@@ -32,10 +34,6 @@
3234
)
3335
from codeflash.verification.coverage_utils import CoverageUtils, JestCoverageUtils
3436

35-
# Import Jest-specific parsing from the JavaScript language module
36-
from codeflash.languages.javascript.parse import jest_end_pattern, jest_start_pattern
37-
from codeflash.languages.javascript.parse import parse_jest_test_xml as _parse_jest_test_xml
38-
3937
if TYPE_CHECKING:
4038
import subprocess
4139

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ ignore = [
289289
"SIM108", # Ternary operator suggestion
290290
"F841", # Unused variable (often intentional)
291291
"ANN202", # Missing return type for private functions
292+
"B009", # getattr-with-constant - needed to avoid mypy [misc] on dunder access
292293
]
293294

294295
[tool.ruff.lint.flake8-type-checking]

tests/languages/javascript/test_vitest_junit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import pytest
1313
from junitparser import JUnitXml
1414

15-
from codeflash.verification.parse_test_output import jest_end_pattern, jest_start_pattern
15+
from codeflash.languages.javascript.parse import jest_end_pattern, jest_start_pattern
1616

1717

1818
class TestVitestJunitXmlFormat:

0 commit comments

Comments
 (0)