From 2e77f85834c6bb0bd39190b568f415a252bf0b95 Mon Sep 17 00:00:00 2001 From: ali Date: Mon, 16 Feb 2026 14:31:08 +0200 Subject: [PATCH 01/39] fix: resolve jest-runner from project's node_modules for Jest 30 compatibility The loop-runner was loading jest-runner from codeflash's node_modules (v29) instead of the project's (v30), causing "runtime.enterTestCode is not a function" errors. This fix: - Adds recursive search to find jest-runner in any node_modules structure - Works with npm, yarn, and pnpm (including non-hoisted deps) - Prefers higher versions when multiple are found - Removes internal looping in capturePerf when using external loop-runner - Creates fresh TestRunner per batch to avoid Jest 30 state corruption Co-Authored-By: Claude Opus 4.5 --- codeflash/languages/javascript/parse.py | 5 - codeflash/languages/javascript/test_runner.py | 2 - packages/codeflash/runtime/capture.js | 29 +-- packages/codeflash/runtime/loop-runner.js | 199 +++++++++++------- 4 files changed, 135 insertions(+), 100 deletions(-) diff --git a/codeflash/languages/javascript/parse.py b/codeflash/languages/javascript/parse.py index a5e7ae8c6..e3eee4831 100644 --- a/codeflash/languages/javascript/parse.py +++ b/codeflash/languages/javascript/parse.py @@ -527,10 +527,5 @@ def parse_jest_test_xml( f"[LOOP-SUMMARY] Results loop_index: min={min_idx}, max={max_idx}, " f"unique_count={len(unique_loop_indices)}, total_results={len(loop_indices)}" ) - if max_idx == 1 and len(loop_indices) > 1: - logger.warning( - f"[LOOP-WARNING] All {len(loop_indices)} results have loop_index=1. " - "Perf test markers may not have been parsed correctly." - ) return test_results diff --git a/codeflash/languages/javascript/test_runner.py b/codeflash/languages/javascript/test_runner.py index 1d79ad382..bcc3a74de 100644 --- a/codeflash/languages/javascript/test_runner.py +++ b/codeflash/languages/javascript/test_runner.py @@ -803,8 +803,6 @@ def run_jest_behavioral_tests( wall_clock_ns = time.perf_counter_ns() - start_time_ns logger.debug(f"Jest behavioral tests completed in {wall_clock_ns / 1e9:.2f}s") - print(result.stdout) - return result_file_path, result, coverage_json_path, None diff --git a/packages/codeflash/runtime/capture.js b/packages/codeflash/runtime/capture.js index 0fdcc5784..4ff9623fc 100644 --- a/packages/codeflash/runtime/capture.js +++ b/packages/codeflash/runtime/capture.js @@ -113,21 +113,26 @@ function checkSharedTimeLimit() { /** * Get the current loop index for a specific invocation. - * The loop index represents how many times ALL test files have been run through. - * This is the batch count from the loop-runner. + * When using external loop-runner (Jest), returns the batch number directly. + * When using internal looping (Vitest), tracks and returns the invocation count. + * * @param {string} invocationKey - Unique key for this test invocation - * @returns {number} The current batch number (loop index) + * @returns {number} The loop index for timing markers (1-based) */ function getInvocationLoopIndex(invocationKey) { - // Track local loop count for stopping logic (increments on each call) + // When using external loop-runner, use the batch number directly + // This is reliable because Jest resets module state between batches + const currentBatch = process.env.CODEFLASH_PERF_CURRENT_BATCH; + if (currentBatch !== undefined) { + return parseInt(currentBatch, 10); + } + + // For internal looping (Vitest), track the count locally if (!sharedPerfState.invocationLoopCounts[invocationKey]) { sharedPerfState.invocationLoopCounts[invocationKey] = 0; } ++sharedPerfState.invocationLoopCounts[invocationKey]; - - // Return the batch number as the loop index for timing markers - // This represents how many times all test files have been run through - return parseInt(process.env.CODEFLASH_PERF_CURRENT_BATCH || '1', 10); + return sharedPerfState.invocationLoopCounts[invocationKey]; } /** @@ -693,11 +698,9 @@ function capturePerf(funcName, lineId, fn, ...args) { // If not set, we're in Vitest mode and need to do all loops internally const hasExternalLoopRunner = process.env.CODEFLASH_PERF_CURRENT_BATCH !== undefined; - // Batched looping: run BATCH_SIZE loops per capturePerf call when using loop-runner + // When using external loop-runner (Jest), execute only once per call - the loop-runner handles batching // For Vitest (no loop-runner), do all loops internally in a single call - const batchSize = shouldLoop - ? (hasExternalLoopRunner ? getPerfBatchSize() : getPerfLoopCount()) - : 1; + const batchSize = hasExternalLoopRunner ? 1 : (shouldLoop ? getPerfLoopCount() : 1); // Initialize runtime tracking for this invocation if needed if (!sharedPerfState.invocationRuntimes[invocationKey]) { @@ -719,7 +722,7 @@ function capturePerf(funcName, lineId, fn, ...args) { break; } - // Get the loop index (batch number) for timing markers + // Get the loop index for timing markers const loopIndex = getInvocationLoopIndex(invocationKey); // Check if we've exceeded max loops for this invocation diff --git a/packages/codeflash/runtime/loop-runner.js b/packages/codeflash/runtime/loop-runner.js index c6d476f1f..1cd0803c9 100644 --- a/packages/codeflash/runtime/loop-runner.js +++ b/packages/codeflash/runtime/loop-runner.js @@ -35,69 +35,113 @@ const path = require('path'); const fs = require('fs'); /** - * Validates that a jest-runner path is valid by checking for package.json. - * @param {string} jestRunnerPath - Path to check - * @returns {boolean} True if valid jest-runner package + * Recursively find jest-runner package in node_modules. + * Works with any package manager (npm, yarn, pnpm) by searching for + * jest-runner/package.json anywhere in the tree. + * + * @param {string} nodeModulesPath - Path to node_modules directory + * @param {number} maxDepth - Maximum recursion depth (default: 5) + * @returns {string|null} Path to jest-runner or null if not found */ -function isValidJestRunnerPath(jestRunnerPath) { - if (!fs.existsSync(jestRunnerPath)) { - return false; +function findJestRunnerRecursive(nodeModulesPath, maxDepth = 5) { + function search(dir, depth) { + if (depth > maxDepth || !fs.existsSync(dir)) return null; + + try { + let entries = fs.readdirSync(dir, { withFileTypes: true }); + + // Sort entries: prefer higher versions for jest-runner@X.Y.Z directories + entries = entries.slice().sort((a, b) => { + const aMatch = a.name.match(/^jest-runner@(\d+)/); + const bMatch = b.name.match(/^jest-runner@(\d+)/); + if (aMatch && bMatch) { + return parseInt(bMatch[1], 10) - parseInt(aMatch[1], 10); + } + return a.name.localeCompare(b.name); + }); + + for (const entry of entries) { + if (!entry.isDirectory()) continue; + + const entryPath = path.join(dir, entry.name); + + // Found jest-runner directory - check if it's a valid package + if (entry.name === 'jest-runner') { + const pkgJsonPath = path.join(entryPath, 'package.json'); + if (fs.existsSync(pkgJsonPath)) { + try { + const pkgJson = JSON.parse(fs.readFileSync(pkgJsonPath, 'utf8')); + if (pkgJson.name === 'jest-runner') { + return entryPath; + } + } catch (e) { + // Ignore JSON parse errors + } + } + } + + // Recurse into: + // - node_modules subdirectories + // - scoped packages (@org/pkg) + // - hidden directories (.pnpm, .yarn, etc.) + // - pnpm versioned directories (jest-runner@30.0.5) + const shouldRecurse = entry.name === 'node_modules' || + entry.name.startsWith('@') || + entry.name.startsWith('.') || + entry.name.startsWith('jest-runner@'); + + if (shouldRecurse) { + const result = search(entryPath, depth + 1); + if (result) return result; + } + } + } catch (e) { + // Ignore permission errors + } + + return null; } - const packageJsonPath = path.join(jestRunnerPath, 'package.json'); - return fs.existsSync(packageJsonPath); + + return search(nodeModulesPath, 0); } /** - * Resolve jest-runner with monorepo support. - * Uses CODEFLASH_MONOREPO_ROOT environment variable if available, - * otherwise walks up the directory tree looking for node_modules/jest-runner. + * Resolve jest-runner from the PROJECT's node_modules (not codeflash's). + * + * Uses recursive search to find jest-runner anywhere in node_modules, + * working with any package manager (npm, yarn, pnpm). * * @returns {string} Path to jest-runner package * @throws {Error} If jest-runner cannot be found */ function resolveJestRunner() { - // Try standard resolution first (works in simple projects) - try { - return require.resolve('jest-runner'); - } catch (e) { - // Standard resolution failed - try monorepo-aware resolution - } + const monorepoMarkers = ['yarn.lock', 'pnpm-workspace.yaml', 'lerna.json', 'package-lock.json']; + + // Walk up from cwd to find all potential node_modules locations + let currentDir = process.cwd(); + const visitedDirs = new Set(); // If Python detected a monorepo root, check there first const monorepoRoot = process.env.CODEFLASH_MONOREPO_ROOT; - if (monorepoRoot) { - const jestRunnerPath = path.join(monorepoRoot, 'node_modules', 'jest-runner'); - if (isValidJestRunnerPath(jestRunnerPath)) { - return jestRunnerPath; - } + if (monorepoRoot && !visitedDirs.has(monorepoRoot)) { + visitedDirs.add(monorepoRoot); + const result = findJestRunnerRecursive(path.join(monorepoRoot, 'node_modules')); + if (result) return result; } - // Fallback: Walk up from cwd looking for node_modules/jest-runner - const monorepoMarkers = ['yarn.lock', 'pnpm-workspace.yaml', 'lerna.json', 'package-lock.json']; - let currentDir = process.cwd(); - const visitedDirs = new Set(); - while (currentDir !== path.dirname(currentDir)) { - // Avoid infinite loops if (visitedDirs.has(currentDir)) break; visitedDirs.add(currentDir); - // Try node_modules/jest-runner at this level - const jestRunnerPath = path.join(currentDir, 'node_modules', 'jest-runner'); - if (isValidJestRunnerPath(jestRunnerPath)) { - return jestRunnerPath; - } + const result = findJestRunnerRecursive(path.join(currentDir, 'node_modules')); + if (result) return result; - // Check if this is a workspace root (has monorepo markers) + // Check if this is a workspace root - stop after this const isWorkspaceRoot = monorepoMarkers.some(marker => fs.existsSync(path.join(currentDir, marker)) ); - if (isWorkspaceRoot) { - // Found workspace root but no jest-runner - stop searching - break; - } - + if (isWorkspaceRoot) break; currentDir = path.dirname(currentDir); } @@ -120,10 +164,15 @@ let jestVersion = 0; try { const jestRunnerPath = resolveJestRunner(); - const internalRequire = createRequire(jestRunnerPath); - // Try to get the TestRunner class (Jest 30+) - const jestRunner = internalRequire(jestRunnerPath); + // Read the package.json to find the actual entry point and version + const pkgJsonPath = path.join(jestRunnerPath, 'package.json'); + const pkgJson = JSON.parse(fs.readFileSync(pkgJsonPath, 'utf8')); + + // Require using the full path to the entry point + const entryPoint = path.join(jestRunnerPath, pkgJson.main || 'build/index.js'); + const jestRunner = require(entryPoint); + TestRunner = jestRunner.default || jestRunner.TestRunner; if (TestRunner && TestRunner.prototype && typeof TestRunner.prototype.runTests === 'function') { @@ -131,9 +180,11 @@ try { jestVersion = 30; jestRunnerAvailable = true; } else { - // Try Jest 29 style import + // Try Jest 29 style import - runTest is in build/runTest.js try { - runTest = internalRequire('./runTest').default; + const runTestPath = path.join(jestRunnerPath, 'build', 'runTest.js'); + const runTestModule = require(runTestPath); + runTest = runTestModule.default; if (typeof runTest === 'function') { // Jest 29 - use direct runTest function jestVersion = 29; @@ -141,10 +192,6 @@ try { } } catch (e29) { // Neither Jest 29 nor 30 style import worked - const errorMsg = `Found jest-runner at ${jestRunnerPath} but could not load it. ` + - `This may indicate an unsupported Jest version. ` + - `Supported versions: Jest 29.x and Jest 30.x`; - console.error(errorMsg); jestRunnerAvailable = false; } } @@ -233,15 +280,12 @@ class CodeflashLoopRunner { this._context = context || {}; this._eventEmitter = new SimpleEventEmitter(); - // For Jest 30+, create an instance of the base TestRunner for delegation - if (jestVersion >= 30) { - if (!TestRunner) { - throw new Error( - `Jest ${jestVersion} detected but TestRunner class not available. ` + - `This indicates an internal error in loop-runner initialization.` - ); - } - this._baseRunner = new TestRunner(globalConfig, context); + // For Jest 30+, verify TestRunner is available (we create fresh instances per batch) + if (jestVersion >= 30 && !TestRunner) { + throw new Error( + `Jest ${jestVersion} detected but TestRunner class not available. ` + + `This indicates an internal error in loop-runner initialization.` + ); } } @@ -270,7 +314,7 @@ class CodeflashLoopRunner { * @param {Object} options - Jest runner options * @returns {Promise} */ - async runTests(tests, watcher, options) { + async runTests(tests, watcher, ...rest) { const startTime = Date.now(); let batchCount = 0; let hasFailure = false; @@ -289,13 +333,11 @@ class CodeflashLoopRunner { // Check time limit BEFORE each batch if (batchCount > MIN_BATCHES && checkTimeLimit()) { - console.log(`[codeflash] Time limit reached after ${batchCount - 1} batches (${Date.now() - startTime}ms elapsed)`); break; } // Check if interrupted if (watcher.isInterrupted()) { - console.log(`[codeflash] Watcher is interrupted`) break; } @@ -303,57 +345,54 @@ class CodeflashLoopRunner { process.env.CODEFLASH_PERF_CURRENT_BATCH = String(batchCount); // Run all test files in this batch - const batchResult = await this._runAllTestsOnce(tests, watcher, options); + const batchResult = await this._runAllTestsOnce(tests, watcher, ...rest); allConsoleOutput += batchResult.consoleOutput; - // if (batchResult.hasFailure) { - // hasFailure = true; - // break; - // } - // Check time limit AFTER each batch if (checkTimeLimit()) { - console.log(`[codeflash] Time limit reached after ${batchCount} batches (${Date.now() - startTime}ms elapsed)`); break; } } const totalTimeMs = Date.now() - startTime; - console.log(`[codeflash] now: ${Date.now()}`) // Output all collected console logs - this is critical for timing marker extraction // The console output contains the !######...######! timing markers from capturePerf if (allConsoleOutput) { process.stdout.write(allConsoleOutput); } - - console.log(`[codeflash] Batched runner completed: ${batchCount} batches, ${tests.length} test files, ${totalTimeMs}ms total`); } /** * Run all test files once (one batch). * Uses different approaches for Jest 29 vs Jest 30. */ - async _runAllTestsOnce(tests, watcher, options) { + async _runAllTestsOnce(tests, watcher, ...args) { if (jestVersion >= 30) { - return this._runAllTestsOnceJest30(tests, watcher, options); + return this._runAllTestsOnceJest30(tests, watcher, ...args); } else { return this._runAllTestsOnceJest29(tests, watcher); } } /** - * Jest 30+ implementation - delegates to base TestRunner and collects results. + * Jest 30+ implementation - creates a fresh TestRunner for each batch to avoid + * state corruption issues that occur when reusing runners across batches. */ - async _runAllTestsOnceJest30(tests, watcher, options) { + async _runAllTestsOnceJest30(tests, watcher, ...args) { let hasFailure = false; let allConsoleOutput = ''; // For Jest 30, we need to collect results through event listeners const resultsCollector = []; - // Subscribe to events from the base runner - const unsubscribeSuccess = this._baseRunner.on('test-file-success', (testData) => { + // Create a FRESH TestRunner instance for each batch + // Jest 30's TestRunner corrupts its internal state after running tests, + // so we cannot reuse the same instance across multiple batches + const batchRunner = new TestRunner(this._globalConfig, this._context); + + // Subscribe to events from the batch runner + const unsubscribeSuccess = batchRunner.on('test-file-success', (testData) => { const [test, result] = testData; resultsCollector.push({ test, result, success: true }); @@ -369,7 +408,7 @@ class CodeflashLoopRunner { this._eventEmitter.emit('test-file-success', testData); }); - const unsubscribeFailure = this._baseRunner.on('test-file-failure', (testData) => { + const unsubscribeFailure = batchRunner.on('test-file-failure', (testData) => { const [test, error] = testData; resultsCollector.push({ test, error, success: false }); hasFailure = true; @@ -378,14 +417,14 @@ class CodeflashLoopRunner { this._eventEmitter.emit('test-file-failure', testData); }); - const unsubscribeStart = this._baseRunner.on('test-file-start', (testData) => { + const unsubscribeStart = batchRunner.on('test-file-start', (testData) => { // Forward to our event emitter this._eventEmitter.emit('test-file-start', testData); }); try { - // Run tests using the base runner (always serial for benchmarking) - await this._baseRunner.runTests(tests, watcher, { ...options, serial: true }); + // Run tests using the fresh batch runner (always serial for benchmarking) + await batchRunner.runTests(tests, watcher, ...args); } finally { // Cleanup subscriptions if (typeof unsubscribeSuccess === 'function') unsubscribeSuccess(); From 56941357c98bb7cd9237f5145adab929ec0dc9d0 Mon Sep 17 00:00:00 2001 From: mohammed ahmed <64513301+mohammedahmed18@users.noreply.github.com> Date: Mon, 16 Feb 2026 14:35:33 +0200 Subject: [PATCH 02/39] Update packages/codeflash/runtime/loop-runner.js Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> --- packages/codeflash/runtime/loop-runner.js | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/packages/codeflash/runtime/loop-runner.js b/packages/codeflash/runtime/loop-runner.js index 1cd0803c9..ffdaa3757 100644 --- a/packages/codeflash/runtime/loop-runner.js +++ b/packages/codeflash/runtime/loop-runner.js @@ -87,7 +87,10 @@ function findJestRunnerRecursive(nodeModulesPath, maxDepth = 5) { // - pnpm versioned directories (jest-runner@30.0.5) const shouldRecurse = entry.name === 'node_modules' || entry.name.startsWith('@') || - entry.name.startsWith('.') || + const shouldRecurse = entry.name === 'node_modules' || + entry.name.startsWith('@') || + entry.name === '.pnpm' || entry.name === '.yarn' || + entry.name.startsWith('jest-runner@'); entry.name.startsWith('jest-runner@'); if (shouldRecurse) { From 2fb4b2dbfdcf292d9beb7c808cbd4f24b66525b9 Mon Sep 17 00:00:00 2001 From: ali Date: Mon, 16 Feb 2026 14:36:39 +0200 Subject: [PATCH 03/39] cleaning up --- .github/workflows/js-tests.yml | 50 ---------------------------------- codeflash/version.py | 2 +- 2 files changed, 1 insertion(+), 51 deletions(-) delete mode 100644 .github/workflows/js-tests.yml diff --git a/.github/workflows/js-tests.yml b/.github/workflows/js-tests.yml deleted file mode 100644 index 0d56e8831..000000000 --- a/.github/workflows/js-tests.yml +++ /dev/null @@ -1,50 +0,0 @@ -name: JavaScript/TypeScript Integration Tests - -on: - push: - branches: - - main - pull_request: - workflow_dispatch: - -concurrency: - group: ${{ github.workflow }}-${{ github.ref_name }} - cancel-in-progress: true - -jobs: - js-integration-tests: - name: JS/TS Integration Tests - runs-on: ubuntu-latest - steps: - - name: Checkout code - uses: actions/checkout@v4 - with: - fetch-depth: 0 - token: ${{ secrets.GITHUB_TOKEN }} - - - name: Setup Node.js - uses: actions/setup-node@v4 - with: - node-version: '20' - - - name: Install uv - uses: astral-sh/setup-uv@v6 - - - name: Install Python dependencies - run: | - uv venv --seed - uv sync - - - name: Install npm dependencies for test projects - run: | - npm install --prefix code_to_optimize/js/code_to_optimize_js - npm install --prefix code_to_optimize/js/code_to_optimize_ts - npm install --prefix code_to_optimize/js/code_to_optimize_vitest - - - name: Run JavaScript integration tests - run: | - uv run pytest tests/languages/javascript/ -v - uv run pytest tests/test_languages/test_vitest_e2e.py -v - uv run pytest tests/test_languages/test_javascript_e2e.py -v - uv run pytest tests/test_languages/test_javascript_support.py -v - uv run pytest tests/code_utils/test_config_js.py -v diff --git a/codeflash/version.py b/codeflash/version.py index 6d60ab0c2..6225467e3 100644 --- a/codeflash/version.py +++ b/codeflash/version.py @@ -1,2 +1,2 @@ # These version placeholders will be replaced by uv-dynamic-versioning during build. -__version__ = "0.20.0.post510.dev0+b8932209" +__version__ = "0.20.0" From 2d73cf88bb199b0c6fa5d15c0fa90f3c6139a6da Mon Sep 17 00:00:00 2001 From: ali Date: Mon, 16 Feb 2026 14:55:49 +0200 Subject: [PATCH 04/39] typo --- packages/codeflash/runtime/loop-runner.js | 2 -- 1 file changed, 2 deletions(-) diff --git a/packages/codeflash/runtime/loop-runner.js b/packages/codeflash/runtime/loop-runner.js index ffdaa3757..994397044 100644 --- a/packages/codeflash/runtime/loop-runner.js +++ b/packages/codeflash/runtime/loop-runner.js @@ -85,8 +85,6 @@ function findJestRunnerRecursive(nodeModulesPath, maxDepth = 5) { // - scoped packages (@org/pkg) // - hidden directories (.pnpm, .yarn, etc.) // - pnpm versioned directories (jest-runner@30.0.5) - const shouldRecurse = entry.name === 'node_modules' || - entry.name.startsWith('@') || const shouldRecurse = entry.name === 'node_modules' || entry.name.startsWith('@') || entry.name === '.pnpm' || entry.name === '.yarn' || From 5e25b7f3b629b4c21c604cee0fc9d7e3e5a370c5 Mon Sep 17 00:00:00 2001 From: ali Date: Mon, 16 Feb 2026 18:29:17 +0200 Subject: [PATCH 05/39] debugging for failed workflow --- codeflash/languages/javascript/test_runner.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/codeflash/languages/javascript/test_runner.py b/codeflash/languages/javascript/test_runner.py index bcc3a74de..d33beee66 100644 --- a/codeflash/languages/javascript/test_runner.py +++ b/codeflash/languages/javascript/test_runner.py @@ -1044,6 +1044,10 @@ def run_jest_benchmarking_tests( # Create result with combined stdout result = subprocess.CompletedProcess(args=result.args, returncode=result.returncode, stdout=stdout, stderr="") + if result.returncode != 0: + logger.debug(f"Jest benchmarking failed with return code {result.returncode}") + logger.debug(f"Jest benchmarking stdout: {result.stdout}") + logger.debug(f"Jest benchmarking stderr: {result.stderr}") except subprocess.TimeoutExpired: logger.warning(f"Jest benchmarking timed out after {total_timeout}s") From bfe4224de880a98ab7247b99e0b442567e6ab2ed Mon Sep 17 00:00:00 2001 From: ali Date: Mon, 16 Feb 2026 18:35:31 +0200 Subject: [PATCH 06/39] just for testing --- codeflash/languages/javascript/test_runner.py | 6 +++--- codeflash/version.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/codeflash/languages/javascript/test_runner.py b/codeflash/languages/javascript/test_runner.py index d33beee66..3a193602b 100644 --- a/codeflash/languages/javascript/test_runner.py +++ b/codeflash/languages/javascript/test_runner.py @@ -1045,9 +1045,9 @@ def run_jest_benchmarking_tests( # Create result with combined stdout result = subprocess.CompletedProcess(args=result.args, returncode=result.returncode, stdout=stdout, stderr="") if result.returncode != 0: - logger.debug(f"Jest benchmarking failed with return code {result.returncode}") - logger.debug(f"Jest benchmarking stdout: {result.stdout}") - logger.debug(f"Jest benchmarking stderr: {result.stderr}") + logger.info(f"Jest benchmarking failed with return code {result.returncode}") + logger.info(f"Jest benchmarking stdout: {result.stdout}") + logger.info(f"Jest benchmarking stderr: {result.stderr}") except subprocess.TimeoutExpired: logger.warning(f"Jest benchmarking timed out after {total_timeout}s") diff --git a/codeflash/version.py b/codeflash/version.py index 6225467e3..ca6d7615e 100644 --- a/codeflash/version.py +++ b/codeflash/version.py @@ -1,2 +1,2 @@ # These version placeholders will be replaced by uv-dynamic-versioning during build. -__version__ = "0.20.0" +__version__ = "0.20.0.post634.dev0+2d73cf88" From d13cdb559b39d1f2b0ce8b0ec5802fa5f8ede709 Mon Sep 17 00:00:00 2001 From: ali Date: Mon, 16 Feb 2026 19:11:27 +0200 Subject: [PATCH 07/39] fallback to directly require the jest-runner module inside the loop runner --- packages/codeflash/runtime/loop-runner.js | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/packages/codeflash/runtime/loop-runner.js b/packages/codeflash/runtime/loop-runner.js index 994397044..43c167f32 100644 --- a/packages/codeflash/runtime/loop-runner.js +++ b/packages/codeflash/runtime/loop-runner.js @@ -89,7 +89,6 @@ function findJestRunnerRecursive(nodeModulesPath, maxDepth = 5) { entry.name.startsWith('@') || entry.name === '.pnpm' || entry.name === '.yarn' || entry.name.startsWith('jest-runner@'); - entry.name.startsWith('jest-runner@'); if (shouldRecurse) { const result = search(entryPath, depth + 1); @@ -197,9 +196,14 @@ try { } } } catch (e) { - // jest-runner not installed - this is expected for Vitest projects - // The runner will throw a helpful error if someone tries to use it without jest-runner - jestRunnerAvailable = false; + // try to directly import jest-runner + try { + const jestRunner = require('jest-runner'); + TestRunner = jestRunner.default || jestRunner.TestRunner; + jestRunnerAvailable = true; + } catch (e2) { + jestRunnerAvailable = false; + } } // Configuration From b4ea8b6bd694fc822e01b940027d7cce3794f36e Mon Sep 17 00:00:00 2001 From: mohammed ahmed <64513301+mohammedahmed18@users.noreply.github.com> Date: Mon, 16 Feb 2026 19:26:51 +0200 Subject: [PATCH 08/39] Update packages/codeflash/runtime/loop-runner.js Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> --- packages/codeflash/runtime/loop-runner.js | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/packages/codeflash/runtime/loop-runner.js b/packages/codeflash/runtime/loop-runner.js index 43c167f32..fc0b88f32 100644 --- a/packages/codeflash/runtime/loop-runner.js +++ b/packages/codeflash/runtime/loop-runner.js @@ -200,7 +200,12 @@ try { try { const jestRunner = require('jest-runner'); TestRunner = jestRunner.default || jestRunner.TestRunner; - jestRunnerAvailable = true; + if (TestRunner && TestRunner.prototype && typeof TestRunner.prototype.runTests === 'function') { + jestVersion = 30; + jestRunnerAvailable = true; + } else { + jestRunnerAvailable = false; + } } catch (e2) { jestRunnerAvailable = false; } From fa00422feaf0eb526606fcf434d11c3e1973beea Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Mon, 16 Feb 2026 13:34:07 -0500 Subject: [PATCH 09/39] refactor: simplify and deduplicate code_context_extractor Consolidate three enricher functions (get_imported_class_definitions, get_external_base_class_inits, get_external_class_inits) into a single enrich_testgen_context that parses code context once. Extract shared helpers, unify prune_cst variants, deduplicate loop bodies, and remove dead UsedNameCollector class. --- .gitignore | 2 + codeflash/code_utils/config_consts.py | 4 +- codeflash/context/code_context_extractor.py | 1136 +++++++------------ codeflash/languages/current.py | 2 +- tests/test_code_context_extractor.py | 619 ++-------- 5 files changed, 469 insertions(+), 1294 deletions(-) diff --git a/.gitignore b/.gitignore index b80ab3816..bf2a23e4d 100644 --- a/.gitignore +++ b/.gitignore @@ -268,3 +268,5 @@ tessl.json # Tessl auto-generates AGENTS.md on install; ignore to avoid cluttering git status AGENTS.md +.serena/ +.codeflash/ diff --git a/codeflash/code_utils/config_consts.py b/codeflash/code_utils/config_consts.py index e344fad8a..b84a136d8 100644 --- a/codeflash/code_utils/config_consts.py +++ b/codeflash/code_utils/config_consts.py @@ -4,8 +4,8 @@ from typing import Any, Union MAX_TEST_RUN_ITERATIONS = 5 -OPTIMIZATION_CONTEXT_TOKEN_LIMIT = 16000 -TESTGEN_CONTEXT_TOKEN_LIMIT = 16000 +OPTIMIZATION_CONTEXT_TOKEN_LIMIT = 48000 +TESTGEN_CONTEXT_TOKEN_LIMIT = 48000 INDIVIDUAL_TESTCASE_TIMEOUT = 15 MAX_FUNCTION_TEST_SECONDS = 60 MIN_IMPROVEMENT_THRESHOLD = 0.05 diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index a77cc29e6..0220a642d 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -6,7 +6,7 @@ from collections import defaultdict from itertools import chain from pathlib import Path -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING import libcst as cst @@ -34,47 +34,41 @@ from codeflash.optimization.function_context import belongs_to_function_qualified if TYPE_CHECKING: + from collections.abc import Callable + from jedi.api.classes import Name - from libcst import CSTNode from codeflash.context.unused_definition_remover import UsageInfo from codeflash.languages.base import HelperFunction +# Error message constants +READ_WRITABLE_LIMIT_ERROR = "Read-writable code has exceeded token limit, cannot proceed" +TESTGEN_LIMIT_ERROR = "Testgen code context has exceeded token limit, cannot proceed" + + +def safe_relative_to(path: Path, root: Path) -> Path: + try: + return path.resolve().relative_to(root.resolve()) + except ValueError: + return path + def build_testgen_context( helpers_of_fto_dict: dict[Path, set[FunctionSource]], helpers_of_helpers_dict: dict[Path, set[FunctionSource]], project_root_path: Path, - remove_docstrings: bool, - include_imported_classes: bool, ) -> CodeStringsMarkdown: - """Build testgen context with optional imported class definitions and external base inits.""" testgen_context = extract_code_markdown_context_from_files( helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, - remove_docstrings=remove_docstrings, + remove_docstrings=False, code_context_type=CodeContextType.TESTGEN, ) - if include_imported_classes: - imported_class_context = get_imported_class_definitions(testgen_context, project_root_path) - if imported_class_context.code_strings: - testgen_context = CodeStringsMarkdown( - code_strings=testgen_context.code_strings + imported_class_context.code_strings - ) - - external_base_inits = get_external_base_class_inits(testgen_context, project_root_path) - if external_base_inits.code_strings: - testgen_context = CodeStringsMarkdown( - code_strings=testgen_context.code_strings + external_base_inits.code_strings - ) - - external_class_inits = get_external_class_inits(testgen_context, project_root_path) - if external_class_inits.code_strings: - testgen_context = CodeStringsMarkdown( - code_strings=testgen_context.code_strings + external_class_inits.code_strings - ) + enrichment = enrich_testgen_context(testgen_context, project_root_path) + if enrichment.code_strings: + testgen_context = CodeStringsMarkdown(code_strings=testgen_context.code_strings + enrichment.code_strings) return testgen_context @@ -142,7 +136,7 @@ def get_code_optimization_context( # Handle token limits final_read_writable_tokens = encoded_tokens_len(final_read_writable_code.markdown) if final_read_writable_tokens > optim_token_limit: - raise ValueError("Read-writable code has exceeded token limit, cannot proceed") + raise ValueError(READ_WRITABLE_LIMIT_ERROR) # Setup preexisting objects for code replacer preexisting_objects = set( @@ -153,53 +147,10 @@ def get_code_optimization_context( ) read_only_context_code = read_only_code_markdown.markdown - read_only_code_markdown_tokens = encoded_tokens_len(read_only_context_code) - total_tokens = final_read_writable_tokens + read_only_code_markdown_tokens - if total_tokens > optim_token_limit: - logger.debug("Code context has exceeded token limit, removing docstrings from read-only code") - # Extract read only code without docstrings - read_only_code_no_docstring_markdown = extract_code_markdown_context_from_files( - helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, remove_docstrings=True - ) - read_only_context_code = read_only_code_no_docstring_markdown.markdown - read_only_code_no_docstring_markdown_tokens = encoded_tokens_len(read_only_context_code) - total_tokens = final_read_writable_tokens + read_only_code_no_docstring_markdown_tokens - if total_tokens > optim_token_limit: - logger.debug("Code context has exceeded token limit, removing read-only code") - read_only_context_code = "" - - # Extract code context for testgen with progressive fallback for token limits - # Try in order: full context -> remove docstrings -> remove imported classes - testgen_context = build_testgen_context( - helpers_of_fto_dict, - helpers_of_helpers_dict, - project_root_path, - remove_docstrings=False, - include_imported_classes=True, - ) + testgen_context = build_testgen_context(helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path) if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit: - logger.debug("Testgen context exceeded token limit, removing docstrings") - testgen_context = build_testgen_context( - helpers_of_fto_dict, - helpers_of_helpers_dict, - project_root_path, - remove_docstrings=True, - include_imported_classes=True, - ) - - if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit: - logger.debug("Testgen context still exceeded token limit, removing imported class definitions") - testgen_context = build_testgen_context( - helpers_of_fto_dict, - helpers_of_helpers_dict, - project_root_path, - remove_docstrings=True, - include_imported_classes=False, - ) - - if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit: - raise ValueError("Testgen code context has exceeded token limit, cannot proceed") + raise ValueError(TESTGEN_LIMIT_ERROR) code_hash_context = hashing_code_context.markdown code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest() @@ -251,10 +202,7 @@ def get_code_optimization_context_for_language( imports_code = "\n".join(code_context.imports) if code_context.imports else "" # Get relative path for target file - try: - target_relative_path = function_to_optimize.file_path.resolve().relative_to(project_root_path.resolve()) - except ValueError: - target_relative_path = function_to_optimize.file_path + target_relative_path = safe_relative_to(function_to_optimize.file_path, project_root_path) # Group helpers by file path helpers_by_file: dict[Path, list[HelperFunction]] = defaultdict(list) @@ -302,10 +250,7 @@ def get_code_optimization_context_for_language( if file_path == function_to_optimize.file_path: continue # Already included in target file - try: - helper_relative_path = file_path.resolve().relative_to(project_root_path.resolve()) - except ValueError: - helper_relative_path = file_path + helper_relative_path = safe_relative_to(file_path, project_root_path) # Combine all helpers from this file combined_helper_code = "\n\n".join(h.source_code for h in file_helpers) @@ -328,11 +273,11 @@ def get_code_optimization_context_for_language( # Check token limits read_writable_tokens = encoded_tokens_len(read_writable_code.markdown) if read_writable_tokens > optim_token_limit: - raise ValueError("Read-writable code has exceeded token limit, cannot proceed") + raise ValueError(READ_WRITABLE_LIMIT_ERROR) testgen_tokens = encoded_tokens_len(testgen_context.markdown) if testgen_tokens > testgen_token_limit: - raise ValueError("Testgen code context has exceeded token limit, cannot proceed") + raise ValueError(TESTGEN_LIMIT_ERROR) # Generate code hash from all read-writable code code_hash = hashlib.sha256(read_writable_code.flat.encode("utf-8")).hexdigest() @@ -350,6 +295,49 @@ def get_code_optimization_context_for_language( ) +def process_file_context( + file_path: Path, + primary_qualified_names: set[str], + secondary_qualified_names: set[str], + code_context_type: CodeContextType, + remove_docstrings: bool, + project_root_path: Path, + helper_functions: list[FunctionSource], +) -> CodeString | None: + try: + original_code = file_path.read_text("utf8") + except Exception as e: + logger.exception(f"Error while parsing {file_path}: {e}") + return None + + try: + all_names = primary_qualified_names | secondary_qualified_names + code_without_unused_defs = remove_unused_definitions_by_function_names(original_code, all_names) + code_context = parse_code_and_prune_cst( + code_without_unused_defs, + code_context_type, + primary_qualified_names, + secondary_qualified_names, + remove_docstrings, + ) + except ValueError as e: + logger.debug(f"Error while getting read-only code: {e}") + return None + + if code_context.strip(): + if code_context_type != CodeContextType.HASHING: + code_context = add_needed_imports_from_module( + src_module_code=original_code, + dst_module_code=code_context, + src_path=file_path, + dst_path=file_path, + project_root=project_root_path, + helper_functions=helper_functions, + ) + return CodeString(code=code_context, file_path=safe_relative_to(file_path, project_root_path)) + return None + + def extract_code_markdown_context_from_files( helpers_of_fto: dict[Path, set[FunctionSource]], helpers_of_helpers: dict[Path, set[FunctionSource]], @@ -391,79 +379,39 @@ def extract_code_markdown_context_from_files( code_context_markdown = CodeStringsMarkdown() # Extract code from file paths that contain fto and first degree helpers. helpers of helpers may also be included if they are in the same files for file_path, function_sources in helpers_of_fto.items(): - try: - original_code = file_path.read_text("utf8") - except Exception as e: - logger.exception(f"Error while parsing {file_path}: {e}") - continue - try: - qualified_function_names = {func.qualified_name for func in function_sources} - helpers_of_helpers_qualified_names = { - func.qualified_name for func in helpers_of_helpers.get(file_path, set()) - } - code_without_unused_defs = remove_unused_definitions_by_function_names( - original_code, qualified_function_names | helpers_of_helpers_qualified_names - ) - code_context = parse_code_and_prune_cst( - code_without_unused_defs, - code_context_type, - qualified_function_names, - helpers_of_helpers_qualified_names, - remove_docstrings, - ) + qualified_function_names = {func.qualified_name for func in function_sources} + helpers_of_helpers_qualified_names = {func.qualified_name for func in helpers_of_helpers.get(file_path, set())} + helper_functions = list(helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set())) + + result = process_file_context( + file_path=file_path, + primary_qualified_names=qualified_function_names, + secondary_qualified_names=helpers_of_helpers_qualified_names, + code_context_type=code_context_type, + remove_docstrings=remove_docstrings, + project_root_path=project_root_path, + helper_functions=helper_functions, + ) - except ValueError as e: - logger.debug(f"Error while getting read-only code: {e}") - continue - if code_context.strip(): - if code_context_type != CodeContextType.HASHING: - code_context = add_needed_imports_from_module( - src_module_code=original_code, - dst_module_code=code_context, - src_path=file_path, - dst_path=file_path, - project_root=project_root_path, - helper_functions=list( - helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set()) - ), - ) - code_string_context = CodeString( - code=code_context, file_path=file_path.resolve().relative_to(project_root_path.resolve()) - ) - code_context_markdown.code_strings.append(code_string_context) + if result is not None: + code_context_markdown.code_strings.append(result) # Extract code from file paths containing helpers of helpers for file_path, helper_function_sources in helpers_of_helpers_no_overlap.items(): - try: - original_code = file_path.read_text("utf8") - except Exception as e: - logger.exception(f"Error while parsing {file_path}: {e}") - continue - try: - qualified_helper_function_names = {func.qualified_name for func in helper_function_sources} - code_without_unused_defs = remove_unused_definitions_by_function_names( - original_code, qualified_helper_function_names - ) - code_context = parse_code_and_prune_cst( - code_without_unused_defs, code_context_type, set(), qualified_helper_function_names, remove_docstrings - ) - except ValueError as e: - logger.debug(f"Error while getting read-only code: {e}") - continue + qualified_helper_function_names = {func.qualified_name for func in helper_function_sources} + helper_functions = list(helpers_of_helpers_no_overlap.get(file_path, set())) + + result = process_file_context( + file_path=file_path, + primary_qualified_names=set(), + secondary_qualified_names=qualified_helper_function_names, + code_context_type=code_context_type, + remove_docstrings=remove_docstrings, + project_root_path=project_root_path, + helper_functions=helper_functions, + ) - if code_context.strip(): - if code_context_type != CodeContextType.HASHING: - code_context = add_needed_imports_from_module( - src_module_code=original_code, - dst_module_code=code_context, - src_path=file_path, - dst_path=file_path, - project_root=project_root_path, - helper_functions=list(helpers_of_helpers_no_overlap.get(file_path, set())), - ) - code_string_context = CodeString( - code=code_context, file_path=file_path.resolve().relative_to(project_root_path.resolve()) - ) - code_context_markdown.code_strings.append(code_string_context) + if result is not None: + code_context_markdown.code_strings.append(result) return code_context_markdown @@ -534,39 +482,28 @@ def get_function_sources_from_jedi( # The definition is part of this project and not defined within the original function is_valid_definition = ( - str(definition_path).startswith(str(project_root_path) + os.sep) - and not path_belongs_to_site_packages(definition_path) + is_project_path(definition_path, project_root_path) and definition.full_name and not belongs_to_function_qualified(definition, qualified_function_name) and definition.full_name.startswith(definition.module_name) ) - if is_valid_definition and definition.type == "function": - qualified_name = get_qualified_name(definition.module_name, definition.full_name) + if is_valid_definition and definition.type in ("function", "class"): + if definition.type == "function": + fqn = definition.full_name + func_name = definition.name + else: + # When a class is instantiated (e.g., MyClass()), track its __init__ as a helper + # This ensures the class definition with constructor is included in testgen context + fqn = f"{definition.full_name}.__init__" + func_name = "__init__" + qualified_name = get_qualified_name(definition.module_name, fqn) # Avoid nested functions or classes. Only class.function is allowed if len(qualified_name.split(".")) <= 2: function_source = FunctionSource( file_path=definition_path, qualified_name=qualified_name, - fully_qualified_name=definition.full_name, - only_function_name=definition.name, - source_code=definition.get_line_code(), - jedi_definition=definition, - ) - file_path_to_function_source[definition_path].add(function_source) - function_source_list.append(function_source) - # When a class is instantiated (e.g., MyClass()), track its __init__ as a helper - # This ensures the class definition with constructor is included in testgen context - elif is_valid_definition and definition.type == "class": - init_qualified_name = get_qualified_name( - definition.module_name, f"{definition.full_name}.__init__" - ) - # Only include if it's a top-level class (not nested) - if len(init_qualified_name.split(".")) <= 2: - function_source = FunctionSource( - file_path=definition_path, - qualified_name=init_qualified_name, - fully_qualified_name=f"{definition.full_name}.__init__", - only_function_name="__init__", + fully_qualified_name=fqn, + only_function_name=func_name, source_code=definition.get_line_code(), jedi_definition=definition, ) @@ -576,60 +513,66 @@ def get_function_sources_from_jedi( return file_path_to_function_source, function_source_list -def get_imported_class_definitions(code_context: CodeStringsMarkdown, project_root_path: Path) -> CodeStringsMarkdown: - """Extract class definitions for imported types from project modules. - - This function analyzes the imports in the extracted code context and fetches - class definitions for any classes imported from project modules. This helps - the LLM understand the actual class structure (constructors, methods, inheritance) - rather than just seeing import statements. - - Also recursively extracts base classes when a class inherits from another class - in the same module, ensuring the full inheritance chain is available for - understanding constructor signatures. - - Args: - code_context: The already extracted code context containing imports - project_root_path: Root path of the project - - Returns: - CodeStringsMarkdown containing class definitions from imported project modules - - """ - import jedi - - # Collect all code from the context +def _parse_and_collect_imports(code_context: CodeStringsMarkdown) -> tuple[ast.Module, dict[str, str]] | None: all_code = "\n".join(cs.code for cs in code_context.code_strings) - - # Parse to find import statements try: tree = ast.parse(all_code) except SyntaxError: - return CodeStringsMarkdown(code_strings=[]) - - # Collect imported names and their source modules - imported_names: dict[str, str] = {} # name -> module_path + return None + imported_names: dict[str, str] = {} for node in ast.walk(tree): if isinstance(node, ast.ImportFrom) and node.module: for alias in node.names: if alias.name != "*": imported_name = alias.asname if alias.asname else alias.name imported_names[imported_name] = node.module + return tree, imported_names + + +def collect_existing_class_names(tree: ast.Module) -> set[str]: + return {node.name for node in ast.walk(tree) if isinstance(node, ast.ClassDef)} + + +def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path: Path) -> CodeStringsMarkdown: + import jedi + + result = _parse_and_collect_imports(code_context) + if result is None: + return CodeStringsMarkdown(code_strings=[]) + tree, imported_names = result if not imported_names: return CodeStringsMarkdown(code_strings=[]) - # Track which classes we've already extracted to avoid duplicates - extracted_classes: set[tuple[Path, str]] = set() # (file_path, class_name) + existing_classes = collect_existing_class_names(tree) - # Also track what's already defined in the context - existing_definitions: set[str] = set() + # Collect base class names from ClassDef nodes (single walk) + base_class_names: set[str] = set() for node in ast.walk(tree): if isinstance(node, ast.ClassDef): - existing_definitions.add(node.name) + for base in node.bases: + if isinstance(base, ast.Name): + base_class_names.add(base.id) + elif isinstance(base, ast.Attribute) and isinstance(base.value, ast.Name): + base_class_names.add(base.attr) - class_code_strings: list[CodeString] = [] + # Classify external imports using importlib-based check + is_project_cache: dict[str, bool] = {} + external_base_classes: set[tuple[str, str]] = set() + external_direct_imports: set[tuple[str, str]] = set() + + for name, module_name in imported_names.items(): + if not _is_project_module_cached(module_name, project_root_path, is_project_cache): + if name in base_class_names: + external_base_classes.add((name, module_name)) + if name not in existing_classes: + external_direct_imports.add((name, module_name)) + + code_strings: list[CodeString] = [] + emitted_class_names: set[str] = set() + # --- Step 1: Project class definitions (jedi resolution + recursive base extraction) --- + extracted_classes: set[tuple[Path, str]] = set() module_cache: dict[Path, tuple[str, ast.Module]] = {} def get_module_source_and_tree(module_path: Path) -> tuple[str, ast.Module] | None: @@ -647,12 +590,9 @@ def get_module_source_and_tree(module_path: Path) -> tuple[str, ast.Module] | No def extract_class_and_bases( class_name: str, module_path: Path, module_source: str, module_tree: ast.Module ) -> None: - """Extract a class and its base classes recursively from the same module.""" - # Skip if already extracted if (module_path, class_name) in extracted_classes: return - # Find the class definition in the module class_node = None for node in ast.walk(module_tree): if isinstance(node, ast.ClassDef) and node.name == class_name: @@ -662,22 +602,18 @@ def extract_class_and_bases( if class_node is None: return - # First, recursively extract base classes from the same module for base in class_node.bases: base_name = None if isinstance(base, ast.Name): base_name = base.id elif isinstance(base, ast.Attribute): - # For module.ClassName, we skip (cross-module inheritance) continue - if base_name and base_name not in existing_definitions: - # Check if base class is defined in the same module + if base_name and base_name not in existing_classes: extract_class_and_bases(base_name, module_path, module_source, module_tree) - # Now extract this class (after its bases, so base classes appear first) if (module_path, class_name) in extracted_classes: - return # Already added by another path + return lines = module_source.split("\n") start_line = class_node.lineno @@ -685,21 +621,17 @@ def extract_class_and_bases( start_line = min(d.lineno for d in class_node.decorator_list) class_source = "\n".join(lines[start_line - 1 : class_node.end_lineno]) - # Extract imports for the class class_imports = extract_imports_for_class(module_tree, class_node, module_source) full_source = class_imports + "\n\n" + class_source if class_imports else class_source - class_code_strings.append(CodeString(code=full_source, file_path=module_path)) + code_strings.append(CodeString(code=full_source, file_path=module_path)) extracted_classes.add((module_path, class_name)) + emitted_class_names.add(class_name) for name, module_name in imported_names.items(): - # Skip if already defined in context - if name in existing_definitions: + if name in existing_classes: continue - - # Try to find the module file using Jedi try: - # Create a script that imports the module to resolve it test_code = f"import {module_name}" script = jedi.Script(test_code, project=jedi.Project(path=project_root_path)) completions = script.goto(1, len(test_code)) @@ -711,123 +643,85 @@ def extract_class_and_bases( if not module_path: continue - # Check if this is a project module (not stdlib/third-party) - if not str(module_path).startswith(str(project_root_path) + os.sep): - continue - if path_belongs_to_site_packages(module_path): + if not is_project_path(module_path, project_root_path): continue - # Get module source and tree - result = get_module_source_and_tree(module_path) - if result is None: + mod_result = get_module_source_and_tree(module_path) + if mod_result is None: continue - module_source, module_tree = result + module_source, module_tree = mod_result - # Extract the class and its base classes extract_class_and_bases(name, module_path, module_source, module_tree) except Exception: logger.debug(f"Error extracting class definition for {name} from {module_name}") continue - return CodeStringsMarkdown(code_strings=class_code_strings) - + # --- Step 2: External base class __init__ stubs --- + if external_base_classes: + for cls, name in resolve_classes_from_modules(external_base_classes): + if name in emitted_class_names: + continue + stub = extract_init_stub(cls, name, require_site_packages=False) + if stub is not None: + code_strings.append(stub) + emitted_class_names.add(name) + + # --- Step 3: External direct import __init__ stubs with BFS --- + if external_direct_imports: + processed_classes: set[type] = set() + worklist: list[tuple[type, str, int]] = [ + (cls, name, 0) for cls, name in resolve_classes_from_modules(external_direct_imports) + ] + + while worklist: + cls, class_name, depth = worklist.pop(0) + + if cls in processed_classes: + continue + processed_classes.add(cls) -def get_external_base_class_inits(code_context: CodeStringsMarkdown, project_root_path: Path) -> CodeStringsMarkdown: - """Extract __init__ methods from external library base classes. + stub = extract_init_stub(cls, class_name) + if stub is None: + continue - Scans the code context for classes that inherit from external libraries and extracts - just their __init__ methods. This helps the LLM understand constructor signatures - for mocking or instantiation. - """ - import importlib - import inspect - import textwrap + if class_name not in emitted_class_names: + code_strings.append(stub) + emitted_class_names.add(class_name) - all_code = "\n".join(cs.code for cs in code_context.code_strings) + if depth < MAX_TRANSITIVE_DEPTH: + for dep_cls in resolve_transitive_type_deps(cls): + if dep_cls not in processed_classes: + worklist.append((dep_cls, dep_cls.__name__, depth + 1)) - try: - tree = ast.parse(all_code) - except SyntaxError: - return CodeStringsMarkdown(code_strings=[]) + return CodeStringsMarkdown(code_strings=code_strings) - imported_names: dict[str, str] = {} - # Use a set to deduplicate external base entries to avoid repeated expensive checks/imports. - external_bases_set: set[tuple[str, str]] = set() - # Local cache to avoid repeated _is_project_module calls for the same module_name. - is_project_cache: dict[str, bool] = {} - for node in ast.walk(tree): - if isinstance(node, ast.ImportFrom) and node.module: - for alias in node.names: - if alias.name != "*": - imported_name = alias.asname if alias.asname else alias.name - imported_names[imported_name] = node.module - elif isinstance(node, ast.ClassDef): - for base in node.bases: - base_name = None - if isinstance(base, ast.Name): - base_name = base.id - elif isinstance(base, ast.Attribute) and isinstance(base.value, ast.Name): - base_name = base.attr - - if base_name and base_name in imported_names: - module_name = imported_names[base_name] - # Check cache first to avoid repeated expensive checks. - cached = is_project_cache.get(module_name) - if cached is None: - is_project = _is_project_module(module_name, project_root_path) - is_project_cache[module_name] = is_project - else: - is_project = cached - - if not is_project: - external_bases_set.add((base_name, module_name)) - - if not external_bases_set: - return CodeStringsMarkdown(code_strings=[]) +def resolve_classes_from_modules(candidates: set[tuple[str, str]]) -> list[tuple[type, str]]: + """Import modules and resolve candidate (class_name, module_name) pairs to class objects.""" + import importlib + import inspect - code_strings: list[CodeString] = [] - # Cache imported modules to avoid repeated importlib.import_module calls. - imported_module_cache: dict[str, object] = {} + resolved: list[tuple[type, str]] = [] + module_cache: dict[str, object] = {} - for base_name, module_name in external_bases_set: + for class_name, module_name in candidates: try: - module = imported_module_cache.get(module_name) + module = module_cache.get(module_name) if module is None: module = importlib.import_module(module_name) - imported_module_cache[module_name] = module - - base_class = getattr(module, base_name, None) - if base_class is None: - continue - - init_method = getattr(base_class, "__init__", None) - if init_method is None: - continue - - try: - init_source = inspect.getsource(init_method) - init_source = textwrap.dedent(init_source) - class_file = Path(inspect.getfile(base_class)) - parts = class_file.parts - if "site-packages" in parts: - idx = parts.index("site-packages") - class_file = Path(*parts[idx + 1 :]) - except (OSError, TypeError): - continue - - class_source = f"class {base_name}:\n" + textwrap.indent(init_source, " ") - code_strings.append(CodeString(code=class_source, file_path=class_file)) + module_cache[module_name] = module + cls = getattr(module, class_name, None) + if cls is not None and inspect.isclass(cls): + resolved.append((cls, class_name)) except (ImportError, ModuleNotFoundError, AttributeError): - logger.debug(f"Failed to extract __init__ for {module_name}.{base_name}") - continue + logger.debug(f"Failed to import {module_name}.{class_name}") - return CodeStringsMarkdown(code_strings=code_strings) + return resolved -MAX_TRANSITIVE_DEPTH = 2 +MAX_TRANSITIVE_DEPTH = 5 def extract_classes_from_type_hint(hint: object) -> list[type]: @@ -897,8 +791,15 @@ def resolve_transitive_type_deps(cls: type) -> list[type]: return deps -def extract_init_stub_for_class(cls: type, class_name: str) -> CodeString | None: - """Extract a stub containing the class definition with only its __init__ method.""" +def extract_init_stub(cls: type, class_name: str, require_site_packages: bool = True) -> CodeString | None: + """Extract a stub containing the class definition with only its __init__ method. + + Args: + cls: The class object to extract __init__ from + class_name: Name to use for the class in the stub + require_site_packages: If True, only extract from site-packages. If False, include stdlib too. + + """ import inspect import textwrap @@ -911,7 +812,7 @@ def extract_init_stub_for_class(cls: type, class_name: str) -> CodeString | None except (OSError, TypeError): return None - if not path_belongs_to_site_packages(class_file): + if require_site_packages and not path_belongs_to_site_packages(class_file): return None try: @@ -929,106 +830,22 @@ def extract_init_stub_for_class(cls: type, class_name: str) -> CodeString | None return CodeString(code=class_source, file_path=class_file) -def get_external_class_inits(code_context: CodeStringsMarkdown, project_root_path: Path) -> CodeStringsMarkdown: - """Extract __init__ methods from directly imported external library classes. - - Scans the code context for classes imported from external packages (site-packages) and extracts - their __init__ methods, including transitive type dependencies found in __init__ annotations. - This helps the LLM understand constructor signatures for instantiation in generated tests. - """ - import importlib - import inspect - - all_code = "\n".join(cs.code for cs in code_context.code_strings) - - try: - tree = ast.parse(all_code) - except SyntaxError: - return CodeStringsMarkdown(code_strings=[]) - - # Collect all from X import Y statements - imported_names: dict[str, str] = {} - is_project_cache: dict[str, bool] = {} - - # Track classes already defined in the context to avoid duplicates - existing_classes: set[str] = set() - - for node in ast.walk(tree): - if isinstance(node, ast.ImportFrom) and node.module: - for alias in node.names: - if alias.name != "*": - imported_name = alias.asname if alias.asname else alias.name - imported_names[imported_name] = node.module - elif isinstance(node, ast.ClassDef): - existing_classes.add(node.name) - - if not imported_names: - return CodeStringsMarkdown(code_strings=[]) - - # Filter to external-only imports - external_imports: set[tuple[str, str]] = set() - for name, module_name in imported_names.items(): - if name in existing_classes: - continue - cached = is_project_cache.get(module_name) - if cached is None: - is_project = _is_project_module(module_name, project_root_path) - is_project_cache[module_name] = is_project - else: - is_project = cached - if not is_project: - external_imports.add((name, module_name)) - - if not external_imports: - return CodeStringsMarkdown(code_strings=[]) - - code_strings: list[CodeString] = [] - imported_module_cache: dict[str, object] = {} - processed_classes: set[type] = set() - emitted_names: set[str] = set() - - # BFS worklist: (class_object, class_name, depth) - worklist: list[tuple[type, str, int]] = [] - - # Seed the worklist with directly imported classes - for class_name, module_name in external_imports: - try: - module = imported_module_cache.get(module_name) - if module is None: - module = importlib.import_module(module_name) - imported_module_cache[module_name] = module - - cls = getattr(module, class_name, None) - if cls is None or not inspect.isclass(cls): - continue - - worklist.append((cls, class_name, 0)) - except (ImportError, ModuleNotFoundError, AttributeError): - logger.debug(f"Failed to import {module_name}.{class_name}") - continue - - while worklist: - cls, class_name, depth = worklist.pop(0) - - if cls in processed_classes: - continue - processed_classes.add(cls) - - stub = extract_init_stub_for_class(cls, class_name) - if stub is None: - continue +def _is_project_module_cached(module_name: str, project_root_path: Path, cache: dict[str, bool]) -> bool: + cached = cache.get(module_name) + if cached is not None: + return cached + is_project = _is_project_module(module_name, project_root_path) + cache[module_name] = is_project + return is_project - if class_name not in emitted_names: - code_strings.append(stub) - emitted_names.add(class_name) - # Resolve transitive type dependencies up to MAX_TRANSITIVE_DEPTH - if depth < MAX_TRANSITIVE_DEPTH: - for dep_cls in resolve_transitive_type_deps(cls): - if dep_cls not in processed_classes: - worklist.append((dep_cls, dep_cls.__name__, depth + 1)) - - return CodeStringsMarkdown(code_strings=code_strings) +def is_project_path(module_path: Path | None, project_root_path: Path) -> bool: + if module_path is None: + return False + # site-packages must be checked first because .venv/site-packages is under project root + if path_belongs_to_site_packages(module_path): + return False + return str(module_path).startswith(str(project_root_path) + os.sep) def _is_project_module(module_name: str, project_root_path: Path) -> bool: @@ -1042,13 +859,7 @@ def _is_project_module(module_name: str, project_root_path: Path) -> bool: else: if spec is None or spec.origin is None: return False - module_path = Path(spec.origin) - # Check if the module is in site-packages (external dependency) - # This must be checked first because .venv/site-packages is under project root - if path_belongs_to_site_packages(module_path): - return False - # Check if the module is within the project root - return str(module_path).startswith(str(project_root_path) + os.sep) + return is_project_path(Path(spec.origin), project_root_path) def extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef, module_source: str) -> str: @@ -1130,78 +941,6 @@ def is_dunder_method(name: str) -> bool: return len(name) > 4 and name.isascii() and name.startswith("__") and name.endswith("__") -class UsedNameCollector(cst.CSTVisitor): - """Collects all base names referenced in code (for import preservation).""" - - def __init__(self) -> None: - self.used_names: set[str] = set() - self.defined_names: set[str] = set() - - def visit_Name(self, node: cst.Name) -> None: - self.used_names.add(node.value) - - def visit_Attribute(self, node: cst.Attribute) -> bool | None: - base = node.value - while isinstance(base, cst.Attribute): - base = base.value - if isinstance(base, cst.Name): - self.used_names.add(base.value) - return True - - def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None: - self.defined_names.add(node.name.value) - return True - - def visit_ClassDef(self, node: cst.ClassDef) -> bool | None: - self.defined_names.add(node.name.value) - return True - - def visit_Assign(self, node: cst.Assign) -> bool | None: - for target in node.targets: - names = extract_names_from_targets(target.target) - self.defined_names.update(names) - return True - - def visit_AnnAssign(self, node: cst.AnnAssign) -> bool | None: - names = extract_names_from_targets(node.target) - self.defined_names.update(names) - return True - - def get_external_names(self) -> set[str]: - return self.used_names - self.defined_names - {"self", "cls"} - - -def get_imported_names(import_node: cst.Import | cst.ImportFrom) -> set[str]: - """Extract the names made available by an import statement.""" - names: set[str] = set() - if isinstance(import_node, cst.Import): - if isinstance(import_node.names, cst.ImportStar): - return {"*"} - for alias in import_node.names: - if isinstance(alias, cst.ImportAlias): - if alias.asname and isinstance(alias.asname.name, cst.Name): - names.add(alias.asname.name.value) - elif isinstance(alias.name, cst.Name): - names.add(alias.name.value) - elif isinstance(alias.name, cst.Attribute): - # import foo.bar -> accessible as "foo" - base: cst.BaseExpression = alias.name - while isinstance(base, cst.Attribute): - base = base.value - if isinstance(base, cst.Name): - names.add(base.value) - elif isinstance(import_node, cst.ImportFrom): - if isinstance(import_node.names, cst.ImportStar): - return {"*"} - for alias in import_node.names: - if isinstance(alias, cst.ImportAlias): - if alias.asname and isinstance(alias.asname.name, cst.Name): - names.add(alias.asname.name.value) - elif isinstance(alias.name, cst.Name): - names.add(alias.name.value) - return names - - def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode: """Removes the docstring from an indented block if it exists.""" if not isinstance(indented_block.body[0], cst.SimpleStatementLine): @@ -1224,27 +963,31 @@ def parse_code_and_prune_cst( defs_with_usages = collect_top_level_defs_with_usages(module, target_functions | helpers_of_helper_functions) if code_context_type == CodeContextType.READ_WRITABLE: - filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions, defs_with_usages) + filtered_node, found_target = prune_cst( + module, target_functions, defs_with_usages=defs_with_usages, keep_class_init=True + ) elif code_context_type == CodeContextType.READ_ONLY: - filtered_node, found_target = prune_cst_for_context( + filtered_node, found_target = prune_cst( module, target_functions, - helpers_of_helper_functions, + helpers=helpers_of_helper_functions, remove_docstrings=remove_docstrings, include_target_in_output=False, - include_init_dunder=False, + include_dunder_methods=True, ) elif code_context_type == CodeContextType.TESTGEN: - filtered_node, found_target = prune_cst_for_context( + filtered_node, found_target = prune_cst( module, target_functions, - helpers_of_helper_functions, + helpers=helpers_of_helper_functions, remove_docstrings=remove_docstrings, - include_target_in_output=True, + include_dunder_methods=True, include_init_dunder=True, ) elif code_context_type == CodeContextType.HASHING: - filtered_node, found_target = prune_cst_for_code_hashing(module, target_functions) + filtered_node, found_target = prune_cst( + module, target_functions, remove_docstrings=True, exclude_init_from_targets=True + ) else: raise ValueError(f"Unknown code_context_type: {code_context_type}") # noqa: EM102 @@ -1258,234 +1001,90 @@ def parse_code_and_prune_cst( return "" -def prune_cst_for_read_writable_code( - node: cst.CSTNode, target_functions: set[str], defs_with_usages: dict[str, UsageInfo], prefix: str = "" -) -> tuple[cst.CSTNode | None, bool]: - """Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions. - - Returns - ------- - (filtered_node, found_target): - filtered_node: The modified CST node or None if it should be removed. - found_target: True if a target function was found in this node's subtree. - - """ - if isinstance(node, (cst.Import, cst.ImportFrom)): - return None, False - - if isinstance(node, cst.FunctionDef): - qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value - if qualified_name in target_functions: - return node, True - return None, False - - if isinstance(node, cst.ClassDef): - # Do not recurse into nested classes - if prefix: - return None, False - - class_name = node.name.value - - # Assuming always an IndentedBlock - if not isinstance(node.body, cst.IndentedBlock): - raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004 - class_prefix = f"{prefix}.{class_name}" if prefix else class_name - - # Check if this class contains any target functions - has_target_functions = any( - isinstance(stmt, cst.FunctionDef) and f"{class_prefix}.{stmt.name.value}" in target_functions - for stmt in node.body.body - ) - - # If the class is used as a dependency (not containing target functions), keep it entirely - # This handles cases like enums, dataclasses, and other types used by the target function - if ( - not has_target_functions - and class_name in defs_with_usages - and defs_with_usages[class_name].used_by_qualified_function - ): - return node, True - - new_body = [] - found_target = False +def _qualified_name(prefix: str, name: str) -> str: + return f"{prefix}.{name}" if prefix else name - for stmt in node.body.body: - if isinstance(stmt, cst.FunctionDef): - qualified_name = f"{class_prefix}.{stmt.name.value}" - if qualified_name in target_functions: - new_body.append(stmt) - found_target = True - elif stmt.name.value == "__init__": - new_body.append(stmt) # enable __init__ optimizations - # If no target functions found, remove the class entirely - if not new_body or not found_target: - return None, False - - return node.with_changes(body=cst.IndentedBlock(body=new_body)), found_target - - if isinstance(node, cst.Assign): - for target in node.targets: - names = extract_names_from_targets(target.target) - for name in names: - if name in defs_with_usages and defs_with_usages[name].used_by_qualified_function: - return node, True - return None, False - if isinstance(node, (cst.AnnAssign, cst.AugAssign)): - names = extract_names_from_targets(node.target) - for name in names: - if name in defs_with_usages and defs_with_usages[name].used_by_qualified_function: - return node, True - return None, False +def _validate_classdef(node: cst.ClassDef, prefix: str) -> tuple[str, cst.IndentedBlock] | None: + if prefix: + return None + if not isinstance(node.body, cst.IndentedBlock): + raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004 + return _qualified_name(prefix, node.name.value), node.body - # For other nodes, we preserve them only if they contain target functions in their children. - section_names = get_section_names(node) - if not section_names: - return node, False +def _recurse_sections( + node: cst.CSTNode, + section_names: list[str], + prune_fn: Callable[[cst.CSTNode], tuple[cst.CSTNode | None, bool]], + keep_non_target_children: bool = False, +) -> tuple[cst.CSTNode | None, bool]: updates: dict[str, list[cst.CSTNode] | cst.CSTNode] = {} found_any_target = False - for section in section_names: original_content = getattr(node, section, None) if isinstance(original_content, (list, tuple)): new_children = [] section_found_target = False for child in original_content: - filtered, found_target = prune_cst_for_read_writable_code( - child, target_functions, defs_with_usages, prefix - ) + filtered, found_target = prune_fn(child) if filtered: new_children.append(filtered) section_found_target |= found_target - - if section_found_target: + if keep_non_target_children: + if section_found_target or new_children: + found_any_target |= section_found_target + updates[section] = new_children + elif section_found_target: found_any_target = True updates[section] = new_children elif original_content is not None: - filtered, found_target = prune_cst_for_read_writable_code( - original_content, target_functions, defs_with_usages, prefix - ) - if found_target: - found_any_target = True + filtered, found_target = prune_fn(original_content) + if keep_non_target_children: + found_any_target |= found_target if filtered: updates[section] = filtered - - if not found_any_target: - return None, False - return (node.with_changes(**updates) if updates else node), True - - -def prune_cst_for_code_hashing( - node: cst.CSTNode, target_functions: set[str], prefix: str = "" -) -> tuple[cst.CSTNode | None, bool]: - """Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions. - - Returns - ------- - (filtered_node, found_target): - filtered_node: The modified CST node or None if it should be removed. - found_target: True if a target function was found in this node's subtree. - - """ - if isinstance(node, (cst.Import, cst.ImportFrom)): - return None, False - - if isinstance(node, cst.FunctionDef): - qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value - # For hashing, exclude __init__ methods even if in target_functions - # because they don't affect the semantic behavior being hashed - # But include other dunder methods like __call__ which do affect behavior - if qualified_name in target_functions and node.name.value != "__init__": - new_body = remove_docstring_from_body(node.body) if isinstance(node.body, cst.IndentedBlock) else node.body - return node.with_changes(body=new_body), True - return None, False - - if isinstance(node, cst.ClassDef): - # Do not recurse into nested classes - if prefix: - return None, False - # Assuming always an IndentedBlock - if not isinstance(node.body, cst.IndentedBlock): - raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004 - class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value - new_class_body: list[cst.CSTNode] = [] - found_target = False - - for stmt in node.body.body: - if isinstance(stmt, cst.FunctionDef): - qualified_name = f"{class_prefix}.{stmt.name.value}" - # For hashing, exclude __init__ methods even if in target_functions - # but include other methods like __call__ which affect behavior - if qualified_name in target_functions and stmt.name.value != "__init__": - stmt_with_changes = stmt.with_changes( - body=remove_docstring_from_body(cast("cst.IndentedBlock", stmt.body)) - ) - new_class_body.append(stmt_with_changes) - found_target = True - # If no target functions found, remove the class entirely - if not new_class_body or not found_target: - return None, False - return node.with_changes( - body=cst.IndentedBlock(cast("list[cst.BaseStatement]", new_class_body)) - ) if new_class_body else None, found_target - - # For other nodes, we preserve them only if they contain target functions in their children. - section_names = get_section_names(node) - if not section_names: - return node, False - - updates: dict[str, list[cst.CSTNode] | cst.CSTNode] = {} - found_any_target = False - - for section in section_names: - original_content = getattr(node, section, None) - if isinstance(original_content, (list, tuple)): - new_children = [] - section_found_target = False - for child in original_content: - filtered, found_target = prune_cst_for_code_hashing(child, target_functions, prefix) - if filtered: - new_children.append(filtered) - section_found_target |= found_target - - if section_found_target: - found_any_target = True - updates[section] = new_children - elif original_content is not None: - filtered, found_target = prune_cst_for_code_hashing(original_content, target_functions, prefix) - if found_target: + elif found_target: found_any_target = True if filtered: updates[section] = filtered - + if keep_non_target_children: + if updates: + return node.with_changes(**updates), found_any_target + return None, False if not found_any_target: return None, False - return (node.with_changes(**updates) if updates else node), True -def prune_cst_for_context( +def prune_cst( node: cst.CSTNode, target_functions: set[str], - helpers_of_helper_functions: set[str], prefix: str = "", + *, + defs_with_usages: dict[str, UsageInfo] | None = None, + helpers: set[str] | None = None, remove_docstrings: bool = False, - include_target_in_output: bool = False, + include_target_in_output: bool = True, + exclude_init_from_targets: bool = False, + keep_class_init: bool = False, + include_dunder_methods: bool = False, include_init_dunder: bool = False, ) -> tuple[cst.CSTNode | None, bool]: - """Recursively filter the node for code context extraction. + """Unified function to prune CST nodes based on various filtering criteria. Args: node: The CST node to filter target_functions: Set of qualified function names that are targets - helpers_of_helper_functions: Set of helper function qualified names prefix: Current qualified name prefix (for class methods) + defs_with_usages: Dict of definitions with usage info (for READ_WRITABLE mode) + helpers: Set of helper function qualified names (for READ_ONLY/TESTGEN modes) remove_docstrings: Whether to remove docstrings from output - include_target_in_output: If True, include target functions in output (testgen mode) - If False, exclude target functions (read-only mode) - include_init_dunder: If True, include __init__ in dunder methods (testgen mode) - If False, exclude __init__ from dunder methods (read-only mode) + include_target_in_output: Whether to include target functions in output + exclude_init_from_targets: Whether to exclude __init__ from targets (HASHING mode) + keep_class_init: Whether to keep __init__ methods in classes (READ_WRITABLE mode) + include_dunder_methods: Whether to include dunder methods (READ_ONLY/TESTGEN modes) + include_init_dunder: Whether to include __init__ in dunder methods Returns: (filtered_node, found_target): @@ -1497,25 +1096,34 @@ def prune_cst_for_context( return None, False if isinstance(node, cst.FunctionDef): - qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value + qualified_name = _qualified_name(prefix, node.name.value) - # Check if it's a helper of helper function - if qualified_name in helpers_of_helper_functions: + # Check if it's a helper function (higher priority than target) + if helpers and qualified_name in helpers: if remove_docstrings and isinstance(node.body, cst.IndentedBlock): return node.with_changes(body=remove_docstring_from_body(node.body)), True return node, True # Check if it's a target function if qualified_name in target_functions: + # Handle exclude_init_from_targets for HASHING mode + if exclude_init_from_targets and node.name.value == "__init__": + return None, False + if include_target_in_output: if remove_docstrings and isinstance(node.body, cst.IndentedBlock): return node.with_changes(body=remove_docstring_from_body(node.body)), True return node, True return None, True - # Check dunder methods - # For read-only mode, exclude __init__; for testgen mode, include all dunders - if is_dunder_method(node.name.value) and (include_init_dunder or node.name.value != "__init__"): + # Handle class __init__ for READ_WRITABLE mode + if keep_class_init and node.name.value == "__init__": + return node, False + + # Handle dunder methods for READ_ONLY/TESTGEN modes + if include_dunder_methods and is_dunder_method(node.name.value): + if not include_init_dunder and node.name.value == "__init__": + return None, False if remove_docstrings and isinstance(node.body, cst.IndentedBlock): return node.with_changes(body=remove_docstring_from_body(node.body)), False return node, False @@ -1523,26 +1131,44 @@ def prune_cst_for_context( return None, False if isinstance(node, cst.ClassDef): - # Do not recurse into nested classes - if prefix: + result = _validate_classdef(node, prefix) + if result is None: return None, False - # Assuming always an IndentedBlock - if not isinstance(node.body, cst.IndentedBlock): - raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004 + class_prefix, _ = result + class_name = node.name.value - class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value + # Handle dependency classes for READ_WRITABLE mode + if defs_with_usages: + # Check if this class contains any target functions + has_target_functions = any( + isinstance(stmt, cst.FunctionDef) and _qualified_name(class_prefix, stmt.name.value) in target_functions + for stmt in node.body.body + ) - # First pass: detect if there is a target function in the class + # If the class is used as a dependency (not containing target functions), keep it entirely + if ( + not has_target_functions + and class_name in defs_with_usages + and defs_with_usages[class_name].used_by_qualified_function + ): + return node, True + + # Recursively filter each statement in the class body + new_class_body: list[cst.CSTNode] = [] found_in_class = False - new_class_body: list[CSTNode] = [] + for stmt in node.body.body: - filtered, found_target = prune_cst_for_context( + filtered, found_target = prune_cst( stmt, target_functions, - helpers_of_helper_functions, class_prefix, + defs_with_usages=defs_with_usages, + helpers=helpers, remove_docstrings=remove_docstrings, include_target_in_output=include_target_in_output, + exclude_init_from_targets=exclude_init_from_targets, + keep_class_init=keep_class_init, + include_dunder_methods=include_dunder_methods, include_init_dunder=include_init_dunder, ) found_in_class |= found_target @@ -1552,57 +1178,67 @@ def prune_cst_for_context( if not found_in_class: return None, False - if remove_docstrings: - return node.with_changes( - body=remove_docstring_from_body(node.body.with_changes(body=new_class_body)) - ) if new_class_body else None, True + # Apply docstring removal to class if needed + if remove_docstrings and new_class_body: + return node.with_changes(body=remove_docstring_from_body(node.body.with_changes(body=new_class_body))), True + return node.with_changes(body=node.body.with_changes(body=new_class_body)) if new_class_body else None, True - # For other nodes, keep the node and recursively filter children + # Handle assignments for READ_WRITABLE mode + if defs_with_usages is not None: + if isinstance(node, cst.Assign): + for target in node.targets: + names = extract_names_from_targets(target.target) + for name in names: + if name in defs_with_usages and defs_with_usages[name].used_by_qualified_function: + return node, True + return None, False + + if isinstance(node, (cst.AnnAssign, cst.AugAssign)): + names = extract_names_from_targets(node.target) + for name in names: + if name in defs_with_usages and defs_with_usages[name].used_by_qualified_function: + return node, True + return None, False + + # For other nodes, recursively process children section_names = get_section_names(node) if not section_names: return node, False - updates: dict[str, list[cst.CSTNode] | cst.CSTNode] = {} - found_any_target = False - - for section in section_names: - original_content = getattr(node, section, None) - if isinstance(original_content, (list, tuple)): - new_children = [] - section_found_target = False - for child in original_content: - filtered, found_target = prune_cst_for_context( - child, - target_functions, - helpers_of_helper_functions, - prefix, - remove_docstrings=remove_docstrings, - include_target_in_output=include_target_in_output, - include_init_dunder=include_init_dunder, - ) - if filtered: - new_children.append(filtered) - section_found_target |= found_target - - if section_found_target or new_children: - found_any_target |= section_found_target - updates[section] = new_children - elif original_content is not None: - filtered, found_target = prune_cst_for_context( - original_content, + if helpers is not None: + return _recurse_sections( + node, + section_names, + lambda child: prune_cst( + child, target_functions, - helpers_of_helper_functions, prefix, + defs_with_usages=defs_with_usages, + helpers=helpers, remove_docstrings=remove_docstrings, include_target_in_output=include_target_in_output, + exclude_init_from_targets=exclude_init_from_targets, + keep_class_init=keep_class_init, + include_dunder_methods=include_dunder_methods, include_init_dunder=include_init_dunder, - ) - found_any_target |= found_target - if filtered: - updates[section] = filtered - - if updates: - return (node.with_changes(**updates), found_any_target) - - return None, False + ), + keep_non_target_children=True, + ) + return _recurse_sections( + node, + section_names, + lambda child: prune_cst( + child, + target_functions, + prefix, + defs_with_usages=defs_with_usages, + helpers=helpers, + remove_docstrings=remove_docstrings, + include_target_in_output=include_target_in_output, + exclude_init_from_targets=exclude_init_from_targets, + keep_class_init=keep_class_init, + include_dunder_methods=include_dunder_methods, + include_init_dunder=include_init_dunder, + ), + ) diff --git a/codeflash/languages/current.py b/codeflash/languages/current.py index ecdb7315a..005249669 100644 --- a/codeflash/languages/current.py +++ b/codeflash/languages/current.py @@ -34,7 +34,7 @@ from codeflash.languages.base import LanguageSupport # Module-level singleton for the current language -_current_language: Language | None = None +_current_language: Language = Language.PYTHON def current_language() -> Language: diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 7088e6f1f..cfa1f5d2b 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -12,12 +12,10 @@ from codeflash.code_utils.code_replacer import replace_functions_and_add_imports from codeflash.context.code_context_extractor import ( collect_names_from_annotation, + enrich_testgen_context, extract_classes_from_type_hint, extract_imports_for_class, get_code_optimization_context, - get_external_base_class_inits, - get_external_class_inits, - get_imported_class_definitions, resolve_transitive_type_deps, ) from codeflash.discovery.functions_to_optimize import FunctionToOptimize @@ -769,199 +767,6 @@ def helper_method(self): assert hashing_context.strip() == expected_hashing_context.strip() -def test_example_class_token_limit_1(tmp_path: Path) -> None: - docstring_filler = " ".join( - ["This is a long docstring that will be used to fill up the token limit." for _ in range(1000)] - ) - code = f""" -class MyClass: - \"\"\"A class with a helper method. -{docstring_filler}\"\"\" - def __init__(self): - self.x = 1 - def target_method(self): - \"\"\"Docstring for target method\"\"\" - y = HelperClass().helper_method() - -class HelperClass: - \"\"\"A helper class for MyClass.\"\"\" - def __init__(self): - \"\"\"Initialize the HelperClass.\"\"\" - self.x = 1 - def __repr__(self): - \"\"\"Return a string representation of the HelperClass.\"\"\" - return "HelperClass" + str(self.x) - def helper_method(self): - return self.x -""" - # Create a temporary Python file using pytest's tmp_path fixture - file_path = tmp_path / "test_code.py" - file_path.write_text(code, encoding="utf-8") - opt = Optimizer( - Namespace( - project_root=file_path.parent.resolve(), - disable_telemetry=True, - tests_root="tests", - test_framework="pytest", - pytest_cmd="pytest", - experiment_id=None, - test_project_root=Path().resolve(), - ) - ) - function_to_optimize = FunctionToOptimize( - function_name="target_method", - file_path=file_path, - parents=[FunctionParent(name="MyClass", type="ClassDef")], - starting_line=None, - ending_line=None, - ) - - code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) - read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code - hashing_context = code_ctx.hashing_code_context - # In this scenario, the read-only code context is too long, so the read-only docstrings are removed. - expected_read_write_context = f""" -```python:{file_path.relative_to(opt.args.project_root)} -class MyClass: - def __init__(self): - self.x = 1 - def target_method(self): - \"\"\"Docstring for target method\"\"\" - y = HelperClass().helper_method() - -class HelperClass: - def __init__(self): - \"\"\"Initialize the HelperClass.\"\"\" - self.x = 1 - def helper_method(self): - return self.x -``` -""" - expected_read_only_context = f""" -```python:{file_path.relative_to(opt.args.project_root)} -class MyClass: - pass - -class HelperClass: - def __repr__(self): - return "HelperClass" + str(self.x) -``` -""" - expected_hashing_context = f""" -```python:{file_path.relative_to(opt.args.project_root)} -class MyClass: - - def target_method(self): - y = HelperClass().helper_method() - -class HelperClass: - - def helper_method(self): - return self.x -``` -""" - assert read_write_context.markdown.strip() == expected_read_write_context.strip() - assert read_only_context.strip() == expected_read_only_context.strip() - assert hashing_context.strip() == expected_hashing_context.strip() - - -def test_example_class_token_limit_2(tmp_path: Path) -> None: - string_filler = " ".join( - ["This is a long string that will be used to fill up the token limit." for _ in range(1000)] - ) - code = f""" -class MyClass: - \"\"\"A class with a helper method. \"\"\" - def __init__(self): - self.x = 1 - def target_method(self): - \"\"\"Docstring for target method\"\"\" - y = HelperClass().helper_method() -x = '{string_filler}' - -class HelperClass: - \"\"\"A helper class for MyClass.\"\"\" - def __init__(self): - \"\"\"Initialize the HelperClass.\"\"\" - self.x = 1 - def __repr__(self): - \"\"\"Return a string representation of the HelperClass.\"\"\" - return "HelperClass" + str(self.x) - def helper_method(self): - return self.x -""" - # Create a temporary Python file using pytest's tmp_path fixture - file_path = tmp_path / "test_code.py" - file_path.write_text(code, encoding="utf-8") - opt = Optimizer( - Namespace( - project_root=file_path.parent.resolve(), - disable_telemetry=True, - tests_root="tests", - test_framework="pytest", - pytest_cmd="pytest", - experiment_id=None, - test_project_root=Path().resolve(), - ) - ) - function_to_optimize = FunctionToOptimize( - function_name="target_method", - file_path=file_path, - parents=[FunctionParent(name="MyClass", type="ClassDef")], - starting_line=None, - ending_line=None, - ) - - code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root, 8000, 100000) - read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code - hashing_context = code_ctx.hashing_code_context - # In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely. - expected_read_write_context = f""" -```python:{file_path.relative_to(opt.args.project_root)} -class MyClass: - def __init__(self): - self.x = 1 - def target_method(self): - \"\"\"Docstring for target method\"\"\" - y = HelperClass().helper_method() - -class HelperClass: - def __init__(self): - \"\"\"Initialize the HelperClass.\"\"\" - self.x = 1 - def helper_method(self): - return self.x -``` -""" - expected_read_only_context = f'''```python:{file_path.relative_to(opt.args.project_root)} -class MyClass: - """A class with a helper method. """ - -class HelperClass: - """A helper class for MyClass.""" - def __repr__(self): - """Return a string representation of the HelperClass.""" - return "HelperClass" + str(self.x) -``` -''' - expected_hashing_context = f""" -```python:{file_path.relative_to(opt.args.project_root)} -class MyClass: - - def target_method(self): - y = HelperClass().helper_method() - -class HelperClass: - - def helper_method(self): - return self.x -``` -""" - assert read_write_context.markdown.strip() == expected_read_write_context.strip() - assert read_only_context.strip() == expected_read_only_context.strip() - assert hashing_context.strip() == expected_hashing_context.strip() - - def test_example_class_token_limit_3(tmp_path: Path) -> None: string_filler = " ".join( ["This is a long string that will be used to fill up the token limit." for _ in range(1000)] @@ -1009,7 +814,7 @@ def helper_method(self): ) # In this scenario, the read-writable code is too long, so we abort. with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"): - code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + get_code_optimization_context(function_to_optimize, opt.args.project_root, optim_token_limit=8000) def test_example_class_token_limit_4(tmp_path: Path) -> None: @@ -1062,7 +867,7 @@ def helper_method(self): # In this scenario, the read-writable code context becomes too large because the __init__ function is referencing the global x variable instead of the class attribute self.x, so we abort. with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"): - code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + get_code_optimization_context(function_to_optimize, opt.args.project_root, optim_token_limit=8000) def test_example_class_token_limit_5(tmp_path: Path) -> None: @@ -2422,7 +2227,7 @@ def nested_method(self): assert "__init__" not in hashing_context # Should not contain __init__ methods # Verify nested classes are excluded from the hashing context - # The prune_cst_for_code_hashing function should not recurse into nested classes + # The prune_cst function in hashing mode should not recurse into nested classes assert "class NestedClass:" not in hashing_context # Nested class definition should not be present # The target method will reference NestedClass, but the actual nested class definition should not be included @@ -3275,8 +3080,8 @@ def dump_layout(layout_type, layout): assert testgen_context.count("def __init__") >= 2, "Both __init__ methods should be in testgen context" -def test_get_imported_class_definitions_extracts_project_classes(tmp_path: Path) -> None: - """Test that get_imported_class_definitions extracts class definitions from project modules.""" +def test_enrich_testgen_context_extracts_project_classes(tmp_path: Path) -> None: + """Test that enrich_testgen_context extracts class definitions from project modules.""" # Create a package structure with two modules package_dir = tmp_path / "mypackage" package_dir.mkdir() @@ -3325,8 +3130,8 @@ def will_fit(self, chunk: PreChunk) -> bool: # Create CodeStringsMarkdown from the chunking module (simulating testgen context) context = CodeStringsMarkdown(code_strings=[CodeString(code=chunking_code, file_path=chunking_path)]) - # Call get_imported_class_definitions - result = get_imported_class_definitions(context, tmp_path) + # Call enrich_testgen_context + result = enrich_testgen_context(context, tmp_path) # Verify Element class was extracted assert len(result.code_strings) == 1, "Should extract exactly one class (Element)" @@ -3339,8 +3144,8 @@ def will_fit(self, chunk: PreChunk) -> bool: assert "import abc" in extracted_code, "Should include necessary imports for base class" -def test_get_imported_class_definitions_skips_existing_definitions(tmp_path: Path) -> None: - """Test that get_imported_class_definitions skips classes already defined in context.""" +def test_enrich_testgen_context_skips_existing_definitions(tmp_path: Path) -> None: + """Test that enrich_testgen_context skips classes already defined in context.""" # Create a package structure package_dir = tmp_path / "mypackage" package_dir.mkdir() @@ -3373,15 +3178,15 @@ def process(self, elem: Element): context = CodeStringsMarkdown(code_strings=[CodeString(code=code_with_local_def, file_path=code_path)]) - # Call get_imported_class_definitions - result = get_imported_class_definitions(context, tmp_path) + # Call enrich_testgen_context + result = enrich_testgen_context(context, tmp_path) # Should NOT extract Element since it's already defined locally assert len(result.code_strings) == 0, "Should not extract classes already defined in context" -def test_get_imported_class_definitions_skips_third_party(tmp_path: Path) -> None: - """Test that get_imported_class_definitions skips third-party/stdlib imports.""" +def test_enrich_testgen_context_skips_third_party(tmp_path: Path) -> None: + """Test that enrich_testgen_context skips third-party/stdlib imports.""" # Create a simple package package_dir = tmp_path / "mypackage" package_dir.mkdir() @@ -3402,15 +3207,15 @@ def __init__(self, path: Path): context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - # Call get_imported_class_definitions - result = get_imported_class_definitions(context, tmp_path) + # Call enrich_testgen_context + result = enrich_testgen_context(context, tmp_path) # Should not extract any classes (Path, Optional, dataclass are stdlib/third-party) assert len(result.code_strings) == 0, "Should not extract stdlib/third-party classes" -def test_get_imported_class_definitions_handles_multiple_imports(tmp_path: Path) -> None: - """Test that get_imported_class_definitions handles multiple class imports.""" +def test_enrich_testgen_context_handles_multiple_imports(tmp_path: Path) -> None: + """Test that enrich_testgen_context handles multiple class imports.""" # Create a package structure package_dir = tmp_path / "mypackage" package_dir.mkdir() @@ -3446,8 +3251,8 @@ def process(self, a: TypeA, b: TypeB): context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - # Call get_imported_class_definitions - result = get_imported_class_definitions(context, tmp_path) + # Call enrich_testgen_context + result = enrich_testgen_context(context, tmp_path) # Should extract both TypeA and TypeB (but not TypeC since it's not imported) assert len(result.code_strings) == 2, "Should extract exactly two classes (TypeA, TypeB)" @@ -3458,8 +3263,8 @@ def process(self, a: TypeA, b: TypeB): assert "class TypeC" not in all_extracted_code, "Should NOT contain TypeC (not imported)" -def test_get_imported_class_definitions_includes_dataclass_decorators(tmp_path: Path) -> None: - """Test that get_imported_class_definitions includes decorators when extracting dataclasses.""" +def test_enrich_testgen_context_includes_dataclass_decorators(tmp_path: Path) -> None: + """Test that enrich_testgen_context includes decorators when extracting dataclasses.""" # Create a package structure package_dir = tmp_path / "mypackage" package_dir.mkdir() @@ -3496,8 +3301,8 @@ def get_config(self) -> LLMConfig: context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - # Call get_imported_class_definitions - result = get_imported_class_definitions(context, tmp_path) + # Call enrich_testgen_context + result = enrich_testgen_context(context, tmp_path) # Should extract both LLMConfigBase (base class) and LLMConfig assert len(result.code_strings) == 2, "Should extract both LLMConfig and its base class LLMConfigBase" @@ -3521,7 +3326,7 @@ def get_config(self) -> LLMConfig: assert "from dataclasses import" in all_extracted_code, "Should include dataclasses import" -def test_get_imported_class_definitions_extracts_imports_for_decorated_classes(tmp_path: Path) -> None: +def test_enrich_testgen_context_extracts_imports_for_decorated_classes(tmp_path: Path) -> None: """Test that extract_imports_for_class includes decorator and type annotation imports.""" # Create a package structure package_dir = tmp_path / "mypackage" @@ -3552,7 +3357,7 @@ def create_config() -> Config: context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_imported_class_definitions(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) assert len(result.code_strings) == 1, "Should extract Config class" extracted_code = result.code_strings[0].code @@ -3724,7 +3529,7 @@ class MyClass: assert result.count("from typing import Optional") == 1 -def test_get_imported_class_definitions_multiple_decorators(tmp_path: Path) -> None: +def test_enrich_testgen_context_multiple_decorators(tmp_path: Path) -> None: """Test that classes with multiple decorators are extracted correctly.""" package_dir = tmp_path / "mypackage" package_dir.mkdir() @@ -3755,7 +3560,7 @@ def sort_configs(configs: list[OrderedConfig]) -> list[OrderedConfig]: context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_imported_class_definitions(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) assert len(result.code_strings) == 1 extracted_code = result.code_strings[0].code @@ -3766,7 +3571,7 @@ def sort_configs(configs: list[OrderedConfig]) -> list[OrderedConfig]: assert "class OrderedConfig" in extracted_code -def test_get_imported_class_definitions_extracts_multilevel_inheritance(tmp_path: Path) -> None: +def test_enrich_testgen_context_extracts_multilevel_inheritance(tmp_path: Path) -> None: """Test that base classes are recursively extracted for multi-level inheritance. This is critical for understanding dataclass constructor signatures, as fields @@ -3826,8 +3631,8 @@ def get_router_config(self) -> RouterConfig: context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - # Call get_imported_class_definitions - result = get_imported_class_definitions(context, tmp_path) + # Call enrich_testgen_context + result = enrich_testgen_context(context, tmp_path) # Should extract 4 classes: GrandParentConfig, ParentConfig, ChildConfig, RouterConfig # (all classes needed to understand the full inheritance hierarchy) @@ -3862,7 +3667,7 @@ def get_router_config(self) -> RouterConfig: assert "model_list: list" in all_extracted_code, "Should include model_list field from Router" -def test_get_external_base_class_inits_extracts_userdict(tmp_path: Path) -> None: +def test_enrich_testgen_context_extracts_userdict(tmp_path: Path) -> None: """Extracts __init__ from collections.UserDict when a class inherits from it.""" code = """from collections import UserDict @@ -3873,7 +3678,7 @@ class MyCustomDict(UserDict): code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_base_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) assert len(result.code_strings) == 1 code_string = result.code_strings[0] @@ -3891,8 +3696,8 @@ def __init__(self, dict=None, /, **kwargs): assert code_string.file_path.as_posix().endswith("collections/__init__.py") -def test_get_external_base_class_inits_skips_project_classes(tmp_path: Path) -> None: - """Returns empty when base class is from the project, not external.""" +def test_enrich_testgen_context_skips_unresolvable_base_classes(tmp_path: Path) -> None: + """Returns empty when base class module cannot be resolved.""" child_code = """from base import ProjectBase class Child(ProjectBase): @@ -3902,12 +3707,12 @@ class Child(ProjectBase): child_path.write_text(child_code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=child_code, file_path=child_path)]) - result = get_external_base_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) assert result.code_strings == [] -def test_get_external_base_class_inits_skips_builtins(tmp_path: Path) -> None: +def test_enrich_testgen_context_skips_builtin_base_classes(tmp_path: Path) -> None: """Returns empty for builtin classes like list that have no inspectable source.""" code = """class MyList(list): pass @@ -3916,12 +3721,12 @@ def test_get_external_base_class_inits_skips_builtins(tmp_path: Path) -> None: code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_base_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) assert result.code_strings == [] -def test_get_external_base_class_inits_deduplicates(tmp_path: Path) -> None: +def test_enrich_testgen_context_deduplicates(tmp_path: Path) -> None: """Extracts the same external base class only once even when inherited multiple times.""" code = """from collections import UserDict @@ -3935,7 +3740,7 @@ class MyDict2(UserDict): code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_base_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) assert len(result.code_strings) == 1 expected_code = """\ @@ -3950,7 +3755,7 @@ def __init__(self, dict=None, /, **kwargs): assert result.code_strings[0].code == expected_code -def test_get_external_base_class_inits_empty_when_no_inheritance(tmp_path: Path) -> None: +def test_enrich_testgen_context_empty_when_no_inheritance(tmp_path: Path) -> None: """Returns empty when there are no external base classes.""" code = """class SimpleClass: pass @@ -3959,7 +3764,7 @@ def test_get_external_base_class_inits_empty_when_no_inheritance(tmp_path: Path) code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_base_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) assert result.code_strings == [] @@ -4103,127 +3908,8 @@ def target_method(self): assert "self.data = {}" in testgen_context, "UserDict __init__ body should be included" -def test_read_only_code_removed_when_exceeds_limit(tmp_path: Path) -> None: - """Test read-only code is completely removed when it exceeds token limit even without docstrings. - - This covers lines 152-153 in code_context_extractor.py where read_only_context_code is set - to empty string when it still exceeds the token limit after docstring removal. - """ - # Create a second-degree helper with large implementation that has no docstrings - # Second-degree helpers go into read-only context - long_lines = [" x = 0"] - for i in range(150): - long_lines.append(f" x = x + {i}") - long_lines.append(" return x") - long_body = "\n".join(long_lines) - - code = f""" -class MyClass: - def __init__(self): - self.x = 1 - - def target_method(self): - return first_helper() - - -def first_helper(): - # First degree helper - calls second degree - return second_helper() - - -def second_helper(): - # Second degree helper - goes into read-only context -{long_body} -""" - file_path = tmp_path / "test_code.py" - file_path.write_text(code, encoding="utf-8") - - func_to_optimize = FunctionToOptimize( - function_name="target_method", file_path=file_path, parents=[FunctionParent(name="MyClass", type="ClassDef")] - ) - - # Use a small optim_token_limit that allows read-writable but not read-only - # Read-writable is ~48 tokens, read-only is ~600 tokens - code_ctx = get_code_optimization_context( - function_to_optimize=func_to_optimize, - project_root_path=tmp_path, - optim_token_limit=100, # Small limit to trigger read-only removal - ) - - # The read-only context should be empty because it exceeded the limit - assert code_ctx.read_only_context_code == "", "Read-only code should be removed when exceeding token limit" - - -def test_testgen_removes_imported_classes_on_overflow(tmp_path: Path) -> None: - """Test testgen context removes imported class definitions when exceeding token limit. - - This covers lines 176-186 in code_context_extractor.py where: - - Testgen context exceeds limit (line 175) - - Removing docstrings still exceeds (line 175 again) - - Removing imported classes succeeds (line 177-183) - """ - # Create a package structure with a large type class used only in type annotations - # This ensures get_imported_class_definitions extracts the full class - package_dir = tmp_path / "mypackage" - package_dir.mkdir() - (package_dir / "__init__.py").write_text("", encoding="utf-8") - - # Create a large class with methods that will be extracted via get_imported_class_definitions - # Use methods WITHOUT docstrings so removing docstrings won't help much - many_methods = "\n".join([f" def method_{i}(self):\n return {i}" for i in range(100)]) - type_class_code = f''' -class TypeClass: - """A type class for annotations.""" - - def __init__(self, value: int): - self.value = value - -{many_methods} -''' - type_class_path = package_dir / "types.py" - type_class_path.write_text(type_class_code, encoding="utf-8") - - # Main module uses TypeClass only in annotation (not instantiated) - # This triggers get_imported_class_definitions to extract the full class - main_code = """ -from mypackage.types import TypeClass - -def target_function(obj: TypeClass) -> int: - return obj.value -""" - main_path = package_dir / "main.py" - main_path.write_text(main_code, encoding="utf-8") - - func_to_optimize = FunctionToOptimize(function_name="target_function", file_path=main_path, parents=[]) - - # Use a testgen_token_limit that: - # - Is exceeded by full context with imported class (~1500 tokens) - # - Is exceeded even after removing docstrings - # - But fits when imported class is removed (~40 tokens) - code_ctx = get_code_optimization_context( - function_to_optimize=func_to_optimize, - project_root_path=tmp_path, - testgen_token_limit=200, # Small limit to trigger imported class removal - ) - - # The testgen context should exist (didn't raise ValueError) - testgen_context = code_ctx.testgen_context.markdown - assert testgen_context, "Testgen context should not be empty" - - # The target function should still be there - assert "def target_function" in testgen_context, "Target function should be in testgen context" - - # The large imported class should NOT be included (removed due to token limit) - assert "class TypeClass" not in testgen_context, ( - "TypeClass should be removed from testgen context when exceeding token limit" - ) - - -def test_testgen_raises_when_all_fallbacks_fail(tmp_path: Path) -> None: - """Test that ValueError is raised when testgen context exceeds limit even after all fallbacks. - - This covers line 186 in code_context_extractor.py. - """ +def test_testgen_raises_when_exceeds_limit(tmp_path: Path) -> None: + """Test that ValueError is raised when testgen context exceeds token limit.""" # Create a function with a very long body that exceeds limits even without imports/docstrings long_lines = [" x = 0"] for i in range(200): @@ -4249,7 +3935,7 @@ def target_function(): ) -def test_get_external_base_class_inits_attribute_base(tmp_path: Path) -> None: +def test_enrich_testgen_context_attribute_base(tmp_path: Path) -> None: """Test handling of base class accessed as module.ClassName (ast.Attribute). This covers line 616 in code_context_extractor.py. @@ -4265,7 +3951,7 @@ def custom_method(self): code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_base_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) # Should extract UserDict __init__ assert len(result.code_strings) == 1 @@ -4273,7 +3959,7 @@ def custom_method(self): assert "def __init__" in result.code_strings[0].code -def test_get_external_base_class_inits_no_init_method(tmp_path: Path) -> None: +def test_enrich_testgen_context_no_init_method(tmp_path: Path) -> None: """Test handling when base class has no __init__ method. This covers line 641 in code_context_extractor.py. @@ -4288,7 +3974,7 @@ class MyProtocol(Protocol): code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_base_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) # Protocol's __init__ can't be easily inspected, should handle gracefully # Result may be empty or contain Protocol based on implementation @@ -4377,7 +4063,7 @@ def target_method(self): def test_imported_class_definitions_module_path_none(tmp_path: Path) -> None: - """Test handling when module_path is None in get_imported_class_definitions. + """Test handling when module_path is None in enrich_testgen_context. This covers line 560 in code_context_extractor.py. """ @@ -4393,123 +4079,12 @@ def method(self, obj: SomeClass): code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_imported_class_definitions(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) # Should handle gracefully and return empty or partial results assert isinstance(result.code_strings, list) -def test_get_imported_names_import_star(tmp_path: Path) -> None: - """Test get_imported_names handles import * correctly. - - This covers lines 808-809 and 824-825 in code_context_extractor.py. - """ - import libcst as cst - - # Test regular import * - # Note: "import *" is not valid Python, but "from x import *" is - from_import_star = cst.parse_statement("from os import *") - assert isinstance(from_import_star, cst.SimpleStatementLine) - import_node = from_import_star.body[0] - assert isinstance(import_node, cst.ImportFrom) - - from codeflash.context.code_context_extractor import get_imported_names - - result = get_imported_names(import_node) - assert result == {"*"} - - -def test_get_imported_names_aliased_import(tmp_path: Path) -> None: - """Test get_imported_names handles aliased imports correctly. - - This covers lines 812-813 and 828-829 in code_context_extractor.py. - """ - import libcst as cst - - from codeflash.context.code_context_extractor import get_imported_names - - # Test import with alias - import_stmt = cst.parse_statement("import numpy as np") - assert isinstance(import_stmt, cst.SimpleStatementLine) - import_node = import_stmt.body[0] - assert isinstance(import_node, cst.Import) - - result = get_imported_names(import_node) - assert "np" in result - - # Test from import with alias - from_import_stmt = cst.parse_statement("from os import path as ospath") - assert isinstance(from_import_stmt, cst.SimpleStatementLine) - from_import_node = from_import_stmt.body[0] - assert isinstance(from_import_node, cst.ImportFrom) - - result2 = get_imported_names(from_import_node) - assert "ospath" in result2 - - -def test_get_imported_names_dotted_import(tmp_path: Path) -> None: - """Test get_imported_names handles dotted imports correctly. - - This covers lines 816-822 in code_context_extractor.py. - """ - import libcst as cst - - from codeflash.context.code_context_extractor import get_imported_names - - # Test dotted import like "import os.path" - import_stmt = cst.parse_statement("import os.path") - assert isinstance(import_stmt, cst.SimpleStatementLine) - import_node = import_stmt.body[0] - assert isinstance(import_node, cst.Import) - - result = get_imported_names(import_node) - assert "os" in result - - -def test_used_name_collector_comprehensive(tmp_path: Path) -> None: - """Test UsedNameCollector handles various node types. - - This covers lines 767-801 in code_context_extractor.py. - """ - import libcst as cst - - from codeflash.context.code_context_extractor import UsedNameCollector - - code = """ -import os -from typing import List - -x: int = 1 -y = os.path.join("a", "b") - -class MyClass: - z = 10 - -def my_func(): - pass -""" - module = cst.parse_module(code) - collector = UsedNameCollector() - # In libcst, the walker traverses the module - cst.MetadataWrapper(module).visit(collector) - - # Check used names - assert "os" in collector.used_names - assert "int" in collector.used_names - assert "List" in collector.used_names - - # Check defined names - assert "x" in collector.defined_names - assert "y" in collector.defined_names - assert "MyClass" in collector.defined_names - assert "my_func" in collector.defined_names - - # Check external names (used but not defined) - external = collector.get_external_names() - assert "os" in external - assert "x" not in external # x is defined - - def test_imported_class_with_base_in_same_module(tmp_path: Path) -> None: """Test that imported classes with bases in the same module are extracted correctly. @@ -4549,52 +4124,13 @@ def target_function(obj: DerivedClass) -> bool: main_path.write_text(main_code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=main_code, file_path=main_path)]) - result = get_imported_class_definitions(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) # Should extract the inheritance chain all_code = "\n".join(cs.code for cs in result.code_strings) assert "class BaseClass" in all_code or "class DerivedClass" in all_code -def test_get_imported_names_from_import_without_alias(tmp_path: Path) -> None: - """Test get_imported_names handles from imports without aliases. - - This covers lines 830-831 in code_context_extractor.py. - """ - import libcst as cst - - from codeflash.context.code_context_extractor import get_imported_names - - # Test from import without alias - from_import_stmt = cst.parse_statement("from os import path, getcwd") - assert isinstance(from_import_stmt, cst.SimpleStatementLine) - from_import_node = from_import_stmt.body[0] - assert isinstance(from_import_node, cst.ImportFrom) - - result = get_imported_names(from_import_node) - assert "path" in result - assert "getcwd" in result - - -def test_get_imported_names_regular_import(tmp_path: Path) -> None: - """Test get_imported_names handles regular imports. - - This covers lines 814-815 in code_context_extractor.py. - """ - import libcst as cst - - from codeflash.context.code_context_extractor import get_imported_names - - # Test regular import without alias - import_stmt = cst.parse_statement("import json") - assert isinstance(import_stmt, cst.SimpleStatementLine) - import_node = import_stmt.body[0] - assert isinstance(import_node, cst.Import) - - result = get_imported_names(import_node) - assert "json" in result - - def test_augmented_assignment_not_in_context(tmp_path: Path) -> None: """Test that augmented assignments are handled but not included unless used. @@ -4625,7 +4161,7 @@ def target_method(self): assert "counter" in read_writable -def test_get_external_class_inits_extracts_click_option(tmp_path: Path) -> None: +def test_enrich_testgen_context_extracts_click_option(tmp_path: Path) -> None: """Extracts __init__ from click.Option when directly imported.""" code = """from click import Option @@ -4636,7 +4172,7 @@ def my_func(opt: Option) -> None: code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) assert len(result.code_strings) == 1 code_string = result.code_strings[0] @@ -4645,8 +4181,8 @@ def my_func(opt: Option) -> None: assert code_string.file_path is not None and "click" in code_string.file_path.as_posix() -def test_get_external_class_inits_skips_project_classes(tmp_path: Path) -> None: - """Returns empty when imported class is from the project, not external.""" +def test_enrich_testgen_context_extracts_project_class_defs(tmp_path: Path) -> None: + """Extracts project class definitions via jedi resolution.""" # Create a project module with a class (tmp_path / "mymodule.py").write_text("class ProjectClass:\n pass\n", encoding="utf-8") @@ -4659,12 +4195,13 @@ def my_func(obj: ProjectClass) -> None: code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) - assert result.code_strings == [] + assert len(result.code_strings) == 1 + assert "class ProjectClass" in result.code_strings[0].code -def test_get_external_class_inits_skips_non_classes(tmp_path: Path) -> None: +def test_enrich_testgen_context_skips_non_classes(tmp_path: Path) -> None: """Returns empty when imported name is a function, not a class.""" code = """from collections import OrderedDict from os.path import join @@ -4676,7 +4213,7 @@ def my_func() -> None: code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) # join is a function, not a class — should be skipped # OrderedDict is a class and should be included @@ -4684,8 +4221,8 @@ def my_func() -> None: assert not any("join" in name for name in class_names) -def test_get_external_class_inits_skips_already_defined_classes(tmp_path: Path) -> None: - """Skips classes already defined in the context (e.g., added by get_imported_class_definitions).""" +def test_enrich_testgen_context_skips_already_defined_classes(tmp_path: Path) -> None: + """Skips classes already defined in the context (e.g., added by enrich_testgen_context).""" code = """from collections import UserDict class UserDict: @@ -4699,14 +4236,14 @@ def my_func(d: UserDict) -> None: code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) # UserDict is already defined in the context, so it should be skipped assert result.code_strings == [] -def test_get_external_class_inits_skips_builtins(tmp_path: Path) -> None: - """Returns empty for builtin classes like list/dict that have no inspectable source.""" +def test_enrich_testgen_context_skips_builtin_annotations(tmp_path: Path) -> None: + """Returns empty for builtin type annotations like list/dict that are not imported.""" code = """x: list = [] y: dict = {} @@ -4717,12 +4254,12 @@ def my_func() -> None: code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) assert result.code_strings == [] -def test_get_external_class_inits_skips_object_init(tmp_path: Path) -> None: +def test_enrich_testgen_context_skips_object_init(tmp_path: Path) -> None: """Skips classes whose __init__ is just object.__init__ (trivial).""" # enum.Enum has a metaclass-based __init__, but individual enum members # effectively use object.__init__. Use a class we know has object.__init__. @@ -4735,14 +4272,14 @@ def my_func(q: QName) -> None: code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) # QName has its own __init__, so it should be included if it's in site-packages. # But since it's stdlib (not site-packages), it should be skipped. assert result.code_strings == [] -def test_get_external_class_inits_empty_when_no_imports(tmp_path: Path) -> None: +def test_enrich_testgen_context_empty_when_no_imports(tmp_path: Path) -> None: """Returns empty when there are no from-imports.""" code = """def my_func() -> None: pass @@ -4751,7 +4288,7 @@ def test_get_external_class_inits_empty_when_no_imports(tmp_path: Path) -> None: code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) assert result.code_strings == [] @@ -4840,17 +4377,17 @@ def test_resolve_transitive_type_deps_handles_failure_gracefully() -> None: """Returns empty list for a class where get_type_hints fails.""" class BadClass: - def __init__(self, x: "NonexistentType") -> None: # type: ignore[name-defined] # noqa: F821 + def __init__(self, x: NonexistentType) -> None: # type: ignore[name-defined] # noqa: F821 pass result = resolve_transitive_type_deps(BadClass) assert result == [] -# --- Integration tests for transitive resolution in get_external_class_inits --- +# --- Integration tests for transitive resolution in enrich_testgen_context --- -def test_get_external_class_inits_transitive_deps(tmp_path: Path) -> None: +def test_enrich_testgen_context_transitive_deps(tmp_path: Path) -> None: """Extracts transitive type dependencies from __init__ annotations.""" code = """from click import Context @@ -4861,7 +4398,7 @@ def my_func(ctx: Context) -> None: code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) class_names = {cs.code.split("\n")[0].replace("class ", "").rstrip(":") for cs in result.code_strings} assert "Context" in class_names @@ -4869,7 +4406,7 @@ def my_func(ctx: Context) -> None: assert "Command" in class_names -def test_get_external_class_inits_no_infinite_loops(tmp_path: Path) -> None: +def test_enrich_testgen_context_no_infinite_loops(tmp_path: Path) -> None: """Handles classes with circular type references without infinite loops.""" # click.Context references Command, and Command references Context back # This should terminate without issues due to the processed_classes set @@ -4882,13 +4419,13 @@ def my_func(ctx: Context) -> None: code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) # Should complete without hanging; just verify we got results assert len(result.code_strings) >= 1 -def test_get_external_class_inits_no_duplicate_stubs(tmp_path: Path) -> None: +def test_enrich_testgen_context_no_duplicate_stubs(tmp_path: Path) -> None: """Does not emit duplicate stubs for the same class name.""" code = """from click import Context @@ -4899,7 +4436,7 @@ def my_func(ctx: Context) -> None: code_path.write_text(code, encoding="utf-8") context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)]) - result = get_external_class_inits(context, tmp_path) + result = enrich_testgen_context(context, tmp_path) class_names = [cs.code.split("\n")[0].replace("class ", "").rstrip(":") for cs in result.code_strings] assert len(class_names) == len(set(class_names)), f"Duplicate class stubs found: {class_names}" From 547c02e8bc4820c7dbc4a253b58d7d23ec497e70 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Mon, 16 Feb 2026 14:49:04 -0500 Subject: [PATCH 10/39] refactor: move context extraction modules to languages/python/context/ Move code_context_extractor.py and unused_definition_remover.py from codeflash/context/ to codeflash/languages/python/context/ and update all import sites. --- codeflash/{ => languages/python}/context/__init__.py | 0 .../python}/context/code_context_extractor.py | 12 ++++++------ .../python}/context/unused_definition_remover.py | 0 codeflash/optimization/function_optimizer.py | 7 +++++-- .../test_benchmark_code_extract_code_context.py | 2 +- tests/test_code_context_extractor.py | 4 ++-- tests/test_get_read_only_code.py | 2 +- tests/test_get_read_writable_code.py | 2 +- tests/test_get_testgen_code.py | 2 +- tests/test_languages/test_code_context_extraction.py | 4 +--- tests/test_languages/test_javascript_e2e.py | 2 +- .../test_javascript_optimization_flow.py | 11 +++++------ tests/test_languages/test_typescript_e2e.py | 8 ++++---- tests/test_languages/test_vitest_e2e.py | 2 +- tests/test_remove_unused_definitions.py | 2 +- tests/test_unused_helper_revert.py | 5 ++++- 16 files changed, 34 insertions(+), 31 deletions(-) rename codeflash/{ => languages/python}/context/__init__.py (100%) rename codeflash/{ => languages/python}/context/code_context_extractor.py (99%) rename codeflash/{ => languages/python}/context/unused_definition_remover.py (100%) diff --git a/codeflash/context/__init__.py b/codeflash/languages/python/context/__init__.py similarity index 100% rename from codeflash/context/__init__.py rename to codeflash/languages/python/context/__init__.py diff --git a/codeflash/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py similarity index 99% rename from codeflash/context/code_context_extractor.py rename to codeflash/languages/python/context/code_context_extractor.py index 0220a642d..a28b12ac8 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -14,16 +14,16 @@ from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects from codeflash.code_utils.code_utils import encoded_tokens_len, get_qualified_name, path_belongs_to_site_packages from codeflash.code_utils.config_consts import OPTIMIZATION_CONTEXT_TOKEN_LIMIT, TESTGEN_CONTEXT_TOKEN_LIMIT -from codeflash.context.unused_definition_remover import ( +from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001 + +# Language support imports for multi-language code context extraction +from codeflash.languages import Language, is_python +from codeflash.languages.python.context.unused_definition_remover import ( collect_top_level_defs_with_usages, extract_names_from_targets, get_section_names, remove_unused_definitions_by_function_names, ) -from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001 - -# Language support imports for multi-language code context extraction -from codeflash.languages import Language, is_python from codeflash.models.models import ( CodeContextType, CodeOptimizationContext, @@ -38,8 +38,8 @@ from jedi.api.classes import Name - from codeflash.context.unused_definition_remover import UsageInfo from codeflash.languages.base import HelperFunction + from codeflash.languages.python.context.unused_definition_remover import UsageInfo # Error message constants READ_WRITABLE_LIMIT_ERROR = "Read-writable code has exceeded token limit, cannot proceed" diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/languages/python/context/unused_definition_remover.py similarity index 100% rename from codeflash/context/unused_definition_remover.py rename to codeflash/languages/python/context/unused_definition_remover.py diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index bb824468e..5e3a8a00f 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -72,8 +72,6 @@ from codeflash.code_utils.shell_utils import make_env_with_project_root from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast from codeflash.code_utils.time_utils import humanize_runtime -from codeflash.context import code_context_extractor -from codeflash.context.unused_definition_remover import detect_unused_helper_functions, revert_unused_helper_functions from codeflash.discovery.functions_to_optimize import was_function_previously_optimized from codeflash.either import Failure, Success, is_successful from codeflash.languages import is_python @@ -81,6 +79,11 @@ from codeflash.languages.current import current_language_support, is_typescript from codeflash.languages.javascript.module_system import detect_module_system from codeflash.languages.javascript.test_runner import clear_created_config_files, get_created_config_files +from codeflash.languages.python.context import code_context_extractor +from codeflash.languages.python.context.unused_definition_remover import ( + detect_unused_helper_functions, + revert_unused_helper_functions, +) from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table, tree_to_markdown from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage, LSPMessageId from codeflash.models.ExperimentMetadata import ExperimentMetadata diff --git a/tests/benchmarks/test_benchmark_code_extract_code_context.py b/tests/benchmarks/test_benchmark_code_extract_code_context.py index bb6140916..77c435720 100644 --- a/tests/benchmarks/test_benchmark_code_extract_code_context.py +++ b/tests/benchmarks/test_benchmark_code_extract_code_context.py @@ -1,8 +1,8 @@ from argparse import Namespace from pathlib import Path -from codeflash.context.code_context_extractor import get_code_optimization_context from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context from codeflash.models.models import FunctionParent from codeflash.optimization.optimizer import Optimizer diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index cfa1f5d2b..add427f32 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -10,7 +10,8 @@ from codeflash.code_utils.code_extractor import GlobalAssignmentCollector, add_global_assignments from codeflash.code_utils.code_replacer import replace_functions_and_add_imports -from codeflash.context.code_context_extractor import ( +from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.languages.python.context.code_context_extractor import ( collect_names_from_annotation, enrich_testgen_context, extract_classes_from_type_hint, @@ -18,7 +19,6 @@ get_code_optimization_context, resolve_transitive_type_deps, ) -from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import CodeString, CodeStringsMarkdown, FunctionParent from codeflash.optimization.optimizer import Optimizer diff --git a/tests/test_get_read_only_code.py b/tests/test_get_read_only_code.py index 618e39767..c6de2cc27 100644 --- a/tests/test_get_read_only_code.py +++ b/tests/test_get_read_only_code.py @@ -2,7 +2,7 @@ import pytest -from codeflash.context.code_context_extractor import parse_code_and_prune_cst +from codeflash.languages.python.context.code_context_extractor import parse_code_and_prune_cst from codeflash.models.models import CodeContextType diff --git a/tests/test_get_read_writable_code.py b/tests/test_get_read_writable_code.py index 6de398a25..c6bbdd04b 100644 --- a/tests/test_get_read_writable_code.py +++ b/tests/test_get_read_writable_code.py @@ -2,7 +2,7 @@ import pytest -from codeflash.context.code_context_extractor import parse_code_and_prune_cst +from codeflash.languages.python.context.code_context_extractor import parse_code_and_prune_cst from codeflash.models.models import CodeContextType diff --git a/tests/test_get_testgen_code.py b/tests/test_get_testgen_code.py index c15005fa7..01c3ae153 100644 --- a/tests/test_get_testgen_code.py +++ b/tests/test_get_testgen_code.py @@ -2,7 +2,7 @@ import pytest -from codeflash.context.code_context_extractor import parse_code_and_prune_cst +from codeflash.languages.python.context.code_context_extractor import parse_code_and_prune_cst from codeflash.models.models import CodeContextType diff --git a/tests/test_languages/test_code_context_extraction.py b/tests/test_languages/test_code_context_extraction.py index 07946ddd3..b7b12a69c 100644 --- a/tests/test_languages/test_code_context_extraction.py +++ b/tests/test_languages/test_code_context_extraction.py @@ -20,14 +20,12 @@ from __future__ import annotations -from pathlib import Path - import pytest -from codeflash.context.code_context_extractor import get_code_optimization_context_for_language from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.base import Language from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport +from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context_for_language @pytest.fixture diff --git a/tests/test_languages/test_javascript_e2e.py b/tests/test_languages/test_javascript_e2e.py index 017e8f66e..7b7e8503b 100644 --- a/tests/test_languages/test_javascript_e2e.py +++ b/tests/test_languages/test_javascript_e2e.py @@ -106,9 +106,9 @@ def js_project_dir(self): def test_extract_code_context_for_javascript(self, js_project_dir): """Test extracting code context for a JavaScript function.""" skip_if_js_not_supported() - from codeflash.context.code_context_extractor import get_code_optimization_context from codeflash.discovery.functions_to_optimize import find_all_functions_in_file from codeflash.languages import current as lang_current + from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context lang_current._current_language = Language.JAVASCRIPT diff --git a/tests/test_languages/test_javascript_optimization_flow.py b/tests/test_languages/test_javascript_optimization_flow.py index 26d2db140..89631565b 100644 --- a/tests/test_languages/test_javascript_optimization_flow.py +++ b/tests/test_languages/test_javascript_optimization_flow.py @@ -9,7 +9,6 @@ This is the JavaScript equivalent of test_instrument_tests.py for Python. """ -from pathlib import Path from unittest.mock import MagicMock, patch import pytest @@ -71,9 +70,9 @@ def test_function_to_optimize_has_correct_language_for_javascript(self, tmp_path def test_code_context_preserves_language(self, tmp_path): """Verify language is preserved in code context extraction.""" skip_if_js_not_supported() - from codeflash.context.code_context_extractor import get_code_optimization_context from codeflash.discovery.functions_to_optimize import find_all_functions_in_file from codeflash.languages import current as lang_current + from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context lang_current._current_language = Language.TYPESCRIPT @@ -164,7 +163,7 @@ def test_testgen_request_includes_correct_language(self, tmp_path): # Mock the AI service request ai_client = AiServiceClient() - with patch.object(ai_client, 'make_ai_service_request') as mock_request: + with patch.object(ai_client, "make_ai_service_request") as mock_request: mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = { @@ -191,8 +190,8 @@ def test_testgen_request_includes_correct_language(self, tmp_path): # Verify the request was made with correct language assert mock_request.called, "API request should have been made" call_args = mock_request.call_args - payload = call_args[1].get('payload', call_args[0][1] if len(call_args[0]) > 1 else {}) - assert payload.get('language') == 'typescript', \ + payload = call_args[1].get("payload", call_args[0][1] if len(call_args[0]) > 1 else {}) + assert payload.get("language") == "typescript", \ f"Expected language='typescript', got language='{payload.get('language')}'" @@ -462,7 +461,7 @@ def test_helper_functions_have_correct_language_javascript(self, tmp_path): """Verify helper functions have language='javascript' for .js files.""" skip_if_js_not_supported() from codeflash.discovery.functions_to_optimize import find_all_functions_in_file - from codeflash.languages import current as lang_current, get_language_support + from codeflash.languages import current as lang_current from codeflash.optimization.function_optimizer import FunctionOptimizer lang_current._current_language = Language.JAVASCRIPT diff --git a/tests/test_languages/test_typescript_e2e.py b/tests/test_languages/test_typescript_e2e.py index a638f01a1..87dc81269 100644 --- a/tests/test_languages/test_typescript_e2e.py +++ b/tests/test_languages/test_typescript_e2e.py @@ -69,7 +69,7 @@ def test_discover_functions_with_type_annotations(self): from codeflash.discovery.functions_to_optimize import find_all_functions_in_file with tempfile.NamedTemporaryFile(suffix=".ts", mode="w", delete=False) as f: - f.write(""" + f.write(r""" export function add(a: number, b: number): number { return a + b; } @@ -123,9 +123,9 @@ def ts_project_dir(self): def test_extract_code_context_for_typescript(self, ts_project_dir): """Test extracting code context for a TypeScript function.""" skip_if_ts_not_supported() - from codeflash.context.code_context_extractor import get_code_optimization_context from codeflash.discovery.functions_to_optimize import find_all_functions_in_file from codeflash.languages import current as lang_current + from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context lang_current._current_language = Language.TYPESCRIPT @@ -201,7 +201,7 @@ def test_replace_function_preserves_types(self): from codeflash.languages import get_language_support from codeflash.languages.base import FunctionInfo - original_source = """ + original_source = r""" interface Config { timeout: number; retries: number; @@ -212,7 +212,7 @@ def test_replace_function_preserves_types(self): } """ - new_function = """function processConfig(config: Config): string { + new_function = r"""function processConfig(config: Config): string { // Optimized with template caching const { timeout, retries } = config; return `timeout=\${timeout}, retries=\${retries}`; diff --git a/tests/test_languages/test_vitest_e2e.py b/tests/test_languages/test_vitest_e2e.py index 68448c1cf..fc3c285a4 100644 --- a/tests/test_languages/test_vitest_e2e.py +++ b/tests/test_languages/test_vitest_e2e.py @@ -117,10 +117,10 @@ def vitest_project_dir(self): def test_extract_code_context_for_typescript(self, vitest_project_dir): """Test extracting code context for a TypeScript function.""" skip_if_js_not_supported() - from codeflash.context.code_context_extractor import get_code_optimization_context from codeflash.discovery.functions_to_optimize import find_all_functions_in_file from codeflash.languages import current as lang_current from codeflash.languages.base import Language + from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context lang_current._current_language = Language.TYPESCRIPT diff --git a/tests/test_remove_unused_definitions.py b/tests/test_remove_unused_definitions.py index 8d272b2bb..5614e7283 100644 --- a/tests/test_remove_unused_definitions.py +++ b/tests/test_remove_unused_definitions.py @@ -1,6 +1,6 @@ -from codeflash.context.unused_definition_remover import remove_unused_definitions_by_function_names +from codeflash.languages.python.context.unused_definition_remover import remove_unused_definitions_by_function_names def test_variable_removal_only() -> None: diff --git a/tests/test_unused_helper_revert.py b/tests/test_unused_helper_revert.py index 18d21de32..bfc75642c 100644 --- a/tests/test_unused_helper_revert.py +++ b/tests/test_unused_helper_revert.py @@ -5,8 +5,11 @@ import pytest -from codeflash.context.unused_definition_remover import detect_unused_helper_functions, revert_unused_helper_functions from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.languages.python.context.unused_definition_remover import ( + detect_unused_helper_functions, + revert_unused_helper_functions, +) from codeflash.models.models import CodeStringsMarkdown from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.verification.verification_utils import TestConfig From b1ec82413ef6b8b063413fe2f8468df246e7c921 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Mon, 16 Feb 2026 15:02:44 -0500 Subject: [PATCH 11/39] refactor: delegate PythonSupport context methods to canonical pipeline Replace duplicate implementations in extract_code_context() and find_helper_functions() with calls to get_code_optimization_context() and get_function_sources_from_jedi() from the canonical context module. --- codeflash/languages/python/support.py | 138 +++++++------------------- 1 file changed, 35 insertions(+), 103 deletions(-) diff --git a/codeflash/languages/python/support.py b/codeflash/languages/python/support.py index 58f66d0b8..4b79b8d91 100644 --- a/codeflash/languages/python/support.py +++ b/codeflash/languages/python/support.py @@ -171,127 +171,59 @@ def discover_tests( # === Code Analysis === def extract_code_context(self, function: FunctionToOptimize, project_root: Path, module_root: Path) -> CodeContext: - """Extract function code and its dependencies. + """Extract function code and its dependencies via the canonical context pipeline.""" + from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context - Uses jedi and libcst for Python code analysis. - - Args: - function: The function to extract context for. - project_root: Root of the project. - module_root: Root of the module containing the function. - - Returns: - CodeContext with target code and dependencies. - - """ try: - source = function.file_path.read_text() + result = get_code_optimization_context(function, project_root) except Exception as e: - logger.exception("Failed to read %s: %s", function.file_path, e) + logger.warning("Failed to extract code context for %s: %s", function.function_name, e) return CodeContext(target_code="", target_file=function.file_path, language=Language.PYTHON) - # Extract the function source - lines = source.splitlines(keepends=True) - if function.starting_line and function.ending_line: - target_lines = lines[function.starting_line - 1 : function.ending_line] - target_code = "".join(target_lines) - else: - target_code = "" - - # Find helper functions - helpers = self.find_helper_functions(function, project_root) - - # Extract imports - import_lines = [] - for line in lines: - stripped = line.strip() - if stripped.startswith(("import ", "from ")): - import_lines.append(stripped) - elif stripped and not stripped.startswith("#"): - # Stop at first non-import, non-comment line - break + helpers = [ + HelperFunction( + name=fs.only_function_name, + qualified_name=fs.qualified_name, + file_path=fs.file_path, + source_code=fs.source_code, + start_line=fs.jedi_definition.line if fs.jedi_definition else 1, + end_line=fs.jedi_definition.line if fs.jedi_definition else 1, + ) + for fs in result.helper_functions + ] return CodeContext( - target_code=target_code, + target_code=result.read_writable_code.markdown, target_file=function.file_path, helper_functions=helpers, - read_only_context="", - imports=import_lines, + read_only_context=result.read_only_context_code, + imports=[], language=Language.PYTHON, ) def find_helper_functions(self, function: FunctionToOptimize, project_root: Path) -> list[HelperFunction]: - """Find helper functions called by the target function. - - Uses jedi for Python code analysis. - - Args: - function: The target function to analyze. - project_root: Root of the project. - - Returns: - List of HelperFunction objects. - - """ - helpers: list[HelperFunction] = [] + """Find helper functions called by the target function via the canonical jedi pipeline.""" + from codeflash.languages.python.context.code_context_extractor import get_function_sources_from_jedi try: - import jedi - - from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages - from codeflash.optimization.function_context import belongs_to_function_qualified - - script = jedi.Script(path=function.file_path, project=jedi.Project(path=project_root)) - file_refs = script.get_names(all_scopes=True, definitions=False, references=True) - - qualified_name = function.qualified_name - - for ref in file_refs: - if not ref.full_name or not belongs_to_function_qualified(ref, qualified_name): - continue - - try: - definitions = ref.goto(follow_imports=True, follow_builtin_imports=False) - except Exception: - continue - - for definition in definitions: - definition_path = definition.module_path - if definition_path is None: - continue - - # Check if it's a valid helper (in project, not in target function) - is_valid = ( - str(definition_path).startswith(str(project_root)) - and not path_belongs_to_site_packages(definition_path) - and definition.full_name - and not belongs_to_function_qualified(definition, qualified_name) - and definition.type == "function" - ) - - if is_valid: - helper_qualified_name = get_qualified_name(definition.module_name, definition.full_name) - # Get source code - try: - helper_source = definition.get_line_code() - except Exception: - helper_source = "" - - helpers.append( - HelperFunction( - name=definition.name, - qualified_name=helper_qualified_name, - file_path=definition_path, - source_code=helper_source, - start_line=definition.line or 1, - end_line=definition.line or 1, - ) - ) - + _dict, sources = get_function_sources_from_jedi( + {function.file_path: {function.qualified_name}}, project_root + ) except Exception as e: logger.warning("Failed to find helpers for %s: %s", function.function_name, e) + return [] - return helpers + return [ + HelperFunction( + name=fs.only_function_name, + qualified_name=fs.qualified_name, + file_path=fs.file_path, + source_code=fs.source_code, + start_line=fs.jedi_definition.line if fs.jedi_definition else 1, + end_line=fs.jedi_definition.line if fs.jedi_definition else 1, + ) + for fs in sources + ] def find_references( self, function: FunctionToOptimize, project_root: Path, tests_root: Path | None = None, max_files: int = 500 From 8566cf051025aa4c450f080b94133b12ac031bfd Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Mon, 16 Feb 2026 15:10:58 -0500 Subject: [PATCH 12/39] fix: update mypy allowlist paths and fix BaseSuite type narrowing Update stale context/ paths in mypy_allowlist.txt to match the languages/python/context/ move. Add assert to narrow BaseSuite to IndentedBlock in prune_cst for mypy. --- codeflash/languages/python/context/code_context_extractor.py | 4 +++- mypy_allowlist.txt | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index a28b12ac8..acab7e2fe 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -1180,7 +1180,9 @@ def prune_cst( # Apply docstring removal to class if needed if remove_docstrings and new_class_body: - return node.with_changes(body=remove_docstring_from_body(node.body.with_changes(body=new_class_body))), True + updated_body = node.body.with_changes(body=new_class_body) + assert isinstance(updated_body, cst.IndentedBlock) + return node.with_changes(body=remove_docstring_from_body(updated_body)), True return node.with_changes(body=node.body.with_changes(body=new_class_body)) if new_class_body else None, True diff --git a/mypy_allowlist.txt b/mypy_allowlist.txt index 6a070b606..e08b14e22 100644 --- a/mypy_allowlist.txt +++ b/mypy_allowlist.txt @@ -6,8 +6,8 @@ codeflash/result/explanation.py codeflash/result/critic.py codeflash/version.py codeflash/optimization/__init__.py -codeflash/context/__init__.py -codeflash/context/code_context_extractor.py +codeflash/languages/python/context/__init__.py +codeflash/languages/python/context/code_context_extractor.py codeflash/discovery/__init__.py codeflash/__init__.py codeflash/models/ExperimentMetadata.py From fadf6d41399d616b352f87ce01fa016a7ae1d525 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Mon, 16 Feb 2026 15:18:38 -0500 Subject: [PATCH 13/39] fix: restore progressive fallback for context token limits Re-add graceful degradation when context exceeds token limits instead of raising ValueError immediately. Read-only context falls back to removing docstrings then removing entirely. Testgen context falls back to removing docstrings then removing enrichment before raising. --- .../python/context/code_context_extractor.py | 43 ++++++++++++++++--- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index acab7e2fe..9f904efbc 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -57,18 +57,22 @@ def build_testgen_context( helpers_of_fto_dict: dict[Path, set[FunctionSource]], helpers_of_helpers_dict: dict[Path, set[FunctionSource]], project_root_path: Path, + *, + remove_docstrings: bool = False, + include_enrichment: bool = True, ) -> CodeStringsMarkdown: testgen_context = extract_code_markdown_context_from_files( helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, - remove_docstrings=False, + remove_docstrings=remove_docstrings, code_context_type=CodeContextType.TESTGEN, ) - enrichment = enrich_testgen_context(testgen_context, project_root_path) - if enrichment.code_strings: - testgen_context = CodeStringsMarkdown(code_strings=testgen_context.code_strings + enrichment.code_strings) + if include_enrichment: + enrichment = enrich_testgen_context(testgen_context, project_root_path) + if enrichment.code_strings: + testgen_context = CodeStringsMarkdown(code_strings=testgen_context.code_strings + enrichment.code_strings) return testgen_context @@ -147,10 +151,39 @@ def get_code_optimization_context( ) read_only_context_code = read_only_code_markdown.markdown + # Progressive fallback for read-only context token limits + read_only_tokens = encoded_tokens_len(read_only_context_code) + if final_read_writable_tokens + read_only_tokens > optim_token_limit: + logger.debug("Code context has exceeded token limit, removing docstrings from read-only code") + read_only_code_no_docstrings = extract_code_markdown_context_from_files( + helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, remove_docstrings=True + ) + read_only_context_code = read_only_code_no_docstrings.markdown + if final_read_writable_tokens + encoded_tokens_len(read_only_context_code) > optim_token_limit: + logger.debug("Code context has exceeded token limit, removing read-only code") + read_only_context_code = "" + + # Progressive fallback for testgen context token limits testgen_context = build_testgen_context(helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path) if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit: - raise ValueError(TESTGEN_LIMIT_ERROR) + logger.debug("Testgen context exceeded token limit, removing docstrings") + testgen_context = build_testgen_context( + helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, remove_docstrings=True + ) + + if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit: + logger.debug("Testgen context still exceeded token limit, removing enrichment") + testgen_context = build_testgen_context( + helpers_of_fto_dict, + helpers_of_helpers_dict, + project_root_path, + remove_docstrings=True, + include_enrichment=False, + ) + + if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit: + raise ValueError(TESTGEN_LIMIT_ERROR) code_hash_context = hashing_code_context.markdown code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest() From bace6112a46aa5cfaf23c3b82e77483c18773d6e Mon Sep 17 00:00:00 2001 From: "codeflash-ai[bot]" <148906541+codeflash-ai[bot]@users.noreply.github.com> Date: Mon, 16 Feb 2026 20:49:37 +0000 Subject: [PATCH 14/39] Optimize _parse_and_collect_imports MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The optimization achieves a **68% runtime improvement** (23.5ms → 14.0ms) by replacing the expensive `ast.walk()` traversal with a targeted recursive collection strategy. **Key Performance Improvement:** The original code uses `ast.walk(tree)` which visits **every single node** in the AST tree (12,947 hits shown in line profiler), consuming 71.7% of total runtime. This includes unnecessary nodes like expressions, literals, and operators that can never contain `ImportFrom` statements. The optimized version implements a custom `collect_imports()` function that: 1. **Only traverses module body and control flow structures** where imports can legally appear (function/class definitions, if/while/for blocks, try/except) 2. **Skips irrelevant AST nodes** like expressions, literals, and operators entirely 3. **Recursively processes nested bodies** (body, orelse, finalbody, handlers) in a depth-first manner **Why This Works:** In Python, `from X import Y` statements can only appear: - At module level - Inside function/class definitions - Within control flow blocks (if/while/for/try) By checking `isinstance()` for only these container node types and recursively descending into their body attributes, we avoid traversing the entire AST subtree for each construct. This dramatically reduces the number of nodes visited while maintaining correctness. **Test Case Performance:** The optimization excels across all scales: - **Small imports** (single statements): 60-77% faster - **Large import lists** (100-500 items): 74-104% faster - **Many code blocks** (500-1000 lines): 70-77% faster - **Mixed code/imports** at scale: 70% faster The performance gain is particularly pronounced when the AST contains large amounts of non-import code (functions, classes, expressions), as shown by the `test_mixed_imports_and_code_large_scale` case improving from 9.31ms to 5.45ms (70.8% faster). **Impact on Workloads:** Given the function_references show this is used in code context extraction benchmarks, this optimization will significantly speed up any workflow that analyzes Python imports from large codebases or performs repeated import analysis during development workflows. --- .../python/context/code_context_extractor.py | 31 +++++++++++++++---- 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index 9f904efbc..173dc8021 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -553,12 +553,31 @@ def _parse_and_collect_imports(code_context: CodeStringsMarkdown) -> tuple[ast.M except SyntaxError: return None imported_names: dict[str, str] = {} - for node in ast.walk(tree): - if isinstance(node, ast.ImportFrom) and node.module: - for alias in node.names: - if alias.name != "*": - imported_name = alias.asname if alias.asname else alias.name - imported_names[imported_name] = node.module + + # Directly iterate over the module body and nested structures instead of ast.walk + # This avoids traversing every single node in the tree + def collect_imports(nodes): + for node in nodes: + if isinstance(node, ast.ImportFrom) and node.module: + for alias in node.names: + if alias.name != "*": + imported_name = alias.asname if alias.asname else alias.name + imported_names[imported_name] = node.module + # Recursively check nested structures (function defs, class defs, if statements, etc.) + elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, + ast.If, ast.For, ast.AsyncFor, ast.While, ast.With, + ast.AsyncWith, ast.Try, ast.ExceptHandler)): + if hasattr(node, 'body'): + collect_imports(node.body) + if hasattr(node, 'orelse'): + collect_imports(node.orelse) + if hasattr(node, 'finalbody'): + collect_imports(node.finalbody) + if hasattr(node, 'handlers'): + for handler in node.handlers: + collect_imports(handler.body) + + collect_imports(tree.body) return tree, imported_names From 73e71d00e7bc7d39260a2b0577056617c9810a01 Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Mon, 16 Feb 2026 20:51:51 +0000 Subject: [PATCH 15/39] style: auto-fix linting issues --- .../python/context/code_context_extractor.py | 31 +++++++++++++------ 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index 173dc8021..f5d4d4a43 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -553,7 +553,7 @@ def _parse_and_collect_imports(code_context: CodeStringsMarkdown) -> tuple[ast.M except SyntaxError: return None imported_names: dict[str, str] = {} - + # Directly iterate over the module body and nested structures instead of ast.walk # This avoids traversing every single node in the tree def collect_imports(nodes): @@ -564,19 +564,32 @@ def collect_imports(nodes): imported_name = alias.asname if alias.asname else alias.name imported_names[imported_name] = node.module # Recursively check nested structures (function defs, class defs, if statements, etc.) - elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, - ast.If, ast.For, ast.AsyncFor, ast.While, ast.With, - ast.AsyncWith, ast.Try, ast.ExceptHandler)): - if hasattr(node, 'body'): + elif isinstance( + node, + ( + ast.FunctionDef, + ast.AsyncFunctionDef, + ast.ClassDef, + ast.If, + ast.For, + ast.AsyncFor, + ast.While, + ast.With, + ast.AsyncWith, + ast.Try, + ast.ExceptHandler, + ), + ): + if hasattr(node, "body"): collect_imports(node.body) - if hasattr(node, 'orelse'): + if hasattr(node, "orelse"): collect_imports(node.orelse) - if hasattr(node, 'finalbody'): + if hasattr(node, "finalbody"): collect_imports(node.finalbody) - if hasattr(node, 'handlers'): + if hasattr(node, "handlers"): for handler in node.handlers: collect_imports(handler.body) - + collect_imports(tree.body) return tree, imported_names From 29c0a66a9bb490ce5f80155cc1e9abcb49f1b81b Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Mon, 16 Feb 2026 20:52:37 +0000 Subject: [PATCH 16/39] fix: resolve mypy type errors in collect_imports --- codeflash/languages/python/context/code_context_extractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index f5d4d4a43..79d9c2959 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -556,7 +556,7 @@ def _parse_and_collect_imports(code_context: CodeStringsMarkdown) -> tuple[ast.M # Directly iterate over the module body and nested structures instead of ast.walk # This avoids traversing every single node in the tree - def collect_imports(nodes): + def collect_imports(nodes: list[ast.stmt]) -> None: for node in nodes: if isinstance(node, ast.ImportFrom) and node.module: for alias in node.names: From 4ff98658c2b0f061fb45c05f8da801d9ffb57f8a Mon Sep 17 00:00:00 2001 From: "codeflash-ai[bot]" <148906541+codeflash-ai[bot]@users.noreply.github.com> Date: Mon, 16 Feb 2026 20:53:44 +0000 Subject: [PATCH 17/39] Optimize collect_existing_class_names MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The optimized code achieves a **350% speedup** (2.36ms → 523μs) by replacing the generic `ast.walk()` traversal with a targeted stack-based iteration that only visits nodes where class definitions can appear. **Key Performance Improvement:** The original implementation uses `ast.walk(tree)`, which performs an exhaustive depth-first traversal of **every single node** in the AST—including expressions, literals, operators, and other leaf nodes that can never contain class definitions. For a typical Python module, this means checking thousands of irrelevant nodes. The optimized version uses a stack-based approach that only descends into structural nodes (ClassDef, FunctionDef, If, For, While, With, Try blocks) where classes can actually be defined. This dramatically reduces the number of nodes visited and `isinstance()` checks performed. **Why This Matters:** From the test results, we see consistent 200-700% speedups across all scenarios: - Empty modules: 579% faster (5.37μs → 791ns) - minimal traversal overhead - Simple cases: 200-400% faster - fewer nodes to check - Complex nested structures: 405% faster (37.2μs → 7.37μs) - targeted descent pays off - Large modules (500 classes): 280% faster (869μs → 228μs) - scales better - Mixed workloads: 558% faster (799μs → 121μs) - avoids non-class nodes **Impact on Workloads:** Based on the function references showing this is called from `build_testgen_context`, this optimization benefits test generation workflows that analyze Python code structure. Since class extraction is likely performed repeatedly during code analysis, the 4x speedup directly improves overall test generation throughput. The optimization is particularly effective for large codebases with many classes and complex nesting patterns, as demonstrated by the benchmark results. --- .../python/context/code_context_extractor.py | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index 9f904efbc..b710f044d 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -563,7 +563,28 @@ def _parse_and_collect_imports(code_context: CodeStringsMarkdown) -> tuple[ast.M def collect_existing_class_names(tree: ast.Module) -> set[str]: - return {node.name for node in ast.walk(tree) if isinstance(node, ast.ClassDef)} + class_names = set() + stack = list(tree.body) + + while stack: + node = stack.pop() + if isinstance(node, ast.ClassDef): + class_names.add(node.name) + stack.extend(node.body) + elif isinstance(node, ast.FunctionDef): + stack.extend(node.body) + elif isinstance(node, (ast.If, ast.For, ast.While, ast.With)): + stack.extend(node.body) + if hasattr(node, 'orelse'): + stack.extend(node.orelse) + elif isinstance(node, ast.Try): + stack.extend(node.body) + stack.extend(node.orelse) + stack.extend(node.finalbody) + for handler in node.handlers: + stack.extend(handler.body) + + return class_names def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path: Path) -> CodeStringsMarkdown: From 69d32681f786ae3b7c6fb615dfb34d98cbfbe91c Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Mon, 16 Feb 2026 20:55:52 +0000 Subject: [PATCH 18/39] style: auto-fix linting issues --- .../languages/python/context/code_context_extractor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index b710f044d..a07ba918d 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -565,7 +565,7 @@ def _parse_and_collect_imports(code_context: CodeStringsMarkdown) -> tuple[ast.M def collect_existing_class_names(tree: ast.Module) -> set[str]: class_names = set() stack = list(tree.body) - + while stack: node = stack.pop() if isinstance(node, ast.ClassDef): @@ -575,7 +575,7 @@ def collect_existing_class_names(tree: ast.Module) -> set[str]: stack.extend(node.body) elif isinstance(node, (ast.If, ast.For, ast.While, ast.With)): stack.extend(node.body) - if hasattr(node, 'orelse'): + if hasattr(node, "orelse"): stack.extend(node.orelse) elif isinstance(node, ast.Try): stack.extend(node.body) @@ -583,7 +583,7 @@ def collect_existing_class_names(tree: ast.Module) -> set[str]: stack.extend(node.finalbody) for handler in node.handlers: stack.extend(handler.body) - + return class_names From ea14b2f5484a772f17b088970af5a92661dd9291 Mon Sep 17 00:00:00 2001 From: Kevin Turcios <106575910+KRRT7@users.noreply.github.com> Date: Mon, 16 Feb 2026 15:59:22 -0500 Subject: [PATCH 19/39] Update codeflash/languages/python/context/code_context_extractor.py Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> --- codeflash/languages/python/context/code_context_extractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index a07ba918d..3c5f80424 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -571,7 +571,7 @@ def collect_existing_class_names(tree: ast.Module) -> set[str]: if isinstance(node, ast.ClassDef): class_names.add(node.name) stack.extend(node.body) - elif isinstance(node, ast.FunctionDef): + elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): stack.extend(node.body) elif isinstance(node, (ast.If, ast.For, ast.While, ast.With)): stack.extend(node.body) From bfa55cb12856c6800e5c965b8299295c7d8b6c4e Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Mon, 16 Feb 2026 21:02:03 +0000 Subject: [PATCH 20/39] fix: handle ast.Match (Python 3.10+) in collect_imports traversal The optimized collect_imports missed match/case statements where imports can legally appear. Add hasattr-guarded handling for ast.Match nodes. Co-authored-by: Kevin Turcios --- codeflash/languages/python/context/code_context_extractor.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index 79d9c2959..0116687f9 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -589,6 +589,10 @@ def collect_imports(nodes: list[ast.stmt]) -> None: if hasattr(node, "handlers"): for handler in node.handlers: collect_imports(handler.body) + # Handle match/case statements (Python 3.10+) + elif hasattr(ast, "Match") and isinstance(node, ast.Match): + for case in node.cases: + collect_imports(case.body) collect_imports(tree.body) return tree, imported_names From 707703ca59cc818181ac160dcecc423901741994 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Mon, 16 Feb 2026 16:55:01 -0500 Subject: [PATCH 21/39] refactor: deduplicate Python language support code Extract shared helpers and remove dead code across the language support area: - Extract `is_assignment_used()` and move `recurse_sections` to unused_definition_remover.py, replacing duplicated logic in both context files - Extract `function_sources_to_helpers()` in support.py to unify identical HelperFunction construction - Remove dead `get_comment_prefix()` method from protocol and all implementations (comment_prefix property serves all callers) --- codeflash/languages/base.py | 9 -- codeflash/languages/javascript/support.py | 9 -- .../python/context/code_context_extractor.py | 69 +-------- .../context/unused_definition_remover.py | 146 ++++++++++-------- codeflash/languages/python/support.py | 49 +++--- 5 files changed, 106 insertions(+), 176 deletions(-) diff --git a/codeflash/languages/base.py b/codeflash/languages/base.py index 99cefdf46..4253798bc 100644 --- a/codeflash/languages/base.py +++ b/codeflash/languages/base.py @@ -519,15 +519,6 @@ def get_test_file_suffix(self) -> str: """ ... - def get_comment_prefix(self) -> str: - """Get the comment prefix for this language. - - Returns: - Comment prefix (e.g., "//" for JS, "#" for Python). - - """ - ... - def find_test_root(self, project_root: Path) -> Path | None: """Find the test root directory for a project. diff --git a/codeflash/languages/javascript/support.py b/codeflash/languages/javascript/support.py index 20fe29573..724dc066e 100644 --- a/codeflash/languages/javascript/support.py +++ b/codeflash/languages/javascript/support.py @@ -1805,15 +1805,6 @@ def get_test_file_suffix(self) -> str: """ return ".test.js" - def get_comment_prefix(self) -> str: - """Get the comment prefix for JavaScript. - - Returns: - JavaScript single-line comment prefix. - - """ - return "//" - def find_test_root(self, project_root: Path) -> Path | None: """Find the test root directory for a JavaScript project. diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index c20032e8a..c7078e995 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -20,8 +20,9 @@ from codeflash.languages import Language, is_python from codeflash.languages.python.context.unused_definition_remover import ( collect_top_level_defs_with_usages, - extract_names_from_targets, get_section_names, + is_assignment_used, + recurse_sections, remove_unused_definitions_by_function_names, ) from codeflash.models.models import ( @@ -34,8 +35,6 @@ from codeflash.optimization.function_context import belongs_to_function_qualified if TYPE_CHECKING: - from collections.abc import Callable - from jedi.api.classes import Name from codeflash.languages.base import HelperFunction @@ -1103,50 +1102,6 @@ def _validate_classdef(node: cst.ClassDef, prefix: str) -> tuple[str, cst.Indent return _qualified_name(prefix, node.name.value), node.body -def _recurse_sections( - node: cst.CSTNode, - section_names: list[str], - prune_fn: Callable[[cst.CSTNode], tuple[cst.CSTNode | None, bool]], - keep_non_target_children: bool = False, -) -> tuple[cst.CSTNode | None, bool]: - updates: dict[str, list[cst.CSTNode] | cst.CSTNode] = {} - found_any_target = False - for section in section_names: - original_content = getattr(node, section, None) - if isinstance(original_content, (list, tuple)): - new_children = [] - section_found_target = False - for child in original_content: - filtered, found_target = prune_fn(child) - if filtered: - new_children.append(filtered) - section_found_target |= found_target - if keep_non_target_children: - if section_found_target or new_children: - found_any_target |= section_found_target - updates[section] = new_children - elif section_found_target: - found_any_target = True - updates[section] = new_children - elif original_content is not None: - filtered, found_target = prune_fn(original_content) - if keep_non_target_children: - found_any_target |= found_target - if filtered: - updates[section] = filtered - elif found_target: - found_any_target = True - if filtered: - updates[section] = filtered - if keep_non_target_children: - if updates: - return node.with_changes(**updates), found_any_target - return None, False - if not found_any_target: - return None, False - return (node.with_changes(**updates) if updates else node), True - - def prune_cst( node: cst.CSTNode, target_functions: set[str], @@ -1278,19 +1233,9 @@ def prune_cst( # Handle assignments for READ_WRITABLE mode if defs_with_usages is not None: - if isinstance(node, cst.Assign): - for target in node.targets: - names = extract_names_from_targets(target.target) - for name in names: - if name in defs_with_usages and defs_with_usages[name].used_by_qualified_function: - return node, True - return None, False - - if isinstance(node, (cst.AnnAssign, cst.AugAssign)): - names = extract_names_from_targets(node.target) - for name in names: - if name in defs_with_usages and defs_with_usages[name].used_by_qualified_function: - return node, True + if isinstance(node, (cst.Assign, cst.AnnAssign, cst.AugAssign)): + if is_assignment_used(node, defs_with_usages): + return node, True return None, False # For other nodes, recursively process children @@ -1299,7 +1244,7 @@ def prune_cst( return node, False if helpers is not None: - return _recurse_sections( + return recurse_sections( node, section_names, lambda child: prune_cst( @@ -1317,7 +1262,7 @@ def prune_cst( ), keep_non_target_children=True, ) - return _recurse_sections( + return recurse_sections( node, section_names, lambda child: prune_cst( diff --git a/codeflash/languages/python/context/unused_definition_remover.py b/codeflash/languages/python/context/unused_definition_remover.py index f4eec94e8..a016f32d3 100644 --- a/codeflash/languages/python/context/unused_definition_remover.py +++ b/codeflash/languages/python/context/unused_definition_remover.py @@ -15,6 +15,8 @@ from codeflash.models.models import CodeString, CodeStringsMarkdown if TYPE_CHECKING: + from collections.abc import Callable + from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import CodeOptimizationContext, FunctionSource @@ -49,6 +51,73 @@ def extract_names_from_targets(target: cst.CSTNode) -> list[str]: return names +def is_assignment_used( + node: cst.CSTNode, + definitions: dict[str, UsageInfo], + name_prefix: str = "", +) -> bool: + if isinstance(node, cst.Assign): + for target in node.targets: + names = extract_names_from_targets(target.target) + for name in names: + lookup = f"{name_prefix}{name}" if name_prefix else name + if lookup in definitions and definitions[lookup].used_by_qualified_function: + return True + return False + if isinstance(node, (cst.AnnAssign, cst.AugAssign)): + names = extract_names_from_targets(node.target) + for name in names: + lookup = f"{name_prefix}{name}" if name_prefix else name + if lookup in definitions and definitions[lookup].used_by_qualified_function: + return True + return False + return False + + +def recurse_sections( + node: cst.CSTNode, + section_names: list[str], + prune_fn: Callable[[cst.CSTNode], tuple[cst.CSTNode | None, bool]], + keep_non_target_children: bool = False, +) -> tuple[cst.CSTNode | None, bool]: + updates: dict[str, list[cst.CSTNode] | cst.CSTNode] = {} + found_any_target = False + for section in section_names: + original_content = getattr(node, section, None) + if isinstance(original_content, (list, tuple)): + new_children = [] + section_found_target = False + for child in original_content: + filtered, found_target = prune_fn(child) + if filtered: + new_children.append(filtered) + section_found_target |= found_target + if keep_non_target_children: + if section_found_target or new_children: + found_any_target |= section_found_target + updates[section] = new_children + elif section_found_target: + found_any_target = True + updates[section] = new_children + elif original_content is not None: + filtered, found_target = prune_fn(original_content) + if keep_non_target_children: + found_any_target |= found_target + if filtered: + updates[section] = filtered + elif found_target: + found_any_target = True + if filtered: + updates[section] = filtered + if keep_non_target_children: + if updates: + return node.with_changes(**updates), found_any_target + return None, False + if not found_any_target: + return None, False + return (node.with_changes(**updates) if updates else node), True + + def collect_top_level_definitions( node: cst.CSTNode, definitions: Optional[dict[str, UsageInfo]] = None ) -> dict[str, UsageInfo]: @@ -423,27 +492,9 @@ def remove_unused_definitions_recursively( elif isinstance(statement, (cst.Assign, cst.AnnAssign, cst.AugAssign)): var_used = False - # Check if any variable in this assignment is used - if isinstance(statement, cst.Assign): - for target in statement.targets: - names = extract_names_from_targets(target.target) - for name in names: - class_var_name = f"{class_name}.{name}" - if ( - class_var_name in definitions - and definitions[class_var_name].used_by_qualified_function - ): - var_used = True - method_or_var_used = True - break - elif isinstance(statement, (cst.AnnAssign, cst.AugAssign)): - names = extract_names_from_targets(statement.target) - for name in names: - class_var_name = f"{class_name}.{name}" - if class_var_name in definitions and definitions[class_var_name].used_by_qualified_function: - var_used = True - method_or_var_used = True - break + if is_assignment_used(statement, definitions, name_prefix=f"{class_name}."): + var_used = True + method_or_var_used = True if var_used or class_has_dependencies: new_statements.append(statement) @@ -459,56 +510,21 @@ def remove_unused_definitions_recursively( return node, method_or_var_used or class_has_dependencies - # Handle assignments (Assign and AnnAssign) - if isinstance(node, cst.Assign): - for target in node.targets: - names = extract_names_from_targets(target.target) - for name in names: - if name in definitions and definitions[name].used_by_qualified_function: - return node, True - return None, False - - if isinstance(node, (cst.AnnAssign, cst.AugAssign)): - names = extract_names_from_targets(node.target) - for name in names: - if name in definitions and definitions[name].used_by_qualified_function: - return node, True + # Handle assignments (Assign, AnnAssign, AugAssign) + if isinstance(node, (cst.Assign, cst.AnnAssign, cst.AugAssign)): + if is_assignment_used(node, definitions): + return node, True return None, False # For other nodes, recursively process children section_names = get_section_names(node) if not section_names: return node, False - - updates = {} - found_used = False - - for section in section_names: - original_content = getattr(node, section, None) - if isinstance(original_content, (list, tuple)): - new_children = [] - section_found_used = False - - for child in original_content: - filtered, used = remove_unused_definitions_recursively(child, definitions) - if filtered: - new_children.append(filtered) - section_found_used |= used - - if new_children or section_found_used: - found_used |= section_found_used - updates[section] = new_children - elif original_content is not None: - filtered, used = remove_unused_definitions_recursively(original_content, definitions) - found_used |= used - if filtered: - updates[section] = filtered - if not found_used: - return None, False - if updates: - return node.with_changes(**updates), found_used - - return node, False + return recurse_sections( + node, + section_names, + lambda child: remove_unused_definitions_recursively(child, definitions), + ) def collect_top_level_defs_with_usages( diff --git a/codeflash/languages/python/support.py b/codeflash/languages/python/support.py index 4b79b8d91..51624adf0 100644 --- a/codeflash/languages/python/support.py +++ b/codeflash/languages/python/support.py @@ -21,9 +21,25 @@ if TYPE_CHECKING: from collections.abc import Sequence + from codeflash.models.models import FunctionSource + logger = logging.getLogger(__name__) +def function_sources_to_helpers(sources: list[FunctionSource]) -> list[HelperFunction]: + return [ + HelperFunction( + name=fs.only_function_name, + qualified_name=fs.qualified_name, + file_path=fs.file_path, + source_code=fs.source_code, + start_line=fs.jedi_definition.line if fs.jedi_definition else 1, + end_line=fs.jedi_definition.line if fs.jedi_definition else 1, + ) + for fs in sources + ] + + @register_language class PythonSupport: """Python language support implementation. @@ -180,17 +196,7 @@ def extract_code_context(self, function: FunctionToOptimize, project_root: Path, logger.warning("Failed to extract code context for %s: %s", function.function_name, e) return CodeContext(target_code="", target_file=function.file_path, language=Language.PYTHON) - helpers = [ - HelperFunction( - name=fs.only_function_name, - qualified_name=fs.qualified_name, - file_path=fs.file_path, - source_code=fs.source_code, - start_line=fs.jedi_definition.line if fs.jedi_definition else 1, - end_line=fs.jedi_definition.line if fs.jedi_definition else 1, - ) - for fs in result.helper_functions - ] + helpers = function_sources_to_helpers(result.helper_functions) return CodeContext( target_code=result.read_writable_code.markdown, @@ -213,17 +219,7 @@ def find_helper_functions(self, function: FunctionToOptimize, project_root: Path logger.warning("Failed to find helpers for %s: %s", function.function_name, e) return [] - return [ - HelperFunction( - name=fs.only_function_name, - qualified_name=fs.qualified_name, - file_path=fs.file_path, - source_code=fs.source_code, - start_line=fs.jedi_definition.line if fs.jedi_definition else 1, - end_line=fs.jedi_definition.line if fs.jedi_definition else 1, - ) - for fs in sources - ] + return function_sources_to_helpers(sources) def find_references( self, function: FunctionToOptimize, project_root: Path, tests_root: Path | None = None, max_files: int = 500 @@ -660,15 +656,6 @@ def get_test_file_suffix(self) -> str: """ return ".py" - def get_comment_prefix(self) -> str: - """Get the comment prefix for Python. - - Returns: - Python single-line comment prefix. - - """ - return "#" - def find_test_root(self, project_root: Path) -> Path | None: """Find the test root directory for a Python project. From 633acce4366c38ad2104c45d0db25a075d5f4eed Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Mon, 16 Feb 2026 21:58:47 +0000 Subject: [PATCH 22/39] style: auto-fix linting issues --- .../python/context/unused_definition_remover.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/codeflash/languages/python/context/unused_definition_remover.py b/codeflash/languages/python/context/unused_definition_remover.py index a016f32d3..38b58f63e 100644 --- a/codeflash/languages/python/context/unused_definition_remover.py +++ b/codeflash/languages/python/context/unused_definition_remover.py @@ -51,11 +51,7 @@ def extract_names_from_targets(target: cst.CSTNode) -> list[str]: return names -def is_assignment_used( - node: cst.CSTNode, - definitions: dict[str, UsageInfo], - name_prefix: str = "", -) -> bool: +def is_assignment_used(node: cst.CSTNode, definitions: dict[str, UsageInfo], name_prefix: str = "") -> bool: if isinstance(node, cst.Assign): for target in node.targets: names = extract_names_from_targets(target.target) @@ -521,9 +517,7 @@ def remove_unused_definitions_recursively( if not section_names: return node, False return recurse_sections( - node, - section_names, - lambda child: remove_unused_definitions_recursively(child, definitions), + node, section_names, lambda child: remove_unused_definitions_recursively(child, definitions) ) From fa452f2f31537fb20c69dd21846bf1b5b4a3a343 Mon Sep 17 00:00:00 2001 From: KRRT7 Date: Tue, 17 Feb 2026 05:54:21 +0000 Subject: [PATCH 23/39] fix: update license format to use license-files Replace deprecated license table format with modern license-files array in both main package and codeflash-benchmark subpackage. This resolves the setuptools deprecation warning about TOML table license format. Changes: - Use license-files = ["LICENSE"] instead of license = {text = "BSL-1.1"} - Add LICENSE file to root directory - Add LICENSE and README.md to codeflash-benchmark/ --- LICENSE | 98 ++++ codeflash-benchmark/LICENSE | 98 ++++ codeflash-benchmark/README.md | 15 + codeflash-benchmark/pyproject.toml | 64 +-- pyproject.toml | 716 ++++++++++++++--------------- 5 files changed, 601 insertions(+), 390 deletions(-) create mode 100644 LICENSE create mode 100644 codeflash-benchmark/LICENSE create mode 100644 codeflash-benchmark/README.md diff --git a/LICENSE b/LICENSE new file mode 100644 index 000000000..6d6a48b5f --- /dev/null +++ b/LICENSE @@ -0,0 +1,98 @@ +Business Source License 1.1 + +Parameters + +Licensor: CodeFlash Inc. +Licensed Work: Codeflash Client version 0.20.x + The Licensed Work is (c) 2024 CodeFlash Inc. + +Additional Use Grant: None. Production use of the Licensed Work is only permitted + if you have entered into a separate written agreement + with CodeFlash Inc. for production use in connection + with a subscription to CodeFlash's Code Optimization + Platform. Please visit codeflash.ai for further + information. + +Change Date: 2030-01-26 + +Change License: MIT + +Notice + +The Business Source License (this document, or the “License”) is not an Open +Source license. However, the Licensed Work will eventually be made available +under an Open Source License, as stated in this License. + +License text copyright (c) 2017 MariaDB Corporation Ab, All Rights Reserved. +“Business Source License” is a trademark of MariaDB Corporation Ab. + +----------------------------------------------------------------------------- + +Business Source License 1.1 + +Terms + +The Licensor hereby grants you the right to copy, modify, create derivative +works, redistribute, and make non-production use of the Licensed Work. The +Licensor may make an Additional Use Grant, above, permitting limited +production use. + +Effective on the Change Date, or the fourth anniversary of the first publicly +available distribution of a specific version of the Licensed Work under this +License, whichever comes first, the Licensor hereby grants you rights under +the terms of the Change License, and the rights granted in the paragraph +above terminate. + +If your use of the Licensed Work does not comply with the requirements +currently in effect as described in this License, you must purchase a +commercial license from the Licensor, its affiliated entities, or authorized +resellers, or you must refrain from using the Licensed Work. + +All copies of the original and modified Licensed Work, and derivative works +of the Licensed Work, are subject to this License. This License applies +separately for each version of the Licensed Work and the Change Date may vary +for each version of the Licensed Work released by Licensor. + +You must conspicuously display this License on each original or modified copy +of the Licensed Work. If you receive the Licensed Work in original or +modified form from a third party, the terms and conditions set forth in this +License apply to your use of that work. + +Any use of the Licensed Work in violation of this License will automatically +terminate your rights under this License for the current and all other +versions of the Licensed Work. + +This License does not grant you any right in any trademark or logo of +Licensor or its affiliates (provided that you may use a trademark or logo of +Licensor as expressly required by this License). + +TO THE EXTENT PERMITTED BY APPLICABLE LAW, THE LICENSED WORK IS PROVIDED ON +AN “AS IS” BASIS. LICENSOR HEREBY DISCLAIMS ALL WARRANTIES AND CONDITIONS, +EXPRESS OR IMPLIED, INCLUDING (WITHOUT LIMITATION) WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, AND +TITLE. + +MariaDB hereby grants you permission to use this License’s text to license +your works, and to refer to it using the trademark “Business Source License”, +as long as you comply with the Covenants of Licensor below. + +Covenants of Licensor + +In consideration of the right to use this License’s text and the “Business +Source License” name and trademark, Licensor covenants to MariaDB, and to all +other recipients of the licensed work to be provided by Licensor: + +1. To specify as the Change License the GPL Version 2.0 or any later version, + or a license that is compatible with GPL Version 2.0 or a later version, + where “compatible” means that software provided under the Change License can + be included in a program with software provided under GPL Version 2.0 or a + later version. Licensor may specify additional Change Licenses without + limitation. + +2. To either: (a) specify an additional grant of rights to use that does not + impose any additional restriction on the right granted in this License, as + the Additional Use Grant; or (b) insert the text “None”. + +3. To specify a Change Date. + +4. Not to modify this License in any other way. \ No newline at end of file diff --git a/codeflash-benchmark/LICENSE b/codeflash-benchmark/LICENSE new file mode 100644 index 000000000..6d6a48b5f --- /dev/null +++ b/codeflash-benchmark/LICENSE @@ -0,0 +1,98 @@ +Business Source License 1.1 + +Parameters + +Licensor: CodeFlash Inc. +Licensed Work: Codeflash Client version 0.20.x + The Licensed Work is (c) 2024 CodeFlash Inc. + +Additional Use Grant: None. Production use of the Licensed Work is only permitted + if you have entered into a separate written agreement + with CodeFlash Inc. for production use in connection + with a subscription to CodeFlash's Code Optimization + Platform. Please visit codeflash.ai for further + information. + +Change Date: 2030-01-26 + +Change License: MIT + +Notice + +The Business Source License (this document, or the “License”) is not an Open +Source license. However, the Licensed Work will eventually be made available +under an Open Source License, as stated in this License. + +License text copyright (c) 2017 MariaDB Corporation Ab, All Rights Reserved. +“Business Source License” is a trademark of MariaDB Corporation Ab. + +----------------------------------------------------------------------------- + +Business Source License 1.1 + +Terms + +The Licensor hereby grants you the right to copy, modify, create derivative +works, redistribute, and make non-production use of the Licensed Work. The +Licensor may make an Additional Use Grant, above, permitting limited +production use. + +Effective on the Change Date, or the fourth anniversary of the first publicly +available distribution of a specific version of the Licensed Work under this +License, whichever comes first, the Licensor hereby grants you rights under +the terms of the Change License, and the rights granted in the paragraph +above terminate. + +If your use of the Licensed Work does not comply with the requirements +currently in effect as described in this License, you must purchase a +commercial license from the Licensor, its affiliated entities, or authorized +resellers, or you must refrain from using the Licensed Work. + +All copies of the original and modified Licensed Work, and derivative works +of the Licensed Work, are subject to this License. This License applies +separately for each version of the Licensed Work and the Change Date may vary +for each version of the Licensed Work released by Licensor. + +You must conspicuously display this License on each original or modified copy +of the Licensed Work. If you receive the Licensed Work in original or +modified form from a third party, the terms and conditions set forth in this +License apply to your use of that work. + +Any use of the Licensed Work in violation of this License will automatically +terminate your rights under this License for the current and all other +versions of the Licensed Work. + +This License does not grant you any right in any trademark or logo of +Licensor or its affiliates (provided that you may use a trademark or logo of +Licensor as expressly required by this License). + +TO THE EXTENT PERMITTED BY APPLICABLE LAW, THE LICENSED WORK IS PROVIDED ON +AN “AS IS” BASIS. LICENSOR HEREBY DISCLAIMS ALL WARRANTIES AND CONDITIONS, +EXPRESS OR IMPLIED, INCLUDING (WITHOUT LIMITATION) WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, AND +TITLE. + +MariaDB hereby grants you permission to use this License’s text to license +your works, and to refer to it using the trademark “Business Source License”, +as long as you comply with the Covenants of Licensor below. + +Covenants of Licensor + +In consideration of the right to use this License’s text and the “Business +Source License” name and trademark, Licensor covenants to MariaDB, and to all +other recipients of the licensed work to be provided by Licensor: + +1. To specify as the Change License the GPL Version 2.0 or any later version, + or a license that is compatible with GPL Version 2.0 or a later version, + where “compatible” means that software provided under the Change License can + be included in a program with software provided under GPL Version 2.0 or a + later version. Licensor may specify additional Change Licenses without + limitation. + +2. To either: (a) specify an additional grant of rights to use that does not + impose any additional restriction on the right granted in this License, as + the Additional Use Grant; or (b) insert the text “None”. + +3. To specify a Change Date. + +4. Not to modify this License in any other way. \ No newline at end of file diff --git a/codeflash-benchmark/README.md b/codeflash-benchmark/README.md new file mode 100644 index 000000000..91d79ae0d --- /dev/null +++ b/codeflash-benchmark/README.md @@ -0,0 +1,15 @@ +# CodeFlash Benchmark + +A pytest benchmarking plugin for [CodeFlash](https://codeflash.ai) - automatic code performance optimization. + +## Installation + +```bash +pip install codeflash-benchmark +``` + +## Usage + +This plugin provides benchmarking capabilities for pytest tests used by CodeFlash's optimization pipeline. + +For more information, visit [codeflash.ai](https://codeflash.ai). diff --git a/codeflash-benchmark/pyproject.toml b/codeflash-benchmark/pyproject.toml index f068f7367..bc5e9040d 100644 --- a/codeflash-benchmark/pyproject.toml +++ b/codeflash-benchmark/pyproject.toml @@ -1,32 +1,32 @@ -[project] -name = "codeflash-benchmark" -version = "0.2.0" -description = "Pytest benchmarking plugin for codeflash.ai - automatic code performance optimization" -authors = [{ name = "CodeFlash Inc.", email = "contact@codeflash.ai" }] -requires-python = ">=3.9" -readme = "README.md" -license = {text = "BSL-1.1"} -keywords = [ - "codeflash", - "benchmark", - "pytest", - "performance", - "testing", -] -dependencies = [ - "pytest>=7.0.0,!=8.3.4", -] - -[project.urls] -Homepage = "https://codeflash.ai" -Repository = "https://github.com/codeflash-ai/codeflash-benchmark" - -[project.entry-points.pytest11] -codeflash-benchmark = "codeflash_benchmark.plugin" - -[build-system] -requires = ["setuptools>=45", "wheel"] -build-backend = "setuptools.build_meta" - -[tool.setuptools] -packages = ["codeflash_benchmark"] +[project] +name = "codeflash-benchmark" +version = "0.2.0" +description = "Pytest benchmarking plugin for codeflash.ai - automatic code performance optimization" +authors = [{ name = "CodeFlash Inc.", email = "contact@codeflash.ai" }] +requires-python = ">=3.9" +readme = "README.md" +license-files = ["LICENSE"] +keywords = [ + "codeflash", + "benchmark", + "pytest", + "performance", + "testing", +] +dependencies = [ + "pytest>=7.0.0,!=8.3.4", +] + +[project.urls] +Homepage = "https://codeflash.ai" +Repository = "https://github.com/codeflash-ai/codeflash-benchmark" + +[project.entry-points.pytest11] +codeflash-benchmark = "codeflash_benchmark.plugin" + +[build-system] +requires = ["setuptools>=45", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +packages = ["codeflash_benchmark"] diff --git a/pyproject.toml b/pyproject.toml index 6af1d1435..f996d2a34 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,358 +1,358 @@ -[project] -name = "codeflash" -dynamic = ["version"] -description = "Client for codeflash.ai - automatic code performance optimization, powered by AI" -authors = [{ name = "CodeFlash Inc.", email = "contact@codeflash.ai" }] -requires-python = ">=3.9" -readme = "README.md" -license = {text = "BSL-1.1"} -keywords = [ - "codeflash", - "performance", - "optimization", - "ai", - "code", - "machine learning", - "LLM", -] -dependencies = [ - "unidiff>=0.7.4", - "pytest>=7.0.0", - "gitpython>=3.1.31", - "libcst>=1.0.1", - "jedi>=0.19.1", - # Tree-sitter for multi-language support - "tree-sitter>=0.23.0", - "tree-sitter-javascript>=0.23.0", - "tree-sitter-typescript>=0.23.0", - "pytest-timeout>=2.1.0", - "tomlkit>=0.11.7", - "junitparser>=3.1.0", - "pydantic>=1.10.1", - "humanize>=4.0.0", - "posthog>=3.0.0", - "click>=8.1.0", - "inquirer>=3.0.0", - "sentry-sdk>=1.40.6,<3.0.0", - "parameterized>=0.9.0", - "isort>=5.11.0", - "dill>=0.3.8", - "rich>=13.8.1", - "lxml>=5.3.0", - "crosshair-tool>=0.0.78", - "coverage>=7.6.4", - "line_profiler>=4.2.0", - "platformdirs>=4.3.7", - "pygls>=2.0.0,<3.0.0", - "codeflash-benchmark", - "filelock", - "pytest-asyncio>=0.18.0", -] - -[project.urls] -Homepage = "https://codeflash.ai" - -[project.scripts] -codeflash = "codeflash.main:main" - -[project.optional-dependencies] - -[dependency-groups] -dev = [ - "ipython>=8.12.0", - "mypy>=1.13", - "ruff>=0.7.0", - "lxml-stubs>=0.5.1", - "pandas-stubs>=2.2.2.240807, <2.2.3.241009", - "types-Pygments>=2.18.0.20240506", - "types-colorama>=0.4.15.20240311", - "types-decorator>=5.1.8.20240310", - "types-jsonschema>=4.23.0.20240813", - "types-requests>=2.32.0.20241016", - "types-six>=1.16.21.20241009", - "types-cffi>=1.16.0.20240331", - "types-openpyxl>=3.1.5.20241020", - "types-regex>=2024.9.11.20240912", - "types-python-dateutil>=2.9.0.20241003", - "types-gevent>=24.11.0.20241230,<25", - "types-greenlet>=3.1.0.20241221,<4", - "types-pexpect>=4.9.0.20241208,<5", - "types-unidiff>=0.7.0.20240505,<0.8", - "prek>=0.2.25", - "ty>=0.0.14", - "uv>=0.9.29", -] -tests = [ - "black>=25.9.0", - "jax>=0.4.30", - "numpy>=2.0.2", - "pandas>=2.3.3", - "pyarrow>=15.0.0", - "pyrsistent>=0.20.0", - "scipy>=1.13.1", - "torch>=2.8.0", - "xarray>=2024.7.0", - "eval_type_backport", - "numba>=0.60.0", - "tensorflow>=2.20.0", -] - -[tool.hatch.build.targets.sdist] -include = ["codeflash"] -exclude = [ - "docs/*", - "experiments/*", - "tests/*", - "*.pyc", - "__pycache__", - "*.pyo", - "*.pyd", - "*.so", - "*.dylib", - "*.dll", - "*.exe", - "*.log", - "*.tmp", - ".env", - ".env.*", - "**/.env", - "**/.env.*", - ".env.example", - "*.pem", - "*.key", - "secrets.*", - "config.yaml", - "config.json", - ".git", - ".gitignore", - ".gitattributes", - ".github", - "Dockerfile", - "docker-compose.yml", - "*.md", - "*.txt", - "*.csv", - "*.db", - "*.sqlite3", - "*.pdf", - "*.docx", - "*.xlsx", - "*.pptx", - "*.iml", - ".idea", - ".vscode", - ".DS_Store", - "Thumbs.db", - "venv", - "env", -] - -[tool.hatch.build.targets.wheel] -exclude = [ - "docs/*", - "experiments/*", - "tests/*", - "*.pyc", - "__pycache__", - "*.pyo", - "*.pyd", - "*.so", - "*.dylib", - "*.dll", - "*.exe", - "*.log", - "*.tmp", - ".env", - ".env.*", - "**/.env", - "**/.env.*", - ".env.example", - "*.pem", - "*.key", - "secrets.*", - "config.yaml", - "config.json", - ".git", - ".gitignore", - ".gitattributes", - ".github", - "Dockerfile", - "docker-compose.yml", - "*.md", - "*.txt", - "*.csv", - "*.db", - "*.sqlite3", - "*.pdf", - "*.docx", - "*.xlsx", - "*.pptx", - "*.iml", - ".idea", - ".vscode", - ".DS_Store", - "Thumbs.db", - "venv", - "env", -] - -[tool.mypy] -show_error_code_links = true -pretty = true -show_absolute_path = true -show_error_context = true -show_error_end = true -strict = true -warn_unreachable = true -install_types = true -plugins = ["pydantic.mypy"] - -[[tool.mypy.overrides]] -module = ["jedi", "jedi.api.classes", "inquirer", "inquirer.themes", "numba"] -ignore_missing_imports = true - -[tool.pydantic-mypy] -init_forbid_extra = true -init_typed = true -warn_required_dynamic_aliases = true - -[tool.ruff] -target-version = "py39" -line-length = 120 -fix = true -show-fixes = true -extend-exclude = ["code_to_optimize/", "pie_test_set/", "tests/", "experiments/"] - -[tool.ruff.lint] -select = ["ALL"] -ignore = [ - "N802", - "C901", - "D100", - "D101", - "D102", - "D103", - "D105", - "D107", - "D203", # incorrect-blank-line-before-class (incompatible with D211) - "D213", # multi-line-summary-second-line (incompatible with D212) - "S101", - "S603", - "S607", - "COM812", - "FIX002", - "PLR0912", - "PLR0913", - "PLR0915", - "TD002", - "TD003", - "TD004", - "PLR2004", - "UP007", # remove once we drop 3.9 support. - "E501", - "BLE001", - "ERA001", - "TRY003", - "EM101", - "T201", - "PGH004", - "S301", - "D104", - "PERF203", - "LOG015", - "PLC0415", - "UP045", - "TD007", - "D417", - "D401", - "S110", # try-except-pass - we do this a lot - "ARG002", # Unused method argument - # Added for multi-language branch - "FBT001", # Boolean positional argument - "FBT002", # Boolean default positional argument - "ANN401", # typing.Any disallowed - "ARG001", # Unused function argument (common in abstract/interface methods) - "TRY300", # Consider moving to else block - "FURB110", # if-exp-instead-of-or-operator - we prefer explicit if-else over "or" - "TRY401", # Redundant exception in logging.exception - "PLR0911", # Too many return statements - "PLW0603", # Global statement - "PLW2901", # Loop variable overwritten - "SIM102", # Nested if statements - "SIM103", # Return negated condition - "ANN001", # Missing type annotation - "PLC0206", # Dictionary items - "S314", # XML parsing (acceptable for dev tool) - "S608", # SQL injection (internal use only) - "S112", # try-except-continue - "PERF401", # List comprehension suggestion - "SIM108", # Ternary operator suggestion - "F841", # Unused variable (often intentional) - "ANN202", # Missing return type for private functions - "B009", # getattr-with-constant - needed to avoid mypy [misc] on dunder access -] - -[tool.ruff.lint.flake8-type-checking] -strict = true -runtime-evaluated-base-classes = ["pydantic.BaseModel"] -runtime-evaluated-decorators = ["pydantic.validate_call", "pydantic.dataclasses.dataclass"] - -[tool.ruff.lint.pep8-naming] -classmethod-decorators = [ - # Allow Pydantic's `@validator` decorator to trigger class method treatment. - "pydantic.validator", -] - -[tool.ruff.lint.isort] -split-on-trailing-comma = false - -[tool.ruff.format] -docstring-code-format = true -skip-magic-trailing-comma = true - -[tool.hatch.version] -source = "uv-dynamic-versioning" - -[tool.uv] -workspace = { members = ["codeflash-benchmark"] } - -[tool.uv.sources] -codeflash-benchmark = { workspace = true } - -[tool.uv-dynamic-versioning] -enable = true -style = "pep440" -vcs = "git" - -[tool.hatch.build.hooks.version] -path = "codeflash/version.py" -template = """# These version placeholders will be replaced by uv-dynamic-versioning during build. -__version__ = "{version}" -""" - - -#[tool.hatch.build.hooks.custom] -#path = "codeflash/update_license_version.py" - - -[tool.codeflash] -# All paths are relative to this pyproject.toml's directory. -module-root = "codeflash" -tests-root = "codeflash" -benchmarks-root = "tests/benchmarks" -ignore-paths = [] -formatter-cmds = ["disabled"] - -[tool.pytest.ini_options] -filterwarnings = [ - "ignore::pytest.PytestCollectionWarning", -] -markers = [ - "ci_skip: mark test to skip in CI environment", -] - - -[build-system] -requires = ["hatchling", "uv-dynamic-versioning"] -build-backend = "hatchling.build" - +[project] +name = "codeflash" +dynamic = ["version"] +description = "Client for codeflash.ai - automatic code performance optimization, powered by AI" +authors = [{ name = "CodeFlash Inc.", email = "contact@codeflash.ai" }] +requires-python = ">=3.9" +readme = "README.md" +license-files = ["LICENSE"] +keywords = [ + "codeflash", + "performance", + "optimization", + "ai", + "code", + "machine learning", + "LLM", +] +dependencies = [ + "unidiff>=0.7.4", + "pytest>=7.0.0", + "gitpython>=3.1.31", + "libcst>=1.0.1", + "jedi>=0.19.1", + # Tree-sitter for multi-language support + "tree-sitter>=0.23.0", + "tree-sitter-javascript>=0.23.0", + "tree-sitter-typescript>=0.23.0", + "pytest-timeout>=2.1.0", + "tomlkit>=0.11.7", + "junitparser>=3.1.0", + "pydantic>=1.10.1", + "humanize>=4.0.0", + "posthog>=3.0.0", + "click>=8.1.0", + "inquirer>=3.0.0", + "sentry-sdk>=1.40.6,<3.0.0", + "parameterized>=0.9.0", + "isort>=5.11.0", + "dill>=0.3.8", + "rich>=13.8.1", + "lxml>=5.3.0", + "crosshair-tool>=0.0.78", + "coverage>=7.6.4", + "line_profiler>=4.2.0", + "platformdirs>=4.3.7", + "pygls>=2.0.0,<3.0.0", + "codeflash-benchmark", + "filelock", + "pytest-asyncio>=0.18.0", +] + +[project.urls] +Homepage = "https://codeflash.ai" + +[project.scripts] +codeflash = "codeflash.main:main" + +[project.optional-dependencies] + +[dependency-groups] +dev = [ + "ipython>=8.12.0", + "mypy>=1.13", + "ruff>=0.7.0", + "lxml-stubs>=0.5.1", + "pandas-stubs>=2.2.2.240807, <2.2.3.241009", + "types-Pygments>=2.18.0.20240506", + "types-colorama>=0.4.15.20240311", + "types-decorator>=5.1.8.20240310", + "types-jsonschema>=4.23.0.20240813", + "types-requests>=2.32.0.20241016", + "types-six>=1.16.21.20241009", + "types-cffi>=1.16.0.20240331", + "types-openpyxl>=3.1.5.20241020", + "types-regex>=2024.9.11.20240912", + "types-python-dateutil>=2.9.0.20241003", + "types-gevent>=24.11.0.20241230,<25", + "types-greenlet>=3.1.0.20241221,<4", + "types-pexpect>=4.9.0.20241208,<5", + "types-unidiff>=0.7.0.20240505,<0.8", + "prek>=0.2.25", + "ty>=0.0.14", + "uv>=0.9.29", +] +tests = [ + "black>=25.9.0", + "jax>=0.4.30", + "numpy>=2.0.2", + "pandas>=2.3.3", + "pyarrow>=15.0.0", + "pyrsistent>=0.20.0", + "scipy>=1.13.1", + "torch>=2.8.0", + "xarray>=2024.7.0", + "eval_type_backport", + "numba>=0.60.0", + "tensorflow>=2.20.0", +] + +[tool.hatch.build.targets.sdist] +include = ["codeflash"] +exclude = [ + "docs/*", + "experiments/*", + "tests/*", + "*.pyc", + "__pycache__", + "*.pyo", + "*.pyd", + "*.so", + "*.dylib", + "*.dll", + "*.exe", + "*.log", + "*.tmp", + ".env", + ".env.*", + "**/.env", + "**/.env.*", + ".env.example", + "*.pem", + "*.key", + "secrets.*", + "config.yaml", + "config.json", + ".git", + ".gitignore", + ".gitattributes", + ".github", + "Dockerfile", + "docker-compose.yml", + "*.md", + "*.txt", + "*.csv", + "*.db", + "*.sqlite3", + "*.pdf", + "*.docx", + "*.xlsx", + "*.pptx", + "*.iml", + ".idea", + ".vscode", + ".DS_Store", + "Thumbs.db", + "venv", + "env", +] + +[tool.hatch.build.targets.wheel] +exclude = [ + "docs/*", + "experiments/*", + "tests/*", + "*.pyc", + "__pycache__", + "*.pyo", + "*.pyd", + "*.so", + "*.dylib", + "*.dll", + "*.exe", + "*.log", + "*.tmp", + ".env", + ".env.*", + "**/.env", + "**/.env.*", + ".env.example", + "*.pem", + "*.key", + "secrets.*", + "config.yaml", + "config.json", + ".git", + ".gitignore", + ".gitattributes", + ".github", + "Dockerfile", + "docker-compose.yml", + "*.md", + "*.txt", + "*.csv", + "*.db", + "*.sqlite3", + "*.pdf", + "*.docx", + "*.xlsx", + "*.pptx", + "*.iml", + ".idea", + ".vscode", + ".DS_Store", + "Thumbs.db", + "venv", + "env", +] + +[tool.mypy] +show_error_code_links = true +pretty = true +show_absolute_path = true +show_error_context = true +show_error_end = true +strict = true +warn_unreachable = true +install_types = true +plugins = ["pydantic.mypy"] + +[[tool.mypy.overrides]] +module = ["jedi", "jedi.api.classes", "inquirer", "inquirer.themes", "numba"] +ignore_missing_imports = true + +[tool.pydantic-mypy] +init_forbid_extra = true +init_typed = true +warn_required_dynamic_aliases = true + +[tool.ruff] +target-version = "py39" +line-length = 120 +fix = true +show-fixes = true +extend-exclude = ["code_to_optimize/", "pie_test_set/", "tests/", "experiments/"] + +[tool.ruff.lint] +select = ["ALL"] +ignore = [ + "N802", + "C901", + "D100", + "D101", + "D102", + "D103", + "D105", + "D107", + "D203", # incorrect-blank-line-before-class (incompatible with D211) + "D213", # multi-line-summary-second-line (incompatible with D212) + "S101", + "S603", + "S607", + "COM812", + "FIX002", + "PLR0912", + "PLR0913", + "PLR0915", + "TD002", + "TD003", + "TD004", + "PLR2004", + "UP007", # remove once we drop 3.9 support. + "E501", + "BLE001", + "ERA001", + "TRY003", + "EM101", + "T201", + "PGH004", + "S301", + "D104", + "PERF203", + "LOG015", + "PLC0415", + "UP045", + "TD007", + "D417", + "D401", + "S110", # try-except-pass - we do this a lot + "ARG002", # Unused method argument + # Added for multi-language branch + "FBT001", # Boolean positional argument + "FBT002", # Boolean default positional argument + "ANN401", # typing.Any disallowed + "ARG001", # Unused function argument (common in abstract/interface methods) + "TRY300", # Consider moving to else block + "FURB110", # if-exp-instead-of-or-operator - we prefer explicit if-else over "or" + "TRY401", # Redundant exception in logging.exception + "PLR0911", # Too many return statements + "PLW0603", # Global statement + "PLW2901", # Loop variable overwritten + "SIM102", # Nested if statements + "SIM103", # Return negated condition + "ANN001", # Missing type annotation + "PLC0206", # Dictionary items + "S314", # XML parsing (acceptable for dev tool) + "S608", # SQL injection (internal use only) + "S112", # try-except-continue + "PERF401", # List comprehension suggestion + "SIM108", # Ternary operator suggestion + "F841", # Unused variable (often intentional) + "ANN202", # Missing return type for private functions + "B009", # getattr-with-constant - needed to avoid mypy [misc] on dunder access +] + +[tool.ruff.lint.flake8-type-checking] +strict = true +runtime-evaluated-base-classes = ["pydantic.BaseModel"] +runtime-evaluated-decorators = ["pydantic.validate_call", "pydantic.dataclasses.dataclass"] + +[tool.ruff.lint.pep8-naming] +classmethod-decorators = [ + # Allow Pydantic's `@validator` decorator to trigger class method treatment. + "pydantic.validator", +] + +[tool.ruff.lint.isort] +split-on-trailing-comma = false + +[tool.ruff.format] +docstring-code-format = true +skip-magic-trailing-comma = true + +[tool.hatch.version] +source = "uv-dynamic-versioning" + +[tool.uv] +workspace = { members = ["codeflash-benchmark"] } + +[tool.uv.sources] +codeflash-benchmark = { workspace = true } + +[tool.uv-dynamic-versioning] +enable = true +style = "pep440" +vcs = "git" + +[tool.hatch.build.hooks.version] +path = "codeflash/version.py" +template = """# These version placeholders will be replaced by uv-dynamic-versioning during build. +__version__ = "{version}" +""" + + +#[tool.hatch.build.hooks.custom] +#path = "codeflash/update_license_version.py" + + +[tool.codeflash] +# All paths are relative to this pyproject.toml's directory. +module-root = "codeflash" +tests-root = "codeflash" +benchmarks-root = "tests/benchmarks" +ignore-paths = [] +formatter-cmds = ["disabled"] + +[tool.pytest.ini_options] +filterwarnings = [ + "ignore::pytest.PytestCollectionWarning", +] +markers = [ + "ci_skip: mark test to skip in CI environment", +] + + +[build-system] +requires = ["hatchling", "uv-dynamic-versioning"] +build-backend = "hatchling.build" + From e1a45dd0c81b96ef3ddfc0ec40907e38ea056ac3 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Tue, 17 Feb 2026 19:02:00 +0530 Subject: [PATCH 24/39] chore: switch Claude workflows from Foundry to AWS Bedrock Replace Azure Foundry authentication with AWS Bedrock OIDC in all Claude Code GitHub Actions workflows. Co-Authored-By: Claude Opus 4.6 --- .github/workflows/claude.yml | 26 ++++++++++++------- .github/workflows/duplicate-code-detector.yml | 12 +++++---- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml index d691072aa..edb861183 100644 --- a/.github/workflows/claude.yml +++ b/.github/workflows/claude.yml @@ -42,11 +42,17 @@ jobs: uv venv --seed uv sync + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.AWS_ROLE_TO_ASSUME }} + aws-region: ${{ secrets.AWS_REGION }} + - name: Run Claude Code id: claude uses: anthropics/claude-code-action@v1 with: - use_foundry: "true" + use_bedrock: "true" use_sticky_comment: true allowed_bots: "claude[bot],codeflash-ai[bot]" prompt: | @@ -173,12 +179,9 @@ jobs: 2. For each optimization PR: - Check if CI is passing: `gh pr checks ` - If all checks pass, merge it: `gh pr merge --squash --delete-branch` - claude_args: '--model claude-opus-4-6 --allowedTools "mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*),Bash(gh pr checks:*),Bash(gh pr merge:*),Bash(gh issue view:*),Bash(gh issue list:*),Bash(gh api:*),Bash(uv run prek *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(uv run pytest *),Bash(git status*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git diff *),Bash(git checkout *),Read,Glob,Grep,Edit"' + claude_args: '--model us.anthropic.claude-opus-4-6-v1:0 --allowedTools "mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*),Bash(gh pr checks:*),Bash(gh pr merge:*),Bash(gh issue view:*),Bash(gh issue list:*),Bash(gh api:*),Bash(uv run prek *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(uv run pytest *),Bash(git status*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git diff *),Bash(git checkout *),Read,Glob,Grep,Edit"' additional_permissions: | actions: read - env: - ANTHROPIC_FOUNDRY_API_KEY: ${{ secrets.AZURE_ANTHROPIC_API_KEY }} - ANTHROPIC_FOUNDRY_BASE_URL: ${{ secrets.AZURE_ANTHROPIC_ENDPOINT }} # @claude mentions (can edit and push) - restricted to maintainers only claude-mention: @@ -240,14 +243,17 @@ jobs: uv venv --seed uv sync + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.AWS_ROLE_TO_ASSUME }} + aws-region: ${{ secrets.AWS_REGION }} + - name: Run Claude Code id: claude uses: anthropics/claude-code-action@v1 with: - use_foundry: "true" - claude_args: '--model claude-opus-4-6 --allowedTools "Read,Edit,Write,Glob,Grep,Bash(git status*),Bash(git diff*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git log*),Bash(git merge*),Bash(git fetch*),Bash(git checkout*),Bash(git branch*),Bash(uv run prek *),Bash(prek *),Bash(uv run ruff *),Bash(uv run pytest *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(gh pr comment*),Bash(gh pr view*),Bash(gh pr diff*),Bash(gh pr merge*),Bash(gh pr close*)"' + use_bedrock: "true" + claude_args: '--model us.anthropic.claude-opus-4-6-v1:0 --allowedTools "Read,Edit,Write,Glob,Grep,Bash(git status*),Bash(git diff*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git log*),Bash(git merge*),Bash(git fetch*),Bash(git checkout*),Bash(git branch*),Bash(uv run prek *),Bash(prek *),Bash(uv run ruff *),Bash(uv run pytest *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(gh pr comment*),Bash(gh pr view*),Bash(gh pr diff*),Bash(gh pr merge*),Bash(gh pr close*)"' additional_permissions: | actions: read - env: - ANTHROPIC_FOUNDRY_API_KEY: ${{ secrets.AZURE_ANTHROPIC_API_KEY }} - ANTHROPIC_FOUNDRY_BASE_URL: ${{ secrets.AZURE_ANTHROPIC_ENDPOINT }} diff --git a/.github/workflows/duplicate-code-detector.yml b/.github/workflows/duplicate-code-detector.yml index ea36bf54d..83896d1ea 100644 --- a/.github/workflows/duplicate-code-detector.yml +++ b/.github/workflows/duplicate-code-detector.yml @@ -42,10 +42,16 @@ jobs: } EOF + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.AWS_ROLE_TO_ASSUME }} + aws-region: ${{ secrets.AWS_REGION }} + - name: Run Claude Code uses: anthropics/claude-code-action@v1 with: - use_foundry: "true" + use_bedrock: "true" use_sticky_comment: true allowed_bots: "claude[bot],codeflash-ai[bot]" claude_args: '--mcp-config /tmp/mcp-config/mcp-servers.json --allowedTools "Read,Glob,Grep,Bash(git diff:*),Bash(git log:*),Bash(git show:*),Bash(wc *),Bash(find *),mcp__serena__*"' @@ -105,10 +111,6 @@ jobs: - Concrete refactoring suggestion If no significant duplication is found, say so briefly. Do not create issues — just comment on the PR. - env: - ANTHROPIC_FOUNDRY_API_KEY: ${{ secrets.AZURE_ANTHROPIC_API_KEY }} - ANTHROPIC_FOUNDRY_BASE_URL: ${{ secrets.AZURE_ANTHROPIC_ENDPOINT }} - - name: Stop Serena if: always() run: docker stop serena && docker rm serena || true From 09c026a7b91cccbc7192bdb3a91a04ae5759391c Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Tue, 17 Feb 2026 20:34:49 +0530 Subject: [PATCH 25/39] fix: use correct Bedrock inference profile ID (no :0 suffix) The cross-region inference profile for Claude Opus 4.6 on Bedrock is `us.anthropic.claude-opus-4-6-v1`, not `us.anthropic.claude-opus-4-6-v1:0`. Co-Authored-By: Claude Opus 4.6 --- .github/workflows/claude.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml index edb861183..6b17da886 100644 --- a/.github/workflows/claude.yml +++ b/.github/workflows/claude.yml @@ -179,7 +179,7 @@ jobs: 2. For each optimization PR: - Check if CI is passing: `gh pr checks ` - If all checks pass, merge it: `gh pr merge --squash --delete-branch` - claude_args: '--model us.anthropic.claude-opus-4-6-v1:0 --allowedTools "mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*),Bash(gh pr checks:*),Bash(gh pr merge:*),Bash(gh issue view:*),Bash(gh issue list:*),Bash(gh api:*),Bash(uv run prek *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(uv run pytest *),Bash(git status*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git diff *),Bash(git checkout *),Read,Glob,Grep,Edit"' + claude_args: '--model us.anthropic.claude-opus-4-6-v1 --allowedTools "mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*),Bash(gh pr checks:*),Bash(gh pr merge:*),Bash(gh issue view:*),Bash(gh issue list:*),Bash(gh api:*),Bash(uv run prek *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(uv run pytest *),Bash(git status*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git diff *),Bash(git checkout *),Read,Glob,Grep,Edit"' additional_permissions: | actions: read @@ -254,6 +254,6 @@ jobs: uses: anthropics/claude-code-action@v1 with: use_bedrock: "true" - claude_args: '--model us.anthropic.claude-opus-4-6-v1:0 --allowedTools "Read,Edit,Write,Glob,Grep,Bash(git status*),Bash(git diff*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git log*),Bash(git merge*),Bash(git fetch*),Bash(git checkout*),Bash(git branch*),Bash(uv run prek *),Bash(prek *),Bash(uv run ruff *),Bash(uv run pytest *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(gh pr comment*),Bash(gh pr view*),Bash(gh pr diff*),Bash(gh pr merge*),Bash(gh pr close*)"' + claude_args: '--model us.anthropic.claude-opus-4-6-v1 --allowedTools "Read,Edit,Write,Glob,Grep,Bash(git status*),Bash(git diff*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git log*),Bash(git merge*),Bash(git fetch*),Bash(git checkout*),Bash(git branch*),Bash(uv run prek *),Bash(prek *),Bash(uv run ruff *),Bash(uv run pytest *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(gh pr comment*),Bash(gh pr view*),Bash(gh pr diff*),Bash(gh pr merge*),Bash(gh pr close*)"' additional_permissions: | actions: read From d3074096e8b0ed45e9974c1a9c6a55e1947fdc77 Mon Sep 17 00:00:00 2001 From: ali Date: Tue, 17 Feb 2026 23:10:01 +0200 Subject: [PATCH 26/39] fix always execute capture perf for external runner --- codeflash/version.py | 2 +- packages/codeflash/runtime/capture.js | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/codeflash/version.py b/codeflash/version.py index 6d60ab0c2..6225467e3 100644 --- a/codeflash/version.py +++ b/codeflash/version.py @@ -1,2 +1,2 @@ # These version placeholders will be replaced by uv-dynamic-versioning during build. -__version__ = "0.20.0.post510.dev0+b8932209" +__version__ = "0.20.0" diff --git a/packages/codeflash/runtime/capture.js b/packages/codeflash/runtime/capture.js index 0fdcc5784..0b6180130 100644 --- a/packages/codeflash/runtime/capture.js +++ b/packages/codeflash/runtime/capture.js @@ -710,12 +710,12 @@ function capturePerf(funcName, lineId, fn, ...args) { for (let batchIndex = 0; batchIndex < batchSize; batchIndex++) { // Check shared time limit BEFORE each iteration - if (shouldLoop && checkSharedTimeLimit()) { + if (!hasExternalLoopRunner && shouldLoop && checkSharedTimeLimit()) { break; } // Check if this invocation has already reached stability - if (getPerfStabilityCheck() && sharedPerfState.stableInvocations[invocationKey]) { + if (!hasExternalLoopRunner && getPerfStabilityCheck() && sharedPerfState.stableInvocations[invocationKey]) { break; } @@ -724,7 +724,7 @@ function capturePerf(funcName, lineId, fn, ...args) { // Check if we've exceeded max loops for this invocation const totalIterations = getTotalIterations(invocationKey); - if (totalIterations > getPerfLoopCount()) { + if (!hasExternalLoopRunner && totalIterations > getPerfLoopCount()) { break; } @@ -776,7 +776,7 @@ function capturePerf(funcName, lineId, fn, ...args) { } // Check stability after accumulating enough samples - if (getPerfStabilityCheck() && runtimes.length >= getPerfMinLoops()) { + if (!hasExternalLoopRunner && getPerfStabilityCheck() && runtimes.length >= getPerfMinLoops()) { const window = getStabilityWindow(); if (shouldStopStability(runtimes, window, getPerfMinLoops())) { sharedPerfState.stableInvocations[invocationKey] = true; @@ -785,7 +785,7 @@ function capturePerf(funcName, lineId, fn, ...args) { } // If we had an error, stop looping - if (lastError) { + if (!hasExternalLoopRunner && lastError) { break; } } From 7b2692feab88da24cc2ddc36f9b22a08197fb9e5 Mon Sep 17 00:00:00 2001 From: ali Date: Wed, 18 Feb 2026 03:39:46 +0200 Subject: [PATCH 27/39] fix path mismatch bug --- codeflash/optimization/function_optimizer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 5e3a8a00f..ed7f7f1fe 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -2077,7 +2077,10 @@ def process_review( generated_tests_str = "" code_lang = self.function_to_optimize.language for test in generated_tests.generated_tests: - if map_gen_test_file_to_no_of_tests[test.behavior_file_path] > 0: + if any( + test_file.name == test.behavior_file_path.name and count > 0 + for test_file, count in map_gen_test_file_to_no_of_tests.items() + ): formatted_generated_test = format_generated_code( test.generated_original_test_source, self.args.formatter_cmds ) From 325ec7d7417bb8af3e53fe392ac072b5b555d2d1 Mon Sep 17 00:00:00 2001 From: KRRT7 Date: Wed, 18 Feb 2026 06:30:16 +0000 Subject: [PATCH 28/39] refactor: inline async decorators to remove codeflash import dependency Instead of injecting `from codeflash.code_utils.codeflash_wrap_decorator import ...` into instrumented source files, inject the decorator function definitions directly. This removes the hard dependency on the codeflash package being importable at runtime in the target environment, matching the pattern already used for sync instrumentation. --- .../code_utils/instrument_existing_tests.py | 235 ++++++++++++++++-- tests/test_async_run_and_parse_tests.py | 84 ++----- tests/test_instrument_async_tests.py | 220 ++++++++-------- 3 files changed, 363 insertions(+), 176 deletions(-) diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index 4366468d0..8c53c8e01 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -1497,15 +1497,218 @@ def _is_target_decorator(self, decorator_node: cst.Name | cst.Attribute | cst.Ca return False -class AsyncDecoratorImportAdder(cst.CSTTransformer): - """Transformer that adds the import for async decorators.""" +def get_behavior_async_inline_code() -> str: + return """import asyncio +import gc +import os +import sqlite3 +from functools import wraps +from pathlib import Path +from tempfile import TemporaryDirectory + +import dill as pickle + + +def get_run_tmp_file(file_path): + if not hasattr(get_run_tmp_file, "tmpdir"): + get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_") + return Path(get_run_tmp_file.tmpdir.name) / file_path + + +def extract_test_context_from_env(): + test_module = os.environ["CODEFLASH_TEST_MODULE"] + test_class = os.environ.get("CODEFLASH_TEST_CLASS", None) + test_function = os.environ["CODEFLASH_TEST_FUNCTION"] + if test_module and test_function: + return (test_module, test_class if test_class else None, test_function) + raise RuntimeError( + "Test context environment variables not set - ensure tests are run through codeflash test runner" + ) + + +def codeflash_behavior_async(func): + @wraps(func) + async def async_wrapper(*args, **kwargs): + loop = asyncio.get_running_loop() + function_name = func.__name__ + line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"] + loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"]) + test_module_name, test_class_name, test_name = extract_test_context_from_env() + test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}" + if not hasattr(async_wrapper, "index"): + async_wrapper.index = {} + if test_id in async_wrapper.index: + async_wrapper.index[test_id] += 1 + else: + async_wrapper.index[test_id] = 0 + codeflash_test_index = async_wrapper.index[test_id] + invocation_id = f"{line_id}_{codeflash_test_index}" + class_prefix = (test_class_name + ".") if test_class_name else "" + test_stdout_tag = f"{test_module_name}:{class_prefix}{test_name}:{function_name}:{loop_index}:{invocation_id}" + print(f"!$######{test_stdout_tag}######$!") + iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0") + db_path = get_run_tmp_file(Path(f"test_return_values_{iteration}.sqlite")) + codeflash_con = sqlite3.connect(db_path) + codeflash_cur = codeflash_con.cursor() + codeflash_cur.execute( + "CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, " + "test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " + "runtime INTEGER, return_value BLOB, verification_type TEXT)" + ) + exception = None + counter = loop.time() + gc.disable() + try: + ret = func(*args, **kwargs) + counter = loop.time() + return_value = await ret + codeflash_duration = int((loop.time() - counter) * 1_000_000_000) + except Exception as e: + codeflash_duration = int((loop.time() - counter) * 1_000_000_000) + exception = e + finally: + gc.enable() + print(f"!######{test_stdout_tag}######!") + pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps((args, kwargs, return_value)) + codeflash_cur.execute( + "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + test_module_name, + test_class_name, + test_name, + function_name, + loop_index, + invocation_id, + codeflash_duration, + pickled_return_value, + "function_call", + ), + ) + codeflash_con.commit() + codeflash_con.close() + if exception: + raise exception + return return_value + return async_wrapper +""" + + +def get_performance_async_inline_code() -> str: + return """import asyncio +import gc +import os +from functools import wraps + + +def extract_test_context_from_env(): + test_module = os.environ["CODEFLASH_TEST_MODULE"] + test_class = os.environ.get("CODEFLASH_TEST_CLASS", None) + test_function = os.environ["CODEFLASH_TEST_FUNCTION"] + if test_module and test_function: + return (test_module, test_class if test_class else None, test_function) + raise RuntimeError( + "Test context environment variables not set - ensure tests are run through codeflash test runner" + ) + + +def codeflash_performance_async(func): + @wraps(func) + async def async_wrapper(*args, **kwargs): + loop = asyncio.get_running_loop() + function_name = func.__name__ + line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"] + loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"]) + test_module_name, test_class_name, test_name = extract_test_context_from_env() + test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}" + if not hasattr(async_wrapper, "index"): + async_wrapper.index = {} + if test_id in async_wrapper.index: + async_wrapper.index[test_id] += 1 + else: + async_wrapper.index[test_id] = 0 + codeflash_test_index = async_wrapper.index[test_id] + invocation_id = f"{line_id}_{codeflash_test_index}" + class_prefix = (test_class_name + ".") if test_class_name else "" + test_stdout_tag = f"{test_module_name}:{class_prefix}{test_name}:{function_name}:{loop_index}:{invocation_id}" + print(f"!$######{test_stdout_tag}######$!") + exception = None + counter = loop.time() + gc.disable() + try: + ret = func(*args, **kwargs) + counter = loop.time() + return_value = await ret + codeflash_duration = int((loop.time() - counter) * 1_000_000_000) + except Exception as e: + codeflash_duration = int((loop.time() - counter) * 1_000_000_000) + exception = e + finally: + gc.enable() + print(f"!######{test_stdout_tag}:{codeflash_duration}######!") + if exception: + raise exception + return return_value + return async_wrapper +""" + + +def get_concurrency_async_inline_code() -> str: + return """import asyncio +import gc +import os +import time +from functools import wraps + + +def codeflash_concurrency_async(func): + @wraps(func) + async def async_wrapper(*args, **kwargs): + function_name = func.__name__ + concurrency_factor = int(os.environ.get("CODEFLASH_CONCURRENCY_FACTOR", "10")) + test_module_name = os.environ.get("CODEFLASH_TEST_MODULE", "") + test_class_name = os.environ.get("CODEFLASH_TEST_CLASS", "") + test_function = os.environ.get("CODEFLASH_TEST_FUNCTION", "") + loop_index = os.environ.get("CODEFLASH_LOOP_INDEX", "0") + gc.disable() + try: + seq_start = time.perf_counter_ns() + for _ in range(concurrency_factor): + result = await func(*args, **kwargs) + sequential_time = time.perf_counter_ns() - seq_start + finally: + gc.enable() + gc.disable() + try: + conc_start = time.perf_counter_ns() + tasks = [func(*args, **kwargs) for _ in range(concurrency_factor)] + await asyncio.gather(*tasks) + concurrent_time = time.perf_counter_ns() - conc_start + finally: + gc.enable() + tag = f"{test_module_name}:{test_class_name}:{test_function}:{function_name}:{loop_index}" + print(f"!@######CONC:{tag}:{sequential_time}:{concurrent_time}:{concurrency_factor}######@!") + return result + return async_wrapper +""" + + +def get_async_inline_code(mode: TestingMode) -> str: + if mode == TestingMode.BEHAVIOR: + return get_behavior_async_inline_code() + if mode == TestingMode.CONCURRENCY: + return get_concurrency_async_inline_code() + return get_performance_async_inline_code() + + +class AsyncInlineCodeInjector(cst.CSTTransformer): + """Injects async decorator function definitions inline instead of importing from codeflash.""" def __init__(self, mode: TestingMode = TestingMode.BEHAVIOR) -> None: self.mode = mode - self.has_import = False + self.has_inline_definition = False + self.has_old_import = False def _get_decorator_name(self) -> str: - """Get the decorator name based on the testing mode.""" if self.mode == TestingMode.BEHAVIOR: return "codeflash_behavior_async" if self.mode == TestingMode.CONCURRENCY: @@ -1513,7 +1716,6 @@ def _get_decorator_name(self) -> str: return "codeflash_performance_async" def visit_ImportFrom(self, node: cst.ImportFrom) -> None: - # Check if the async decorator import is already present if ( isinstance(node.module, cst.Attribute) and isinstance(node.module.value, cst.Attribute) @@ -1526,21 +1728,18 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None: decorator_name = self._get_decorator_name() for import_alias in node.names: if import_alias.name.value == decorator_name: - self.has_import = True + self.has_old_import = True + + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: + if node.name.value == self._get_decorator_name(): + self.has_inline_definition = True def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: - # If the import is already there, don't add it again - if self.has_import: + if self.has_inline_definition or self.has_old_import: return updated_node - - # Choose import based on mode - decorator_name = self._get_decorator_name() - - # Parse the import statement into a CST node - import_node = cst.parse_statement(f"from codeflash.code_utils.codeflash_wrap_decorator import {decorator_name}") - - # Add the import to the module's body - return updated_node.with_changes(body=[import_node, *list(updated_node.body)]) + inline_code = get_async_inline_code(self.mode) + inline_stmts = cst.parse_module(inline_code).body + return updated_node.with_changes(body=[*inline_stmts, *list(updated_node.body)]) def add_async_decorator_to_function( @@ -1575,7 +1774,7 @@ def add_async_decorator_to_function( # Add the import if decorator was added if decorator_transformer.added_decorator: - import_transformer = AsyncDecoratorImportAdder(mode) + import_transformer = AsyncInlineCodeInjector(mode) module = module.visit(import_transformer) modified_code = sort_imports(code=module.code, float_to_top=True) diff --git a/tests/test_async_run_and_parse_tests.py b/tests/test_async_run_and_parse_tests.py index 1eb667b3f..5750a015f 100644 --- a/tests/test_async_run_and_parse_tests.py +++ b/tests/test_async_run_and_parse_tests.py @@ -9,6 +9,7 @@ from codeflash.code_utils.instrument_existing_tests import ( add_async_decorator_to_function, + get_async_inline_code, inject_profiling_into_existing_test, ) from codeflash.discovery.functions_to_optimize import FunctionToOptimize @@ -59,12 +60,16 @@ async def test_async_sort(): assert source_success - # Verify the file was modified + # Verify the file was modified with exact expected output instrumented_source = fto_path.read_text("utf-8") - assert ( - '''import asyncio\nfrom typing import List, Union\n\nfrom codeflash.code_utils.codeflash_wrap_decorator import \\\n codeflash_behavior_async\n\n\n@codeflash_behavior_async\nasync def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation for testing.\n """\n print("codeflash stdout: Async sorting list")\n \n await asyncio.sleep(0.01)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n print(f"result: {result}")\n return result\n\n\nclass AsyncBubbleSorter:\n """Class with async sorting method for testing."""\n \n async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation within a class.\n """\n print("codeflash stdout: AsyncBubbleSorter.sorter() called")\n \n # Add some async delay\n await asyncio.sleep(0.005)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n return result\n''' - in instrumented_source + from codeflash.code_utils.formatter import sort_imports + + inline_code = get_async_inline_code(TestingMode.BEHAVIOR) + decorated_original = original_code.replace( + "async def async_sorter", "@codeflash_behavior_async\nasync def async_sorter" ) + expected = sort_imports(code=inline_code + decorated_original, float_to_top=True) + assert instrumented_source.strip() == expected.strip() # Add codeflash capture instrument_codeflash_capture(func, {}, tests_root) @@ -300,10 +305,14 @@ async def test_async_perf(): # Verify the file was modified instrumented_source = fto_path.read_text("utf-8") - assert ( - instrumented_source - == '''import asyncio\nfrom typing import List, Union\n\nfrom codeflash.code_utils.codeflash_wrap_decorator import \\\n codeflash_performance_async\n\n\n@codeflash_performance_async\nasync def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation for testing.\n """\n print("codeflash stdout: Async sorting list")\n \n await asyncio.sleep(0.01)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n print(f"result: {result}")\n return result\n\n\nclass AsyncBubbleSorter:\n """Class with async sorting method for testing."""\n \n async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation within a class.\n """\n print("codeflash stdout: AsyncBubbleSorter.sorter() called")\n \n # Add some async delay\n await asyncio.sleep(0.005)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n return result\n''' + from codeflash.code_utils.formatter import sort_imports + + inline_code = get_async_inline_code(TestingMode.PERFORMANCE) + decorated_original = original_code.replace( + "async def async_sorter", "@codeflash_performance_async\nasync def async_sorter" ) + expected = sort_imports(code=inline_code + decorated_original, float_to_top=True) + assert instrumented_source.strip() == expected.strip() instrument_codeflash_capture(func, {}, tests_root) @@ -411,61 +420,14 @@ async def async_error_function(lst): # Verify the file was modified instrumented_source = fto_path.read_text("utf-8") - expected_instrumented_source = """import asyncio -from typing import List, Union - -from codeflash.code_utils.codeflash_wrap_decorator import \\ - codeflash_behavior_async - - -async def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]: - \"\"\" - Async bubble sort implementation for testing. - \"\"\" - print("codeflash stdout: Async sorting list") - - await asyncio.sleep(0.01) - - n = len(lst) - for i in range(n): - for j in range(0, n - i - 1): - if lst[j] > lst[j + 1]: - lst[j], lst[j + 1] = lst[j + 1], lst[j] - - result = lst.copy() - print(f"result: {result}") - return result - + from codeflash.code_utils.formatter import sort_imports -class AsyncBubbleSorter: - \"\"\"Class with async sorting method for testing.\"\"\" - - async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]: - \"\"\" - Async bubble sort implementation within a class. - \"\"\" - print("codeflash stdout: AsyncBubbleSorter.sorter() called") - - # Add some async delay - await asyncio.sleep(0.005) - - n = len(lst) - for i in range(n): - for j in range(0, n - i - 1): - if lst[j] > lst[j + 1]: - lst[j], lst[j + 1] = lst[j + 1], lst[j] - - result = lst.copy() - return result - - -@codeflash_behavior_async -async def async_error_function(lst): - \"\"\"Async function that raises an error for testing.\"\"\" - await asyncio.sleep(0.001) # Small delay - raise ValueError("Test error") -""" - assert expected_instrumented_source == instrumented_source + inline_code = get_async_inline_code(TestingMode.BEHAVIOR) + decorated_modified = modified_code.replace( + "async def async_error_function", "@codeflash_behavior_async\nasync def async_error_function" + ) + expected = sort_imports(code=inline_code + decorated_modified, float_to_top=True) + assert instrumented_source.strip() == expected.strip() instrument_codeflash_capture(func, {}, tests_root) opt = Optimizer( diff --git a/tests/test_instrument_async_tests.py b/tests/test_instrument_async_tests.py index 29e65ad06..d700d91d5 100644 --- a/tests/test_instrument_async_tests.py +++ b/tests/test_instrument_async_tests.py @@ -7,6 +7,7 @@ from codeflash.code_utils.instrument_existing_tests import ( add_async_decorator_to_function, + get_async_inline_code, inject_profiling_into_existing_test, ) from codeflash.discovery.functions_to_optimize import FunctionToOptimize @@ -57,20 +58,6 @@ def test_async_decorator_application_behavior_mode(temp_dir): async_function_code = ''' import asyncio -async def async_function(x: int, y: int) -> int: - """Simple async function for testing.""" - await asyncio.sleep(0.01) - return x * y -''' - - expected_decorated_code = ''' -import asyncio - -from codeflash.code_utils.codeflash_wrap_decorator import \\ - codeflash_behavior_async - - -@codeflash_behavior_async async def async_function(x: int, y: int) -> int: """Simple async function for testing.""" await asyncio.sleep(0.01) @@ -86,7 +73,15 @@ async def async_function(x: int, y: int) -> int: assert decorator_added modified_code = test_file.read_text() - assert modified_code.strip() == expected_decorated_code.strip() + from codeflash.code_utils.formatter import sort_imports + + inline_code = get_async_inline_code(TestingMode.BEHAVIOR) + expected = sort_imports( + code=inline_code + "\n@codeflash_behavior_async\nasync def async_function(x: int, y: int) -> int:\n" + ' """Simple async function for testing."""\n await asyncio.sleep(0.01)\n return x * y\n', + float_to_top=True, + ) + assert modified_code.strip() == expected.strip() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -94,20 +89,6 @@ def test_async_decorator_application_performance_mode(temp_dir): async_function_code = ''' import asyncio -async def async_function(x: int, y: int) -> int: - """Simple async function for testing.""" - await asyncio.sleep(0.01) - return x * y -''' - - expected_decorated_code = ''' -import asyncio - -from codeflash.code_utils.codeflash_wrap_decorator import \\ - codeflash_performance_async - - -@codeflash_performance_async async def async_function(x: int, y: int) -> int: """Simple async function for testing.""" await asyncio.sleep(0.01) @@ -123,7 +104,15 @@ async def async_function(x: int, y: int) -> int: assert decorator_added modified_code = test_file.read_text() - assert modified_code.strip() == expected_decorated_code.strip() + from codeflash.code_utils.formatter import sort_imports + + inline_code = get_async_inline_code(TestingMode.PERFORMANCE) + expected = sort_imports( + code=inline_code + "\n@codeflash_performance_async\nasync def async_function(x: int, y: int) -> int:\n" + ' """Simple async function for testing."""\n await asyncio.sleep(0.01)\n return x * y\n', + float_to_top=True, + ) + assert modified_code.strip() == expected.strip() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -132,20 +121,6 @@ def test_async_decorator_application_concurrency_mode(temp_dir): async_function_code = ''' import asyncio -async def async_function(x: int, y: int) -> int: - """Simple async function for testing.""" - await asyncio.sleep(0.01) - return x * y -''' - - expected_decorated_code = ''' -import asyncio - -from codeflash.code_utils.codeflash_wrap_decorator import \\ - codeflash_concurrency_async - - -@codeflash_concurrency_async async def async_function(x: int, y: int) -> int: """Simple async function for testing.""" await asyncio.sleep(0.01) @@ -161,7 +136,15 @@ async def async_function(x: int, y: int) -> int: assert decorator_added modified_code = test_file.read_text() - assert modified_code.strip() == expected_decorated_code.strip() + from codeflash.code_utils.formatter import sort_imports + + inline_code = get_async_inline_code(TestingMode.CONCURRENCY) + expected = sort_imports( + code=inline_code + "\n@codeflash_concurrency_async\nasync def async_function(x: int, y: int) -> int:\n" + ' """Simple async function for testing."""\n await asyncio.sleep(0.01)\n return x * y\n', + float_to_top=True, + ) + assert modified_code.strip() == expected.strip() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -182,27 +165,6 @@ def sync_method(self, a: int, b: int) -> int: return a - b ''' - expected_decorated_code = ''' -import asyncio - -from codeflash.code_utils.codeflash_wrap_decorator import \\ - codeflash_behavior_async - - -class Calculator: - """Test class with async methods.""" - - @codeflash_behavior_async - async def async_method(self, a: int, b: int) -> int: - """Async method in class.""" - await asyncio.sleep(0.005) - return a ** b - - def sync_method(self, a: int, b: int) -> int: - """Sync method in class.""" - return a - b -''' - test_file = temp_dir / "test_async.py" test_file.write_text(async_class_code) @@ -217,11 +179,31 @@ def sync_method(self, a: int, b: int) -> int: assert decorator_added modified_code = test_file.read_text() - assert modified_code.strip() == expected_decorated_code.strip() + from codeflash.code_utils.formatter import sort_imports + + inline_code = get_async_inline_code(TestingMode.BEHAVIOR) + expected = sort_imports( + code=inline_code + + "\nclass Calculator:\n" + ' """Test class with async methods."""\n' + " \n" + " @codeflash_behavior_async\n" + " async def async_method(self, a: int, b: int) -> int:\n" + ' """Async method in class."""\n' + " await asyncio.sleep(0.005)\n" + " return a ** b\n" + " \n" + " def sync_method(self, a: int, b: int) -> int:\n" + ' """Sync method in class."""\n' + " return a - b\n", + float_to_top=True, + ) + assert modified_code.strip() == expected.strip() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") def test_async_decorator_no_duplicate_application(temp_dir): + # Case 1: Old-style import already present — injector should detect and skip already_decorated_code = ''' from codeflash.code_utils.codeflash_wrap_decorator import codeflash_behavior_async import asyncio @@ -243,6 +225,30 @@ async def async_function(x: int, y: int) -> int: # Should not add duplicate decorator assert not decorator_added + # Case 2: Inline definition already present — injector should detect and skip + already_inline_code = ''' +import asyncio + +def codeflash_behavior_async(func): + return func + +@codeflash_behavior_async +async def async_function(x: int, y: int) -> int: + """Already decorated async function.""" + await asyncio.sleep(0.01) + return x * y +''' + + test_file2 = temp_dir / "test_async2.py" + test_file2.write_text(already_inline_code) + + func2 = FunctionToOptimize(function_name="async_function", file_path=test_file2, parents=[], is_async=True) + + decorator_added2 = add_async_decorator_to_function(test_file2, func2, TestingMode.BEHAVIOR) + + # Should not add duplicate decorator + assert not decorator_added2 + @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") def test_inject_profiling_async_function_behavior_mode(temp_dir): @@ -285,11 +291,17 @@ async def test_async_function(): assert source_success is True - # Verify the file was modified + # Verify the file was modified with exact expected output instrumented_source = source_file.read_text() - assert "@codeflash_behavior_async" in instrumented_source - assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source - assert "codeflash_behavior_async" in instrumented_source + from codeflash.code_utils.formatter import sort_imports + + inline_code = get_async_inline_code(TestingMode.BEHAVIOR) + expected = sort_imports( + code=inline_code + "\n@codeflash_behavior_async\nasync def async_function(x: int, y: int) -> int:\n" + ' """Simple async function for testing."""\n await asyncio.sleep(0.01)\n return x * y\n', + float_to_top=True, + ) + assert instrumented_source.strip() == expected.strip() success, instrumented_test_code = inject_profiling_into_existing_test( test_file, [CodePosition(8, 18), CodePosition(11, 19)], func, temp_dir, mode=TestingMode.BEHAVIOR @@ -340,12 +352,17 @@ async def test_async_function(): assert source_success is True - # Verify the file was modified + # Verify the file was modified with exact expected output instrumented_source = source_file.read_text() - assert "@codeflash_performance_async" in instrumented_source - # Check for the import with line continuation formatting - assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source - assert "codeflash_performance_async" in instrumented_source + from codeflash.code_utils.formatter import sort_imports + + inline_code = get_async_inline_code(TestingMode.PERFORMANCE) + expected = sort_imports( + code=inline_code + "\n@codeflash_performance_async\nasync def async_function(x: int, y: int) -> int:\n" + ' """Simple async function for testing."""\n await asyncio.sleep(0.01)\n return x * y\n', + float_to_top=True, + ) + assert instrumented_source.strip() == expected.strip() # Now test the full pipeline with source module path success, instrumented_test_code = inject_profiling_into_existing_test( @@ -406,11 +423,21 @@ async def test_mixed_functions(): # Verify the file was modified instrumented_source = source_file.read_text() - assert "@codeflash_behavior_async" in instrumented_source - assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source - assert "codeflash_behavior_async" in instrumented_source - # Sync function should remain unchanged - assert "def sync_function(x: int, y: int) -> int:" in instrumented_source + from codeflash.code_utils.formatter import sort_imports + + inline_code = get_async_inline_code(TestingMode.BEHAVIOR) + expected = sort_imports( + code=inline_code + + "\ndef sync_function(x: int, y: int) -> int:\n" + ' """Regular sync function."""\n' + " return x * y\n" + "\n@codeflash_behavior_async\nasync def async_function(x: int, y: int) -> int:\n" + ' """Simple async function."""\n' + " await asyncio.sleep(0.01)\n" + " return x * y\n", + float_to_top=True, + ) + assert instrumented_source.strip() == expected.strip() success, instrumented_test_code = inject_profiling_into_existing_test( test_file, [CodePosition(8, 18), CodePosition(11, 19)], async_func, temp_dir, mode=TestingMode.BEHAVIOR @@ -446,24 +473,23 @@ async def nested_async_method(self, x: int) -> int: decorator_added = add_async_decorator_to_function(test_file, func, TestingMode.BEHAVIOR) - expected_output = """import asyncio - -from codeflash.code_utils.codeflash_wrap_decorator import \\ - codeflash_behavior_async - - -class OuterClass: - class InnerClass: - @codeflash_behavior_async - async def nested_async_method(self, x: int) -> int: - \"\"\"Nested async method.\"\"\" - await asyncio.sleep(0.001) - return x * 2 -""" - assert decorator_added modified_code = test_file.read_text() - assert modified_code.strip() == expected_output.strip() + from codeflash.code_utils.formatter import sort_imports + + inline_code = get_async_inline_code(TestingMode.BEHAVIOR) + expected = sort_imports( + code=inline_code + + "\nclass OuterClass: \n" + " class InnerClass: \n" + " @codeflash_behavior_async\n" + " async def nested_async_method(self, x: int) -> int:\n" + ' """Nested async method."""\n' + " await asyncio.sleep(0.001)\n" + " return x * 2\n", + float_to_top=True, + ) + assert modified_code.strip() == expected.strip() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") From 64a18c9870f2c785f27ceb85abc32c1d9d8db129 Mon Sep 17 00:00:00 2001 From: KRRT7 Date: Wed, 18 Feb 2026 08:17:08 +0000 Subject: [PATCH 29/39] refactor: use helper file for async decorator instrumentation Replace inline code injection with a helper file approach that writes decorator implementations to a separate codeflash_async_wrapper.py file. This removes the codeflash package import dependency from instrumented source files while keeping line numbers stable (only 1 import + 1 decorator line added, same as before). Co-Authored-By: Claude Opus 4.6 --- .../code_directories/async_e2e/main.py | 6 +- .../code_utils/instrument_existing_tests.py | 81 +++++------- codeflash/optimization/function_optimizer.py | 25 +++- tests/test_async_run_and_parse_tests.py | 73 ++++++++--- tests/test_instrument_async_tests.py | 115 ++++++++---------- 5 files changed, 161 insertions(+), 139 deletions(-) diff --git a/code_to_optimize/code_directories/async_e2e/main.py b/code_to_optimize/code_directories/async_e2e/main.py index 317068a1c..8ab92ccdc 100644 --- a/code_to_optimize/code_directories/async_e2e/main.py +++ b/code_to_optimize/code_directories/async_e2e/main.py @@ -1,4 +1,3 @@ -import time import asyncio @@ -6,11 +5,14 @@ async def retry_with_backoff(func, max_retries=3): if max_retries < 1: raise ValueError("max_retries must be at least 1") last_exception = None + _sleep = asyncio.sleep for attempt in range(max_retries): try: return await func() except Exception as e: last_exception = e if attempt < max_retries - 1: - time.sleep(0.0001 * attempt) + delay = 0.0001 * attempt + if delay: + await _sleep(delay) raise last_exception diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index 8c53c8e01..cea455cf0 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -1700,69 +1700,44 @@ def get_async_inline_code(mode: TestingMode) -> str: return get_performance_async_inline_code() -class AsyncInlineCodeInjector(cst.CSTTransformer): - """Injects async decorator function definitions inline instead of importing from codeflash.""" +ASYNC_HELPER_FILENAME = "codeflash_async_wrapper.py" - def __init__(self, mode: TestingMode = TestingMode.BEHAVIOR) -> None: - self.mode = mode - self.has_inline_definition = False - self.has_old_import = False - - def _get_decorator_name(self) -> str: - if self.mode == TestingMode.BEHAVIOR: - return "codeflash_behavior_async" - if self.mode == TestingMode.CONCURRENCY: - return "codeflash_concurrency_async" - return "codeflash_performance_async" - - def visit_ImportFrom(self, node: cst.ImportFrom) -> None: - if ( - isinstance(node.module, cst.Attribute) - and isinstance(node.module.value, cst.Attribute) - and isinstance(node.module.value.value, cst.Name) - and node.module.value.value.value == "codeflash" - and node.module.value.attr.value == "code_utils" - and node.module.attr.value == "codeflash_wrap_decorator" - and not isinstance(node.names, cst.ImportStar) - ): - decorator_name = self._get_decorator_name() - for import_alias in node.names: - if import_alias.name.value == decorator_name: - self.has_old_import = True - def visit_FunctionDef(self, node: cst.FunctionDef) -> None: - if node.name.value == self._get_decorator_name(): - self.has_inline_definition = True +def get_decorator_name_for_mode(mode: TestingMode) -> str: + if mode == TestingMode.BEHAVIOR: + return "codeflash_behavior_async" + if mode == TestingMode.CONCURRENCY: + return "codeflash_concurrency_async" + return "codeflash_performance_async" + - def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: - if self.has_inline_definition or self.has_old_import: - return updated_node - inline_code = get_async_inline_code(self.mode) - inline_stmts = cst.parse_module(inline_code).body - return updated_node.with_changes(body=[*inline_stmts, *list(updated_node.body)]) +def write_async_helper_file(target_dir: Path, mode: TestingMode) -> Path: + """Write the async decorator helper file to the target directory.""" + helper_path = target_dir / ASYNC_HELPER_FILENAME + if helper_path.exists(): + decorator_name = get_decorator_name_for_mode(mode) + if f"def {decorator_name}" in helper_path.read_text("utf-8"): + return helper_path + helper_path.write_text(get_async_inline_code(mode), "utf-8") + return helper_path def add_async_decorator_to_function( - source_path: Path, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR + source_path: Path, + function: FunctionToOptimize, + mode: TestingMode = TestingMode.BEHAVIOR, + project_root: Path | None = None, ) -> bool: """Add async decorator to an async function definition and write back to file. - Args: - ---- - source_path: Path to the source file to modify in-place. - function: The FunctionToOptimize object representing the target async function. - mode: The testing mode to determine which decorator to apply. - - Returns: - ------- - Boolean indicating whether the decorator was successfully added. + Writes a helper file containing the decorator implementation to project_root (or source directory + as fallback) and adds a standard import + decorator to the source file. """ if not function.is_async: return False try: - # Read source code with source_path.open(encoding="utf8") as f: source_code = f.read() @@ -1772,10 +1747,14 @@ def add_async_decorator_to_function( decorator_transformer = AsyncDecoratorAdder(function, mode) module = module.visit(decorator_transformer) - # Add the import if decorator was added if decorator_transformer.added_decorator: - import_transformer = AsyncInlineCodeInjector(mode) - module = module.visit(import_transformer) + # Write the helper file to project_root (on sys.path) or source dir as fallback + helper_dir = project_root if project_root is not None else source_path.parent + write_async_helper_file(helper_dir, mode) + # Add the import via CST so sort_imports can place it correctly + decorator_name = get_decorator_name_for_mode(mode) + import_node = cst.parse_statement(f"from codeflash_async_wrapper import {decorator_name}") + module = module.with_changes(body=[import_node, *list(module.body)]) modified_code = sort_imports(code=module.code, float_to_top=True) except Exception as e: diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index ed7f7f1fe..c4c68bf85 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -2296,7 +2296,10 @@ def establish_original_code_baseline( from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function success = add_async_decorator_to_function( - self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR + self.function_to_optimize.file_path, + self.function_to_optimize, + TestingMode.BEHAVIOR, + project_root=self.project_root, ) # Instrument codeflash capture @@ -2361,7 +2364,10 @@ def establish_original_code_baseline( from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function add_async_decorator_to_function( - self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE + self.function_to_optimize.file_path, + self.function_to_optimize, + TestingMode.PERFORMANCE, + project_root=self.project_root, ) try: @@ -2535,7 +2541,10 @@ def run_optimized_candidate( from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function add_async_decorator_to_function( - self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR + self.function_to_optimize.file_path, + self.function_to_optimize, + TestingMode.BEHAVIOR, + project_root=self.project_root, ) try: @@ -2611,7 +2620,10 @@ def run_optimized_candidate( from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function add_async_decorator_to_function( - self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE + self.function_to_optimize.file_path, + self.function_to_optimize, + TestingMode.PERFORMANCE, + project_root=self.project_root, ) try: @@ -2974,7 +2986,10 @@ def run_concurrency_benchmark( try: # Add concurrency decorator to the source function add_async_decorator_to_function( - self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.CONCURRENCY + self.function_to_optimize.file_path, + self.function_to_optimize, + TestingMode.CONCURRENCY, + project_root=self.project_root, ) # Run the concurrency benchmark tests diff --git a/tests/test_async_run_and_parse_tests.py b/tests/test_async_run_and_parse_tests.py index 5750a015f..1777a1c73 100644 --- a/tests/test_async_run_and_parse_tests.py +++ b/tests/test_async_run_and_parse_tests.py @@ -8,8 +8,9 @@ import pytest from codeflash.code_utils.instrument_existing_tests import ( + ASYNC_HELPER_FILENAME, add_async_decorator_to_function, - get_async_inline_code, + get_decorator_name_for_mode, inject_profiling_into_existing_test, ) from codeflash.discovery.functions_to_optimize import FunctionToOptimize @@ -56,7 +57,9 @@ async def test_async_sort(): func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True) # For async functions, instrument the source module directly with decorators - source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR) + source_success = add_async_decorator_to_function( + fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path + ) assert source_success @@ -64,11 +67,12 @@ async def test_async_sort(): instrumented_source = fto_path.read_text("utf-8") from codeflash.code_utils.formatter import sort_imports - inline_code = get_async_inline_code(TestingMode.BEHAVIOR) + decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR) decorated_original = original_code.replace( - "async def async_sorter", "@codeflash_behavior_async\nasync def async_sorter" + "async def async_sorter", f"@{decorator_name}\nasync def async_sorter" ) - expected = sort_imports(code=inline_code + decorated_original, float_to_top=True) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{decorated_original}" + expected = sort_imports(code=code_with_import, float_to_top=True) assert instrumented_source.strip() == expected.strip() # Add codeflash capture @@ -147,6 +151,9 @@ async def test_async_sort(): test_path.unlink() if test_path_perf.exists(): test_path_perf.unlink() + helper_path = project_root_path / ASYNC_HELPER_FILENAME + if helper_path.exists(): + helper_path.unlink() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -187,7 +194,9 @@ async def test_async_class_sort(): is_async=True, ) - source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR) + source_success = add_async_decorator_to_function( + fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path + ) assert source_success @@ -269,6 +278,9 @@ async def test_async_class_sort(): test_path.unlink() if test_path_perf.exists(): test_path_perf.unlink() + helper_path = project_root_path / ASYNC_HELPER_FILENAME + if helper_path.exists(): + helper_path.unlink() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -299,7 +311,9 @@ async def test_async_perf(): func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True) # Instrument the source module with async performance decorators - source_success = add_async_decorator_to_function(fto_path, func, TestingMode.PERFORMANCE) + source_success = add_async_decorator_to_function( + fto_path, func, TestingMode.PERFORMANCE, project_root=project_root_path + ) assert source_success @@ -307,11 +321,12 @@ async def test_async_perf(): instrumented_source = fto_path.read_text("utf-8") from codeflash.code_utils.formatter import sort_imports - inline_code = get_async_inline_code(TestingMode.PERFORMANCE) + decorator_name = get_decorator_name_for_mode(TestingMode.PERFORMANCE) decorated_original = original_code.replace( - "async def async_sorter", "@codeflash_performance_async\nasync def async_sorter" + "async def async_sorter", f"@{decorator_name}\nasync def async_sorter" ) - expected = sort_imports(code=inline_code + decorated_original, float_to_top=True) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{decorated_original}" + expected = sort_imports(code=code_with_import, float_to_top=True) assert instrumented_source.strip() == expected.strip() instrument_codeflash_capture(func, {}, tests_root) @@ -368,6 +383,9 @@ async def test_async_perf(): # Clean up test files if test_path.exists(): test_path.unlink() + helper_path = project_root_path / ASYNC_HELPER_FILENAME + if helper_path.exists(): + helper_path.unlink() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -413,7 +431,9 @@ async def async_error_function(lst): function_name="async_error_function", parents=[], file_path=Path(fto_path), is_async=True ) - source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR) + source_success = add_async_decorator_to_function( + fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path + ) assert source_success @@ -422,11 +442,12 @@ async def async_error_function(lst): from codeflash.code_utils.formatter import sort_imports - inline_code = get_async_inline_code(TestingMode.BEHAVIOR) + decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR) decorated_modified = modified_code.replace( - "async def async_error_function", "@codeflash_behavior_async\nasync def async_error_function" + "async def async_error_function", f"@{decorator_name}\nasync def async_error_function" ) - expected = sort_imports(code=inline_code + decorated_modified, float_to_top=True) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{decorated_modified}" + expected = sort_imports(code=code_with_import, float_to_top=True) assert instrumented_source.strip() == expected.strip() instrument_codeflash_capture(func, {}, tests_root) @@ -488,6 +509,9 @@ async def async_error_function(lst): test_path.unlink() if test_path_perf.exists(): test_path_perf.unlink() + helper_path = project_root_path / ASYNC_HELPER_FILENAME + if helper_path.exists(): + helper_path.unlink() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -525,7 +549,9 @@ async def test_async_multi(): func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True) - source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR) + source_success = add_async_decorator_to_function( + fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path + ) assert source_success instrument_codeflash_capture(func, {}, tests_root) @@ -598,6 +624,9 @@ async def test_async_multi(): test_path.unlink() if test_path_perf.exists(): test_path_perf.unlink() + helper_path = project_root_path / ASYNC_HELPER_FILENAME + if helper_path.exists(): + helper_path.unlink() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -640,7 +669,9 @@ async def test_async_edge_cases(): func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True) - source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR) + source_success = add_async_decorator_to_function( + fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path + ) assert source_success instrument_codeflash_capture(func, {}, tests_root) @@ -715,6 +746,9 @@ async def test_async_edge_cases(): test_path.unlink() if test_path_perf.exists(): test_path_perf.unlink() + helper_path = project_root_path / ASYNC_HELPER_FILENAME + if helper_path.exists(): + helper_path.unlink() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -949,7 +983,9 @@ async def test_mixed_sorting(): function_name="async_merge_sort", parents=[], file_path=Path(mixed_fto_path), is_async=True ) - source_success = add_async_decorator_to_function(mixed_fto_path, async_func, TestingMode.BEHAVIOR) + source_success = add_async_decorator_to_function( + mixed_fto_path, async_func, TestingMode.BEHAVIOR, project_root=project_root_path + ) assert source_success @@ -1022,3 +1058,6 @@ async def test_mixed_sorting(): test_path.unlink() if test_path_perf.exists(): test_path_perf.unlink() + helper_path = project_root_path / ASYNC_HELPER_FILENAME + if helper_path.exists(): + helper_path.unlink() diff --git a/tests/test_instrument_async_tests.py b/tests/test_instrument_async_tests.py index d700d91d5..0e57ec209 100644 --- a/tests/test_instrument_async_tests.py +++ b/tests/test_instrument_async_tests.py @@ -6,8 +6,9 @@ import pytest from codeflash.code_utils.instrument_existing_tests import ( + ASYNC_HELPER_FILENAME, add_async_decorator_to_function, - get_async_inline_code, + get_decorator_name_for_mode, inject_profiling_into_existing_test, ) from codeflash.discovery.functions_to_optimize import FunctionToOptimize @@ -75,13 +76,14 @@ async def async_function(x: int, y: int) -> int: modified_code = test_file.read_text() from codeflash.code_utils.formatter import sort_imports - inline_code = get_async_inline_code(TestingMode.BEHAVIOR) - expected = sort_imports( - code=inline_code + "\n@codeflash_behavior_async\nasync def async_function(x: int, y: int) -> int:\n" - ' """Simple async function for testing."""\n await asyncio.sleep(0.01)\n return x * y\n', - float_to_top=True, + decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR) + code_with_decorator = async_function_code.replace( + "async def async_function", f"@{decorator_name}\nasync def async_function" ) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}" + expected = sort_imports(code=code_with_import, float_to_top=True) assert modified_code.strip() == expected.strip() + assert (temp_dir / ASYNC_HELPER_FILENAME).exists() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -106,13 +108,14 @@ async def async_function(x: int, y: int) -> int: modified_code = test_file.read_text() from codeflash.code_utils.formatter import sort_imports - inline_code = get_async_inline_code(TestingMode.PERFORMANCE) - expected = sort_imports( - code=inline_code + "\n@codeflash_performance_async\nasync def async_function(x: int, y: int) -> int:\n" - ' """Simple async function for testing."""\n await asyncio.sleep(0.01)\n return x * y\n', - float_to_top=True, + decorator_name = get_decorator_name_for_mode(TestingMode.PERFORMANCE) + code_with_decorator = async_function_code.replace( + "async def async_function", f"@{decorator_name}\nasync def async_function" ) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}" + expected = sort_imports(code=code_with_import, float_to_top=True) assert modified_code.strip() == expected.strip() + assert (temp_dir / ASYNC_HELPER_FILENAME).exists() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -138,13 +141,14 @@ async def async_function(x: int, y: int) -> int: modified_code = test_file.read_text() from codeflash.code_utils.formatter import sort_imports - inline_code = get_async_inline_code(TestingMode.CONCURRENCY) - expected = sort_imports( - code=inline_code + "\n@codeflash_concurrency_async\nasync def async_function(x: int, y: int) -> int:\n" - ' """Simple async function for testing."""\n await asyncio.sleep(0.01)\n return x * y\n', - float_to_top=True, + decorator_name = get_decorator_name_for_mode(TestingMode.CONCURRENCY) + code_with_decorator = async_function_code.replace( + "async def async_function", f"@{decorator_name}\nasync def async_function" ) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}" + expected = sort_imports(code=code_with_import, float_to_top=True) assert modified_code.strip() == expected.strip() + assert (temp_dir / ASYNC_HELPER_FILENAME).exists() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -181,24 +185,14 @@ def sync_method(self, a: int, b: int) -> int: modified_code = test_file.read_text() from codeflash.code_utils.formatter import sort_imports - inline_code = get_async_inline_code(TestingMode.BEHAVIOR) - expected = sort_imports( - code=inline_code - + "\nclass Calculator:\n" - ' """Test class with async methods."""\n' - " \n" - " @codeflash_behavior_async\n" - " async def async_method(self, a: int, b: int) -> int:\n" - ' """Async method in class."""\n' - " await asyncio.sleep(0.005)\n" - " return a ** b\n" - " \n" - " def sync_method(self, a: int, b: int) -> int:\n" - ' """Sync method in class."""\n' - " return a - b\n", - float_to_top=True, + decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR) + code_with_decorator = async_class_code.replace( + " async def async_method", f" @{decorator_name}\n async def async_method" ) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}" + expected = sort_imports(code=code_with_import, float_to_top=True) assert modified_code.strip() == expected.strip() + assert (temp_dir / ASYNC_HELPER_FILENAME).exists() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -295,13 +289,14 @@ async def test_async_function(): instrumented_source = source_file.read_text() from codeflash.code_utils.formatter import sort_imports - inline_code = get_async_inline_code(TestingMode.BEHAVIOR) - expected = sort_imports( - code=inline_code + "\n@codeflash_behavior_async\nasync def async_function(x: int, y: int) -> int:\n" - ' """Simple async function for testing."""\n await asyncio.sleep(0.01)\n return x * y\n', - float_to_top=True, + decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR) + code_with_decorator = source_module_code.replace( + "async def async_function", f"@{decorator_name}\nasync def async_function" ) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}" + expected = sort_imports(code=code_with_import, float_to_top=True) assert instrumented_source.strip() == expected.strip() + assert (temp_dir / ASYNC_HELPER_FILENAME).exists() success, instrumented_test_code = inject_profiling_into_existing_test( test_file, [CodePosition(8, 18), CodePosition(11, 19)], func, temp_dir, mode=TestingMode.BEHAVIOR @@ -356,13 +351,14 @@ async def test_async_function(): instrumented_source = source_file.read_text() from codeflash.code_utils.formatter import sort_imports - inline_code = get_async_inline_code(TestingMode.PERFORMANCE) - expected = sort_imports( - code=inline_code + "\n@codeflash_performance_async\nasync def async_function(x: int, y: int) -> int:\n" - ' """Simple async function for testing."""\n await asyncio.sleep(0.01)\n return x * y\n', - float_to_top=True, + decorator_name = get_decorator_name_for_mode(TestingMode.PERFORMANCE) + code_with_decorator = source_module_code.replace( + "async def async_function", f"@{decorator_name}\nasync def async_function" ) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}" + expected = sort_imports(code=code_with_import, float_to_top=True) assert instrumented_source.strip() == expected.strip() + assert (temp_dir / ASYNC_HELPER_FILENAME).exists() # Now test the full pipeline with source module path success, instrumented_test_code = inject_profiling_into_existing_test( @@ -425,19 +421,14 @@ async def test_mixed_functions(): instrumented_source = source_file.read_text() from codeflash.code_utils.formatter import sort_imports - inline_code = get_async_inline_code(TestingMode.BEHAVIOR) - expected = sort_imports( - code=inline_code - + "\ndef sync_function(x: int, y: int) -> int:\n" - ' """Regular sync function."""\n' - " return x * y\n" - "\n@codeflash_behavior_async\nasync def async_function(x: int, y: int) -> int:\n" - ' """Simple async function."""\n' - " await asyncio.sleep(0.01)\n" - " return x * y\n", - float_to_top=True, + decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR) + code_with_decorator = source_module_code.replace( + "async def async_function", f"@{decorator_name}\nasync def async_function" ) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}" + expected = sort_imports(code=code_with_import, float_to_top=True) assert instrumented_source.strip() == expected.strip() + assert (temp_dir / ASYNC_HELPER_FILENAME).exists() success, instrumented_test_code = inject_profiling_into_existing_test( test_file, [CodePosition(8, 18), CodePosition(11, 19)], async_func, temp_dir, mode=TestingMode.BEHAVIOR @@ -477,19 +468,15 @@ async def nested_async_method(self, x: int) -> int: modified_code = test_file.read_text() from codeflash.code_utils.formatter import sort_imports - inline_code = get_async_inline_code(TestingMode.BEHAVIOR) - expected = sort_imports( - code=inline_code - + "\nclass OuterClass: \n" - " class InnerClass: \n" - " @codeflash_behavior_async\n" - " async def nested_async_method(self, x: int) -> int:\n" - ' """Nested async method."""\n' - " await asyncio.sleep(0.001)\n" - " return x * 2\n", - float_to_top=True, + decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR) + code_with_decorator = nested_async_code.replace( + " async def nested_async_method", + f" @{decorator_name}\n async def nested_async_method", ) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}" + expected = sort_imports(code=code_with_import, float_to_top=True) assert modified_code.strip() == expected.strip() + assert (temp_dir / ASYNC_HELPER_FILENAME).exists() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") From e0861214ec08485361617ea7f83a0855f642db66 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Wed, 18 Feb 2026 03:46:02 -0500 Subject: [PATCH 30/39] fix: clean up async helper file and combine all decorators into single file Write all three async decorator implementations into one helper file to avoid overwrite issues when switching modes. Clean up the helper file in revert_code_and_helpers and early-exit paths so it doesn't persist in the user's project root after optimization. --- .../code_utils/instrument_existing_tests.py | 52 +++---------------- codeflash/optimization/function_optimizer.py | 9 ++++ 2 files changed, 15 insertions(+), 46 deletions(-) diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index cea455cf0..9486fc677 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -1497,11 +1497,11 @@ def _is_target_decorator(self, decorator_node: cst.Name | cst.Attribute | cst.Ca return False -def get_behavior_async_inline_code() -> str: - return """import asyncio +ASYNC_HELPER_INLINE_CODE = """import asyncio import gc import os import sqlite3 +import time from functools import wraps from pathlib import Path from tempfile import TemporaryDirectory @@ -1590,25 +1590,6 @@ async def async_wrapper(*args, **kwargs): raise exception return return_value return async_wrapper -""" - - -def get_performance_async_inline_code() -> str: - return """import asyncio -import gc -import os -from functools import wraps - - -def extract_test_context_from_env(): - test_module = os.environ["CODEFLASH_TEST_MODULE"] - test_class = os.environ.get("CODEFLASH_TEST_CLASS", None) - test_function = os.environ["CODEFLASH_TEST_FUNCTION"] - if test_module and test_function: - return (test_module, test_class if test_class else None, test_function) - raise RuntimeError( - "Test context environment variables not set - ensure tests are run through codeflash test runner" - ) def codeflash_performance_async(func): @@ -1649,15 +1630,6 @@ async def async_wrapper(*args, **kwargs): raise exception return return_value return async_wrapper -""" - - -def get_concurrency_async_inline_code() -> str: - return """import asyncio -import gc -import os -import time -from functools import wraps def codeflash_concurrency_async(func): @@ -1691,15 +1663,6 @@ async def async_wrapper(*args, **kwargs): return async_wrapper """ - -def get_async_inline_code(mode: TestingMode) -> str: - if mode == TestingMode.BEHAVIOR: - return get_behavior_async_inline_code() - if mode == TestingMode.CONCURRENCY: - return get_concurrency_async_inline_code() - return get_performance_async_inline_code() - - ASYNC_HELPER_FILENAME = "codeflash_async_wrapper.py" @@ -1711,14 +1674,11 @@ def get_decorator_name_for_mode(mode: TestingMode) -> str: return "codeflash_performance_async" -def write_async_helper_file(target_dir: Path, mode: TestingMode) -> Path: +def write_async_helper_file(target_dir: Path) -> Path: """Write the async decorator helper file to the target directory.""" helper_path = target_dir / ASYNC_HELPER_FILENAME - if helper_path.exists(): - decorator_name = get_decorator_name_for_mode(mode) - if f"def {decorator_name}" in helper_path.read_text("utf-8"): - return helper_path - helper_path.write_text(get_async_inline_code(mode), "utf-8") + if not helper_path.exists(): + helper_path.write_text(ASYNC_HELPER_INLINE_CODE, "utf-8") return helper_path @@ -1750,7 +1710,7 @@ def add_async_decorator_to_function( if decorator_transformer.added_decorator: # Write the helper file to project_root (on sys.path) or source dir as fallback helper_dir = project_root if project_root is not None else source_path.parent - write_async_helper_file(helper_dir, mode) + write_async_helper_file(helper_dir) # Add the import via CST so sort_imports can place it correctly decorator_name = get_decorator_name_for_mode(mode) import_node = cst.parse_statement(f"from codeflash_async_wrapper import {decorator_name}") diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index c4c68bf85..0a515076c 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -1897,6 +1897,7 @@ def setup_and_establish_baseline( if self.args.override_fixtures: restore_conftest(original_conftest_content) cleanup_paths(paths_to_cleanup) + self.cleanup_async_helper_file() return Failure(baseline_result.failure()) original_code_baseline, test_functions_to_remove = baseline_result.unwrap() @@ -1908,6 +1909,7 @@ def setup_and_establish_baseline( if self.args.override_fixtures: restore_conftest(original_conftest_content) cleanup_paths(paths_to_cleanup) + self.cleanup_async_helper_file() return Failure("The threshold for test confidence was not met.") return Success( @@ -2279,6 +2281,13 @@ def revert_code_and_helpers(self, original_helper_code: dict[Path, str]) -> None self.write_code_and_helpers( self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path ) + self.cleanup_async_helper_file() + + def cleanup_async_helper_file(self) -> None: + from codeflash.code_utils.instrument_existing_tests import ASYNC_HELPER_FILENAME + + helper_path = self.project_root / ASYNC_HELPER_FILENAME + helper_path.unlink(missing_ok=True) def establish_original_code_baseline( self, From 950545119447c30dd79ca3414d3c0fe3ae603279 Mon Sep 17 00:00:00 2001 From: Kevin Turcios <106575910+KRRT7@users.noreply.github.com> Date: Wed, 18 Feb 2026 09:25:33 +0000 Subject: [PATCH 31/39] fix: revert async e2e fixture to use time.sleep() for optimization target The e2e test expects codeflash to detect and fix the intentional use of blocking time.sleep() in an async function. Using asyncio.sleep() removes the optimization opportunity and causes the CI job to fail. Co-Authored-By: Claude Opus 4.6 --- code_to_optimize/code_directories/async_e2e/main.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/code_to_optimize/code_directories/async_e2e/main.py b/code_to_optimize/code_directories/async_e2e/main.py index 8ab92ccdc..317068a1c 100644 --- a/code_to_optimize/code_directories/async_e2e/main.py +++ b/code_to_optimize/code_directories/async_e2e/main.py @@ -1,3 +1,4 @@ +import time import asyncio @@ -5,14 +6,11 @@ async def retry_with_backoff(func, max_retries=3): if max_retries < 1: raise ValueError("max_retries must be at least 1") last_exception = None - _sleep = asyncio.sleep for attempt in range(max_retries): try: return await func() except Exception as e: last_exception = e if attempt < max_retries - 1: - delay = 0.0001 * attempt - if delay: - await _sleep(delay) + time.sleep(0.0001 * attempt) raise last_exception From 6c092b5e7f73ccfc9ece386f2e7555c1ca8468dc Mon Sep 17 00:00:00 2001 From: Kevin Turcios <106575910+KRRT7@users.noreply.github.com> Date: Wed, 18 Feb 2026 09:45:51 +0000 Subject: [PATCH 32/39] fix: update expected coverage lines for optimized async e2e code The optimized code removes `import time`, shifting all function lines up by 1. Update expected_lines from [10-20] to [9-19] to match. Co-Authored-By: Claude Opus 4.6 --- tests/scripts/end_to_end_test_async.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/scripts/end_to_end_test_async.py b/tests/scripts/end_to_end_test_async.py index 0b4bf8957..0e38ae797 100644 --- a/tests/scripts/end_to_end_test_async.py +++ b/tests/scripts/end_to_end_test_async.py @@ -13,7 +13,7 @@ def run_test(expected_improvement_pct: int) -> bool: CoverageExpectation( function_name="retry_with_backoff", expected_coverage=100.0, - expected_lines=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], + expected_lines=[9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], ) ], ) From 0749621bee492f29128ca447e0c2c69c9e3f92e4 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Wed, 18 Feb 2026 05:57:17 -0500 Subject: [PATCH 33/39] fix: use original repo roots for filtering in worktree diff mode In --worktree mode, get_git_diff resolves file paths from cwd (the original repo), but module_root/project_root are mirrored to the worktree. This caused filter_functions to reject all diff-discovered functions as "outside module-root". Use the pre-mirror roots for filtering, then remap file paths to the worktree for downstream use. --- codeflash/optimization/optimizer.py | 38 ++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index b8f42010f..15284ba6f 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -183,18 +183,50 @@ def get_optimizable_functions(self) -> tuple[dict[Path, list[FunctionToOptimize] """Discover functions to optimize.""" from codeflash.discovery.functions_to_optimize import get_functions_to_optimize - return get_functions_to_optimize( + # In worktree mode for git-diff discovery, file paths come from the original repo + # (via get_git_diff using cwd), but module_root/project_root have been mirrored to + # the worktree. Use the original roots for filtering so path comparisons match, + # then remap the discovered file paths to the worktree. + project_root = self.args.project_root + module_root = self.args.module_root + use_original_roots = ( + self.current_worktree and self.original_args_and_test_cfg and not self.args.all and not self.args.file + ) + if use_original_roots: + original_args, _ = self.original_args_and_test_cfg + project_root = original_args.project_root + module_root = original_args.module_root + + result = get_functions_to_optimize( optimize_all=self.args.all, replay_test=self.args.replay_test, file=self.args.file, only_get_this_function=self.args.function, test_cfg=self.test_cfg, ignore_paths=self.args.ignore_paths, - project_root=self.args.project_root, - module_root=self.args.module_root, + project_root=project_root, + module_root=module_root, previous_checkpoint_functions=self.args.previous_checkpoint_functions, ) + # Remap discovered file paths from the original repo to the worktree so + # downstream optimization reads/writes happen in the worktree. + if use_original_roots: + import dataclasses + + original_git_root = git_root_dir() + file_to_funcs, count, trace = result + remapped: dict[Path, list[FunctionToOptimize]] = {} + for file_path, funcs in file_to_funcs.items(): + new_path = mirror_path(Path(file_path), original_git_root, self.current_worktree) + remapped[new_path] = [ + dataclasses.replace(func, file_path=mirror_path(func.file_path, original_git_root, self.current_worktree)) + for func in funcs + ] + return remapped, count, trace + + return result + def create_function_optimizer( self, function_to_optimize: FunctionToOptimize, From d43d9aeb4b8875144c677c09ab6f010c07573508 Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Wed, 18 Feb 2026 11:00:58 +0000 Subject: [PATCH 34/39] style: auto-fix ruff formatting for long line --- codeflash/optimization/optimizer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 15284ba6f..7964e31a7 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -220,7 +220,9 @@ def get_optimizable_functions(self) -> tuple[dict[Path, list[FunctionToOptimize] for file_path, funcs in file_to_funcs.items(): new_path = mirror_path(Path(file_path), original_git_root, self.current_worktree) remapped[new_path] = [ - dataclasses.replace(func, file_path=mirror_path(func.file_path, original_git_root, self.current_worktree)) + dataclasses.replace( + func, file_path=mirror_path(func.file_path, original_git_root, self.current_worktree) + ) for func in funcs ] return remapped, count, trace From 6a19b9d4b8232b8c034e8b424f21d4846ecfd7cb Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Wed, 18 Feb 2026 11:02:09 +0000 Subject: [PATCH 35/39] fix: add type assertions for mypy narrowing in worktree path remapping --- codeflash/optimization/optimizer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 7964e31a7..06540ca85 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -193,6 +193,7 @@ def get_optimizable_functions(self) -> tuple[dict[Path, list[FunctionToOptimize] self.current_worktree and self.original_args_and_test_cfg and not self.args.all and not self.args.file ) if use_original_roots: + assert self.original_args_and_test_cfg is not None original_args, _ = self.original_args_and_test_cfg project_root = original_args.project_root module_root = original_args.module_root @@ -214,6 +215,7 @@ def get_optimizable_functions(self) -> tuple[dict[Path, list[FunctionToOptimize] if use_original_roots: import dataclasses + assert self.current_worktree is not None original_git_root = git_root_dir() file_to_funcs, count, trace = result remapped: dict[Path, list[FunctionToOptimize]] = {} From 305210b1f7b4ba9b5042658db9d10638393cc15b Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Wed, 18 Feb 2026 06:32:12 -0500 Subject: [PATCH 36/39] fix: resolve git root from module_root for worktree PR creation git_root_dir() searches from CWD (original repo), but in worktree mode file paths have been remapped to the worktree. This caused relative_to() to raise ValueError when creating PRs. Search from module_root instead so root_dir is always in the same path space as the file paths. --- codeflash/optimization/function_optimizer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 0a515076c..372d7892f 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -13,6 +13,7 @@ from typing import TYPE_CHECKING, Callable import libcst as cst +from git import Repo as GitRepo from rich.console import Group from rich.panel import Panel from rich.syntax import Syntax @@ -2222,11 +2223,11 @@ def process_review( console.print(Panel(panel_content, title="Optimization Review", border_style=display_info[1])) if raise_pr or staging_review: - data["root_dir"] = git_root_dir() + data["root_dir"] = git_root_dir(GitRepo(str(self.args.module_root), search_parent_directories=True)) if raise_pr and not staging_review and opt_review_result.review != "low": # Ensure root_dir is set for PR creation (needed for async functions that skip opt_review) if "root_dir" not in data: - data["root_dir"] = git_root_dir() + data["root_dir"] = git_root_dir(GitRepo(str(self.args.module_root), search_parent_directories=True)) data["git_remote"] = self.args.git_remote # Remove language from data dict as check_create_pr doesn't accept it pr_data = {k: v for k, v in data.items() if k != "language"} From c76acaeba69a437736b93f2f89d9c3cf750ceffa Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Wed, 18 Feb 2026 15:24:49 -0500 Subject: [PATCH 37/39] chore: bump version to 0.20.1 --- codeflash/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeflash/version.py b/codeflash/version.py index 6225467e3..5c0c09b55 100644 --- a/codeflash/version.py +++ b/codeflash/version.py @@ -1,2 +1,2 @@ # These version placeholders will be replaced by uv-dynamic-versioning during build. -__version__ = "0.20.0" +__version__ = "0.20.1" From 7c7eeb5bc9db6c39f90c03775c44a0c408c80785 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Thu, 19 Feb 2026 20:39:42 -0500 Subject: [PATCH 38/39] fix: update test import for moved code_context_extractor module --- tests/test_languages/test_java_e2e.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_languages/test_java_e2e.py b/tests/test_languages/test_java_e2e.py index 1b6aa3ace..c01865048 100644 --- a/tests/test_languages/test_java_e2e.py +++ b/tests/test_languages/test_java_e2e.py @@ -89,7 +89,7 @@ def java_project_dir(self): def test_extract_code_context_for_java(self, java_project_dir): """Test extracting code context for a Java method.""" - from codeflash.context.code_context_extractor import get_code_optimization_context + from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context from codeflash.languages import current as lang_current from codeflash.languages.base import Language From ea48939787cf56ef4cc28015c2e00535fe01d7c7 Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Fri, 20 Feb 2026 01:45:01 +0000 Subject: [PATCH 39/39] style: auto-fix linting and formatting issues --- codeflash/cli_cmds/console.py | 11 ++++++- codeflash/cli_cmds/logging_config.py | 20 +++++++++++-- codeflash/languages/java/context.py | 12 ++------ codeflash/languages/java/instrumentation.py | 18 ++++++----- codeflash/languages/java/replacement.py | 8 +---- .../parse_line_profile_test_output.py | 30 +++++++------------ 6 files changed, 53 insertions(+), 46 deletions(-) diff --git a/codeflash/cli_cmds/console.py b/codeflash/cli_cmds/console.py index b1e4b45d8..5ca7f9eea 100644 --- a/codeflash/cli_cmds/console.py +++ b/codeflash/cli_cmds/console.py @@ -40,7 +40,16 @@ logging.basicConfig( level=logging.INFO, - handlers=[RichHandler(rich_tracebacks=True, markup=False, highlighter=NullHighlighter(), console=console, show_path=False, show_time=False)], + handlers=[ + RichHandler( + rich_tracebacks=True, + markup=False, + highlighter=NullHighlighter(), + console=console, + show_path=False, + show_time=False, + ) + ], format=BARE_LOGGING_FORMAT, ) diff --git a/codeflash/cli_cmds/logging_config.py b/codeflash/cli_cmds/logging_config.py index c2f339abd..dbb3663bd 100644 --- a/codeflash/cli_cmds/logging_config.py +++ b/codeflash/cli_cmds/logging_config.py @@ -14,7 +14,16 @@ def set_level(level: int, *, echo_setting: bool = True) -> None: logging.basicConfig( level=level, - handlers=[RichHandler(rich_tracebacks=True, markup=False, highlighter=NullHighlighter(), console=console, show_path=False, show_time=False)], + handlers=[ + RichHandler( + rich_tracebacks=True, + markup=False, + highlighter=NullHighlighter(), + console=console, + show_path=False, + show_time=False, + ) + ], format=BARE_LOGGING_FORMAT, ) logging.getLogger().setLevel(level) @@ -23,7 +32,14 @@ def set_level(level: int, *, echo_setting: bool = True) -> None: logging.basicConfig( format=VERBOSE_LOGGING_FORMAT, handlers=[ - RichHandler(rich_tracebacks=True, markup=False, highlighter=NullHighlighter(), console=console, show_path=False, show_time=False) + RichHandler( + rich_tracebacks=True, + markup=False, + highlighter=NullHighlighter(), + console=console, + show_path=False, + show_time=False, + ) ], force=True, ) diff --git a/codeflash/languages/java/context.py b/codeflash/languages/java/context.py index 29067f23f..394f52037 100644 --- a/codeflash/languages/java/context.py +++ b/codeflash/languages/java/context.py @@ -887,11 +887,7 @@ def collect_type_identifiers(node: Node) -> None: def get_java_imported_type_skeletons( - imports: list, - project_root: Path, - module_root: Path | None, - analyzer: JavaAnalyzer, - target_code: str = "", + imports: list, project_root: Path, module_root: Path | None, analyzer: JavaAnalyzer, target_code: str = "" ) -> str: """Extract type skeletons for project-internal imported types. @@ -1011,9 +1007,7 @@ def _extract_constructor_summaries(skeleton: TypeSkeleton) -> list[str]: return summaries -def _format_skeleton_for_context( - skeleton: TypeSkeleton, source: str, class_name: str, analyzer: JavaAnalyzer -) -> str: +def _format_skeleton_for_context(skeleton: TypeSkeleton, source: str, class_name: str, analyzer: JavaAnalyzer) -> str: """Format a TypeSkeleton into a context string with method signatures. Includes: type declaration, fields, constructors, and public method signatures @@ -1094,7 +1088,7 @@ def _extract_public_method_signatures(source: str, class_name: str, analyzer: Ja sig_parts_bytes.append(mod_slice) continue - if ctype == "block" or ctype == "constructor_body": + if ctype in {"block", "constructor_body"}: break sig_parts_bytes.append(source_bytes[child.start_byte : child.end_byte]) diff --git a/codeflash/languages/java/instrumentation.py b/codeflash/languages/java/instrumentation.py index 7cad460dd..18fdb1409 100644 --- a/codeflash/languages/java/instrumentation.py +++ b/codeflash/languages/java/instrumentation.py @@ -730,11 +730,17 @@ def split_var_declaration(stmt_node, source_bytes_ref: bytes) -> tuple[str, str] # The variable is assigned inside a for/try block which Java considers # conditionally executed, so an uninitialized declaration would cause # "variable might not have been initialized" errors. - _PRIMITIVE_DEFAULTS = { - "byte": "0", "short": "0", "int": "0", "long": "0L", - "float": "0.0f", "double": "0.0", "char": "'\\0'", "boolean": "false", + primitive_defaults = { + "byte": "0", + "short": "0", + "int": "0", + "long": "0L", + "float": "0.0f", + "double": "0.0", + "char": "'\\0'", + "boolean": "false", } - default_val = _PRIMITIVE_DEFAULTS.get(type_text, "null") + default_val = primitive_defaults.get(type_text, "null") hoisted = f"{type_text} {name_text} = {default_val};" assignment = f"{name_text} = {value_text};" return hoisted, assignment @@ -918,9 +924,7 @@ def build_instrumented_body(body_text: str, next_wrapper_id: int, base_indent: s replacements: list[tuple[int, int, bytes]] = [] wrapper_id = 0 - method_ordinal = 0 - for method_node, body_node in test_methods: - method_ordinal += 1 + for method_ordinal, (method_node, body_node) in enumerate(test_methods, start=1): body_start = body_node.start_byte + 1 # skip '{' body_end = body_node.end_byte - 1 # skip '}' body_text = source_bytes[body_start:body_end].decode("utf8") diff --git a/codeflash/languages/java/replacement.py b/codeflash/languages/java/replacement.py index 23e3c9232..a374043e5 100644 --- a/codeflash/languages/java/replacement.py +++ b/codeflash/languages/java/replacement.py @@ -374,13 +374,7 @@ def replace_function( class_name, ) source = _insert_class_members( - source, - class_name, - new_fields_to_add, - new_helpers_before, - new_helpers_after, - func_name, - analyzer, + source, class_name, new_fields_to_add, new_helpers_before, new_helpers_after, func_name, analyzer ) # Re-find the target method after modifications diff --git a/codeflash/verification/parse_line_profile_test_output.py b/codeflash/verification/parse_line_profile_test_output.py index 34b27bdb3..4ef799425 100644 --- a/codeflash/verification/parse_line_profile_test_output.py +++ b/codeflash/verification/parse_line_profile_test_output.py @@ -6,16 +6,14 @@ import json import linecache import os -from typing import TYPE_CHECKING, Optional +from pathlib import Path +from typing import Optional import dill as pickle from codeflash.code_utils.tabulate import tabulate from codeflash.languages import is_python -if TYPE_CHECKING: - from pathlib import Path - def show_func( filename: str, start_lineno: int, func_name: str, timings: list[tuple[int, int, float]], unit: float @@ -80,9 +78,7 @@ def show_text(stats: dict) -> str: return out_table -def show_text_non_python( - stats: dict, line_contents: dict[tuple[str, int], str] -) -> str: +def show_text_non_python(stats: dict, line_contents: dict[tuple[str, int], str]) -> str: """Show text for non-Python timings using profiler-provided line contents.""" out_table = "" out_table += "# Timer unit: {:g} s\n".format(stats["unit"]) @@ -100,13 +96,13 @@ def show_text_non_python( table_rows = [] for lineno, nhits, time in timings: percent = "" if total_time == 0 else "%5.1f" % (100 * time / total_time) - time_disp = "%5.1f" % time + time_disp = f"{time:5.1f}" if len(time_disp) > default_column_sizes["time"]: - time_disp = "%5.1g" % time + time_disp = f"{time:5.1g}" perhit = (float(time) / nhits) if nhits > 0 else 0.0 - perhit_disp = "%5.1f" % perhit + perhit_disp = f"{perhit:5.1f}" if len(perhit_disp) > default_column_sizes["perhit"]: - perhit_disp = "%5.1g" % perhit + perhit_disp = f"{perhit:5.1g}" nhits_disp = "%d" % nhits # noqa: UP031 if len(nhits_disp) > default_column_sizes["hits"]: nhits_disp = f"{nhits:g}" @@ -115,11 +111,7 @@ def show_text_non_python( table_cols = ("Hits", "Time", "Per Hit", "% Time", "Line Contents") out_table += tabulate( - headers=table_cols, - tabular_data=table_rows, - tablefmt="pipe", - colglobalalign=None, - preserve_whitespace=True, + headers=table_cols, tabular_data=table_rows, tablefmt="pipe", colglobalalign=None, preserve_whitespace=True ) out_table += "\n" return out_table @@ -159,9 +151,7 @@ def parse_line_profile_results(line_profiler_output_file: Optional[Path]) -> dic line_num = int(line_str) line_num = int(line_num) - lines_by_file.setdefault(file_path, []).append( - (line_num, int(stats.get("hits", 0)), int(stats.get("time", 0))) - ) + lines_by_file.setdefault(file_path, []).append((line_num, int(stats.get("hits", 0)), int(stats.get("time", 0)))) line_contents[(file_path, line_num)] = stats.get("content", "") for file_path, line_stats in lines_by_file.items(): @@ -169,7 +159,7 @@ def parse_line_profile_results(line_profiler_output_file: Optional[Path]) -> dic if not sorted_line_stats: continue start_lineno = sorted_line_stats[0][0] - grouped_timings[(file_path, start_lineno, os.path.basename(file_path))] = sorted_line_stats + grouped_timings[(file_path, start_lineno, Path(file_path).name)] = sorted_line_stats stats_dict["timings"] = grouped_timings stats_dict["unit"] = 1e-9