|
9 | 9 |
|
10 | 10 | import pytest |
11 | 11 |
|
12 | | -from codeflash.languages.python.static_analysis.code_extractor import GlobalAssignmentCollector, add_global_assignments |
13 | | -from codeflash.languages.python.static_analysis.code_replacer import replace_functions_and_add_imports |
14 | 12 | from codeflash.discovery.functions_to_optimize import FunctionToOptimize |
15 | 13 | from codeflash.languages.python.context.code_context_extractor import ( |
16 | 14 | collect_type_names_from_annotation, |
|
20 | 18 | get_code_optimization_context, |
21 | 19 | resolve_instance_class_name, |
22 | 20 | ) |
| 21 | +from codeflash.languages.python.static_analysis.code_extractor import GlobalAssignmentCollector, add_global_assignments |
| 22 | +from codeflash.languages.python.static_analysis.code_replacer import replace_functions_and_add_imports |
23 | 23 | from codeflash.models.models import CodeString, CodeStringsMarkdown, FunctionParent |
24 | 24 | from codeflash.optimization.optimizer import Optimizer |
25 | 25 |
|
@@ -4701,3 +4701,167 @@ def get_log_level() -> str: |
4701 | 4701 | combined = "\n".join(cs.code for cs in result.code_strings) |
4702 | 4702 | assert "class AppConfig:" in combined |
4703 | 4703 | assert "@property" in combined |
| 4704 | + |
| 4705 | +def test_extract_parameter_type_constructors_isinstance_single(tmp_path: Path) -> None: |
| 4706 | + """isinstance(x, SomeType) in function body should be picked up.""" |
| 4707 | + pkg = tmp_path / "mypkg" |
| 4708 | + pkg.mkdir() |
| 4709 | + (pkg / "__init__.py").write_text("", encoding="utf-8") |
| 4710 | + (pkg / "models.py").write_text( |
| 4711 | + "class Widget:\n def __init__(self, size: int):\n self.size = size\n", |
| 4712 | + encoding="utf-8", |
| 4713 | + ) |
| 4714 | + (pkg / "processor.py").write_text( |
| 4715 | + "from mypkg.models import Widget\n\ndef check(x) -> bool:\n return isinstance(x, Widget)\n", |
| 4716 | + encoding="utf-8", |
| 4717 | + ) |
| 4718 | + fto = FunctionToOptimize( |
| 4719 | + function_name="check", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4 |
| 4720 | + ) |
| 4721 | + result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set()) |
| 4722 | + assert len(result.code_strings) == 1 |
| 4723 | + assert "class Widget:" in result.code_strings[0].code |
| 4724 | + assert "__init__" in result.code_strings[0].code |
| 4725 | + |
| 4726 | + |
| 4727 | +def test_extract_parameter_type_constructors_isinstance_tuple(tmp_path: Path) -> None: |
| 4728 | + """isinstance(x, (TypeA, TypeB)) should pick up both types.""" |
| 4729 | + pkg = tmp_path / "mypkg" |
| 4730 | + pkg.mkdir() |
| 4731 | + (pkg / "__init__.py").write_text("", encoding="utf-8") |
| 4732 | + (pkg / "models.py").write_text( |
| 4733 | + "class Alpha:\n def __init__(self, a: int):\n self.a = a\n\n" |
| 4734 | + "class Beta:\n def __init__(self, b: str):\n self.b = b\n", |
| 4735 | + encoding="utf-8", |
| 4736 | + ) |
| 4737 | + (pkg / "processor.py").write_text( |
| 4738 | + "from mypkg.models import Alpha, Beta\n\ndef check(x) -> bool:\n return isinstance(x, (Alpha, Beta))\n", |
| 4739 | + encoding="utf-8", |
| 4740 | + ) |
| 4741 | + fto = FunctionToOptimize( |
| 4742 | + function_name="check", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4 |
| 4743 | + ) |
| 4744 | + result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set()) |
| 4745 | + assert len(result.code_strings) == 2 |
| 4746 | + combined = "\n".join(cs.code for cs in result.code_strings) |
| 4747 | + assert "class Alpha:" in combined |
| 4748 | + assert "class Beta:" in combined |
| 4749 | + |
| 4750 | + |
| 4751 | +def test_extract_parameter_type_constructors_type_is_pattern(tmp_path: Path) -> None: |
| 4752 | + """type(x) is SomeType pattern should be picked up.""" |
| 4753 | + pkg = tmp_path / "mypkg" |
| 4754 | + pkg.mkdir() |
| 4755 | + (pkg / "__init__.py").write_text("", encoding="utf-8") |
| 4756 | + (pkg / "models.py").write_text( |
| 4757 | + "class Gadget:\n def __init__(self, val: float):\n self.val = val\n", |
| 4758 | + encoding="utf-8", |
| 4759 | + ) |
| 4760 | + (pkg / "processor.py").write_text( |
| 4761 | + "from mypkg.models import Gadget\n\ndef check(x) -> bool:\n return type(x) is Gadget\n", |
| 4762 | + encoding="utf-8", |
| 4763 | + ) |
| 4764 | + fto = FunctionToOptimize( |
| 4765 | + function_name="check", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4 |
| 4766 | + ) |
| 4767 | + result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set()) |
| 4768 | + assert len(result.code_strings) == 1 |
| 4769 | + assert "class Gadget:" in result.code_strings[0].code |
| 4770 | + |
| 4771 | + |
| 4772 | +def test_extract_parameter_type_constructors_base_classes(tmp_path: Path) -> None: |
| 4773 | + """Base classes of enclosing class should be picked up for methods.""" |
| 4774 | + pkg = tmp_path / "mypkg" |
| 4775 | + pkg.mkdir() |
| 4776 | + (pkg / "__init__.py").write_text("", encoding="utf-8") |
| 4777 | + (pkg / "base.py").write_text( |
| 4778 | + "class BaseProcessor:\n def __init__(self, config: str):\n self.config = config\n", |
| 4779 | + encoding="utf-8", |
| 4780 | + ) |
| 4781 | + (pkg / "child.py").write_text( |
| 4782 | + "from mypkg.base import BaseProcessor\n\nclass ChildProcessor(BaseProcessor):\n" |
| 4783 | + " def process(self) -> str:\n return self.config\n", |
| 4784 | + encoding="utf-8", |
| 4785 | + ) |
| 4786 | + fto = FunctionToOptimize( |
| 4787 | + function_name="process", |
| 4788 | + file_path=(pkg / "child.py").resolve(), |
| 4789 | + starting_line=4, |
| 4790 | + ending_line=5, |
| 4791 | + parents=[FunctionParent(name="ChildProcessor", type="ClassDef")], |
| 4792 | + ) |
| 4793 | + result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set()) |
| 4794 | + assert len(result.code_strings) == 1 |
| 4795 | + assert "class BaseProcessor:" in result.code_strings[0].code |
| 4796 | + |
| 4797 | + |
| 4798 | +def test_extract_parameter_type_constructors_isinstance_builtins_excluded(tmp_path: Path) -> None: |
| 4799 | + """Isinstance with builtins (int, str, etc.) should not produce stubs.""" |
| 4800 | + pkg = tmp_path / "mypkg" |
| 4801 | + pkg.mkdir() |
| 4802 | + (pkg / "__init__.py").write_text("", encoding="utf-8") |
| 4803 | + (pkg / "func.py").write_text( |
| 4804 | + "def check(x) -> bool:\n return isinstance(x, (int, str, float))\n", |
| 4805 | + encoding="utf-8", |
| 4806 | + ) |
| 4807 | + fto = FunctionToOptimize( |
| 4808 | + function_name="check", file_path=(pkg / "func.py").resolve(), starting_line=1, ending_line=2 |
| 4809 | + ) |
| 4810 | + result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set()) |
| 4811 | + assert len(result.code_strings) == 0 |
| 4812 | + |
| 4813 | + |
| 4814 | +def test_extract_parameter_type_constructors_transitive(tmp_path: Path) -> None: |
| 4815 | + """Transitive extraction: if Widget.__init__ takes a Config, Config's stub should also appear.""" |
| 4816 | + pkg = tmp_path / "mypkg" |
| 4817 | + pkg.mkdir() |
| 4818 | + (pkg / "__init__.py").write_text("", encoding="utf-8") |
| 4819 | + (pkg / "config.py").write_text( |
| 4820 | + "class Config:\n def __init__(self, debug: bool = False):\n self.debug = debug\n", |
| 4821 | + encoding="utf-8", |
| 4822 | + ) |
| 4823 | + (pkg / "models.py").write_text( |
| 4824 | + "from mypkg.config import Config\n\n" |
| 4825 | + "class Widget:\n def __init__(self, cfg: Config):\n self.cfg = cfg\n", |
| 4826 | + encoding="utf-8", |
| 4827 | + ) |
| 4828 | + (pkg / "processor.py").write_text( |
| 4829 | + "from mypkg.models import Widget\n\ndef process(w: Widget) -> str:\n return str(w)\n", |
| 4830 | + encoding="utf-8", |
| 4831 | + ) |
| 4832 | + fto = FunctionToOptimize( |
| 4833 | + function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4 |
| 4834 | + ) |
| 4835 | + result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set()) |
| 4836 | + combined = "\n".join(cs.code for cs in result.code_strings) |
| 4837 | + assert "class Widget:" in combined |
| 4838 | + assert "class Config:" in combined |
| 4839 | + |
| 4840 | + |
| 4841 | + |
| 4842 | + |
| 4843 | +def test_enrich_testgen_context_third_party_uses_stubs(tmp_path: Path) -> None: |
| 4844 | + """Third-party classes should produce compact __init__ stubs, not full class source.""" |
| 4845 | + # Use a real third-party package (pydantic) so jedi can actually resolve it |
| 4846 | + context_code = ( |
| 4847 | + "from pydantic import BaseModel\n\n" |
| 4848 | + "class MyModel(BaseModel):\n" |
| 4849 | + " name: str\n\n" |
| 4850 | + "def process(m: MyModel) -> str:\n" |
| 4851 | + " return m.name\n" |
| 4852 | + ) |
| 4853 | + consumer_path = tmp_path / "consumer.py" |
| 4854 | + consumer_path.write_text(context_code, encoding="utf-8") |
| 4855 | + |
| 4856 | + context = CodeStringsMarkdown(code_strings=[CodeString(code=context_code, file_path=consumer_path)]) |
| 4857 | + result = enrich_testgen_context(context, tmp_path) |
| 4858 | + |
| 4859 | + # BaseModel lives in site-packages so should get stub treatment (compact __init__), |
| 4860 | + # not the full class definition with hundreds of methods |
| 4861 | + for cs in result.code_strings: |
| 4862 | + if "BaseModel" in cs.code: |
| 4863 | + assert "class BaseModel:" in cs.code |
| 4864 | + assert "__init__" in cs.code |
| 4865 | + # Full BaseModel has many methods; stubs should only have __init__/properties |
| 4866 | + assert "model_dump" not in cs.code |
| 4867 | + break |
0 commit comments