Skip to content

Commit b5a7300

Browse files
Fix MCP Java test instrumentation paths and use larger Java benchmark inputs
1 parent f18ee12 commit b5a7300

4 files changed

Lines changed: 144 additions & 165 deletions

File tree

code_to_optimize/java/src/main/java/com/example/BubbleSort.java

Lines changed: 0 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -36,119 +36,4 @@ public static int[] bubbleSort(int[] arr) {
3636
return result;
3737
}
3838

39-
/**
40-
* Sort an array in descending order using bubble sort.
41-
*
42-
* @param arr Array to sort
43-
* @return New sorted array (descending order)
44-
*/
45-
public static int[] bubbleSortDescending(int[] arr) {
46-
if (arr == null || arr.length == 0) {
47-
return arr;
48-
}
49-
50-
int[] result = new int[arr.length];
51-
for (int i = 0; i < arr.length; i++) {
52-
result[i] = arr[i];
53-
}
54-
55-
int n = result.length;
56-
57-
for (int i = 0; i < n - 1; i++) {
58-
for (int j = 0; j < n - i - 1; j++) {
59-
if (result[j] < result[j + 1]) {
60-
int temp = result[j];
61-
result[j] = result[j + 1];
62-
result[j + 1] = temp;
63-
}
64-
}
65-
}
66-
67-
return result;
68-
}
69-
70-
/**
71-
* Sort an array using insertion sort algorithm.
72-
*
73-
* @param arr Array to sort
74-
* @return New sorted array
75-
*/
76-
public static int[] insertionSort(int[] arr) {
77-
if (arr == null || arr.length == 0) {
78-
return arr;
79-
}
80-
81-
int[] result = new int[arr.length];
82-
for (int i = 0; i < arr.length; i++) {
83-
result[i] = arr[i];
84-
}
85-
86-
int n = result.length;
87-
88-
for (int i = 1; i < n; i++) {
89-
int key = result[i];
90-
int j = i - 1;
91-
92-
while (j >= 0 && result[j] > key) {
93-
result[j + 1] = result[j];
94-
j = j - 1;
95-
}
96-
result[j + 1] = key;
97-
}
98-
99-
return result;
100-
}
101-
102-
/**
103-
* Sort an array using selection sort algorithm.
104-
*
105-
* @param arr Array to sort
106-
* @return New sorted array
107-
*/
108-
public static int[] selectionSort(int[] arr) {
109-
if (arr == null || arr.length == 0) {
110-
return arr;
111-
}
112-
113-
int[] result = new int[arr.length];
114-
for (int i = 0; i < arr.length; i++) {
115-
result[i] = arr[i];
116-
}
117-
118-
int n = result.length;
119-
120-
for (int i = 0; i < n - 1; i++) {
121-
int minIdx = i;
122-
for (int j = i + 1; j < n; j++) {
123-
if (result[j] < result[minIdx]) {
124-
minIdx = j;
125-
}
126-
}
127-
128-
int temp = result[minIdx];
129-
result[minIdx] = result[i];
130-
result[i] = temp;
131-
}
132-
133-
return result;
134-
}
135-
136-
/**
137-
* Check if an array is sorted in ascending order.
138-
*
139-
* @param arr Array to check
140-
* @return true if sorted in ascending order
141-
*/
142-
public static boolean isSorted(int[] arr) {
143-
if (arr == null || arr.length <= 1) {
144-
return true;
145-
}
146-
147-
for (int i = 0; i < arr.length - 1; i++) {
148-
if (arr[i] > arr[i + 1]) {
149-
return false;
150-
}
151-
}
152-
return true;
153-
}
15439
}
Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
package com.example;
22

33
import org.junit.jupiter.api.Test;
4+
5+
import java.util.Arrays;
6+
47
import static org.junit.jupiter.api.Assertions.*;
58

