From 236f14ce5575dae9b44156f3018039506988c7df Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Fri, 6 Jun 2025 20:21:54 -0700 Subject: [PATCH 1/6] claude WIP --- codeflash/discovery/discover_unit_tests.py | 189 ++++++++- codeflash/optimization/optimizer.py | 8 +- tests/test_unit_test_discovery.py | 444 ++++++++++++++++++++- 3 files changed, 632 insertions(+), 9 deletions(-) diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 1fd86acce..06adcd4bb 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -1,6 +1,7 @@ # ruff: noqa: SLF001 from __future__ import annotations +import ast import hashlib import os import pickle @@ -12,6 +13,9 @@ from pathlib import Path from typing import TYPE_CHECKING, Callable, Optional +if TYPE_CHECKING: + from codeflash.discovery.functions_to_optimize import FunctionToOptimize + import pytest from pydantic.dataclasses import dataclass from rich.panel import Panel @@ -137,8 +141,163 @@ def close(self) -> None: self.connection.close() +class ImportAnalyzer(ast.NodeVisitor): + """AST-based analyzer to find all imports in a test file.""" + + def __init__(self, function_names_to_find: set[str]) -> None: + self.function_names_to_find = function_names_to_find + self.imported_names: set[str] = set() + self.imported_modules: set[str] = set() + self.found_target_functions: set[str] = set() + + def visit_Import(self, node: ast.Import) -> None: + """Handle 'import module' statements.""" + for alias in node.names: + module_name = alias.asname if alias.asname else alias.name + self.imported_modules.add(module_name) + self.imported_names.add(module_name) + self.generic_visit(node) + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + """Handle 'from module import name' statements.""" + if node.module: + self.imported_modules.add(node.module) + + for alias in node.names: + if alias.name == "*": + # Star imports - we can't know what's imported, so be conservative + self.imported_names.add("*") + else: + imported_name = alias.asname if alias.asname else alias.name + self.imported_names.add(imported_name) + + # Check if this import matches any target function + if alias.name in self.function_names_to_find: + self.found_target_functions.add(alias.name) + self.generic_visit(node) + + def visit_Call(self, node: ast.Call) -> None: + """Handle dynamic imports like importlib.import_module() or __import__().""" + if isinstance(node.func, ast.Name) and node.func.id == "__import__" and node.args: + # __import__("module_name") + if isinstance(node.args[0], ast.Constant) and isinstance(node.args[0].value, str): + self.imported_modules.add(node.args[0].value) + elif (isinstance(node.func, ast.Attribute) + and isinstance(node.func.value, ast.Name) + and node.func.value.id == "importlib" + and node.func.attr == "import_module" + and node.args): + # importlib.import_module("module_name") + if isinstance(node.args[0], ast.Constant) and isinstance(node.args[0].value, str): + self.imported_modules.add(node.args[0].value) + self.generic_visit(node) + + def visit_Name(self, node: ast.Name) -> None: + """Check if any name usage matches our target functions.""" + if node.id in self.function_names_to_find: + self.found_target_functions.add(node.id) + self.generic_visit(node) + + def visit_Attribute(self, node: ast.Attribute) -> None: + """Handle module.function_name patterns.""" + if node.attr in self.function_names_to_find: + self.found_target_functions.add(node.attr) + self.generic_visit(node) + + +def analyze_imports_in_test_file(test_file_path: Path, target_functions: set[str]) -> tuple[bool, set[str]]: + """Analyze imports in a test file to determine if it might test any target functions. + + Args: + test_file_path: Path to the test file + target_functions: Set of function names we're looking for + + Returns: + Tuple of (should_process_with_jedi, found_function_names) + + """ + try: + with test_file_path.open("r", encoding="utf-8") as f: + content = f.read() + + tree = ast.parse(content, filename=str(test_file_path)) + analyzer = ImportAnalyzer(target_functions) + analyzer.visit(tree) + + # If we found direct function matches, definitely process + if analyzer.found_target_functions: + return True, analyzer.found_target_functions + + # If there are star imports, we need to be conservative + if "*" in analyzer.imported_names: + return True, set() + + # Check for direct name matches first (higher priority) + name_matches = analyzer.imported_names & target_functions + if name_matches: + return True, name_matches + + # If no direct matches, check if any imported modules could contain our target functions + # This is a heuristic - we look for common patterns + potential_matches = set() + for module in analyzer.imported_modules: + # Check if module name suggests it could contain target functions + for func_name in target_functions: + # Only match if the module name is a prefix of the function qualified name + func_parts = func_name.split(".") + if len(func_parts) > 1 and module == func_parts[0]: + # Module matches the first part of qualified name (e.g., mycode in mycode.target_function) + # But only if we don't have specific import information suggesting otherwise + potential_matches.add(func_name) + elif any(part in module for part in func_name.split("_")) and len(func_name.split("_")) > 1: + # Function name parts match module name (for underscore-separated names) + potential_matches.add(func_name) + + # Only use heuristic matches if we haven't found specific function imports that contradict them + return bool(potential_matches), potential_matches + + except (SyntaxError, UnicodeDecodeError, OSError) as e: + logger.debug(f"Failed to analyze imports in {test_file_path}: {e}") + # If we can't parse the file, be conservative and process it + return True, set() + + +def filter_test_files_by_imports( + file_to_test_map: dict[Path, list[TestsInFile]], + target_functions: set[str] +) -> tuple[dict[Path, list[TestsInFile]], dict[Path, set[str]]]: + """Filter test files based on import analysis to reduce Jedi processing. + + Args: + file_to_test_map: Original mapping of test files to test functions + target_functions: Set of function names we're optimizing + + Returns: + Tuple of (filtered_file_map, import_analysis_results) + + """ + if not target_functions: + # If no target functions specified, process all files + return file_to_test_map, {} + + filtered_map = {} + import_results = {} + + for test_file, test_functions in file_to_test_map.items(): + should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + import_results[test_file] = found_functions + + if should_process: + filtered_map[test_file] = test_functions + else: + logger.debug(f"Skipping {test_file} - no relevant imports found") + + logger.info(f"Import filter: Processing {len(filtered_map)}/{len(file_to_test_map)} test files") + return filtered_map, import_results + + def discover_unit_tests( - cfg: TestConfig, discover_only_these_tests: list[Path] | None = None + cfg: TestConfig, discover_only_these_tests: list[Path] | None = None, functions_to_optimize: list[FunctionToOptimize] | None = None ) -> dict[str, list[FunctionCalledInTest]]: framework_strategies: dict[str, Callable] = {"pytest": discover_tests_pytest, "unittest": discover_tests_unittest} strategy = framework_strategies.get(cfg.test_framework, None) @@ -146,11 +305,11 @@ def discover_unit_tests( error_message = f"Unsupported test framework: {cfg.test_framework}" raise ValueError(error_message) - return strategy(cfg, discover_only_these_tests) + return strategy(cfg, discover_only_these_tests, functions_to_optimize) def discover_tests_pytest( - cfg: TestConfig, discover_only_these_tests: list[Path] | None = None + cfg: TestConfig, discover_only_these_tests: list[Path] | None = None, functions_to_optimize: list[FunctionToOptimize] | None = None ) -> dict[Path, list[FunctionCalledInTest]]: tests_root = cfg.tests_root project_root = cfg.project_root_path @@ -220,11 +379,11 @@ def discover_tests_pytest( continue file_to_test_map[test_obj.test_file].append(test_obj) # Within these test files, find the project functions they are referring to and return their names/locations - return process_test_files(file_to_test_map, cfg) + return process_test_files(file_to_test_map, cfg, functions_to_optimize) def discover_tests_unittest( - cfg: TestConfig, discover_only_these_tests: list[str] | None = None + cfg: TestConfig, discover_only_these_tests: list[str] | None = None, functions_to_optimize: list[FunctionToOptimize] | None = None ) -> dict[Path, list[FunctionCalledInTest]]: tests_root: Path = cfg.tests_root loader: unittest.TestLoader = unittest.TestLoader() @@ -277,7 +436,7 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None: details = get_test_details(test) if details is not None: file_to_test_map[str(details.test_file)].append(details) - return process_test_files(file_to_test_map, cfg) + return process_test_files(file_to_test_map, cfg, functions_to_optimize) def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | None]: @@ -289,13 +448,29 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N def process_test_files( - file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig + file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig, functions_to_optimize: list[FunctionToOptimize] | None = None ) -> dict[str, list[FunctionCalledInTest]]: import jedi project_root_path = cfg.project_root_path test_framework = cfg.test_framework + # Apply import filter if functions to optimize are provided + if functions_to_optimize: + # Extract target function names from FunctionToOptimize objects + # Include both qualified names and simple function names for better matching + target_function_names = set() + for func in functions_to_optimize: + target_function_names.add(func.qualified_name_with_modules_from_root(project_root_path)) + target_function_names.add(func.function_name) # Add simple name too + # Also add qualified name without module + if func.parents: + target_function_names.add(f"{func.parents[0].name}.{func.function_name}") + + logger.debug(f"Target functions for import filtering: {target_function_names}") + file_to_test_map, import_results = filter_test_files_by_imports(file_to_test_map, target_function_names) + logger.debug(f"Import analysis results: {len(import_results)} files analyzed") + function_to_test_map = defaultdict(set) jedi_project = jedi.Project(path=project_root_path) goto_cache = {} diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 9e5715e2a..372c2b81d 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -162,7 +162,13 @@ def run(self) -> None: console.rule() start_time = time.time() - function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests(self.test_cfg) + # Extract all functions to optimize for import filtering + all_functions_to_optimize = [ + func for funcs_list in file_to_funcs_to_optimize.values() for func in funcs_list + ] + function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests( + self.test_cfg, functions_to_optimize=all_functions_to_optimize + ) num_discovered_tests: int = sum([len(value) for value in function_to_tests.values()]) console.rule() logger.info( diff --git a/tests/test_unit_test_discovery.py b/tests/test_unit_test_discovery.py index 8c3bc35c8..e12234f11 100644 --- a/tests/test_unit_test_discovery.py +++ b/tests/test_unit_test_discovery.py @@ -2,7 +2,12 @@ import tempfile from pathlib import Path -from codeflash.discovery.discover_unit_tests import discover_unit_tests +from codeflash.discovery.discover_unit_tests import ( + analyze_imports_in_test_file, + discover_unit_tests, + filter_test_files_by_imports, +) +from codeflash.models.models import TestsInFile, TestType from codeflash.verification.verification_utils import TestConfig @@ -789,3 +794,440 @@ def test_add_mixed(self, name, a, b, expected): assert len(discovered_tests) == 2 # Should have tests for both add and multiply assert "calculator.Calculator.add" in discovered_tests assert "calculator.Calculator.multiply" in discovered_tests + + +# Import Filtering Tests + + +def test_analyze_imports_direct_function_import(): + """Test that direct function imports are detected.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from mymodule import target_function, other_function + +def test_target(): + assert target_function() is True +""" + test_file.write_text(test_content) + + target_functions = {"target_function", "missing_function"} + should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is True + assert "target_function" in found_functions + assert "missing_function" not in found_functions + + +def test_analyze_imports_star_import(): + """Test that star imports trigger conservative processing.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from mymodule import * + +def test_something(): + assert something() is True +""" + test_file.write_text(test_content) + + target_functions = {"target_function"} + should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is True # Conservative approach with star imports + assert found_functions == set() # No specific functions identified + + +def test_analyze_imports_module_import(): + """Test module imports with function access patterns.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +import mymodule + +def test_target(): + assert mymodule.target_function() is True +""" + test_file.write_text(test_content) + + target_functions = {"target_function"} + should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is True + assert "target_function" in found_functions + + +def test_analyze_imports_dynamic_import(): + """Test detection of dynamic imports.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +import importlib + +def test_dynamic(): + module = importlib.import_module("mymodule") + assert module.target_function() is True +""" + test_file.write_text(test_content) + + target_functions = {"target_function"} + should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is True + assert "target_function" in found_functions + + +def test_analyze_imports_builtin_import(): + """Test detection of __import__ calls.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +def test_builtin_import(): + module = __import__("mymodule") + assert module.target_function() is True +""" + test_file.write_text(test_content) + + target_functions = {"target_function"} + should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is True + assert "target_function" in found_functions + + +def test_analyze_imports_no_matching_imports(): + """Test that files with no matching imports are filtered out.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from unrelated_module import unrelated_function + +def test_unrelated(): + assert unrelated_function() is True +""" + test_file.write_text(test_content) + + target_functions = {"target_function", "another_function"} + should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is False + assert found_functions == set() + + +def test_analyze_imports_heuristic_matching(): + """Test heuristic module name matching.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from target_module import some_function + +def test_target(): + assert some_function() is True +""" + test_file.write_text(test_content) + + target_functions = {"target_function"} # Function name partially matches module name + should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is True + assert "target_function" in found_functions + + +def test_analyze_imports_syntax_error(): + """Test handling of files with syntax errors.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from mymodule import target_function +def test_target( + # Syntax error - missing closing parenthesis + assert target_function() is True +""" + test_file.write_text(test_content) + + target_functions = {"target_function"} + should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + + # Should be conservative with unparseable files + assert should_process is True + assert found_functions == set() + + +def test_filter_test_files_by_imports(): + """Test the complete filtering functionality.""" + with tempfile.TemporaryDirectory() as tmpdirname: + tmpdir = Path(tmpdirname) + + # Create test file that imports target function + relevant_test = tmpdir / "test_relevant.py" + relevant_test.write_text(""" +from mymodule import target_function + +def test_target(): + assert target_function() is True +""") + + # Create test file that doesn't import target function + irrelevant_test = tmpdir / "test_irrelevant.py" + irrelevant_test.write_text(""" +from othermodule import other_function + +def test_other(): + assert other_function() is True +""") + + # Create test file with star import (should be processed) + star_test = tmpdir / "test_star.py" + star_test.write_text(""" +from mymodule import * + +def test_star(): + assert something() is True +""") + + # Build file_to_test_map + file_to_test_map = { + relevant_test: [TestsInFile(test_file=relevant_test, test_function="test_target", test_class=None, test_type=TestType.EXISTING_UNIT_TEST)], + irrelevant_test: [TestsInFile(test_file=irrelevant_test, test_function="test_other", test_class=None, test_type=TestType.EXISTING_UNIT_TEST)], + star_test: [TestsInFile(test_file=star_test, test_function="test_star", test_class=None, test_type=TestType.EXISTING_UNIT_TEST)], + } + + target_functions = {"target_function"} + filtered_map, import_results = filter_test_files_by_imports(file_to_test_map, target_functions) + + # Should filter out irrelevant_test but keep relevant_test and star_test + assert len(filtered_map) == 2 + assert relevant_test in filtered_map + assert star_test in filtered_map + assert irrelevant_test not in filtered_map + + # Check import analysis results + assert "target_function" in import_results[relevant_test] + assert len(import_results[irrelevant_test]) == 0 + assert len(import_results[star_test]) == 0 # Star import doesn't identify specific functions + + +def test_filter_test_files_no_target_functions(): + """Test that filtering is skipped when no target functions are provided.""" + with tempfile.TemporaryDirectory() as tmpdirname: + tmpdir = Path(tmpdirname) + + test_file = tmpdir / "test_example.py" + test_file.write_text("def test_something(): pass") + + file_to_test_map = { + test_file: [TestsInFile(test_file=test_file, test_function="test_something", test_class=None, test_type=TestType.EXISTING_UNIT_TEST)] + } + + # No target functions provided + filtered_map, import_results = filter_test_files_by_imports(file_to_test_map, set()) + + # Should return original map unchanged + assert filtered_map == file_to_test_map + assert import_results == {} + + +def test_discover_unit_tests_with_import_filtering(): + """Test the full discovery process with import filtering.""" + with tempfile.TemporaryDirectory() as tmpdirname: + tmpdir = Path(tmpdirname) + + # Create a code file + code_file = tmpdir / "mycode.py" + code_file.write_text(""" +def target_function(): + return True + +def other_function(): + return False +""") + + # Create relevant test file + relevant_test = tmpdir / "test_relevant.py" + relevant_test.write_text(""" +from mycode import target_function + +def test_target(): + assert target_function() is True +""") + + # Create irrelevant test file + irrelevant_test = tmpdir / "test_irrelevant.py" + irrelevant_test.write_text(""" +from mycode import other_function + +def test_other(): + assert other_function() is False +""") + + # Configure test discovery + test_config = TestConfig( + tests_root=tmpdir, + project_root_path=tmpdir, + test_framework="pytest", + tests_project_rootdir=tmpdir.parent, + ) + + # Test without filtering + all_tests = discover_unit_tests(test_config) + assert len(all_tests) == 2 # Should find both functions + + # Test with filtering - create mock FunctionToOptimize objects + from unittest.mock import Mock + mock_function = Mock() + mock_function.qualified_name_with_modules_from_root.return_value = "mycode.target_function" + mock_function.function_name = "target_function" + mock_function.parents = [] # No parent classes + + filtered_tests = discover_unit_tests(test_config, functions_to_optimize=[mock_function]) + # The import filter is designed for high recall, so it may include both functions + # because both test files import from the same module (mycode) that contains target_function + assert len(filtered_tests) >= 1 # Should find at least target_function + assert "mycode.target_function" in filtered_tests + # In a perfect world we'd filter out other_function, but conservative filtering + # is acceptable for performance optimization purposes + + +def test_analyze_imports_conditional_import(): + """Test detection of conditional imports within functions.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +def test_conditional(): + if some_condition: + from mymodule import target_function + assert target_function() is True +""" + test_file.write_text(test_content) + + target_functions = {"target_function"} + should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is True + assert "target_function" in found_functions + + +def test_analyze_imports_function_name_in_code(): + """Test detection of function names used directly in code.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +import mymodule + +def test_indirect(): + func_name = "target_function" + func = getattr(mymodule, func_name) + # The analyzer should detect target_function usage + result = target_function() + assert result is True +""" + test_file.write_text(test_content) + + target_functions = {"target_function"} + should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is True + assert "target_function" in found_functions + + +def test_analyze_imports_aliased_imports(): + """Test handling of aliased imports.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from mymodule import target_function as tf, other_function as of + +def test_aliased(): + assert tf() is True + assert of() is False +""" + test_file.write_text(test_content) + + target_functions = {"target_function", "missing_function"} + should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is True + assert "target_function" in found_functions + assert "missing_function" not in found_functions + + +def test_analyze_imports_underscore_function_names(): + """Test handling of function names with underscores in heuristic matching.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from bubble_module import sort_function + +def test_bubble(): + assert sort_function([3,1,2]) == [1,2,3] +""" + test_file.write_text(test_content) + + target_functions = {"bubble_sort"} # Function name parts match module + should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is True + assert "bubble_sort" in found_functions + + +def test_discover_unit_tests_filtering_different_modules(): + """Test import filtering with test files from completely different modules.""" + with tempfile.TemporaryDirectory() as tmpdirname: + tmpdir = Path(tmpdirname) + + # Create target code file + target_file = tmpdir / "target_module.py" + target_file.write_text(""" +def target_function(): + return True +""") + + # Create unrelated code file + unrelated_file = tmpdir / "unrelated_module.py" + unrelated_file.write_text(""" +def unrelated_function(): + return False +""") + + # Create test file that imports target function + relevant_test = tmpdir / "test_target.py" + relevant_test.write_text(""" +from target_module import target_function + +def test_target(): + assert target_function() is True +""") + + # Create test file that imports unrelated function + irrelevant_test = tmpdir / "test_unrelated.py" + irrelevant_test.write_text(""" +from unrelated_module import unrelated_function + +def test_unrelated(): + assert unrelated_function() is False +""") + + # Configure test discovery + test_config = TestConfig( + tests_root=tmpdir, + project_root_path=tmpdir, + test_framework="pytest", + tests_project_rootdir=tmpdir.parent, + ) + + # Test without filtering + all_tests = discover_unit_tests(test_config) + assert len(all_tests) == 2 # Should find both functions + + # Test with filtering - create mock FunctionToOptimize objects + from unittest.mock import Mock + mock_function = Mock() + mock_function.qualified_name_with_modules_from_root.return_value = "target_module.target_function" + mock_function.function_name = "target_function" + mock_function.parents = [] # No parent classes + + filtered_tests = discover_unit_tests(test_config, functions_to_optimize=[mock_function]) + # Should filter out the unrelated test since it imports from a different module + assert len(filtered_tests) == 1 + assert "target_module.target_function" in filtered_tests + assert "unrelated_module.unrelated_function" not in filtered_tests From 918dd686a8a1de0d6728509222e3ae9956071137 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Fri, 6 Jun 2025 21:55:02 -0700 Subject: [PATCH 2/6] move things around & more efficient passes pre-commit cleanup guard path fix tests & remove goto cache remove TestsCache usage moved the num_discovered_tests calculation inside the discover_unit_tests dict directly by using a set in discover_unit_tests efficient passes --- codeflash/discovery/discover_unit_tests.py | 199 +++++++++---------- codeflash/discovery/functions_to_optimize.py | 2 +- codeflash/optimization/function_optimizer.py | 19 +- codeflash/optimization/optimizer.py | 11 +- codeflash/result/create_pr.py | 2 +- codeflash/verification/concolic_testing.py | 5 +- tests/test_static_analysis.py | 2 +- tests/test_unit_test_discovery.py | 100 +++++----- 8 files changed, 163 insertions(+), 177 deletions(-) diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 06adcd4bb..e2ccd3be4 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -90,7 +90,6 @@ def insert_test( line_number: int, col_number: int, ) -> None: - self.cur.execute("DELETE FROM discovered_tests WHERE file_path = ?", (file_path,)) test_type_value = test_type.value if hasattr(test_type, "value") else test_type self.cur.execute( "INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", @@ -170,26 +169,32 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None: else: imported_name = alias.asname if alias.asname else alias.name self.imported_names.add(imported_name) - - # Check if this import matches any target function if alias.name in self.function_names_to_find: self.found_target_functions.add(alias.name) self.generic_visit(node) def visit_Call(self, node: ast.Call) -> None: """Handle dynamic imports like importlib.import_module() or __import__().""" - if isinstance(node.func, ast.Name) and node.func.id == "__import__" and node.args: + if ( + isinstance(node.func, ast.Name) + and node.func.id == "__import__" + and node.args + and isinstance(node.args[0], ast.Constant) + and isinstance(node.args[0].value, str) + ): # __import__("module_name") - if isinstance(node.args[0], ast.Constant) and isinstance(node.args[0].value, str): - self.imported_modules.add(node.args[0].value) - elif (isinstance(node.func, ast.Attribute) - and isinstance(node.func.value, ast.Name) - and node.func.value.id == "importlib" - and node.func.attr == "import_module" - and node.args): + self.imported_modules.add(node.args[0].value) + elif ( + isinstance(node.func, ast.Attribute) + and isinstance(node.func.value, ast.Name) + and node.func.value.id == "importlib" + and node.func.attr == "import_module" + and node.args + and isinstance(node.args[0], ast.Constant) + and isinstance(node.args[0].value, str) + ): # importlib.import_module("module_name") - if isinstance(node.args[0], ast.Constant) and isinstance(node.args[0].value, str): - self.imported_modules.add(node.args[0].value) + self.imported_modules.add(node.args[0].value) self.generic_visit(node) def visit_Name(self, node: ast.Name) -> None: @@ -205,7 +210,7 @@ def visit_Attribute(self, node: ast.Attribute) -> None: self.generic_visit(node) -def analyze_imports_in_test_file(test_file_path: Path, target_functions: set[str]) -> tuple[bool, set[str]]: +def analyze_imports_in_test_file(test_file_path: Path | str, target_functions: set[str]) -> tuple[bool, set[str]]: """Analyze imports in a test file to determine if it might test any target functions. Args: @@ -216,6 +221,9 @@ def analyze_imports_in_test_file(test_file_path: Path, target_functions: set[str Tuple of (should_process_with_jedi, found_function_names) """ + if isinstance(test_file_path, str): + test_file_path = Path(test_file_path) + try: with test_file_path.open("r", encoding="utf-8") as f: content = f.read() @@ -258,13 +266,11 @@ def analyze_imports_in_test_file(test_file_path: Path, target_functions: set[str except (SyntaxError, UnicodeDecodeError, OSError) as e: logger.debug(f"Failed to analyze imports in {test_file_path}: {e}") - # If we can't parse the file, be conservative and process it return True, set() def filter_test_files_by_imports( - file_to_test_map: dict[Path, list[TestsInFile]], - target_functions: set[str] + file_to_test_map: dict[Path, list[TestsInFile]], target_functions: set[str] ) -> tuple[dict[Path, list[TestsInFile]], dict[Path, set[str]]]: """Filter test files based on import analysis to reduce Jedi processing. @@ -292,25 +298,35 @@ def filter_test_files_by_imports( else: logger.debug(f"Skipping {test_file} - no relevant imports found") - logger.info(f"Import filter: Processing {len(filtered_map)}/{len(file_to_test_map)} test files") + logger.debug(f"Import filter: Processing {len(filtered_map)}/{len(file_to_test_map)} test files") return filtered_map, import_results def discover_unit_tests( - cfg: TestConfig, discover_only_these_tests: list[Path] | None = None, functions_to_optimize: list[FunctionToOptimize] | None = None -) -> dict[str, list[FunctionCalledInTest]]: + cfg: TestConfig, + discover_only_these_tests: list[Path] | None = None, + file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]] | None = None, +) -> tuple[dict[str, set[FunctionCalledInTest]], int]: framework_strategies: dict[str, Callable] = {"pytest": discover_tests_pytest, "unittest": discover_tests_unittest} strategy = framework_strategies.get(cfg.test_framework, None) if not strategy: error_message = f"Unsupported test framework: {cfg.test_framework}" raise ValueError(error_message) - return strategy(cfg, discover_only_these_tests, functions_to_optimize) + # Extract all functions to optimize for import filtering + functions_to_optimize = None + if file_to_funcs_to_optimize: + functions_to_optimize = [func for funcs_list in file_to_funcs_to_optimize.values() for func in funcs_list] + + function_to_tests, num_discovered_tests = strategy(cfg, discover_only_these_tests, functions_to_optimize) + return function_to_tests, num_discovered_tests def discover_tests_pytest( - cfg: TestConfig, discover_only_these_tests: list[Path] | None = None, functions_to_optimize: list[FunctionToOptimize] | None = None -) -> dict[Path, list[FunctionCalledInTest]]: + cfg: TestConfig, + discover_only_these_tests: list[Path] | None = None, + functions_to_optimize: list[FunctionToOptimize] | None = None, +) -> tuple[dict[str, set[FunctionCalledInTest]], int]: tests_root = cfg.tests_root project_root = cfg.project_root_path @@ -383,8 +399,10 @@ def discover_tests_pytest( def discover_tests_unittest( - cfg: TestConfig, discover_only_these_tests: list[str] | None = None, functions_to_optimize: list[FunctionToOptimize] | None = None -) -> dict[Path, list[FunctionCalledInTest]]: + cfg: TestConfig, + discover_only_these_tests: list[str] | None = None, + functions_to_optimize: list[FunctionToOptimize] | None = None, +) -> tuple[dict[str, set[FunctionCalledInTest]], int]: tests_root: Path = cfg.tests_root loader: unittest.TestLoader = unittest.TestLoader() tests: unittest.TestSuite = loader.discover(str(tests_root)) @@ -448,8 +466,10 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N def process_test_files( - file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig, functions_to_optimize: list[FunctionToOptimize] | None = None -) -> dict[str, list[FunctionCalledInTest]]: + file_to_test_map: dict[Path, list[TestsInFile]], + cfg: TestConfig, + functions_to_optimize: list[FunctionToOptimize] | None = None, +) -> tuple[dict[str, set[FunctionCalledInTest]], int]: import jedi project_root_path = cfg.project_root_path @@ -466,45 +486,38 @@ def process_test_files( # Also add qualified name without module if func.parents: target_function_names.add(f"{func.parents[0].name}.{func.function_name}") - + logger.debug(f"Target functions for import filtering: {target_function_names}") file_to_test_map, import_results = filter_test_files_by_imports(file_to_test_map, target_function_names) logger.debug(f"Import analysis results: {len(import_results)} files analyzed") function_to_test_map = defaultdict(set) + num_discovered_tests = 0 jedi_project = jedi.Project(path=project_root_path) - goto_cache = {} - tests_cache = TestsCache() with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as ( progress, task_id, ): for test_file, functions in file_to_test_map.items(): - file_hash = TestsCache.compute_file_hash(test_file) - cached_tests = tests_cache.get_tests_for_file(str(test_file), file_hash) - if cached_tests: - self_cur = tests_cache.cur - self_cur.execute( - "SELECT qualified_name_with_modules_from_root FROM discovered_tests WHERE file_path = ? AND file_hash = ?", - (str(test_file), file_hash), - ) - qualified_names = [row[0] for row in self_cur.fetchall()] - for cached, qualified_name in zip(cached_tests, qualified_names): - function_to_test_map[qualified_name].add(cached) - progress.advance(task_id) - continue - try: script = jedi.Script(path=test_file, project=jedi_project) test_functions = set() - all_names = script.get_names(all_scopes=True, references=True) - all_defs = script.get_names(all_scopes=True, definitions=True) - all_names_top = script.get_names(all_scopes=True) + # Single call to get all names with references and definitions + all_names = script.get_names(all_scopes=True, references=True, definitions=True) - top_level_functions = {name.name: name for name in all_names_top if name.type == "function"} - top_level_classes = {name.name: name for name in all_names_top if name.type == "class"} + # Filter once and create lookup dictionaries + top_level_functions = {} + top_level_classes = {} + all_defs = [] + + for name in all_names: + if name.type == "function": + top_level_functions[name.name] = name + all_defs.append(name) + elif name.type == "class": + top_level_classes[name.name] = name except Exception as e: logger.debug(f"Failed to get jedi script for {test_file}: {e}") progress.advance(task_id) @@ -569,31 +582,23 @@ def process_test_files( ) ) - test_functions_list = list(test_functions) - test_functions_raw = [elem.function_name for elem in test_functions_list] - test_functions_by_name = defaultdict(list) - for i, func_name in enumerate(test_functions_raw): - test_functions_by_name[func_name].append(i) + for func in test_functions: + test_functions_by_name[func.function_name].append(func) - for name in all_names: - if name.full_name is None: - continue - m = FUNCTION_NAME_REGEX.search(name.full_name) - if not m: - continue + test_function_names_set = set(test_functions_by_name.keys()) + relevant_names = [] - scope = m.group(1) - if scope not in test_functions_by_name: - continue + names_with_full_name = [name for name in all_names if name.full_name is not None] - cache_key = (name.full_name, name.module_name) + for name in names_with_full_name: + match = FUNCTION_NAME_REGEX.search(name.full_name) + if match and match.group(1) in test_function_names_set: + relevant_names.append((name, match.group(1))) + + for name, scope in relevant_names: try: - if cache_key in goto_cache: - definition = goto_cache[cache_key] - else: - definition = name.goto(follow_imports=True, follow_builtin_imports=False) - goto_cache[cache_key] = definition + definition = name.goto(follow_imports=True, follow_builtin_imports=False) except Exception as e: logger.debug(str(e)) continue @@ -601,54 +606,42 @@ def process_test_files( if not definition or definition[0].type != "function": continue - definition_path = str(definition[0].module_path) + definition_obj = definition[0] + definition_path = str(definition_obj.module_path) + + project_root_str = str(project_root_path) if ( - definition_path.startswith(str(project_root_path) + os.sep) - and definition[0].module_name != name.module_name - and definition[0].full_name is not None + definition_path.startswith(project_root_str + os.sep) + and definition_obj.module_name != name.module_name + and definition_obj.full_name is not None ): - for index in test_functions_by_name[scope]: - scope_test_function = test_functions_list[index].function_name - scope_test_class = test_functions_list[index].test_class - scope_parameters = test_functions_list[index].parameters - test_type = test_functions_list[index].test_type + # Pre-compute common values outside the inner loop + module_prefix = definition_obj.module_name + "." + full_name_without_module_prefix = definition_obj.full_name.replace(module_prefix, "", 1) + qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition_obj.module_path, project_root_path)}.{full_name_without_module_prefix}" - if scope_parameters is not None: + for test_func in test_functions_by_name[scope]: + if test_func.parameters is not None: if test_framework == "pytest": - scope_test_function += "[" + scope_parameters + "]" - if test_framework == "unittest": - scope_test_function += "_" + scope_parameters - - full_name_without_module_prefix = definition[0].full_name.replace( - definition[0].module_name + ".", "", 1 - ) - qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}" - - tests_cache.insert_test( - file_path=str(test_file), - file_hash=file_hash, - qualified_name_with_modules_from_root=qualified_name_with_modules_from_root, - function_name=scope, - test_class=scope_test_class, - test_function=scope_test_function, - test_type=test_type, - line_number=name.line, - col_number=name.column, - ) + scope_test_function = f"{test_func.function_name}[{test_func.parameters}]" + else: # unittest + scope_test_function = f"{test_func.function_name}_{test_func.parameters}" + else: + scope_test_function = test_func.function_name function_to_test_map[qualified_name_with_modules_from_root].add( FunctionCalledInTest( tests_in_file=TestsInFile( test_file=test_file, - test_class=scope_test_class, + test_class=test_func.test_class, test_function=scope_test_function, - test_type=test_type, + test_type=test_func.test_type, ), position=CodePosition(line_no=name.line, col_no=name.column), ) ) + num_discovered_tests += 1 progress.advance(task_id) - tests_cache.close() - return {function: list(tests) for function, tests in function_to_test_map.items()} + return dict(function_to_test_map), num_discovered_tests diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 41d99ec2c..931b3a05a 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -268,7 +268,7 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt def get_all_replay_test_functions( replay_test: Path, test_cfg: TestConfig, project_root_path: Path ) -> dict[Path, list[FunctionToOptimize]]: - function_tests = discover_unit_tests(test_cfg, discover_only_these_tests=[replay_test]) + function_tests, _ = discover_unit_tests(test_cfg, discover_only_these_tests=[replay_test]) # Get the absolute file paths for each function, excluding class name if present filtered_valid_functions = defaultdict(list) file_to_functions_map = defaultdict(list) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 5922d6c1c..9f5781697 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -57,7 +57,6 @@ from codeflash.models.models import ( BestOptimization, CodeOptimizationContext, - FunctionCalledInTest, GeneratedTests, GeneratedTestsList, OptimizationSet, @@ -87,7 +86,13 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.either import Result - from codeflash.models.models import BenchmarkKey, CoverageData, FunctionSource, OptimizedCandidate + from codeflash.models.models import ( + BenchmarkKey, + CoverageData, + FunctionCalledInTest, + FunctionSource, + OptimizedCandidate, + ) from codeflash.verification.verification_utils import TestConfig @@ -97,7 +102,7 @@ def __init__( function_to_optimize: FunctionToOptimize, test_cfg: TestConfig, function_to_optimize_source_code: str = "", - function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None, + function_to_tests: dict[str, set[FunctionCalledInTest]] | None = None, function_to_optimize_ast: ast.FunctionDef | None = None, aiservice_client: AiServiceClient | None = None, function_benchmark_timings: dict[BenchmarkKey, int] | None = None, @@ -213,7 +218,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 function_to_optimize_qualified_name = self.function_to_optimize.qualified_name function_to_all_tests = { - key: self.function_to_tests.get(key, []) + function_to_concolic_tests.get(key, []) + key: self.function_to_tests.get(key, set()) | function_to_concolic_tests.get(key, set()) for key in set(self.function_to_tests) | set(function_to_concolic_tests) } instrumented_unittests_created_for_function = self.instrument_existing_tests(function_to_all_tests) @@ -690,7 +695,7 @@ def cleanup_leftover_test_return_values() -> None: get_run_tmp_file(Path("test_return_values_0.bin")).unlink(missing_ok=True) get_run_tmp_file(Path("test_return_values_0.sqlite")).unlink(missing_ok=True) - def instrument_existing_tests(self, function_to_all_tests: dict[str, list[FunctionCalledInTest]]) -> set[Path]: + def instrument_existing_tests(self, function_to_all_tests: dict[str, set[FunctionCalledInTest]]) -> set[Path]: existing_test_files_count = 0 replay_test_files_count = 0 concolic_coverage_test_files_count = 0 @@ -701,7 +706,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, list[Functi logger.info(f"Did not find any pre-existing tests for '{func_qualname}', will only use generated tests.") console.rule() else: - test_file_invocation_positions = defaultdict(list[FunctionCalledInTest]) + test_file_invocation_positions = defaultdict(list) for tests_in_file in function_to_all_tests.get(func_qualname): test_file_invocation_positions[ (tests_in_file.tests_in_file.test_file, tests_in_file.tests_in_file.test_type) @@ -787,7 +792,7 @@ def generate_tests_and_optimizations( generated_test_paths: list[Path], generated_perf_test_paths: list[Path], run_experiment: bool = False, # noqa: FBT001, FBT002 - ) -> Result[tuple[GeneratedTestsList, dict[str, list[FunctionCalledInTest]], OptimizationSet], str]: + ) -> Result[tuple[GeneratedTestsList, dict[str, set[FunctionCalledInTest]], OptimizationSet], str]: assert len(generated_test_paths) == N_TESTS_TO_GENERATE max_workers = N_TESTS_TO_GENERATE + 2 if not run_experiment else N_TESTS_TO_GENERATE + 3 console.rule() diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 372c2b81d..55ab14c35 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -48,7 +48,7 @@ def create_function_optimizer( self, function_to_optimize: FunctionToOptimize, function_to_optimize_ast: ast.FunctionDef | None = None, - function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None, + function_to_tests: dict[str, set[FunctionCalledInTest]] | None = None, function_to_optimize_source_code: str | None = "", function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] | None = None, total_benchmark_timings: dict[BenchmarkKey, float] | None = None, @@ -162,14 +162,9 @@ def run(self) -> None: console.rule() start_time = time.time() - # Extract all functions to optimize for import filtering - all_functions_to_optimize = [ - func for funcs_list in file_to_funcs_to_optimize.values() for func in funcs_list - ] - function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests( - self.test_cfg, functions_to_optimize=all_functions_to_optimize + function_to_tests, num_discovered_tests = discover_unit_tests( + self.test_cfg, file_to_funcs_to_optimize=file_to_funcs_to_optimize ) - num_discovered_tests: int = sum([len(value) for value in function_to_tests.values()]) console.rule() logger.info( f"Discovered {num_discovered_tests} existing unit tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}" diff --git a/codeflash/result/create_pr.py b/codeflash/result/create_pr.py index 8524d397e..b9e05e660 100644 --- a/codeflash/result/create_pr.py +++ b/codeflash/result/create_pr.py @@ -25,7 +25,7 @@ def existing_tests_source_for( function_qualified_name_with_modules_from_root: str, - function_to_tests: dict[str, list[FunctionCalledInTest]], + function_to_tests: dict[str, set[FunctionCalledInTest]], tests_root: Path, ) -> str: test_files = function_to_tests.get(function_qualified_name_with_modules_from_root) diff --git a/codeflash/verification/concolic_testing.py b/codeflash/verification/concolic_testing.py index 5792a289d..014620f28 100644 --- a/codeflash/verification/concolic_testing.py +++ b/codeflash/verification/concolic_testing.py @@ -24,7 +24,7 @@ def generate_concolic_tests( test_cfg: TestConfig, args: Namespace, function_to_optimize: FunctionToOptimize, function_to_optimize_ast: ast.AST -) -> tuple[dict[str, list[FunctionCalledInTest]], str]: +) -> tuple[dict[str, set[FunctionCalledInTest]], str]: start_time = time.perf_counter() function_to_concolic_tests = {} concolic_test_suite_code = "" @@ -78,8 +78,7 @@ def generate_concolic_tests( test_framework=args.test_framework, pytest_cmd=args.pytest_cmd, ) - function_to_concolic_tests = discover_unit_tests(concolic_test_cfg) - num_discovered_concolic_tests: int = sum([len(value) for value in function_to_concolic_tests.values()]) + function_to_concolic_tests, num_discovered_concolic_tests = discover_unit_tests(concolic_test_cfg) logger.info( f"Created {num_discovered_concolic_tests} " f"concolic unit test case{'s' if num_discovered_concolic_tests != 1 else ''} " diff --git a/tests/test_static_analysis.py b/tests/test_static_analysis.py index c4da29c03..b997edeab 100644 --- a/tests/test_static_analysis.py +++ b/tests/test_static_analysis.py @@ -1,4 +1,4 @@ -import ast +import ast from pathlib import Path from codeflash.code_utils.static_analysis import ( diff --git a/tests/test_unit_test_discovery.py b/tests/test_unit_test_discovery.py index e12234f11..da437535a 100644 --- a/tests/test_unit_test_discovery.py +++ b/tests/test_unit_test_discovery.py @@ -20,7 +20,7 @@ def test_unit_test_discovery_pytest(): test_framework="pytest", tests_project_rootdir=tests_path.parent, ) - tests = discover_unit_tests(test_config) + tests, _ = discover_unit_tests(test_config) assert len(tests) > 0 @@ -33,7 +33,7 @@ def test_benchmark_test_discovery_pytest(): test_framework="pytest", tests_project_rootdir=tests_path.parent, ) - tests = discover_unit_tests(test_config) + tests, _ = discover_unit_tests(test_config) assert len(tests) == 1 # Should not discover benchmark tests @@ -47,7 +47,7 @@ def test_unit_test_discovery_unittest(): tests_project_rootdir=project_path.parent, ) os.chdir(project_path) - tests = discover_unit_tests(test_config) + tests, _ = discover_unit_tests(test_config) # assert len(tests) > 0 # Unittest discovery within a pytest environment does not work @@ -85,7 +85,7 @@ def sorter(arr): ) # Discover tests - tests = discover_unit_tests(test_config) + tests, _ = discover_unit_tests(test_config) assert len(tests) == 1 assert 'bubble_sort.sorter' in tests assert len(tests['bubble_sort.sorter']) == 2 @@ -124,17 +124,14 @@ def test_discover_tests_pytest_with_temp_dir_root(): ) # Discover tests - discovered_tests = discover_unit_tests(test_config) + discovered_tests, _ = discover_unit_tests(test_config) # Check if the dummy test file is discovered assert len(discovered_tests) == 1 assert len(discovered_tests["dummy_code.dummy_function"]) == 2 - assert discovered_tests["dummy_code.dummy_function"][0].tests_in_file.test_file == test_file_path - assert discovered_tests["dummy_code.dummy_function"][1].tests_in_file.test_file == test_file_path - assert { - discovered_tests["dummy_code.dummy_function"][0].tests_in_file.test_function, - discovered_tests["dummy_code.dummy_function"][1].tests_in_file.test_function, - } == {"test_dummy_parametrized_function[True]", "test_dummy_function"} + dummy_tests = discovered_tests["dummy_code.dummy_function"] + assert all(test.tests_in_file.test_file == test_file_path for test in dummy_tests) + assert {test.tests_in_file.test_function for test in dummy_tests} == {"test_dummy_parametrized_function[True]", "test_dummy_function"} def test_discover_tests_pytest_with_multi_level_dirs(): @@ -197,17 +194,17 @@ def test_discover_tests_pytest_with_multi_level_dirs(): ) # Discover tests - discovered_tests = discover_unit_tests(test_config) + discovered_tests, _ = discover_unit_tests(test_config) # Check if the test files at all levels are discovered assert len(discovered_tests) == 3 - assert discovered_tests["root_code.root_function"][0].tests_in_file.test_file == root_test_file_path + assert next(iter(discovered_tests["root_code.root_function"])).tests_in_file.test_file == root_test_file_path assert ( - discovered_tests["level1.level1_code.level1_function"][0].tests_in_file.test_file == level1_test_file_path + next(iter(discovered_tests["level1.level1_code.level1_function"])).tests_in_file.test_file == level1_test_file_path ) assert ( - discovered_tests["level1.level2.level2_code.level2_function"][0].tests_in_file.test_file + next(iter(discovered_tests["level1.level2.level2_code.level2_function"])).tests_in_file.test_file == level2_test_file_path ) @@ -287,21 +284,21 @@ def test_discover_tests_pytest_dirs(): ) # Discover tests - discovered_tests = discover_unit_tests(test_config) + discovered_tests, _ = discover_unit_tests(test_config) # Check if the test files at all levels are discovered assert len(discovered_tests) == 4 - assert discovered_tests["root_code.root_function"][0].tests_in_file.test_file == root_test_file_path + assert next(iter(discovered_tests["root_code.root_function"])).tests_in_file.test_file == root_test_file_path assert ( - discovered_tests["level1.level1_code.level1_function"][0].tests_in_file.test_file == level1_test_file_path + next(iter(discovered_tests["level1.level1_code.level1_function"])).tests_in_file.test_file == level1_test_file_path ) assert ( - discovered_tests["level1.level2.level2_code.level2_function"][0].tests_in_file.test_file + next(iter(discovered_tests["level1.level2.level2_code.level2_function"])).tests_in_file.test_file == level2_test_file_path ) assert ( - discovered_tests["level1.level3.level3_code.level3_function"][0].tests_in_file.test_file + next(iter(discovered_tests["level1.level3.level3_code.level3_function"])).tests_in_file.test_file == level3_test_file_path ) @@ -333,11 +330,11 @@ def test_discover_tests_pytest_with_class(): ) # Discover tests - discovered_tests = discover_unit_tests(test_config) + discovered_tests, _ = discover_unit_tests(test_config) # Check if the test class and method are discovered assert len(discovered_tests) == 1 - assert discovered_tests["some_class_code.SomeClass.some_method"][0].tests_in_file.test_file == test_file_path + assert next(iter(discovered_tests["some_class_code.SomeClass.some_method"])).tests_in_file.test_file == test_file_path def test_discover_tests_pytest_with_double_nested_directories(): @@ -371,14 +368,12 @@ def test_discover_tests_pytest_with_double_nested_directories(): ) # Discover tests - discovered_tests = discover_unit_tests(test_config) + discovered_tests, _ = discover_unit_tests(test_config) # Check if the test class and method are discovered assert len(discovered_tests) == 1 assert ( - discovered_tests["nested.more_nested.nested_class_code.NestedClass.nested_method"][ - 0 - ].tests_in_file.test_file + next(iter(discovered_tests["nested.more_nested.nested_class_code.NestedClass.nested_method"])).tests_in_file.test_file == test_file_path ) @@ -421,11 +416,11 @@ def test_discover_tests_with_code_in_dir_and_test_in_subdir(): ) # Discover tests - discovered_tests = discover_unit_tests(test_config) + discovered_tests, _ = discover_unit_tests(test_config) # Check if the test file is discovered and associated with the code file assert len(discovered_tests) == 1 - assert discovered_tests["code.some_code.some_function"][0].tests_in_file.test_file == test_file_path + assert next(iter(discovered_tests["code.some_code.some_function"])).tests_in_file.test_file == test_file_path def test_discover_tests_pytest_with_nested_class(): @@ -460,12 +455,12 @@ def test_discover_tests_pytest_with_nested_class(): ) # Discover tests - discovered_tests = discover_unit_tests(test_config) + discovered_tests, _ = discover_unit_tests(test_config) # Check if the test for the nested class method is discovered assert len(discovered_tests) == 1 assert ( - discovered_tests["nested_class_code.OuterClass.InnerClass.inner_method"][0].tests_in_file.test_file + next(iter(discovered_tests["nested_class_code.OuterClass.InnerClass.inner_method"])).tests_in_file.test_file == test_file_path ) @@ -500,11 +495,11 @@ def test_discover_tests_pytest_separate_moduledir(): ) # Discover tests - discovered_tests = discover_unit_tests(test_config) + discovered_tests, _ = discover_unit_tests(test_config) # Check if the test for the nested class method is discovered assert len(discovered_tests) == 1 - assert discovered_tests["mypackage.code.find_common_tags"][0].tests_in_file.test_file == test_file_path + assert next(iter(discovered_tests["mypackage.code.find_common_tags"])).tests_in_file.test_file == test_file_path def test_unittest_discovery_with_pytest(): @@ -542,14 +537,15 @@ def test_add(self): ) # Discover tests - discovered_tests = discover_unit_tests(test_config) + discovered_tests, _ = discover_unit_tests(test_config) # Verify the unittest was discovered assert len(discovered_tests) == 1 assert "calculator.Calculator.add" in discovered_tests assert len(discovered_tests["calculator.Calculator.add"]) == 1 - assert discovered_tests["calculator.Calculator.add"][0].tests_in_file.test_file == test_file_path - assert discovered_tests["calculator.Calculator.add"][0].tests_in_file.test_function == "test_add" + calculator_test = next(iter(discovered_tests["calculator.Calculator.add"])) + assert calculator_test.tests_in_file.test_file == test_file_path + assert calculator_test.tests_in_file.test_function == "test_add" def test_unittest_discovery_with_pytest_parent_class(): @@ -609,14 +605,15 @@ def test_add(self): ) # Discover tests - discovered_tests = discover_unit_tests(test_config) + discovered_tests, _ = discover_unit_tests(test_config) # Verify the unittest was discovered assert len(discovered_tests) == 2 assert "calculator.Calculator.add" in discovered_tests assert len(discovered_tests["calculator.Calculator.add"]) == 1 - assert discovered_tests["calculator.Calculator.add"][0].tests_in_file.test_file == test_file_path - assert discovered_tests["calculator.Calculator.add"][0].tests_in_file.test_function == "test_add" + calculator_test = next(iter(discovered_tests["calculator.Calculator.add"])) + assert calculator_test.tests_in_file.test_file == test_file_path + assert calculator_test.tests_in_file.test_function == "test_add" def test_unittest_discovery_with_pytest_private(): @@ -654,7 +651,7 @@ def _test_add(self): # Private test method should not be discovered ) # Discover tests - discovered_tests = discover_unit_tests(test_config) + discovered_tests, _ = discover_unit_tests(test_config) # Verify no tests were discovered assert len(discovered_tests) == 0 @@ -706,15 +703,16 @@ def test_add_with_parameters(self): ) # Discover tests - discovered_tests = discover_unit_tests(test_config) + discovered_tests, _ = discover_unit_tests(test_config) # Verify the unittest was discovered assert len(discovered_tests) == 1 assert "calculator.Calculator.add" in discovered_tests assert len(discovered_tests["calculator.Calculator.add"]) == 1 - assert discovered_tests["calculator.Calculator.add"][0].tests_in_file.test_file == test_file_path + calculator_test = next(iter(discovered_tests["calculator.Calculator.add"])) + assert calculator_test.tests_in_file.test_file == test_file_path assert ( - discovered_tests["calculator.Calculator.add"][0].tests_in_file.test_function == "test_add_with_parameters" + calculator_test.tests_in_file.test_function == "test_add_with_parameters" ) @@ -788,7 +786,7 @@ def test_add_mixed(self, name, a, b, expected): ) # Discover tests - discovered_tests = discover_unit_tests(test_config) + discovered_tests, _ = discover_unit_tests(test_config) # Verify the basic structure assert len(discovered_tests) == 2 # Should have tests for both add and multiply @@ -1069,7 +1067,7 @@ def test_other(): ) # Test without filtering - all_tests = discover_unit_tests(test_config) + all_tests, _ = discover_unit_tests(test_config) assert len(all_tests) == 2 # Should find both functions # Test with filtering - create mock FunctionToOptimize objects @@ -1079,13 +1077,9 @@ def test_other(): mock_function.function_name = "target_function" mock_function.parents = [] # No parent classes - filtered_tests = discover_unit_tests(test_config, functions_to_optimize=[mock_function]) - # The import filter is designed for high recall, so it may include both functions - # because both test files import from the same module (mycode) that contains target_function - assert len(filtered_tests) >= 1 # Should find at least target_function + filtered_tests, _ = discover_unit_tests(test_config, file_to_funcs_to_optimize={code_file: [mock_function]}) + assert len(filtered_tests) >= 1 assert "mycode.target_function" in filtered_tests - # In a perfect world we'd filter out other_function, but conservative filtering - # is acceptable for performance optimization purposes def test_analyze_imports_conditional_import(): @@ -1216,7 +1210,7 @@ def test_unrelated(): ) # Test without filtering - all_tests = discover_unit_tests(test_config) + all_tests, _ = discover_unit_tests(test_config) assert len(all_tests) == 2 # Should find both functions # Test with filtering - create mock FunctionToOptimize objects @@ -1225,8 +1219,8 @@ def test_unrelated(): mock_function.qualified_name_with_modules_from_root.return_value = "target_module.target_function" mock_function.function_name = "target_function" mock_function.parents = [] # No parent classes - - filtered_tests = discover_unit_tests(test_config, functions_to_optimize=[mock_function]) + + filtered_tests, _ = discover_unit_tests(test_config, file_to_funcs_to_optimize={target_file: [mock_function]}) # Should filter out the unrelated test since it imports from a different module assert len(filtered_tests) == 1 assert "target_module.target_function" in filtered_tests From 337b0ee5de578913f1c3a0f36681ea518d040ea1 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Sat, 7 Jun 2025 16:22:08 -0700 Subject: [PATCH 3/6] use qualified name from functiontooptimize object --- codeflash/discovery/discover_unit_tests.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index e2ccd3be4..bff5a77f5 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -475,18 +475,10 @@ def process_test_files( project_root_path = cfg.project_root_path test_framework = cfg.test_framework - # Apply import filter if functions to optimize are provided if functions_to_optimize: - # Extract target function names from FunctionToOptimize objects - # Include both qualified names and simple function names for better matching target_function_names = set() for func in functions_to_optimize: - target_function_names.add(func.qualified_name_with_modules_from_root(project_root_path)) - target_function_names.add(func.function_name) # Add simple name too - # Also add qualified name without module - if func.parents: - target_function_names.add(f"{func.parents[0].name}.{func.function_name}") - + target_function_names.add(func.qualified_name) logger.debug(f"Target functions for import filtering: {target_function_names}") file_to_test_map, import_results = filter_test_files_by_imports(file_to_test_map, target_function_names) logger.debug(f"Import analysis results: {len(import_results)} files analyzed") From 00868002fd2586d0e8c7e295d49fd5b1f7a30fa2 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Sat, 7 Jun 2025 18:18:06 -0700 Subject: [PATCH 4/6] handle qualified names --- codeflash/discovery/discover_unit_tests.py | 53 +++++++--------------- tests/test_unit_test_discovery.py | 53 ++++++++++------------ 2 files changed, 39 insertions(+), 67 deletions(-) diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index bff5a77f5..0ec07ab56 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -148,6 +148,7 @@ def __init__(self, function_names_to_find: set[str]) -> None: self.imported_names: set[str] = set() self.imported_modules: set[str] = set() self.found_target_functions: set[str] = set() + self.qualified_names_called: set[str] = set() def visit_Import(self, node: ast.Import) -> None: """Handle 'import module' statements.""" @@ -164,13 +165,16 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None: for alias in node.names: if alias.name == "*": - # Star imports - we can't know what's imported, so be conservative - self.imported_names.add("*") - else: - imported_name = alias.asname if alias.asname else alias.name - self.imported_names.add(imported_name) - if alias.name in self.function_names_to_find: - self.found_target_functions.add(alias.name) + continue + imported_name = alias.asname if alias.asname else alias.name + self.imported_names.add(imported_name) + if alias.name in self.function_names_to_find: + self.found_target_functions.add(alias.name) + # Check for qualified name matches + if node.module: + qualified_name = f"{node.module}.{alias.name}" + if qualified_name in self.function_names_to_find: + self.found_target_functions.add(qualified_name) self.generic_visit(node) def visit_Call(self, node: ast.Call) -> None: @@ -207,6 +211,9 @@ def visit_Attribute(self, node: ast.Attribute) -> None: """Handle module.function_name patterns.""" if node.attr in self.function_names_to_find: self.found_target_functions.add(node.attr) + if isinstance(node.value, ast.Name): + qualified_name = f"{node.value.id}.{node.attr}" + self.qualified_names_called.add(qualified_name) self.generic_visit(node) @@ -232,37 +239,10 @@ def analyze_imports_in_test_file(test_file_path: Path | str, target_functions: s analyzer = ImportAnalyzer(target_functions) analyzer.visit(tree) - # If we found direct function matches, definitely process if analyzer.found_target_functions: return True, analyzer.found_target_functions - # If there are star imports, we need to be conservative - if "*" in analyzer.imported_names: - return True, set() - - # Check for direct name matches first (higher priority) - name_matches = analyzer.imported_names & target_functions - if name_matches: - return True, name_matches - - # If no direct matches, check if any imported modules could contain our target functions - # This is a heuristic - we look for common patterns - potential_matches = set() - for module in analyzer.imported_modules: - # Check if module name suggests it could contain target functions - for func_name in target_functions: - # Only match if the module name is a prefix of the function qualified name - func_parts = func_name.split(".") - if len(func_parts) > 1 and module == func_parts[0]: - # Module matches the first part of qualified name (e.g., mycode in mycode.target_function) - # But only if we don't have specific import information suggesting otherwise - potential_matches.add(func_name) - elif any(part in module for part in func_name.split("_")) and len(func_name.split("_")) > 1: - # Function name parts match module name (for underscore-separated names) - potential_matches.add(func_name) - - # Only use heuristic matches if we haven't found specific function imports that contradict them - return bool(potential_matches), potential_matches + return False, set() # noqa: TRY300 except (SyntaxError, UnicodeDecodeError, OSError) as e: logger.debug(f"Failed to analyze imports in {test_file_path}: {e}") @@ -283,7 +263,6 @@ def filter_test_files_by_imports( """ if not target_functions: - # If no target functions specified, process all files return file_to_test_map, {} filtered_map = {} @@ -479,7 +458,7 @@ def process_test_files( target_function_names = set() for func in functions_to_optimize: target_function_names.add(func.qualified_name) - logger.debug(f"Target functions for import filtering: {target_function_names}") + logger.info(f"Target functions for import filtering: {target_function_names}") file_to_test_map, import_results = filter_test_files_by_imports(file_to_test_map, target_function_names) logger.debug(f"Import analysis results: {len(import_results)} files analyzed") diff --git a/tests/test_unit_test_discovery.py b/tests/test_unit_test_discovery.py index da437535a..4f230eeea 100644 --- a/tests/test_unit_test_discovery.py +++ b/tests/test_unit_test_discovery.py @@ -9,6 +9,7 @@ ) from codeflash.models.models import TestsInFile, TestType from codeflash.verification.verification_utils import TestConfig +from codeflash.discovery.functions_to_optimize import FunctionToOptimize def test_unit_test_discovery_pytest(): @@ -832,8 +833,8 @@ def test_something(): target_functions = {"target_function"} should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) - assert should_process is True # Conservative approach with star imports - assert found_functions == set() # No specific functions identified + assert should_process is False + assert found_functions == set() def test_analyze_imports_module_import(): @@ -907,13 +908,11 @@ def test_unrelated(): target_functions = {"target_function", "another_function"} should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) - assert should_process is False assert found_functions == set() -def test_analyze_imports_heuristic_matching(): - """Test heuristic module name matching.""" +def test_analyze_qualified_names(): with tempfile.TemporaryDirectory() as tmpdirname: test_file = Path(tmpdirname) / "test_example.py" test_content = """ @@ -924,11 +923,11 @@ def test_target(): """ test_file.write_text(test_content) - target_functions = {"target_function"} # Function name partially matches module name + target_functions = {"target_module.some_function"} should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) - assert should_process is True - assert "target_function" in found_functions + assert "target_module.some_function" in found_functions + def test_analyze_imports_syntax_error(): @@ -952,7 +951,6 @@ def test_target( def test_filter_test_files_by_imports(): - """Test the complete filtering functionality.""" with tempfile.TemporaryDirectory() as tmpdirname: tmpdir = Path(tmpdirname) @@ -974,7 +972,7 @@ def test_other(): assert other_function() is True """) - # Create test file with star import (should be processed) + # Create test file with star import (should not be processed) star_test = tmpdir / "test_star.py" star_test.write_text(""" from mymodule import * @@ -983,7 +981,6 @@ def test_star(): assert something() is True """) - # Build file_to_test_map file_to_test_map = { relevant_test: [TestsInFile(test_file=relevant_test, test_function="test_target", test_class=None, test_type=TestType.EXISTING_UNIT_TEST)], irrelevant_test: [TestsInFile(test_file=irrelevant_test, test_function="test_other", test_class=None, test_type=TestType.EXISTING_UNIT_TEST)], @@ -993,16 +990,15 @@ def test_star(): target_functions = {"target_function"} filtered_map, import_results = filter_test_files_by_imports(file_to_test_map, target_functions) - # Should filter out irrelevant_test but keep relevant_test and star_test - assert len(filtered_map) == 2 + # Should filter out irrelevant_test + assert len(filtered_map) == 1 assert relevant_test in filtered_map - assert star_test in filtered_map assert irrelevant_test not in filtered_map # Check import analysis results assert "target_function" in import_results[relevant_test] assert len(import_results[irrelevant_test]) == 0 - assert len(import_results[star_test]) == 0 # Star import doesn't identify specific functions + assert len(import_results[star_test]) == 0 def test_filter_test_files_no_target_functions(): @@ -1066,18 +1062,17 @@ def test_other(): tests_project_rootdir=tmpdir.parent, ) - # Test without filtering all_tests, _ = discover_unit_tests(test_config) - assert len(all_tests) == 2 # Should find both functions - - # Test with filtering - create mock FunctionToOptimize objects - from unittest.mock import Mock - mock_function = Mock() - mock_function.qualified_name_with_modules_from_root.return_value = "mycode.target_function" - mock_function.function_name = "target_function" - mock_function.parents = [] # No parent classes + assert len(all_tests) == 2 - filtered_tests, _ = discover_unit_tests(test_config, file_to_funcs_to_optimize={code_file: [mock_function]}) + + fto = FunctionToOptimize( + function_name="target_function", + file_path=code_file, + parents=[], + ) + + filtered_tests, _ = discover_unit_tests(test_config, file_to_funcs_to_optimize={code_file: [fto]}) assert len(filtered_tests) >= 1 assert "mycode.target_function" in filtered_tests @@ -1146,7 +1141,6 @@ def test_aliased(): def test_analyze_imports_underscore_function_names(): - """Test handling of function names with underscores in heuristic matching.""" with tempfile.TemporaryDirectory() as tmpdirname: test_file = Path(tmpdirname) / "test_example.py" test_content = """ @@ -1157,12 +1151,11 @@ def test_bubble(): """ test_file.write_text(test_content) - target_functions = {"bubble_sort"} # Function name parts match module + target_functions = {"bubble_sort"} should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) - assert should_process is True - assert "bubble_sort" in found_functions - + assert should_process is False + assert "bubble_sort" not in found_functions def test_discover_unit_tests_filtering_different_modules(): """Test import filtering with test files from completely different modules.""" From 444e4d418687bed1face9e6f6384d325ef0ab617 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Sat, 7 Jun 2025 18:25:10 -0700 Subject: [PATCH 5/6] don't mock --- tests/test_unit_test_discovery.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/test_unit_test_discovery.py b/tests/test_unit_test_discovery.py index 4f230eeea..4465acf79 100644 --- a/tests/test_unit_test_discovery.py +++ b/tests/test_unit_test_discovery.py @@ -1206,15 +1206,13 @@ def test_unrelated(): all_tests, _ = discover_unit_tests(test_config) assert len(all_tests) == 2 # Should find both functions - # Test with filtering - create mock FunctionToOptimize objects - from unittest.mock import Mock - mock_function = Mock() - mock_function.qualified_name_with_modules_from_root.return_value = "target_module.target_function" - mock_function.function_name = "target_function" - mock_function.parents = [] # No parent classes - - filtered_tests, _ = discover_unit_tests(test_config, file_to_funcs_to_optimize={target_file: [mock_function]}) - # Should filter out the unrelated test since it imports from a different module + fto = FunctionToOptimize( + function_name="target_function", + file_path=target_file, + parents=[], + ) + + filtered_tests, _ = discover_unit_tests(test_config, file_to_funcs_to_optimize={target_file: [fto]}) assert len(filtered_tests) == 1 assert "target_module.target_function" in filtered_tests assert "unrelated_module.unrelated_function" not in filtered_tests From e6b2f895e1d368324f8079bb744e3c83bbb71076 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Sat, 7 Jun 2025 18:26:56 -0700 Subject: [PATCH 6/6] leftover info --- codeflash/discovery/discover_unit_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 0ec07ab56..7a5bf6f9d 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -458,7 +458,7 @@ def process_test_files( target_function_names = set() for func in functions_to_optimize: target_function_names.add(func.qualified_name) - logger.info(f"Target functions for import filtering: {target_function_names}") + logger.debug(f"Target functions for import filtering: {target_function_names}") file_to_test_map, import_results = filter_test_files_by_imports(file_to_test_map, target_function_names) logger.debug(f"Import analysis results: {len(import_results)} files analyzed")