Skip to content

Commit 15a9261

Browse files
Merge pull request #1880 from codeflash-ai/java-config-redesign
feat: zero-config Java projects + smart ReplayHelper for end-to-end optimization
2 parents c077997 + 823300f commit 15a9261

19 files changed

Lines changed: 990 additions & 119 deletions

File tree

.github/workflows/e2e-java-tracer.yaml

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,9 @@ name: E2E - Java Tracer
33
on:
44
pull_request:
55
paths:
6-
- 'codeflash/languages/java/**'
7-
- 'codeflash/languages/base.py'
8-
- 'codeflash/languages/registry.py'
9-
- 'codeflash/tracer.py'
10-
- 'codeflash/benchmarking/function_ranker.py'
11-
- 'codeflash/discovery/functions_to_optimize.py'
12-
- 'codeflash/optimization/**'
13-
- 'codeflash/verification/**'
6+
- 'codeflash/**'
147
- 'codeflash-java-runtime/**'
15-
- 'tests/test_languages/fixtures/java_tracer_e2e/**'
16-
- 'tests/scripts/end_to_end_test_java_tracer.py'
8+
- 'tests/**'
179
- '.github/workflows/e2e-java-tracer.yaml'
1810

1911
workflow_dispatch:

codeflash-java-runtime/src/main/java/com/codeflash/ReplayHelper.java

Lines changed: 200 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,181 @@
1212

