Skip to content

Commit cd1f353

Browse files
authored
Merge pull request #1397 from codeflash-ai/fix/path_resolution_for_esm
ESM config compatibility for vitest
2 parents 67b48be + ae4df22 commit cd1f353

8 files changed

Lines changed: 269 additions & 6 deletions

File tree

codeflash/languages/javascript/instrument.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,115 @@ def is_relevant_import(module_path: str) -> bool:
901901
return test_code
902902

903903

904+
def fix_import_path_for_test_location(
905+
test_code: str, source_file_path: Path, test_file_path: Path, module_root: Path
906+
) -> str:
907+
"""Fix import paths in generated test code to be relative to test file location.
908+
909+
The AI may generate tests with import paths that are relative to the module root
910+
(e.g., 'apps/web/app/file') instead of relative to where the test file is located
911+
(e.g., '../../app/file'). This function fixes such imports.
912+
913+
Args:
914+
test_code: The generated test code.
915+
source_file_path: Absolute path to the source file being tested.
916+
test_file_path: Absolute path to where the test file will be written.
917+
module_root: Root directory of the module/project.
918+
919+
Returns:
920+
Test code with corrected import paths.
921+
922+
"""
923+
import os
924+
925+
# Calculate the correct relative import path from test file to source file
926+
test_dir = test_file_path.parent
927+
try:
928+
correct_rel_path = os.path.relpath(source_file_path, test_dir)
929+
correct_rel_path = correct_rel_path.replace("\\", "/")
930+
# Remove file extension for JS/TS imports
931+
for ext in [".tsx", ".ts", ".jsx", ".js", ".mjs", ".cjs"]:
932+
if correct_rel_path.endswith(ext):
933+
correct_rel_path = correct_rel_path[: -len(ext)]
934+
break
935+
# Ensure it starts with ./ or ../
936+
if not correct_rel_path.startswith("."):
937+
correct_rel_path = "./" + correct_rel_path
938+
except ValueError:
939+
# Can't compute relative path (different drives on Windows)
940+
return test_code
941+
942+
# Try to compute what incorrect path the AI might have generated
943+
# The AI often uses module_root-relative paths like 'apps/web/app/...'
944+
try:
945+
source_rel_to_module = os.path.relpath(source_file_path, module_root)
946+
source_rel_to_module = source_rel_to_module.replace("\\", "/")
947+
# Remove extension
948+
for ext in [".tsx", ".ts", ".jsx", ".js", ".mjs", ".cjs"]:
949+
if source_rel_to_module.endswith(ext):
950+
source_rel_to_module = source_rel_to_module[: -len(ext)]
951+
break
952+
except ValueError:
953+
return test_code
954+
955+
# Also check for project root-relative paths (including module_root in path)
956+
try:
957+
project_root = module_root.parent if module_root.name in ["src", "lib", "app", "web", "apps"] else module_root
958+
source_rel_to_project = os.path.relpath(source_file_path, project_root)
959+
source_rel_to_project = source_rel_to_project.replace("\\", "/")
960+
for ext in [".tsx", ".ts", ".jsx", ".js", ".mjs", ".cjs"]:
961+
if source_rel_to_project.endswith(ext):
962+
source_rel_to_project = source_rel_to_project[: -len(ext)]
963+
break
964+
except ValueError:
965+
source_rel_to_project = None
966+
967+
# Source file name (for matching module paths that end with the file name)
968+
source_name = source_file_path.stem
969+
970+
# Patterns to find import statements
971+
# ESM: import { func } from 'path' or import func from 'path'
972+
esm_import_pattern = re.compile(r"(import\s+(?:{[^}]+}|\w+)\s+from\s+['\"])([^'\"]+)(['\"])")
973+
# CommonJS: const { func } = require('path') or const func = require('path')
974+
cjs_require_pattern = re.compile(
975+
r"((?:const|let|var)\s+(?:{[^}]+}|\w+)\s*=\s*require\s*\(\s*['\"])([^'\"]+)(['\"])"
976+
)
977+
978+
def should_fix_path(import_path: str) -> bool:
979+
"""Check if this import path looks like it should point to our source file."""
980+
# Skip relative imports that already look correct
981+
if import_path.startswith(("./", "../")):
982+
return False
983+
# Skip package imports (no path separators or start with @)
984+
if "/" not in import_path and "\\" not in import_path:
985+
return False
986+
if import_path.startswith("@") and "/" in import_path:
987+
# Could be an alias like @/utils - skip these
988+
return False
989+
# Check if it looks like it points to our source file
990+
if import_path == source_rel_to_module:
991+
return True
992+
if source_rel_to_project and import_path == source_rel_to_project:
993+
return True
994+
if import_path.endswith((source_name, "/" + source_name)):
995+
return True
996+
return False
997+
998+
def fix_import(match: re.Match[str]) -> str:
999+
"""Replace incorrect import path with correct relative path."""
1000+
prefix = match.group(1)
1001+
import_path = match.group(2)
1002+
suffix = match.group(3)
1003+
1004+
if should_fix_path(import_path):
1005+
logger.debug(f"Fixing import path: {import_path} -> {correct_rel_path}")
1006+
return f"{prefix}{correct_rel_path}{suffix}"
1007+
return match.group(0)
1008+
1009+
test_code = esm_import_pattern.sub(fix_import, test_code)
1010+
return cjs_require_pattern.sub(fix_import, test_code)
1011+
1012+
9041013
def get_instrumented_test_path(original_path: Path, mode: str) -> Path:
9051014
"""Generate path for instrumented test file.
9061015

codeflash/languages/javascript/parse.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,19 @@ def parse_jest_test_xml(
175175
logger.debug(f"Found {marker_count} timing start markers in Jest stdout")
176176
else:
177177
logger.debug(f"No timing start markers found in Jest stdout (len={len(global_stdout)})")
178+
# Check for END markers with duration (perf test markers)
179+
end_marker_count = len(jest_end_pattern.findall(global_stdout))
180+
if end_marker_count > 0:
181+
logger.debug(
182+
f"[PERF-DEBUG] Found {end_marker_count} END timing markers with duration in Jest stdout"
183+
)
184+
# Sample a few markers to verify loop indices
185+
end_samples = list(jest_end_pattern.finditer(global_stdout))[:5]
186+
for sample in end_samples:
187+
groups = sample.groups()
188+
logger.debug(f"[PERF-DEBUG] Sample END marker: loopIndex={groups[3]}, duration={groups[5]}")
189+
else:
190+
logger.debug("[PERF-DEBUG] No END markers with duration found in Jest stdout")
178191
except (AttributeError, UnicodeDecodeError):
179192
global_stdout = ""
180193

@@ -197,6 +210,14 @@ def parse_jest_test_xml(
197210
key = match.groups()[:5]
198211
end_matches_dict[key] = match
199212

213+
# Debug: log suite-level END marker parsing for perf tests
214+
if end_matches_dict:
215+
# Get unique loop indices from the parsed END markers
216+
loop_indices = sorted({int(k[3]) if k[3].isdigit() else 1 for k in end_matches_dict})
217+
logger.debug(
218+
f"[PERF-DEBUG] Suite {suite_count}: parsed {len(end_matches_dict)} END markers from suite_stdout, loop_index range: {min(loop_indices)}-{max(loop_indices)}"
219+
)
220+
200221
# Also collect timing markers from testcase-level system-out (Vitest puts output at testcase level)
201222
for tc in suite:
202223
tc_system_out = tc._elem.find("system-out") # noqa: SLF001
@@ -327,6 +348,13 @@ def parse_jest_test_xml(
327348
sanitized_test_name = re.sub(r"[!#: ()\[\]{}|\\/*?^$.+\-]", "_", test_name)
328349
matching_starts = [m for m in start_matches if sanitized_test_name in m.group(2)]
329350

351+
# Debug: log which branch we're taking
352+
logger.debug(
353+
f"[FLOW-DEBUG] Testcase '{test_name[:50]}': "
354+
f"total_start_matches={len(start_matches)}, matching_starts={len(matching_starts)}, "
355+
f"total_end_matches={len(end_matches_dict)}"
356+
)
357+
330358
# For performance tests (capturePerf), there are no START markers - only END markers with duration
331359
# Check for END markers directly if no START markers found
332360
matching_ends_direct = []
@@ -337,6 +365,28 @@ def parse_jest_test_xml(
337365
# end_key is (module, testName, funcName, loopIndex, invocationId)
338366
if len(end_key) >= 2 and sanitized_test_name in end_key[1]:
339367
matching_ends_direct.append(end_match)
368+
# Debug: log matching results for perf tests
369+
if matching_ends_direct:
370+
loop_indices = [int(m.groups()[3]) if m.groups()[3].isdigit() else 1 for m in matching_ends_direct]
371+
logger.debug(
372+
f"[PERF-MATCH] Testcase '{test_name[:40]}': matched {len(matching_ends_direct)} END markers, "
373+
f"loop_index range: {min(loop_indices)}-{max(loop_indices)}"
374+
)
375+
elif end_matches_dict:
376+
# No matches but we have END markers - check why
377+
sample_keys = list(end_matches_dict.keys())[:3]
378+
logger.debug(
379+
f"[PERF-MISMATCH] Testcase '{test_name[:40]}': no matches found. "
380+
f"sanitized_test_name='{sanitized_test_name[:50]}', "
381+
f"sample end_keys={[k[1][:30] if len(k) >= 2 else k for k in sample_keys]}"
382+
)
383+
384+
# Log if we're skipping the matching_ends_direct branch
385+
if matching_starts and end_matches_dict:
386+
logger.debug(
387+
f"[FLOW-SKIP] Testcase '{test_name[:40]}': has {len(matching_starts)} START markers, "
388+
f"skipping {len(end_matches_dict)} END markers (behavior test mode)"
389+
)
340390

341391
if not matching_starts and not matching_ends_direct:
342392
# No timing markers found - use JUnit XML time attribute as fallback
@@ -373,11 +423,13 @@ def parse_jest_test_xml(
373423
)
374424
elif matching_ends_direct:
375425
# Performance test format: process END markers directly (no START markers)
426+
loop_indices_found = []
376427
for end_match in matching_ends_direct:
377428
groups = end_match.groups()
378429
# groups: (module, testName, funcName, loopIndex, invocationId, durationNs)
379430
func_name = groups[2]
380431
loop_index = int(groups[3]) if groups[3].isdigit() else 1
432+
loop_indices_found.append(loop_index)
381433
line_id = groups[4]
382434
try:
383435
runtime = int(groups[5])
@@ -403,6 +455,12 @@ def parse_jest_test_xml(
403455
stdout="",
404456
)
405457
)
458+
if loop_indices_found:
459+
logger.debug(
460+
f"[LOOP-DEBUG] Testcase '{test_name}': processed {len(matching_ends_direct)} END markers, "
461+
f"loop_index range: {min(loop_indices_found)}-{max(loop_indices_found)}, "
462+
f"total results so far: {len(test_results.test_results)}"
463+
)
406464
else:
407465
# Process each timing marker
408466
for match in matching_starts:
@@ -454,5 +512,19 @@ def parse_jest_test_xml(
454512
f"Jest XML parsing complete: {len(test_results.test_results)} results "
455513
f"from {suite_count} suites, {testcase_count} testcases"
456514
)
515+
# Debug: show loop_index distribution for perf analysis
516+
if test_results.test_results:
517+
loop_indices = [r.loop_index for r in test_results.test_results]
518+
unique_loop_indices = sorted(set(loop_indices))
519+
min_idx, max_idx = min(unique_loop_indices), max(unique_loop_indices)
520+
logger.debug(
521+
f"[LOOP-SUMMARY] Results loop_index: min={min_idx}, max={max_idx}, "
522+
f"unique_count={len(unique_loop_indices)}, total_results={len(loop_indices)}"
523+
)
524+
if max_idx == 1 and len(loop_indices) > 1:
525+
logger.warning(
526+
f"[LOOP-WARNING] All {len(loop_indices)} results have loop_index=1. "
527+
"Perf test markers may not have been parsed correctly."
528+
)
457529

458530
return test_results

codeflash/languages/javascript/support.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2134,13 +2134,15 @@ def run_benchmarking_tests(
21342134
from codeflash.languages.test_framework import get_js_test_framework_or_default
21352135

21362136
framework = test_framework or get_js_test_framework_or_default()
2137+
logger.debug("run_benchmarking_tests called with framework=%s", framework)
21372138

21382139
# Use JS-specific high max_loops - actual loop count is limited by target_duration
21392140
effective_max_loops = self.JS_BENCHMARKING_MAX_LOOPS
21402141

21412142
if framework == "vitest":
21422143
from codeflash.languages.javascript.vitest_runner import run_vitest_benchmarking_tests
21432144

2145+
logger.debug("Dispatching to run_vitest_benchmarking_tests")
21442146
return run_vitest_benchmarking_tests(
21452147
test_paths=test_paths,
21462148
test_env=test_env,

codeflash/languages/javascript/vitest_runner.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def _ensure_codeflash_vitest_config(project_root: Path) -> Path | None:
192192
logger.debug("Detected vitest workspace configuration - skipping custom config")
193193
return None
194194

195-
codeflash_config_path = project_root / "codeflash.vitest.config.js"
195+
codeflash_config_path = project_root / "codeflash.vitest.config.mjs"
196196

197197
# If already exists, use it
198198
if codeflash_config_path.exists():
@@ -281,7 +281,7 @@ def _build_vitest_behavioral_command(
281281

282282
# For monorepos with restrictive vitest configs (e.g., include: test/**/*.test.ts),
283283
# we need to create a custom config that allows all test patterns.
284-
# This is done by creating a codeflash.vitest.config.js file.
284+
# This is done by creating a codeflash.vitest.config.mjs file.
285285
if project_root:
286286
codeflash_vitest_config = _ensure_codeflash_vitest_config(project_root)
287287
if codeflash_vitest_config:
@@ -520,6 +520,9 @@ def run_vitest_benchmarking_tests(
520520
) -> tuple[Path, subprocess.CompletedProcess]:
521521
"""Run Vitest benchmarking tests with external looping from Python.
522522
523+
NOTE: This function MUST use benchmarking_file_path (perf tests with capturePerf),
524+
NOT instrumented_behavior_file_path (behavior tests with capture).
525+
523526
Uses external process-level looping to run tests multiple times and
524527
collect timing data. This matches the Python pytest approach where
525528
looping is controlled externally for simplicity.
@@ -544,6 +547,26 @@ def run_vitest_benchmarking_tests(
544547
# Get performance test files
545548
test_files = [Path(file.benchmarking_file_path) for file in test_paths.test_files if file.benchmarking_file_path]
546549

550+
# Log test file selection
551+
total_test_files = len(test_paths.test_files)
552+
perf_test_files = len(test_files)
553+
logger.debug(
554+
f"Vitest benchmark test file selection: {perf_test_files}/{total_test_files} have benchmarking_file_path"
555+
)
556+
if perf_test_files == 0:
557+
logger.warning("No perf test files found! Cannot run benchmarking tests.")
558+
for tf in test_paths.test_files:
559+
logger.warning(
560+
f"Test file: behavior={tf.instrumented_behavior_file_path}, perf={tf.benchmarking_file_path}"
561+
)
562+
elif perf_test_files < total_test_files:
563+
for tf in test_paths.test_files:
564+
if not tf.benchmarking_file_path:
565+
logger.warning(f"Missing benchmarking_file_path: behavior={tf.instrumented_behavior_file_path}")
566+
else:
567+
for tf in test_files[:3]: # Log first 3 perf test files
568+
logger.debug(f"Using perf test file: {tf}")
569+
547570
# Use provided project_root, or detect it as fallback
548571
if project_root is None and test_files:
549572
project_root = _find_vitest_project_root(test_files[0])
@@ -574,14 +597,25 @@ def run_vitest_benchmarking_tests(
574597
vitest_env["CODEFLASH_PERF_STABILITY_CHECK"] = "true" if stability_check else "false"
575598
vitest_env["CODEFLASH_LOOP_INDEX"] = "1"
576599

600+
# Set test module for marker identification (use first test file as reference)
601+
if test_files:
602+
test_module_path = str(
603+
test_files[0].relative_to(effective_cwd)
604+
if test_files[0].is_relative_to(effective_cwd)
605+
else test_files[0].name
606+
)
607+
vitest_env["CODEFLASH_TEST_MODULE"] = test_module_path
608+
logger.debug(f"[VITEST-BENCH] Set CODEFLASH_TEST_MODULE={test_module_path}")
609+
577610
# Total timeout for the entire benchmark run
578611
total_timeout = max(120, (target_duration_ms // 1000) + 60, timeout or 120)
579612

580-
logger.debug(f"Running Vitest benchmarking tests: {' '.join(vitest_cmd)}")
613+
logger.debug(f"[VITEST-BENCH] Running Vitest benchmarking tests: {' '.join(vitest_cmd)}")
581614
logger.debug(
582-
f"Vitest benchmarking config: min_loops={min_loops}, max_loops={max_loops}, "
615+
f"[VITEST-BENCH] Config: min_loops={min_loops}, max_loops={max_loops}, "
583616
f"target_duration={target_duration_ms}ms, stability_check={stability_check}"
584617
)
618+
logger.debug(f"[VITEST-BENCH] Environment: CODEFLASH_PERF_LOOP_COUNT={vitest_env.get('CODEFLASH_PERF_LOOP_COUNT')}")
585619

586620
total_start_time = time.time()
587621

@@ -606,7 +640,27 @@ def run_vitest_benchmarking_tests(
606640
result = subprocess.CompletedProcess(args=vitest_cmd, returncode=-1, stdout="", stderr="Vitest not found")
607641

608642
wall_clock_seconds = time.time() - total_start_time
609-
logger.debug(f"Vitest benchmarking completed in {wall_clock_seconds:.2f}s")
643+
logger.debug(f"[VITEST-BENCH] Completed in {wall_clock_seconds:.2f}s, returncode={result.returncode}")
644+
645+
# Debug: Check for END markers with duration (perf test format)
646+
if result.stdout:
647+
import re
648+
649+
perf_end_pattern = re.compile(r"!######[^:]+:[^:]+:[^:]+:(\d+):[^:]+:(\d+)######!")
650+
perf_matches = list(perf_end_pattern.finditer(result.stdout))
651+
if perf_matches:
652+
loop_indices = [int(m.group(1)) for m in perf_matches]
653+
logger.debug(
654+
f"[VITEST-BENCH] Found {len(perf_matches)} perf END markers in stdout, "
655+
f"loop_index range: {min(loop_indices)}-{max(loop_indices)}"
656+
)
657+
else:
658+
logger.debug(f"[VITEST-BENCH] No perf END markers found in stdout (len={len(result.stdout)})")
659+
# Check if there are behavior END markers instead
660+
behavior_end_pattern = re.compile(r"!######[^:]+:[^:]+:[^:]+:\d+:[^#]+######!")
661+
behavior_matches = list(behavior_end_pattern.finditer(result.stdout))
662+
if behavior_matches:
663+
logger.debug(f"[VITEST-BENCH] Found {len(behavior_matches)} behavior END markers instead (no duration)")
610664

611665
return result_file_path, result
612666

codeflash/optimization/function_optimizer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2363,6 +2363,12 @@ def establish_original_code_baseline(
23632363
)
23642364
console.rule()
23652365
with progress_bar("Running performance benchmarks..."):
2366+
logger.debug(
2367+
f"[BENCHMARK-START] Starting benchmarking tests with {len(self.test_files.test_files)} test files"
2368+
)
2369+
for idx, tf in enumerate(self.test_files.test_files):
2370+
logger.debug(f"[BENCHMARK-FILES] Test file {idx}: perf_file={tf.benchmarking_file_path}")
2371+
23662372
if self.function_to_optimize.is_async and is_python():
23672373
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
23682374

@@ -2380,6 +2386,7 @@ def establish_original_code_baseline(
23802386
enable_coverage=False,
23812387
code_context=code_context,
23822388
)
2389+
logger.debug(f"[BENCHMARK-DONE] Got {len(benchmarking_results.test_results)} benchmark results")
23832390
finally:
23842391
if self.function_to_optimize.is_async:
23852392
self.write_code_and_helpers(

codeflash/verification/test_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,7 @@ def run_benchmarking_tests(
325325
pytest_max_loops: int = 100_000,
326326
js_project_root: Path | None = None,
327327
) -> tuple[Path, subprocess.CompletedProcess]:
328+
logger.debug(f"run_benchmarking_tests called: framework={test_framework}, num_files={len(test_paths.test_files)}")
328329
# Check if there's a language support for this test framework that implements run_benchmarking_tests
329330
language_support = get_language_support_by_framework(test_framework)
330331
if language_support is not None and hasattr(language_support, "run_benchmarking_tests"):

0 commit comments

Comments
 (0)