diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml index d691072aa..6b17da886 100644 --- a/.github/workflows/claude.yml +++ b/.github/workflows/claude.yml @@ -42,11 +42,17 @@ jobs: uv venv --seed uv sync + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.AWS_ROLE_TO_ASSUME }} + aws-region: ${{ secrets.AWS_REGION }} + - name: Run Claude Code id: claude uses: anthropics/claude-code-action@v1 with: - use_foundry: "true" + use_bedrock: "true" use_sticky_comment: true allowed_bots: "claude[bot],codeflash-ai[bot]" prompt: | @@ -173,12 +179,9 @@ jobs: 2. For each optimization PR: - Check if CI is passing: `gh pr checks ` - If all checks pass, merge it: `gh pr merge --squash --delete-branch` - claude_args: '--model claude-opus-4-6 --allowedTools "mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*),Bash(gh pr checks:*),Bash(gh pr merge:*),Bash(gh issue view:*),Bash(gh issue list:*),Bash(gh api:*),Bash(uv run prek *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(uv run pytest *),Bash(git status*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git diff *),Bash(git checkout *),Read,Glob,Grep,Edit"' + claude_args: '--model us.anthropic.claude-opus-4-6-v1 --allowedTools "mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*),Bash(gh pr checks:*),Bash(gh pr merge:*),Bash(gh issue view:*),Bash(gh issue list:*),Bash(gh api:*),Bash(uv run prek *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(uv run pytest *),Bash(git status*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git diff *),Bash(git checkout *),Read,Glob,Grep,Edit"' additional_permissions: | actions: read - env: - ANTHROPIC_FOUNDRY_API_KEY: ${{ secrets.AZURE_ANTHROPIC_API_KEY }} - ANTHROPIC_FOUNDRY_BASE_URL: ${{ secrets.AZURE_ANTHROPIC_ENDPOINT }} # @claude mentions (can edit and push) - restricted to maintainers only claude-mention: @@ -240,14 +243,17 @@ jobs: uv venv --seed uv sync + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.AWS_ROLE_TO_ASSUME }} + aws-region: ${{ secrets.AWS_REGION }} + - name: Run Claude Code id: claude uses: anthropics/claude-code-action@v1 with: - use_foundry: "true" - claude_args: '--model claude-opus-4-6 --allowedTools "Read,Edit,Write,Glob,Grep,Bash(git status*),Bash(git diff*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git log*),Bash(git merge*),Bash(git fetch*),Bash(git checkout*),Bash(git branch*),Bash(uv run prek *),Bash(prek *),Bash(uv run ruff *),Bash(uv run pytest *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(gh pr comment*),Bash(gh pr view*),Bash(gh pr diff*),Bash(gh pr merge*),Bash(gh pr close*)"' + use_bedrock: "true" + claude_args: '--model us.anthropic.claude-opus-4-6-v1 --allowedTools "Read,Edit,Write,Glob,Grep,Bash(git status*),Bash(git diff*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git log*),Bash(git merge*),Bash(git fetch*),Bash(git checkout*),Bash(git branch*),Bash(uv run prek *),Bash(prek *),Bash(uv run ruff *),Bash(uv run pytest *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(gh pr comment*),Bash(gh pr view*),Bash(gh pr diff*),Bash(gh pr merge*),Bash(gh pr close*)"' additional_permissions: | actions: read - env: - ANTHROPIC_FOUNDRY_API_KEY: ${{ secrets.AZURE_ANTHROPIC_API_KEY }} - ANTHROPIC_FOUNDRY_BASE_URL: ${{ secrets.AZURE_ANTHROPIC_ENDPOINT }} diff --git a/.github/workflows/duplicate-code-detector.yml b/.github/workflows/duplicate-code-detector.yml index ea36bf54d..83896d1ea 100644 --- a/.github/workflows/duplicate-code-detector.yml +++ b/.github/workflows/duplicate-code-detector.yml @@ -42,10 +42,16 @@ jobs: } EOF + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.AWS_ROLE_TO_ASSUME }} + aws-region: ${{ secrets.AWS_REGION }} + - name: Run Claude Code uses: anthropics/claude-code-action@v1 with: - use_foundry: "true" + use_bedrock: "true" use_sticky_comment: true allowed_bots: "claude[bot],codeflash-ai[bot]" claude_args: '--mcp-config /tmp/mcp-config/mcp-servers.json --allowedTools "Read,Glob,Grep,Bash(git diff:*),Bash(git log:*),Bash(git show:*),Bash(wc *),Bash(find *),mcp__serena__*"' @@ -105,10 +111,6 @@ jobs: - Concrete refactoring suggestion If no significant duplication is found, say so briefly. Do not create issues — just comment on the PR. - env: - ANTHROPIC_FOUNDRY_API_KEY: ${{ secrets.AZURE_ANTHROPIC_API_KEY }} - ANTHROPIC_FOUNDRY_BASE_URL: ${{ secrets.AZURE_ANTHROPIC_ENDPOINT }} - - name: Stop Serena if: always() run: docker stop serena && docker rm serena || true diff --git a/.github/workflows/js-tests.yml b/.github/workflows/js-tests.yml deleted file mode 100644 index 0d56e8831..000000000 --- a/.github/workflows/js-tests.yml +++ /dev/null @@ -1,50 +0,0 @@ -name: JavaScript/TypeScript Integration Tests - -on: - push: - branches: - - main - pull_request: - workflow_dispatch: - -concurrency: - group: ${{ github.workflow }}-${{ github.ref_name }} - cancel-in-progress: true - -jobs: - js-integration-tests: - name: JS/TS Integration Tests - runs-on: ubuntu-latest - steps: - - name: Checkout code - uses: actions/checkout@v4 - with: - fetch-depth: 0 - token: ${{ secrets.GITHUB_TOKEN }} - - - name: Setup Node.js - uses: actions/setup-node@v4 - with: - node-version: '20' - - - name: Install uv - uses: astral-sh/setup-uv@v6 - - - name: Install Python dependencies - run: | - uv venv --seed - uv sync - - - name: Install npm dependencies for test projects - run: | - npm install --prefix code_to_optimize/js/code_to_optimize_js - npm install --prefix code_to_optimize/js/code_to_optimize_ts - npm install --prefix code_to_optimize/js/code_to_optimize_vitest - - - name: Run JavaScript integration tests - run: | - uv run pytest tests/languages/javascript/ -v - uv run pytest tests/test_languages/test_vitest_e2e.py -v - uv run pytest tests/test_languages/test_javascript_e2e.py -v - uv run pytest tests/test_languages/test_javascript_support.py -v - uv run pytest tests/code_utils/test_config_js.py -v diff --git a/.gitignore b/.gitignore index a22143bf8..c52422253 100644 --- a/.gitignore +++ b/.gitignore @@ -274,3 +274,5 @@ tessl.json # Tessl auto-generates AGENTS.md on install; ignore to avoid cluttering git status AGENTS.md +.serena/ +.codeflash/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 000000000..6d6a48b5f --- /dev/null +++ b/LICENSE @@ -0,0 +1,98 @@ +Business Source License 1.1 + +Parameters + +Licensor: CodeFlash Inc. +Licensed Work: Codeflash Client version 0.20.x + The Licensed Work is (c) 2024 CodeFlash Inc. + +Additional Use Grant: None. Production use of the Licensed Work is only permitted + if you have entered into a separate written agreement + with CodeFlash Inc. for production use in connection + with a subscription to CodeFlash's Code Optimization + Platform. Please visit codeflash.ai for further + information. + +Change Date: 2030-01-26 + +Change License: MIT + +Notice + +The Business Source License (this document, or the “License”) is not an Open +Source license. However, the Licensed Work will eventually be made available +under an Open Source License, as stated in this License. + +License text copyright (c) 2017 MariaDB Corporation Ab, All Rights Reserved. +“Business Source License” is a trademark of MariaDB Corporation Ab. + +----------------------------------------------------------------------------- + +Business Source License 1.1 + +Terms + +The Licensor hereby grants you the right to copy, modify, create derivative +works, redistribute, and make non-production use of the Licensed Work. The +Licensor may make an Additional Use Grant, above, permitting limited +production use. + +Effective on the Change Date, or the fourth anniversary of the first publicly +available distribution of a specific version of the Licensed Work under this +License, whichever comes first, the Licensor hereby grants you rights under +the terms of the Change License, and the rights granted in the paragraph +above terminate. + +If your use of the Licensed Work does not comply with the requirements +currently in effect as described in this License, you must purchase a +commercial license from the Licensor, its affiliated entities, or authorized +resellers, or you must refrain from using the Licensed Work. + +All copies of the original and modified Licensed Work, and derivative works +of the Licensed Work, are subject to this License. This License applies +separately for each version of the Licensed Work and the Change Date may vary +for each version of the Licensed Work released by Licensor. + +You must conspicuously display this License on each original or modified copy +of the Licensed Work. If you receive the Licensed Work in original or +modified form from a third party, the terms and conditions set forth in this +License apply to your use of that work. + +Any use of the Licensed Work in violation of this License will automatically +terminate your rights under this License for the current and all other +versions of the Licensed Work. + +This License does not grant you any right in any trademark or logo of +Licensor or its affiliates (provided that you may use a trademark or logo of +Licensor as expressly required by this License). + +TO THE EXTENT PERMITTED BY APPLICABLE LAW, THE LICENSED WORK IS PROVIDED ON +AN “AS IS” BASIS. LICENSOR HEREBY DISCLAIMS ALL WARRANTIES AND CONDITIONS, +EXPRESS OR IMPLIED, INCLUDING (WITHOUT LIMITATION) WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, AND +TITLE. + +MariaDB hereby grants you permission to use this License’s text to license +your works, and to refer to it using the trademark “Business Source License”, +as long as you comply with the Covenants of Licensor below. + +Covenants of Licensor + +In consideration of the right to use this License’s text and the “Business +Source License” name and trademark, Licensor covenants to MariaDB, and to all +other recipients of the licensed work to be provided by Licensor: + +1. To specify as the Change License the GPL Version 2.0 or any later version, + or a license that is compatible with GPL Version 2.0 or a later version, + where “compatible” means that software provided under the Change License can + be included in a program with software provided under GPL Version 2.0 or a + later version. Licensor may specify additional Change Licenses without + limitation. + +2. To either: (a) specify an additional grant of rights to use that does not + impose any additional restriction on the right granted in this License, as + the Additional Use Grant; or (b) insert the text “None”. + +3. To specify a Change Date. + +4. Not to modify this License in any other way. \ No newline at end of file diff --git a/codeflash-benchmark/LICENSE b/codeflash-benchmark/LICENSE new file mode 100644 index 000000000..6d6a48b5f --- /dev/null +++ b/codeflash-benchmark/LICENSE @@ -0,0 +1,98 @@ +Business Source License 1.1 + +Parameters + +Licensor: CodeFlash Inc. +Licensed Work: Codeflash Client version 0.20.x + The Licensed Work is (c) 2024 CodeFlash Inc. + +Additional Use Grant: None. Production use of the Licensed Work is only permitted + if you have entered into a separate written agreement + with CodeFlash Inc. for production use in connection + with a subscription to CodeFlash's Code Optimization + Platform. Please visit codeflash.ai for further + information. + +Change Date: 2030-01-26 + +Change License: MIT + +Notice + +The Business Source License (this document, or the “License”) is not an Open +Source license. However, the Licensed Work will eventually be made available +under an Open Source License, as stated in this License. + +License text copyright (c) 2017 MariaDB Corporation Ab, All Rights Reserved. +“Business Source License” is a trademark of MariaDB Corporation Ab. + +----------------------------------------------------------------------------- + +Business Source License 1.1 + +Terms + +The Licensor hereby grants you the right to copy, modify, create derivative +works, redistribute, and make non-production use of the Licensed Work. The +Licensor may make an Additional Use Grant, above, permitting limited +production use. + +Effective on the Change Date, or the fourth anniversary of the first publicly +available distribution of a specific version of the Licensed Work under this +License, whichever comes first, the Licensor hereby grants you rights under +the terms of the Change License, and the rights granted in the paragraph +above terminate. + +If your use of the Licensed Work does not comply with the requirements +currently in effect as described in this License, you must purchase a +commercial license from the Licensor, its affiliated entities, or authorized +resellers, or you must refrain from using the Licensed Work. + +All copies of the original and modified Licensed Work, and derivative works +of the Licensed Work, are subject to this License. This License applies +separately for each version of the Licensed Work and the Change Date may vary +for each version of the Licensed Work released by Licensor. + +You must conspicuously display this License on each original or modified copy +of the Licensed Work. If you receive the Licensed Work in original or +modified form from a third party, the terms and conditions set forth in this +License apply to your use of that work. + +Any use of the Licensed Work in violation of this License will automatically +terminate your rights under this License for the current and all other +versions of the Licensed Work. + +This License does not grant you any right in any trademark or logo of +Licensor or its affiliates (provided that you may use a trademark or logo of +Licensor as expressly required by this License). + +TO THE EXTENT PERMITTED BY APPLICABLE LAW, THE LICENSED WORK IS PROVIDED ON +AN “AS IS” BASIS. LICENSOR HEREBY DISCLAIMS ALL WARRANTIES AND CONDITIONS, +EXPRESS OR IMPLIED, INCLUDING (WITHOUT LIMITATION) WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, AND +TITLE. + +MariaDB hereby grants you permission to use this License’s text to license +your works, and to refer to it using the trademark “Business Source License”, +as long as you comply with the Covenants of Licensor below. + +Covenants of Licensor + +In consideration of the right to use this License’s text and the “Business +Source License” name and trademark, Licensor covenants to MariaDB, and to all +other recipients of the licensed work to be provided by Licensor: + +1. To specify as the Change License the GPL Version 2.0 or any later version, + or a license that is compatible with GPL Version 2.0 or a later version, + where “compatible” means that software provided under the Change License can + be included in a program with software provided under GPL Version 2.0 or a + later version. Licensor may specify additional Change Licenses without + limitation. + +2. To either: (a) specify an additional grant of rights to use that does not + impose any additional restriction on the right granted in this License, as + the Additional Use Grant; or (b) insert the text “None”. + +3. To specify a Change Date. + +4. Not to modify this License in any other way. \ No newline at end of file diff --git a/codeflash-benchmark/README.md b/codeflash-benchmark/README.md new file mode 100644 index 000000000..91d79ae0d --- /dev/null +++ b/codeflash-benchmark/README.md @@ -0,0 +1,15 @@ +# CodeFlash Benchmark + +A pytest benchmarking plugin for [CodeFlash](https://codeflash.ai) - automatic code performance optimization. + +## Installation + +```bash +pip install codeflash-benchmark +``` + +## Usage + +This plugin provides benchmarking capabilities for pytest tests used by CodeFlash's optimization pipeline. + +For more information, visit [codeflash.ai](https://codeflash.ai). diff --git a/codeflash-benchmark/pyproject.toml b/codeflash-benchmark/pyproject.toml index f068f7367..bc5e9040d 100644 --- a/codeflash-benchmark/pyproject.toml +++ b/codeflash-benchmark/pyproject.toml @@ -1,32 +1,32 @@ -[project] -name = "codeflash-benchmark" -version = "0.2.0" -description = "Pytest benchmarking plugin for codeflash.ai - automatic code performance optimization" -authors = [{ name = "CodeFlash Inc.", email = "contact@codeflash.ai" }] -requires-python = ">=3.9" -readme = "README.md" -license = {text = "BSL-1.1"} -keywords = [ - "codeflash", - "benchmark", - "pytest", - "performance", - "testing", -] -dependencies = [ - "pytest>=7.0.0,!=8.3.4", -] - -[project.urls] -Homepage = "https://codeflash.ai" -Repository = "https://github.com/codeflash-ai/codeflash-benchmark" - -[project.entry-points.pytest11] -codeflash-benchmark = "codeflash_benchmark.plugin" - -[build-system] -requires = ["setuptools>=45", "wheel"] -build-backend = "setuptools.build_meta" - -[tool.setuptools] -packages = ["codeflash_benchmark"] +[project] +name = "codeflash-benchmark" +version = "0.2.0" +description = "Pytest benchmarking plugin for codeflash.ai - automatic code performance optimization" +authors = [{ name = "CodeFlash Inc.", email = "contact@codeflash.ai" }] +requires-python = ">=3.9" +readme = "README.md" +license-files = ["LICENSE"] +keywords = [ + "codeflash", + "benchmark", + "pytest", + "performance", + "testing", +] +dependencies = [ + "pytest>=7.0.0,!=8.3.4", +] + +[project.urls] +Homepage = "https://codeflash.ai" +Repository = "https://github.com/codeflash-ai/codeflash-benchmark" + +[project.entry-points.pytest11] +codeflash-benchmark = "codeflash_benchmark.plugin" + +[build-system] +requires = ["setuptools>=45", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +packages = ["codeflash_benchmark"] diff --git a/codeflash/cli_cmds/console.py b/codeflash/cli_cmds/console.py index b1e4b45d8..5ca7f9eea 100644 --- a/codeflash/cli_cmds/console.py +++ b/codeflash/cli_cmds/console.py @@ -40,7 +40,16 @@ logging.basicConfig( level=logging.INFO, - handlers=[RichHandler(rich_tracebacks=True, markup=False, highlighter=NullHighlighter(), console=console, show_path=False, show_time=False)], + handlers=[ + RichHandler( + rich_tracebacks=True, + markup=False, + highlighter=NullHighlighter(), + console=console, + show_path=False, + show_time=False, + ) + ], format=BARE_LOGGING_FORMAT, ) diff --git a/codeflash/cli_cmds/logging_config.py b/codeflash/cli_cmds/logging_config.py index c2f339abd..dbb3663bd 100644 --- a/codeflash/cli_cmds/logging_config.py +++ b/codeflash/cli_cmds/logging_config.py @@ -14,7 +14,16 @@ def set_level(level: int, *, echo_setting: bool = True) -> None: logging.basicConfig( level=level, - handlers=[RichHandler(rich_tracebacks=True, markup=False, highlighter=NullHighlighter(), console=console, show_path=False, show_time=False)], + handlers=[ + RichHandler( + rich_tracebacks=True, + markup=False, + highlighter=NullHighlighter(), + console=console, + show_path=False, + show_time=False, + ) + ], format=BARE_LOGGING_FORMAT, ) logging.getLogger().setLevel(level) @@ -23,7 +32,14 @@ def set_level(level: int, *, echo_setting: bool = True) -> None: logging.basicConfig( format=VERBOSE_LOGGING_FORMAT, handlers=[ - RichHandler(rich_tracebacks=True, markup=False, highlighter=NullHighlighter(), console=console, show_path=False, show_time=False) + RichHandler( + rich_tracebacks=True, + markup=False, + highlighter=NullHighlighter(), + console=console, + show_path=False, + show_time=False, + ) ], force=True, ) diff --git a/codeflash/code_utils/config_consts.py b/codeflash/code_utils/config_consts.py index e9afbcc64..d47c8baf0 100644 --- a/codeflash/code_utils/config_consts.py +++ b/codeflash/code_utils/config_consts.py @@ -4,8 +4,8 @@ from typing import Any, Union MAX_TEST_RUN_ITERATIONS = 5 -OPTIMIZATION_CONTEXT_TOKEN_LIMIT = 16000 -TESTGEN_CONTEXT_TOKEN_LIMIT = 16000 +OPTIMIZATION_CONTEXT_TOKEN_LIMIT = 48000 +TESTGEN_CONTEXT_TOKEN_LIMIT = 48000 INDIVIDUAL_TESTCASE_TIMEOUT = 15 # For Python pytest JAVA_TESTCASE_TIMEOUT = 120 # Java Maven tests need more time due to startup overhead MAX_FUNCTION_TEST_SECONDS = 60 diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index 006ed63cf..b4a626429 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -1518,73 +1518,207 @@ def _is_target_decorator(self, decorator_node: cst.Name | cst.Attribute | cst.Ca return False -class AsyncDecoratorImportAdder(cst.CSTTransformer): - """Transformer that adds the import for async decorators.""" +ASYNC_HELPER_INLINE_CODE = """import asyncio +import gc +import os +import sqlite3 +import time +from functools import wraps +from pathlib import Path +from tempfile import TemporaryDirectory + +import dill as pickle - def __init__(self, mode: TestingMode = TestingMode.BEHAVIOR) -> None: - self.mode = mode - self.has_import = False - - def _get_decorator_name(self) -> str: - """Get the decorator name based on the testing mode.""" - if self.mode == TestingMode.BEHAVIOR: - return "codeflash_behavior_async" - if self.mode == TestingMode.CONCURRENCY: - return "codeflash_concurrency_async" - return "codeflash_performance_async" - - def visit_ImportFrom(self, node: cst.ImportFrom) -> None: - # Check if the async decorator import is already present - if ( - isinstance(node.module, cst.Attribute) - and isinstance(node.module.value, cst.Attribute) - and isinstance(node.module.value.value, cst.Name) - and node.module.value.value.value == "codeflash" - and node.module.value.attr.value == "code_utils" - and node.module.attr.value == "codeflash_wrap_decorator" - and not isinstance(node.names, cst.ImportStar) - ): - decorator_name = self._get_decorator_name() - for import_alias in node.names: - if import_alias.name.value == decorator_name: - self.has_import = True - - def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: - # If the import is already there, don't add it again - if self.has_import: - return updated_node - - # Choose import based on mode - decorator_name = self._get_decorator_name() - - # Parse the import statement into a CST node - import_node = cst.parse_statement(f"from codeflash.code_utils.codeflash_wrap_decorator import {decorator_name}") - - # Add the import to the module's body - return updated_node.with_changes(body=[import_node, *list(updated_node.body)]) + +def get_run_tmp_file(file_path): + if not hasattr(get_run_tmp_file, "tmpdir"): + get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_") + return Path(get_run_tmp_file.tmpdir.name) / file_path + + +def extract_test_context_from_env(): + test_module = os.environ["CODEFLASH_TEST_MODULE"] + test_class = os.environ.get("CODEFLASH_TEST_CLASS", None) + test_function = os.environ["CODEFLASH_TEST_FUNCTION"] + if test_module and test_function: + return (test_module, test_class if test_class else None, test_function) + raise RuntimeError( + "Test context environment variables not set - ensure tests are run through codeflash test runner" + ) + + +def codeflash_behavior_async(func): + @wraps(func) + async def async_wrapper(*args, **kwargs): + loop = asyncio.get_running_loop() + function_name = func.__name__ + line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"] + loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"]) + test_module_name, test_class_name, test_name = extract_test_context_from_env() + test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}" + if not hasattr(async_wrapper, "index"): + async_wrapper.index = {} + if test_id in async_wrapper.index: + async_wrapper.index[test_id] += 1 + else: + async_wrapper.index[test_id] = 0 + codeflash_test_index = async_wrapper.index[test_id] + invocation_id = f"{line_id}_{codeflash_test_index}" + class_prefix = (test_class_name + ".") if test_class_name else "" + test_stdout_tag = f"{test_module_name}:{class_prefix}{test_name}:{function_name}:{loop_index}:{invocation_id}" + print(f"!$######{test_stdout_tag}######$!") + iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0") + db_path = get_run_tmp_file(Path(f"test_return_values_{iteration}.sqlite")) + codeflash_con = sqlite3.connect(db_path) + codeflash_cur = codeflash_con.cursor() + codeflash_cur.execute( + "CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, " + "test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " + "runtime INTEGER, return_value BLOB, verification_type TEXT)" + ) + exception = None + counter = loop.time() + gc.disable() + try: + ret = func(*args, **kwargs) + counter = loop.time() + return_value = await ret + codeflash_duration = int((loop.time() - counter) * 1_000_000_000) + except Exception as e: + codeflash_duration = int((loop.time() - counter) * 1_000_000_000) + exception = e + finally: + gc.enable() + print(f"!######{test_stdout_tag}######!") + pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps((args, kwargs, return_value)) + codeflash_cur.execute( + "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + test_module_name, + test_class_name, + test_name, + function_name, + loop_index, + invocation_id, + codeflash_duration, + pickled_return_value, + "function_call", + ), + ) + codeflash_con.commit() + codeflash_con.close() + if exception: + raise exception + return return_value + return async_wrapper + + +def codeflash_performance_async(func): + @wraps(func) + async def async_wrapper(*args, **kwargs): + loop = asyncio.get_running_loop() + function_name = func.__name__ + line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"] + loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"]) + test_module_name, test_class_name, test_name = extract_test_context_from_env() + test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}" + if not hasattr(async_wrapper, "index"): + async_wrapper.index = {} + if test_id in async_wrapper.index: + async_wrapper.index[test_id] += 1 + else: + async_wrapper.index[test_id] = 0 + codeflash_test_index = async_wrapper.index[test_id] + invocation_id = f"{line_id}_{codeflash_test_index}" + class_prefix = (test_class_name + ".") if test_class_name else "" + test_stdout_tag = f"{test_module_name}:{class_prefix}{test_name}:{function_name}:{loop_index}:{invocation_id}" + print(f"!$######{test_stdout_tag}######$!") + exception = None + counter = loop.time() + gc.disable() + try: + ret = func(*args, **kwargs) + counter = loop.time() + return_value = await ret + codeflash_duration = int((loop.time() - counter) * 1_000_000_000) + except Exception as e: + codeflash_duration = int((loop.time() - counter) * 1_000_000_000) + exception = e + finally: + gc.enable() + print(f"!######{test_stdout_tag}:{codeflash_duration}######!") + if exception: + raise exception + return return_value + return async_wrapper + + +def codeflash_concurrency_async(func): + @wraps(func) + async def async_wrapper(*args, **kwargs): + function_name = func.__name__ + concurrency_factor = int(os.environ.get("CODEFLASH_CONCURRENCY_FACTOR", "10")) + test_module_name = os.environ.get("CODEFLASH_TEST_MODULE", "") + test_class_name = os.environ.get("CODEFLASH_TEST_CLASS", "") + test_function = os.environ.get("CODEFLASH_TEST_FUNCTION", "") + loop_index = os.environ.get("CODEFLASH_LOOP_INDEX", "0") + gc.disable() + try: + seq_start = time.perf_counter_ns() + for _ in range(concurrency_factor): + result = await func(*args, **kwargs) + sequential_time = time.perf_counter_ns() - seq_start + finally: + gc.enable() + gc.disable() + try: + conc_start = time.perf_counter_ns() + tasks = [func(*args, **kwargs) for _ in range(concurrency_factor)] + await asyncio.gather(*tasks) + concurrent_time = time.perf_counter_ns() - conc_start + finally: + gc.enable() + tag = f"{test_module_name}:{test_class_name}:{test_function}:{function_name}:{loop_index}" + print(f"!@######CONC:{tag}:{sequential_time}:{concurrent_time}:{concurrency_factor}######@!") + return result + return async_wrapper +""" + +ASYNC_HELPER_FILENAME = "codeflash_async_wrapper.py" + + +def get_decorator_name_for_mode(mode: TestingMode) -> str: + if mode == TestingMode.BEHAVIOR: + return "codeflash_behavior_async" + if mode == TestingMode.CONCURRENCY: + return "codeflash_concurrency_async" + return "codeflash_performance_async" + + +def write_async_helper_file(target_dir: Path) -> Path: + """Write the async decorator helper file to the target directory.""" + helper_path = target_dir / ASYNC_HELPER_FILENAME + if not helper_path.exists(): + helper_path.write_text(ASYNC_HELPER_INLINE_CODE, "utf-8") + return helper_path def add_async_decorator_to_function( - source_path: Path, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR + source_path: Path, + function: FunctionToOptimize, + mode: TestingMode = TestingMode.BEHAVIOR, + project_root: Path | None = None, ) -> bool: """Add async decorator to an async function definition and write back to file. - Args: - ---- - source_path: Path to the source file to modify in-place. - function: The FunctionToOptimize object representing the target async function. - mode: The testing mode to determine which decorator to apply. - - Returns: - ------- - Boolean indicating whether the decorator was successfully added. + Writes a helper file containing the decorator implementation to project_root (or source directory + as fallback) and adds a standard import + decorator to the source file. """ if not function.is_async: return False try: - # Read source code with source_path.open(encoding="utf8") as f: source_code = f.read() @@ -1594,10 +1728,14 @@ def add_async_decorator_to_function( decorator_transformer = AsyncDecoratorAdder(function, mode) module = module.visit(decorator_transformer) - # Add the import if decorator was added if decorator_transformer.added_decorator: - import_transformer = AsyncDecoratorImportAdder(mode) - module = module.visit(import_transformer) + # Write the helper file to project_root (on sys.path) or source dir as fallback + helper_dir = project_root if project_root is not None else source_path.parent + write_async_helper_file(helper_dir) + # Add the import via CST so sort_imports can place it correctly + decorator_name = get_decorator_name_for_mode(mode) + import_node = cst.parse_statement(f"from codeflash_async_wrapper import {decorator_name}") + module = module.with_changes(body=[import_node, *list(module.body)]) modified_code = sort_imports(code=module.code, float_to_top=True) except Exception as e: diff --git a/codeflash/languages/base.py b/codeflash/languages/base.py index 9d326f022..3715a05de 100644 --- a/codeflash/languages/base.py +++ b/codeflash/languages/base.py @@ -520,15 +520,6 @@ def get_test_file_suffix(self) -> str: """ ... - def get_comment_prefix(self) -> str: - """Get the comment prefix for this language. - - Returns: - Comment prefix (e.g., "//" for JS, "#" for Python). - - """ - ... - def find_test_root(self, project_root: Path) -> Path | None: """Find the test root directory for a project. diff --git a/codeflash/languages/current.py b/codeflash/languages/current.py index b576e1616..b9e45d367 100644 --- a/codeflash/languages/current.py +++ b/codeflash/languages/current.py @@ -34,7 +34,7 @@ from codeflash.languages.base import LanguageSupport # Module-level singleton for the current language -_current_language: Language | None = None +_current_language: Language = Language.PYTHON def current_language() -> Language: diff --git a/codeflash/languages/java/context.py b/codeflash/languages/java/context.py index 29067f23f..394f52037 100644 --- a/codeflash/languages/java/context.py +++ b/codeflash/languages/java/context.py @@ -887,11 +887,7 @@ def collect_type_identifiers(node: Node) -> None: def get_java_imported_type_skeletons( - imports: list, - project_root: Path, - module_root: Path | None, - analyzer: JavaAnalyzer, - target_code: str = "", + imports: list, project_root: Path, module_root: Path | None, analyzer: JavaAnalyzer, target_code: str = "" ) -> str: """Extract type skeletons for project-internal imported types. @@ -1011,9 +1007,7 @@ def _extract_constructor_summaries(skeleton: TypeSkeleton) -> list[str]: return summaries -def _format_skeleton_for_context( - skeleton: TypeSkeleton, source: str, class_name: str, analyzer: JavaAnalyzer -) -> str: +def _format_skeleton_for_context(skeleton: TypeSkeleton, source: str, class_name: str, analyzer: JavaAnalyzer) -> str: """Format a TypeSkeleton into a context string with method signatures. Includes: type declaration, fields, constructors, and public method signatures @@ -1094,7 +1088,7 @@ def _extract_public_method_signatures(source: str, class_name: str, analyzer: Ja sig_parts_bytes.append(mod_slice) continue - if ctype == "block" or ctype == "constructor_body": + if ctype in {"block", "constructor_body"}: break sig_parts_bytes.append(source_bytes[child.start_byte : child.end_byte]) diff --git a/codeflash/languages/java/instrumentation.py b/codeflash/languages/java/instrumentation.py index 7cad460dd..18fdb1409 100644 --- a/codeflash/languages/java/instrumentation.py +++ b/codeflash/languages/java/instrumentation.py @@ -730,11 +730,17 @@ def split_var_declaration(stmt_node, source_bytes_ref: bytes) -> tuple[str, str] # The variable is assigned inside a for/try block which Java considers # conditionally executed, so an uninitialized declaration would cause # "variable might not have been initialized" errors. - _PRIMITIVE_DEFAULTS = { - "byte": "0", "short": "0", "int": "0", "long": "0L", - "float": "0.0f", "double": "0.0", "char": "'\\0'", "boolean": "false", + primitive_defaults = { + "byte": "0", + "short": "0", + "int": "0", + "long": "0L", + "float": "0.0f", + "double": "0.0", + "char": "'\\0'", + "boolean": "false", } - default_val = _PRIMITIVE_DEFAULTS.get(type_text, "null") + default_val = primitive_defaults.get(type_text, "null") hoisted = f"{type_text} {name_text} = {default_val};" assignment = f"{name_text} = {value_text};" return hoisted, assignment @@ -918,9 +924,7 @@ def build_instrumented_body(body_text: str, next_wrapper_id: int, base_indent: s replacements: list[tuple[int, int, bytes]] = [] wrapper_id = 0 - method_ordinal = 0 - for method_node, body_node in test_methods: - method_ordinal += 1 + for method_ordinal, (method_node, body_node) in enumerate(test_methods, start=1): body_start = body_node.start_byte + 1 # skip '{' body_end = body_node.end_byte - 1 # skip '}' body_text = source_bytes[body_start:body_end].decode("utf8") diff --git a/codeflash/languages/java/replacement.py b/codeflash/languages/java/replacement.py index 23e3c9232..a374043e5 100644 --- a/codeflash/languages/java/replacement.py +++ b/codeflash/languages/java/replacement.py @@ -374,13 +374,7 @@ def replace_function( class_name, ) source = _insert_class_members( - source, - class_name, - new_fields_to_add, - new_helpers_before, - new_helpers_after, - func_name, - analyzer, + source, class_name, new_fields_to_add, new_helpers_before, new_helpers_after, func_name, analyzer ) # Re-find the target method after modifications diff --git a/codeflash/languages/javascript/parse.py b/codeflash/languages/javascript/parse.py index a5e7ae8c6..e3eee4831 100644 --- a/codeflash/languages/javascript/parse.py +++ b/codeflash/languages/javascript/parse.py @@ -527,10 +527,5 @@ def parse_jest_test_xml( f"[LOOP-SUMMARY] Results loop_index: min={min_idx}, max={max_idx}, " f"unique_count={len(unique_loop_indices)}, total_results={len(loop_indices)}" ) - if max_idx == 1 and len(loop_indices) > 1: - logger.warning( - f"[LOOP-WARNING] All {len(loop_indices)} results have loop_index=1. " - "Perf test markers may not have been parsed correctly." - ) return test_results diff --git a/codeflash/languages/javascript/support.py b/codeflash/languages/javascript/support.py index 5d6967442..572f162f0 100644 --- a/codeflash/languages/javascript/support.py +++ b/codeflash/languages/javascript/support.py @@ -1805,15 +1805,6 @@ def get_test_file_suffix(self) -> str: """ return ".test.js" - def get_comment_prefix(self) -> str: - """Get the comment prefix for JavaScript. - - Returns: - JavaScript single-line comment prefix. - - """ - return "//" - def find_test_root(self, project_root: Path) -> Path | None: """Find the test root directory for a JavaScript project. diff --git a/codeflash/languages/javascript/test_runner.py b/codeflash/languages/javascript/test_runner.py index 1d79ad382..3a193602b 100644 --- a/codeflash/languages/javascript/test_runner.py +++ b/codeflash/languages/javascript/test_runner.py @@ -803,8 +803,6 @@ def run_jest_behavioral_tests( wall_clock_ns = time.perf_counter_ns() - start_time_ns logger.debug(f"Jest behavioral tests completed in {wall_clock_ns / 1e9:.2f}s") - print(result.stdout) - return result_file_path, result, coverage_json_path, None @@ -1046,6 +1044,10 @@ def run_jest_benchmarking_tests( # Create result with combined stdout result = subprocess.CompletedProcess(args=result.args, returncode=result.returncode, stdout=stdout, stderr="") + if result.returncode != 0: + logger.info(f"Jest benchmarking failed with return code {result.returncode}") + logger.info(f"Jest benchmarking stdout: {result.stdout}") + logger.info(f"Jest benchmarking stderr: {result.stderr}") except subprocess.TimeoutExpired: logger.warning(f"Jest benchmarking timed out after {total_timeout}s") diff --git a/codeflash/context/__init__.py b/codeflash/languages/python/context/__init__.py similarity index 100% rename from codeflash/context/__init__.py rename to codeflash/languages/python/context/__init__.py diff --git a/codeflash/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py similarity index 50% rename from codeflash/context/code_context_extractor.py rename to codeflash/languages/python/context/code_context_extractor.py index 7e0f1fa0c..173dbde86 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -6,7 +6,7 @@ from collections import defaultdict from itertools import chain from pathlib import Path -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING import libcst as cst @@ -14,16 +14,17 @@ from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects from codeflash.code_utils.code_utils import encoded_tokens_len, get_qualified_name, path_belongs_to_site_packages from codeflash.code_utils.config_consts import OPTIMIZATION_CONTEXT_TOKEN_LIMIT, TESTGEN_CONTEXT_TOKEN_LIMIT -from codeflash.context.unused_definition_remover import ( - collect_top_level_defs_with_usages, - extract_names_from_targets, - get_section_names, - remove_unused_definitions_by_function_names, -) from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001 # Language support imports for multi-language code context extraction from codeflash.languages import Language, is_python +from codeflash.languages.python.context.unused_definition_remover import ( + collect_top_level_defs_with_usages, + get_section_names, + is_assignment_used, + recurse_sections, + remove_unused_definitions_by_function_names, +) from codeflash.models.models import ( CodeContextType, CodeOptimizationContext, @@ -35,20 +36,30 @@ if TYPE_CHECKING: from jedi.api.classes import Name - from libcst import CSTNode - from codeflash.context.unused_definition_remover import UsageInfo from codeflash.languages.base import HelperFunction + from codeflash.languages.python.context.unused_definition_remover import UsageInfo + +# Error message constants +READ_WRITABLE_LIMIT_ERROR = "Read-writable code has exceeded token limit, cannot proceed" +TESTGEN_LIMIT_ERROR = "Testgen code context has exceeded token limit, cannot proceed" + + +def safe_relative_to(path: Path, root: Path) -> Path: + try: + return path.resolve().relative_to(root.resolve()) + except ValueError: + return path def build_testgen_context( helpers_of_fto_dict: dict[Path, set[FunctionSource]], helpers_of_helpers_dict: dict[Path, set[FunctionSource]], project_root_path: Path, - remove_docstrings: bool, - include_imported_classes: bool, + *, + remove_docstrings: bool = False, + include_enrichment: bool = True, ) -> CodeStringsMarkdown: - """Build testgen context with optional imported class definitions and external base inits.""" testgen_context = extract_code_markdown_context_from_files( helpers_of_fto_dict, helpers_of_helpers_dict, @@ -57,24 +68,10 @@ def build_testgen_context( code_context_type=CodeContextType.TESTGEN, ) - if include_imported_classes: - imported_class_context = get_imported_class_definitions(testgen_context, project_root_path) - if imported_class_context.code_strings: - testgen_context = CodeStringsMarkdown( - code_strings=testgen_context.code_strings + imported_class_context.code_strings - ) - - external_base_inits = get_external_base_class_inits(testgen_context, project_root_path) - if external_base_inits.code_strings: - testgen_context = CodeStringsMarkdown( - code_strings=testgen_context.code_strings + external_base_inits.code_strings - ) - - external_class_inits = get_external_class_inits(testgen_context, project_root_path) - if external_class_inits.code_strings: - testgen_context = CodeStringsMarkdown( - code_strings=testgen_context.code_strings + external_class_inits.code_strings - ) + if include_enrichment: + enrichment = enrich_testgen_context(testgen_context, project_root_path) + if enrichment.code_strings: + testgen_context = CodeStringsMarkdown(code_strings=testgen_context.code_strings + enrichment.code_strings) return testgen_context @@ -142,7 +139,7 @@ def get_code_optimization_context( # Handle token limits final_read_writable_tokens = encoded_tokens_len(final_read_writable_code.markdown) if final_read_writable_tokens > optim_token_limit: - raise ValueError("Read-writable code has exceeded token limit, cannot proceed") + raise ValueError(READ_WRITABLE_LIMIT_ERROR) # Setup preexisting objects for code replacer preexisting_objects = set( @@ -153,53 +150,39 @@ def get_code_optimization_context( ) read_only_context_code = read_only_code_markdown.markdown - read_only_code_markdown_tokens = encoded_tokens_len(read_only_context_code) - total_tokens = final_read_writable_tokens + read_only_code_markdown_tokens - if total_tokens > optim_token_limit: + # Progressive fallback for read-only context token limits + read_only_tokens = encoded_tokens_len(read_only_context_code) + if final_read_writable_tokens + read_only_tokens > optim_token_limit: logger.debug("Code context has exceeded token limit, removing docstrings from read-only code") - # Extract read only code without docstrings - read_only_code_no_docstring_markdown = extract_code_markdown_context_from_files( + read_only_code_no_docstrings = extract_code_markdown_context_from_files( helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, remove_docstrings=True ) - read_only_context_code = read_only_code_no_docstring_markdown.markdown - read_only_code_no_docstring_markdown_tokens = encoded_tokens_len(read_only_context_code) - total_tokens = final_read_writable_tokens + read_only_code_no_docstring_markdown_tokens - if total_tokens > optim_token_limit: + read_only_context_code = read_only_code_no_docstrings.markdown + if final_read_writable_tokens + encoded_tokens_len(read_only_context_code) > optim_token_limit: logger.debug("Code context has exceeded token limit, removing read-only code") read_only_context_code = "" - # Extract code context for testgen with progressive fallback for token limits - # Try in order: full context -> remove docstrings -> remove imported classes - testgen_context = build_testgen_context( - helpers_of_fto_dict, - helpers_of_helpers_dict, - project_root_path, - remove_docstrings=False, - include_imported_classes=True, - ) + # Progressive fallback for testgen context token limits + testgen_context = build_testgen_context(helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path) if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit: logger.debug("Testgen context exceeded token limit, removing docstrings") testgen_context = build_testgen_context( - helpers_of_fto_dict, - helpers_of_helpers_dict, - project_root_path, - remove_docstrings=True, - include_imported_classes=True, + helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, remove_docstrings=True ) if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit: - logger.debug("Testgen context still exceeded token limit, removing imported class definitions") + logger.debug("Testgen context still exceeded token limit, removing enrichment") testgen_context = build_testgen_context( helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, remove_docstrings=True, - include_imported_classes=False, + include_enrichment=False, ) if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit: - raise ValueError("Testgen code context has exceeded token limit, cannot proceed") + raise ValueError(TESTGEN_LIMIT_ERROR) code_hash_context = hashing_code_context.markdown code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest() @@ -251,10 +234,7 @@ def get_code_optimization_context_for_language( imports_code = "\n".join(code_context.imports) if code_context.imports else "" # Get relative path for target file - try: - target_relative_path = function_to_optimize.file_path.resolve().relative_to(project_root_path.resolve()) - except ValueError: - target_relative_path = function_to_optimize.file_path + target_relative_path = safe_relative_to(function_to_optimize.file_path, project_root_path) # Group helpers by file path helpers_by_file: dict[Path, list[HelperFunction]] = defaultdict(list) @@ -302,10 +282,7 @@ def get_code_optimization_context_for_language( if file_path == function_to_optimize.file_path: continue # Already included in target file - try: - helper_relative_path = file_path.resolve().relative_to(project_root_path.resolve()) - except ValueError: - helper_relative_path = file_path + helper_relative_path = safe_relative_to(file_path, project_root_path) # Combine all helpers from this file combined_helper_code = "\n\n".join(h.source_code for h in file_helpers) @@ -333,11 +310,11 @@ def get_code_optimization_context_for_language( # Check token limits read_writable_tokens = encoded_tokens_len(read_writable_code.markdown) if read_writable_tokens > optim_token_limit: - raise ValueError("Read-writable code has exceeded token limit, cannot proceed") + raise ValueError(READ_WRITABLE_LIMIT_ERROR) testgen_tokens = encoded_tokens_len(testgen_context.markdown) if testgen_tokens > testgen_token_limit: - raise ValueError("Testgen code context has exceeded token limit, cannot proceed") + raise ValueError(TESTGEN_LIMIT_ERROR) # Generate code hash from all read-writable code code_hash = hashlib.sha256(read_writable_code.flat.encode("utf-8")).hexdigest() @@ -355,6 +332,49 @@ def get_code_optimization_context_for_language( ) +def process_file_context( + file_path: Path, + primary_qualified_names: set[str], + secondary_qualified_names: set[str], + code_context_type: CodeContextType, + remove_docstrings: bool, + project_root_path: Path, + helper_functions: list[FunctionSource], +) -> CodeString | None: + try: + original_code = file_path.read_text("utf8") + except Exception as e: + logger.exception(f"Error while parsing {file_path}: {e}") + return None + + try: + all_names = primary_qualified_names | secondary_qualified_names + code_without_unused_defs = remove_unused_definitions_by_function_names(original_code, all_names) + code_context = parse_code_and_prune_cst( + code_without_unused_defs, + code_context_type, + primary_qualified_names, + secondary_qualified_names, + remove_docstrings, + ) + except ValueError as e: + logger.debug(f"Error while getting read-only code: {e}") + return None + + if code_context.strip(): + if code_context_type != CodeContextType.HASHING: + code_context = add_needed_imports_from_module( + src_module_code=original_code, + dst_module_code=code_context, + src_path=file_path, + dst_path=file_path, + project_root=project_root_path, + helper_functions=helper_functions, + ) + return CodeString(code=code_context, file_path=safe_relative_to(file_path, project_root_path)) + return None + + def extract_code_markdown_context_from_files( helpers_of_fto: dict[Path, set[FunctionSource]], helpers_of_helpers: dict[Path, set[FunctionSource]], @@ -396,79 +416,39 @@ def extract_code_markdown_context_from_files( code_context_markdown = CodeStringsMarkdown() # Extract code from file paths that contain fto and first degree helpers. helpers of helpers may also be included if they are in the same files for file_path, function_sources in helpers_of_fto.items(): - try: - original_code = file_path.read_text("utf8") - except Exception as e: - logger.exception(f"Error while parsing {file_path}: {e}") - continue - try: - qualified_function_names = {func.qualified_name for func in function_sources} - helpers_of_helpers_qualified_names = { - func.qualified_name for func in helpers_of_helpers.get(file_path, set()) - } - code_without_unused_defs = remove_unused_definitions_by_function_names( - original_code, qualified_function_names | helpers_of_helpers_qualified_names - ) - code_context = parse_code_and_prune_cst( - code_without_unused_defs, - code_context_type, - qualified_function_names, - helpers_of_helpers_qualified_names, - remove_docstrings, - ) + qualified_function_names = {func.qualified_name for func in function_sources} + helpers_of_helpers_qualified_names = {func.qualified_name for func in helpers_of_helpers.get(file_path, set())} + helper_functions = list(helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set())) + + result = process_file_context( + file_path=file_path, + primary_qualified_names=qualified_function_names, + secondary_qualified_names=helpers_of_helpers_qualified_names, + code_context_type=code_context_type, + remove_docstrings=remove_docstrings, + project_root_path=project_root_path, + helper_functions=helper_functions, + ) - except ValueError as e: - logger.debug(f"Error while getting read-only code: {e}") - continue - if code_context.strip(): - if code_context_type != CodeContextType.HASHING: - code_context = add_needed_imports_from_module( - src_module_code=original_code, - dst_module_code=code_context, - src_path=file_path, - dst_path=file_path, - project_root=project_root_path, - helper_functions=list( - helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set()) - ), - ) - code_string_context = CodeString( - code=code_context, file_path=file_path.resolve().relative_to(project_root_path.resolve()) - ) - code_context_markdown.code_strings.append(code_string_context) + if result is not None: + code_context_markdown.code_strings.append(result) # Extract code from file paths containing helpers of helpers for file_path, helper_function_sources in helpers_of_helpers_no_overlap.items(): - try: - original_code = file_path.read_text("utf8") - except Exception as e: - logger.exception(f"Error while parsing {file_path}: {e}") - continue - try: - qualified_helper_function_names = {func.qualified_name for func in helper_function_sources} - code_without_unused_defs = remove_unused_definitions_by_function_names( - original_code, qualified_helper_function_names - ) - code_context = parse_code_and_prune_cst( - code_without_unused_defs, code_context_type, set(), qualified_helper_function_names, remove_docstrings - ) - except ValueError as e: - logger.debug(f"Error while getting read-only code: {e}") - continue + qualified_helper_function_names = {func.qualified_name for func in helper_function_sources} + helper_functions = list(helpers_of_helpers_no_overlap.get(file_path, set())) + + result = process_file_context( + file_path=file_path, + primary_qualified_names=set(), + secondary_qualified_names=qualified_helper_function_names, + code_context_type=code_context_type, + remove_docstrings=remove_docstrings, + project_root_path=project_root_path, + helper_functions=helper_functions, + ) - if code_context.strip(): - if code_context_type != CodeContextType.HASHING: - code_context = add_needed_imports_from_module( - src_module_code=original_code, - dst_module_code=code_context, - src_path=file_path, - dst_path=file_path, - project_root=project_root_path, - helper_functions=list(helpers_of_helpers_no_overlap.get(file_path, set())), - ) - code_string_context = CodeString( - code=code_context, file_path=file_path.resolve().relative_to(project_root_path.resolve()) - ) - code_context_markdown.code_strings.append(code_string_context) + if result is not None: + code_context_markdown.code_strings.append(result) return code_context_markdown @@ -539,39 +519,28 @@ def get_function_sources_from_jedi( # The definition is part of this project and not defined within the original function is_valid_definition = ( - str(definition_path).startswith(str(project_root_path) + os.sep) - and not path_belongs_to_site_packages(definition_path) + is_project_path(definition_path, project_root_path) and definition.full_name and not belongs_to_function_qualified(definition, qualified_function_name) and definition.full_name.startswith(definition.module_name) ) - if is_valid_definition and definition.type == "function": - qualified_name = get_qualified_name(definition.module_name, definition.full_name) + if is_valid_definition and definition.type in ("function", "class"): + if definition.type == "function": + fqn = definition.full_name + func_name = definition.name + else: + # When a class is instantiated (e.g., MyClass()), track its __init__ as a helper + # This ensures the class definition with constructor is included in testgen context + fqn = f"{definition.full_name}.__init__" + func_name = "__init__" + qualified_name = get_qualified_name(definition.module_name, fqn) # Avoid nested functions or classes. Only class.function is allowed if len(qualified_name.split(".")) <= 2: function_source = FunctionSource( file_path=definition_path, qualified_name=qualified_name, - fully_qualified_name=definition.full_name, - only_function_name=definition.name, - source_code=definition.get_line_code(), - jedi_definition=definition, - ) - file_path_to_function_source[definition_path].add(function_source) - function_source_list.append(function_source) - # When a class is instantiated (e.g., MyClass()), track its __init__ as a helper - # This ensures the class definition with constructor is included in testgen context - elif is_valid_definition and definition.type == "class": - init_qualified_name = get_qualified_name( - definition.module_name, f"{definition.full_name}.__init__" - ) - # Only include if it's a top-level class (not nested) - if len(init_qualified_name.split(".")) <= 2: - function_source = FunctionSource( - file_path=definition_path, - qualified_name=init_qualified_name, - fully_qualified_name=f"{definition.full_name}.__init__", - only_function_name="__init__", + fully_qualified_name=fqn, + only_function_name=func_name, source_code=definition.get_line_code(), jedi_definition=definition, ) @@ -581,60 +550,123 @@ def get_function_sources_from_jedi( return file_path_to_function_source, function_source_list -def get_imported_class_definitions(code_context: CodeStringsMarkdown, project_root_path: Path) -> CodeStringsMarkdown: - """Extract class definitions for imported types from project modules. - - This function analyzes the imports in the extracted code context and fetches - class definitions for any classes imported from project modules. This helps - the LLM understand the actual class structure (constructors, methods, inheritance) - rather than just seeing import statements. - - Also recursively extracts base classes when a class inherits from another class - in the same module, ensuring the full inheritance chain is available for - understanding constructor signatures. - - Args: - code_context: The already extracted code context containing imports - project_root_path: Root path of the project - - Returns: - CodeStringsMarkdown containing class definitions from imported project modules - - """ - import jedi - - # Collect all code from the context +def _parse_and_collect_imports(code_context: CodeStringsMarkdown) -> tuple[ast.Module, dict[str, str]] | None: all_code = "\n".join(cs.code for cs in code_context.code_strings) - - # Parse to find import statements try: tree = ast.parse(all_code) except SyntaxError: - return CodeStringsMarkdown(code_strings=[]) + return None + imported_names: dict[str, str] = {} - # Collect imported names and their source modules - imported_names: dict[str, str] = {} # name -> module_path - for node in ast.walk(tree): - if isinstance(node, ast.ImportFrom) and node.module: - for alias in node.names: - if alias.name != "*": - imported_name = alias.asname if alias.asname else alias.name - imported_names[imported_name] = node.module + # Directly iterate over the module body and nested structures instead of ast.walk + # This avoids traversing every single node in the tree + def collect_imports(nodes: list[ast.stmt]) -> None: + for node in nodes: + if isinstance(node, ast.ImportFrom) and node.module: + for alias in node.names: + if alias.name != "*": + imported_name = alias.asname if alias.asname else alias.name + imported_names[imported_name] = node.module + # Recursively check nested structures (function defs, class defs, if statements, etc.) + elif isinstance( + node, + ( + ast.FunctionDef, + ast.AsyncFunctionDef, + ast.ClassDef, + ast.If, + ast.For, + ast.AsyncFor, + ast.While, + ast.With, + ast.AsyncWith, + ast.Try, + ast.ExceptHandler, + ), + ): + if hasattr(node, "body"): + collect_imports(node.body) + if hasattr(node, "orelse"): + collect_imports(node.orelse) + if hasattr(node, "finalbody"): + collect_imports(node.finalbody) + if hasattr(node, "handlers"): + for handler in node.handlers: + collect_imports(handler.body) + # Handle match/case statements (Python 3.10+) + elif hasattr(ast, "Match") and isinstance(node, ast.Match): + for case in node.cases: + collect_imports(case.body) + + collect_imports(tree.body) + return tree, imported_names + + +def collect_existing_class_names(tree: ast.Module) -> set[str]: + class_names = set() + stack = list(tree.body) + + while stack: + node = stack.pop() + if isinstance(node, ast.ClassDef): + class_names.add(node.name) + stack.extend(node.body) + elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + stack.extend(node.body) + elif isinstance(node, (ast.If, ast.For, ast.While, ast.With)): + stack.extend(node.body) + if hasattr(node, "orelse"): + stack.extend(node.orelse) + elif isinstance(node, ast.Try): + stack.extend(node.body) + stack.extend(node.orelse) + stack.extend(node.finalbody) + for handler in node.handlers: + stack.extend(handler.body) + + return class_names + + +def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path: Path) -> CodeStringsMarkdown: + import jedi + + result = _parse_and_collect_imports(code_context) + if result is None: + return CodeStringsMarkdown(code_strings=[]) + tree, imported_names = result if not imported_names: return CodeStringsMarkdown(code_strings=[]) - # Track which classes we've already extracted to avoid duplicates - extracted_classes: set[tuple[Path, str]] = set() # (file_path, class_name) + existing_classes = collect_existing_class_names(tree) - # Also track what's already defined in the context - existing_definitions: set[str] = set() + # Collect base class names from ClassDef nodes (single walk) + base_class_names: set[str] = set() for node in ast.walk(tree): if isinstance(node, ast.ClassDef): - existing_definitions.add(node.name) + for base in node.bases: + if isinstance(base, ast.Name): + base_class_names.add(base.id) + elif isinstance(base, ast.Attribute) and isinstance(base.value, ast.Name): + base_class_names.add(base.attr) + + # Classify external imports using importlib-based check + is_project_cache: dict[str, bool] = {} + external_base_classes: set[tuple[str, str]] = set() + external_direct_imports: set[tuple[str, str]] = set() + + for name, module_name in imported_names.items(): + if not _is_project_module_cached(module_name, project_root_path, is_project_cache): + if name in base_class_names: + external_base_classes.add((name, module_name)) + if name not in existing_classes: + external_direct_imports.add((name, module_name)) - class_code_strings: list[CodeString] = [] + code_strings: list[CodeString] = [] + emitted_class_names: set[str] = set() + # --- Step 1: Project class definitions (jedi resolution + recursive base extraction) --- + extracted_classes: set[tuple[Path, str]] = set() module_cache: dict[Path, tuple[str, ast.Module]] = {} def get_module_source_and_tree(module_path: Path) -> tuple[str, ast.Module] | None: @@ -652,12 +684,9 @@ def get_module_source_and_tree(module_path: Path) -> tuple[str, ast.Module] | No def extract_class_and_bases( class_name: str, module_path: Path, module_source: str, module_tree: ast.Module ) -> None: - """Extract a class and its base classes recursively from the same module.""" - # Skip if already extracted if (module_path, class_name) in extracted_classes: return - # Find the class definition in the module class_node = None for node in ast.walk(module_tree): if isinstance(node, ast.ClassDef) and node.name == class_name: @@ -667,22 +696,18 @@ def extract_class_and_bases( if class_node is None: return - # First, recursively extract base classes from the same module for base in class_node.bases: base_name = None if isinstance(base, ast.Name): base_name = base.id elif isinstance(base, ast.Attribute): - # For module.ClassName, we skip (cross-module inheritance) continue - if base_name and base_name not in existing_definitions: - # Check if base class is defined in the same module + if base_name and base_name not in existing_classes: extract_class_and_bases(base_name, module_path, module_source, module_tree) - # Now extract this class (after its bases, so base classes appear first) if (module_path, class_name) in extracted_classes: - return # Already added by another path + return lines = module_source.split("\n") start_line = class_node.lineno @@ -690,21 +715,17 @@ def extract_class_and_bases( start_line = min(d.lineno for d in class_node.decorator_list) class_source = "\n".join(lines[start_line - 1 : class_node.end_lineno]) - # Extract imports for the class class_imports = extract_imports_for_class(module_tree, class_node, module_source) full_source = class_imports + "\n\n" + class_source if class_imports else class_source - class_code_strings.append(CodeString(code=full_source, file_path=module_path)) + code_strings.append(CodeString(code=full_source, file_path=module_path)) extracted_classes.add((module_path, class_name)) + emitted_class_names.add(class_name) for name, module_name in imported_names.items(): - # Skip if already defined in context - if name in existing_definitions: + if name in existing_classes: continue - - # Try to find the module file using Jedi try: - # Create a script that imports the module to resolve it test_code = f"import {module_name}" script = jedi.Script(test_code, project=jedi.Project(path=project_root_path)) completions = script.goto(1, len(test_code)) @@ -716,123 +737,85 @@ def extract_class_and_bases( if not module_path: continue - # Check if this is a project module (not stdlib/third-party) - if not str(module_path).startswith(str(project_root_path) + os.sep): - continue - if path_belongs_to_site_packages(module_path): + if not is_project_path(module_path, project_root_path): continue - # Get module source and tree - result = get_module_source_and_tree(module_path) - if result is None: + mod_result = get_module_source_and_tree(module_path) + if mod_result is None: continue - module_source, module_tree = result + module_source, module_tree = mod_result - # Extract the class and its base classes extract_class_and_bases(name, module_path, module_source, module_tree) except Exception: logger.debug(f"Error extracting class definition for {name} from {module_name}") continue - return CodeStringsMarkdown(code_strings=class_code_strings) + # --- Step 2: External base class __init__ stubs --- + if external_base_classes: + for cls, name in resolve_classes_from_modules(external_base_classes): + if name in emitted_class_names: + continue + stub = extract_init_stub(cls, name, require_site_packages=False) + if stub is not None: + code_strings.append(stub) + emitted_class_names.add(name) + + # --- Step 3: External direct import __init__ stubs with BFS --- + if external_direct_imports: + processed_classes: set[type] = set() + worklist: list[tuple[type, str, int]] = [ + (cls, name, 0) for cls, name in resolve_classes_from_modules(external_direct_imports) + ] + + while worklist: + cls, class_name, depth = worklist.pop(0) + + if cls in processed_classes: + continue + processed_classes.add(cls) + stub = extract_init_stub(cls, class_name) + if stub is None: + continue -def get_external_base_class_inits(code_context: CodeStringsMarkdown, project_root_path: Path) -> CodeStringsMarkdown: - """Extract __init__ methods from external library base classes. + if class_name not in emitted_class_names: + code_strings.append(stub) + emitted_class_names.add(class_name) - Scans the code context for classes that inherit from external libraries and extracts - just their __init__ methods. This helps the LLM understand constructor signatures - for mocking or instantiation. - """ - import importlib - import inspect - import textwrap - - all_code = "\n".join(cs.code for cs in code_context.code_strings) + if depth < MAX_TRANSITIVE_DEPTH: + for dep_cls in resolve_transitive_type_deps(cls): + if dep_cls not in processed_classes: + worklist.append((dep_cls, dep_cls.__name__, depth + 1)) - try: - tree = ast.parse(all_code) - except SyntaxError: - return CodeStringsMarkdown(code_strings=[]) + return CodeStringsMarkdown(code_strings=code_strings) - imported_names: dict[str, str] = {} - # Use a set to deduplicate external base entries to avoid repeated expensive checks/imports. - external_bases_set: set[tuple[str, str]] = set() - # Local cache to avoid repeated _is_project_module calls for the same module_name. - is_project_cache: dict[str, bool] = {} - for node in ast.walk(tree): - if isinstance(node, ast.ImportFrom) and node.module: - for alias in node.names: - if alias.name != "*": - imported_name = alias.asname if alias.asname else alias.name - imported_names[imported_name] = node.module - elif isinstance(node, ast.ClassDef): - for base in node.bases: - base_name = None - if isinstance(base, ast.Name): - base_name = base.id - elif isinstance(base, ast.Attribute) and isinstance(base.value, ast.Name): - base_name = base.attr - - if base_name and base_name in imported_names: - module_name = imported_names[base_name] - # Check cache first to avoid repeated expensive checks. - cached = is_project_cache.get(module_name) - if cached is None: - is_project = _is_project_module(module_name, project_root_path) - is_project_cache[module_name] = is_project - else: - is_project = cached - - if not is_project: - external_bases_set.add((base_name, module_name)) - - if not external_bases_set: - return CodeStringsMarkdown(code_strings=[]) +def resolve_classes_from_modules(candidates: set[tuple[str, str]]) -> list[tuple[type, str]]: + """Import modules and resolve candidate (class_name, module_name) pairs to class objects.""" + import importlib + import inspect - code_strings: list[CodeString] = [] - # Cache imported modules to avoid repeated importlib.import_module calls. - imported_module_cache: dict[str, object] = {} + resolved: list[tuple[type, str]] = [] + module_cache: dict[str, object] = {} - for base_name, module_name in external_bases_set: + for class_name, module_name in candidates: try: - module = imported_module_cache.get(module_name) + module = module_cache.get(module_name) if module is None: module = importlib.import_module(module_name) - imported_module_cache[module_name] = module - - base_class = getattr(module, base_name, None) - if base_class is None: - continue - - init_method = getattr(base_class, "__init__", None) - if init_method is None: - continue - - try: - init_source = inspect.getsource(init_method) - init_source = textwrap.dedent(init_source) - class_file = Path(inspect.getfile(base_class)) - parts = class_file.parts - if "site-packages" in parts: - idx = parts.index("site-packages") - class_file = Path(*parts[idx + 1 :]) - except (OSError, TypeError): - continue - - class_source = f"class {base_name}:\n" + textwrap.indent(init_source, " ") - code_strings.append(CodeString(code=class_source, file_path=class_file)) + module_cache[module_name] = module + cls = getattr(module, class_name, None) + if cls is not None and inspect.isclass(cls): + resolved.append((cls, class_name)) except (ImportError, ModuleNotFoundError, AttributeError): - logger.debug(f"Failed to extract __init__ for {module_name}.{base_name}") - continue + logger.debug(f"Failed to import {module_name}.{class_name}") - return CodeStringsMarkdown(code_strings=code_strings) + return resolved -MAX_TRANSITIVE_DEPTH = 2 +MAX_TRANSITIVE_DEPTH = 5 def extract_classes_from_type_hint(hint: object) -> list[type]: @@ -902,8 +885,15 @@ def resolve_transitive_type_deps(cls: type) -> list[type]: return deps -def extract_init_stub_for_class(cls: type, class_name: str) -> CodeString | None: - """Extract a stub containing the class definition with only its __init__ method.""" +def extract_init_stub(cls: type, class_name: str, require_site_packages: bool = True) -> CodeString | None: + """Extract a stub containing the class definition with only its __init__ method. + + Args: + cls: The class object to extract __init__ from + class_name: Name to use for the class in the stub + require_site_packages: If True, only extract from site-packages. If False, include stdlib too. + + """ import inspect import textwrap @@ -916,7 +906,7 @@ def extract_init_stub_for_class(cls: type, class_name: str) -> CodeString | None except (OSError, TypeError): return None - if not path_belongs_to_site_packages(class_file): + if require_site_packages and not path_belongs_to_site_packages(class_file): return None try: @@ -934,106 +924,22 @@ def extract_init_stub_for_class(cls: type, class_name: str) -> CodeString | None return CodeString(code=class_source, file_path=class_file) -def get_external_class_inits(code_context: CodeStringsMarkdown, project_root_path: Path) -> CodeStringsMarkdown: - """Extract __init__ methods from directly imported external library classes. +def _is_project_module_cached(module_name: str, project_root_path: Path, cache: dict[str, bool]) -> bool: + cached = cache.get(module_name) + if cached is not None: + return cached + is_project = _is_project_module(module_name, project_root_path) + cache[module_name] = is_project + return is_project - Scans the code context for classes imported from external packages (site-packages) and extracts - their __init__ methods, including transitive type dependencies found in __init__ annotations. - This helps the LLM understand constructor signatures for instantiation in generated tests. - """ - import importlib - import inspect - all_code = "\n".join(cs.code for cs in code_context.code_strings) - - try: - tree = ast.parse(all_code) - except SyntaxError: - return CodeStringsMarkdown(code_strings=[]) - - # Collect all from X import Y statements - imported_names: dict[str, str] = {} - is_project_cache: dict[str, bool] = {} - - # Track classes already defined in the context to avoid duplicates - existing_classes: set[str] = set() - - for node in ast.walk(tree): - if isinstance(node, ast.ImportFrom) and node.module: - for alias in node.names: - if alias.name != "*": - imported_name = alias.asname if alias.asname else alias.name - imported_names[imported_name] = node.module - elif isinstance(node, ast.ClassDef): - existing_classes.add(node.name) - - if not imported_names: - return CodeStringsMarkdown(code_strings=[]) - - # Filter to external-only imports - external_imports: set[tuple[str, str]] = set() - for name, module_name in imported_names.items(): - if name in existing_classes: - continue - cached = is_project_cache.get(module_name) - if cached is None: - is_project = _is_project_module(module_name, project_root_path) - is_project_cache[module_name] = is_project - else: - is_project = cached - if not is_project: - external_imports.add((name, module_name)) - - if not external_imports: - return CodeStringsMarkdown(code_strings=[]) - - code_strings: list[CodeString] = [] - imported_module_cache: dict[str, object] = {} - processed_classes: set[type] = set() - emitted_names: set[str] = set() - - # BFS worklist: (class_object, class_name, depth) - worklist: list[tuple[type, str, int]] = [] - - # Seed the worklist with directly imported classes - for class_name, module_name in external_imports: - try: - module = imported_module_cache.get(module_name) - if module is None: - module = importlib.import_module(module_name) - imported_module_cache[module_name] = module - - cls = getattr(module, class_name, None) - if cls is None or not inspect.isclass(cls): - continue - - worklist.append((cls, class_name, 0)) - except (ImportError, ModuleNotFoundError, AttributeError): - logger.debug(f"Failed to import {module_name}.{class_name}") - continue - - while worklist: - cls, class_name, depth = worklist.pop(0) - - if cls in processed_classes: - continue - processed_classes.add(cls) - - stub = extract_init_stub_for_class(cls, class_name) - if stub is None: - continue - - if class_name not in emitted_names: - code_strings.append(stub) - emitted_names.add(class_name) - - # Resolve transitive type dependencies up to MAX_TRANSITIVE_DEPTH - if depth < MAX_TRANSITIVE_DEPTH: - for dep_cls in resolve_transitive_type_deps(cls): - if dep_cls not in processed_classes: - worklist.append((dep_cls, dep_cls.__name__, depth + 1)) - - return CodeStringsMarkdown(code_strings=code_strings) +def is_project_path(module_path: Path | None, project_root_path: Path) -> bool: + if module_path is None: + return False + # site-packages must be checked first because .venv/site-packages is under project root + if path_belongs_to_site_packages(module_path): + return False + return str(module_path).startswith(str(project_root_path) + os.sep) def _is_project_module(module_name: str, project_root_path: Path) -> bool: @@ -1047,13 +953,7 @@ def _is_project_module(module_name: str, project_root_path: Path) -> bool: else: if spec is None or spec.origin is None: return False - module_path = Path(spec.origin) - # Check if the module is in site-packages (external dependency) - # This must be checked first because .venv/site-packages is under project root - if path_belongs_to_site_packages(module_path): - return False - # Check if the module is within the project root - return str(module_path).startswith(str(project_root_path) + os.sep) + return is_project_path(Path(spec.origin), project_root_path) def extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef, module_source: str) -> str: @@ -1135,78 +1035,6 @@ def is_dunder_method(name: str) -> bool: return len(name) > 4 and name.isascii() and name.startswith("__") and name.endswith("__") -class UsedNameCollector(cst.CSTVisitor): - """Collects all base names referenced in code (for import preservation).""" - - def __init__(self) -> None: - self.used_names: set[str] = set() - self.defined_names: set[str] = set() - - def visit_Name(self, node: cst.Name) -> None: - self.used_names.add(node.value) - - def visit_Attribute(self, node: cst.Attribute) -> bool | None: - base = node.value - while isinstance(base, cst.Attribute): - base = base.value - if isinstance(base, cst.Name): - self.used_names.add(base.value) - return True - - def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None: - self.defined_names.add(node.name.value) - return True - - def visit_ClassDef(self, node: cst.ClassDef) -> bool | None: - self.defined_names.add(node.name.value) - return True - - def visit_Assign(self, node: cst.Assign) -> bool | None: - for target in node.targets: - names = extract_names_from_targets(target.target) - self.defined_names.update(names) - return True - - def visit_AnnAssign(self, node: cst.AnnAssign) -> bool | None: - names = extract_names_from_targets(node.target) - self.defined_names.update(names) - return True - - def get_external_names(self) -> set[str]: - return self.used_names - self.defined_names - {"self", "cls"} - - -def get_imported_names(import_node: cst.Import | cst.ImportFrom) -> set[str]: - """Extract the names made available by an import statement.""" - names: set[str] = set() - if isinstance(import_node, cst.Import): - if isinstance(import_node.names, cst.ImportStar): - return {"*"} - for alias in import_node.names: - if isinstance(alias, cst.ImportAlias): - if alias.asname and isinstance(alias.asname.name, cst.Name): - names.add(alias.asname.name.value) - elif isinstance(alias.name, cst.Name): - names.add(alias.name.value) - elif isinstance(alias.name, cst.Attribute): - # import foo.bar -> accessible as "foo" - base: cst.BaseExpression = alias.name - while isinstance(base, cst.Attribute): - base = base.value - if isinstance(base, cst.Name): - names.add(base.value) - elif isinstance(import_node, cst.ImportFrom): - if isinstance(import_node.names, cst.ImportStar): - return {"*"} - for alias in import_node.names: - if isinstance(alias, cst.ImportAlias): - if alias.asname and isinstance(alias.asname.name, cst.Name): - names.add(alias.asname.name.value) - elif isinstance(alias.name, cst.Name): - names.add(alias.name.value) - return names - - def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode: """Removes the docstring from an indented block if it exists.""" if not isinstance(indented_block.body[0], cst.SimpleStatementLine): @@ -1229,27 +1057,31 @@ def parse_code_and_prune_cst( defs_with_usages = collect_top_level_defs_with_usages(module, target_functions | helpers_of_helper_functions) if code_context_type == CodeContextType.READ_WRITABLE: - filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions, defs_with_usages) + filtered_node, found_target = prune_cst( + module, target_functions, defs_with_usages=defs_with_usages, keep_class_init=True + ) elif code_context_type == CodeContextType.READ_ONLY: - filtered_node, found_target = prune_cst_for_context( + filtered_node, found_target = prune_cst( module, target_functions, - helpers_of_helper_functions, + helpers=helpers_of_helper_functions, remove_docstrings=remove_docstrings, include_target_in_output=False, - include_init_dunder=False, + include_dunder_methods=True, ) elif code_context_type == CodeContextType.TESTGEN: - filtered_node, found_target = prune_cst_for_context( + filtered_node, found_target = prune_cst( module, target_functions, - helpers_of_helper_functions, + helpers=helpers_of_helper_functions, remove_docstrings=remove_docstrings, - include_target_in_output=True, + include_dunder_methods=True, include_init_dunder=True, ) elif code_context_type == CodeContextType.HASHING: - filtered_node, found_target = prune_cst_for_code_hashing(module, target_functions) + filtered_node, found_target = prune_cst( + module, target_functions, remove_docstrings=True, exclude_init_from_targets=True + ) else: raise ValueError(f"Unknown code_context_type: {code_context_type}") # noqa: EM102 @@ -1263,234 +1095,46 @@ def parse_code_and_prune_cst( return "" -def prune_cst_for_read_writable_code( - node: cst.CSTNode, target_functions: set[str], defs_with_usages: dict[str, UsageInfo], prefix: str = "" -) -> tuple[cst.CSTNode | None, bool]: - """Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions. - - Returns - ------- - (filtered_node, found_target): - filtered_node: The modified CST node or None if it should be removed. - found_target: True if a target function was found in this node's subtree. - - """ - if isinstance(node, (cst.Import, cst.ImportFrom)): - return None, False - - if isinstance(node, cst.FunctionDef): - qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value - if qualified_name in target_functions: - return node, True - return None, False - - if isinstance(node, cst.ClassDef): - # Do not recurse into nested classes - if prefix: - return None, False - - class_name = node.name.value - - # Assuming always an IndentedBlock - if not isinstance(node.body, cst.IndentedBlock): - raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004 - class_prefix = f"{prefix}.{class_name}" if prefix else class_name - - # Check if this class contains any target functions - has_target_functions = any( - isinstance(stmt, cst.FunctionDef) and f"{class_prefix}.{stmt.name.value}" in target_functions - for stmt in node.body.body - ) - - # If the class is used as a dependency (not containing target functions), keep it entirely - # This handles cases like enums, dataclasses, and other types used by the target function - if ( - not has_target_functions - and class_name in defs_with_usages - and defs_with_usages[class_name].used_by_qualified_function - ): - return node, True - - new_body = [] - found_target = False - - for stmt in node.body.body: - if isinstance(stmt, cst.FunctionDef): - qualified_name = f"{class_prefix}.{stmt.name.value}" - if qualified_name in target_functions: - new_body.append(stmt) - found_target = True - elif stmt.name.value == "__init__": - new_body.append(stmt) # enable __init__ optimizations - # If no target functions found, remove the class entirely - if not new_body or not found_target: - return None, False - - return node.with_changes(body=cst.IndentedBlock(body=new_body)), found_target - - if isinstance(node, cst.Assign): - for target in node.targets: - names = extract_names_from_targets(target.target) - for name in names: - if name in defs_with_usages and defs_with_usages[name].used_by_qualified_function: - return node, True - return None, False - - if isinstance(node, (cst.AnnAssign, cst.AugAssign)): - names = extract_names_from_targets(node.target) - for name in names: - if name in defs_with_usages and defs_with_usages[name].used_by_qualified_function: - return node, True - return None, False - - # For other nodes, we preserve them only if they contain target functions in their children. - section_names = get_section_names(node) - if not section_names: - return node, False - - updates: dict[str, list[cst.CSTNode] | cst.CSTNode] = {} - found_any_target = False - - for section in section_names: - original_content = getattr(node, section, None) - if isinstance(original_content, (list, tuple)): - new_children = [] - section_found_target = False - for child in original_content: - filtered, found_target = prune_cst_for_read_writable_code( - child, target_functions, defs_with_usages, prefix - ) - if filtered: - new_children.append(filtered) - section_found_target |= found_target - - if section_found_target: - found_any_target = True - updates[section] = new_children - elif original_content is not None: - filtered, found_target = prune_cst_for_read_writable_code( - original_content, target_functions, defs_with_usages, prefix - ) - if found_target: - found_any_target = True - if filtered: - updates[section] = filtered - - if not found_any_target: - return None, False - return (node.with_changes(**updates) if updates else node), True - - -def prune_cst_for_code_hashing( - node: cst.CSTNode, target_functions: set[str], prefix: str = "" -) -> tuple[cst.CSTNode | None, bool]: - """Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions. +def _qualified_name(prefix: str, name: str) -> str: + return f"{prefix}.{name}" if prefix else name - Returns - ------- - (filtered_node, found_target): - filtered_node: The modified CST node or None if it should be removed. - found_target: True if a target function was found in this node's subtree. - """ - if isinstance(node, (cst.Import, cst.ImportFrom)): - return None, False - - if isinstance(node, cst.FunctionDef): - qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value - # For hashing, exclude __init__ methods even if in target_functions - # because they don't affect the semantic behavior being hashed - # But include other dunder methods like __call__ which do affect behavior - if qualified_name in target_functions and node.name.value != "__init__": - new_body = remove_docstring_from_body(node.body) if isinstance(node.body, cst.IndentedBlock) else node.body - return node.with_changes(body=new_body), True - return None, False - - if isinstance(node, cst.ClassDef): - # Do not recurse into nested classes - if prefix: - return None, False - # Assuming always an IndentedBlock - if not isinstance(node.body, cst.IndentedBlock): - raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004 - class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value - new_class_body: list[cst.CSTNode] = [] - found_target = False - - for stmt in node.body.body: - if isinstance(stmt, cst.FunctionDef): - qualified_name = f"{class_prefix}.{stmt.name.value}" - # For hashing, exclude __init__ methods even if in target_functions - # but include other methods like __call__ which affect behavior - if qualified_name in target_functions and stmt.name.value != "__init__": - stmt_with_changes = stmt.with_changes( - body=remove_docstring_from_body(cast("cst.IndentedBlock", stmt.body)) - ) - new_class_body.append(stmt_with_changes) - found_target = True - # If no target functions found, remove the class entirely - if not new_class_body or not found_target: - return None, False - return node.with_changes( - body=cst.IndentedBlock(cast("list[cst.BaseStatement]", new_class_body)) - ) if new_class_body else None, found_target - - # For other nodes, we preserve them only if they contain target functions in their children. - section_names = get_section_names(node) - if not section_names: - return node, False - - updates: dict[str, list[cst.CSTNode] | cst.CSTNode] = {} - found_any_target = False - - for section in section_names: - original_content = getattr(node, section, None) - if isinstance(original_content, (list, tuple)): - new_children = [] - section_found_target = False - for child in original_content: - filtered, found_target = prune_cst_for_code_hashing(child, target_functions, prefix) - if filtered: - new_children.append(filtered) - section_found_target |= found_target - - if section_found_target: - found_any_target = True - updates[section] = new_children - elif original_content is not None: - filtered, found_target = prune_cst_for_code_hashing(original_content, target_functions, prefix) - if found_target: - found_any_target = True - if filtered: - updates[section] = filtered - - if not found_any_target: - return None, False - - return (node.with_changes(**updates) if updates else node), True +def _validate_classdef(node: cst.ClassDef, prefix: str) -> tuple[str, cst.IndentedBlock] | None: + if prefix: + return None + if not isinstance(node.body, cst.IndentedBlock): + raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004 + return _qualified_name(prefix, node.name.value), node.body -def prune_cst_for_context( +def prune_cst( node: cst.CSTNode, target_functions: set[str], - helpers_of_helper_functions: set[str], prefix: str = "", + *, + defs_with_usages: dict[str, UsageInfo] | None = None, + helpers: set[str] | None = None, remove_docstrings: bool = False, - include_target_in_output: bool = False, + include_target_in_output: bool = True, + exclude_init_from_targets: bool = False, + keep_class_init: bool = False, + include_dunder_methods: bool = False, include_init_dunder: bool = False, ) -> tuple[cst.CSTNode | None, bool]: - """Recursively filter the node for code context extraction. + """Unified function to prune CST nodes based on various filtering criteria. Args: node: The CST node to filter target_functions: Set of qualified function names that are targets - helpers_of_helper_functions: Set of helper function qualified names prefix: Current qualified name prefix (for class methods) + defs_with_usages: Dict of definitions with usage info (for READ_WRITABLE mode) + helpers: Set of helper function qualified names (for READ_ONLY/TESTGEN modes) remove_docstrings: Whether to remove docstrings from output - include_target_in_output: If True, include target functions in output (testgen mode) - If False, exclude target functions (read-only mode) - include_init_dunder: If True, include __init__ in dunder methods (testgen mode) - If False, exclude __init__ from dunder methods (read-only mode) + include_target_in_output: Whether to include target functions in output + exclude_init_from_targets: Whether to exclude __init__ from targets (HASHING mode) + keep_class_init: Whether to keep __init__ methods in classes (READ_WRITABLE mode) + include_dunder_methods: Whether to include dunder methods (READ_ONLY/TESTGEN modes) + include_init_dunder: Whether to include __init__ in dunder methods Returns: (filtered_node, found_target): @@ -1502,25 +1146,34 @@ def prune_cst_for_context( return None, False if isinstance(node, cst.FunctionDef): - qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value + qualified_name = _qualified_name(prefix, node.name.value) - # Check if it's a helper of helper function - if qualified_name in helpers_of_helper_functions: + # Check if it's a helper function (higher priority than target) + if helpers and qualified_name in helpers: if remove_docstrings and isinstance(node.body, cst.IndentedBlock): return node.with_changes(body=remove_docstring_from_body(node.body)), True return node, True # Check if it's a target function if qualified_name in target_functions: + # Handle exclude_init_from_targets for HASHING mode + if exclude_init_from_targets and node.name.value == "__init__": + return None, False + if include_target_in_output: if remove_docstrings and isinstance(node.body, cst.IndentedBlock): return node.with_changes(body=remove_docstring_from_body(node.body)), True return node, True return None, True - # Check dunder methods - # For read-only mode, exclude __init__; for testgen mode, include all dunders - if is_dunder_method(node.name.value) and (include_init_dunder or node.name.value != "__init__"): + # Handle class __init__ for READ_WRITABLE mode + if keep_class_init and node.name.value == "__init__": + return node, False + + # Handle dunder methods for READ_ONLY/TESTGEN modes + if include_dunder_methods and is_dunder_method(node.name.value): + if not include_init_dunder and node.name.value == "__init__": + return None, False if remove_docstrings and isinstance(node.body, cst.IndentedBlock): return node.with_changes(body=remove_docstring_from_body(node.body)), False return node, False @@ -1528,26 +1181,44 @@ def prune_cst_for_context( return None, False if isinstance(node, cst.ClassDef): - # Do not recurse into nested classes - if prefix: + result = _validate_classdef(node, prefix) + if result is None: return None, False - # Assuming always an IndentedBlock - if not isinstance(node.body, cst.IndentedBlock): - raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004 + class_prefix, _ = result + class_name = node.name.value - class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value + # Handle dependency classes for READ_WRITABLE mode + if defs_with_usages: + # Check if this class contains any target functions + has_target_functions = any( + isinstance(stmt, cst.FunctionDef) and _qualified_name(class_prefix, stmt.name.value) in target_functions + for stmt in node.body.body + ) - # First pass: detect if there is a target function in the class + # If the class is used as a dependency (not containing target functions), keep it entirely + if ( + not has_target_functions + and class_name in defs_with_usages + and defs_with_usages[class_name].used_by_qualified_function + ): + return node, True + + # Recursively filter each statement in the class body + new_class_body: list[cst.CSTNode] = [] found_in_class = False - new_class_body: list[CSTNode] = [] + for stmt in node.body.body: - filtered, found_target = prune_cst_for_context( + filtered, found_target = prune_cst( stmt, target_functions, - helpers_of_helper_functions, class_prefix, + defs_with_usages=defs_with_usages, + helpers=helpers, remove_docstrings=remove_docstrings, include_target_in_output=include_target_in_output, + exclude_init_from_targets=exclude_init_from_targets, + keep_class_init=keep_class_init, + include_dunder_methods=include_dunder_methods, include_init_dunder=include_init_dunder, ) found_in_class |= found_target @@ -1557,57 +1228,59 @@ def prune_cst_for_context( if not found_in_class: return None, False - if remove_docstrings: - return node.with_changes( - body=remove_docstring_from_body(node.body.with_changes(body=new_class_body)) - ) if new_class_body else None, True + # Apply docstring removal to class if needed + if remove_docstrings and new_class_body: + updated_body = node.body.with_changes(body=new_class_body) + assert isinstance(updated_body, cst.IndentedBlock) + return node.with_changes(body=remove_docstring_from_body(updated_body)), True + return node.with_changes(body=node.body.with_changes(body=new_class_body)) if new_class_body else None, True - # For other nodes, keep the node and recursively filter children + # Handle assignments for READ_WRITABLE mode + if defs_with_usages is not None: + if isinstance(node, (cst.Assign, cst.AnnAssign, cst.AugAssign)): + if is_assignment_used(node, defs_with_usages): + return node, True + return None, False + + # For other nodes, recursively process children section_names = get_section_names(node) if not section_names: return node, False - updates: dict[str, list[cst.CSTNode] | cst.CSTNode] = {} - found_any_target = False - - for section in section_names: - original_content = getattr(node, section, None) - if isinstance(original_content, (list, tuple)): - new_children = [] - section_found_target = False - for child in original_content: - filtered, found_target = prune_cst_for_context( - child, - target_functions, - helpers_of_helper_functions, - prefix, - remove_docstrings=remove_docstrings, - include_target_in_output=include_target_in_output, - include_init_dunder=include_init_dunder, - ) - if filtered: - new_children.append(filtered) - section_found_target |= found_target - - if section_found_target or new_children: - found_any_target |= section_found_target - updates[section] = new_children - elif original_content is not None: - filtered, found_target = prune_cst_for_context( - original_content, + if helpers is not None: + return recurse_sections( + node, + section_names, + lambda child: prune_cst( + child, target_functions, - helpers_of_helper_functions, prefix, + defs_with_usages=defs_with_usages, + helpers=helpers, remove_docstrings=remove_docstrings, include_target_in_output=include_target_in_output, + exclude_init_from_targets=exclude_init_from_targets, + keep_class_init=keep_class_init, + include_dunder_methods=include_dunder_methods, include_init_dunder=include_init_dunder, - ) - found_any_target |= found_target - if filtered: - updates[section] = filtered - - if updates: - return (node.with_changes(**updates), found_any_target) - - return None, False + ), + keep_non_target_children=True, + ) + return recurse_sections( + node, + section_names, + lambda child: prune_cst( + child, + target_functions, + prefix, + defs_with_usages=defs_with_usages, + helpers=helpers, + remove_docstrings=remove_docstrings, + include_target_in_output=include_target_in_output, + exclude_init_from_targets=exclude_init_from_targets, + keep_class_init=keep_class_init, + include_dunder_methods=include_dunder_methods, + include_init_dunder=include_init_dunder, + ), + ) diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/languages/python/context/unused_definition_remover.py similarity index 93% rename from codeflash/context/unused_definition_remover.py rename to codeflash/languages/python/context/unused_definition_remover.py index 5baa51afe..b3d13405f 100644 --- a/codeflash/context/unused_definition_remover.py +++ b/codeflash/languages/python/context/unused_definition_remover.py @@ -15,6 +15,8 @@ from codeflash.models.models import CodeString, CodeStringsMarkdown if TYPE_CHECKING: + from collections.abc import Callable + from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import CodeOptimizationContext, FunctionSource @@ -49,6 +51,69 @@ def extract_names_from_targets(target: cst.CSTNode) -> list[str]: return names +def is_assignment_used(node: cst.CSTNode, definitions: dict[str, UsageInfo], name_prefix: str = "") -> bool: + if isinstance(node, cst.Assign): + for target in node.targets: + names = extract_names_from_targets(target.target) + for name in names: + lookup = f"{name_prefix}{name}" if name_prefix else name + if lookup in definitions and definitions[lookup].used_by_qualified_function: + return True + return False + if isinstance(node, (cst.AnnAssign, cst.AugAssign)): + names = extract_names_from_targets(node.target) + for name in names: + lookup = f"{name_prefix}{name}" if name_prefix else name + if lookup in definitions and definitions[lookup].used_by_qualified_function: + return True + return False + return False + + +def recurse_sections( + node: cst.CSTNode, + section_names: list[str], + prune_fn: Callable[[cst.CSTNode], tuple[cst.CSTNode | None, bool]], + keep_non_target_children: bool = False, +) -> tuple[cst.CSTNode | None, bool]: + updates: dict[str, list[cst.CSTNode] | cst.CSTNode] = {} + found_any_target = False + for section in section_names: + original_content = getattr(node, section, None) + if isinstance(original_content, (list, tuple)): + new_children = [] + section_found_target = False + for child in original_content: + filtered, found_target = prune_fn(child) + if filtered: + new_children.append(filtered) + section_found_target |= found_target + if keep_non_target_children: + if section_found_target or new_children: + found_any_target |= section_found_target + updates[section] = new_children + elif section_found_target: + found_any_target = True + updates[section] = new_children + elif original_content is not None: + filtered, found_target = prune_fn(original_content) + if keep_non_target_children: + found_any_target |= found_target + if filtered: + updates[section] = filtered + elif found_target: + found_any_target = True + if filtered: + updates[section] = filtered + if keep_non_target_children: + if updates: + return node.with_changes(**updates), found_any_target + return None, False + if not found_any_target: + return None, False + return (node.with_changes(**updates) if updates else node), True + + def collect_top_level_definitions( node: cst.CSTNode, definitions: Optional[dict[str, UsageInfo]] = None ) -> dict[str, UsageInfo]: @@ -423,27 +488,9 @@ def remove_unused_definitions_recursively( elif isinstance(statement, (cst.Assign, cst.AnnAssign, cst.AugAssign)): var_used = False - # Check if any variable in this assignment is used - if isinstance(statement, cst.Assign): - for target in statement.targets: - names = extract_names_from_targets(target.target) - for name in names: - class_var_name = f"{class_name}.{name}" - if ( - class_var_name in definitions - and definitions[class_var_name].used_by_qualified_function - ): - var_used = True - method_or_var_used = True - break - elif isinstance(statement, (cst.AnnAssign, cst.AugAssign)): - names = extract_names_from_targets(statement.target) - for name in names: - class_var_name = f"{class_name}.{name}" - if class_var_name in definitions and definitions[class_var_name].used_by_qualified_function: - var_used = True - method_or_var_used = True - break + if is_assignment_used(statement, definitions, name_prefix=f"{class_name}."): + var_used = True + method_or_var_used = True if var_used or class_has_dependencies: new_statements.append(statement) @@ -459,56 +506,19 @@ def remove_unused_definitions_recursively( return node, method_or_var_used or class_has_dependencies - # Handle assignments (Assign and AnnAssign) - if isinstance(node, cst.Assign): - for target in node.targets: - names = extract_names_from_targets(target.target) - for name in names: - if name in definitions and definitions[name].used_by_qualified_function: - return node, True - return None, False - - if isinstance(node, (cst.AnnAssign, cst.AugAssign)): - names = extract_names_from_targets(node.target) - for name in names: - if name in definitions and definitions[name].used_by_qualified_function: - return node, True + # Handle assignments (Assign, AnnAssign, AugAssign) + if isinstance(node, (cst.Assign, cst.AnnAssign, cst.AugAssign)): + if is_assignment_used(node, definitions): + return node, True return None, False # For other nodes, recursively process children section_names = get_section_names(node) if not section_names: return node, False - - updates = {} - found_used = False - - for section in section_names: - original_content = getattr(node, section, None) - if isinstance(original_content, (list, tuple)): - new_children = [] - section_found_used = False - - for child in original_content: - filtered, used = remove_unused_definitions_recursively(child, definitions) - if filtered: - new_children.append(filtered) - section_found_used |= used - - if new_children or section_found_used: - found_used |= section_found_used - updates[section] = new_children - elif original_content is not None: - filtered, used = remove_unused_definitions_recursively(original_content, definitions) - found_used |= used - if filtered: - updates[section] = filtered - if not found_used: - return None, False - if updates: - return node.with_changes(**updates), found_used - - return node, False + return recurse_sections( + node, section_names, lambda child: remove_unused_definitions_recursively(child, definitions) + ) def collect_top_level_defs_with_usages( diff --git a/codeflash/languages/python/support.py b/codeflash/languages/python/support.py index 052e64064..65536a4bc 100644 --- a/codeflash/languages/python/support.py +++ b/codeflash/languages/python/support.py @@ -21,9 +21,25 @@ if TYPE_CHECKING: from collections.abc import Sequence + from codeflash.models.models import FunctionSource + logger = logging.getLogger(__name__) +def function_sources_to_helpers(sources: list[FunctionSource]) -> list[HelperFunction]: + return [ + HelperFunction( + name=fs.only_function_name, + qualified_name=fs.qualified_name, + file_path=fs.file_path, + source_code=fs.source_code, + start_line=fs.jedi_definition.line if fs.jedi_definition else 1, + end_line=fs.jedi_definition.line if fs.jedi_definition else 1, + ) + for fs in sources + ] + + @register_language class PythonSupport: """Python language support implementation. @@ -171,127 +187,39 @@ def discover_tests( # === Code Analysis === def extract_code_context(self, function: FunctionToOptimize, project_root: Path, module_root: Path) -> CodeContext: - """Extract function code and its dependencies. - - Uses jedi and libcst for Python code analysis. - - Args: - function: The function to extract context for. - project_root: Root of the project. - module_root: Root of the module containing the function. - - Returns: - CodeContext with target code and dependencies. + """Extract function code and its dependencies via the canonical context pipeline.""" + from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context - """ try: - source = function.file_path.read_text() + result = get_code_optimization_context(function, project_root) except Exception as e: - logger.exception("Failed to read %s: %s", function.file_path, e) + logger.warning("Failed to extract code context for %s: %s", function.function_name, e) return CodeContext(target_code="", target_file=function.file_path, language=Language.PYTHON) - # Extract the function source - lines = source.splitlines(keepends=True) - if function.starting_line and function.ending_line: - target_lines = lines[function.starting_line - 1 : function.ending_line] - target_code = "".join(target_lines) - else: - target_code = "" - - # Find helper functions - helpers = self.find_helper_functions(function, project_root) - - # Extract imports - import_lines = [] - for line in lines: - stripped = line.strip() - if stripped.startswith(("import ", "from ")): - import_lines.append(stripped) - elif stripped and not stripped.startswith("#"): - # Stop at first non-import, non-comment line - break + helpers = function_sources_to_helpers(result.helper_functions) return CodeContext( - target_code=target_code, + target_code=result.read_writable_code.markdown, target_file=function.file_path, helper_functions=helpers, - read_only_context="", - imports=import_lines, + read_only_context=result.read_only_context_code, + imports=[], language=Language.PYTHON, ) def find_helper_functions(self, function: FunctionToOptimize, project_root: Path) -> list[HelperFunction]: - """Find helper functions called by the target function. - - Uses jedi for Python code analysis. - - Args: - function: The target function to analyze. - project_root: Root of the project. - - Returns: - List of HelperFunction objects. - - """ - helpers: list[HelperFunction] = [] + """Find helper functions called by the target function via the canonical jedi pipeline.""" + from codeflash.languages.python.context.code_context_extractor import get_function_sources_from_jedi try: - import jedi - - from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages - from codeflash.optimization.function_context import belongs_to_function_qualified - - script = jedi.Script(path=function.file_path, project=jedi.Project(path=project_root)) - file_refs = script.get_names(all_scopes=True, definitions=False, references=True) - - qualified_name = function.qualified_name - - for ref in file_refs: - if not ref.full_name or not belongs_to_function_qualified(ref, qualified_name): - continue - - try: - definitions = ref.goto(follow_imports=True, follow_builtin_imports=False) - except Exception: - continue - - for definition in definitions: - definition_path = definition.module_path - if definition_path is None: - continue - - # Check if it's a valid helper (in project, not in target function) - is_valid = ( - str(definition_path).startswith(str(project_root)) - and not path_belongs_to_site_packages(definition_path) - and definition.full_name - and not belongs_to_function_qualified(definition, qualified_name) - and definition.type == "function" - ) - - if is_valid: - helper_qualified_name = get_qualified_name(definition.module_name, definition.full_name) - # Get source code - try: - helper_source = definition.get_line_code() - except Exception: - helper_source = "" - - helpers.append( - HelperFunction( - name=definition.name, - qualified_name=helper_qualified_name, - file_path=definition_path, - source_code=helper_source, - start_line=definition.line or 1, - end_line=definition.line or 1, - ) - ) - + _dict, sources = get_function_sources_from_jedi( + {function.file_path: {function.qualified_name}}, project_root + ) except Exception as e: logger.warning("Failed to find helpers for %s: %s", function.function_name, e) + return [] - return helpers + return function_sources_to_helpers(sources) def find_references( self, function: FunctionToOptimize, project_root: Path, tests_root: Path | None = None, max_files: int = 500 @@ -728,15 +656,6 @@ def get_test_file_suffix(self) -> str: """ return ".py" - def get_comment_prefix(self) -> str: - """Get the comment prefix for Python. - - Returns: - Python single-line comment prefix. - - """ - return "#" - def find_test_root(self, project_root: Path) -> Path | None: """Find the test root directory for a Python project. diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 927e6ed9f..bd27a1a7c 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, Callable import libcst as cst +from git import Repo as GitRepo from rich.console import Group from rich.panel import Panel from rich.syntax import Syntax @@ -71,8 +72,6 @@ from codeflash.code_utils.shell_utils import make_env_with_project_root from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast from codeflash.code_utils.time_utils import humanize_runtime -from codeflash.context import code_context_extractor -from codeflash.context.unused_definition_remover import detect_unused_helper_functions, revert_unused_helper_functions from codeflash.discovery.functions_to_optimize import was_function_previously_optimized from codeflash.either import Failure, Success, is_successful from codeflash.languages import is_java, is_javascript, is_python @@ -80,6 +79,11 @@ from codeflash.languages.current import current_language_support, is_typescript from codeflash.languages.javascript.module_system import detect_module_system from codeflash.languages.javascript.test_runner import clear_created_config_files, get_created_config_files +from codeflash.languages.python.context import code_context_extractor +from codeflash.languages.python.context.unused_definition_remover import ( + detect_unused_helper_functions, + revert_unused_helper_functions, +) from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table, tree_to_markdown from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage, LSPMessageId from codeflash.models.ExperimentMetadata import ExperimentMetadata @@ -2231,6 +2235,7 @@ def setup_and_establish_baseline( if self.args.override_fixtures: restore_conftest(original_conftest_content) cleanup_paths(paths_to_cleanup) + self.cleanup_async_helper_file() return Failure(baseline_result.failure()) original_code_baseline, test_functions_to_remove = baseline_result.unwrap() @@ -2242,6 +2247,7 @@ def setup_and_establish_baseline( if self.args.override_fixtures: restore_conftest(original_conftest_content) cleanup_paths(paths_to_cleanup) + self.cleanup_async_helper_file() return Failure("The threshold for test confidence was not met.") return Success( @@ -2411,7 +2417,10 @@ def process_review( generated_tests_str = "" code_lang = self.function_to_optimize.language for test in generated_tests.generated_tests: - if map_gen_test_file_to_no_of_tests[test.behavior_file_path] > 0: + if any( + test_file.name == test.behavior_file_path.name and count > 0 + for test_file, count in map_gen_test_file_to_no_of_tests.items() + ): formatted_generated_test = format_generated_code( test.generated_original_test_source, self.args.formatter_cmds ) @@ -2551,11 +2560,11 @@ def process_review( console.print(Panel(panel_content, title="Optimization Review", border_style=display_info[1])) if raise_pr or staging_review: - data["root_dir"] = git_root_dir() + data["root_dir"] = git_root_dir(GitRepo(str(self.args.module_root), search_parent_directories=True)) if raise_pr and not staging_review and opt_review_result.review != "low": # Ensure root_dir is set for PR creation (needed for async functions that skip opt_review) if "root_dir" not in data: - data["root_dir"] = git_root_dir() + data["root_dir"] = git_root_dir(GitRepo(str(self.args.module_root), search_parent_directories=True)) data["git_remote"] = self.args.git_remote # Remove language from data dict as check_create_pr doesn't accept it pr_data = {k: v for k, v in data.items() if k != "language"} @@ -2610,6 +2619,13 @@ def revert_code_and_helpers(self, original_helper_code: dict[Path, str]) -> None self.write_code_and_helpers( self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path ) + self.cleanup_async_helper_file() + + def cleanup_async_helper_file(self) -> None: + from codeflash.code_utils.instrument_existing_tests import ASYNC_HELPER_FILENAME + + helper_path = self.project_root / ASYNC_HELPER_FILENAME + helper_path.unlink(missing_ok=True) def establish_original_code_baseline( self, @@ -2627,7 +2643,10 @@ def establish_original_code_baseline( from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function success = add_async_decorator_to_function( - self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR + self.function_to_optimize.file_path, + self.function_to_optimize, + TestingMode.BEHAVIOR, + project_root=self.project_root, ) # Instrument codeflash capture @@ -2692,7 +2711,10 @@ def establish_original_code_baseline( from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function add_async_decorator_to_function( - self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE + self.function_to_optimize.file_path, + self.function_to_optimize, + TestingMode.PERFORMANCE, + project_root=self.project_root, ) try: @@ -2866,7 +2888,10 @@ def run_optimized_candidate( from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function add_async_decorator_to_function( - self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR + self.function_to_optimize.file_path, + self.function_to_optimize, + TestingMode.BEHAVIOR, + project_root=self.project_root, ) try: @@ -2961,7 +2986,10 @@ def run_optimized_candidate( from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function add_async_decorator_to_function( - self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE + self.function_to_optimize.file_path, + self.function_to_optimize, + TestingMode.PERFORMANCE, + project_root=self.project_root, ) try: @@ -3330,7 +3358,10 @@ def run_concurrency_benchmark( try: # Add concurrency decorator to the source function add_async_decorator_to_function( - self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.CONCURRENCY + self.function_to_optimize.file_path, + self.function_to_optimize, + TestingMode.CONCURRENCY, + project_root=self.project_root, ) # Run the concurrency benchmark tests diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index de7e5a8d4..e3c564ecb 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -183,18 +183,54 @@ def get_optimizable_functions(self) -> tuple[dict[Path, list[FunctionToOptimize] """Discover functions to optimize.""" from codeflash.discovery.functions_to_optimize import get_functions_to_optimize - return get_functions_to_optimize( + # In worktree mode for git-diff discovery, file paths come from the original repo + # (via get_git_diff using cwd), but module_root/project_root have been mirrored to + # the worktree. Use the original roots for filtering so path comparisons match, + # then remap the discovered file paths to the worktree. + project_root = self.args.project_root + module_root = self.args.module_root + use_original_roots = ( + self.current_worktree and self.original_args_and_test_cfg and not self.args.all and not self.args.file + ) + if use_original_roots: + assert self.original_args_and_test_cfg is not None + original_args, _ = self.original_args_and_test_cfg + project_root = original_args.project_root + module_root = original_args.module_root + + result = get_functions_to_optimize( optimize_all=self.args.all, replay_test=self.args.replay_test, file=self.args.file, only_get_this_function=self.args.function, test_cfg=self.test_cfg, ignore_paths=self.args.ignore_paths, - project_root=self.args.project_root, - module_root=self.args.module_root, + project_root=project_root, + module_root=module_root, previous_checkpoint_functions=self.args.previous_checkpoint_functions, ) + # Remap discovered file paths from the original repo to the worktree so + # downstream optimization reads/writes happen in the worktree. + if use_original_roots: + import dataclasses + + assert self.current_worktree is not None + original_git_root = git_root_dir() + file_to_funcs, count, trace = result + remapped: dict[Path, list[FunctionToOptimize]] = {} + for file_path, funcs in file_to_funcs.items(): + new_path = mirror_path(Path(file_path), original_git_root, self.current_worktree) + remapped[new_path] = [ + dataclasses.replace( + func, file_path=mirror_path(func.file_path, original_git_root, self.current_worktree) + ) + for func in funcs + ] + return remapped, count, trace + + return result + def create_function_optimizer( self, function_to_optimize: FunctionToOptimize, diff --git a/codeflash/verification/parse_line_profile_test_output.py b/codeflash/verification/parse_line_profile_test_output.py index 34b27bdb3..4ef799425 100644 --- a/codeflash/verification/parse_line_profile_test_output.py +++ b/codeflash/verification/parse_line_profile_test_output.py @@ -6,16 +6,14 @@ import json import linecache import os -from typing import TYPE_CHECKING, Optional +from pathlib import Path +from typing import Optional import dill as pickle from codeflash.code_utils.tabulate import tabulate from codeflash.languages import is_python -if TYPE_CHECKING: - from pathlib import Path - def show_func( filename: str, start_lineno: int, func_name: str, timings: list[tuple[int, int, float]], unit: float @@ -80,9 +78,7 @@ def show_text(stats: dict) -> str: return out_table -def show_text_non_python( - stats: dict, line_contents: dict[tuple[str, int], str] -) -> str: +def show_text_non_python(stats: dict, line_contents: dict[tuple[str, int], str]) -> str: """Show text for non-Python timings using profiler-provided line contents.""" out_table = "" out_table += "# Timer unit: {:g} s\n".format(stats["unit"]) @@ -100,13 +96,13 @@ def show_text_non_python( table_rows = [] for lineno, nhits, time in timings: percent = "" if total_time == 0 else "%5.1f" % (100 * time / total_time) - time_disp = "%5.1f" % time + time_disp = f"{time:5.1f}" if len(time_disp) > default_column_sizes["time"]: - time_disp = "%5.1g" % time + time_disp = f"{time:5.1g}" perhit = (float(time) / nhits) if nhits > 0 else 0.0 - perhit_disp = "%5.1f" % perhit + perhit_disp = f"{perhit:5.1f}" if len(perhit_disp) > default_column_sizes["perhit"]: - perhit_disp = "%5.1g" % perhit + perhit_disp = f"{perhit:5.1g}" nhits_disp = "%d" % nhits # noqa: UP031 if len(nhits_disp) > default_column_sizes["hits"]: nhits_disp = f"{nhits:g}" @@ -115,11 +111,7 @@ def show_text_non_python( table_cols = ("Hits", "Time", "Per Hit", "% Time", "Line Contents") out_table += tabulate( - headers=table_cols, - tabular_data=table_rows, - tablefmt="pipe", - colglobalalign=None, - preserve_whitespace=True, + headers=table_cols, tabular_data=table_rows, tablefmt="pipe", colglobalalign=None, preserve_whitespace=True ) out_table += "\n" return out_table @@ -159,9 +151,7 @@ def parse_line_profile_results(line_profiler_output_file: Optional[Path]) -> dic line_num = int(line_str) line_num = int(line_num) - lines_by_file.setdefault(file_path, []).append( - (line_num, int(stats.get("hits", 0)), int(stats.get("time", 0))) - ) + lines_by_file.setdefault(file_path, []).append((line_num, int(stats.get("hits", 0)), int(stats.get("time", 0)))) line_contents[(file_path, line_num)] = stats.get("content", "") for file_path, line_stats in lines_by_file.items(): @@ -169,7 +159,7 @@ def parse_line_profile_results(line_profiler_output_file: Optional[Path]) -> dic if not sorted_line_stats: continue start_lineno = sorted_line_stats[0][0] - grouped_timings[(file_path, start_lineno, os.path.basename(file_path))] = sorted_line_stats + grouped_timings[(file_path, start_lineno, Path(file_path).name)] = sorted_line_stats stats_dict["timings"] = grouped_timings stats_dict["unit"] = 1e-9 diff --git a/codeflash/version.py b/codeflash/version.py index 6d60ab0c2..5c0c09b55 100644 --- a/codeflash/version.py +++ b/codeflash/version.py @@ -1,2 +1,2 @@ # These version placeholders will be replaced by uv-dynamic-versioning during build. -__version__ = "0.20.0.post510.dev0+b8932209" +__version__ = "0.20.1" diff --git a/mypy_allowlist.txt b/mypy_allowlist.txt index 6a070b606..e08b14e22 100644 --- a/mypy_allowlist.txt +++ b/mypy_allowlist.txt @@ -6,8 +6,8 @@ codeflash/result/explanation.py codeflash/result/critic.py codeflash/version.py codeflash/optimization/__init__.py -codeflash/context/__init__.py -codeflash/context/code_context_extractor.py +codeflash/languages/python/context/__init__.py +codeflash/languages/python/context/code_context_extractor.py codeflash/discovery/__init__.py codeflash/__init__.py codeflash/models/ExperimentMetadata.py diff --git a/packages/codeflash/runtime/capture.js b/packages/codeflash/runtime/capture.js index 0fdcc5784..d5489aa37 100644 --- a/packages/codeflash/runtime/capture.js +++ b/packages/codeflash/runtime/capture.js @@ -113,21 +113,26 @@ function checkSharedTimeLimit() { /** * Get the current loop index for a specific invocation. - * The loop index represents how many times ALL test files have been run through. - * This is the batch count from the loop-runner. + * When using external loop-runner (Jest), returns the batch number directly. + * When using internal looping (Vitest), tracks and returns the invocation count. + * * @param {string} invocationKey - Unique key for this test invocation - * @returns {number} The current batch number (loop index) + * @returns {number} The loop index for timing markers (1-based) */ function getInvocationLoopIndex(invocationKey) { - // Track local loop count for stopping logic (increments on each call) + // When using external loop-runner, use the batch number directly + // This is reliable because Jest resets module state between batches + const currentBatch = process.env.CODEFLASH_PERF_CURRENT_BATCH; + if (currentBatch !== undefined) { + return parseInt(currentBatch, 10); + } + + // For internal looping (Vitest), track the count locally if (!sharedPerfState.invocationLoopCounts[invocationKey]) { sharedPerfState.invocationLoopCounts[invocationKey] = 0; } ++sharedPerfState.invocationLoopCounts[invocationKey]; - - // Return the batch number as the loop index for timing markers - // This represents how many times all test files have been run through - return parseInt(process.env.CODEFLASH_PERF_CURRENT_BATCH || '1', 10); + return sharedPerfState.invocationLoopCounts[invocationKey]; } /** @@ -693,11 +698,9 @@ function capturePerf(funcName, lineId, fn, ...args) { // If not set, we're in Vitest mode and need to do all loops internally const hasExternalLoopRunner = process.env.CODEFLASH_PERF_CURRENT_BATCH !== undefined; - // Batched looping: run BATCH_SIZE loops per capturePerf call when using loop-runner + // When using external loop-runner (Jest), execute only once per call - the loop-runner handles batching // For Vitest (no loop-runner), do all loops internally in a single call - const batchSize = shouldLoop - ? (hasExternalLoopRunner ? getPerfBatchSize() : getPerfLoopCount()) - : 1; + const batchSize = hasExternalLoopRunner ? 1 : (shouldLoop ? getPerfLoopCount() : 1); // Initialize runtime tracking for this invocation if needed if (!sharedPerfState.invocationRuntimes[invocationKey]) { @@ -710,21 +713,21 @@ function capturePerf(funcName, lineId, fn, ...args) { for (let batchIndex = 0; batchIndex < batchSize; batchIndex++) { // Check shared time limit BEFORE each iteration - if (shouldLoop && checkSharedTimeLimit()) { + if (!hasExternalLoopRunner && shouldLoop && checkSharedTimeLimit()) { break; } // Check if this invocation has already reached stability - if (getPerfStabilityCheck() && sharedPerfState.stableInvocations[invocationKey]) { + if (!hasExternalLoopRunner && getPerfStabilityCheck() && sharedPerfState.stableInvocations[invocationKey]) { break; } - // Get the loop index (batch number) for timing markers + // Get the loop index for timing markers const loopIndex = getInvocationLoopIndex(invocationKey); // Check if we've exceeded max loops for this invocation const totalIterations = getTotalIterations(invocationKey); - if (totalIterations > getPerfLoopCount()) { + if (!hasExternalLoopRunner && totalIterations > getPerfLoopCount()) { break; } @@ -776,7 +779,7 @@ function capturePerf(funcName, lineId, fn, ...args) { } // Check stability after accumulating enough samples - if (getPerfStabilityCheck() && runtimes.length >= getPerfMinLoops()) { + if (!hasExternalLoopRunner && getPerfStabilityCheck() && runtimes.length >= getPerfMinLoops()) { const window = getStabilityWindow(); if (shouldStopStability(runtimes, window, getPerfMinLoops())) { sharedPerfState.stableInvocations[invocationKey] = true; @@ -785,7 +788,7 @@ function capturePerf(funcName, lineId, fn, ...args) { } // If we had an error, stop looping - if (lastError) { + if (!hasExternalLoopRunner && lastError) { break; } } diff --git a/packages/codeflash/runtime/loop-runner.js b/packages/codeflash/runtime/loop-runner.js index c6d476f1f..fc0b88f32 100644 --- a/packages/codeflash/runtime/loop-runner.js +++ b/packages/codeflash/runtime/loop-runner.js @@ -35,69 +35,113 @@ const path = require('path'); const fs = require('fs'); /** - * Validates that a jest-runner path is valid by checking for package.json. - * @param {string} jestRunnerPath - Path to check - * @returns {boolean} True if valid jest-runner package + * Recursively find jest-runner package in node_modules. + * Works with any package manager (npm, yarn, pnpm) by searching for + * jest-runner/package.json anywhere in the tree. + * + * @param {string} nodeModulesPath - Path to node_modules directory + * @param {number} maxDepth - Maximum recursion depth (default: 5) + * @returns {string|null} Path to jest-runner or null if not found */ -function isValidJestRunnerPath(jestRunnerPath) { - if (!fs.existsSync(jestRunnerPath)) { - return false; +function findJestRunnerRecursive(nodeModulesPath, maxDepth = 5) { + function search(dir, depth) { + if (depth > maxDepth || !fs.existsSync(dir)) return null; + + try { + let entries = fs.readdirSync(dir, { withFileTypes: true }); + + // Sort entries: prefer higher versions for jest-runner@X.Y.Z directories + entries = entries.slice().sort((a, b) => { + const aMatch = a.name.match(/^jest-runner@(\d+)/); + const bMatch = b.name.match(/^jest-runner@(\d+)/); + if (aMatch && bMatch) { + return parseInt(bMatch[1], 10) - parseInt(aMatch[1], 10); + } + return a.name.localeCompare(b.name); + }); + + for (const entry of entries) { + if (!entry.isDirectory()) continue; + + const entryPath = path.join(dir, entry.name); + + // Found jest-runner directory - check if it's a valid package + if (entry.name === 'jest-runner') { + const pkgJsonPath = path.join(entryPath, 'package.json'); + if (fs.existsSync(pkgJsonPath)) { + try { + const pkgJson = JSON.parse(fs.readFileSync(pkgJsonPath, 'utf8')); + if (pkgJson.name === 'jest-runner') { + return entryPath; + } + } catch (e) { + // Ignore JSON parse errors + } + } + } + + // Recurse into: + // - node_modules subdirectories + // - scoped packages (@org/pkg) + // - hidden directories (.pnpm, .yarn, etc.) + // - pnpm versioned directories (jest-runner@30.0.5) + const shouldRecurse = entry.name === 'node_modules' || + entry.name.startsWith('@') || + entry.name === '.pnpm' || entry.name === '.yarn' || + entry.name.startsWith('jest-runner@'); + + if (shouldRecurse) { + const result = search(entryPath, depth + 1); + if (result) return result; + } + } + } catch (e) { + // Ignore permission errors + } + + return null; } - const packageJsonPath = path.join(jestRunnerPath, 'package.json'); - return fs.existsSync(packageJsonPath); + + return search(nodeModulesPath, 0); } /** - * Resolve jest-runner with monorepo support. - * Uses CODEFLASH_MONOREPO_ROOT environment variable if available, - * otherwise walks up the directory tree looking for node_modules/jest-runner. + * Resolve jest-runner from the PROJECT's node_modules (not codeflash's). + * + * Uses recursive search to find jest-runner anywhere in node_modules, + * working with any package manager (npm, yarn, pnpm). * * @returns {string} Path to jest-runner package * @throws {Error} If jest-runner cannot be found */ function resolveJestRunner() { - // Try standard resolution first (works in simple projects) - try { - return require.resolve('jest-runner'); - } catch (e) { - // Standard resolution failed - try monorepo-aware resolution - } + const monorepoMarkers = ['yarn.lock', 'pnpm-workspace.yaml', 'lerna.json', 'package-lock.json']; + + // Walk up from cwd to find all potential node_modules locations + let currentDir = process.cwd(); + const visitedDirs = new Set(); // If Python detected a monorepo root, check there first const monorepoRoot = process.env.CODEFLASH_MONOREPO_ROOT; - if (monorepoRoot) { - const jestRunnerPath = path.join(monorepoRoot, 'node_modules', 'jest-runner'); - if (isValidJestRunnerPath(jestRunnerPath)) { - return jestRunnerPath; - } + if (monorepoRoot && !visitedDirs.has(monorepoRoot)) { + visitedDirs.add(monorepoRoot); + const result = findJestRunnerRecursive(path.join(monorepoRoot, 'node_modules')); + if (result) return result; } - // Fallback: Walk up from cwd looking for node_modules/jest-runner - const monorepoMarkers = ['yarn.lock', 'pnpm-workspace.yaml', 'lerna.json', 'package-lock.json']; - let currentDir = process.cwd(); - const visitedDirs = new Set(); - while (currentDir !== path.dirname(currentDir)) { - // Avoid infinite loops if (visitedDirs.has(currentDir)) break; visitedDirs.add(currentDir); - // Try node_modules/jest-runner at this level - const jestRunnerPath = path.join(currentDir, 'node_modules', 'jest-runner'); - if (isValidJestRunnerPath(jestRunnerPath)) { - return jestRunnerPath; - } + const result = findJestRunnerRecursive(path.join(currentDir, 'node_modules')); + if (result) return result; - // Check if this is a workspace root (has monorepo markers) + // Check if this is a workspace root - stop after this const isWorkspaceRoot = monorepoMarkers.some(marker => fs.existsSync(path.join(currentDir, marker)) ); - if (isWorkspaceRoot) { - // Found workspace root but no jest-runner - stop searching - break; - } - + if (isWorkspaceRoot) break; currentDir = path.dirname(currentDir); } @@ -120,10 +164,15 @@ let jestVersion = 0; try { const jestRunnerPath = resolveJestRunner(); - const internalRequire = createRequire(jestRunnerPath); - // Try to get the TestRunner class (Jest 30+) - const jestRunner = internalRequire(jestRunnerPath); + // Read the package.json to find the actual entry point and version + const pkgJsonPath = path.join(jestRunnerPath, 'package.json'); + const pkgJson = JSON.parse(fs.readFileSync(pkgJsonPath, 'utf8')); + + // Require using the full path to the entry point + const entryPoint = path.join(jestRunnerPath, pkgJson.main || 'build/index.js'); + const jestRunner = require(entryPoint); + TestRunner = jestRunner.default || jestRunner.TestRunner; if (TestRunner && TestRunner.prototype && typeof TestRunner.prototype.runTests === 'function') { @@ -131,9 +180,11 @@ try { jestVersion = 30; jestRunnerAvailable = true; } else { - // Try Jest 29 style import + // Try Jest 29 style import - runTest is in build/runTest.js try { - runTest = internalRequire('./runTest').default; + const runTestPath = path.join(jestRunnerPath, 'build', 'runTest.js'); + const runTestModule = require(runTestPath); + runTest = runTestModule.default; if (typeof runTest === 'function') { // Jest 29 - use direct runTest function jestVersion = 29; @@ -141,17 +192,23 @@ try { } } catch (e29) { // Neither Jest 29 nor 30 style import worked - const errorMsg = `Found jest-runner at ${jestRunnerPath} but could not load it. ` + - `This may indicate an unsupported Jest version. ` + - `Supported versions: Jest 29.x and Jest 30.x`; - console.error(errorMsg); jestRunnerAvailable = false; } } } catch (e) { - // jest-runner not installed - this is expected for Vitest projects - // The runner will throw a helpful error if someone tries to use it without jest-runner - jestRunnerAvailable = false; + // try to directly import jest-runner + try { + const jestRunner = require('jest-runner'); + TestRunner = jestRunner.default || jestRunner.TestRunner; + if (TestRunner && TestRunner.prototype && typeof TestRunner.prototype.runTests === 'function') { + jestVersion = 30; + jestRunnerAvailable = true; + } else { + jestRunnerAvailable = false; + } + } catch (e2) { + jestRunnerAvailable = false; + } } // Configuration @@ -233,15 +290,12 @@ class CodeflashLoopRunner { this._context = context || {}; this._eventEmitter = new SimpleEventEmitter(); - // For Jest 30+, create an instance of the base TestRunner for delegation - if (jestVersion >= 30) { - if (!TestRunner) { - throw new Error( - `Jest ${jestVersion} detected but TestRunner class not available. ` + - `This indicates an internal error in loop-runner initialization.` - ); - } - this._baseRunner = new TestRunner(globalConfig, context); + // For Jest 30+, verify TestRunner is available (we create fresh instances per batch) + if (jestVersion >= 30 && !TestRunner) { + throw new Error( + `Jest ${jestVersion} detected but TestRunner class not available. ` + + `This indicates an internal error in loop-runner initialization.` + ); } } @@ -270,7 +324,7 @@ class CodeflashLoopRunner { * @param {Object} options - Jest runner options * @returns {Promise} */ - async runTests(tests, watcher, options) { + async runTests(tests, watcher, ...rest) { const startTime = Date.now(); let batchCount = 0; let hasFailure = false; @@ -289,13 +343,11 @@ class CodeflashLoopRunner { // Check time limit BEFORE each batch if (batchCount > MIN_BATCHES && checkTimeLimit()) { - console.log(`[codeflash] Time limit reached after ${batchCount - 1} batches (${Date.now() - startTime}ms elapsed)`); break; } // Check if interrupted if (watcher.isInterrupted()) { - console.log(`[codeflash] Watcher is interrupted`) break; } @@ -303,57 +355,54 @@ class CodeflashLoopRunner { process.env.CODEFLASH_PERF_CURRENT_BATCH = String(batchCount); // Run all test files in this batch - const batchResult = await this._runAllTestsOnce(tests, watcher, options); + const batchResult = await this._runAllTestsOnce(tests, watcher, ...rest); allConsoleOutput += batchResult.consoleOutput; - // if (batchResult.hasFailure) { - // hasFailure = true; - // break; - // } - // Check time limit AFTER each batch if (checkTimeLimit()) { - console.log(`[codeflash] Time limit reached after ${batchCount} batches (${Date.now() - startTime}ms elapsed)`); break; } } const totalTimeMs = Date.now() - startTime; - console.log(`[codeflash] now: ${Date.now()}`) // Output all collected console logs - this is critical for timing marker extraction // The console output contains the !######...######! timing markers from capturePerf if (allConsoleOutput) { process.stdout.write(allConsoleOutput); } - - console.log(`[codeflash] Batched runner completed: ${batchCount} batches, ${tests.length} test files, ${totalTimeMs}ms total`); } /** * Run all test files once (one batch). * Uses different approaches for Jest 29 vs Jest 30. */ - async _runAllTestsOnce(tests, watcher, options) { + async _runAllTestsOnce(tests, watcher, ...args) { if (jestVersion >= 30) { - return this._runAllTestsOnceJest30(tests, watcher, options); + return this._runAllTestsOnceJest30(tests, watcher, ...args); } else { return this._runAllTestsOnceJest29(tests, watcher); } } /** - * Jest 30+ implementation - delegates to base TestRunner and collects results. + * Jest 30+ implementation - creates a fresh TestRunner for each batch to avoid + * state corruption issues that occur when reusing runners across batches. */ - async _runAllTestsOnceJest30(tests, watcher, options) { + async _runAllTestsOnceJest30(tests, watcher, ...args) { let hasFailure = false; let allConsoleOutput = ''; // For Jest 30, we need to collect results through event listeners const resultsCollector = []; - // Subscribe to events from the base runner - const unsubscribeSuccess = this._baseRunner.on('test-file-success', (testData) => { + // Create a FRESH TestRunner instance for each batch + // Jest 30's TestRunner corrupts its internal state after running tests, + // so we cannot reuse the same instance across multiple batches + const batchRunner = new TestRunner(this._globalConfig, this._context); + + // Subscribe to events from the batch runner + const unsubscribeSuccess = batchRunner.on('test-file-success', (testData) => { const [test, result] = testData; resultsCollector.push({ test, result, success: true }); @@ -369,7 +418,7 @@ class CodeflashLoopRunner { this._eventEmitter.emit('test-file-success', testData); }); - const unsubscribeFailure = this._baseRunner.on('test-file-failure', (testData) => { + const unsubscribeFailure = batchRunner.on('test-file-failure', (testData) => { const [test, error] = testData; resultsCollector.push({ test, error, success: false }); hasFailure = true; @@ -378,14 +427,14 @@ class CodeflashLoopRunner { this._eventEmitter.emit('test-file-failure', testData); }); - const unsubscribeStart = this._baseRunner.on('test-file-start', (testData) => { + const unsubscribeStart = batchRunner.on('test-file-start', (testData) => { // Forward to our event emitter this._eventEmitter.emit('test-file-start', testData); }); try { - // Run tests using the base runner (always serial for benchmarking) - await this._baseRunner.runTests(tests, watcher, { ...options, serial: true }); + // Run tests using the fresh batch runner (always serial for benchmarking) + await batchRunner.runTests(tests, watcher, ...args); } finally { // Cleanup subscriptions if (typeof unsubscribeSuccess === 'function') unsubscribeSuccess(); diff --git a/pyproject.toml b/pyproject.toml index ea5f2140a..263d2f080 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ description = "Client for codeflash.ai - automatic code performance optimization authors = [{ name = "CodeFlash Inc.", email = "contact@codeflash.ai" }] requires-python = ">=3.9" readme = "README.md" -license = {text = "BSL-1.1"} +license-files = ["LICENSE"] keywords = [ "codeflash", "performance", @@ -356,4 +356,3 @@ markers = [ [build-system] requires = ["hatchling", "uv-dynamic-versioning"] build-backend = "hatchling.build" - diff --git a/tests/benchmarks/test_benchmark_code_extract_code_context.py b/tests/benchmarks/test_benchmark_code_extract_code_context.py index bb6140916..77c435720 100644 --- a/tests/benchmarks/test_benchmark_code_extract_code_context.py +++ b/tests/benchmarks/test_benchmark_code_extract_code_context.py @@ -1,8 +1,8 @@ from argparse import Namespace from pathlib import Path -from codeflash.context.code_context_extractor import get_code_optimization_context from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context from codeflash.models.models import FunctionParent from codeflash.optimization.optimizer import Optimizer diff --git a/tests/scripts/end_to_end_test_async.py b/tests/scripts/end_to_end_test_async.py index 0b4bf8957..0e38ae797 100644 --- a/tests/scripts/end_to_end_test_async.py +++ b/tests/scripts/end_to_end_test_async.py @@ -13,7 +13,7 @@ def run_test(expected_improvement_pct: int) -> bool: CoverageExpectation( function_name="retry_with_backoff", expected_coverage=100.0, - expected_lines=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], + expected_lines=[9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], ) ], ) diff --git a/tests/test_async_run_and_parse_tests.py b/tests/test_async_run_and_parse_tests.py index 14bb8b2db..e9d85bf68 100644 --- a/tests/test_async_run_and_parse_tests.py +++ b/tests/test_async_run_and_parse_tests.py @@ -8,7 +8,9 @@ import pytest from codeflash.code_utils.instrument_existing_tests import ( + ASYNC_HELPER_FILENAME, add_async_decorator_to_function, + get_decorator_name_for_mode, inject_profiling_into_existing_test, ) from codeflash.discovery.functions_to_optimize import FunctionToOptimize @@ -55,16 +57,23 @@ async def test_async_sort(): func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True) # For async functions, instrument the source module directly with decorators - source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR) + source_success = add_async_decorator_to_function( + fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path + ) assert source_success - # Verify the file was modified + # Verify the file was modified with exact expected output instrumented_source = fto_path.read_text("utf-8") - assert ( - '''import asyncio\nfrom typing import List, Union\n\nfrom codeflash.code_utils.codeflash_wrap_decorator import \\\n codeflash_behavior_async\n\n\n@codeflash_behavior_async\nasync def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation for testing.\n """\n print("codeflash stdout: Async sorting list")\n \n await asyncio.sleep(0.01)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n print(f"result: {result}")\n return result\n\n\nclass AsyncBubbleSorter:\n """Class with async sorting method for testing."""\n \n async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation within a class.\n """\n print("codeflash stdout: AsyncBubbleSorter.sorter() called")\n \n # Add some async delay\n await asyncio.sleep(0.005)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n return result\n''' - in instrumented_source + from codeflash.code_utils.formatter import sort_imports + + decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR) + decorated_original = original_code.replace( + "async def async_sorter", f"@{decorator_name}\nasync def async_sorter" ) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{decorated_original}" + expected = sort_imports(code=code_with_import, float_to_top=True) + assert instrumented_source.strip() == expected.strip() # Add codeflash capture instrument_codeflash_capture(func, {}, tests_root) @@ -142,6 +151,9 @@ async def test_async_sort(): test_path.unlink() if test_path_perf.exists(): test_path_perf.unlink() + helper_path = project_root_path / ASYNC_HELPER_FILENAME + if helper_path.exists(): + helper_path.unlink() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -182,7 +194,9 @@ async def test_async_class_sort(): is_async=True, ) - source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR) + source_success = add_async_decorator_to_function( + fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path + ) assert source_success @@ -264,6 +278,9 @@ async def test_async_class_sort(): test_path.unlink() if test_path_perf.exists(): test_path_perf.unlink() + helper_path = project_root_path / ASYNC_HELPER_FILENAME + if helper_path.exists(): + helper_path.unlink() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -294,16 +311,23 @@ async def test_async_perf(): func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True) # Instrument the source module with async performance decorators - source_success = add_async_decorator_to_function(fto_path, func, TestingMode.PERFORMANCE) + source_success = add_async_decorator_to_function( + fto_path, func, TestingMode.PERFORMANCE, project_root=project_root_path + ) assert source_success # Verify the file was modified instrumented_source = fto_path.read_text("utf-8") - assert ( - instrumented_source - == '''import asyncio\nfrom typing import List, Union\n\nfrom codeflash.code_utils.codeflash_wrap_decorator import \\\n codeflash_performance_async\n\n\n@codeflash_performance_async\nasync def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation for testing.\n """\n print("codeflash stdout: Async sorting list")\n \n await asyncio.sleep(0.01)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n print(f"result: {result}")\n return result\n\n\nclass AsyncBubbleSorter:\n """Class with async sorting method for testing."""\n \n async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation within a class.\n """\n print("codeflash stdout: AsyncBubbleSorter.sorter() called")\n \n # Add some async delay\n await asyncio.sleep(0.005)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n return result\n''' + from codeflash.code_utils.formatter import sort_imports + + decorator_name = get_decorator_name_for_mode(TestingMode.PERFORMANCE) + decorated_original = original_code.replace( + "async def async_sorter", f"@{decorator_name}\nasync def async_sorter" ) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{decorated_original}" + expected = sort_imports(code=code_with_import, float_to_top=True) + assert instrumented_source.strip() == expected.strip() instrument_codeflash_capture(func, {}, tests_root) @@ -359,6 +383,9 @@ async def test_async_perf(): # Clean up test files if test_path.exists(): test_path.unlink() + helper_path = project_root_path / ASYNC_HELPER_FILENAME + if helper_path.exists(): + helper_path.unlink() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -404,68 +431,24 @@ async def async_error_function(lst): function_name="async_error_function", parents=[], file_path=Path(fto_path), is_async=True ) - source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR) + source_success = add_async_decorator_to_function( + fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path + ) assert source_success # Verify the file was modified instrumented_source = fto_path.read_text("utf-8") - expected_instrumented_source = """import asyncio -from typing import List, Union - -from codeflash.code_utils.codeflash_wrap_decorator import \\ - codeflash_behavior_async + from codeflash.code_utils.formatter import sort_imports - -async def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]: - \"\"\" - Async bubble sort implementation for testing. - \"\"\" - print("codeflash stdout: Async sorting list") - - await asyncio.sleep(0.01) - - n = len(lst) - for i in range(n): - for j in range(0, n - i - 1): - if lst[j] > lst[j + 1]: - lst[j], lst[j + 1] = lst[j + 1], lst[j] - - result = lst.copy() - print(f"result: {result}") - return result - - -class AsyncBubbleSorter: - \"\"\"Class with async sorting method for testing.\"\"\" - - async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]: - \"\"\" - Async bubble sort implementation within a class. - \"\"\" - print("codeflash stdout: AsyncBubbleSorter.sorter() called") - - # Add some async delay - await asyncio.sleep(0.005) - - n = len(lst) - for i in range(n): - for j in range(0, n - i - 1): - if lst[j] > lst[j + 1]: - lst[j], lst[j + 1] = lst[j + 1], lst[j] - - result = lst.copy() - return result - - -@codeflash_behavior_async -async def async_error_function(lst): - \"\"\"Async function that raises an error for testing.\"\"\" - await asyncio.sleep(0.001) # Small delay - raise ValueError("Test error") -""" - assert expected_instrumented_source == instrumented_source + decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR) + decorated_modified = modified_code.replace( + "async def async_error_function", f"@{decorator_name}\nasync def async_error_function" + ) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{decorated_modified}" + expected = sort_imports(code=code_with_import, float_to_top=True) + assert instrumented_source.strip() == expected.strip() instrument_codeflash_capture(func, {}, tests_root) opt = Optimizer( @@ -526,6 +509,9 @@ async def async_error_function(lst): test_path.unlink() if test_path_perf.exists(): test_path_perf.unlink() + helper_path = project_root_path / ASYNC_HELPER_FILENAME + if helper_path.exists(): + helper_path.unlink() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -563,7 +549,9 @@ async def test_async_multi(): func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True) - source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR) + source_success = add_async_decorator_to_function( + fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path + ) assert source_success instrument_codeflash_capture(func, {}, tests_root) @@ -636,6 +624,9 @@ async def test_async_multi(): test_path.unlink() if test_path_perf.exists(): test_path_perf.unlink() + helper_path = project_root_path / ASYNC_HELPER_FILENAME + if helper_path.exists(): + helper_path.unlink() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -678,7 +669,9 @@ async def test_async_edge_cases(): func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True) - source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR) + source_success = add_async_decorator_to_function( + fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path + ) assert source_success instrument_codeflash_capture(func, {}, tests_root) @@ -753,6 +746,9 @@ async def test_async_edge_cases(): test_path.unlink() if test_path_perf.exists(): test_path_perf.unlink() + helper_path = project_root_path / ASYNC_HELPER_FILENAME + if helper_path.exists(): + helper_path.unlink() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -988,7 +984,9 @@ async def test_mixed_sorting(): function_name="async_merge_sort", parents=[], file_path=Path(mixed_fto_path), is_async=True ) - source_success = add_async_decorator_to_function(mixed_fto_path, async_func, TestingMode.BEHAVIOR) + source_success = add_async_decorator_to_function( + mixed_fto_path, async_func, TestingMode.BEHAVIOR, project_root=project_root_path + ) assert source_success @@ -1061,3 +1059,6 @@ async def test_mixed_sorting(): test_path.unlink() if test_path_perf.exists(): test_path_perf.unlink() + helper_path = project_root_path / ASYNC_HELPER_FILENAME + if helper_path.exists(): + helper_path.unlink() diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 7088e6f1f..add427f32 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -10,17 +10,15 @@ from codeflash.code_utils.code_extractor import GlobalAssignmentCollector, add_global_assignments from codeflash.code_utils.code_replacer import replace_functions_and_add_imports -from codeflash.context.code_context_extractor import ( +from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.languages.python.context.code_context_extractor import ( collect_names_from_annotation, + enrich_testgen_context, extract_classes_from_type_hint, extract_imports_for_class, get_code_optimization_context, - get_external_base_class_inits, - get_external_class_inits, - get_imported_class_definitions, resolve_transitive_type_deps, ) -from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import CodeString, CodeStringsMarkdown, FunctionParent from codeflash.optimization.optimizer import Optimizer @@ -769,199 +767,6 @@ def helper_method(self): assert hashing_context.strip() == expected_hashing_context.strip() -def test_example_class_token_limit_1(tmp_path: Path) -> None: - docstring_filler = " ".join( - ["This is a long docstring that will be used to fill up the token limit." for _ in range(1000)] - ) - code = f""" -class MyClass: - \"\"\"A class with a helper method. -{docstring_filler}\"\"\" - def __init__(self): - self.x = 1 - def target_method(self): - \"\"\"Docstring for target method\"\"\" - y = HelperClass().helper_method() - -class HelperClass: - \"\"\"A helper class for MyClass.\"\"\" - def __init__(self): - \"\"\"Initialize the HelperClass.\"\"\" - self.x = 1 - def __repr__(self): - \"\"\"Return a string representation of the HelperClass.\"\"\" - return "HelperClass" + str(self.x) - def helper_method(self): - return self.x -""" - # Create a temporary Python file using pytest's tmp_path fixture - file_path = tmp_path / "test_code.py" - file_path.write_text(code, encoding="utf-8") - opt = Optimizer( - Namespace( - project_root=file_path.parent.resolve(), - disable_telemetry=True, - tests_root="tests", - test_framework="pytest", - pytest_cmd="pytest", - experiment_id=None, - test_project_root=Path().resolve(), - ) - ) - function_to_optimize = FunctionToOptimize( - function_name="target_method", - file_path=file_path, - parents=[FunctionParent(name="MyClass", type="ClassDef")], - starting_line=None, - ending_line=None, - ) - - code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) - read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code - hashing_context = code_ctx.hashing_code_context - # In this scenario, the read-only code context is too long, so the read-only docstrings are removed. - expected_read_write_context = f""" -```python:{file_path.relative_to(opt.args.project_root)} -class MyClass: - def __init__(self): - self.x = 1 - def target_method(self): - \"\"\"Docstring for target method\"\"\" - y = HelperClass().helper_method() - -class HelperClass: - def __init__(self): - \"\"\"Initialize the HelperClass.\"\"\" - self.x = 1 - def helper_method(self): - return self.x -``` -""" - expected_read_only_context = f""" -```python:{file_path.relative_to(opt.args.project_root)} -class MyClass: - pass - -class HelperClass: - def __repr__(self): - return "HelperClass" + str(self.x) -``` -""" - expected_hashing_context = f""" -```python:{file_path.relative_to(opt.args.project_root)} -class MyClass: - - def target_method(self): - y = HelperClass().helper_method() - -class HelperClass: - - def helper_method(self): - return self.x -``` -""" - assert read_write_context.markdown.strip() == expected_read_write_context.strip() - assert read_only_context.strip() == expected_read_only_context.strip() - assert hashing_context.strip() == expected_hashing_context.strip() - - -def test_example_class_token_limit_2(tmp_path: Path) -> None: - string_filler = " ".join( - ["This is a long string that will be used to fill up the token limit." for _ in range(1000)] - ) - code = f""" -class MyClass: - \"\"\"A class with a helper method. \"\"\" - def __init__(self): - self.x = 1 - def target_method(self): - \"\"\"Docstring for target method\"\"\" - y = HelperClass().helper_method() -x = '{string_filler}' - -class HelperClass: - \"\"\"A helper class for MyClass.\"\"\" - def __init__(self): - \"\"\"Initialize the HelperClass.\"\"\" - self.x = 1 - def __repr__(self): - \"\"\"Return a string representation of the HelperClass.\"\"\" - return "HelperClass" + str(self.x) - def helper_method(self): - return self.x -""" - # Create a temporary Python file using pytest's tmp_path fixture - file_path = tmp_path / "test_code.py" - file_path.write_text(code, encoding="utf-8") - opt = Optimizer( - Namespace( - project_root=file_path.parent.resolve(), - disable_telemetry=True, - tests_root="tests", - test_framework="pytest", - pytest_cmd="pytest", - experiment_id=None, - test_project_root=Path().resolve(), - ) - ) - function_to_optimize = FunctionToOptimize( - function_name="target_method", - file_path=file_path, - parents=[FunctionParent(name="MyClass", type="ClassDef")], - starting_line=None, - ending_line=None, - ) - - code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root, 8000, 100000) - read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code - hashing_context = code_ctx.hashing_code_context - # In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely. - expected_read_write_context = f""" -```python:{file_path.relative_to(opt.args.project_root)} -class MyClass: - def __init__(self): - self.x = 1 - def target_method(self): - \"\"\"Docstring for target method\"\"\" - y = HelperClass().helper_method() - -class HelperClass: - def __init__(self): - \"\"\"Initialize the HelperClass.\"\"\" - self.x = 1 - def helper_method(self): - return self.x -``` -""" - expected_read_only_context = f'''```python:{file_path.relative_to(opt.args.project_root)} -class MyClass: - """A class with a helper method. """ - -class HelperClass: - """A helper class for MyClass.""" - def __repr__(self): - """Return a string representation of the HelperClass.""" - return "HelperClass" + str(self.x) -``` -''' - expected_hashing_context = f""" -```python:{file_path.relative_to(opt.args.project_root)} -class MyClass: - - def target_method(self): - y = HelperClass().helper_method() - -class HelperClass: - - def helper_method(self): - return self.x -``` -""" - assert read_write_context.markdown.strip() == expected_read_write_context.strip() - assert read_only_context.strip() == expected_read_only_context.strip() - assert hashing_context.strip() == expected_hashing_context.strip() - - def test_example_class_token_limit_3(tmp_path: Path) -> None: string_filler = " ".join( ["This is a long string that will be used to fill up the token limit." for _ in range(1000)] @@ -1009,7 +814,7 @@ def helper_method(self): ) # In this scenario, the read-writable code is too long, so we abort. with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"): - code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + get_code_optimization_context(function_to_optimize, opt.args.project_root, optim_token_limit=8000) def test_example_class_token_limit_4(tmp_path: Path) -> None: @@ -1062,7 +867,7 @@ def helper_method(self): # In this scenario, the read-writable code context becomes too large because the __init__ function is referencing the global x variable instead of the class attribute self.x, so we abort. with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"): - code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + get_code_optimization_context(function_to_optimize, opt.args.project_root, optim_token_limit=8000) def test_example_class_token_limit_5(tmp_path: Path) -> None: @@ -2422,7 +2227,7 @@ def nested_method(self): assert "__init__" not in hashing_context # Should not contain __init__ methods # Verify nested classes are excluded from the hashing context - # The prune_cst_for_code_hashing function should not recurse into nested classes + # The prune_cst function in hashing mode should not recurse into nested classes assert "class NestedClass:" not in hashing_context # Nested class definition should not be present # The target method will reference NestedClass, but the actual nested class definition should not be included @@ -3275,8 +3080,8 @@ def dump_layout(layout_type, layout): assert testgen_context.count("def __init__") >= 2, "Both __init__ methods should be in testgen context" -def test_get_imported_class_definitions_extracts_project_classes(tmp_path: Path) -> None: - """Test that get_imported_class_definitions extracts class definitions from project modules.""" +def test_enrich_testgen_context_extracts_project_classes(tmp_path: Path) -> None: + """Test that enrich_testgen_context extracts class definitions from project modules.""" # Create a package structure with two modules package_dir = tmp_path / "mypackage" package_dir.mkdir() @@ -3325,8 +3130,8 @@ def will_fit(self, chunk: PreChunk) -> bool: # Create CodeStringsMarkdown from the chunking module (simulating testgen context) context = CodeStringsMarkdown(code_strings=[CodeString(code=chunking_code, file_path=chunking_path)]) - # Call get_imported_class_definitions - result = get_imported_class_definitions(context, tmp_path) + # Call enrich_testgen_context + result = enrich_testgen_context(context, tmp_path) # Verify Element class was extracted assert len(result.code_strings) == 1, "Should extract exactly one class (Element)" @@ -3339,8 +3144,8 @@ def will_fit(self, chunk: PreChunk) -> bool: assert "import abc" in extracted_code, "Should include necessary imports for base class" -def test_get_imported_class_definitions_skips_existing_definitions(tmp_path: Path) -> None: - """Test that get_imported_class_definitions skips classes already defined in context.""" +def test_enrich_testgen_context_skips_existing_definitions(tmp_path: Path) -> None: + """Test that enrich_testgen_context skips classes already defined in context.""" # Create a package structure package_dir = tmp_path / "mypackage" package_dir.mkdir() @@ -3373,15 +3178,15 @@ def process(self, elem: Element): context = CodeStringsMarkdown(code_strings=[CodeString(code=code_with_local_def, file_path=code_path)]) - # Call get_imported_class_definitions - result = get_imported_class_definitions(context, tmp_path) + # Call enrich_testgen_context + result = enrich_testgen_context(context, tmp_path) # Should NOT extract Element since it's already defined locally assert len(result.code_strings) == 0, "Should not extract classes already defined in context" -def test_get_imported_class_definitions_skips_third_party(tmp_path: Path) -> None: - """Test that get_imported_class_definitions skips third-party/stdlib imports.""" +def test_enrich_testgen_context_skips_third_party(tmp_path: Path) -> None: + """Test that enrich_testgen_context skips third-party/stdlib imports.""" # Create a simple package package_dir = tmp_path / "mypackage" package_dir.mkdir() @@ -3402,15 +3207,15 @@ def __init__(self, path: Path): context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - # Call get_imported_class_definitions - result = get_imported_class_definitions(context, tmp_path) + # Call enrich_testgen_context + result = enrich_testgen_context(context, tmp_path) # Should not extract any classes (Path, Optional, dataclass are stdlib/third-party) assert len(result.code_strings) == 0, "Should not extract stdlib/third-party classes" -def test_get_imported_class_definitions_handles_multiple_imports(tmp_path: Path) -> None: - """Test that get_imported_class_definitions handles multiple class imports.""" +def test_enrich_testgen_context_handles_multiple_imports(tmp_path: Path) -> None: + """Test that enrich_testgen_context handles multiple class imports.""" # Create a package structure package_dir = tmp_path / "mypackage" package_dir.mkdir() @@ -3446,8 +3251,8 @@ def process(self, a: TypeA, b: TypeB): context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - # Call get_imported_class_definitions - result = get_imported_class_definitions(context, tmp_path) + # Call enrich_testgen_context + result = enrich_testgen_context(context, tmp_path) # Should extract both TypeA and TypeB (but not TypeC since it's not imported) assert len(result.code_strings) == 2, "Should extract exactly two classes (TypeA, TypeB)" @@ -3458,8 +3263,8 @@ def process(self, a: TypeA, b: TypeB): assert "class TypeC" not in all_extracted_code, "Should NOT contain TypeC (not imported)" -def test_get_imported_class_definitions_includes_dataclass_decorators(tmp_path: Path) -> None: - """Test that get_imported_class_definitions includes decorators when extracting dataclasses.""" +def test_enrich_testgen_context_includes_dataclass_decorators(tmp_path: Path) -> None: + """Test that enrich_testgen_context includes decorators when extracting dataclasses.""" # Create a package structure package_dir = tmp_path / "mypackage" package_dir.mkdir() @@ -3496,8 +3301,8 @@ def get_config(self) -> LLMConfig: context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - # Call get_imported_class_definitions - result = get_imported_class_definitions(context, tmp_path) + # Call enrich_testgen_context + result = enrich_testgen_context(context, tmp_path) # Should extract both LLMConfigBase (base class) and LLMConfig assert len(result.code_strings) == 2, "Should extract both LLMConfig and its base class LLMConfigBase" @@ -3521,7 +3326,7 @@ def get_config(self) -> LLMConfig: assert "from dataclasses import" in all_extracted_code, "Should include dataclasses import" -def test_get_imported_class_definitions_extracts_imports_for_decorated_classes(tmp_path: Path) -> None: +def test_enrich_testgen_context_extracts_imports_for_decorated_classes(tmp_path: Path) -> None: """Test that extract_imports_for_class includes decorator and type annotation imports.""" # Create a package structure package_dir = tmp_path / "mypackage" @@ -3552,7 +3357,7 @@ def create_config() -> Config: context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_imported_class_definitions(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) assert len(result.code_strings) == 1, "Should extract Config class" extracted_code = result.code_strings[0].code @@ -3724,7 +3529,7 @@ class MyClass: assert result.count("from typing import Optional") == 1 -def test_get_imported_class_definitions_multiple_decorators(tmp_path: Path) -> None: +def test_enrich_testgen_context_multiple_decorators(tmp_path: Path) -> None: """Test that classes with multiple decorators are extracted correctly.""" package_dir = tmp_path / "mypackage" package_dir.mkdir() @@ -3755,7 +3560,7 @@ def sort_configs(configs: list[OrderedConfig]) -> list[OrderedConfig]: context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_imported_class_definitions(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) assert len(result.code_strings) == 1 extracted_code = result.code_strings[0].code @@ -3766,7 +3571,7 @@ def sort_configs(configs: list[OrderedConfig]) -> list[OrderedConfig]: assert "class OrderedConfig" in extracted_code -def test_get_imported_class_definitions_extracts_multilevel_inheritance(tmp_path: Path) -> None: +def test_enrich_testgen_context_extracts_multilevel_inheritance(tmp_path: Path) -> None: """Test that base classes are recursively extracted for multi-level inheritance. This is critical for understanding dataclass constructor signatures, as fields @@ -3826,8 +3631,8 @@ def get_router_config(self) -> RouterConfig: context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - # Call get_imported_class_definitions - result = get_imported_class_definitions(context, tmp_path) + # Call enrich_testgen_context + result = enrich_testgen_context(context, tmp_path) # Should extract 4 classes: GrandParentConfig, ParentConfig, ChildConfig, RouterConfig # (all classes needed to understand the full inheritance hierarchy) @@ -3862,7 +3667,7 @@ def get_router_config(self) -> RouterConfig: assert "model_list: list" in all_extracted_code, "Should include model_list field from Router" -def test_get_external_base_class_inits_extracts_userdict(tmp_path: Path) -> None: +def test_enrich_testgen_context_extracts_userdict(tmp_path: Path) -> None: """Extracts __init__ from collections.UserDict when a class inherits from it.""" code = """from collections import UserDict @@ -3873,7 +3678,7 @@ class MyCustomDict(UserDict): code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_base_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) assert len(result.code_strings) == 1 code_string = result.code_strings[0] @@ -3891,8 +3696,8 @@ def __init__(self, dict=None, /, **kwargs): assert code_string.file_path.as_posix().endswith("collections/__init__.py") -def test_get_external_base_class_inits_skips_project_classes(tmp_path: Path) -> None: - """Returns empty when base class is from the project, not external.""" +def test_enrich_testgen_context_skips_unresolvable_base_classes(tmp_path: Path) -> None: + """Returns empty when base class module cannot be resolved.""" child_code = """from base import ProjectBase class Child(ProjectBase): @@ -3902,12 +3707,12 @@ class Child(ProjectBase): child_path.write_text(child_code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=child_code, file_path=child_path)]) - result = get_external_base_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) assert result.code_strings == [] -def test_get_external_base_class_inits_skips_builtins(tmp_path: Path) -> None: +def test_enrich_testgen_context_skips_builtin_base_classes(tmp_path: Path) -> None: """Returns empty for builtin classes like list that have no inspectable source.""" code = """class MyList(list): pass @@ -3916,12 +3721,12 @@ def test_get_external_base_class_inits_skips_builtins(tmp_path: Path) -> None: code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_base_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) assert result.code_strings == [] -def test_get_external_base_class_inits_deduplicates(tmp_path: Path) -> None: +def test_enrich_testgen_context_deduplicates(tmp_path: Path) -> None: """Extracts the same external base class only once even when inherited multiple times.""" code = """from collections import UserDict @@ -3935,7 +3740,7 @@ class MyDict2(UserDict): code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_base_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) assert len(result.code_strings) == 1 expected_code = """\ @@ -3950,7 +3755,7 @@ def __init__(self, dict=None, /, **kwargs): assert result.code_strings[0].code == expected_code -def test_get_external_base_class_inits_empty_when_no_inheritance(tmp_path: Path) -> None: +def test_enrich_testgen_context_empty_when_no_inheritance(tmp_path: Path) -> None: """Returns empty when there are no external base classes.""" code = """class SimpleClass: pass @@ -3959,7 +3764,7 @@ def test_get_external_base_class_inits_empty_when_no_inheritance(tmp_path: Path) code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_base_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) assert result.code_strings == [] @@ -4103,127 +3908,8 @@ def target_method(self): assert "self.data = {}" in testgen_context, "UserDict __init__ body should be included" -def test_read_only_code_removed_when_exceeds_limit(tmp_path: Path) -> None: - """Test read-only code is completely removed when it exceeds token limit even without docstrings. - - This covers lines 152-153 in code_context_extractor.py where read_only_context_code is set - to empty string when it still exceeds the token limit after docstring removal. - """ - # Create a second-degree helper with large implementation that has no docstrings - # Second-degree helpers go into read-only context - long_lines = [" x = 0"] - for i in range(150): - long_lines.append(f" x = x + {i}") - long_lines.append(" return x") - long_body = "\n".join(long_lines) - - code = f""" -class MyClass: - def __init__(self): - self.x = 1 - - def target_method(self): - return first_helper() - - -def first_helper(): - # First degree helper - calls second degree - return second_helper() - - -def second_helper(): - # Second degree helper - goes into read-only context -{long_body} -""" - file_path = tmp_path / "test_code.py" - file_path.write_text(code, encoding="utf-8") - - func_to_optimize = FunctionToOptimize( - function_name="target_method", file_path=file_path, parents=[FunctionParent(name="MyClass", type="ClassDef")] - ) - - # Use a small optim_token_limit that allows read-writable but not read-only - # Read-writable is ~48 tokens, read-only is ~600 tokens - code_ctx = get_code_optimization_context( - function_to_optimize=func_to_optimize, - project_root_path=tmp_path, - optim_token_limit=100, # Small limit to trigger read-only removal - ) - - # The read-only context should be empty because it exceeded the limit - assert code_ctx.read_only_context_code == "", "Read-only code should be removed when exceeding token limit" - - -def test_testgen_removes_imported_classes_on_overflow(tmp_path: Path) -> None: - """Test testgen context removes imported class definitions when exceeding token limit. - - This covers lines 176-186 in code_context_extractor.py where: - - Testgen context exceeds limit (line 175) - - Removing docstrings still exceeds (line 175 again) - - Removing imported classes succeeds (line 177-183) - """ - # Create a package structure with a large type class used only in type annotations - # This ensures get_imported_class_definitions extracts the full class - package_dir = tmp_path / "mypackage" - package_dir.mkdir() - (package_dir / "__init__.py").write_text("", encoding="utf-8") - - # Create a large class with methods that will be extracted via get_imported_class_definitions - # Use methods WITHOUT docstrings so removing docstrings won't help much - many_methods = "\n".join([f" def method_{i}(self):\n return {i}" for i in range(100)]) - type_class_code = f''' -class TypeClass: - """A type class for annotations.""" - - def __init__(self, value: int): - self.value = value - -{many_methods} -''' - type_class_path = package_dir / "types.py" - type_class_path.write_text(type_class_code, encoding="utf-8") - - # Main module uses TypeClass only in annotation (not instantiated) - # This triggers get_imported_class_definitions to extract the full class - main_code = """ -from mypackage.types import TypeClass - -def target_function(obj: TypeClass) -> int: - return obj.value -""" - main_path = package_dir / "main.py" - main_path.write_text(main_code, encoding="utf-8") - - func_to_optimize = FunctionToOptimize(function_name="target_function", file_path=main_path, parents=[]) - - # Use a testgen_token_limit that: - # - Is exceeded by full context with imported class (~1500 tokens) - # - Is exceeded even after removing docstrings - # - But fits when imported class is removed (~40 tokens) - code_ctx = get_code_optimization_context( - function_to_optimize=func_to_optimize, - project_root_path=tmp_path, - testgen_token_limit=200, # Small limit to trigger imported class removal - ) - - # The testgen context should exist (didn't raise ValueError) - testgen_context = code_ctx.testgen_context.markdown - assert testgen_context, "Testgen context should not be empty" - - # The target function should still be there - assert "def target_function" in testgen_context, "Target function should be in testgen context" - - # The large imported class should NOT be included (removed due to token limit) - assert "class TypeClass" not in testgen_context, ( - "TypeClass should be removed from testgen context when exceeding token limit" - ) - - -def test_testgen_raises_when_all_fallbacks_fail(tmp_path: Path) -> None: - """Test that ValueError is raised when testgen context exceeds limit even after all fallbacks. - - This covers line 186 in code_context_extractor.py. - """ +def test_testgen_raises_when_exceeds_limit(tmp_path: Path) -> None: + """Test that ValueError is raised when testgen context exceeds token limit.""" # Create a function with a very long body that exceeds limits even without imports/docstrings long_lines = [" x = 0"] for i in range(200): @@ -4249,7 +3935,7 @@ def target_function(): ) -def test_get_external_base_class_inits_attribute_base(tmp_path: Path) -> None: +def test_enrich_testgen_context_attribute_base(tmp_path: Path) -> None: """Test handling of base class accessed as module.ClassName (ast.Attribute). This covers line 616 in code_context_extractor.py. @@ -4265,7 +3951,7 @@ def custom_method(self): code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_base_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) # Should extract UserDict __init__ assert len(result.code_strings) == 1 @@ -4273,7 +3959,7 @@ def custom_method(self): assert "def __init__" in result.code_strings[0].code -def test_get_external_base_class_inits_no_init_method(tmp_path: Path) -> None: +def test_enrich_testgen_context_no_init_method(tmp_path: Path) -> None: """Test handling when base class has no __init__ method. This covers line 641 in code_context_extractor.py. @@ -4288,7 +3974,7 @@ class MyProtocol(Protocol): code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_base_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) # Protocol's __init__ can't be easily inspected, should handle gracefully # Result may be empty or contain Protocol based on implementation @@ -4377,7 +4063,7 @@ def target_method(self): def test_imported_class_definitions_module_path_none(tmp_path: Path) -> None: - """Test handling when module_path is None in get_imported_class_definitions. + """Test handling when module_path is None in enrich_testgen_context. This covers line 560 in code_context_extractor.py. """ @@ -4393,123 +4079,12 @@ def method(self, obj: SomeClass): code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_imported_class_definitions(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) # Should handle gracefully and return empty or partial results assert isinstance(result.code_strings, list) -def test_get_imported_names_import_star(tmp_path: Path) -> None: - """Test get_imported_names handles import * correctly. - - This covers lines 808-809 and 824-825 in code_context_extractor.py. - """ - import libcst as cst - - # Test regular import * - # Note: "import *" is not valid Python, but "from x import *" is - from_import_star = cst.parse_statement("from os import *") - assert isinstance(from_import_star, cst.SimpleStatementLine) - import_node = from_import_star.body[0] - assert isinstance(import_node, cst.ImportFrom) - - from codeflash.context.code_context_extractor import get_imported_names - - result = get_imported_names(import_node) - assert result == {"*"} - - -def test_get_imported_names_aliased_import(tmp_path: Path) -> None: - """Test get_imported_names handles aliased imports correctly. - - This covers lines 812-813 and 828-829 in code_context_extractor.py. - """ - import libcst as cst - - from codeflash.context.code_context_extractor import get_imported_names - - # Test import with alias - import_stmt = cst.parse_statement("import numpy as np") - assert isinstance(import_stmt, cst.SimpleStatementLine) - import_node = import_stmt.body[0] - assert isinstance(import_node, cst.Import) - - result = get_imported_names(import_node) - assert "np" in result - - # Test from import with alias - from_import_stmt = cst.parse_statement("from os import path as ospath") - assert isinstance(from_import_stmt, cst.SimpleStatementLine) - from_import_node = from_import_stmt.body[0] - assert isinstance(from_import_node, cst.ImportFrom) - - result2 = get_imported_names(from_import_node) - assert "ospath" in result2 - - -def test_get_imported_names_dotted_import(tmp_path: Path) -> None: - """Test get_imported_names handles dotted imports correctly. - - This covers lines 816-822 in code_context_extractor.py. - """ - import libcst as cst - - from codeflash.context.code_context_extractor import get_imported_names - - # Test dotted import like "import os.path" - import_stmt = cst.parse_statement("import os.path") - assert isinstance(import_stmt, cst.SimpleStatementLine) - import_node = import_stmt.body[0] - assert isinstance(import_node, cst.Import) - - result = get_imported_names(import_node) - assert "os" in result - - -def test_used_name_collector_comprehensive(tmp_path: Path) -> None: - """Test UsedNameCollector handles various node types. - - This covers lines 767-801 in code_context_extractor.py. - """ - import libcst as cst - - from codeflash.context.code_context_extractor import UsedNameCollector - - code = """ -import os -from typing import List - -x: int = 1 -y = os.path.join("a", "b") - -class MyClass: - z = 10 - -def my_func(): - pass -""" - module = cst.parse_module(code) - collector = UsedNameCollector() - # In libcst, the walker traverses the module - cst.MetadataWrapper(module).visit(collector) - - # Check used names - assert "os" in collector.used_names - assert "int" in collector.used_names - assert "List" in collector.used_names - - # Check defined names - assert "x" in collector.defined_names - assert "y" in collector.defined_names - assert "MyClass" in collector.defined_names - assert "my_func" in collector.defined_names - - # Check external names (used but not defined) - external = collector.get_external_names() - assert "os" in external - assert "x" not in external # x is defined - - def test_imported_class_with_base_in_same_module(tmp_path: Path) -> None: """Test that imported classes with bases in the same module are extracted correctly. @@ -4549,52 +4124,13 @@ def target_function(obj: DerivedClass) -> bool: main_path.write_text(main_code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=main_code, file_path=main_path)]) - result = get_imported_class_definitions(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) # Should extract the inheritance chain all_code = "\n".join(cs.code for cs in result.code_strings) assert "class BaseClass" in all_code or "class DerivedClass" in all_code -def test_get_imported_names_from_import_without_alias(tmp_path: Path) -> None: - """Test get_imported_names handles from imports without aliases. - - This covers lines 830-831 in code_context_extractor.py. - """ - import libcst as cst - - from codeflash.context.code_context_extractor import get_imported_names - - # Test from import without alias - from_import_stmt = cst.parse_statement("from os import path, getcwd") - assert isinstance(from_import_stmt, cst.SimpleStatementLine) - from_import_node = from_import_stmt.body[0] - assert isinstance(from_import_node, cst.ImportFrom) - - result = get_imported_names(from_import_node) - assert "path" in result - assert "getcwd" in result - - -def test_get_imported_names_regular_import(tmp_path: Path) -> None: - """Test get_imported_names handles regular imports. - - This covers lines 814-815 in code_context_extractor.py. - """ - import libcst as cst - - from codeflash.context.code_context_extractor import get_imported_names - - # Test regular import without alias - import_stmt = cst.parse_statement("import json") - assert isinstance(import_stmt, cst.SimpleStatementLine) - import_node = import_stmt.body[0] - assert isinstance(import_node, cst.Import) - - result = get_imported_names(import_node) - assert "json" in result - - def test_augmented_assignment_not_in_context(tmp_path: Path) -> None: """Test that augmented assignments are handled but not included unless used. @@ -4625,7 +4161,7 @@ def target_method(self): assert "counter" in read_writable -def test_get_external_class_inits_extracts_click_option(tmp_path: Path) -> None: +def test_enrich_testgen_context_extracts_click_option(tmp_path: Path) -> None: """Extracts __init__ from click.Option when directly imported.""" code = """from click import Option @@ -4636,7 +4172,7 @@ def my_func(opt: Option) -> None: code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) assert len(result.code_strings) == 1 code_string = result.code_strings[0] @@ -4645,8 +4181,8 @@ def my_func(opt: Option) -> None: assert code_string.file_path is not None and "click" in code_string.file_path.as_posix() -def test_get_external_class_inits_skips_project_classes(tmp_path: Path) -> None: - """Returns empty when imported class is from the project, not external.""" +def test_enrich_testgen_context_extracts_project_class_defs(tmp_path: Path) -> None: + """Extracts project class definitions via jedi resolution.""" # Create a project module with a class (tmp_path / "mymodule.py").write_text("class ProjectClass:\n pass\n", encoding="utf-8") @@ -4659,12 +4195,13 @@ def my_func(obj: ProjectClass) -> None: code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) - assert result.code_strings == [] + assert len(result.code_strings) == 1 + assert "class ProjectClass" in result.code_strings[0].code -def test_get_external_class_inits_skips_non_classes(tmp_path: Path) -> None: +def test_enrich_testgen_context_skips_non_classes(tmp_path: Path) -> None: """Returns empty when imported name is a function, not a class.""" code = """from collections import OrderedDict from os.path import join @@ -4676,7 +4213,7 @@ def my_func() -> None: code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) # join is a function, not a class — should be skipped # OrderedDict is a class and should be included @@ -4684,8 +4221,8 @@ def my_func() -> None: assert not any("join" in name for name in class_names) -def test_get_external_class_inits_skips_already_defined_classes(tmp_path: Path) -> None: - """Skips classes already defined in the context (e.g., added by get_imported_class_definitions).""" +def test_enrich_testgen_context_skips_already_defined_classes(tmp_path: Path) -> None: + """Skips classes already defined in the context (e.g., added by enrich_testgen_context).""" code = """from collections import UserDict class UserDict: @@ -4699,14 +4236,14 @@ def my_func(d: UserDict) -> None: code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) # UserDict is already defined in the context, so it should be skipped assert result.code_strings == [] -def test_get_external_class_inits_skips_builtins(tmp_path: Path) -> None: - """Returns empty for builtin classes like list/dict that have no inspectable source.""" +def test_enrich_testgen_context_skips_builtin_annotations(tmp_path: Path) -> None: + """Returns empty for builtin type annotations like list/dict that are not imported.""" code = """x: list = [] y: dict = {} @@ -4717,12 +4254,12 @@ def my_func() -> None: code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) assert result.code_strings == [] -def test_get_external_class_inits_skips_object_init(tmp_path: Path) -> None: +def test_enrich_testgen_context_skips_object_init(tmp_path: Path) -> None: """Skips classes whose __init__ is just object.__init__ (trivial).""" # enum.Enum has a metaclass-based __init__, but individual enum members # effectively use object.__init__. Use a class we know has object.__init__. @@ -4735,14 +4272,14 @@ def my_func(q: QName) -> None: code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) # QName has its own __init__, so it should be included if it's in site-packages. # But since it's stdlib (not site-packages), it should be skipped. assert result.code_strings == [] -def test_get_external_class_inits_empty_when_no_imports(tmp_path: Path) -> None: +def test_enrich_testgen_context_empty_when_no_imports(tmp_path: Path) -> None: """Returns empty when there are no from-imports.""" code = """def my_func() -> None: pass @@ -4751,7 +4288,7 @@ def test_get_external_class_inits_empty_when_no_imports(tmp_path: Path) -> None: code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) assert result.code_strings == [] @@ -4840,17 +4377,17 @@ def test_resolve_transitive_type_deps_handles_failure_gracefully() -> None: """Returns empty list for a class where get_type_hints fails.""" class BadClass: - def __init__(self, x: "NonexistentType") -> None: # type: ignore[name-defined] # noqa: F821 + def __init__(self, x: NonexistentType) -> None: # type: ignore[name-defined] # noqa: F821 pass result = resolve_transitive_type_deps(BadClass) assert result == [] -# --- Integration tests for transitive resolution in get_external_class_inits --- +# --- Integration tests for transitive resolution in enrich_testgen_context --- -def test_get_external_class_inits_transitive_deps(tmp_path: Path) -> None: +def test_enrich_testgen_context_transitive_deps(tmp_path: Path) -> None: """Extracts transitive type dependencies from __init__ annotations.""" code = """from click import Context @@ -4861,7 +4398,7 @@ def my_func(ctx: Context) -> None: code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) class_names = {cs.code.split("\n")[0].replace("class ", "").rstrip(":") for cs in result.code_strings} assert "Context" in class_names @@ -4869,7 +4406,7 @@ def my_func(ctx: Context) -> None: assert "Command" in class_names -def test_get_external_class_inits_no_infinite_loops(tmp_path: Path) -> None: +def test_enrich_testgen_context_no_infinite_loops(tmp_path: Path) -> None: """Handles classes with circular type references without infinite loops.""" # click.Context references Command, and Command references Context back # This should terminate without issues due to the processed_classes set @@ -4882,13 +4419,13 @@ def my_func(ctx: Context) -> None: code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) # Should complete without hanging; just verify we got results assert len(result.code_strings) >= 1 -def test_get_external_class_inits_no_duplicate_stubs(tmp_path: Path) -> None: +def test_enrich_testgen_context_no_duplicate_stubs(tmp_path: Path) -> None: """Does not emit duplicate stubs for the same class name.""" code = """from click import Context @@ -4899,7 +4436,7 @@ def my_func(ctx: Context) -> None: code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) class_names = [cs.code.split("\n")[0].replace("class ", "").rstrip(":") for cs in result.code_strings] assert len(class_names) == len(set(class_names)), f"Duplicate class stubs found: {class_names}" diff --git a/tests/test_get_read_only_code.py b/tests/test_get_read_only_code.py index 618e39767..c6de2cc27 100644 --- a/tests/test_get_read_only_code.py +++ b/tests/test_get_read_only_code.py @@ -2,7 +2,7 @@ import pytest -from codeflash.context.code_context_extractor import parse_code_and_prune_cst +from codeflash.languages.python.context.code_context_extractor import parse_code_and_prune_cst from codeflash.models.models import CodeContextType diff --git a/tests/test_get_read_writable_code.py b/tests/test_get_read_writable_code.py index 6de398a25..c6bbdd04b 100644 --- a/tests/test_get_read_writable_code.py +++ b/tests/test_get_read_writable_code.py @@ -2,7 +2,7 @@ import pytest -from codeflash.context.code_context_extractor import parse_code_and_prune_cst +from codeflash.languages.python.context.code_context_extractor import parse_code_and_prune_cst from codeflash.models.models import CodeContextType diff --git a/tests/test_get_testgen_code.py b/tests/test_get_testgen_code.py index c15005fa7..01c3ae153 100644 --- a/tests/test_get_testgen_code.py +++ b/tests/test_get_testgen_code.py @@ -2,7 +2,7 @@ import pytest -from codeflash.context.code_context_extractor import parse_code_and_prune_cst +from codeflash.languages.python.context.code_context_extractor import parse_code_and_prune_cst from codeflash.models.models import CodeContextType diff --git a/tests/test_instrument_async_tests.py b/tests/test_instrument_async_tests.py index 69552ba08..edd9c296b 100644 --- a/tests/test_instrument_async_tests.py +++ b/tests/test_instrument_async_tests.py @@ -6,7 +6,9 @@ import pytest from codeflash.code_utils.instrument_existing_tests import ( + ASYNC_HELPER_FILENAME, add_async_decorator_to_function, + get_decorator_name_for_mode, inject_profiling_into_existing_test, ) from codeflash.discovery.functions_to_optimize import FunctionToOptimize @@ -57,20 +59,6 @@ def test_async_decorator_application_behavior_mode(temp_dir): async_function_code = ''' import asyncio -async def async_function(x: int, y: int) -> int: - """Simple async function for testing.""" - await asyncio.sleep(0.01) - return x * y -''' - - expected_decorated_code = ''' -import asyncio - -from codeflash.code_utils.codeflash_wrap_decorator import \\ - codeflash_behavior_async - - -@codeflash_behavior_async async def async_function(x: int, y: int) -> int: """Simple async function for testing.""" await asyncio.sleep(0.01) @@ -86,7 +74,16 @@ async def async_function(x: int, y: int) -> int: assert decorator_added modified_code = test_file.read_text() - assert modified_code.strip() == expected_decorated_code.strip() + from codeflash.code_utils.formatter import sort_imports + + decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR) + code_with_decorator = async_function_code.replace( + "async def async_function", f"@{decorator_name}\nasync def async_function" + ) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}" + expected = sort_imports(code=code_with_import, float_to_top=True) + assert modified_code.strip() == expected.strip() + assert (temp_dir / ASYNC_HELPER_FILENAME).exists() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -94,20 +91,6 @@ def test_async_decorator_application_performance_mode(temp_dir): async_function_code = ''' import asyncio -async def async_function(x: int, y: int) -> int: - """Simple async function for testing.""" - await asyncio.sleep(0.01) - return x * y -''' - - expected_decorated_code = ''' -import asyncio - -from codeflash.code_utils.codeflash_wrap_decorator import \\ - codeflash_performance_async - - -@codeflash_performance_async async def async_function(x: int, y: int) -> int: """Simple async function for testing.""" await asyncio.sleep(0.01) @@ -123,7 +106,16 @@ async def async_function(x: int, y: int) -> int: assert decorator_added modified_code = test_file.read_text() - assert modified_code.strip() == expected_decorated_code.strip() + from codeflash.code_utils.formatter import sort_imports + + decorator_name = get_decorator_name_for_mode(TestingMode.PERFORMANCE) + code_with_decorator = async_function_code.replace( + "async def async_function", f"@{decorator_name}\nasync def async_function" + ) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}" + expected = sort_imports(code=code_with_import, float_to_top=True) + assert modified_code.strip() == expected.strip() + assert (temp_dir / ASYNC_HELPER_FILENAME).exists() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -132,20 +124,6 @@ def test_async_decorator_application_concurrency_mode(temp_dir): async_function_code = ''' import asyncio -async def async_function(x: int, y: int) -> int: - """Simple async function for testing.""" - await asyncio.sleep(0.01) - return x * y -''' - - expected_decorated_code = ''' -import asyncio - -from codeflash.code_utils.codeflash_wrap_decorator import \\ - codeflash_concurrency_async - - -@codeflash_concurrency_async async def async_function(x: int, y: int) -> int: """Simple async function for testing.""" await asyncio.sleep(0.01) @@ -161,7 +139,16 @@ async def async_function(x: int, y: int) -> int: assert decorator_added modified_code = test_file.read_text() - assert modified_code.strip() == expected_decorated_code.strip() + from codeflash.code_utils.formatter import sort_imports + + decorator_name = get_decorator_name_for_mode(TestingMode.CONCURRENCY) + code_with_decorator = async_function_code.replace( + "async def async_function", f"@{decorator_name}\nasync def async_function" + ) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}" + expected = sort_imports(code=code_with_import, float_to_top=True) + assert modified_code.strip() == expected.strip() + assert (temp_dir / ASYNC_HELPER_FILENAME).exists() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -182,27 +169,6 @@ def sync_method(self, a: int, b: int) -> int: return a - b ''' - expected_decorated_code = ''' -import asyncio - -from codeflash.code_utils.codeflash_wrap_decorator import \\ - codeflash_behavior_async - - -class Calculator: - """Test class with async methods.""" - - @codeflash_behavior_async - async def async_method(self, a: int, b: int) -> int: - """Async method in class.""" - await asyncio.sleep(0.005) - return a ** b - - def sync_method(self, a: int, b: int) -> int: - """Sync method in class.""" - return a - b -''' - test_file = temp_dir / "test_async.py" test_file.write_text(async_class_code) @@ -217,11 +183,21 @@ def sync_method(self, a: int, b: int) -> int: assert decorator_added modified_code = test_file.read_text() - assert modified_code.strip() == expected_decorated_code.strip() + from codeflash.code_utils.formatter import sort_imports + + decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR) + code_with_decorator = async_class_code.replace( + " async def async_method", f" @{decorator_name}\n async def async_method" + ) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}" + expected = sort_imports(code=code_with_import, float_to_top=True) + assert modified_code.strip() == expected.strip() + assert (temp_dir / ASYNC_HELPER_FILENAME).exists() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") def test_async_decorator_no_duplicate_application(temp_dir): + # Case 1: Old-style import already present — injector should detect and skip already_decorated_code = ''' from codeflash.code_utils.codeflash_wrap_decorator import codeflash_behavior_async import asyncio @@ -243,6 +219,30 @@ async def async_function(x: int, y: int) -> int: # Should not add duplicate decorator assert not decorator_added + # Case 2: Inline definition already present — injector should detect and skip + already_inline_code = ''' +import asyncio + +def codeflash_behavior_async(func): + return func + +@codeflash_behavior_async +async def async_function(x: int, y: int) -> int: + """Already decorated async function.""" + await asyncio.sleep(0.01) + return x * y +''' + + test_file2 = temp_dir / "test_async2.py" + test_file2.write_text(already_inline_code) + + func2 = FunctionToOptimize(function_name="async_function", file_path=test_file2, parents=[], is_async=True) + + decorator_added2 = add_async_decorator_to_function(test_file2, func2, TestingMode.BEHAVIOR) + + # Should not add duplicate decorator + assert not decorator_added2 + @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") def test_inject_profiling_async_function_behavior_mode(temp_dir): @@ -285,11 +285,18 @@ async def test_async_function(): assert source_success is True - # Verify the file was modified + # Verify the file was modified with exact expected output instrumented_source = source_file.read_text() - assert "@codeflash_behavior_async" in instrumented_source - assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source - assert "codeflash_behavior_async" in instrumented_source + from codeflash.code_utils.formatter import sort_imports + + decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR) + code_with_decorator = source_module_code.replace( + "async def async_function", f"@{decorator_name}\nasync def async_function" + ) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}" + expected = sort_imports(code=code_with_import, float_to_top=True) + assert instrumented_source.strip() == expected.strip() + assert (temp_dir / ASYNC_HELPER_FILENAME).exists() success, instrumented_test_code = inject_profiling_into_existing_test( async_test_code, test_file, [CodePosition(8, 18), CodePosition(11, 19)], func, temp_dir, mode=TestingMode.BEHAVIOR @@ -340,12 +347,18 @@ async def test_async_function(): assert source_success is True - # Verify the file was modified + # Verify the file was modified with exact expected output instrumented_source = source_file.read_text() - assert "@codeflash_performance_async" in instrumented_source - # Check for the import with line continuation formatting - assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source - assert "codeflash_performance_async" in instrumented_source + from codeflash.code_utils.formatter import sort_imports + + decorator_name = get_decorator_name_for_mode(TestingMode.PERFORMANCE) + code_with_decorator = source_module_code.replace( + "async def async_function", f"@{decorator_name}\nasync def async_function" + ) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}" + expected = sort_imports(code=code_with_import, float_to_top=True) + assert instrumented_source.strip() == expected.strip() + assert (temp_dir / ASYNC_HELPER_FILENAME).exists() # Now test the full pipeline with source module path success, instrumented_test_code = inject_profiling_into_existing_test( @@ -406,11 +419,16 @@ async def test_mixed_functions(): # Verify the file was modified instrumented_source = source_file.read_text() - assert "@codeflash_behavior_async" in instrumented_source - assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source - assert "codeflash_behavior_async" in instrumented_source - # Sync function should remain unchanged - assert "def sync_function(x: int, y: int) -> int:" in instrumented_source + from codeflash.code_utils.formatter import sort_imports + + decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR) + code_with_decorator = source_module_code.replace( + "async def async_function", f"@{decorator_name}\nasync def async_function" + ) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}" + expected = sort_imports(code=code_with_import, float_to_top=True) + assert instrumented_source.strip() == expected.strip() + assert (temp_dir / ASYNC_HELPER_FILENAME).exists() success, instrumented_test_code = inject_profiling_into_existing_test( mixed_test_code, test_file, [CodePosition(8, 18), CodePosition(11, 19)], async_func, temp_dir, mode=TestingMode.BEHAVIOR @@ -446,24 +464,19 @@ async def nested_async_method(self, x: int) -> int: decorator_added = add_async_decorator_to_function(test_file, func, TestingMode.BEHAVIOR) - expected_output = """import asyncio - -from codeflash.code_utils.codeflash_wrap_decorator import \\ - codeflash_behavior_async - - -class OuterClass: - class InnerClass: - @codeflash_behavior_async - async def nested_async_method(self, x: int) -> int: - \"\"\"Nested async method.\"\"\" - await asyncio.sleep(0.001) - return x * 2 -""" - assert decorator_added modified_code = test_file.read_text() - assert modified_code.strip() == expected_output.strip() + from codeflash.code_utils.formatter import sort_imports + + decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR) + code_with_decorator = nested_async_code.replace( + " async def nested_async_method", + f" @{decorator_name}\n async def nested_async_method", + ) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}" + expected = sort_imports(code=code_with_import, float_to_top=True) + assert modified_code.strip() == expected.strip() + assert (temp_dir / ASYNC_HELPER_FILENAME).exists() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") diff --git a/tests/test_languages/test_code_context_extraction.py b/tests/test_languages/test_code_context_extraction.py index 07946ddd3..b7b12a69c 100644 --- a/tests/test_languages/test_code_context_extraction.py +++ b/tests/test_languages/test_code_context_extraction.py @@ -20,14 +20,12 @@ from __future__ import annotations -from pathlib import Path - import pytest -from codeflash.context.code_context_extractor import get_code_optimization_context_for_language from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.base import Language from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport +from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context_for_language @pytest.fixture diff --git a/tests/test_languages/test_java_e2e.py b/tests/test_languages/test_java_e2e.py index 1b6aa3ace..c01865048 100644 --- a/tests/test_languages/test_java_e2e.py +++ b/tests/test_languages/test_java_e2e.py @@ -89,7 +89,7 @@ def java_project_dir(self): def test_extract_code_context_for_java(self, java_project_dir): """Test extracting code context for a Java method.""" - from codeflash.context.code_context_extractor import get_code_optimization_context + from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context from codeflash.languages import current as lang_current from codeflash.languages.base import Language diff --git a/tests/test_languages/test_javascript_e2e.py b/tests/test_languages/test_javascript_e2e.py index 017e8f66e..7b7e8503b 100644 --- a/tests/test_languages/test_javascript_e2e.py +++ b/tests/test_languages/test_javascript_e2e.py @@ -106,9 +106,9 @@ def js_project_dir(self): def test_extract_code_context_for_javascript(self, js_project_dir): """Test extracting code context for a JavaScript function.""" skip_if_js_not_supported() - from codeflash.context.code_context_extractor import get_code_optimization_context from codeflash.discovery.functions_to_optimize import find_all_functions_in_file from codeflash.languages import current as lang_current + from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context lang_current._current_language = Language.JAVASCRIPT diff --git a/tests/test_languages/test_javascript_optimization_flow.py b/tests/test_languages/test_javascript_optimization_flow.py index 26d2db140..89631565b 100644 --- a/tests/test_languages/test_javascript_optimization_flow.py +++ b/tests/test_languages/test_javascript_optimization_flow.py @@ -9,7 +9,6 @@ This is the JavaScript equivalent of test_instrument_tests.py for Python. """ -from pathlib import Path from unittest.mock import MagicMock, patch import pytest @@ -71,9 +70,9 @@ def test_function_to_optimize_has_correct_language_for_javascript(self, tmp_path def test_code_context_preserves_language(self, tmp_path): """Verify language is preserved in code context extraction.""" skip_if_js_not_supported() - from codeflash.context.code_context_extractor import get_code_optimization_context from codeflash.discovery.functions_to_optimize import find_all_functions_in_file from codeflash.languages import current as lang_current + from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context lang_current._current_language = Language.TYPESCRIPT @@ -164,7 +163,7 @@ def test_testgen_request_includes_correct_language(self, tmp_path): # Mock the AI service request ai_client = AiServiceClient() - with patch.object(ai_client, 'make_ai_service_request') as mock_request: + with patch.object(ai_client, "make_ai_service_request") as mock_request: mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = { @@ -191,8 +190,8 @@ def test_testgen_request_includes_correct_language(self, tmp_path): # Verify the request was made with correct language assert mock_request.called, "API request should have been made" call_args = mock_request.call_args - payload = call_args[1].get('payload', call_args[0][1] if len(call_args[0]) > 1 else {}) - assert payload.get('language') == 'typescript', \ + payload = call_args[1].get("payload", call_args[0][1] if len(call_args[0]) > 1 else {}) + assert payload.get("language") == "typescript", \ f"Expected language='typescript', got language='{payload.get('language')}'" @@ -462,7 +461,7 @@ def test_helper_functions_have_correct_language_javascript(self, tmp_path): """Verify helper functions have language='javascript' for .js files.""" skip_if_js_not_supported() from codeflash.discovery.functions_to_optimize import find_all_functions_in_file - from codeflash.languages import current as lang_current, get_language_support + from codeflash.languages import current as lang_current from codeflash.optimization.function_optimizer import FunctionOptimizer lang_current._current_language = Language.JAVASCRIPT diff --git a/tests/test_languages/test_typescript_e2e.py b/tests/test_languages/test_typescript_e2e.py index a638f01a1..87dc81269 100644 --- a/tests/test_languages/test_typescript_e2e.py +++ b/tests/test_languages/test_typescript_e2e.py @@ -69,7 +69,7 @@ def test_discover_functions_with_type_annotations(self): from codeflash.discovery.functions_to_optimize import find_all_functions_in_file with tempfile.NamedTemporaryFile(suffix=".ts", mode="w", delete=False) as f: - f.write(""" + f.write(r""" export function add(a: number, b: number): number { return a + b; } @@ -123,9 +123,9 @@ def ts_project_dir(self): def test_extract_code_context_for_typescript(self, ts_project_dir): """Test extracting code context for a TypeScript function.""" skip_if_ts_not_supported() - from codeflash.context.code_context_extractor import get_code_optimization_context from codeflash.discovery.functions_to_optimize import find_all_functions_in_file from codeflash.languages import current as lang_current + from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context lang_current._current_language = Language.TYPESCRIPT @@ -201,7 +201,7 @@ def test_replace_function_preserves_types(self): from codeflash.languages import get_language_support from codeflash.languages.base import FunctionInfo - original_source = """ + original_source = r""" interface Config { timeout: number; retries: number; @@ -212,7 +212,7 @@ def test_replace_function_preserves_types(self): } """ - new_function = """function processConfig(config: Config): string { + new_function = r"""function processConfig(config: Config): string { // Optimized with template caching const { timeout, retries } = config; return `timeout=\${timeout}, retries=\${retries}`; diff --git a/tests/test_languages/test_vitest_e2e.py b/tests/test_languages/test_vitest_e2e.py index 68448c1cf..fc3c285a4 100644 --- a/tests/test_languages/test_vitest_e2e.py +++ b/tests/test_languages/test_vitest_e2e.py @@ -117,10 +117,10 @@ def vitest_project_dir(self): def test_extract_code_context_for_typescript(self, vitest_project_dir): """Test extracting code context for a TypeScript function.""" skip_if_js_not_supported() - from codeflash.context.code_context_extractor import get_code_optimization_context from codeflash.discovery.functions_to_optimize import find_all_functions_in_file from codeflash.languages import current as lang_current from codeflash.languages.base import Language + from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context lang_current._current_language = Language.TYPESCRIPT diff --git a/tests/test_remove_unused_definitions.py b/tests/test_remove_unused_definitions.py index 8d272b2bb..5614e7283 100644 --- a/tests/test_remove_unused_definitions.py +++ b/tests/test_remove_unused_definitions.py @@ -1,6 +1,6 @@ -from codeflash.context.unused_definition_remover import remove_unused_definitions_by_function_names +from codeflash.languages.python.context.unused_definition_remover import remove_unused_definitions_by_function_names def test_variable_removal_only() -> None: diff --git a/tests/test_unused_helper_revert.py b/tests/test_unused_helper_revert.py index 18d21de32..bfc75642c 100644 --- a/tests/test_unused_helper_revert.py +++ b/tests/test_unused_helper_revert.py @@ -5,8 +5,11 @@ import pytest -from codeflash.context.unused_definition_remover import detect_unused_helper_functions, revert_unused_helper_functions from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.languages.python.context.unused_definition_remover import ( + detect_unused_helper_functions, + revert_unused_helper_functions, +) from codeflash.models.models import CodeStringsMarkdown from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.verification.verification_utils import TestConfig