From 3dd705b438b28ea94a325d1bc3905112a94d8703 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Mon, 2 Jun 2025 19:22:19 -0700 Subject: [PATCH 01/21] rename conftest --- codeflash/code_utils/code_utils.py | 15 ++++++++++++++- codeflash/code_utils/config_parser.py | 13 ++++++++++++- codeflash/discovery/discover_unit_tests.py | 5 +++-- codeflash/verification/test_runner.py | 4 ++-- 4 files changed, 31 insertions(+), 6 deletions(-) diff --git a/codeflash/code_utils/code_utils.py b/codeflash/code_utils/code_utils.py index 4c167fb69..65a6c2db6 100644 --- a/codeflash/code_utils/code_utils.py +++ b/codeflash/code_utils/code_utils.py @@ -13,7 +13,7 @@ import tomlkit from codeflash.cli_cmds.console import logger -from codeflash.code_utils.config_parser import find_pyproject_toml +from codeflash.code_utils.config_parser import find_pyproject_toml, find_conftest @contextmanager @@ -84,6 +84,19 @@ def add_addopts_to_pyproject() -> None: with Path.open(pyproject_file, "w", encoding="utf-8") as f: f.write(original_content) +@contextmanager +def rename_conftest() -> None: + conftest_file = find_conftest() + tmp_conftest_file = Path(conftest_file + ".tmp") + try: + # Rename original file + if conftest_file.exists(): + conftest_file.rename(tmp_conftest_file) + yield + finally: + # Restore original file + if conftest_file.exists(): + tmp_conftest_file.rename(conftest_file) def encoded_tokens_len(s: str) -> int: """Return the approximate length of the encoded tokens. diff --git a/codeflash/code_utils/config_parser.py b/codeflash/code_utils/config_parser.py index 7b6243a75..0dfc601b2 100644 --- a/codeflash/code_utils/config_parser.py +++ b/codeflash/code_utils/config_parser.py @@ -1,7 +1,7 @@ from __future__ import annotations from pathlib import Path -from typing import Any +from typing import Any, Union import tomlkit @@ -30,6 +30,17 @@ def find_pyproject_toml(config_file: Path | None = None) -> Path: raise ValueError(msg) +def find_conftest() -> Union[Path, None]: + # Find the conftest file on the root of the project + dir_path = Path.cwd() + while dir_path != dir_path.parent: + config_file = dir_path / "conftest.py" + if config_file.exists(): + return config_file + # Search for conftest.py in the parent directories + dir_path = dir_path.parent + return None + def parse_config_file( config_file_path: Path | None = None, diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index b76e63a91..2eeabe2d9 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -16,7 +16,8 @@ from pydantic.dataclasses import dataclass from codeflash.cli_cmds.console import console, logger, test_files_progress_bar -from codeflash.code_utils.code_utils import custom_addopts, get_run_tmp_file, module_name_from_file_path +from codeflash.code_utils.code_utils import custom_addopts, get_run_tmp_file, module_name_from_file_path, \ + rename_conftest from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_cache_db from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType @@ -149,7 +150,7 @@ def discover_tests_pytest( project_root = cfg.project_root_path tmp_pickle_path = get_run_tmp_file("collected_tests.pkl") - with custom_addopts(): + with custom_addopts(), rename_conftest(): result = subprocess.run( [ SAFE_SYS_EXECUTABLE, diff --git a/codeflash/verification/test_runner.py b/codeflash/verification/test_runner.py index f0d1c01a5..042441e24 100644 --- a/codeflash/verification/test_runner.py +++ b/codeflash/verification/test_runner.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING from codeflash.cli_cmds.console import logger -from codeflash.code_utils.code_utils import custom_addopts, get_run_tmp_file +from codeflash.code_utils.code_utils import custom_addopts, get_run_tmp_file, rename_conftest from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME from codeflash.code_utils.coverage_utils import prepare_coverage_files @@ -23,7 +23,7 @@ def execute_test_subprocess( cmd_list: list[str], cwd: Path, env: dict[str, str] | None, timeout: int = 600 ) -> subprocess.CompletedProcess: """Execute a subprocess with the given command list, working directory, environment variables, and timeout.""" - with custom_addopts(): + with custom_addopts(), rename_conftest(): logger.debug(f"executing test run with command: {' '.join(cmd_list)}") return subprocess.run(cmd_list, capture_output=True, cwd=cwd, env=env, text=True, timeout=timeout, check=False) From 7137ac63c1b3feb02a0701e43a2a66abb1618355 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Mon, 2 Jun 2025 19:22:28 -0700 Subject: [PATCH 02/21] rename conftest --- codeflash/code_utils/code_utils.py | 4 +++- codeflash/code_utils/config_parser.py | 1 + codeflash/discovery/discover_unit_tests.py | 8 ++++++-- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/codeflash/code_utils/code_utils.py b/codeflash/code_utils/code_utils.py index 65a6c2db6..97444f066 100644 --- a/codeflash/code_utils/code_utils.py +++ b/codeflash/code_utils/code_utils.py @@ -13,7 +13,7 @@ import tomlkit from codeflash.cli_cmds.console import logger -from codeflash.code_utils.config_parser import find_pyproject_toml, find_conftest +from codeflash.code_utils.config_parser import find_conftest, find_pyproject_toml @contextmanager @@ -84,6 +84,7 @@ def add_addopts_to_pyproject() -> None: with Path.open(pyproject_file, "w", encoding="utf-8") as f: f.write(original_content) + @contextmanager def rename_conftest() -> None: conftest_file = find_conftest() @@ -98,6 +99,7 @@ def rename_conftest() -> None: if conftest_file.exists(): tmp_conftest_file.rename(conftest_file) + def encoded_tokens_len(s: str) -> int: """Return the approximate length of the encoded tokens. diff --git a/codeflash/code_utils/config_parser.py b/codeflash/code_utils/config_parser.py index 0dfc601b2..ea9917e31 100644 --- a/codeflash/code_utils/config_parser.py +++ b/codeflash/code_utils/config_parser.py @@ -30,6 +30,7 @@ def find_pyproject_toml(config_file: Path | None = None) -> Path: raise ValueError(msg) + def find_conftest() -> Union[Path, None]: # Find the conftest file on the root of the project dir_path = Path.cwd() diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 2eeabe2d9..0fad9eac4 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -16,8 +16,12 @@ from pydantic.dataclasses import dataclass from codeflash.cli_cmds.console import console, logger, test_files_progress_bar -from codeflash.code_utils.code_utils import custom_addopts, get_run_tmp_file, module_name_from_file_path, \ - rename_conftest +from codeflash.code_utils.code_utils import ( + custom_addopts, + get_run_tmp_file, + module_name_from_file_path, + rename_conftest, +) from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_cache_db from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType From e433989620ac0d8a0c5390491a1afdba1625022d Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Mon, 2 Jun 2025 20:05:22 -0700 Subject: [PATCH 03/21] bugfix --- codeflash/code_utils/code_utils.py | 11 ++++++----- codeflash/code_utils/config_parser.py | 9 +++++---- codeflash/discovery/discover_unit_tests.py | 2 +- codeflash/verification/test_runner.py | 4 ++-- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/codeflash/code_utils/code_utils.py b/codeflash/code_utils/code_utils.py index ebd8d1462..66b501ac6 100644 --- a/codeflash/code_utils/code_utils.py +++ b/codeflash/code_utils/code_utils.py @@ -88,17 +88,18 @@ def add_addopts_to_pyproject() -> None: @contextmanager -def rename_conftest() -> None: - conftest_file = find_conftest() - tmp_conftest_file = Path(conftest_file + ".tmp") +def rename_conftest(tests_path: Path) -> None: + conftest_file = find_conftest(tests_path) + tmp_conftest_file = None try: # Rename original file - if conftest_file.exists(): + if conftest_file: + tmp_conftest_file = Path(str(conftest_file) + ".tmp") conftest_file.rename(tmp_conftest_file) yield finally: # Restore original file - if conftest_file.exists(): + if conftest_file: tmp_conftest_file.rename(conftest_file) diff --git a/codeflash/code_utils/config_parser.py b/codeflash/code_utils/config_parser.py index ea9917e31..74b77f04c 100644 --- a/codeflash/code_utils/config_parser.py +++ b/codeflash/code_utils/config_parser.py @@ -31,15 +31,16 @@ def find_pyproject_toml(config_file: Path | None = None) -> Path: raise ValueError(msg) -def find_conftest() -> Union[Path, None]: +def find_conftest(tests_path: Path) -> Union[Path, None]: # Find the conftest file on the root of the project dir_path = Path.cwd() - while dir_path != dir_path.parent: - config_file = dir_path / "conftest.py" + cur_path = tests_path + while cur_path != dir_path: + config_file = cur_path / "conftest.py" if config_file.exists(): return config_file # Search for conftest.py in the parent directories - dir_path = dir_path.parent + cur_path = cur_path.parent return None diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 69cd6038f..58fe7abd0 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -157,7 +157,7 @@ def discover_tests_pytest( project_root = cfg.project_root_path tmp_pickle_path = get_run_tmp_file("collected_tests.pkl") - with custom_addopts(), rename_conftest(): + with custom_addopts(), rename_conftest(tests_root): result = subprocess.run( [ SAFE_SYS_EXECUTABLE, diff --git a/codeflash/verification/test_runner.py b/codeflash/verification/test_runner.py index 042441e24..f0d1c01a5 100644 --- a/codeflash/verification/test_runner.py +++ b/codeflash/verification/test_runner.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING from codeflash.cli_cmds.console import logger -from codeflash.code_utils.code_utils import custom_addopts, get_run_tmp_file, rename_conftest +from codeflash.code_utils.code_utils import custom_addopts, get_run_tmp_file from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME from codeflash.code_utils.coverage_utils import prepare_coverage_files @@ -23,7 +23,7 @@ def execute_test_subprocess( cmd_list: list[str], cwd: Path, env: dict[str, str] | None, timeout: int = 600 ) -> subprocess.CompletedProcess: """Execute a subprocess with the given command list, working directory, environment variables, and timeout.""" - with custom_addopts(), rename_conftest(): + with custom_addopts(): logger.debug(f"executing test run with command: {' '.join(cmd_list)}") return subprocess.run(cmd_list, capture_output=True, cwd=cwd, env=env, text=True, timeout=timeout, check=False) From 1271f2622cc01a85df7935d67851dc711514f5e2 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Tue, 3 Jun 2025 15:45:30 -0700 Subject: [PATCH 04/21] wip --- code_to_optimize/tests/pytest/test_bubble_sort.py | 3 ++- codeflash/verification/test_runner.py | 1 + pyproject.toml | 4 ++++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/code_to_optimize/tests/pytest/test_bubble_sort.py b/code_to_optimize/tests/pytest/test_bubble_sort.py index eccad6e09..cd775d2c3 100644 --- a/code_to_optimize/tests/pytest/test_bubble_sort.py +++ b/code_to_optimize/tests/pytest/test_bubble_sort.py @@ -1,6 +1,7 @@ from code_to_optimize.bubble_sort import sorter +import pytest - +@pytest.mark.no_autouse def test_sort(): input = [5, 4, 3, 2, 1, 0] output = sorter(input) diff --git a/codeflash/verification/test_runner.py b/codeflash/verification/test_runner.py index f0d1c01a5..072c1023d 100644 --- a/codeflash/verification/test_runner.py +++ b/codeflash/verification/test_runner.py @@ -53,6 +53,7 @@ def run_behavioral_tests( ) else: test_files.append(str(file.instrumented_behavior_file_path)) + test_files = ['/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/test_bubble_sort__perfinstrumented.py'] pytest_cmd_list = ( shlex.split(f"{SAFE_SYS_EXECUTABLE} -m pytest", posix=IS_POSIX) if pytest_cmd == "pytest" diff --git a/pyproject.toml b/pyproject.toml index c3e48f889..b5f774b5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -238,6 +238,10 @@ formatter-cmds = [ "uvx ruff check --exit-zero --fix $file", "uvx ruff format $file", ] +[tool.pytest.ini_options] +markers = [ + "no_autouse" +] [build-system] From fee9f114c863e0709b9382e9307567309bd23777 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Tue, 3 Jun 2025 16:47:20 -0700 Subject: [PATCH 05/21] wip --- code_to_optimize/tests/pytest/conftest.py.tmp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 code_to_optimize/tests/pytest/conftest.py.tmp diff --git a/code_to_optimize/tests/pytest/conftest.py.tmp b/code_to_optimize/tests/pytest/conftest.py.tmp new file mode 100644 index 000000000..0efaf5189 --- /dev/null +++ b/code_to_optimize/tests/pytest/conftest.py.tmp @@ -0,0 +1,14 @@ +import pytest +import time + + +@pytest.fixture(autouse=True) +def fixture(request): + if request.node.get_closest_marker("no_autouse"): + # Skip the fixture logic + yield + else: + start_time = time.time() + time.sleep(0.1) + yield + print(f"Took {time.time() - start_time} seconds") \ No newline at end of file From 81ce1fed2aff1dd385791c12fdb19073eb2458c6 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Wed, 4 Jun 2025 18:33:59 -0700 Subject: [PATCH 06/21] wip --- .../benchmarks/test_benchmark_bubble_sort.py | 19 +++++++++ code_to_optimize/tests/pytest/conftest.py | 26 ++++++++++++ code_to_optimize/tests/pytest/conftest.py.tmp | 14 ------- codeflash/code_utils/code_replacer.py | 7 ++++ codeflash/code_utils/code_utils.py | 42 +++++++++++++++++++ pyproject.toml | 2 + 6 files changed, 96 insertions(+), 14 deletions(-) create mode 100644 code_to_optimize/tests/pytest/conftest.py delete mode 100644 code_to_optimize/tests/pytest/conftest.py.tmp diff --git a/code_to_optimize/tests/pytest/benchmarks/test_benchmark_bubble_sort.py b/code_to_optimize/tests/pytest/benchmarks/test_benchmark_bubble_sort.py index 3d7b24a6c..ce7839119 100644 --- a/code_to_optimize/tests/pytest/benchmarks/test_benchmark_bubble_sort.py +++ b/code_to_optimize/tests/pytest/benchmarks/test_benchmark_bubble_sort.py @@ -3,10 +3,29 @@ from code_to_optimize.bubble_sort import sorter +class DummyBenchmark: + def __call__(self, func, *args, **kwargs): + # Mimic calling benchmark(func, *args, **kwargs) + return func(*args, **kwargs) + + def __getattr__(self, name): + # Mimic benchmark attributes like .pedantic, .extra_info etc. + def dummy(*args, **kwargs): + return None + + return dummy + + +@pytest.fixture +def benchmark(): + return DummyBenchmark() + + def test_sort(benchmark): result = benchmark(sorter, list(reversed(range(500)))) assert result == list(range(500)) + # This should not be picked up as a benchmark test def test_sort2(): result = sorter(list(reversed(range(500)))) diff --git a/code_to_optimize/tests/pytest/conftest.py b/code_to_optimize/tests/pytest/conftest.py new file mode 100644 index 000000000..aff5ceecd --- /dev/null +++ b/code_to_optimize/tests/pytest/conftest.py @@ -0,0 +1,26 @@ +import pytest +import time + +# @pytest.fixture(autouse=True) +# def fixture(request): +# if request.node.get_closest_marker("no_autouse"): +# # Skip the fixture logic +# yield +# else: +# start_time = time.time() +# time.sleep(0.1) +# yield +# print(f"Took {time.time() - start_time} seconds") + + +@pytest.fixture(autouse=True) +def fixture1(request): # We don't need this fixture during testing + start_time = time.time() + time.sleep(0.1) + yield + print(f"Took {time.time() - start_time} seconds") + + +@pytest.fixture(autouse=True) +def fixture2(request): # We need it + yield diff --git a/code_to_optimize/tests/pytest/conftest.py.tmp b/code_to_optimize/tests/pytest/conftest.py.tmp deleted file mode 100644 index 0efaf5189..000000000 --- a/code_to_optimize/tests/pytest/conftest.py.tmp +++ /dev/null @@ -1,14 +0,0 @@ -import pytest -import time - - -@pytest.fixture(autouse=True) -def fixture(request): - if request.node.get_closest_marker("no_autouse"): - # Skip the fixture logic - yield - else: - start_time = time.time() - time.sleep(0.1) - yield - print(f"Took {time.time() - start_time} seconds") \ No newline at end of file diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index eb367bdfa..cb5632fcd 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -33,6 +33,13 @@ def normalize_code(code: str) -> str: return ast.unparse(normalize_node(ast.parse(code))) +def modify_autouse_fixture(): + # find fixutre definition in conftetst.py (the one closest to the test) + # get fixtures present in override-fixtures in pyproject.toml + # add if marker closest return + autousetransformer + + class OptimFunctionCollector(cst.CSTVisitor): METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,) diff --git a/codeflash/code_utils/code_utils.py b/codeflash/code_utils/code_utils.py index 66b501ac6..43f7a1879 100644 --- a/codeflash/code_utils/code_utils.py +++ b/codeflash/code_utils/code_utils.py @@ -87,6 +87,48 @@ def add_addopts_to_pyproject() -> None: f.write(original_content) +@contextmanager +def add_override_fixtures_to_pyproject() -> None: + pyproject_file = find_pyproject_toml() + try: + # Read original file + if pyproject_file.exists(): + with Path.open(pyproject_file, encoding="utf-8") as f: + original_content = f.read() + data = tomlkit.parse(original_content) + # Backup original markers + original_fixtures = data.get("tool", {}).get("codeflash", {}).get("override-fixtures", []) + original_fixtures.append("please_put_your_fixtures_here") + data["tool"]["pytest"]["override-fixtures"]["markers"] = list(original_fixtures) + with Path.open(pyproject_file, "w", encoding="utf-8") as f: + f.write(tomlkit.dumps(data)) + yield + finally: + with Path.open(pyproject_file, "w", encoding="utf-8") as f: + f.write(original_content) + + +@contextmanager +def add_custom_markers_to_pyproject() -> None: + pyproject_file = find_pyproject_toml() + try: + # Read original file + if pyproject_file.exists(): + with Path.open(pyproject_file, encoding="utf-8") as f: + original_content = f.read() + data = tomlkit.parse(original_content) + # Backup original markers + original_markers = data.get("tool", {}).get("pytest", {}).get("ini_options", {}).get("markers", []) + original_markers.append("codeflash_no_autouse") + data["tool"]["pytest"]["ini_options"]["markers"] = list(original_markers) + with Path.open(pyproject_file, "w", encoding="utf-8") as f: + f.write(tomlkit.dumps(data)) + yield + finally: + with Path.open(pyproject_file, "w", encoding="utf-8") as f: + f.write(original_content) + + @contextmanager def rename_conftest(tests_path: Path) -> None: conftest_file = find_conftest(tests_path) diff --git a/pyproject.toml b/pyproject.toml index b5f774b5a..169f7d298 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -238,6 +238,8 @@ formatter-cmds = [ "uvx ruff check --exit-zero --fix $file", "uvx ruff format $file", ] +override-fixtures = ["fixture1"] #autouse fixtures present in conftest.py which have to be disabled during test execution + [tool.pytest.ini_options] markers = [ "no_autouse" From ed952d7506ea3370c54eaedc02d4a3fa88df8409 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Thu, 5 Jun 2025 13:31:34 -0700 Subject: [PATCH 07/21] wip --- codeflash/cli_cmds/cmd_init.py | 9 +++- codeflash/code_utils/code_replacer.py | 9 +++- codeflash/code_utils/code_utils.py | 60 +--------------------- codeflash/code_utils/config_parser.py | 9 ++-- codeflash/discovery/discover_unit_tests.py | 3 +- codeflash/verification/test_runner.py | 1 - pyproject.toml | 6 --- 7 files changed, 22 insertions(+), 75 deletions(-) diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index bfe600fa4..965188637 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -34,7 +34,7 @@ from argparse import Namespace CODEFLASH_LOGO: str = ( - f"{LF}" # noqa: ISC003 + f"{LF}" r" _ ___ _ _ " + f"{LF}" r" | | / __)| | | | " + f"{LF}" r" ____ ___ _ | | ____ | |__ | | ____ ___ | | _ " + f"{LF}" @@ -723,11 +723,16 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None: formatter_cmds.append("disabled") check_formatter_installed(formatter_cmds, exit_on_failure=False) codeflash_section["formatter-cmds"] = formatter_cmds + codeflash_section["override-fixtures"] = False # don't override fixtures by default, let the user decide # Add the 'codeflash' section, ensuring 'tool' section exists tool_section = pyproject_data.get("tool", tomlkit.table()) tool_section["codeflash"] = codeflash_section pyproject_data["tool"] = tool_section - + if "tool.pytest.ini_options" not in pyproject_data: + pyproject_data["tool.pytest.ini_options"] = {} + if "markers" not in pyproject_data["tool.pytest.ini_options"]: + pyproject_data["tool.pytest.ini_options"]["markers"] = [] + pyproject_data["tool.pytest.ini_options"]["markers"].append("codeflash_no_autouse") with toml_path.open("w", encoding="utf8") as pyproject_file: pyproject_file.write(tomlkit.dumps(pyproject_data)) click.echo(f"✅ Added Codeflash configuration to {toml_path}") diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index cb5632fcd..602d02c71 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -9,6 +9,7 @@ from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module +from codeflash.code_utils.config_parser import find_conftest_files from codeflash.models.models import FunctionParent if TYPE_CHECKING: @@ -37,7 +38,13 @@ def modify_autouse_fixture(): # find fixutre definition in conftetst.py (the one closest to the test) # get fixtures present in override-fixtures in pyproject.toml # add if marker closest return - autousetransformer + conftest_files = find_conftest_files() + for cf_file in conftest_files: + #iterate over all functions in the file + # if function has autouse fixture, modify function to bypass with custom marker + pass + +#reuse line profiler utils to add decorator and import to test fns class OptimFunctionCollector(cst.CSTVisitor): diff --git a/codeflash/code_utils/code_utils.py b/codeflash/code_utils/code_utils.py index 43f7a1879..5fc9bd9e9 100644 --- a/codeflash/code_utils/code_utils.py +++ b/codeflash/code_utils/code_utils.py @@ -13,7 +13,7 @@ import tomlkit from codeflash.cli_cmds.console import logger -from codeflash.code_utils.config_parser import find_conftest, find_pyproject_toml +from codeflash.code_utils.config_parser import find_pyproject_toml ImportErrorPattern = re.compile(r"ModuleNotFoundError.*$", re.MULTILINE) @@ -87,64 +87,6 @@ def add_addopts_to_pyproject() -> None: f.write(original_content) -@contextmanager -def add_override_fixtures_to_pyproject() -> None: - pyproject_file = find_pyproject_toml() - try: - # Read original file - if pyproject_file.exists(): - with Path.open(pyproject_file, encoding="utf-8") as f: - original_content = f.read() - data = tomlkit.parse(original_content) - # Backup original markers - original_fixtures = data.get("tool", {}).get("codeflash", {}).get("override-fixtures", []) - original_fixtures.append("please_put_your_fixtures_here") - data["tool"]["pytest"]["override-fixtures"]["markers"] = list(original_fixtures) - with Path.open(pyproject_file, "w", encoding="utf-8") as f: - f.write(tomlkit.dumps(data)) - yield - finally: - with Path.open(pyproject_file, "w", encoding="utf-8") as f: - f.write(original_content) - - -@contextmanager -def add_custom_markers_to_pyproject() -> None: - pyproject_file = find_pyproject_toml() - try: - # Read original file - if pyproject_file.exists(): - with Path.open(pyproject_file, encoding="utf-8") as f: - original_content = f.read() - data = tomlkit.parse(original_content) - # Backup original markers - original_markers = data.get("tool", {}).get("pytest", {}).get("ini_options", {}).get("markers", []) - original_markers.append("codeflash_no_autouse") - data["tool"]["pytest"]["ini_options"]["markers"] = list(original_markers) - with Path.open(pyproject_file, "w", encoding="utf-8") as f: - f.write(tomlkit.dumps(data)) - yield - finally: - with Path.open(pyproject_file, "w", encoding="utf-8") as f: - f.write(original_content) - - -@contextmanager -def rename_conftest(tests_path: Path) -> None: - conftest_file = find_conftest(tests_path) - tmp_conftest_file = None - try: - # Rename original file - if conftest_file: - tmp_conftest_file = Path(str(conftest_file) + ".tmp") - conftest_file.rename(tmp_conftest_file) - yield - finally: - # Restore original file - if conftest_file: - tmp_conftest_file.rename(conftest_file) - - def encoded_tokens_len(s: str) -> int: """Return the approximate length of the encoded tokens. diff --git a/codeflash/code_utils/config_parser.py b/codeflash/code_utils/config_parser.py index 74b77f04c..41739a8f7 100644 --- a/codeflash/code_utils/config_parser.py +++ b/codeflash/code_utils/config_parser.py @@ -1,7 +1,7 @@ from __future__ import annotations from pathlib import Path -from typing import Any, Union +from typing import Any import tomlkit @@ -31,17 +31,18 @@ def find_pyproject_toml(config_file: Path | None = None) -> Path: raise ValueError(msg) -def find_conftest(tests_path: Path) -> Union[Path, None]: +def find_conftest_files(tests_path: Path) -> list[Path]: # Find the conftest file on the root of the project dir_path = Path.cwd() cur_path = tests_path + list_of_conftest_files = [] while cur_path != dir_path: config_file = cur_path / "conftest.py" if config_file.exists(): - return config_file + list_of_conftest_files.append(config_file) # Search for conftest.py in the parent directories cur_path = cur_path.parent - return None + return list_of_conftest_files def parse_config_file( diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 58fe7abd0..1fd86acce 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -23,7 +23,6 @@ custom_addopts, get_run_tmp_file, module_name_from_file_path, - rename_conftest, ) from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_cache_db from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType @@ -157,7 +156,7 @@ def discover_tests_pytest( project_root = cfg.project_root_path tmp_pickle_path = get_run_tmp_file("collected_tests.pkl") - with custom_addopts(), rename_conftest(tests_root): + with custom_addopts(): result = subprocess.run( [ SAFE_SYS_EXECUTABLE, diff --git a/codeflash/verification/test_runner.py b/codeflash/verification/test_runner.py index 072c1023d..f0d1c01a5 100644 --- a/codeflash/verification/test_runner.py +++ b/codeflash/verification/test_runner.py @@ -53,7 +53,6 @@ def run_behavioral_tests( ) else: test_files.append(str(file.instrumented_behavior_file_path)) - test_files = ['/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/test_bubble_sort__perfinstrumented.py'] pytest_cmd_list = ( shlex.split(f"{SAFE_SYS_EXECUTABLE} -m pytest", posix=IS_POSIX) if pytest_cmd == "pytest" diff --git a/pyproject.toml b/pyproject.toml index 169f7d298..c3e48f889 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -238,12 +238,6 @@ formatter-cmds = [ "uvx ruff check --exit-zero --fix $file", "uvx ruff format $file", ] -override-fixtures = ["fixture1"] #autouse fixtures present in conftest.py which have to be disabled during test execution - -[tool.pytest.ini_options] -markers = [ - "no_autouse" -] [build-system] From bfc6bb418a916f95b3a64b54b02c5ce7b3665b4f Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Thu, 5 Jun 2025 13:31:40 -0700 Subject: [PATCH 08/21] wip --- codeflash/code_utils/code_replacer.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index 602d02c71..ff0cfd91b 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -9,7 +9,6 @@ from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module -from codeflash.code_utils.config_parser import find_conftest_files from codeflash.models.models import FunctionParent if TYPE_CHECKING: @@ -34,17 +33,18 @@ def normalize_code(code: str) -> str: return ast.unparse(normalize_node(ast.parse(code))) -def modify_autouse_fixture(): - # find fixutre definition in conftetst.py (the one closest to the test) - # get fixtures present in override-fixtures in pyproject.toml - # add if marker closest return - conftest_files = find_conftest_files() - for cf_file in conftest_files: - #iterate over all functions in the file - # if function has autouse fixture, modify function to bypass with custom marker - pass +# def modify_autouse_fixture(): +# # find fixutre definition in conftetst.py (the one closest to the test) +# # get fixtures present in override-fixtures in pyproject.toml +# # add if marker closest return +# conftest_files = find_conftest_files() +# for cf_file in conftest_files: +# # iterate over all functions in the file +# # if function has autouse fixture, modify function to bypass with custom marker +# pass -#reuse line profiler utils to add decorator and import to test fns + +# reuse line profiler utils to add decorator and import to test fns class OptimFunctionCollector(cst.CSTVisitor): From 83f4354f099799b0816e9fc3daa403bc3410aab9 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Thu, 5 Jun 2025 14:30:12 -0700 Subject: [PATCH 09/21] wip --- codeflash/code_utils/config_parser.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/codeflash/code_utils/config_parser.py b/codeflash/code_utils/config_parser.py index 41739a8f7..a0f1c7019 100644 --- a/codeflash/code_utils/config_parser.py +++ b/codeflash/code_utils/config_parser.py @@ -31,18 +31,19 @@ def find_pyproject_toml(config_file: Path | None = None) -> Path: raise ValueError(msg) -def find_conftest_files(tests_path: Path) -> list[Path]: - # Find the conftest file on the root of the project - dir_path = Path.cwd() - cur_path = tests_path - list_of_conftest_files = [] - while cur_path != dir_path: - config_file = cur_path / "conftest.py" - if config_file.exists(): - list_of_conftest_files.append(config_file) - # Search for conftest.py in the parent directories - cur_path = cur_path.parent - return list_of_conftest_files +def find_conftest_files(test_paths: list[Path]) -> list[Path]: + list_of_conftest_files = set() + for test_path in test_paths: + # Find the conftest file on the root of the project + dir_path = Path.cwd() + cur_path = test_path + while cur_path != dir_path: + config_file = cur_path / "conftest.py" + if config_file.exists(): + list_of_conftest_files.add(config_file) + # Search for conftest.py in the parent directories + cur_path = cur_path.parent + return list(list_of_conftest_files) def parse_config_file( From c5c95332e2ff6e0cc9ed1fdbdca07a287b9d84c0 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Thu, 5 Jun 2025 16:38:32 -0700 Subject: [PATCH 10/21] some progress --- code_to_optimize/tests/pytest/conftest.py | 17 +--- .../tests/pytest/test_bubble_sort.py | 3 +- codeflash/cli_cmds/cmd_init.py | 2 +- codeflash/code_utils/code_extractor.py | 49 ++++++++++ codeflash/code_utils/code_replacer.py | 91 ++++++++++++++++--- codeflash/optimization/function_optimizer.py | 10 +- 6 files changed, 144 insertions(+), 28 deletions(-) diff --git a/code_to_optimize/tests/pytest/conftest.py b/code_to_optimize/tests/pytest/conftest.py index aff5ceecd..d382726e0 100644 --- a/code_to_optimize/tests/pytest/conftest.py +++ b/code_to_optimize/tests/pytest/conftest.py @@ -1,20 +1,9 @@ import pytest import time -# @pytest.fixture(autouse=True) -# def fixture(request): -# if request.node.get_closest_marker("no_autouse"): -# # Skip the fixture logic -# yield -# else: -# start_time = time.time() -# time.sleep(0.1) -# yield -# print(f"Took {time.time() - start_time} seconds") - @pytest.fixture(autouse=True) -def fixture1(request): # We don't need this fixture during testing +def fixture1(request): start_time = time.time() time.sleep(0.1) yield @@ -22,5 +11,7 @@ def fixture1(request): # We don't need this fixture during testing @pytest.fixture(autouse=True) -def fixture2(request): # We need it +def fixture2(request): # We don't need this fixture during testing + print("not doing anything") yield + print("did nothing") diff --git a/code_to_optimize/tests/pytest/test_bubble_sort.py b/code_to_optimize/tests/pytest/test_bubble_sort.py index cd775d2c3..eccad6e09 100644 --- a/code_to_optimize/tests/pytest/test_bubble_sort.py +++ b/code_to_optimize/tests/pytest/test_bubble_sort.py @@ -1,7 +1,6 @@ from code_to_optimize.bubble_sort import sorter -import pytest -@pytest.mark.no_autouse + def test_sort(): input = [5, 4, 3, 2, 1, 0] output = sorter(input) diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index 965188637..32e826b81 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -34,7 +34,7 @@ from argparse import Namespace CODEFLASH_LOGO: str = ( - f"{LF}" + f"{LF}" # noqa : ISC003 r" _ ___ _ _ " + f"{LF}" r" | | / __)| | | | " + f"{LF}" r" ____ ___ _ | | ____ | |__ | | ____ ___ | | _ " + f"{LF}" diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 0dcc2357f..dfd0dbc08 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -5,9 +5,11 @@ from typing import TYPE_CHECKING, Optional import libcst as cst +from libcst import MetadataWrapper from libcst.codemod import CodemodContext from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor from libcst.helpers import calculate_module_and_package +from libcst.metadata import FullyQualifiedNameProvider from codeflash.cli_cmds.console import logger from codeflash.models.models import FunctionParent @@ -21,6 +23,53 @@ from codeflash.models.models import FunctionSource +class FunctionNameCollector(cst.CSTVisitor): + """A LibCST visitor that collects the fully qualified names of all functions.""" + + METADATA_DEPENDENCIES = (FullyQualifiedNameProvider,) + + def __init__(self) -> None: + super().__init__() + self.qualified_names: set[str] = set() + + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: + """Visits a function definition node and extracts its qualified name.""" + try: + q_names = self.get_metadata(FullyQualifiedNameProvider, node) + for q_name in q_names: + self.qualified_names.add(q_name.name) + except KeyError: + # This can happen for functions defined in scopes where a qualified + # name cannot be determined. + pass + + +def get_function_qualified_names(file_path: Path) -> list[str]: + """Parse a Python file and returns a list of fully qualified function names. + + Args: + file_path: The path to the Python file. + + Returns: + A list of string representations of the qualified function names. + + """ + with file_path.open("r") as f: + source_code = f.read() + + # Parse the source code into a CST + module = cst.parse_module(source_code) + + # Wrap the module with a metadata wrapper to enable name resolution + wrapper = MetadataWrapper(module) + + # Create an instance of the visitor and visit the wrapped module + visitor = FunctionNameCollector() + wrapper.visit(visitor) + + return list(visitor.qualified_names) + + class GlobalAssignmentCollector(cst.CSTVisitor): """Collects all global assignment statements.""" diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index ff0cfd91b..982a21f80 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -1,14 +1,23 @@ from __future__ import annotations import ast +import contextlib from collections import defaultdict from functools import lru_cache from typing import TYPE_CHECKING, Optional, TypeVar +import isort import libcst as cst +import libcst.matchers as m from codeflash.cli_cmds.console import logger -from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module +from codeflash.code_utils.code_extractor import ( + add_global_assignments, + add_needed_imports_from_module, + get_function_qualified_names, +) +from codeflash.code_utils.config_parser import find_conftest_files +from codeflash.code_utils.line_profile_utils import ImportAdder, add_decorator_to_qualified_function from codeflash.models.models import FunctionParent if TYPE_CHECKING: @@ -33,18 +42,78 @@ def normalize_code(code: str) -> str: return ast.unparse(normalize_node(ast.parse(code))) -# def modify_autouse_fixture(): -# # find fixutre definition in conftetst.py (the one closest to the test) -# # get fixtures present in override-fixtures in pyproject.toml -# # add if marker closest return -# conftest_files = find_conftest_files() -# for cf_file in conftest_files: -# # iterate over all functions in the file -# # if function has autouse fixture, modify function to bypass with custom marker -# pass +class AutouseFixtureModifier(cst.CSTTransformer): + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: + # Matcher for '@fixture' or '@pytest.fixture' + fixture_decorator_func = m.Name("fixture") | m.Attribute(value=m.Name("pytest"), attr=m.Name("fixture")) + + for decorator in original_node.decorators: + if m.matches( + decorator, + m.Decorator( + decorator=m.Call( + func=fixture_decorator_func, args=[m.Arg(value=m.Name("True"), keyword=m.Name("autouse"))] + ) + ), + ): + # Found a matching fixture with autouse=True + + # 1. The original body of the function will become the 'else' block. + # updated_node.body is an IndentedBlock, which is what cst.Else expects. + else_block = cst.Else(body=updated_node.body) + # 2. Create the new 'if' block that will exit the fixture early. + if_test = cst.parse_expression('request.node.get_closest_marker("no_autouse")') + yield_statement = cst.parse_statement("yield") + if_body = cst.IndentedBlock(body=[yield_statement]) -# reuse line profiler utils to add decorator and import to test fns + # 3. Construct the full if/else statement. + new_if_statement = cst.If(test=if_test, body=if_body, orelse=else_block) + + # 4. Replace the entire function's body with our new single statement. + return updated_node.with_changes(body=cst.IndentedBlock(body=[new_if_statement])) + return updated_node + + +@contextlib.contextmanager +def disable_autouse(test_path: Path) -> None: + file_content = test_path.read_text(encoding="utf-8") + try: + module = cst.parse_module(file_content) + disable_autouse_fixture = AutouseFixtureModifier() + modified_module = module.visit(disable_autouse_fixture) + test_path.write_text(modified_module.code, encoding="utf-8") + finally: + test_path.write_text(file_content, encoding="utf-8") + + +def modify_autouse_fixture(test_paths: list[Path]) -> None: + # find fixutre definition in conftetst.py (the one closest to the test) + # get fixtures present in override-fixtures in pyproject.toml + # add if marker closest return + conftest_files = find_conftest_files(test_paths) + for cf_file in conftest_files: + # iterate over all functions in the file + # if function has autouse fixture, modify function to bypass with custom marker + disable_autouse(cf_file) + + +# # reuse line profiler utils to add decorator and import to test fns +def add_custom_marker_to_all_tests(test_paths: list[Path]) -> None: + for test_path in test_paths: + # read file + file_content = test_path.read_text(encoding="utf-8") + module = cst.parse_module(file_content) + importadder = ImportAdder("import pytest") + modified_module = module.visit(importadder) + modified_module = isort.code(modified_module.code, float_to_top=True) + qualified_fn_names = get_function_qualified_names(test_path) + for fn_name in qualified_fn_names: + modified_module = add_decorator_to_qualified_function( + modified_module, fn_name, "pytest.mark.codeflash_no_autouse" + ) + # write the modified module back to the file + test_path.write_text(modified_module.code, encoding="utf-8") class OptimFunctionCollector(cst.CSTVisitor): diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 12aeff3fa..fbf69a948 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -21,7 +21,11 @@ from codeflash.benchmarking.utils import process_benchmark_data from codeflash.cli_cmds.console import code_print, console, logger, progress_bar from codeflash.code_utils import env_utils -from codeflash.code_utils.code_replacer import replace_function_definitions_in_module +from codeflash.code_utils.code_replacer import ( + add_custom_marker_to_all_tests, + modify_autouse_fixture, + replace_function_definitions_in_module, +) from codeflash.code_utils.code_utils import ( ImportErrorPattern, cleanup_paths, @@ -742,6 +746,10 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, list[Functi f"{concolic_coverage_test_files_count} concolic coverage test file" f"{'s' if concolic_coverage_test_files_count != 1 else ''} for {func_qualname}" ) + logger.debug("disabling all autouse fixtures associated with the test files") + modify_autouse_fixture(list(unique_instrumented_test_files)) + logger.debug("add custom marker to all tests") + add_custom_marker_to_all_tests(list(unique_instrumented_test_files)) return unique_instrumented_test_files def generate_tests_and_optimizations( From b065cbcfce041654c4681d8de7946b5aa6db5243 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Thu, 5 Jun 2025 16:51:12 -0700 Subject: [PATCH 11/21] start testing now --- .../benchmarks/test_benchmark_bubble_sort.py | 19 ----------------- codeflash/code_utils/code_replacer.py | 21 ++++++++++--------- codeflash/code_utils/code_utils.py | 5 +++++ codeflash/optimization/function_optimizer.py | 8 ++++--- 4 files changed, 21 insertions(+), 32 deletions(-) diff --git a/code_to_optimize/tests/pytest/benchmarks/test_benchmark_bubble_sort.py b/code_to_optimize/tests/pytest/benchmarks/test_benchmark_bubble_sort.py index ce7839119..3d7b24a6c 100644 --- a/code_to_optimize/tests/pytest/benchmarks/test_benchmark_bubble_sort.py +++ b/code_to_optimize/tests/pytest/benchmarks/test_benchmark_bubble_sort.py @@ -3,29 +3,10 @@ from code_to_optimize.bubble_sort import sorter -class DummyBenchmark: - def __call__(self, func, *args, **kwargs): - # Mimic calling benchmark(func, *args, **kwargs) - return func(*args, **kwargs) - - def __getattr__(self, name): - # Mimic benchmark attributes like .pedantic, .extra_info etc. - def dummy(*args, **kwargs): - return None - - return dummy - - -@pytest.fixture -def benchmark(): - return DummyBenchmark() - - def test_sort(benchmark): result = benchmark(sorter, list(reversed(range(500)))) assert result == list(range(500)) - # This should not be picked up as a benchmark test def test_sort2(): result = sorter(list(reversed(range(500)))) diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index 982a21f80..dc59077ab 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -76,26 +76,27 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu @contextlib.contextmanager -def disable_autouse(test_path: Path) -> None: +def disable_autouse(test_path: Path) -> str: file_content = test_path.read_text(encoding="utf-8") - try: - module = cst.parse_module(file_content) - disable_autouse_fixture = AutouseFixtureModifier() - modified_module = module.visit(disable_autouse_fixture) - test_path.write_text(modified_module.code, encoding="utf-8") - finally: - test_path.write_text(file_content, encoding="utf-8") + module = cst.parse_module(file_content) + disable_autouse_fixture = AutouseFixtureModifier() + modified_module = module.visit(disable_autouse_fixture) + test_path.write_text(modified_module.code, encoding="utf-8") + return file_content -def modify_autouse_fixture(test_paths: list[Path]) -> None: +def modify_autouse_fixture(test_paths: list[Path]) -> dict[Path, list[str]]: # find fixutre definition in conftetst.py (the one closest to the test) # get fixtures present in override-fixtures in pyproject.toml # add if marker closest return + file_content_map = {} conftest_files = find_conftest_files(test_paths) for cf_file in conftest_files: # iterate over all functions in the file # if function has autouse fixture, modify function to bypass with custom marker - disable_autouse(cf_file) + original_content = disable_autouse(cf_file) + file_content_map[cf_file] = original_content + return file_content_map # # reuse line profiler utils to add decorator and import to test fns diff --git a/codeflash/code_utils/code_utils.py b/codeflash/code_utils/code_utils.py index 5fc9bd9e9..6a9de176b 100644 --- a/codeflash/code_utils/code_utils.py +++ b/codeflash/code_utils/code_utils.py @@ -208,3 +208,8 @@ def cleanup_paths(paths: list[Path]) -> None: shutil.rmtree(path, ignore_errors=True) else: path.unlink(missing_ok=True) + + +def restore_conftest(path_to_content_map: dict[Path, str]) -> None: + for path, file_content in path_to_content_map.items(): + path.write_text(file_content, encoding="utf8") diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index fbf69a948..3aba65847 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -33,6 +33,7 @@ get_run_tmp_file, has_any_async_functions, module_name_from_file_path, + restore_conftest, ) from codeflash.code_utils.config_consts import ( INDIVIDUAL_TESTCASE_TIMEOUT, @@ -212,7 +213,8 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 for key in set(self.function_to_tests) | set(function_to_concolic_tests) } instrumented_unittests_created_for_function = self.instrument_existing_tests(function_to_all_tests) - + # logger.debug("disabling all autouse fixtures associated with the test files") + original_conftest_content = modify_autouse_fixture(list(instrumented_unittests_created_for_function)) # Get a dict of file_path_to_classes of fto and helpers_of_fto file_path_to_helper_classes = defaultdict(set) for function_source in code_context.helper_functions: @@ -234,6 +236,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 ) if not is_successful(baseline_result): + restore_conftest(original_conftest_content) cleanup_paths(paths_to_cleanup) return Failure(baseline_result.failure()) @@ -241,6 +244,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 if isinstance(original_code_baseline, OriginalCodeBaseline) and not coverage_critic( original_code_baseline.coverage_results, self.args.test_framework ): + restore_conftest(original_conftest_content) cleanup_paths(paths_to_cleanup) return Failure("The threshold for test coverage was not met.") # request for new optimizations but don't block execution, check for completion later @@ -746,8 +750,6 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, list[Functi f"{concolic_coverage_test_files_count} concolic coverage test file" f"{'s' if concolic_coverage_test_files_count != 1 else ''} for {func_qualname}" ) - logger.debug("disabling all autouse fixtures associated with the test files") - modify_autouse_fixture(list(unique_instrumented_test_files)) logger.debug("add custom marker to all tests") add_custom_marker_to_all_tests(list(unique_instrumented_test_files)) return unique_instrumented_test_files From 5c49694180e0ac2b3e8a10d7a9c5b34e6c885d07 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Thu, 5 Jun 2025 19:16:20 -0700 Subject: [PATCH 12/21] start cleaning up and testing --- code_to_optimize/tests/pytest/conftest.py | 17 ---- .../tests/pytest/test_bubble_sort.py | 2 +- codeflash/cli_cmds/cmd_init.py | 16 ++-- codeflash/code_utils/code_extractor.py | 49 ---------- codeflash/code_utils/code_replacer.py | 90 +++++++++++++++---- codeflash/optimization/function_optimizer.py | 13 ++- 6 files changed, 95 insertions(+), 92 deletions(-) delete mode 100644 code_to_optimize/tests/pytest/conftest.py diff --git a/code_to_optimize/tests/pytest/conftest.py b/code_to_optimize/tests/pytest/conftest.py deleted file mode 100644 index d382726e0..000000000 --- a/code_to_optimize/tests/pytest/conftest.py +++ /dev/null @@ -1,17 +0,0 @@ -import pytest -import time - - -@pytest.fixture(autouse=True) -def fixture1(request): - start_time = time.time() - time.sleep(0.1) - yield - print(f"Took {time.time() - start_time} seconds") - - -@pytest.fixture(autouse=True) -def fixture2(request): # We don't need this fixture during testing - print("not doing anything") - yield - print("did nothing") diff --git a/code_to_optimize/tests/pytest/test_bubble_sort.py b/code_to_optimize/tests/pytest/test_bubble_sort.py index eccad6e09..b848a990f 100644 --- a/code_to_optimize/tests/pytest/test_bubble_sort.py +++ b/code_to_optimize/tests/pytest/test_bubble_sort.py @@ -1,5 +1,5 @@ from code_to_optimize.bubble_sort import sorter - +import pytest def test_sort(): input = [5, 4, 3, 2, 1, 0] diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index 32e826b81..08743b2ef 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -16,6 +16,7 @@ import tomlkit from git import InvalidGitRepositoryError, Repo from pydantic.dataclasses import dataclass +from tomlkit import table from codeflash.api.cfapi import is_github_app_installed_on_repo from codeflash.cli_cmds.cli_common import apologize_and_exit, inquirer_wrapper, inquirer_wrapper_path @@ -728,11 +729,16 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None: tool_section = pyproject_data.get("tool", tomlkit.table()) tool_section["codeflash"] = codeflash_section pyproject_data["tool"] = tool_section - if "tool.pytest.ini_options" not in pyproject_data: - pyproject_data["tool.pytest.ini_options"] = {} - if "markers" not in pyproject_data["tool.pytest.ini_options"]: - pyproject_data["tool.pytest.ini_options"]["markers"] = [] - pyproject_data["tool.pytest.ini_options"]["markers"].append("codeflash_no_autouse") + # Create [tool.pytest.ini_options] if it doesn't exist + tool_section = pyproject_data.get("tool", table()) + pytest_section = tool_section.get("pytest", table()) + ini_options = pytest_section.get("ini_options", table()) + # Define or overwrite the 'markers' array + ini_options["markers"] = ["codeflash_no_autouse"] + # Set updated sections back + pytest_section["ini_options"] = ini_options + tool_section["pytest"] = pytest_section + pyproject_data["tool"] = tool_section with toml_path.open("w", encoding="utf8") as pyproject_file: pyproject_file.write(tomlkit.dumps(pyproject_data)) click.echo(f"✅ Added Codeflash configuration to {toml_path}") diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index dfd0dbc08..0dcc2357f 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -5,11 +5,9 @@ from typing import TYPE_CHECKING, Optional import libcst as cst -from libcst import MetadataWrapper from libcst.codemod import CodemodContext from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor from libcst.helpers import calculate_module_and_package -from libcst.metadata import FullyQualifiedNameProvider from codeflash.cli_cmds.console import logger from codeflash.models.models import FunctionParent @@ -23,53 +21,6 @@ from codeflash.models.models import FunctionSource -class FunctionNameCollector(cst.CSTVisitor): - """A LibCST visitor that collects the fully qualified names of all functions.""" - - METADATA_DEPENDENCIES = (FullyQualifiedNameProvider,) - - def __init__(self) -> None: - super().__init__() - self.qualified_names: set[str] = set() - - def visit_FunctionDef(self, node: cst.FunctionDef) -> None: - """Visits a function definition node and extracts its qualified name.""" - try: - q_names = self.get_metadata(FullyQualifiedNameProvider, node) - for q_name in q_names: - self.qualified_names.add(q_name.name) - except KeyError: - # This can happen for functions defined in scopes where a qualified - # name cannot be determined. - pass - - -def get_function_qualified_names(file_path: Path) -> list[str]: - """Parse a Python file and returns a list of fully qualified function names. - - Args: - file_path: The path to the Python file. - - Returns: - A list of string representations of the qualified function names. - - """ - with file_path.open("r") as f: - source_code = f.read() - - # Parse the source code into a CST - module = cst.parse_module(source_code) - - # Wrap the module with a metadata wrapper to enable name resolution - wrapper = MetadataWrapper(module) - - # Create an instance of the visitor and visit the wrapped module - visitor = FunctionNameCollector() - wrapper.visit(visitor) - - return list(visitor.qualified_names) - - class GlobalAssignmentCollector(cst.CSTVisitor): """Collects all global assignment statements.""" diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index dc59077ab..ce4dcc9d9 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -1,7 +1,6 @@ from __future__ import annotations import ast -import contextlib from collections import defaultdict from functools import lru_cache from typing import TYPE_CHECKING, Optional, TypeVar @@ -11,13 +10,9 @@ import libcst.matchers as m from codeflash.cli_cmds.console import logger -from codeflash.code_utils.code_extractor import ( - add_global_assignments, - add_needed_imports_from_module, - get_function_qualified_names, -) +from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module from codeflash.code_utils.config_parser import find_conftest_files -from codeflash.code_utils.line_profile_utils import ImportAdder, add_decorator_to_qualified_function +from codeflash.code_utils.line_profile_utils import ImportAdder from codeflash.models.models import FunctionParent if TYPE_CHECKING: @@ -42,6 +37,74 @@ def normalize_code(code: str) -> str: return ast.unparse(normalize_node(ast.parse(code))) +class PytestMarkAdder(cst.CSTTransformer): + """Transformer that adds pytest marks to test functions.""" + + def __init__(self, mark_name: str) -> None: + super().__init__() + self.mark_name = mark_name + self.has_pytest_import = False + + def visit_Module(self, node: cst.Module) -> None: + """Check if pytest is already imported.""" + for statement in node.body: + if isinstance(statement, cst.SimpleStatementLine): + for stmt in statement.body: + if isinstance(stmt, cst.Import): + for import_alias in stmt.names: + if isinstance(import_alias, cst.ImportAlias) and import_alias.name.value == "pytest": + self.has_pytest_import = True + elif isinstance(stmt, cst.ImportFrom) and stmt.module and stmt.module.value == "pytest": + self.has_pytest_import = True + + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002 + """Add pytest import if not present.""" + if not self.has_pytest_import: + # Create import statement + import_stmt = cst.SimpleStatementLine(body=[cst.Import(names=[cst.ImportAlias(name=cst.Name("pytest"))])]) + # Add import at the beginning + updated_node = updated_node.with_changes(body=[import_stmt, *updated_node.body]) + return updated_node + + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002 + """Add pytest mark to test functions.""" + # Check if the mark already exists + for decorator in updated_node.decorators: + if self._is_pytest_mark(decorator.decorator, self.mark_name): + return updated_node + + # Create the pytest mark decorator + mark_decorator = self._create_pytest_mark() + + # Add the decorator + new_decorators = [*list(updated_node.decorators), mark_decorator] + return updated_node.with_changes(decorators=new_decorators) + + def _is_pytest_mark(self, decorator: cst.BaseExpression, mark_name: str) -> bool: + """Check if a decorator is a specific pytest mark.""" + if isinstance(decorator, cst.Attribute): + if ( + isinstance(decorator.value, cst.Attribute) + and isinstance(decorator.value.value, cst.Name) + and decorator.value.value.value == "pytest" + and decorator.value.attr.value == "mark" + and decorator.attr.value == mark_name + ): + return True + elif isinstance(decorator, cst.Call) and isinstance(decorator.func, cst.Attribute): + return self._is_pytest_mark(decorator.func, mark_name) + return False + + def _create_pytest_mark(self) -> cst.Decorator: + """Create a pytest mark decorator.""" + # Base: pytest.mark.{mark_name} + mark_attr = cst.Attribute( + value=cst.Attribute(value=cst.Name("pytest"), attr=cst.Name("mark")), attr=cst.Name(self.mark_name) + ) + decorator = mark_attr + return cst.Decorator(decorator=decorator) + + class AutouseFixtureModifier(cst.CSTTransformer): def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # Matcher for '@fixture' or '@pytest.fixture' @@ -63,7 +126,7 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu else_block = cst.Else(body=updated_node.body) # 2. Create the new 'if' block that will exit the fixture early. - if_test = cst.parse_expression('request.node.get_closest_marker("no_autouse")') + if_test = cst.parse_expression('request.node.get_closest_marker("codeflash_no_autouse")') yield_statement = cst.parse_statement("yield") if_body = cst.IndentedBlock(body=[yield_statement]) @@ -75,7 +138,6 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu return updated_node -@contextlib.contextmanager def disable_autouse(test_path: Path) -> str: file_content = test_path.read_text(encoding="utf-8") module = cst.parse_module(file_content) @@ -107,13 +169,9 @@ def add_custom_marker_to_all_tests(test_paths: list[Path]) -> None: module = cst.parse_module(file_content) importadder = ImportAdder("import pytest") modified_module = module.visit(importadder) - modified_module = isort.code(modified_module.code, float_to_top=True) - qualified_fn_names = get_function_qualified_names(test_path) - for fn_name in qualified_fn_names: - modified_module = add_decorator_to_qualified_function( - modified_module, fn_name, "pytest.mark.codeflash_no_autouse" - ) - # write the modified module back to the file + modified_module = cst.parse_module(isort.code(modified_module.code, float_to_top=True)) + pytest_mark_adder = PytestMarkAdder("codeflash_no_autouse") + modified_module = modified_module.visit(pytest_mark_adder) test_path.write_text(modified_module.code, encoding="utf-8") diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 3aba65847..98d15b6a5 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -213,8 +213,15 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 for key in set(self.function_to_tests) | set(function_to_concolic_tests) } instrumented_unittests_created_for_function = self.instrument_existing_tests(function_to_all_tests) - # logger.debug("disabling all autouse fixtures associated with the test files") - original_conftest_content = modify_autouse_fixture(list(instrumented_unittests_created_for_function)) + logger.debug("disabling all autouse fixtures associated with the test files") + original_conftest_content = modify_autouse_fixture( + generated_test_paths + generated_perf_test_paths + list(instrumented_unittests_created_for_function) + ) + logger.debug("add custom marker to all tests") + add_custom_marker_to_all_tests( + generated_test_paths + generated_perf_test_paths + list(instrumented_unittests_created_for_function) + ) + # Get a dict of file_path_to_classes of fto and helpers_of_fto file_path_to_helper_classes = defaultdict(set) for function_source in code_context.helper_functions: @@ -750,8 +757,6 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, list[Functi f"{concolic_coverage_test_files_count} concolic coverage test file" f"{'s' if concolic_coverage_test_files_count != 1 else ''} for {func_qualname}" ) - logger.debug("add custom marker to all tests") - add_custom_marker_to_all_tests(list(unique_instrumented_test_files)) return unique_instrumented_test_files def generate_tests_and_optimizations( From 677f0f862c8395f0c95139621df2ee29f42c4f62 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Thu, 5 Jun 2025 19:17:08 -0700 Subject: [PATCH 13/21] undo bbsort --- code_to_optimize/tests/pytest/test_bubble_sort.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/code_to_optimize/tests/pytest/test_bubble_sort.py b/code_to_optimize/tests/pytest/test_bubble_sort.py index b848a990f..eccad6e09 100644 --- a/code_to_optimize/tests/pytest/test_bubble_sort.py +++ b/code_to_optimize/tests/pytest/test_bubble_sort.py @@ -1,5 +1,5 @@ from code_to_optimize.bubble_sort import sorter -import pytest + def test_sort(): input = [5, 4, 3, 2, 1, 0] From 92c97bb58caeca67d9cd6f1e7e0d22489807eb48 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Fri, 6 Jun 2025 11:39:38 -0700 Subject: [PATCH 14/21] use override arg --- codeflash/cli_cmds/cli.py | 1 + codeflash/cli_cmds/cmd_init.py | 13 +--------- codeflash/code_utils/config_parser.py | 7 +++++- codeflash/optimization/function_optimizer.py | 26 +++++++++++--------- 4 files changed, 23 insertions(+), 24 deletions(-) diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index 3a6f7dba2..d677deed9 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -130,6 +130,7 @@ def process_pyproject_config(args: Namespace) -> Namespace: or not hasattr(args, key.replace("-", "_")) ): setattr(args, key.replace("-", "_"), pyproject_config[key]) + args.override_fixtures = pyproject_config.get("override_fixtures", False) assert args.module_root is not None, "--module-root must be specified" assert Path(args.module_root).is_dir(), f"--module-root {args.module_root} must be a valid directory" assert args.tests_root is not None, "--tests-root must be specified" diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index 08743b2ef..059a9abe5 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -16,7 +16,6 @@ import tomlkit from git import InvalidGitRepositoryError, Repo from pydantic.dataclasses import dataclass -from tomlkit import table from codeflash.api.cfapi import is_github_app_installed_on_repo from codeflash.cli_cmds.cli_common import apologize_and_exit, inquirer_wrapper, inquirer_wrapper_path @@ -35,7 +34,7 @@ from argparse import Namespace CODEFLASH_LOGO: str = ( - f"{LF}" # noqa : ISC003 + f"{LF}" # noqa: ISC003 r" _ ___ _ _ " + f"{LF}" r" | | / __)| | | | " + f"{LF}" r" ____ ___ _ | | ____ | |__ | | ____ ___ | | _ " + f"{LF}" @@ -729,16 +728,6 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None: tool_section = pyproject_data.get("tool", tomlkit.table()) tool_section["codeflash"] = codeflash_section pyproject_data["tool"] = tool_section - # Create [tool.pytest.ini_options] if it doesn't exist - tool_section = pyproject_data.get("tool", table()) - pytest_section = tool_section.get("pytest", table()) - ini_options = pytest_section.get("ini_options", table()) - # Define or overwrite the 'markers' array - ini_options["markers"] = ["codeflash_no_autouse"] - # Set updated sections back - pytest_section["ini_options"] = ini_options - tool_section["pytest"] = pytest_section - pyproject_data["tool"] = tool_section with toml_path.open("w", encoding="utf8") as pyproject_file: pyproject_file.write(tomlkit.dumps(pyproject_data)) click.echo(f"✅ Added Codeflash configuration to {toml_path}") diff --git a/codeflash/code_utils/config_parser.py b/codeflash/code_utils/config_parser.py index a0f1c7019..13813cfc1 100644 --- a/codeflash/code_utils/config_parser.py +++ b/codeflash/code_utils/config_parser.py @@ -71,7 +71,12 @@ def parse_config_file( path_keys = ["module-root", "tests-root", "benchmarks-root"] path_list_keys = ["ignore-paths"] str_keys = {"pytest-cmd": "pytest", "git-remote": "origin"} - bool_keys = {"disable-telemetry": False, "disable-imports-sorting": False, "benchmark": False} + bool_keys = { + "override-fixtures": False, + "disable-telemetry": False, + "disable-imports-sorting": False, + "benchmark": False, + } list_str_keys = {"formatter-cmds": ["black $file"]} for key, default_value in str_keys.items(): diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 98d15b6a5..89965dd97 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -213,14 +213,15 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 for key in set(self.function_to_tests) | set(function_to_concolic_tests) } instrumented_unittests_created_for_function = self.instrument_existing_tests(function_to_all_tests) - logger.debug("disabling all autouse fixtures associated with the test files") - original_conftest_content = modify_autouse_fixture( - generated_test_paths + generated_perf_test_paths + list(instrumented_unittests_created_for_function) - ) - logger.debug("add custom marker to all tests") - add_custom_marker_to_all_tests( - generated_test_paths + generated_perf_test_paths + list(instrumented_unittests_created_for_function) - ) + if self.args.override_fixtures: + logger.debug("disabling all autouse fixtures associated with the test files") + original_conftest_content = modify_autouse_fixture( + generated_test_paths + generated_perf_test_paths + list(instrumented_unittests_created_for_function) + ) + logger.debug("add custom marker to all tests") + add_custom_marker_to_all_tests( + generated_test_paths + generated_perf_test_paths + list(instrumented_unittests_created_for_function) + ) # Get a dict of file_path_to_classes of fto and helpers_of_fto file_path_to_helper_classes = defaultdict(set) @@ -243,7 +244,8 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 ) if not is_successful(baseline_result): - restore_conftest(original_conftest_content) + if self.args.override_fixtures: + restore_conftest(original_conftest_content) cleanup_paths(paths_to_cleanup) return Failure(baseline_result.failure()) @@ -251,7 +253,8 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 if isinstance(original_code_baseline, OriginalCodeBaseline) and not coverage_critic( original_code_baseline.coverage_results, self.args.test_framework ): - restore_conftest(original_conftest_content) + if self.args.override_fixtures: + restore_conftest(original_conftest_content) cleanup_paths(paths_to_cleanup) return Failure("The threshold for test coverage was not met.") # request for new optimizations but don't block execution, check for completion later @@ -360,7 +363,8 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 original_helper_code, self.function_to_optimize.file_path, ) - + if self.args.override_fixtures: + restore_conftest(original_conftest_content) if not best_optimization: return Failure(f"No best optimizations found for function {self.function_to_optimize.qualified_name}") return Success(best_optimization) From 97eabcf68c4997121d56c03b27c29c9cc98cbaee Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Fri, 6 Jun 2025 12:02:52 -0700 Subject: [PATCH 15/21] only modify generated files --- codeflash/optimization/function_optimizer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 89965dd97..7bb2981c8 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -214,13 +214,13 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 } instrumented_unittests_created_for_function = self.instrument_existing_tests(function_to_all_tests) if self.args.override_fixtures: - logger.debug("disabling all autouse fixtures associated with the test files") + logger.info("Disabling all autouse fixtures associated with the generated test files") original_conftest_content = modify_autouse_fixture( - generated_test_paths + generated_perf_test_paths + list(instrumented_unittests_created_for_function) + generated_test_paths + generated_perf_test_paths ) - logger.debug("add custom marker to all tests") + logger.info("Add custom marker to generated test files") add_custom_marker_to_all_tests( - generated_test_paths + generated_perf_test_paths + list(instrumented_unittests_created_for_function) + generated_test_paths + generated_perf_test_paths ) # Get a dict of file_path_to_classes of fto and helpers_of_fto From 8550e69b89fe6226723b4d636d15cc246e650017 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Fri, 6 Jun 2025 12:03:18 -0700 Subject: [PATCH 16/21] precommit fix --- codeflash/optimization/function_optimizer.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 7bb2981c8..35004bb0f 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -215,13 +215,9 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 instrumented_unittests_created_for_function = self.instrument_existing_tests(function_to_all_tests) if self.args.override_fixtures: logger.info("Disabling all autouse fixtures associated with the generated test files") - original_conftest_content = modify_autouse_fixture( - generated_test_paths + generated_perf_test_paths - ) + original_conftest_content = modify_autouse_fixture(generated_test_paths + generated_perf_test_paths) logger.info("Add custom marker to generated test files") - add_custom_marker_to_all_tests( - generated_test_paths + generated_perf_test_paths - ) + add_custom_marker_to_all_tests(generated_test_paths + generated_perf_test_paths) # Get a dict of file_path_to_classes of fto and helpers_of_fto file_path_to_helper_classes = defaultdict(set) From a65569cfa04002b15d9fa690dbd4fba51b9e64b4 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Fri, 6 Jun 2025 12:30:30 -0700 Subject: [PATCH 17/21] tests --- tests/test_code_replacement.py | 371 ++++++++++++++++++++++++++++++++- 1 file changed, 363 insertions(+), 8 deletions(-) diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 2e8c2f6fd..0d1940798 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -1,5 +1,6 @@ from __future__ import annotations - +import libcst as cst +from codeflash.code_utils.code_replacer import AutouseFixtureModifier, PytestMarkAdder import dataclasses import os from collections import defaultdict @@ -1139,8 +1140,8 @@ def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]: ) assert ( - new_code - == """from __future__ import annotations + new_code + == """from __future__ import annotations import sys from codeflash.verification.comparator import comparator from enum import Enum @@ -1345,8 +1346,8 @@ def cosine_similarity_top_k( project_root_path=Path(__file__).parent.parent.resolve(), ) assert ( - new_code - == '''import numpy as np + new_code + == '''import numpy as np from pydantic.dataclasses import dataclass from typing import List, Optional, Tuple, Union @dataclass(config=dict(arbitrary_types_allowed=True)) @@ -1404,8 +1405,8 @@ def cosine_similarity_top_k( ) assert ( - new_helper_code - == '''import numpy as np + new_helper_code + == '''import numpy as np from pydantic.dataclasses import dataclass from typing import List, Optional, Tuple, Union @dataclass(config=dict(arbitrary_types_allowed=True)) @@ -1662,6 +1663,7 @@ def new_function2(value): ) assert new_code == original_code + def test_global_reassignment() -> None: original_code = """a=1 print("Hello world") @@ -2131,4 +2133,357 @@ def new_function2(value): ) new_code = code_path.read_text(encoding="utf-8") code_path.unlink(missing_ok=True) - assert new_code.rstrip() == expected_code.rstrip() \ No newline at end of file + assert new_code.rstrip() == expected_code.rstrip() + + +class TestAutouseFixtureModifier: + """Test cases for AutouseFixtureModifier class.""" + + def test_modifies_autouse_fixture_with_pytest_decorator(self): + """Test that autouse fixture with @pytest.fixture is modified correctly.""" + source_code = ''' +import pytest + +@pytest.fixture(autouse=True) +def my_fixture(request): + print("setup") + yield + print("teardown") +''' + expected_code = ''' +import pytest + +@pytest.fixture(autouse=True) +def my_fixture(request): + if request.node.get_closest_marker("codeflash_no_autouse"): + yield + else: + print("setup") + yield + print("teardown") +''' + module = cst.parse_module(source_code) + modifier = AutouseFixtureModifier() + modified_module = module.visit(modifier) + + # Parse expected to normalize formatting + expected_module = cst.parse_module(expected_code) + assert modified_module.code.strip() == expected_module.code.strip() + + def test_modifies_autouse_fixture_with_fixture_decorator(self): + """Test that autouse fixture with @fixture is modified correctly.""" + source_code = ''' +from pytest import fixture + +@fixture(autouse=True) +def my_fixture(request): + setup_code() + yield "value" + cleanup_code() +''' + module = cst.parse_module(source_code) + modifier = AutouseFixtureModifier() + modified_module = module.visit(modifier) + + # Check that the if statement was added + assert 'if request.node.get_closest_marker("codeflash_no_autouse"):' in modified_module.code + assert "yield" in modified_module.code + assert "else:" in modified_module.code + assert "setup_code()" in modified_module.code + assert "cleanup_code()" in modified_module.code + + def test_ignores_non_autouse_fixture(self): + """Test that non-autouse fixtures are not modified.""" + source_code = ''' +import pytest + +@pytest.fixture +def my_fixture(request): + return "test_value" + +@pytest.fixture(scope="session") +def session_fixture(): + return "session_value" +''' + module = cst.parse_module(source_code) + modifier = AutouseFixtureModifier() + modified_module = module.visit(modifier) + + # Code should remain unchanged + assert modified_module.code == source_code + + def test_ignores_regular_functions(self): + """Test that regular functions are not modified.""" + source_code = ''' +def regular_function(): + return "not a fixture" + +@some_other_decorator +def decorated_function(): + return "also not a fixture" +''' + module = cst.parse_module(source_code) + modifier = AutouseFixtureModifier() + modified_module = module.visit(modifier) + + # Code should remain unchanged + assert modified_module.code == source_code + + def test_handles_multiple_autouse_fixtures(self): + """Test that multiple autouse fixtures in the same file are all modified.""" + source_code = ''' +import pytest + +@pytest.fixture(autouse=True) +def fixture_one(request): + yield "one" + +@pytest.fixture(autouse=True) +def fixture_two(request): + yield "two" +''' + module = cst.parse_module(source_code) + modifier = AutouseFixtureModifier() + modified_module = module.visit(modifier) + + # Both fixtures should be modified + code = modified_module.code + assert code.count('if request.node.get_closest_marker("codeflash_no_autouse"):') == 2 + assert code.count("else:") == 2 + + def test_preserves_fixture_with_complex_body(self): + """Test that fixtures with complex bodies are handled correctly.""" + source_code = ''' +import pytest + +@pytest.fixture(autouse=True) +def complex_fixture(request): + try: + setup_database() + configure_logging() + yield get_test_client() + finally: + cleanup_database() + reset_logging() +''' + module = cst.parse_module(source_code) + modifier = AutouseFixtureModifier() + modified_module = module.visit(modifier) + + code = modified_module.code + assert 'if request.node.get_closest_marker("codeflash_no_autouse"):' in code + assert "try:" in code + assert "setup_database()" in code + assert "finally:" in code + assert "cleanup_database()" in code + + +class TestPytestMarkAdder: + """Test cases for PytestMarkAdder class.""" + + def test_adds_pytest_import_when_missing(self): + """Test that pytest import is added when not present.""" + source_code = ''' +def test_something(): + assert True +''' + module = cst.parse_module(source_code) + mark_adder = PytestMarkAdder("codeflash_no_autouse") + modified_module = module.visit(mark_adder) + + code = modified_module.code + assert "import pytest" in code + assert "@pytest.mark.codeflash_no_autouse" in code + + def test_skips_pytest_import_when_present(self): + """Test that pytest import is not duplicated when already present.""" + source_code = ''' +import pytest + +def test_something(): + assert True +''' + module = cst.parse_module(source_code) + mark_adder = PytestMarkAdder("codeflash_no_autouse") + modified_module = module.visit(mark_adder) + + code = modified_module.code + # Should only have one import pytest line + assert code.count("import pytest") == 1 + assert "@pytest.mark.codeflash_no_autouse" in code + + def test_handles_from_pytest_import(self): + """Test that existing 'from pytest import ...' is recognized.""" + source_code = ''' +from pytest import fixture + +def test_something(): + assert True +''' + module = cst.parse_module(source_code) + mark_adder = PytestMarkAdder("codeflash_no_autouse") + modified_module = module.visit(mark_adder) + + code = modified_module.code + # Should not add import pytest since pytest is already imported + assert "import pytest" not in code + assert "from pytest import fixture" in code + assert "@pytest.mark.codeflash_no_autouse" in code + + def test_adds_mark_to_all_functions(self): + """Test that marks are added to all functions in the module.""" + source_code = ''' +import pytest + +def test_first(): + assert True + +def test_second(): + assert False + +def helper_function(): + return "not a test" +''' + module = cst.parse_module(source_code) + mark_adder = PytestMarkAdder("codeflash_no_autouse") + modified_module = module.visit(mark_adder) + + code = modified_module.code + # All functions should get the mark + assert code.count("@pytest.mark.codeflash_no_autouse") == 3 + + def test_skips_existing_mark(self): + """Test that existing marks are not duplicated.""" + source_code = ''' +import pytest + +@pytest.mark.codeflash_no_autouse +def test_already_marked(): + assert True + +def test_needs_mark(): + assert True +''' + module = cst.parse_module(source_code) + mark_adder = PytestMarkAdder("codeflash_no_autouse") + modified_module = module.visit(mark_adder) + + code = modified_module.code + # Should have exactly 2 marks total (one existing, one added) + assert code.count("@pytest.mark.codeflash_no_autouse") == 2 + + def test_handles_different_mark_names(self): + """Test that different mark names work correctly.""" + source_code = ''' +import pytest + +def test_something(): + assert True +''' + module = cst.parse_module(source_code) + mark_adder = PytestMarkAdder("slow") + modified_module = module.visit(mark_adder) + + code = modified_module.code + assert "@pytest.mark.slow" in code + assert "codeflash_no_autouse" not in code + + def test_preserves_existing_decorators(self): + """Test that existing decorators are preserved.""" + source_code = ''' +import pytest + +@pytest.mark.parametrize("value", [1, 2, 3]) +@pytest.fixture +def test_with_decorators(): + assert True +''' + module = cst.parse_module(source_code) + mark_adder = PytestMarkAdder("codeflash_no_autouse") + modified_module = module.visit(mark_adder) + + code = modified_module.code + assert "@pytest.mark.parametrize" in code + assert "@pytest.fixture" in code + assert "@pytest.mark.codeflash_no_autouse" in code + + def test_handles_call_style_existing_marks(self): + """Test recognition of existing marks in call style (with parentheses).""" + source_code = ''' +import pytest + +@pytest.mark.codeflash_no_autouse() +def test_with_call_mark(): + assert True + +def test_needs_mark(): + assert True +''' + module = cst.parse_module(source_code) + mark_adder = PytestMarkAdder("codeflash_no_autouse") + modified_module = module.visit(mark_adder) + + code = modified_module.code + # Should recognize the existing call-style mark and not duplicate + lines_with_mark = [line for line in code.split('\n') if 'codeflash_no_autouse' in line] + assert len(lines_with_mark) == 2 # One existing, one added + + def test_empty_module(self): + """Test handling of empty module.""" + source_code = '' + module = cst.parse_module(source_code) + mark_adder = PytestMarkAdder("codeflash_no_autouse") + modified_module = module.visit(mark_adder) + + # Should just add the import + code = modified_module.code + assert "import pytest" in code + + def test_module_with_only_imports(self): + """Test handling of module with only imports.""" + source_code = ''' +import os +import sys +from pathlib import Path +''' + module = cst.parse_module(source_code) + mark_adder = PytestMarkAdder("codeflash_no_autouse") + modified_module = module.visit(mark_adder) + + code = modified_module.code + assert "import pytest" in code + assert "import os" in code + assert "import sys" in code + assert "from pathlib import Path" in code + + +class TestIntegration: + """Integration tests for both transformers working together.""" + + def test_both_transformers_together(self): + """Test that both transformers can work on the same code.""" + source_code = ''' +import pytest + +@pytest.fixture(autouse=True) +def my_fixture(request): + yield "value" + +def test_something(): + assert True +''' + # First apply AutouseFixtureModifier + module = cst.parse_module(source_code) + autouse_modifier = AutouseFixtureModifier() + modified_module = module.visit(autouse_modifier) + + # Then apply PytestMarkAdder + mark_adder = PytestMarkAdder("codeflash_no_autouse") + final_module = modified_module.visit(mark_adder) + + code = final_module.code + # Should have both modifications + assert 'if request.node.get_closest_marker("codeflash_no_autouse"):' in code + assert "@pytest.mark.codeflash_no_autouse" in code + # Mark should be added to both functions + assert code.count("@pytest.mark.codeflash_no_autouse") == 2 From c93b80e87bacf6fa8407c7cc56f0ada3692e2ee2 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Fri, 6 Jun 2025 13:11:07 -0700 Subject: [PATCH 18/21] Ready to review --- codeflash/cli_cmds/cli.py | 2 +- codeflash/cli_cmds/cmd_init.py | 2 +- codeflash/code_utils/code_replacer.py | 2 - tests/test_code_replacement.py | 208 ++++++++++++++++++++------ 4 files changed, 166 insertions(+), 48 deletions(-) diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index d677deed9..5edff57a0 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -123,6 +123,7 @@ def process_pyproject_config(args: Namespace) -> Namespace: "disable_telemetry", "disable_imports_sorting", "git_remote", + "override_fixtures", ] for key in supported_keys: if key in pyproject_config and ( @@ -130,7 +131,6 @@ def process_pyproject_config(args: Namespace) -> Namespace: or not hasattr(args, key.replace("-", "_")) ): setattr(args, key.replace("-", "_"), pyproject_config[key]) - args.override_fixtures = pyproject_config.get("override_fixtures", False) assert args.module_root is not None, "--module-root must be specified" assert Path(args.module_root).is_dir(), f"--module-root {args.module_root} must be a valid directory" assert args.tests_root is not None, "--tests-root must be specified" diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index 059a9abe5..bfe600fa4 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -723,11 +723,11 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None: formatter_cmds.append("disabled") check_formatter_installed(formatter_cmds, exit_on_failure=False) codeflash_section["formatter-cmds"] = formatter_cmds - codeflash_section["override-fixtures"] = False # don't override fixtures by default, let the user decide # Add the 'codeflash' section, ensuring 'tool' section exists tool_section = pyproject_data.get("tool", tomlkit.table()) tool_section["codeflash"] = codeflash_section pyproject_data["tool"] = tool_section + with toml_path.open("w", encoding="utf8") as pyproject_file: pyproject_file.write(tomlkit.dumps(pyproject_data)) click.echo(f"✅ Added Codeflash configuration to {toml_path}") diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index ce4dcc9d9..932053fc6 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -54,8 +54,6 @@ def visit_Module(self, node: cst.Module) -> None: for import_alias in stmt.names: if isinstance(import_alias, cst.ImportAlias) and import_alias.name.value == "pytest": self.has_pytest_import = True - elif isinstance(stmt, cst.ImportFrom) and stmt.module and stmt.module.value == "pytest": - self.has_pytest_import = True def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002 """Add pytest import if not present.""" diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 0d1940798..e848e4525 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -2180,17 +2180,25 @@ def my_fixture(request): setup_code() yield "value" cleanup_code() +''' + expected_code = ''' +from pytest import fixture + +@fixture(autouse=True) +def my_fixture(request): + if request.node.get_closest_marker("codeflash_no_autouse"): + yield + else: + setup_code() + yield "value" + cleanup_code() ''' module = cst.parse_module(source_code) modifier = AutouseFixtureModifier() modified_module = module.visit(modifier) # Check that the if statement was added - assert 'if request.node.get_closest_marker("codeflash_no_autouse"):' in modified_module.code - assert "yield" in modified_module.code - assert "else:" in modified_module.code - assert "setup_code()" in modified_module.code - assert "cleanup_code()" in modified_module.code + assert modified_module.code.strip() == expected_code.strip() def test_ignores_non_autouse_fixture(self): """Test that non-autouse fixtures are not modified.""" @@ -2241,6 +2249,23 @@ def fixture_one(request): @pytest.fixture(autouse=True) def fixture_two(request): yield "two" +''' + expected_code = ''' +import pytest + +@pytest.fixture(autouse=True) +def fixture_one(request): + if request.node.get_closest_marker("codeflash_no_autouse"): + yield + else: + yield "one" + +@pytest.fixture(autouse=True) +def fixture_two(request): + if request.node.get_closest_marker("codeflash_no_autouse"): + yield + else: + yield "two" ''' module = cst.parse_module(source_code) modifier = AutouseFixtureModifier() @@ -2248,8 +2273,7 @@ def fixture_two(request): # Both fixtures should be modified code = modified_module.code - assert code.count('if request.node.get_closest_marker("codeflash_no_autouse"):') == 2 - assert code.count("else:") == 2 + assert code==expected_code def test_preserves_fixture_with_complex_body(self): """Test that fixtures with complex bodies are handled correctly.""" @@ -2258,24 +2282,39 @@ def test_preserves_fixture_with_complex_body(self): @pytest.fixture(autouse=True) def complex_fixture(request): - try: - setup_database() - configure_logging() - yield get_test_client() - finally: - cleanup_database() - reset_logging() + if request.node.get_closest_marker("codeflash_no_autouse"): + yield + else: + try: + setup_database() + configure_logging() + yield get_test_client() + finally: + cleanup_database() + reset_logging() +''' + expected_code = ''' +import pytest + +@pytest.fixture(autouse=True) +def complex_fixture(request): + if request.node.get_closest_marker("codeflash_no_autouse"): + yield + else: + try: + setup_database() + configure_logging() + yield get_test_client() + finally: + cleanup_database() + reset_logging() ''' module = cst.parse_module(source_code) modifier = AutouseFixtureModifier() modified_module = module.visit(modifier) code = modified_module.code - assert 'if request.node.get_closest_marker("codeflash_no_autouse"):' in code - assert "try:" in code - assert "setup_database()" in code - assert "finally:" in code - assert "cleanup_database()" in code + assert code==expected_code class TestPytestMarkAdder: @@ -2284,6 +2323,12 @@ class TestPytestMarkAdder: def test_adds_pytest_import_when_missing(self): """Test that pytest import is added when not present.""" source_code = ''' +def test_something(): + assert True +''' + expected_code = ''' +import pytest +@pytest.mark.codeflash_no_autouse def test_something(): assert True ''' @@ -2292,14 +2337,20 @@ def test_something(): modified_module = module.visit(mark_adder) code = modified_module.code - assert "import pytest" in code - assert "@pytest.mark.codeflash_no_autouse" in code + assert code==expected_code def test_skips_pytest_import_when_present(self): """Test that pytest import is not duplicated when already present.""" source_code = ''' import pytest +def test_something(): + assert True +''' + expected_code = ''' +import pytest + +@pytest.mark.codeflash_no_autouse def test_something(): assert True ''' @@ -2309,8 +2360,7 @@ def test_something(): code = modified_module.code # Should only have one import pytest line - assert code.count("import pytest") == 1 - assert "@pytest.mark.codeflash_no_autouse" in code + assert code==expected_code def test_handles_from_pytest_import(self): """Test that existing 'from pytest import ...' is recognized.""" @@ -2320,15 +2370,21 @@ def test_handles_from_pytest_import(self): def test_something(): assert True ''' + expected_code = ''' +import pytest +from pytest import fixture + +@pytest.mark.codeflash_no_autouse +def test_something(): + assert True + ''' module = cst.parse_module(source_code) mark_adder = PytestMarkAdder("codeflash_no_autouse") modified_module = module.visit(mark_adder) code = modified_module.code # Should not add import pytest since pytest is already imported - assert "import pytest" not in code - assert "from pytest import fixture" in code - assert "@pytest.mark.codeflash_no_autouse" in code + assert code.strip()==expected_code.strip() def test_adds_mark_to_all_functions(self): """Test that marks are added to all functions in the module.""" @@ -2341,6 +2397,21 @@ def test_first(): def test_second(): assert False +def helper_function(): + return "not a test" +''' + expected_code = ''' +import pytest + +@pytest.mark.codeflash_no_autouse +def test_first(): + assert True + +@pytest.mark.codeflash_no_autouse +def test_second(): + assert False + +@pytest.mark.codeflash_no_autouse def helper_function(): return "not a test" ''' @@ -2350,7 +2421,7 @@ def helper_function(): code = modified_module.code # All functions should get the mark - assert code.count("@pytest.mark.codeflash_no_autouse") == 3 + assert code==expected_code def test_skips_existing_mark(self): """Test that existing marks are not duplicated.""" @@ -2361,6 +2432,17 @@ def test_skips_existing_mark(self): def test_already_marked(): assert True +def test_needs_mark(): + assert True +''' + expected_code = ''' +import pytest + +@pytest.mark.codeflash_no_autouse +def test_already_marked(): + assert True + +@pytest.mark.codeflash_no_autouse def test_needs_mark(): assert True ''' @@ -2370,13 +2452,20 @@ def test_needs_mark(): code = modified_module.code # Should have exactly 2 marks total (one existing, one added) - assert code.count("@pytest.mark.codeflash_no_autouse") == 2 + assert code==expected_code def test_handles_different_mark_names(self): """Test that different mark names work correctly.""" source_code = ''' import pytest +def test_something(): + assert True +''' + expected_code = ''' +import pytest + +@pytest.mark.slow def test_something(): assert True ''' @@ -2385,8 +2474,7 @@ def test_something(): modified_module = module.visit(mark_adder) code = modified_module.code - assert "@pytest.mark.slow" in code - assert "codeflash_no_autouse" not in code + assert code==expected_code def test_preserves_existing_decorators(self): """Test that existing decorators are preserved.""" @@ -2395,6 +2483,15 @@ def test_preserves_existing_decorators(self): @pytest.mark.parametrize("value", [1, 2, 3]) @pytest.fixture +def test_with_decorators(): + assert True +''' + expected_code = ''' +import pytest + +@pytest.mark.parametrize("value", [1, 2, 3]) +@pytest.fixture +@pytest.mark.codeflash_no_autouse def test_with_decorators(): assert True ''' @@ -2403,9 +2500,7 @@ def test_with_decorators(): modified_module = module.visit(mark_adder) code = modified_module.code - assert "@pytest.mark.parametrize" in code - assert "@pytest.fixture" in code - assert "@pytest.mark.codeflash_no_autouse" in code + assert code==expected_code def test_handles_call_style_existing_marks(self): """Test recognition of existing marks in call style (with parentheses).""" @@ -2416,6 +2511,17 @@ def test_handles_call_style_existing_marks(self): def test_with_call_mark(): assert True +def test_needs_mark(): + assert True +''' + expected_code = ''' +import pytest + +@pytest.mark.codeflash_no_autouse() +def test_with_call_mark(): + assert True + +@pytest.mark.codeflash_no_autouse def test_needs_mark(): assert True ''' @@ -2425,8 +2531,7 @@ def test_needs_mark(): code = modified_module.code # Should recognize the existing call-style mark and not duplicate - lines_with_mark = [line for line in code.split('\n') if 'codeflash_no_autouse' in line] - assert len(lines_with_mark) == 2 # One existing, one added + assert code==expected_code def test_empty_module(self): """Test handling of empty module.""" @@ -2437,7 +2542,7 @@ def test_empty_module(self): # Should just add the import code = modified_module.code - assert "import pytest" in code + assert code =='import pytest' def test_module_with_only_imports(self): """Test handling of module with only imports.""" @@ -2445,16 +2550,19 @@ def test_module_with_only_imports(self): import os import sys from pathlib import Path +''' + expected_code = ''' +import pytest +import os +import sys +from pathlib import Path ''' module = cst.parse_module(source_code) mark_adder = PytestMarkAdder("codeflash_no_autouse") modified_module = module.visit(mark_adder) code = modified_module.code - assert "import pytest" in code - assert "import os" in code - assert "import sys" in code - assert "from pathlib import Path" in code + assert code==expected_code class TestIntegration: @@ -2469,6 +2577,21 @@ def test_both_transformers_together(self): def my_fixture(request): yield "value" +def test_something(): + assert True +''' + expected_code = ''' +import pytest + +@pytest.fixture(autouse=True) +@pytest.mark.codeflash_no_autouse +def my_fixture(request): + if request.node.get_closest_marker("codeflash_no_autouse"): + yield + else: + yield "value" + +@pytest.mark.codeflash_no_autouse def test_something(): assert True ''' @@ -2483,7 +2606,4 @@ def test_something(): code = final_module.code # Should have both modifications - assert 'if request.node.get_closest_marker("codeflash_no_autouse"):' in code - assert "@pytest.mark.codeflash_no_autouse" in code - # Mark should be added to both functions - assert code.count("@pytest.mark.codeflash_no_autouse") == 2 + assert code==expected_code From 62909dbc194e2c260ff8780288c9479a87be8d05 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Fri, 6 Jun 2025 13:19:39 -0700 Subject: [PATCH 19/21] Update tests/test_code_replacement.py --- tests/test_code_replacement.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 8bfb2dc1f..c3a9e468f 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -1664,6 +1664,7 @@ def new_function2(value): ) assert new_code == original_code + def test_global_reassignment() -> None: original_code = """a=1 print("Hello world") From 6abb3df8ef50ccd294bab71662a7b81e7fb0e6ef Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Fri, 6 Jun 2025 13:24:17 -0700 Subject: [PATCH 20/21] 1 test failing --- tests/test_code_replacement.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index c3a9e468f..3dd94b2a7 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -2315,7 +2315,7 @@ def complex_fixture(request): modified_module = module.visit(modifier) code = modified_module.code - assert code==expected_code + assert code.strip()==expected_code.strip() class TestPytestMarkAdder: From 85ce164d6a0e5b80577ffa528b16f438d2f85a13 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Fri, 6 Jun 2025 13:41:34 -0700 Subject: [PATCH 21/21] rstrip for comparing strings --- tests/test_code_replacement.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 3dd94b2a7..1dab67c97 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -820,7 +820,7 @@ def main_method(self): ) func_optimizer = FunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config) code_context = func_optimizer.get_code_optimization_context().unwrap() - assert code_context.testgen_context_code == get_code_output + assert code_context.testgen_context_code.rstrip() == get_code_output.rstrip() def test_code_replacement11() -> None: @@ -2283,16 +2283,13 @@ def test_preserves_fixture_with_complex_body(self): @pytest.fixture(autouse=True) def complex_fixture(request): - if request.node.get_closest_marker("codeflash_no_autouse"): - yield - else: - try: - setup_database() - configure_logging() - yield get_test_client() - finally: - cleanup_database() - reset_logging() + try: + setup_database() + configure_logging() + yield get_test_client() + finally: + cleanup_database() + reset_logging() ''' expected_code = ''' import pytest @@ -2315,7 +2312,7 @@ def complex_fixture(request): modified_module = module.visit(modifier) code = modified_module.code - assert code.strip()==expected_code.strip() + assert code.rstrip()==expected_code.rstrip() class TestPytestMarkAdder: