diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 29bea8761..86d574af1 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -78,10 +78,23 @@ def __init__(self, file_path: str) -> None: self.file_path: str = file_path self.functions: list[FunctionToOptimize] = [] + @staticmethod + def is_pytest_fixture(node: cst.FunctionDef) -> bool: + for decorator in node.decorators: + dec = decorator.decorator + if isinstance(dec, cst.Call): + dec = dec.func + if isinstance(dec, cst.Attribute) and dec.attr.value == "fixture": + if isinstance(dec.value, cst.Name) and dec.value.value == "pytest": + return True + if isinstance(dec, cst.Name) and dec.value == "fixture": + return True + return False + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: return_visitor: ReturnStatementVisitor = ReturnStatementVisitor() node.visit(return_visitor) - if return_visitor.has_return_statement: + if return_visitor.has_return_statement and not self.is_pytest_fixture(node): pos: CodeRange = self.get_metadata(cst.metadata.PositionProvider, node) parents: CSTNode | None = self.get_metadata(cst.metadata.ParentNodeProvider, node) ast_parents: list[FunctionParent] = [] @@ -108,14 +121,12 @@ def __init__(self, file_path: Path) -> None: self.file_path: Path = file_path def visit_FunctionDef(self, node: FunctionDef) -> None: - # Check if the function has a return statement and add it to the list if function_has_return_statement(node) and not function_is_a_property(node): self.functions.append( FunctionToOptimize(function_name=node.name, file_path=self.file_path, parents=self.ast_path[:]) ) def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None: - # Check if the async function has a return statement and add it to the list if function_has_return_statement(node) and not function_is_a_property(node): self.functions.append( FunctionToOptimize( @@ -831,22 +842,17 @@ def filter_functions( test_dir_patterns = (os.sep + "test" + os.sep, os.sep + "tests" + os.sep, os.sep + "__tests__" + os.sep) def is_test_file(file_path_normalized: str) -> bool: - """Check if a file is a test file based on patterns.""" if tests_root_overlaps_source: - # Use file pattern matching when tests_root overlaps with source file_lower = file_path_normalized.lower() - # Check filename patterns (e.g., .test.ts, .spec.ts) + basename = Path(file_lower).name + if basename.startswith("test_") or basename == "conftest.py": + return True if any(pattern in file_lower for pattern in test_file_name_patterns): return True - # Check directory patterns, but only within the project root - # to avoid false positives from parent directories (e.g., project at /home/user/tests/myproject) if project_root_str and file_lower.startswith(project_root_str.lower()): relative_path = file_lower[len(project_root_str) :] return any(pattern in relative_path for pattern in test_dir_patterns) - # If we can't compute relative path from project root, don't check directory patterns - # This avoids false positives when project is inside a folder named "tests" return False - # Use directory-based filtering when tests are in a separate directory return file_path_normalized.startswith(tests_root_str + os.sep) # We desperately need Python 3.10+ only support to make this code readable with structural pattern matching diff --git a/tests/test_function_discovery.py b/tests/test_function_discovery.py index 79907fcf5..3232d8be2 100644 --- a/tests/test_function_discovery.py +++ b/tests/test_function_discovery.py @@ -1149,4 +1149,131 @@ def test_is_object_empty(): ) # Strict check: exactly 2 functions - assert count == 2, f"Expected exactly 2 functions, got {count}" \ No newline at end of file + assert count == 2, f"Expected exactly 2 functions, got {count}" + + +def test_filter_functions_python_test_prefix_convention(): + """Test that files following Python's test_*.py naming convention are filtered. + + Python's standard test file naming uses the test_ prefix (e.g., test_utils.py), + which was previously not caught by the pattern matching in overlapping mode. + """ + with tempfile.TemporaryDirectory() as temp_dir_str: + temp_dir = Path(temp_dir_str) + + # Source file that should NOT be filtered + source_file = temp_dir / "utils.py" + with source_file.open("w") as f: + f.write("def process(): return 1") + + # Python test file with test_ prefix - SHOULD be filtered + test_prefix_file = temp_dir / "test_utils.py" + with test_prefix_file.open("w") as f: + f.write("def test_process(): return 1") + + # conftest.py - SHOULD be filtered + conftest_file = temp_dir / "conftest.py" + with conftest_file.open("w") as f: + f.write(""" +import pytest + +@pytest.fixture +def sample_data(): + return [1, 2, 3] +""") + + # File in a test_ prefixed directory - should NOT be filtered by file patterns + # (directory patterns don't cover test_ prefix dirs, which is fine) + test_subdir = temp_dir / "test_integration" + test_subdir.mkdir() + file_in_test_dir = test_subdir / "helpers.py" + with file_in_test_dir.open("w") as f: + f.write("def helper(): return 1") + + # test_ prefix file inside a subdirectory - SHOULD be filtered + test_in_subdir = test_subdir / "test_helpers.py" + with test_in_subdir.open("w") as f: + f.write("def test_helper(): return 1") + + all_functions = {} + for file_path in [source_file, test_prefix_file, conftest_file, file_in_test_dir, test_in_subdir]: + discovered = find_all_functions_in_file(file_path) + all_functions.update(discovered) + + with unittest.mock.patch( + "codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={} + ): + filtered, count = filter_functions( + all_functions, + tests_root=temp_dir, # Overlapping case + ignore_paths=[], + project_root=temp_dir, + module_root=temp_dir, + ) + + # source_file and file_in_test_dir should remain + # test_prefix_file, conftest_file, and test_in_subdir should be filtered + expected_files = {source_file, file_in_test_dir} + assert set(filtered.keys()) == expected_files, ( + f"Expected {expected_files}, got {set(filtered.keys())}" + ) + assert count == 2, f"Expected exactly 2 functions, got {count}" + + +def test_pytest_fixture_not_discovered(): + """Test that @pytest.fixture decorated functions are not discovered via libcst path.""" + from codeflash.languages.python.support import PythonSupport + + with tempfile.TemporaryDirectory() as temp_dir_str: + temp_dir = Path(temp_dir_str) + + fixture_file = temp_dir / "conftest.py" + with fixture_file.open("w") as f: + f.write(""" +import pytest +from pytest import fixture + +def regular_function(): + return 42 + +@pytest.fixture +def sample_data(): + return [1, 2, 3] + +@pytest.fixture() +def sample_config(): + return {"key": "value"} + +@fixture +def direct_import_fixture(): + return "data" + +@fixture() +def direct_import_fixture_with_parens(): + return "data" + +@pytest.fixture(scope="session") +def session_fixture(): + return "session" + +class TestHelpers: + @pytest.fixture + def class_fixture(self): + return "class_data" + + def helper_method(self): + return "helper" +""") + + support = PythonSupport() + functions = support.discover_functions(fixture_file) + function_names = [fn.function_name for fn in functions] + + assert "regular_function" in function_names + assert "helper_method" in function_names + assert "sample_data" not in function_names + assert "sample_config" not in function_names + assert "direct_import_fixture" not in function_names + assert "direct_import_fixture_with_parens" not in function_names + assert "session_fixture" not in function_names + assert "class_fixture" not in function_names