Skip to content

Commit 5cee1b5

Browse files
committed
feat: improve test generation context for external library types
Extend extract_parameter_type_constructors to scan function bodies for isinstance/type() patterns and collect base class names from enclosing classes. Add one-level transitive stub extraction so the LLM also sees constructor signatures for types referenced in __init__ parameters. In enrich_testgen_context, branch on source: project classes get full definitions, third-party (site-packages) classes get compact __init__ stubs to avoid blowing token limits.
1 parent d284be5 commit 5cee1b5

2 files changed

Lines changed: 270 additions & 8 deletions

File tree

codeflash/languages/python/context/code_context_extractor.py

Lines changed: 104 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -837,6 +837,41 @@ def extract_parameter_type_constructors(
837837
if func_node.args.kwarg:
838838
type_names |= collect_type_names_from_annotation(func_node.args.kwarg.annotation)
839839

840+
# Scan function body for isinstance(x, SomeType) and type(x) is/== SomeType patterns
841+
for body_node in ast.walk(func_node):
842+
if (
843+
isinstance(body_node, ast.Call)
844+
and isinstance(body_node.func, ast.Name)
845+
and body_node.func.id == "isinstance"
846+
):
847+
if len(body_node.args) >= 2:
848+
second_arg = body_node.args[1]
849+
if isinstance(second_arg, ast.Name):
850+
type_names.add(second_arg.id)
851+
elif isinstance(second_arg, ast.Tuple):
852+
for elt in second_arg.elts:
853+
if isinstance(elt, ast.Name):
854+
type_names.add(elt.id)
855+
elif isinstance(body_node, ast.Compare):
856+
# type(x) is/== SomeType
857+
if (
858+
isinstance(body_node.left, ast.Call)
859+
and isinstance(body_node.left.func, ast.Name)
860+
and body_node.left.func.id == "type"
861+
):
862+
for comparator in body_node.comparators:
863+
if isinstance(comparator, ast.Name):
864+
type_names.add(comparator.id)
865+
866+
# Collect base class names from enclosing class (if this is a method)
867+
if function_to_optimize.class_name is not None:
868+
for top_node in ast.walk(tree):
869+
if isinstance(top_node, ast.ClassDef) and top_node.name == function_to_optimize.class_name:
870+
for base in top_node.bases:
871+
if isinstance(base, ast.Name):
872+
type_names.add(base.id)
873+
break
874+
840875
type_names -= BUILTIN_AND_TYPING_NAMES
841876
type_names -= existing_class_names
842877
if not type_names:
@@ -881,6 +916,58 @@ def extract_parameter_type_constructors(
881916
logger.debug(f"Error extracting constructor stub for {type_name} from {module_name}")
882917
continue
883918

919+
# Transitive extraction (one level): for each extracted stub, find __init__ param types and extract their stubs
920+
# Build an extended import map that includes imports from source modules of already-extracted stubs
921+
transitive_import_map = dict(import_map)
922+
for _, cached_tree in module_cache.values():
923+
for cache_node in ast.walk(cached_tree):
924+
if isinstance(cache_node, ast.ImportFrom) and cache_node.module:
925+
for alias in cache_node.names:
926+
name = alias.asname if alias.asname else alias.name
927+
if name not in transitive_import_map:
928+
transitive_import_map[name] = cache_node.module
929+
930+
emitted_names = type_names | existing_class_names | BUILTIN_AND_TYPING_NAMES
931+
transitive_type_names: set[str] = set()
932+
for cs in code_strings:
933+
try:
934+
stub_tree = ast.parse(cs.code)
935+
except SyntaxError:
936+
continue
937+
for stub_node in ast.walk(stub_tree):
938+
if isinstance(stub_node, (ast.FunctionDef, ast.AsyncFunctionDef)) and stub_node.name in (
939+
"__init__",
940+
"__post_init__",
941+
):
942+
for arg in stub_node.args.args + stub_node.args.posonlyargs + stub_node.args.kwonlyargs:
943+
transitive_type_names |= collect_type_names_from_annotation(arg.annotation)
944+
transitive_type_names -= emitted_names
945+
for type_name in sorted(transitive_type_names):
946+
module_name = transitive_import_map.get(type_name)
947+
if not module_name:
948+
continue
949+
try:
950+
script_code = f"from {module_name} import {type_name}"
951+
script = jedi.Script(script_code, project=jedi.Project(path=project_root_path))
952+
definitions = script.goto(1, len(f"from {module_name} import ") + len(type_name), follow_imports=True)
953+
if not definitions:
954+
continue
955+
module_path = definitions[0].module_path
956+
if not module_path:
957+
continue
958+
if module_path in module_cache:
959+
mod_source, mod_tree = module_cache[module_path]
960+
else:
961+
mod_source = module_path.read_text(encoding="utf-8")
962+
mod_tree = ast.parse(mod_source)
963+
module_cache[module_path] = (mod_source, mod_tree)
964+
stub = extract_init_stub_from_class(type_name, mod_source, mod_tree)
965+
if stub:
966+
code_strings.append(CodeString(code=stub, file_path=module_path))
967+
except Exception:
968+
logger.debug(f"Error extracting transitive constructor stub for {type_name} from {module_name}")
969+
continue
970+
884971
return CodeStringsMarkdown(code_strings=code_strings)
885972

886973

@@ -1004,12 +1091,23 @@ def extract_class_and_bases(
10041091
continue
10051092
module_source, module_tree = mod_result
10061093

1007-
extract_class_and_bases(name, module_path, module_source, module_tree)
1008-
1009-
if (module_path, name) not in extracted_classes:
1010-
resolved_class = resolve_instance_class_name(name, module_tree)
1011-
if resolved_class and resolved_class not in existing_classes:
1012-
extract_class_and_bases(resolved_class, module_path, module_source, module_tree)
1094+
if is_project:
1095+
extract_class_and_bases(name, module_path, module_source, module_tree)
1096+
if (module_path, name) not in extracted_classes:
1097+
resolved_class = resolve_instance_class_name(name, module_tree)
1098+
if resolved_class and resolved_class not in existing_classes:
1099+
extract_class_and_bases(resolved_class, module_path, module_source, module_tree)
1100+
elif is_third_party:
1101+
target_name = name
1102+
if not any(isinstance(n, ast.ClassDef) and n.name == name for n in ast.walk(module_tree)):
1103+
resolved_class = resolve_instance_class_name(name, module_tree)
1104+
if resolved_class:
1105+
target_name = resolved_class
1106+
if target_name not in emitted_class_names:
1107+
stub = extract_init_stub_from_class(target_name, module_source, module_tree)
1108+
if stub:
1109+
code_strings.append(CodeString(code=stub, file_path=module_path))
1110+
emitted_class_names.add(target_name)
10131111

10141112
except Exception:
10151113
logger.debug(f"Error extracting class definition for {name} from {module_name}")

tests/test_code_context_extractor.py

Lines changed: 166 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99

1010
import pytest
1111

12-
from codeflash.languages.python.static_analysis.code_extractor import GlobalAssignmentCollector, add_global_assignments
13-
from codeflash.languages.python.static_analysis.code_replacer import replace_functions_and_add_imports
1412
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
1513
from codeflash.languages.python.context.code_context_extractor import (
1614
collect_type_names_from_annotation,
@@ -20,6 +18,8 @@
2018
get_code_optimization_context,
2119
resolve_instance_class_name,
2220
)
21+
from codeflash.languages.python.static_analysis.code_extractor import GlobalAssignmentCollector, add_global_assignments
22+
from codeflash.languages.python.static_analysis.code_replacer import replace_functions_and_add_imports
2323
from codeflash.models.models import CodeString, CodeStringsMarkdown, FunctionParent
2424
from codeflash.optimization.optimizer import Optimizer
2525

@@ -4701,3 +4701,167 @@ def get_log_level() -> str:
47014701
combined = "\n".join(cs.code for cs in result.code_strings)
47024702
assert "class AppConfig:" in combined
47034703
assert "@property" in combined
4704+
4705+
def test_extract_parameter_type_constructors_isinstance_single(tmp_path: Path) -> None:
4706+
"""isinstance(x, SomeType) in function body should be picked up."""
4707+
pkg = tmp_path / "mypkg"
4708+
pkg.mkdir()
4709+
(pkg / "__init__.py").write_text("", encoding="utf-8")
4710+
(pkg / "models.py").write_text(
4711+
"class Widget:\n def __init__(self, size: int):\n self.size = size\n",
4712+
encoding="utf-8",
4713+
)
4714+
(pkg / "processor.py").write_text(
4715+
"from mypkg.models import Widget\n\ndef check(x) -> bool:\n return isinstance(x, Widget)\n",
4716+
encoding="utf-8",
4717+
)
4718+
fto = FunctionToOptimize(
4719+
function_name="check", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
4720+
)
4721+
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
4722+
assert len(result.code_strings) == 1
4723+
assert "class Widget:" in result.code_strings[0].code
4724+
assert "__init__" in result.code_strings[0].code
4725+
4726+
4727+
def test_extract_parameter_type_constructors_isinstance_tuple(tmp_path: Path) -> None:
4728+
"""isinstance(x, (TypeA, TypeB)) should pick up both types."""
4729+
pkg = tmp_path / "mypkg"
4730+
pkg.mkdir()
4731+
(pkg / "__init__.py").write_text("", encoding="utf-8")
4732+
(pkg / "models.py").write_text(
4733+
"class Alpha:\n def __init__(self, a: int):\n self.a = a\n\n"
4734+
"class Beta:\n def __init__(self, b: str):\n self.b = b\n",
4735+
encoding="utf-8",
4736+
)
4737+
(pkg / "processor.py").write_text(
4738+
"from mypkg.models import Alpha, Beta\n\ndef check(x) -> bool:\n return isinstance(x, (Alpha, Beta))\n",
4739+
encoding="utf-8",
4740+
)
4741+
fto = FunctionToOptimize(
4742+
function_name="check", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
4743+
)
4744+
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
4745+
assert len(result.code_strings) == 2
4746+
combined = "\n".join(cs.code for cs in result.code_strings)
4747+
assert "class Alpha:" in combined
4748+
assert "class Beta:" in combined
4749+
4750+
4751+
def test_extract_parameter_type_constructors_type_is_pattern(tmp_path: Path) -> None:
4752+
"""type(x) is SomeType pattern should be picked up."""
4753+
pkg = tmp_path / "mypkg"
4754+
pkg.mkdir()
4755+
(pkg / "__init__.py").write_text("", encoding="utf-8")
4756+
(pkg / "models.py").write_text(
4757+
"class Gadget:\n def __init__(self, val: float):\n self.val = val\n",
4758+
encoding="utf-8",
4759+
)
4760+
(pkg / "processor.py").write_text(
4761+
"from mypkg.models import Gadget\n\ndef check(x) -> bool:\n return type(x) is Gadget\n",
4762+
encoding="utf-8",
4763+
)
4764+
fto = FunctionToOptimize(
4765+
function_name="check", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
4766+
)
4767+
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
4768+
assert len(result.code_strings) == 1
4769+
assert "class Gadget:" in result.code_strings[0].code
4770+
4771+
4772+
def test_extract_parameter_type_constructors_base_classes(tmp_path: Path) -> None:
4773+
"""Base classes of enclosing class should be picked up for methods."""
4774+
pkg = tmp_path / "mypkg"
4775+
pkg.mkdir()
4776+
(pkg / "__init__.py").write_text("", encoding="utf-8")
4777+
(pkg / "base.py").write_text(
4778+
"class BaseProcessor:\n def __init__(self, config: str):\n self.config = config\n",
4779+
encoding="utf-8",
4780+
)
4781+
(pkg / "child.py").write_text(
4782+
"from mypkg.base import BaseProcessor\n\nclass ChildProcessor(BaseProcessor):\n"
4783+
" def process(self) -> str:\n return self.config\n",
4784+
encoding="utf-8",
4785+
)
4786+
fto = FunctionToOptimize(
4787+
function_name="process",
4788+
file_path=(pkg / "child.py").resolve(),
4789+
starting_line=4,
4790+
ending_line=5,
4791+
parents=[FunctionParent(name="ChildProcessor", type="ClassDef")],
4792+
)
4793+
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
4794+
assert len(result.code_strings) == 1
4795+
assert "class BaseProcessor:" in result.code_strings[0].code
4796+
4797+
4798+
def test_extract_parameter_type_constructors_isinstance_builtins_excluded(tmp_path: Path) -> None:
4799+
"""Isinstance with builtins (int, str, etc.) should not produce stubs."""
4800+
pkg = tmp_path / "mypkg"
4801+
pkg.mkdir()
4802+
(pkg / "__init__.py").write_text("", encoding="utf-8")
4803+
(pkg / "func.py").write_text(
4804+
"def check(x) -> bool:\n return isinstance(x, (int, str, float))\n",
4805+
encoding="utf-8",
4806+
)
4807+
fto = FunctionToOptimize(
4808+
function_name="check", file_path=(pkg / "func.py").resolve(), starting_line=1, ending_line=2
4809+
)
4810+
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
4811+
assert len(result.code_strings) == 0
4812+
4813+
4814+
def test_extract_parameter_type_constructors_transitive(tmp_path: Path) -> None:
4815+
"""Transitive extraction: if Widget.__init__ takes a Config, Config's stub should also appear."""
4816+
pkg = tmp_path / "mypkg"
4817+
pkg.mkdir()
4818+
(pkg / "__init__.py").write_text("", encoding="utf-8")
4819+
(pkg / "config.py").write_text(
4820+
"class Config:\n def __init__(self, debug: bool = False):\n self.debug = debug\n",
4821+
encoding="utf-8",
4822+
)
4823+
(pkg / "models.py").write_text(
4824+
"from mypkg.config import Config\n\n"
4825+
"class Widget:\n def __init__(self, cfg: Config):\n self.cfg = cfg\n",
4826+
encoding="utf-8",
4827+
)
4828+
(pkg / "processor.py").write_text(
4829+
"from mypkg.models import Widget\n\ndef process(w: Widget) -> str:\n return str(w)\n",
4830+
encoding="utf-8",
4831+
)
4832+
fto = FunctionToOptimize(
4833+
function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
4834+
)
4835+
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
4836+
combined = "\n".join(cs.code for cs in result.code_strings)
4837+
assert "class Widget:" in combined
4838+
assert "class Config:" in combined
4839+
4840+
4841+
4842+
4843+
def test_enrich_testgen_context_third_party_uses_stubs(tmp_path: Path) -> None:
4844+
"""Third-party classes should produce compact __init__ stubs, not full class source."""
4845+
# Use a real third-party package (pydantic) so jedi can actually resolve it
4846+
context_code = (
4847+
"from pydantic import BaseModel\n\n"
4848+
"class MyModel(BaseModel):\n"
4849+
" name: str\n\n"
4850+
"def process(m: MyModel) -> str:\n"
4851+
" return m.name\n"
4852+
)
4853+
consumer_path = tmp_path / "consumer.py"
4854+
consumer_path.write_text(context_code, encoding="utf-8")
4855+
4856+
context = CodeStringsMarkdown(code_strings=[CodeString(code=context_code, file_path=consumer_path)])
4857+
result = enrich_testgen_context(context, tmp_path)
4858+
4859+
# BaseModel lives in site-packages so should get stub treatment (compact __init__),
4860+
# not the full class definition with hundreds of methods
4861+
for cs in result.code_strings:
4862+
if "BaseModel" in cs.code:
4863+
assert "class BaseModel:" in cs.code
4864+
assert "__init__" in cs.code
4865+
# Full BaseModel has many methods; stubs should only have __init__/properties
4866+
assert "model_dump" not in cs.code
4867+
break

0 commit comments

Comments
 (0)