Skip to content

Commit ca882b7

Browse files
make sure js support is working fine
1 parent c884d99 commit ca882b7

12 files changed

Lines changed: 215 additions & 113 deletions

File tree

code_to_optimize/js/code_to_optimize_js/bubble_sort.js

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,26 +31,5 @@ function bubbleSort(arr) {
3131
return result;
3232
}
3333

34-
/**
35-
* Sort an array in descending order.
36-
* @param {number[]} arr - The array to sort
37-
* @returns {number[]} - The sorted array in descending order
38-
*/
39-
function bubbleSortDescending(arr) {
40-
const n = arr.length;
41-
const result = [...arr];
42-
43-
for (let i = 0; i < n - 1; i++) {
44-
for (let j = 0; j < n - i - 1; j++) {
45-
if (result[j] < result[j + 1]) {
46-
const temp = result[j];
47-
result[j] = result[j + 1];
48-
result[j + 1] = temp;
49-
}
50-
}
51-
}
52-
53-
return result;
54-
}
5534

56-
module.exports = { bubbleSort, bubbleSortDescending };
35+
module.exports = { bubbleSort };

code_to_optimize/js/code_to_optimize_js/tests/bubble_sort.test.js

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
const { bubbleSort, bubbleSortDescending } = require('../bubble_sort');
1+
const { bubbleSort } = require('../bubble_sort');
22