69
/**
710
* Tests for BubbleSort sorting algorithms.
811
*/
912
class BubbleSortTest {
13+
private static final int LARGE_SORT_SIZE = 5000;
1014

1115
@Test
1216
void testBubbleSort() {
13-
assertArrayEquals(new int[]{1, 2, 3, 4, 5}, BubbleSort.bubbleSort(new int[]{5, 3, 1, 4, 2}));
17+
assertArrayEquals(ascendingRange(LARGE_SORT_SIZE), BubbleSort.bubbleSort(descendingRange(LARGE_SORT_SIZE)));
1418
assertArrayEquals(new int[]{1, 2, 3}, BubbleSort.bubbleSort(new int[]{3, 2, 1}));
1519
assertArrayEquals(new int[]{1}, BubbleSort.bubbleSort(new int[]{1}));
1620
assertArrayEquals(new int[]{}, BubbleSort.bubbleSort(new int[]{}));
@@ -19,56 +23,57 @@ void testBubbleSort() {
1923

2024
@Test
2125
void testBubbleSortAlreadySorted() {
22-
assertArrayEquals(new int[]{1, 2, 3, 4, 5}, BubbleSort.bubbleSort(new int[]{1, 2, 3, 4, 5}));
26+
int[] sorted = ascendingRange(LARGE_SORT_SIZE);
27+
assertArrayEquals(sorted, BubbleSort.bubbleSort(sorted));
2328
}
2429

2530
@Test
2631
void testBubbleSortWithDuplicates() {
27-
assertArrayEquals(new int[]{1, 2, 2, 3, 3, 4}, BubbleSort.bubbleSort(new int[]{3, 2, 4, 1, 3, 2}));
32+
int[] input = duplicateHeavyRange(LARGE_SORT_SIZE);
33+
assertArrayEquals(sortedCopy(input), BubbleSort.bubbleSort(input));
2834
}
2935

3036
@Test
3137
void testBubbleSortWithNegatives() {
32-
assertArrayEquals(new int[]{-5, -2, 0, 3, 7}, BubbleSort.bubbleSort(new int[]{3, -2, 7, 0, -5}));
38+
int[] input = mixedSignedRange(LARGE_SORT_SIZE);
39+
assertArrayEquals(sortedCopy(input), BubbleSort.bubbleSort(input));
3340
}
3441

35-
@Test
36-
void testBubbleSortDescending() {
37-
assertArrayEquals(new int[]{5, 4, 3, 2, 1}, BubbleSort.bubbleSortDescending(new int[]{1, 3, 5, 2, 4}));
38-
assertArrayEquals(new int[]{3, 2, 1}, BubbleSort.bubbleSortDescending(new int[]{1, 2, 3}));
39-
assertArrayEquals(new int[]{}, BubbleSort.bubbleSortDescending(new int[]{}));
42+
private static int[] ascendingRange(int size) {
43+
int[] arr = new int[size];
44+
for (int i = 0; i < size; i++) {
45+
arr[i] = i;
46+
}
47+
return arr;
4048
}
4149

42-
@Test
43-
void testInsertionSort() {
44-
assertArrayEquals(new int[]{1, 2, 3, 4, 5}, BubbleSort.insertionSort(new int[]{5, 3, 1, 4, 2}));
45-
assertArrayEquals(new int[]{1, 2, 3}, BubbleSort.insertionSort(new int[]{3, 2, 1}));
46-
assertArrayEquals(new int[]{1}, BubbleSort.insertionSort(new int[]{1}));
47-
assertArrayEquals(new int[]{}, BubbleSort.insertionSort(new int[]{}));
50+
private static int[] descendingRange(int size) {
51+
int[] arr = new int[size];
52+
for (int i = 0; i < size; i++) {
53+
arr[i] = size - i - 1;
54+
}
55+
return arr;
4856
}
4957

50-
@Test
51-
void testSelectionSort() {
52-
assertArrayEquals(new int[]{1, 2, 3, 4, 5}, BubbleSort.selectionSort(new int[]{5, 3, 1, 4, 2}));
53-
assertArrayEquals(new int[]{1, 2, 3}, BubbleSort.selectionSort(new int[]{3, 2, 1}));
54-
assertArrayEquals(new int[]{1}, BubbleSort.selectionSort(new int[]{1}));
58+
private static int[] duplicateHeavyRange(int size) {
59+
int[] arr = new int[size];
60+
for (int i = 0; i < size; i++) {
61+
arr[i] = (size - i - 1) % 32;
62+
}
63+
return arr;
5564
}
5665

57-
@Test
58-
void testIsSorted() {
59-
assertTrue(BubbleSort.isSorted(new int[]{1, 2, 3, 4, 5}));
60-
assertTrue(BubbleSort.isSorted(new int[]{1}));
61-
assertTrue(BubbleSort.isSorted(new int[]{}));
62-
assertTrue(BubbleSort.isSorted(null));
63-
assertFalse(BubbleSort.isSorted(new int[]{5, 3, 1}));
64-
assertFalse(BubbleSort.isSorted(new int[]{1, 3, 2}));
66+
private static int[] mixedSignedRange(int size) {
67+
int[] arr = new int[size];
68+
for (int i = 0; i < size; i++) {
69+
arr[i] = (i % 2 == 0) ? (size - i) : -(size - i);
70+
}
71+
return arr;
6572
}
6673

67-
@Test
68-
void testBubbleSortDoesNotMutateInput() {
69-
int[] original = {5, 3, 1, 4, 2};
70-
int[] copy = {5, 3, 1, 4, 2};
71-
BubbleSort.bubbleSort(original);
72-
assertArrayEquals(copy, original);
74+
private static int[] sortedCopy(int[] arr) {
75+
int[] expected = arr.clone();
76+
Arrays.sort(expected);
77+
return expected;
7378
}
7479
}

mcp_server/runner.py

Lines changed: 63 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import inspect
4+
from dataclasses import dataclass
45
from enum import Enum
56
from pathlib import Path
67
from typing import TYPE_CHECKING
@@ -19,6 +20,12 @@ class TestingMode(str, Enum):
1920
BENCHMARKING = "benchmarking"
2021

2122

23+
@dataclass(frozen=True)
24+
class _ResolvedTestFile:
25+
original_path: Path
26+
effective_path: Path
27+
28+
2229
def build_test_env(project_root: Path) -> dict[str, str]:
2330
env = make_env_with_project_root(project_root)
2431
env["CODEFLASH_TEST_ITERATION"] = "0"
@@ -29,18 +36,19 @@ def build_test_env(project_root: Path) -> dict[str, str]:
2936
return env
3037

3138

32-
def _build_test_files(test_file_paths: list[str], mode: TestingMode) -> TestFiles:
39+
def _build_test_files(test_files: list[_ResolvedTestFile], mode: TestingMode) -> TestFiles:
3340
from codeflash.models.models import TestFile, TestFiles
3441
from codeflash.models.test_type import TestType
3542

3643
test_files_objs = []
37-
for path_str in test_file_paths:
38-
p = Path(path_str).resolve()
44+
for test_file in test_files:
45+
effective_path = test_file.effective_path.resolve()
46+
original_path = test_file.original_path.resolve()
3947
test_files_objs.append(
4048
TestFile(
41-
instrumented_behavior_file_path=p,
42-
benchmarking_file_path=p if mode == TestingMode.BENCHMARKING else None,
43-
original_file_path=p,
49+
instrumented_behavior_file_path=effective_path,
50+
benchmarking_file_path=effective_path if mode == TestingMode.BENCHMARKING else None,
51+
original_file_path=original_path,
4452
test_type=TestType.EXISTING_UNIT_TEST,
4553
)
4654
)
@@ -138,8 +146,31 @@ def _invoke_with_optional_test_framework(run_callable: object, *, test_framework
138146
return run_callable(**kwargs)
139147

140148

149+
def _resolve_test_files(test_file_paths: list[str]) -> list[_ResolvedTestFile]:
150+
return [_ResolvedTestFile(original_path=Path(path).resolve(), effective_path=Path(path).resolve()) for path in test_file_paths]
151+
152+
153+
def _instrumented_test_path(test_path: Path, language: str, mode: TestingMode) -> Path:
154+
if language != "java":
155+
return test_path
156+
157+
suffix = "__perfinstrumented" if mode == TestingMode.BEHAVIORAL else "__perfonlyinstrumented"
158+
if test_path.stem.endswith(suffix):
159+
return test_path
160+
return test_path.with_name(f"{test_path.stem}{suffix}{test_path.suffix}")
161+
162+
163+
def _reset_java_compilation_cache(language: str) -> None:
164+
if language != "java":
165+
return
166+
167+
from codeflash.languages.java.test_runner import CompilationCache
168+
169+
CompilationCache.clear()
170+
171+
141172
class _InstrumentedFiles:
142-
"""Context manager that instruments test files in-place and restores originals on exit."""
173+
"""Context manager that instruments MCP test files and restores originals on exit."""
143174

144175
def __init__(
145176
self,
@@ -157,8 +188,17 @@ def __init__(
157188
self.language = language
158189
self.mode = mode
159190
self._backups: dict[Path, str] = {}
191+
self._created_files: set[Path] = set()
160192

161-
def __enter__(self) -> list[str]:
193+
def _write_instrumented_source(self, target_path: Path, code: str) -> None:
194+
if target_path.exists():
195+
self._backups[target_path] = target_path.read_text(encoding="utf-8")
196+
else:
197+
self._created_files.add(target_path)
198+
199+
target_path.write_text(code, encoding="utf-8")
200+
201+
def __enter__(self) -> list[_ResolvedTestFile]:
162202
from codeflash.languages.current import set_current_language
163203
from codeflash.languages.registry import get_language_support
164204

@@ -174,13 +214,14 @@ def __enter__(self) -> list[str]:
174214

175215
instrument_mode = "behavior" if self.mode == TestingMode.BEHAVIORAL else "performance"
176216

177-
instrumented_paths: list[str] = []
217+
instrumented_paths: list[_ResolvedTestFile] = []
178218
for test_file in self.test_file_paths:
179219
test_path = Path(test_file).resolve()
220+
instrumented_path = _instrumented_test_path(test_path, self.language, self.mode)
180221

181222
call_positions = _find_call_positions(test_path, func_to_optimize.function_name, self.language)
182223
if self.language == "python" and not call_positions:
183-
instrumented_paths.append(test_file)
224+
instrumented_paths.append(_ResolvedTestFile(original_path=test_path, effective_path=test_path))
184225
continue
185226

186227
success, code = lang_support.instrument_existing_test(
@@ -192,18 +233,23 @@ def __enter__(self) -> list[str]:
192233
)
193234

194235
if success and code:
195-
self._backups[test_path] = test_path.read_text(encoding="utf-8")
196-
test_path.write_text(code, encoding="utf-8")
197-
instrumented_paths.append(str(test_path))
236+
self._write_instrumented_source(instrumented_path, code)
237+
instrumented_paths.append(_ResolvedTestFile(original_path=test_path, effective_path=instrumented_path))
198238
else:
199-
instrumented_paths.append(test_file)
239+
instrumented_paths.append(_ResolvedTestFile(original_path=test_path, effective_path=test_path))
200240

201241
return instrumented_paths
202242

203243
def __exit__(self, *_exc: object) -> None:
244+
# restore original code for backup files
204245
for path, original_content in self._backups.items():
205246
path.write_text(original_content, encoding="utf-8")
247+
248+
# remove new files
249+
for path in self._created_files:
250+
path.unlink(missing_ok=True)
206251
self._backups.clear()
252+
self._created_files.clear()
207253

208254

209255
def run_and_parse(
@@ -225,11 +271,12 @@ def run_and_parse(
225271

226272
set_current_language(language)
227273
lang_support = get_language_support(language)
274+
_reset_java_compilation_cache(language)
228275

229276
test_env = build_test_env(project_root)
230277
test_config = _build_test_config(project_root)
231278

232-
def _execute(effective_files: list[str]) -> tuple[TestResults, subprocess.CompletedProcess[str]]:
279+
def _execute(effective_files: list[_ResolvedTestFile]) -> tuple[TestResults, subprocess.CompletedProcess[str]]:
233280
test_files_obj = _build_test_files(effective_files, mode)
234281

235282
if mode == TestingMode.BEHAVIORAL:
@@ -281,4 +328,4 @@ def _execute(effective_files: list[str]) -> tuple[TestResults, subprocess.Comple
281328
) as effective_files:
282329
return _execute(effective_files)
283330
else:
284-
return _execute(test_files)
331+
return _execute(_resolve_test_files(test_files))

0 commit comments

Comments
 (0)