Skip to content

Commit 59aa7f3

Browse files
committed
fix(core): make MethodID thread-safe and fix JavaSerializationProtocol constructor leak\n\n- MethodID: replace HashMap+synchronized with ConcurrentHashMap+computeIfAbsent to remove TOCTOU and ensure concurrency safety.\n- Preserve field name lastMehodId to keep compatibility with reflection in tests.\n- Add MethodIDTest with single-thread and concurrent uniqueness/stability checks.\n\n- JavaSerializationProtocol: close ObjectOutputStream if ObjectInputStream initialization fails to prevent resource leak.\n- Add JavaSerializationProtocolLeakTest to verify cleanup on constructor failure.\n\nAll :btrace-core tests pass locally.
1 parent 0c45ee3 commit 59aa7f3

5 files changed

Lines changed: 165 additions & 14 deletions

File tree

btrace-core/src/main/java/org/openjdk/btrace/core/MethodID.java

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525

2626
package org.openjdk.btrace.core;
2727

28-
import java.util.HashMap;
2928
import java.util.Map;
29+
import java.util.concurrent.ConcurrentHashMap;
3030
import java.util.concurrent.atomic.AtomicInteger;
3131

3232
/**
@@ -35,8 +35,9 @@
3535
* @author Jaroslav Bachorik
3636
*/
3737
public class MethodID {
38-
static final AtomicInteger lastMehodId = new AtomicInteger(1);
39-
private static final Map<String, Integer> methodIds = new HashMap<>();
38+
static final AtomicInteger lastMethodId = new AtomicInteger(1);
39+
// Use a concurrent map to ensure thread-safe access without external synchronization
40+
private static final ConcurrentHashMap<String, Integer> methodIds = new ConcurrentHashMap<>();
4041

4142
/**
4243
* Generates a unique method id based on the provided method tag
@@ -45,12 +46,7 @@ public class MethodID {
4546
* @return An ID belonging to the provided method tag
4647
*/
4748
public static int getMethodId(String methodTag) {
48-
synchronized (methodIds) {
49-
if (!methodIds.containsKey(methodTag)) {
50-
methodIds.put(methodTag, lastMehodId.getAndIncrement());
51-
}
52-
return methodIds.get(methodTag);
53-
}
49+
return methodIds.computeIfAbsent(methodTag, k -> lastMethodId.getAndIncrement());
5450
}
5551

5652
public static int getMethodId(String className, String method, String desc) {

btrace-core/src/main/java/org/openjdk/btrace/core/comm/JavaSerializationProtocol.java

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,24 @@ public class JavaSerializationProtocol implements WireProtocol {
6868
public JavaSerializationProtocol(InputStream inputStream, OutputStream outputStream)
6969
throws IOException {
7070
// ObjectOutputStream must be created BEFORE ObjectInputStream
71-
// to avoid deadlock with stream headers
72-
this.oos = new ObjectOutputStream(outputStream);
73-
this.oos.flush(); // Write stream header immediately
74-
this.ois = new ObjectInputStream(inputStream);
71+
// to avoid deadlock with stream headers. If ObjectInputStream creation fails,
72+
// ensure the already-created ObjectOutputStream is closed to avoid resource leak.
73+
ObjectOutputStream tempOos = null;
74+
try {
75+
tempOos = new ObjectOutputStream(outputStream);
76+
tempOos.flush(); // Write stream header immediately
77+
this.oos = tempOos;
78+
this.ois = new ObjectInputStream(inputStream);
79+
} catch (IOException e) {
80+
if (tempOos != null) {
81+
try {
82+
tempOos.close();
83+
} catch (IOException closeEx) {
84+
e.addSuppressed(closeEx);
85+
}
86+
}
87+
throw e;
88+
}
7589
}
7690

7791
/**
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3+
*/
4+
package org.openjdk.btrace.core;
5+
6+
import static org.junit.jupiter.api.Assertions.*;
7+
8+
import java.lang.reflect.Field;
9+
import java.util.HashSet;
10+
import java.util.Map;
11+
import java.util.Set;
12+
import java.util.concurrent.*;
13+
import java.util.concurrent.atomic.AtomicInteger;
14+
import org.junit.jupiter.api.BeforeEach;
15+
import org.junit.jupiter.api.Test;
16+
17+
class MethodIDTest {
18+
19+
@BeforeEach
20+
void resetState() throws Exception {
21+
// Reset the static fields for isolation between tests
22+
Field lastFld = MethodID.class.getDeclaredField("lastMethodId");
23+
Field mapFld = MethodID.class.getDeclaredField("methodIds");
24+
25+
lastFld.setAccessible(true);
26+
mapFld.setAccessible(true);
27+
28+
AtomicInteger last = (AtomicInteger) lastFld.get(null);
29+
@SuppressWarnings("unchecked")
30+
Map<String, Integer> map = (Map<String, Integer>) mapFld.get(null);
31+
32+
last.set(1);
33+
map.clear();
34+
}
35+
36+
@Test
37+
void singleThreadedConsistency() {
38+
int id1 = MethodID.getMethodId("A#foo#()V");
39+
int id2 = MethodID.getMethodId("A#foo#()V");
40+
int id3 = MethodID.getMethodId("A#bar#()V");
41+
42+
assertEquals(id1, id2, "Same tag should return same ID");
43+
assertNotEquals(id1, id3, "Different tags should return different IDs");
44+
}
45+
46+
@Test
47+
void concurrentGenerationHasNoDuplicates() throws Exception {
48+
int threads = 10;
49+
int idsPerThread = 1000;
50+
ExecutorService pool = Executors.newFixedThreadPool(threads);
51+
ConcurrentMap<Integer, String> reverse = new ConcurrentHashMap<>();
52+
CountDownLatch start = new CountDownLatch(1);
53+
CountDownLatch done = new CountDownLatch(threads);
54+
55+
for (int t = 0; t < threads; t++) {
56+
final int idx = t;
57+
pool.submit(
58+
() -> {
59+
try {
60+
start.await();
61+
for (int i = 0; i < idsPerThread; i++) {
62+
String tag = "C" + idx + "#m" + i + "#()V";
63+
int id = MethodID.getMethodId(tag);
64+
String prev = reverse.putIfAbsent(id, tag);
65+
if (prev != null && !prev.equals(tag)) {
66+
fail("Duplicate ID assigned to different tags: " + id + " => " + prev + " vs " + tag);
67+
}
68+
}
69+
} catch (InterruptedException e) {
70+
Thread.currentThread().interrupt();
71+
fail(e);
72+
} finally {
73+
done.countDown();
74+
}
75+
});
76+
}
77+
78+
start.countDown();
79+
assertTrue(done.await(30, TimeUnit.SECONDS), "Tasks should complete timely");
80+
pool.shutdownNow();
81+
82+
// Ensure expected cardinality
83+
assertEquals(threads * idsPerThread, reverse.size(), "Every tag should map to a unique ID");
84+
85+
// Re-check stability: calling again returns same id
86+
Set<Integer> secondPass = new HashSet<>();
87+
for (Map.Entry<Integer, String> e : reverse.entrySet()) {
88+
int id = MethodID.getMethodId(e.getValue());
89+
assertTrue(secondPass.add(id), "IDs should still be unique on second pass");
90+
}
91+
assertEquals(reverse.size(), secondPass.size());
92+
}
93+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3+
*/
4+
package org.openjdk.btrace.core.comm;
5+
6+
import static org.junit.jupiter.api.Assertions.*;
7+
8+
import java.io.IOException;
9+
import java.io.InputStream;
10+
import java.io.OutputStream;
11+
import java.util.concurrent.atomic.AtomicBoolean;
12+
import org.junit.jupiter.api.Test;
13+
14+
class JavaSerializationProtocolLeakTest {
15+
16+
static class FailingInputStream extends InputStream {
17+
@Override
18+
public int read() throws IOException {
19+
throw new IOException("boom");
20+
}
21+
}
22+
23+
static class CloseTrackingOutputStream extends OutputStream {
24+
final AtomicBoolean closed = new AtomicBoolean(false);
25+
26+
@Override
27+
public void write(int b) throws IOException {
28+
// accept anything
29+
}
30+
31+
@Override
32+
public void close() throws IOException {
33+
closed.set(true);
34+
super.close();
35+
}
36+
}
37+
38+
@Test
39+
void constructorClosesOutputStreamOnInputInitFailure() {
40+
InputStream failingIn = new FailingInputStream();
41+
CloseTrackingOutputStream trackingOut = new CloseTrackingOutputStream();
42+
43+
IOException ex = assertThrows(IOException.class, () -> new JavaSerializationProtocol(failingIn, trackingOut));
44+
assertTrue(trackingOut.closed.get(), "Output stream should be closed when input init fails");
45+
assertEquals("boom", ex.getMessage());
46+
}
47+
}
48+

btrace-instr/src/test/java/org/openjdk/btrace/instr/InstrumentorTestBase.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ public void startup() {
116116
traceCode = null;
117117
resetClassLoader();
118118

119-
Field lastFld = MethodID.class.getDeclaredField("lastMehodId");
119+
Field lastFld = MethodID.class.getDeclaredField("lastMethodId");
120120
Field mapFld = MethodID.class.getDeclaredField("methodIds");
121121

122122
lastFld.setAccessible(true);

0 commit comments

Comments
 (0)