1313
public class ReplayHelper {
1414

15-
private final Connection db;
15+
private final Connection traceDb;
16+
17+
// Codeflash instrumentation state — read from environment variables once
18+
private final String mode; // "behavior", "performance", or null
19+
private final int loopIndex;
20+
private final String testIteration;
21+
private final String outputFile; // SQLite path for behavior capture
22+
private final int innerIterations; // for performance looping
23+
24+
// Behavior mode: lazily opened SQLite connection for writing results
25+
private Connection behaviorDb;
26+
private boolean behaviorDbInitialized;
1627

1728
public ReplayHelper(String traceDbPath) {
1829
try {
19-
this.db = DriverManager.getConnection("jdbc:sqlite:" + traceDbPath);
30+
this.traceDb = DriverManager.getConnection("jdbc:sqlite:" + traceDbPath);
2031
} catch (SQLException e) {
2132
throw new RuntimeException("Failed to open trace database: " + traceDbPath, e);
2233
}
34+
35+
// Read codeflash instrumentation env vars (set by the test runner)
36+
this.mode = System.getenv("CODEFLASH_MODE");
37+
this.loopIndex = parseIntEnv("CODEFLASH_LOOP_INDEX", 1);
38+
this.testIteration = getEnvOrDefault("CODEFLASH_TEST_ITERATION", "0");
39+
this.outputFile = System.getenv("CODEFLASH_OUTPUT_FILE");
40+
this.innerIterations = parseIntEnv("CODEFLASH_INNER_ITERATIONS", 10);
2341
}
2442

2543
public void replay(String className, String methodName, String descriptor, int invocationIndex) throws Exception {
26-
// Query the function_calls table for this method at the given index
44+
// Deserialize args and resolve method (done once, outside timing)
45+
Object[] allArgs = loadArgs(className, methodName, descriptor, invocationIndex);
46+
Class<?> targetClass = Class.forName(className);
47+
48+
Type[] paramTypes = Type.getArgumentTypes(descriptor);
49+
Class<?>[] paramClasses = new Class<?>[paramTypes.length];
50+
for (int i = 0; i < paramTypes.length; i++) {
51+
paramClasses[i] = typeToClass(paramTypes[i]);
52+
}
53+
54+
Method method = targetClass.getDeclaredMethod(methodName, paramClasses);
55+
method.setAccessible(true);
56+
boolean isStatic = Modifier.isStatic(method.getModifiers());
57+
58+
Object instance = null;
59+
if (!isStatic) {
60+
try {
61+
java.lang.reflect.Constructor<?> ctor = targetClass.getDeclaredConstructor();
62+
ctor.setAccessible(true);
63+
instance = ctor.newInstance();
64+
} catch (NoSuchMethodException e) {
65+
instance = new org.objenesis.ObjenesisStd().newInstance(targetClass);
66+
}
67+
}
68+
69+
// Get the calling test method name from the stack trace
70+
String testMethodName = getCallingTestMethodName();
71+
// Module name = the test class that called us
72+
String testClassName = getCallingTestClassName();
73+
74+
if ("behavior".equals(mode)) {
75+
replayBehavior(method, instance, allArgs, className, methodName, testClassName, testMethodName);
76+
} else if ("performance".equals(mode)) {
77+
replayPerformance(method, instance, allArgs, className, methodName, testClassName, testMethodName);
78+
} else {
79+
// No codeflash mode — just invoke (trace-only or manual testing)
80+
method.invoke(instance, allArgs);
81+
}
82+
}
83+
84+
private void replayBehavior(Method method, Object instance, Object[] args,
85+
String className, String methodName,
86+
String testClassName, String testMethodName) throws Exception {
87+
// testIteration goes at the END so the Comparator's lastUnderscore stripping
88+
// removes it, making baseline (iteration=0) and candidate (iteration=N) keys match.
89+
String invId = testMethodName + "_" + testIteration;
90+
91+
// Print start marker (same format as behavior instrumentation)
92+
System.out.println("!$######" + testClassName + ":" + testClassName + "." + testMethodName
93+
+ ":" + methodName + ":" + loopIndex + ":" + invId + "######$!");
94+
95+
long startNs = System.nanoTime();
96+
Object result;
97+
try {
98+
result = method.invoke(instance, args);
99+
} catch (java.lang.reflect.InvocationTargetException e) {
100+
throw (Exception) e.getCause();
101+
}
102+
long durationNs = System.nanoTime() - startNs;
103+
104+
// Print end marker
105+
System.out.println("!######" + testClassName + ":" + testClassName + "." + testMethodName
106+
+ ":" + methodName + ":" + loopIndex + ":" + invId + ":" + durationNs + "######!");
107+
108+
// Write return value to SQLite for correctness comparison
109+
if (outputFile != null && !outputFile.isEmpty()) {
110+
writeBehaviorResult(testClassName, testMethodName, methodName, invId, durationNs, result);
111+
}
112+
}
113+
114+
private void replayPerformance(Method method, Object instance, Object[] args,
115+
String className, String methodName,
116+
String testClassName, String testMethodName) throws Exception {
117+
// Performance mode: run inner loop for JIT warmup, print timing for each iteration
118+
int maxInner = innerIterations;
119+
for (int inner = 0; inner < maxInner; inner++) {
120+
int loopId = (loopIndex - 1) * maxInner + inner;
121+
String invId = testMethodName;
122+
123+
// Print start marker
124+
System.out.println("!$######" + testClassName + ":" + testClassName + "." + testMethodName
125+
+ ":" + methodName + ":" + loopId + ":" + invId + "######$!");
126+
127+
long startNs = System.nanoTime();
128+
try {
129+
method.invoke(instance, args);
130+
} catch (java.lang.reflect.InvocationTargetException e) {
131+
// Swallow — performance mode doesn't check correctness
132+
}
133+
long durationNs = System.nanoTime() - startNs;
134+
135+
// Print end marker
136+
System.out.println("!######" + testClassName + ":" + testClassName + "." + testMethodName
137+
+ ":" + methodName + ":" + loopId + ":" + invId + ":" + durationNs + "######!");
138+
}
139+
}
140+
141+
private void writeBehaviorResult(String testClassName, String testMethodName,
142+
String functionName, String invId,
143+
long durationNs, Object result) {
144+
try {
145+
ensureBehaviorDb();
146+
String sql = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)";
147+
try (PreparedStatement ps = behaviorDb.prepareStatement(sql)) {
148+
ps.setString(1, testClassName); // test_module_path
149+
ps.setString(2, testClassName); // test_class_name
150+
ps.setString(3, testMethodName); // test_function_name
151+
ps.setString(4, functionName); // function_getting_tested
152+
ps.setInt(5, loopIndex); // loop_index
153+
ps.setString(6, invId); // iteration_id
154+
ps.setLong(7, durationNs); // runtime
155+
ps.setBytes(8, serializeResult(result)); // return_value
156+
ps.setString(9, "function_call"); // verification_type
157+
ps.executeUpdate();
158+
}
159+
} catch (Exception e) {
160+
System.err.println("ReplayHelper: SQLite behavior write error: " + e.getMessage());
161+
}
162+
}
163+
164+
private void ensureBehaviorDb() throws SQLException {
165+
if (behaviorDbInitialized) return;
166+
behaviorDbInitialized = true;
167+
behaviorDb = DriverManager.getConnection("jdbc:sqlite:" + outputFile);
168+
try (java.sql.Statement stmt = behaviorDb.createStatement()) {
169+
stmt.execute("CREATE TABLE IF NOT EXISTS test_results (" +
170+
"test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " +
171+
"function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " +
172+
"runtime INTEGER, return_value BLOB, verification_type TEXT)");
173+
}
174+
}
175+
176+
private byte[] serializeResult(Object result) {
177+
if (result == null) return null;
178+
try {
179+
return Serializer.serialize(result);
180+
} catch (Exception e) {
181+
// Fall back to String.valueOf if Kryo fails
182+
return String.valueOf(result).getBytes(java.nio.charset.StandardCharsets.UTF_8);
183+
}
184+
}
185+
186+
private Object[] loadArgs(String className, String methodName, String descriptor, int invocationIndex)
187+
throws SQLException {
27188
byte[] argsBlob;
28-
try (PreparedStatement stmt = db.prepareStatement(
189+
try (PreparedStatement stmt = traceDb.prepareStatement(
29190
"SELECT args FROM function_calls " +
30191
"WHERE classname = ? AND function = ? AND descriptor = ? " +
31192
"ORDER BY time_ns LIMIT 1 OFFSET ?")) {
@@ -43,46 +204,35 @@ public void replay(String className, String methodName, String descriptor, int i
43204
}
44205
}
45206

46-
// Deserialize args
47207
Object deserialized = Serializer.deserialize(argsBlob);
48208
if (!(deserialized instanceof Object[])) {
49209
throw new RuntimeException("Deserialized args is not Object[], got: "
50210
+ (deserialized == null ? "null" : deserialized.getClass().getName()));
51211
}
52-
Object[] allArgs = (Object[]) deserialized;
53-
54-
// Load the target class
55-
Class<?> targetClass = Class.forName(className);
212+
return (Object[]) deserialized;
213+
}
56214

57-
// Parse descriptor to find parameter types
58-
Type[] paramTypes = Type.getArgumentTypes(descriptor);
59-
Class<?>[] paramClasses = new Class<?>[paramTypes.length];
60-
for (int i = 0; i < paramTypes.length; i++) {
61-
paramClasses[i] = typeToClass(paramTypes[i]);
215+
private static String getCallingTestMethodName() {
216+
StackTraceElement[] stack = Thread.currentThread().getStackTrace();
217+
// Walk up: [0]=getStackTrace, [1]=this method, [2]=replay(), [3]=calling test method
218+
for (int i = 3; i < stack.length; i++) {
219+
String method = stack[i].getMethodName();
220+
if (method.startsWith("replay_")) {
221+
return method;
222+
}
62223
}
224+
return stack.length > 3 ? stack[3].getMethodName() : "unknown";
225+
}
63226

64-
// Find the method
65-
Method method = targetClass.getDeclaredMethod(methodName, paramClasses);
66-
method.setAccessible(true);
67-
68-
boolean isStatic = Modifier.isStatic(method.getModifiers());
69-
70-
if (isStatic) {
71-
method.invoke(null, allArgs);
72-
} else {
73-
// Args contain only explicit parameters (no 'this').
74-
// Create a default instance via no-arg constructor or Kryo.
75-
Object instance;
76-
try {
77-
java.lang.reflect.Constructor<?> ctor = targetClass.getDeclaredConstructor();
78-
ctor.setAccessible(true);
79-
instance = ctor.newInstance();
80-
} catch (NoSuchMethodException e) {
81-
// Fall back to Objenesis instantiation (no constructor needed)
82-
instance = new org.objenesis.ObjenesisStd().newInstance(targetClass);
227+
private static String getCallingTestClassName() {
228+
StackTraceElement[] stack = Thread.currentThread().getStackTrace();
229+
for (int i = 3; i < stack.length; i++) {
230+
String cls = stack[i].getClassName();
231+
if (cls.contains("ReplayTest") || cls.contains("replay")) {
232+
return cls;
83233
}
84-
method.invoke(instance, allArgs);
85234
}
235+
return stack.length > 3 ? stack[3].getClassName() : "unknown";
86236
}
87237

88238
private static Class<?> typeToClass(Type type) throws ClassNotFoundException {
@@ -106,11 +256,23 @@ private static Class<?> typeToClass(Type type) throws ClassNotFoundException {
106256
}
107257
}
108258

259+
private static int parseIntEnv(String name, int defaultValue) {
260+
String val = System.getenv(name);
261+
if (val == null || val.isEmpty()) return defaultValue;
262+
try { return Integer.parseInt(val); } catch (NumberFormatException e) { return defaultValue; }
263+
}
264+
265+
private static String getEnvOrDefault(String name, String defaultValue) {
266+
String val = System.getenv(name);
267+
return (val != null && !val.isEmpty()) ? val : defaultValue;
268+
}
269+
109270
public void close() {
110-
try {
111-
if (db != null) db.close();
112-
} catch (SQLException e) {
113-
System.err.println("Error closing ReplayHelper: " + e.getMessage());
271+
try { if (traceDb != null) traceDb.close(); } catch (SQLException e) {
272+
System.err.println("Error closing ReplayHelper trace db: " + e.getMessage());
273+
}
274+
try { if (behaviorDb != null) behaviorDb.close(); } catch (SQLException e) {
275+
System.err.println("Error closing ReplayHelper behavior db: " + e.getMessage());
114276
}
115277
}
116278
}

codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceRecorder.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ public final class TraceRecorder {
2222
private final TracerConfig config;
2323
private final TraceWriter writer;
2424
private final ConcurrentHashMap<String, AtomicInteger> functionCounts = new ConcurrentHashMap<>();
25+
private final AtomicInteger droppedCaptures = new AtomicInteger(0);
2526
private final int maxFunctionCount;
2627
private final ExecutorService serializerExecutor;
2728

@@ -82,11 +83,13 @@ private void onEntryImpl(String className, String methodName, String descriptor,
8283
argsBlob = future.get(SERIALIZATION_TIMEOUT_MS, TimeUnit.MILLISECONDS);
8384
} catch (TimeoutException e) {
8485
future.cancel(true);
86+
droppedCaptures.incrementAndGet();
8587
System.err.println("[codeflash-tracer] Serialization timed out for " + className + "."
8688
+ methodName);
8789
return;
8890
} catch (Exception e) {
8991
Throwable cause = e.getCause() != null ? e.getCause() : e;
92+
droppedCaptures.incrementAndGet();
9093
System.err.println("[codeflash-tracer] Serialization failed for " + className + "."
9194
+ methodName + ": " + cause.getClass().getSimpleName() + ": " + cause.getMessage());
9295
return;
@@ -113,11 +116,15 @@ public void flush() {
113116
}
114117
metadata.put("totalCaptures", String.valueOf(totalCaptures));
115118

119+
int dropped = droppedCaptures.get();
120+
metadata.put("droppedCaptures", String.valueOf(dropped));
121+
116122
writer.writeMetadata(metadata);
117123
writer.flush();
118124
writer.close();
119125

120126
System.err.println("[codeflash-tracer] Captured " + totalCaptures
121-
+ " invocations across " + functionCounts.size() + " methods");
127+
+ " invocations across " + functionCounts.size() + " methods"
128+
+ (dropped > 0 ? " (" + dropped + " dropped due to serialization timeout/failure)" : ""));
122129
}
123130
}

codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracingClassVisitor.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,20 @@
44
import org.objectweb.asm.MethodVisitor;
55
import org.objectweb.asm.Opcodes;
66

7+
import java.util.Collections;
8+
import java.util.Map;
9+
710
public class TracingClassVisitor extends ClassVisitor {
811

912
private final String internalClassName;
13+
private final Map<String, Integer> methodLineNumbers;
1014
private String sourceFile;
1115

12-
public TracingClassVisitor(ClassVisitor classVisitor, String internalClassName) {
16+
public TracingClassVisitor(ClassVisitor classVisitor, String internalClassName,
17+
Map<String, Integer> methodLineNumbers) {
1318
super(Opcodes.ASM9, classVisitor);
1419
this.internalClassName = internalClassName;
20+
this.methodLineNumbers = methodLineNumbers != null ? methodLineNumbers : Collections.emptyMap();
1521
}
1622

1723
@Override
@@ -37,7 +43,8 @@ public MethodVisitor visitMethod(int access, String name, String descriptor,
3743
return mv;
3844
}
3945

46+
int lineNumber = methodLineNumbers.getOrDefault(name + descriptor, 0);
4047
return new TracingMethodAdapter(mv, access, name, descriptor,
41-
internalClassName, 0, sourceFile != null ? sourceFile : "");
48+
internalClassName, lineNumber, sourceFile != null ? sourceFile : "");
4249
}
4350
}

0 commit comments

Comments
 (0)