33
describe('bubbleSort', () => {
44
test('sorts an empty array', () => {
@@ -54,17 +54,3 @@ describe('bubbleSort', () => {
5454
expect(result[result.length - 1]).toBe(96);
5555
});
5656
});
57-
58-
describe('bubbleSortDescending', () => {
59-
test('sorts in descending order', () => {
60-
expect(bubbleSortDescending([1, 3, 2, 5, 4])).toEqual([5, 4, 3, 2, 1]);
61-
});
62-
63-
test('handles empty array', () => {
64-
expect(bubbleSortDescending([])).toEqual([]);
65-
});
66-
67-
test('handles single element', () => {
68-
expect(bubbleSortDescending([42])).toEqual([42]);
69-
});
70-
});

codeflash/languages/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,7 @@ def run_behavioral_tests(
956956
project_root: Path | None = None,
957957
enable_coverage: bool = False,
958958
candidate_index: int = 0,
959+
test_framework: str | None = None,
959960
) -> tuple[Path, Any, Path | None, Path | None]:
960961
"""Run behavioral tests for this language.
961962
@@ -967,6 +968,7 @@ def run_behavioral_tests(
967968
project_root: Project root directory.
968969
enable_coverage: Whether to collect coverage information.
969970
candidate_index: Index of the candidate being tested.
971+
test_framework: Test framework to use
970972
971973
Returns:
972974
Tuple of (result_file_path, subprocess_result, coverage_path, config_path).

codeflash/languages/javascript/test_runner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -796,6 +796,7 @@ def run_jest_behavioral_tests(
796796
# Build Jest command
797797
jest_cmd = [
798798
"npx",
799+
"-y",
799800
"jest",
800801
"--reporters=default",
801802
f"--reporters={CODEFLASH_JEST_REPORTER}",
@@ -1050,6 +1051,7 @@ def run_jest_benchmarking_tests(
10501051
# Build Jest command for performance tests
10511052
jest_cmd = [
10521053
"npx",
1054+
"-y",
10531055
"jest",
10541056
"--reporters=default",
10551057
f"--reporters={CODEFLASH_JEST_REPORTER}",
@@ -1220,6 +1222,7 @@ def run_jest_line_profile_tests(
12201222
# Build Jest command for line profiling - simple run without benchmarking loops
12211223
jest_cmd = [
12221224
"npx",
1225+
"-y",
12231226
"jest",
12241227
"--reporters=default",
12251228
f"--reporters={CODEFLASH_JEST_REPORTER}",

mcp_server/db.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def _create_tables(conn: sqlite3.Connection) -> None:
3838
created_at TEXT NOT NULL,
3939
project_root TEXT NOT NULL,
4040
test_files TEXT NOT NULL,
41-
total_runtime_ns INTEGER,
41+
best_summed_runtime_ns INTEGER,
4242
total_tests INTEGER,
4343
passed INTEGER,
4444
failed INTEGER,
@@ -79,23 +79,23 @@ def store_run(
7979
raw_stdout: str = "",
8080
raw_stderr: str = "",
8181
) -> None:
82-
total_runtime_ns = test_results.total_passed_runtime() if test_results else 0
82+
best_summed_runtime_ns = test_results.total_passed_runtime() if test_results else 0
8383
total_tests = len(test_results)
8484
passed = sum(1 for r in test_results if r.did_pass)
8585
failed = total_tests - passed
8686
loops_executed = test_results.effective_loop_count() if test_results else 0
8787

8888
conn.execute(
8989
"INSERT INTO runs (run_id, run_type, created_at, project_root, test_files, "
90-
"total_runtime_ns, total_tests, passed, failed, loops_executed, raw_stdout, raw_stderr) "
90+
"best_summed_runtime_ns, total_tests, passed, failed, loops_executed, raw_stdout, raw_stderr) "
9191
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
9292
(
9393
run_id,
9494
run_type,
9595
datetime.now(timezone.utc).isoformat(),
9696
project_root,
9797
json.dumps(test_files),
98-
total_runtime_ns,
98+
best_summed_runtime_ns,
9999
total_tests,
100100
passed,
101101
failed,
@@ -200,7 +200,7 @@ def load_test_results(conn: sqlite3.Connection, run_id: str) -> TestResults:
200200

201201
def load_run_metadata(conn: sqlite3.Connection, run_id: str) -> dict[str, Any] | None:
202202
row = conn.execute(
203-
"SELECT run_type, created_at, project_root, test_files, total_runtime_ns, "
203+
"SELECT run_type, created_at, project_root, test_files, best_summed_runtime_ns, "
204204
"total_tests, passed, failed, loops_executed FROM runs WHERE run_id = ?",
205205
(run_id,),
206206
).fetchone()
@@ -211,7 +211,7 @@ def load_run_metadata(conn: sqlite3.Connection, run_id: str) -> dict[str, Any] |
211211
"created_at": row[1],
212212
"project_root": row[2],
213213
"test_files": json.loads(row[3]),
214-
"total_runtime_ns": row[4],
214+
"best_summed_runtime_ns": row[4],
215215
"total_tests": row[5],
216216
"passed": row[6],
217217
"failed": row[7],

mcp_server/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class BehavioralRunResult:
2323
total_tests: int
2424
passed: int
2525
failed: int
26-
total_runtime_ns: int
26+
best_summed_runtime_ns: int
2727
test_results: list[TestInvocationResult]
2828
errors: list[str] = field(default_factory=list)
2929

@@ -58,7 +58,7 @@ class SpeedupInfo:
5858
@dataclass
5959
class BenchmarkRunResult:
6060
run_id: str
61-
total_runtime_ns: int
61+
best_summed_runtime_ns: int
6262
loops_executed: int
6363
test_results: list[TestInvocationResult]
6464
speedup: SpeedupInfo | None = None

mcp_server/runner.py

Lines changed: 81 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import inspect
34
from enum import Enum
45
from pathlib import Path
56
from typing import TYPE_CHECKING
@@ -53,12 +54,63 @@ def _build_test_config(project_root: Path, tests_dir: Path | None = None) -> Tes
5354
return TestConfig(tests_root=effective_tests_dir, project_root_path=project_root, tests_project_rootdir=effective_tests_dir)
5455

5556

56-
def _find_call_positions(test_path: Path, function_name: str) -> list:
57-
"""Scan a test file's AST to find all call sites of the target function."""
57+
def _build_fallback_function_to_optimize(module_path: Path, function_name: str, language: str):
58+
from codeflash.models.function_types import FunctionParent, FunctionToOptimize
59+
60+
qualified_name_parts = function_name.split(".")
61+
simple_name = qualified_name_parts[-1]
62+
parents = [FunctionParent(name=part, type="ClassDef") for part in qualified_name_parts[:-1]]
63+
return FunctionToOptimize(
64+
function_name=simple_name,
65+
file_path=module_path,
66+
parents=parents,
67+
is_method=bool(parents),
68+
language=language,
69+
)
70+
71+
72+
def _resolve_function_to_optimize(lang_support: object, module_path: str, function_name: str, language: str):
73+
from codeflash.languages.base import FunctionFilterCriteria
74+
75+
source_path = Path(module_path).resolve()
76+
fallback = _build_fallback_function_to_optimize(source_path, function_name, language)
77+
78+
try:
79+
source = source_path.read_text(encoding="utf-8")
80+
except OSError:
81+
return fallback
82+
83+
criteria = FunctionFilterCriteria(require_return=False, require_export=False)
84+
discovered_functions = lang_support.discover_functions(source, source_path, criteria)
85+
if not discovered_functions:
86+
return fallback
87+
88+
requested_name = function_name.rsplit(".", 1)[-1]
89+
90+
qualified_matches = [func for func in discovered_functions if func.qualified_name == function_name]
91+
if len(qualified_matches) == 1:
92+
return qualified_matches[0]
93+
94+
top_level_matches = [func for func in discovered_functions if func.qualified_name == requested_name]
95+
if len(top_level_matches) == 1:
96+
return top_level_matches[0]
97+
98+
simple_matches = [func for func in discovered_functions if func.function_name == requested_name]
99+
if len(simple_matches) == 1:
100+
return simple_matches[0]
101+
102+
return fallback
103+
104+
105+
def _find_call_positions(test_path: Path, function_name: str, language: str) -> list:
106+
"""Scan a Python test file's AST to find all call sites of the target function."""
58107
import ast
59108

60109
from codeflash.models.models import CodePosition
61110

111+
if language != "python":
112+
return []
113+
62114
try:
63115
source = test_path.read_text(encoding="utf-8")
64116
tree = ast.parse(source)
@@ -77,6 +129,15 @@ def _find_call_positions(test_path: Path, function_name: str) -> list:
77129
return positions
78130

79131

132+
def _invoke_with_optional_test_framework(run_callable: object, *, test_framework: str | None = None, **kwargs: object):
133+
try:
134+
if test_framework is not None and "test_framework" in inspect.signature(run_callable).parameters:
135+
kwargs["test_framework"] = test_framework
136+
except (TypeError, ValueError):
137+
pass
138+
return run_callable(**kwargs)
139+
140+
80141
class _InstrumentedFiles:
81142
"""Context manager that instruments test files in-place and restores originals on exit."""
82143

@@ -100,16 +161,15 @@ def __init__(
100161
def __enter__(self) -> list[str]:
101162
from codeflash.languages.current import set_current_language
102163
from codeflash.languages.registry import get_language_support
103-
from codeflash.models.function_types import FunctionToOptimize
104164

105165
set_current_language(self.language)
106166
lang_support = get_language_support(self.language)
107167

108-
func_to_optimize = FunctionToOptimize(
168+
func_to_optimize = _resolve_function_to_optimize(
169+
lang_support=lang_support,
170+
module_path=self.module_path,
109171
function_name=self.function_name,
110-
file_path=Path(self.module_path),
111-
parents=(),
112-
qualified_name=self.function_name,
172+
language=self.language,
113173
)
114174

115175
instrument_mode = "behavior" if self.mode == TestingMode.BEHAVIORAL else "performance"
@@ -118,8 +178,8 @@ def __enter__(self) -> list[str]:
118178
for test_file in self.test_file_paths:
119179
test_path = Path(test_file).resolve()
120180

121-
call_positions = _find_call_positions(test_path, self.function_name)
122-
if not call_positions:
181+
call_positions = _find_call_positions(test_path, func_to_optimize.function_name, self.language)
182+
if self.language == "python" and not call_positions:
123183
instrumented_paths.append(test_file)
124184
continue
125185

@@ -157,6 +217,7 @@ def run_and_parse(
157217
target_duration_seconds: float = 0.5,
158218
function_name: str | None = None,
159219
module_path: str | None = None,
220+
test_framework: str | None = None,
160221
) -> tuple[TestResults, subprocess.CompletedProcess[str]]:
161222
from codeflash.languages.current import set_current_language
162223
from codeflash.languages.registry import get_language_support
@@ -172,11 +233,19 @@ def _execute(effective_files: list[str]) -> tuple[TestResults, subprocess.Comple
172233
test_files_obj = _build_test_files(effective_files, mode)
173234

174235
if mode == TestingMode.BEHAVIORAL:
175-
result_file_path, run_result, _, _ = lang_support.run_behavioral_tests(
176-
test_paths=test_files_obj, test_env=test_env, cwd=project_root, timeout=timeout, project_root=project_root
236+
result_file_path, run_result, _, _ = _invoke_with_optional_test_framework(
237+
lang_support.run_behavioral_tests,
238+
test_framework=test_framework,
239+
test_paths=test_files_obj,
240+
test_env=test_env,
241+
cwd=project_root,
242+
timeout=timeout,
243+
project_root=project_root,
177244
)
178245
else:
179-
result_file_path, run_result = lang_support.run_benchmarking_tests(
246+
result_file_path, run_result = _invoke_with_optional_test_framework(
247+
lang_support.run_benchmarking_tests,
248+
test_framework=test_framework,
180249
test_paths=test_files_obj,
181250
test_env=test_env,
182251
cwd=project_root,

mcp_server/server.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def run_behavioral_tests(
1919
run_id: str | None = None,
2020
function_name: str | None = None,
2121
module_path: str | None = None,
22+
test_framework: str | None = None,
2223
) -> dict[str, Any]:
2324
"""Run tests and capture function return values + timing for each test invocation.
2425
@@ -39,9 +40,10 @@ def run_behavioral_tests(
3940
run_id: Identifier for this run. Use descriptive IDs like "baseline-exp-1". Auto-generated UUID if omitted.
4041
function_name: Name of the function being optimized. When provided with module_path, enables automatic instrumentation of test files to capture return values and precise timing.
4142
module_path: Absolute path to the source file containing the function being optimized. Required together with function_name for instrumentation.
43+
test_framework: Optional test framework override. If omitted, codeflash will try to detect the framework automatically. For Python, the supported value is `pytest`. For JavaScript/TypeScript, supported values are `jest`, `vitest`, and `mocha`. For Java and Go, leave this unset because it is not used.
4244
4345
Returns:
44-
run_id, total_tests, passed, failed, total_runtime_ns, test_results (per-test detail), errors.
46+
run_id, total_tests, passed, failed, best_summed_runtime_ns, test_results (per-test detail), errors.
4547
4648
"""
4749
from mcp_server.tools.behavioral import run_behavioral_tests as impl
@@ -54,6 +56,7 @@ def run_behavioral_tests(
5456
run_id=run_id,
5557
function_name=function_name,
5658
module_path=module_path,
59+
test_framework=test_framework,
5760
)
5861

5962

@@ -98,6 +101,7 @@ def run_benchmarking_tests(
98101
baseline_run_id: str | None = None,
99102
function_name: str | None = None,
100103
module_path: str | None = None,
104+
test_framework: str | None = None,
101105
) -> dict[str, Any]:
102106
"""Run tests in multi-loop mode for stable timing, then compute speedup against a baseline.
103107
@@ -132,9 +136,10 @@ def run_benchmarking_tests(
132136
baseline_run_id: Run ID of a previous benchmark to compare against. Omit for baseline capture.
133137
function_name: Name of the function being benchmarked. When provided with module_path, enables automatic instrumentation with performance-mode timing capture.
134138
module_path: Absolute path to the source file containing the function. Required together with function_name.
139+
test_framework: Optional test framework override. If omitted, codeflash will try to detect the framework automatically. For Python, the supported value is `pytest`. For JavaScript/TypeScript, supported values are `jest`, `vitest`, and `mocha`. For Java and Go, leave this unset because it is not used.
135140
136141
Returns:
137-
run_id, total_runtime_ns, loops_executed, test_results, speedup (null if no baseline_run_id).
142+
run_id, best_summed_runtime_ns, loops_executed, test_results, speedup (null if no baseline_run_id).
138143
139144
"""
140145
from mcp_server.tools.benchmarking import run_benchmarking_tests as impl
@@ -151,6 +156,7 @@ def run_benchmarking_tests(
151156
baseline_run_id=baseline_run_id,
152157
function_name=function_name,
153158
module_path=module_path,
159+
test_framework=test_framework,
154160
)
155161

156162

0 commit comments

Comments
 (0)