@@ -3501,8 +3501,8 @@ def get_router_config(self) -> RouterConfig:
35013501 assert "model_list: list" in all_extracted_code , "Should include model_list field from Router"
35023502
35033503
3504- def test_enrich_testgen_context_extracts_userdict (tmp_path : Path ) -> None :
3505- """Extracts __init__ from collections.UserDict when a class inherits from it ."""
3504+ def test_enrich_testgen_context_skips_stdlib_userdict (tmp_path : Path ) -> None :
3505+ """Skips stdlib classes like collections.UserDict."""
35063506 code = """from collections import UserDict
35073507
35083508class MyCustomDict(UserDict):
@@ -3514,14 +3514,7 @@ class MyCustomDict(UserDict):
35143514 context = CodeStringsMarkdown (code_strings = [CodeString (code = code , file_path = code_path )])
35153515 result = enrich_testgen_context (context , tmp_path )
35163516
3517- assert len (result .code_strings ) == 1
3518- code_string = result .code_strings [0 ]
3519-
3520- assert "class UserDict" in code_string .code
3521- assert "def __init__" in code_string .code
3522- assert "self.data = {}" in code_string .code
3523- assert code_string .file_path is not None
3524- assert code_string .file_path .as_posix ().endswith ("collections/__init__.py" )
3517+ assert len (result .code_strings ) == 0 , "Should not extract stdlib classes"
35253518
35263519
35273520def test_enrich_testgen_context_skips_unresolvable_base_classes (tmp_path : Path ) -> None :
@@ -3555,24 +3548,24 @@ def test_enrich_testgen_context_skips_builtin_base_classes(tmp_path: Path) -> No
35553548
35563549
35573550def test_enrich_testgen_context_deduplicates (tmp_path : Path ) -> None :
3558- """Extracts the same external base class only once even when inherited multiple times."""
3559- code = """from collections import UserDict
3560-
3561- class MyDict1(UserDict):
3562- pass
3551+ """Extracts the same project class only once even when imported multiple times."""
3552+ package_dir = tmp_path / "mypkg"
3553+ package_dir .mkdir ()
3554+ (package_dir / "__init__.py" ).write_text ("" , encoding = "utf-8" )
3555+ (package_dir / "base.py" ).write_text (
3556+ "class Base:\n def __init__(self, x: int):\n self.x = x\n " ,
3557+ encoding = "utf-8" ,
3558+ )
35633559
3564- class MyDict2(UserDict):
3565- pass
3566- """
3567- code_path = tmp_path / "mydicts.py"
3560+ code = "from mypkg.base import Base\n \n class A(Base):\n pass\n \n class B(Base):\n pass\n "
3561+ code_path = package_dir / "children.py"
35683562 code_path .write_text (code , encoding = "utf-8" )
35693563
35703564 context = CodeStringsMarkdown (code_strings = [CodeString (code = code , file_path = code_path )])
35713565 result = enrich_testgen_context (context , tmp_path )
35723566
35733567 assert len (result .code_strings ) == 1
3574- assert "class UserDict" in result .code_strings [0 ].code
3575- assert "def __init__" in result .code_strings [0 ].code
3568+ assert "class Base" in result .code_strings [0 ].code
35763569
35773570
35783571def test_enrich_testgen_context_empty_when_no_inheritance (tmp_path : Path ) -> None :
@@ -3699,18 +3692,17 @@ def reify_channel_message(data: dict) -> MessageIn:
36993692
37003693
37013694def test_testgen_context_includes_external_base_inits (tmp_path : Path ) -> None :
3702- """Test that external base class __init__ methods are included in testgen context.
3703-
3704- This covers line 65 in code_context_extractor.py where external_base_inits.code_strings
3705- are appended to the testgen context when a class inherits from an external library.
3706- """
3707- code = """from collections import UserDict
3695+ """Test that base class definitions from project modules are included in testgen context."""
3696+ package_dir = tmp_path / "mypkg"
3697+ package_dir .mkdir ()
3698+ (package_dir / "__init__.py" ).write_text ("" , encoding = "utf-8" )
3699+ (package_dir / "base.py" ).write_text (
3700+ "class BaseDict:\n def __init__(self, data=None):\n self.data = data or {}\n " ,
3701+ encoding = "utf-8" ,
3702+ )
37083703
3709- class MyCustomDict(UserDict):
3710- def target_method(self):
3711- return self.data
3712- """
3713- file_path = tmp_path / "test_code.py"
3704+ code = "from mypkg.base import BaseDict\n \n class MyCustomDict(BaseDict):\n def target_method(self):\n return self.data\n "
3705+ file_path = package_dir / "test_code.py"
37143706 file_path .write_text (code , encoding = "utf-8" )
37153707
37163708 func_to_optimize = FunctionToOptimize (
@@ -3721,11 +3713,10 @@ def target_method(self):
37213713
37223714 code_ctx = get_code_optimization_context (function_to_optimize = func_to_optimize , project_root_path = tmp_path )
37233715
3724- # The testgen context should include the UserDict __init__ method
37253716 testgen_context = code_ctx .testgen_context .markdown
3726- assert "class UserDict " in testgen_context , "UserDict class should be in testgen context"
3727- assert "def __init__" in testgen_context , "UserDict __init__ should be in testgen context"
3728- assert "self.data = {} " in testgen_context , "UserDict __init__ body should be included"
3717+ assert "class BaseDict " in testgen_context , "BaseDict class should be in testgen context"
3718+ assert "def __init__" in testgen_context , "BaseDict __init__ should be in testgen context"
3719+ assert "self.data" in testgen_context , "BaseDict __init__ body should be included"
37293720
37303721
37313722def test_testgen_raises_when_exceeds_limit (tmp_path : Path ) -> None :
@@ -3756,26 +3747,24 @@ def target_function():
37563747
37573748
37583749def test_enrich_testgen_context_attribute_base (tmp_path : Path ) -> None :
3759- """Test handling of base class accessed as module.ClassName (ast.Attribute).
3760-
3761- This covers line 616 in code_context_extractor.py.
3762- """
3763- # Use the standard import style which the code actually handles
3764- code = """from collections import UserDict
3750+ """Test handling of base class in a project module."""
3751+ package_dir = tmp_path / "mypkg"
3752+ package_dir .mkdir ()
3753+ (package_dir / "__init__.py" ).write_text ("" , encoding = "utf-8" )
3754+ (package_dir / "base.py" ).write_text (
3755+ "class CustomDict:\n def __init__(self, data=None):\n self.data = data or {}\n " ,
3756+ encoding = "utf-8" ,
3757+ )
37653758
3766- class MyDict(UserDict):
3767- def custom_method(self):
3768- return self.data
3769- """
3770- code_path = tmp_path / "mydict.py"
3759+ code = "from mypkg.base import CustomDict\n \n class MyDict(CustomDict):\n def custom_method(self):\n return self.data\n "
3760+ code_path = package_dir / "mydict.py"
37713761 code_path .write_text (code , encoding = "utf-8" )
37723762
37733763 context = CodeStringsMarkdown (code_strings = [CodeString (code = code , file_path = code_path )])
37743764 result = enrich_testgen_context (context , tmp_path )
37753765
3776- # Should extract UserDict
37773766 assert len (result .code_strings ) == 1
3778- assert "class UserDict " in result .code_strings [0 ].code
3767+ assert "class CustomDict " in result .code_strings [0 ].code
37793768 assert "def __init__" in result .code_strings [0 ].code
37803769
37813770
@@ -4026,8 +4015,8 @@ def my_func() -> None:
40264015 assert result .code_strings == []
40274016
40284017
4029- def test_enrich_testgen_context_skips_object_init (tmp_path : Path ) -> None :
4030- """QName has a real class definition in stdlib source, so it gets extracted ."""
4018+ def test_enrich_testgen_context_skips_stdlib (tmp_path : Path ) -> None :
4019+ """Skips stdlib classes like QName ."""
40314020 code = """from xml.etree.ElementTree import QName
40324021
40334022def my_func(q: QName) -> None:
@@ -4039,9 +4028,7 @@ def my_func(q: QName) -> None:
40394028 context = CodeStringsMarkdown (code_strings = [CodeString (code = code , file_path = code_path )])
40404029 result = enrich_testgen_context (context , tmp_path )
40414030
4042- # QName has its own class definition in ElementTree source
4043- assert len (result .code_strings ) == 1
4044- assert "class QName" in result .code_strings [0 ].code
4031+ assert result .code_strings == [], "Should not extract stdlib classes"
40454032
40464033
40474034def test_enrich_testgen_context_empty_when_no_imports (tmp_path : Path ) -> None :
0 commit comments