Skip to content

Commit b269212

Browse files
committed
context extraction imporvements
1 parent 7f5e163 commit b269212

2 files changed

Lines changed: 49 additions & 55 deletions

File tree

codeflash/languages/python/context/code_context_extractor.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -950,7 +950,7 @@ def extract_class_and_bases(
950950
emitted_class_names.add(class_name)
951951

952952
for name, module_name in imported_names.items():
953-
if name in existing_classes:
953+
if name in existing_classes or module_name == "__future__":
954954
continue
955955
try:
956956
test_code = f"import {module_name}"
@@ -964,6 +964,13 @@ def extract_class_and_bases(
964964
if not module_path:
965965
continue
966966

967+
resolved_module = module_path.resolve()
968+
module_str = str(resolved_module)
969+
is_project = module_str.startswith(str(project_root_path.resolve()))
970+
is_third_party = "site-packages" in module_str
971+
if not is_project and not is_third_party:
972+
continue
973+
967974
mod_result = get_module_source_and_tree(module_path)
968975
if mod_result is None:
969976
continue

tests/test_code_context_extractor.py

Lines changed: 41 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3501,8 +3501,8 @@ def get_router_config(self) -> RouterConfig:
35013501
assert "model_list: list" in all_extracted_code, "Should include model_list field from Router"
35023502

35033503

3504-
def test_enrich_testgen_context_extracts_userdict(tmp_path: Path) -> None:
3505-
"""Extracts __init__ from collections.UserDict when a class inherits from it."""
3504+
def test_enrich_testgen_context_skips_stdlib_userdict(tmp_path: Path) -> None:
3505+
"""Skips stdlib classes like collections.UserDict."""
35063506
code = """from collections import UserDict
35073507
35083508
class MyCustomDict(UserDict):
@@ -3514,14 +3514,7 @@ class MyCustomDict(UserDict):
35143514
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
35153515
result = enrich_testgen_context(context, tmp_path)
35163516

3517-
assert len(result.code_strings) == 1
3518-
code_string = result.code_strings[0]
3519-
3520-
assert "class UserDict" in code_string.code
3521-
assert "def __init__" in code_string.code
3522-
assert "self.data = {}" in code_string.code
3523-
assert code_string.file_path is not None
3524-
assert code_string.file_path.as_posix().endswith("collections/__init__.py")
3517+
assert len(result.code_strings) == 0, "Should not extract stdlib classes"
35253518

35263519

35273520
def test_enrich_testgen_context_skips_unresolvable_base_classes(tmp_path: Path) -> None:
@@ -3555,24 +3548,24 @@ def test_enrich_testgen_context_skips_builtin_base_classes(tmp_path: Path) -> No
35553548

35563549

35573550
def test_enrich_testgen_context_deduplicates(tmp_path: Path) -> None:
3558-
"""Extracts the same external base class only once even when inherited multiple times."""
3559-
code = """from collections import UserDict
3560-
3561-
class MyDict1(UserDict):
3562-
pass
3551+
"""Extracts the same project class only once even when imported multiple times."""
3552+
package_dir = tmp_path / "mypkg"
3553+
package_dir.mkdir()
3554+
(package_dir / "__init__.py").write_text("", encoding="utf-8")
3555+
(package_dir / "base.py").write_text(
3556+
"class Base:\n def __init__(self, x: int):\n self.x = x\n",
3557+
encoding="utf-8",
3558+
)
35633559

3564-
class MyDict2(UserDict):
3565-
pass
3566-
"""
3567-
code_path = tmp_path / "mydicts.py"
3560+
code = "from mypkg.base import Base\n\nclass A(Base):\n pass\n\nclass B(Base):\n pass\n"
3561+
code_path = package_dir / "children.py"
35683562
code_path.write_text(code, encoding="utf-8")
35693563

35703564
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
35713565
result = enrich_testgen_context(context, tmp_path)
35723566

35733567
assert len(result.code_strings) == 1
3574-
assert "class UserDict" in result.code_strings[0].code
3575-
assert "def __init__" in result.code_strings[0].code
3568+
assert "class Base" in result.code_strings[0].code
35763569

35773570

35783571
def test_enrich_testgen_context_empty_when_no_inheritance(tmp_path: Path) -> None:
@@ -3699,18 +3692,17 @@ def reify_channel_message(data: dict) -> MessageIn:
36993692

37003693

37013694
def test_testgen_context_includes_external_base_inits(tmp_path: Path) -> None:
3702-
"""Test that external base class __init__ methods are included in testgen context.
3703-
3704-
This covers line 65 in code_context_extractor.py where external_base_inits.code_strings
3705-
are appended to the testgen context when a class inherits from an external library.
3706-
"""
3707-
code = """from collections import UserDict
3695+
"""Test that base class definitions from project modules are included in testgen context."""
3696+
package_dir = tmp_path / "mypkg"
3697+
package_dir.mkdir()
3698+
(package_dir / "__init__.py").write_text("", encoding="utf-8")
3699+
(package_dir / "base.py").write_text(
3700+
"class BaseDict:\n def __init__(self, data=None):\n self.data = data or {}\n",
3701+
encoding="utf-8",
3702+
)
37083703

3709-
class MyCustomDict(UserDict):
3710-
def target_method(self):
3711-
return self.data
3712-
"""
3713-
file_path = tmp_path / "test_code.py"
3704+
code = "from mypkg.base import BaseDict\n\nclass MyCustomDict(BaseDict):\n def target_method(self):\n return self.data\n"
3705+
file_path = package_dir / "test_code.py"
37143706
file_path.write_text(code, encoding="utf-8")
37153707

37163708
func_to_optimize = FunctionToOptimize(
@@ -3721,11 +3713,10 @@ def target_method(self):
37213713

37223714
code_ctx = get_code_optimization_context(function_to_optimize=func_to_optimize, project_root_path=tmp_path)
37233715

3724-
# The testgen context should include the UserDict __init__ method
37253716
testgen_context = code_ctx.testgen_context.markdown
3726-
assert "class UserDict" in testgen_context, "UserDict class should be in testgen context"
3727-
assert "def __init__" in testgen_context, "UserDict __init__ should be in testgen context"
3728-
assert "self.data = {}" in testgen_context, "UserDict __init__ body should be included"
3717+
assert "class BaseDict" in testgen_context, "BaseDict class should be in testgen context"
3718+
assert "def __init__" in testgen_context, "BaseDict __init__ should be in testgen context"
3719+
assert "self.data" in testgen_context, "BaseDict __init__ body should be included"
37293720

37303721

37313722
def test_testgen_raises_when_exceeds_limit(tmp_path: Path) -> None:
@@ -3756,26 +3747,24 @@ def target_function():
37563747

37573748

37583749
def test_enrich_testgen_context_attribute_base(tmp_path: Path) -> None:
3759-
"""Test handling of base class accessed as module.ClassName (ast.Attribute).
3760-
3761-
This covers line 616 in code_context_extractor.py.
3762-
"""
3763-
# Use the standard import style which the code actually handles
3764-
code = """from collections import UserDict
3750+
"""Test handling of base class in a project module."""
3751+
package_dir = tmp_path / "mypkg"
3752+
package_dir.mkdir()
3753+
(package_dir / "__init__.py").write_text("", encoding="utf-8")
3754+
(package_dir / "base.py").write_text(
3755+
"class CustomDict:\n def __init__(self, data=None):\n self.data = data or {}\n",
3756+
encoding="utf-8",
3757+
)
37653758

3766-
class MyDict(UserDict):
3767-
def custom_method(self):
3768-
return self.data
3769-
"""
3770-
code_path = tmp_path / "mydict.py"
3759+
code = "from mypkg.base import CustomDict\n\nclass MyDict(CustomDict):\n def custom_method(self):\n return self.data\n"
3760+
code_path = package_dir / "mydict.py"
37713761
code_path.write_text(code, encoding="utf-8")
37723762

37733763
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
37743764
result = enrich_testgen_context(context, tmp_path)
37753765

3776-
# Should extract UserDict
37773766
assert len(result.code_strings) == 1
3778-
assert "class UserDict" in result.code_strings[0].code
3767+
assert "class CustomDict" in result.code_strings[0].code
37793768
assert "def __init__" in result.code_strings[0].code
37803769

37813770

@@ -4026,8 +4015,8 @@ def my_func() -> None:
40264015
assert result.code_strings == []
40274016

40284017

4029-
def test_enrich_testgen_context_skips_object_init(tmp_path: Path) -> None:
4030-
"""QName has a real class definition in stdlib source, so it gets extracted."""
4018+
def test_enrich_testgen_context_skips_stdlib(tmp_path: Path) -> None:
4019+
"""Skips stdlib classes like QName."""
40314020
code = """from xml.etree.ElementTree import QName
40324021
40334022
def my_func(q: QName) -> None:
@@ -4039,9 +4028,7 @@ def my_func(q: QName) -> None:
40394028
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
40404029
result = enrich_testgen_context(context, tmp_path)
40414030

4042-
# QName has its own class definition in ElementTree source
4043-
assert len(result.code_strings) == 1
4044-
assert "class QName" in result.code_strings[0].code
4031+
assert result.code_strings == [], "Should not extract stdlib classes"
40454032

40464033

40474034
def test_enrich_testgen_context_empty_when_no_imports(tmp_path: Path) -> None:

0 commit comments

Comments
 (0)