Skip to content

Commit 051d1b6

Browse files
misrasaurabh1claude
andcommitted
feat: add inner loop and compile-once-run-many optimization for Java benchmarking
- Add inner loop in Java test instrumentation for JIT warmup within single JVM - Implement compile-once-run-many: compile tests once with Maven, then run directly via JUnit Console Launcher (~500ms vs ~5-10s per invocation) - Add fallback to Maven-based execution when direct execution fails - Update parsing to handle JUnit Console Launcher output format - Add inner_iterations parameter (default: 100) to control loop count - Add comprehensive E2E tests for inner loop benchmarking Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent c299d99 commit 051d1b6

6 files changed

Lines changed: 1123 additions & 267 deletions

File tree

codeflash/languages/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,7 @@ def run_benchmarking_tests(
653653
min_loops: int = 5,
654654
max_loops: int = 100_000,
655655
target_duration_seconds: float = 10.0,
656+
inner_iterations: int = 100,
656657
) -> tuple[Path, Any]:
657658
"""Run benchmarking tests for this language.
658659
@@ -665,6 +666,7 @@ def run_benchmarking_tests(
665666
min_loops: Minimum number of loops for benchmarking.
666667
max_loops: Maximum number of loops for benchmarking.
667668
target_duration_seconds: Target duration for benchmarking in seconds.
669+
inner_iterations: Number of inner loop iterations per test method (Java only).
668670
669671
Returns:
670672
Tuple of (result_file_path, subprocess_result).

codeflash/languages/java/instrumentation.py

Lines changed: 59 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import TYPE_CHECKING
2121

2222
from codeflash.languages.base import FunctionInfo
23-
from codeflash.languages.java.parser import JavaAnalyzer, get_java_analyzer
23+
from codeflash.languages.java.parser import JavaAnalyzer
2424

2525
if TYPE_CHECKING:
2626
from collections.abc import Sequence
@@ -154,8 +154,8 @@ def instrument_existing_test(
154154

155155
# Rename the class declaration in the source
156156
# Pattern: "public class ClassName" or "class ClassName"
157-
pattern = rf'\b(public\s+)?class\s+{re.escape(original_class_name)}\b'
158-
replacement = rf'\1class {new_class_name}'
157+
pattern = rf"\b(public\s+)?class\s+{re.escape(original_class_name)}\b"
158+
replacement = rf"\1class {new_class_name}"
159159
modified_source = re.sub(pattern, replacement, source)
160160

161161
# Add timing instrumentation to test methods
@@ -214,7 +214,7 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
214214
]
215215

216216
# Find position to insert imports (after package, before class)
217-
lines = source.split('\n')
217+
lines = source.split("\n")
218218
result = []
219219
imports_added = False
220220
i = 0
@@ -225,11 +225,11 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
225225

226226
# Add imports after the last existing import or before the class declaration
227227
if not imports_added:
228-
if stripped.startswith('import '):
228+
if stripped.startswith("import "):
229229
result.append(line)
230230
i += 1
231231
# Find end of imports
232-
while i < len(lines) and lines[i].strip().startswith('import '):
232+
while i < len(lines) and lines[i].strip().startswith("import "):
233233
result.append(lines[i])
234234
i += 1
235235
# Add our imports
@@ -238,7 +238,7 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
238238
result.append(imp)
239239
imports_added = True
240240
continue
241-
elif stripped.startswith('public class') or stripped.startswith('class'):
241+
if stripped.startswith("public class") or stripped.startswith("class"):
242242
# No imports found, add before class
243243
for imp in import_statements:
244244
result.append(imp)
@@ -249,8 +249,8 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
249249
i += 1
250250

251251
# Now add timing and SQLite instrumentation to test methods
252-
source = '\n'.join(result)
253-
lines = source.split('\n')
252+
source = "\n".join(result)
253+
lines = source.split("\n")
254254
result = []
255255
i = 0
256256
iteration_counter = 0
@@ -260,20 +260,20 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
260260
stripped = line.strip()
261261

262262
# Look for @Test annotation
263-
if stripped.startswith('@Test'):
263+
if stripped.startswith("@Test"):
264264
result.append(line)
265265
i += 1
266266

267267
# Collect any additional annotations
268-
while i < len(lines) and lines[i].strip().startswith('@'):
268+
while i < len(lines) and lines[i].strip().startswith("@"):
269269
result.append(lines[i])
270270
i += 1
271271

272272
# Now find the method signature and opening brace
273273
method_lines = []
274274
while i < len(lines):
275275
method_lines.append(lines[i])
276-
if '{' in lines[i]:
276+
if "{" in lines[i]:
277277
break
278278
i += 1
279279

@@ -298,9 +298,9 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
298298
while i < len(lines) and brace_depth > 0:
299299
body_line = lines[i]
300300
for ch in body_line:
301-
if ch == '{':
301+
if ch == "{":
302302
brace_depth += 1
303-
elif ch == '}':
303+
elif ch == "}":
304304
brace_depth -= 1
305305

306306
if brace_depth > 0:
@@ -323,13 +323,13 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
323323
# - new ClassName(args)
324324
# - this
325325
method_call_pattern = re.compile(
326-
rf'((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*({re.escape(func_name)})\s*\(([^)]*)\)',
326+
rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*({re.escape(func_name)})\s*\(([^)]*)\)",
327327
re.MULTILINE
328328
)
329329

330330
for body_line in body_lines:
331331
# Check if this line contains a call to the target function
332-
if func_name in body_line and '(' in body_line:
332+
if func_name in body_line and "(" in body_line:
333333
line_indent = len(body_line) - len(body_line.lstrip())
334334
line_indent_str = " " * line_indent
335335

@@ -360,7 +360,7 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
360360
# If we captured any calls, serialize the last one; otherwise serialize null
361361
if call_counter > 0:
362362
result_var = f"_cf_result{iter_id}_{call_counter}"
363-
serialize_expr = f'new GsonBuilder().serializeNulls().create().toJson({result_var})'
363+
serialize_expr = f"new GsonBuilder().serializeNulls().create().toJson({result_var})"
364364
else:
365365
serialize_expr = '"null"'
366366

@@ -399,8 +399,8 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
399399
f"{indent} // Write to SQLite if output file is set",
400400
f"{indent} if (_cf_outputFile{iter_id} != null && !_cf_outputFile{iter_id}.isEmpty()) {{",
401401
f"{indent} try {{",
402-
f"{indent} Class.forName(\"org.sqlite.JDBC\");",
403-
f"{indent} try (Connection _cf_conn{iter_id} = DriverManager.getConnection(\"jdbc:sqlite:\" + _cf_outputFile{iter_id})) {{",
402+
f'{indent} Class.forName("org.sqlite.JDBC");',
403+
f'{indent} try (Connection _cf_conn{iter_id} = DriverManager.getConnection("jdbc:sqlite:" + _cf_outputFile{iter_id})) {{',
404404
f"{indent} try (Statement _cf_stmt{iter_id} = _cf_conn{iter_id}.createStatement()) {{",
405405
f'{indent} _cf_stmt{iter_id}.execute("CREATE TABLE IF NOT EXISTS test_results (" +',
406406
f'{indent} "test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " +',
@@ -433,20 +433,26 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
433433
result.append(line)
434434
i += 1
435435

436-
return '\n'.join(result)
436+
return "\n".join(result)
437437

438438

439439
def _add_timing_instrumentation(source: str, class_name: str, func_name: str) -> str:
440-
"""Add timing instrumentation to test methods.
440+
"""Add timing instrumentation to test methods with inner loop for JIT warmup.
441441
442442
For each @Test method, this adds:
443-
1. Start timing marker printed at the beginning
444-
2. End timing marker printed at the end (in a finally block)
443+
1. Inner loop that runs N iterations (controlled by CODEFLASH_INNER_ITERATIONS env var)
444+
2. Start timing marker printed at the beginning of each iteration
445+
3. End timing marker printed at the end of each iteration (in a finally block)
446+
447+
The inner loop allows JIT warmup within a single JVM invocation, avoiding
448+
expensive Maven restarts. Post-processing uses min runtime across all iterations.
445449
446450
Timing markers format:
447451
Start: !$######testModule:testClass:funcName:loopIndex:iterationId######$!
448452
End: !######testModule:testClass:funcName:loopIndex:iterationId:durationNs######!
449453
454+
Where iterationId is the inner iteration number (0, 1, 2, ..., N-1).
455+
450456
Args:
451457
source: The test source code.
452458
class_name: Name of the test class.
@@ -460,7 +466,7 @@ def _add_timing_instrumentation(source: str, class_name: str, func_name: str) ->
460466
# Pattern matches: @Test (with optional parameters) followed by method declaration
461467
# We process line by line for cleaner handling
462468

463-
lines = source.split('\n')
469+
lines = source.split("\n")
464470
result = []
465471
i = 0
466472
iteration_counter = 0
@@ -470,20 +476,20 @@ def _add_timing_instrumentation(source: str, class_name: str, func_name: str) ->
470476
stripped = line.strip()
471477

472478
# Look for @Test annotation
473-
if stripped.startswith('@Test'):
479+
if stripped.startswith("@Test"):
474480
result.append(line)
475481
i += 1
476482

477483
# Collect any additional annotations
478-
while i < len(lines) and lines[i].strip().startswith('@'):
484+
while i < len(lines) and lines[i].strip().startswith("@"):
479485
result.append(lines[i])
480486
i += 1
481487

482488
# Now find the method signature and opening brace
483489
method_lines = []
484490
while i < len(lines):
485491
method_lines.append(lines[i])
486-
if '{' in lines[i]:
492+
if "{" in lines[i]:
487493
break
488494
i += 1
489495

@@ -500,21 +506,24 @@ def _add_timing_instrumentation(source: str, class_name: str, func_name: str) ->
500506
method_sig_line = method_lines[-1] if method_lines else ""
501507
base_indent = len(method_sig_line) - len(method_sig_line.lstrip())
502508
indent = " " * (base_indent + 4) # Add one level of indentation
509+
inner_indent = " " * (base_indent + 8) # Two levels for inside inner loop
510+
inner_body_indent = " " * (base_indent + 12) # Three levels for try block body
503511

504-
# Add timing start code
512+
# Add timing instrumentation with inner loop
505513
# Note: CODEFLASH_LOOP_INDEX must always be set - no null check, crash if missing
506-
# Start marker is printed BEFORE timing starts
507-
# System.nanoTime() immediately precedes try block with test code
514+
# CODEFLASH_INNER_ITERATIONS controls inner loop count (default: 100)
508515
timing_start_code = [
509-
f"{indent}// Codeflash timing instrumentation",
516+
f"{indent}// Codeflash timing instrumentation with inner loop for JIT warmup",
510517
f'{indent}int _cf_loop{iter_id} = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));',
511-
f"{indent}int _cf_iter{iter_id} = {iter_id};",
518+
f'{indent}int _cf_innerIterations{iter_id} = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "100"));',
512519
f'{indent}String _cf_mod{iter_id} = "{class_name}";',
513520
f'{indent}String _cf_cls{iter_id} = "{class_name}";',
514521
f'{indent}String _cf_fn{iter_id} = "{func_name}";',
515-
f'{indent}System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":" + _cf_iter{iter_id} + "######$!");',
516-
f"{indent}long _cf_start{iter_id} = System.nanoTime();",
517-
f"{indent}try {{",
522+
"",
523+
f"{indent}for (int _cf_i{iter_id} = 0; _cf_i{iter_id} < _cf_innerIterations{iter_id}; _cf_i{iter_id}++) {{",
524+
f'{inner_indent}System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":" + _cf_i{iter_id} + "######$!");',
525+
f"{inner_indent}long _cf_start{iter_id} = System.nanoTime();",
526+
f"{inner_indent}try {{",
518527
]
519528
result.extend(timing_start_code)
520529

@@ -526,28 +535,29 @@ def _add_timing_instrumentation(source: str, class_name: str, func_name: str) ->
526535
body_line = lines[i]
527536
# Count braces (simple approach - doesn't handle strings/comments perfectly)
528537
for ch in body_line:
529-
if ch == '{':
538+
if ch == "{":
530539
brace_depth += 1
531-
elif ch == '}':
540+
elif ch == "}":
532541
brace_depth -= 1
533542

534543
if brace_depth > 0:
535544
body_lines.append(body_line)
536545
i += 1
537546
else:
538547
# This line contains the closing brace, but we've hit depth 0
539-
# Add indented body lines
548+
# Add indented body lines (inside try block, inside for loop)
540549
for bl in body_lines:
541-
result.append(" " + bl)
550+
result.append(" " + bl) # 8 extra spaces for inner loop + try
542551

543-
# Add finally block
552+
# Add finally block and close inner loop
544553
method_close_indent = " " * base_indent # Same level as method signature
545554
timing_end_code = [
546-
f"{indent}}} finally {{",
547-
f"{indent} long _cf_end{iter_id} = System.nanoTime();",
548-
f"{indent} long _cf_dur{iter_id} = _cf_end{iter_id} - _cf_start{iter_id};",
549-
f'{indent} System.out.println("!######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":" + _cf_iter{iter_id} + ":" + _cf_dur{iter_id} + "######!");',
550-
f"{indent}}}",
555+
f"{inner_indent}}} finally {{",
556+
f"{inner_indent} long _cf_end{iter_id} = System.nanoTime();",
557+
f"{inner_indent} long _cf_dur{iter_id} = _cf_end{iter_id} - _cf_start{iter_id};",
558+
f'{inner_indent} System.out.println("!######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":" + _cf_i{iter_id} + ":" + _cf_dur{iter_id} + "######!");',
559+
f"{inner_indent}}}",
560+
f"{indent}}}", # Close for loop
551561
f"{method_close_indent}}}", # Method closing brace
552562
]
553563
result.extend(timing_end_code)
@@ -556,7 +566,7 @@ def _add_timing_instrumentation(source: str, class_name: str, func_name: str) ->
556566
result.append(line)
557567
i += 1
558568

559-
return '\n'.join(result)
569+
return "\n".join(result)
560570

561571

562572
def create_benchmark_test(
@@ -653,7 +663,7 @@ def instrument_generated_java_test(
653663
"""
654664
# Extract class name from the test code
655665
# Use pattern that starts at beginning of line to avoid matching words in comments
656-
class_match = re.search(r'^(?:public\s+)?class\s+(\w+)', test_code, re.MULTILINE)
666+
class_match = re.search(r"^(?:public\s+)?class\s+(\w+)", test_code, re.MULTILINE)
657667
if not class_match:
658668
logger.warning("Could not find class name in generated test")
659669
return test_code
@@ -668,8 +678,8 @@ def instrument_generated_java_test(
668678

669679
# Rename the class in the source
670680
modified_code = re.sub(
671-
rf'\b(public\s+)?class\s+{re.escape(original_class_name)}\b',
672-
rf'\1class {new_class_name}',
681+
rf"\b(public\s+)?class\s+{re.escape(original_class_name)}\b",
682+
rf"\1class {new_class_name}",
673683
test_code,
674684
)
675685

codeflash/languages/java/support.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -356,11 +356,12 @@ def run_benchmarking_tests(
356356
cwd: Path,
357357
timeout: int | None = None,
358358
project_root: Path | None = None,
359-
min_loops: int = 5,
360-
max_loops: int = 100_000,
359+
min_loops: int = 1,
360+
max_loops: int = 3,
361361
target_duration_seconds: float = 10.0,
362+
inner_iterations: int = 100,
362363
) -> tuple[Path, Any]:
363-
"""Run benchmarking tests for Java."""
364+
"""Run benchmarking tests for Java with inner loop for JIT warmup."""
364365
return run_benchmarking_tests(
365366
test_paths,
366367
test_env,
@@ -370,6 +371,7 @@ def run_benchmarking_tests(
370371
min_loops,
371372
max_loops,
372373
target_duration_seconds,
374+
inner_iterations,
373375
)
374376

375377

0 commit comments

Comments
 (0)