11from __future__ import annotations
22
33import inspect
4+ from dataclasses import dataclass
45from enum import Enum
56from pathlib import Path
67from typing import TYPE_CHECKING
@@ -19,6 +20,12 @@ class TestingMode(str, Enum):
1920 BENCHMARKING = "benchmarking"
2021
2122
23+ @dataclass (frozen = True )
24+ class _ResolvedTestFile :
25+ original_path : Path
26+ effective_path : Path
27+
28+
2229def build_test_env (project_root : Path ) -> dict [str , str ]:
2330 env = make_env_with_project_root (project_root )
2431 env ["CODEFLASH_TEST_ITERATION" ] = "0"
@@ -29,18 +36,19 @@ def build_test_env(project_root: Path) -> dict[str, str]:
2936 return env
3037
3138
32- def _build_test_files (test_file_paths : list [str ], mode : TestingMode ) -> TestFiles :
39+ def _build_test_files (test_files : list [_ResolvedTestFile ], mode : TestingMode ) -> TestFiles :
3340 from codeflash .models .models import TestFile , TestFiles
3441 from codeflash .models .test_type import TestType
3542
3643 test_files_objs = []
37- for path_str in test_file_paths :
38- p = Path (path_str ).resolve ()
44+ for test_file in test_files :
45+ effective_path = test_file .effective_path .resolve ()
46+ original_path = test_file .original_path .resolve ()
3947 test_files_objs .append (
4048 TestFile (
41- instrumented_behavior_file_path = p ,
42- benchmarking_file_path = p if mode == TestingMode .BENCHMARKING else None ,
43- original_file_path = p ,
49+ instrumented_behavior_file_path = effective_path ,
50+ benchmarking_file_path = effective_path if mode == TestingMode .BENCHMARKING else None ,
51+ original_file_path = original_path ,
4452 test_type = TestType .EXISTING_UNIT_TEST ,
4553 )
4654 )
@@ -138,8 +146,31 @@ def _invoke_with_optional_test_framework(run_callable: object, *, test_framework
138146 return run_callable (** kwargs )
139147
140148
149+ def _resolve_test_files (test_file_paths : list [str ]) -> list [_ResolvedTestFile ]:
150+ return [_ResolvedTestFile (original_path = Path (path ).resolve (), effective_path = Path (path ).resolve ()) for path in test_file_paths ]
151+
152+
153+ def _instrumented_test_path (test_path : Path , language : str , mode : TestingMode ) -> Path :
154+ if language != "java" :
155+ return test_path
156+
157+ suffix = "__perfinstrumented" if mode == TestingMode .BEHAVIORAL else "__perfonlyinstrumented"
158+ if test_path .stem .endswith (suffix ):
159+ return test_path
160+ return test_path .with_name (f"{ test_path .stem } { suffix } { test_path .suffix } " )
161+
162+
163+ def _reset_java_compilation_cache (language : str ) -> None :
164+ if language != "java" :
165+ return
166+
167+ from codeflash .languages .java .test_runner import CompilationCache
168+
169+ CompilationCache .clear ()
170+
171+
141172class _InstrumentedFiles :
142- """Context manager that instruments test files in-place and restores originals on exit."""
173+ """Context manager that instruments MCP test files and restores originals on exit."""
143174
144175 def __init__ (
145176 self ,
@@ -157,8 +188,17 @@ def __init__(
157188 self .language = language
158189 self .mode = mode
159190 self ._backups : dict [Path , str ] = {}
191+ self ._created_files : set [Path ] = set ()
160192
161- def __enter__ (self ) -> list [str ]:
193+ def _write_instrumented_source (self , target_path : Path , code : str ) -> None :
194+ if target_path .exists ():
195+ self ._backups [target_path ] = target_path .read_text (encoding = "utf-8" )
196+ else :
197+ self ._created_files .add (target_path )
198+
199+ target_path .write_text (code , encoding = "utf-8" )
200+
201+ def __enter__ (self ) -> list [_ResolvedTestFile ]:
162202 from codeflash .languages .current import set_current_language
163203 from codeflash .languages .registry import get_language_support
164204
@@ -174,13 +214,14 @@ def __enter__(self) -> list[str]:
174214
175215 instrument_mode = "behavior" if self .mode == TestingMode .BEHAVIORAL else "performance"
176216
177- instrumented_paths : list [str ] = []
217+ instrumented_paths : list [_ResolvedTestFile ] = []
178218 for test_file in self .test_file_paths :
179219 test_path = Path (test_file ).resolve ()
220+ instrumented_path = _instrumented_test_path (test_path , self .language , self .mode )
180221
181222 call_positions = _find_call_positions (test_path , func_to_optimize .function_name , self .language )
182223 if self .language == "python" and not call_positions :
183- instrumented_paths .append (test_file )
224+ instrumented_paths .append (_ResolvedTestFile ( original_path = test_path , effective_path = test_path ) )
184225 continue
185226
186227 success , code = lang_support .instrument_existing_test (
@@ -192,18 +233,23 @@ def __enter__(self) -> list[str]:
192233 )
193234
194235 if success and code :
195- self ._backups [test_path ] = test_path .read_text (encoding = "utf-8" )
196- test_path .write_text (code , encoding = "utf-8" )
197- instrumented_paths .append (str (test_path ))
236+ self ._write_instrumented_source (instrumented_path , code )
237+ instrumented_paths .append (_ResolvedTestFile (original_path = test_path , effective_path = instrumented_path ))
198238 else :
199- instrumented_paths .append (test_file )
239+ instrumented_paths .append (_ResolvedTestFile ( original_path = test_path , effective_path = test_path ) )
200240
201241 return instrumented_paths
202242
203243 def __exit__ (self , * _exc : object ) -> None :
244+ # restore original code for backup files
204245 for path , original_content in self ._backups .items ():
205246 path .write_text (original_content , encoding = "utf-8" )
247+
248+ # remove new files
249+ for path in self ._created_files :
250+ path .unlink (missing_ok = True )
206251 self ._backups .clear ()
252+ self ._created_files .clear ()
207253
208254
209255def run_and_parse (
@@ -225,11 +271,12 @@ def run_and_parse(
225271
226272 set_current_language (language )
227273 lang_support = get_language_support (language )
274+ _reset_java_compilation_cache (language )
228275
229276 test_env = build_test_env (project_root )
230277 test_config = _build_test_config (project_root )
231278
232- def _execute (effective_files : list [str ]) -> tuple [TestResults , subprocess .CompletedProcess [str ]]:
279+ def _execute (effective_files : list [_ResolvedTestFile ]) -> tuple [TestResults , subprocess .CompletedProcess [str ]]:
233280 test_files_obj = _build_test_files (effective_files , mode )
234281
235282 if mode == TestingMode .BEHAVIORAL :
@@ -281,4 +328,4 @@ def _execute(effective_files: list[str]) -> tuple[TestResults, subprocess.Comple
281328 ) as effective_files :
282329 return _execute (effective_files )
283330 else :
284- return _execute (test_files )
331+ return _execute (_resolve_test_files ( test_files ) )
0 commit comments