Skip to content

Commit 35fab76

Browse files
committed
Add uncaught exception handler.
1 parent fd27f87 commit 35fab76

3 files changed

Lines changed: 235 additions & 65 deletions

File tree

randomizedtesting-jupiter/src/main/java/com/carrotsearch/randomizedtesting/jupiter/DetectThreadLeaksExtension.java

Lines changed: 112 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
package com.carrotsearch.randomizedtesting.jupiter;
22

3-
import java.util.Arrays;
3+
import java.util.ArrayList;
44
import java.util.HashSet;
55
import java.util.LinkedHashSet;
6+
import java.util.List;
67
import java.util.Map;
78
import java.util.concurrent.TimeUnit;
89
import java.util.function.Predicate;
@@ -24,6 +25,7 @@ public class DetectThreadLeaksExtension
2425
ExtensionContext.Namespace.create(DetectThreadLeaksExtension.class);
2526
private static final String SNAPSHOT_KEY = "snapshot";
2627
private static final String CONCURRENT_KEY = "concurrent";
28+
private static final String UNCAUGHT_EXCEPTION_HANDLER_KEY = "uncaught-exception-handler";
2729

2830
/** Total time budget (ms) to join interrupted threads before giving up. */
2931
private static final long INTERRUPT_JOIN_MS = 2_000L;
@@ -39,48 +41,67 @@ public void beforeAll(ExtensionContext context) {
3941
return;
4042
}
4143
if (scope(context) == DetectThreadLeaks.Scope.SUITE) {
42-
context.getStore(EXTENSION_NAMESPACE).put(SNAPSHOT_KEY, liveThreads(buildFilter(context)));
44+
var store = context.getStore(EXTENSION_NAMESPACE);
45+
var filter = buildFilter(context);
46+
store.put(UNCAUGHT_EXCEPTION_HANDLER_KEY, installUncaughtExceptionHandler());
47+
store.put(SNAPSHOT_KEY, liveThreads(filter));
4348
}
4449
}
4550

4651
@Override
4752
public void afterAll(ExtensionContext context) {
48-
if (isConcurrentMode(context) || scope(context) != DetectThreadLeaks.Scope.SUITE) {
49-
return;
53+
if (isConcurrentMode(context) || scope(context) != DetectThreadLeaks.Scope.SUITE) return;
54+
var store = context.getStore(EXTENSION_NAMESPACE);
55+
var handler = store.get(UNCAUGHT_EXCEPTION_HANDLER_KEY, UncaughtExceptionsHandler.class);
56+
try {
57+
checkLeaks(
58+
store,
59+
"suite [" + context.getDisplayName() + "]",
60+
linger(context),
61+
buildFilter(context),
62+
handler);
63+
} finally {
64+
if (handler != null) handler.restore();
5065
}
51-
checkLeaks(
52-
context.getStore(EXTENSION_NAMESPACE),
53-
"suite [" + context.getDisplayName() + "]",
54-
linger(context),
55-
buildFilter(context));
5666
}
5767

5868
@Override
5969
public void beforeEach(ExtensionContext context) {
60-
if (isConcurrentMode(context) || scope(context) != DetectThreadLeaks.Scope.TEST) {
61-
return;
62-
}
63-
context.getStore(EXTENSION_NAMESPACE).put(SNAPSHOT_KEY, liveThreads(buildFilter(context)));
70+
if (isConcurrentMode(context) || scope(context) != DetectThreadLeaks.Scope.TEST) return;
71+
var store = context.getStore(EXTENSION_NAMESPACE);
72+
var filter = buildFilter(context);
73+
store.put(UNCAUGHT_EXCEPTION_HANDLER_KEY, installUncaughtExceptionHandler());
74+
store.put(SNAPSHOT_KEY, liveThreads(filter));
6475
}
6576

6677
@Override
6778
public void afterEach(ExtensionContext context) {
68-
if (isConcurrentMode(context) || scope(context) != DetectThreadLeaks.Scope.TEST) {
69-
return;
79+
if (isConcurrentMode(context) || scope(context) != DetectThreadLeaks.Scope.TEST) return;
80+
var store = context.getStore(EXTENSION_NAMESPACE);
81+
var handler = store.get(UNCAUGHT_EXCEPTION_HANDLER_KEY, UncaughtExceptionsHandler.class);
82+
try {
83+
checkLeaks(
84+
store,
85+
"test [" + context.getDisplayName() + "]",
86+
linger(context),
87+
buildFilter(context),
88+
handler);
89+
} finally {
90+
if (handler != null) handler.restore();
7091
}
71-
checkLeaks(
72-
context.getStore(EXTENSION_NAMESPACE),
73-
"test [" + context.getDisplayName() + "]",
74-
linger(context),
75-
buildFilter(context));
92+
}
93+
94+
private static UncaughtExceptionsHandler installUncaughtExceptionHandler() {
95+
var handler = new UncaughtExceptionsHandler(Thread.getDefaultUncaughtExceptionHandler());
96+
Thread.setDefaultUncaughtExceptionHandler(handler);
97+
return handler;
7698
}
7799

78100
private static DetectThreadLeaks.Scope scope(ExtensionContext context) {
79101
return context.getRequiredTestClass().getAnnotation(DetectThreadLeaks.class).scope();
80102
}
81103

82104
private static int linger(ExtensionContext context) {
83-
// Method-level annotation takes precedence over class-level.
84105
var methodAnn =
85106
context
86107
.getTestMethod()
@@ -94,7 +115,7 @@ private static int linger(ExtensionContext context) {
94115

95116
/**
96117
* Collects {@link DetectThreadLeaks.ExcludeThreads} filter classes from the entire hierarchy
97-
* (method class superclasses) and returns a combined predicate that excludes a thread when
118+
* (method to class to superclasses) and returns a combined predicate that excludes a thread when
98119
* any filter matches it.
99120
*/
100121
private static Predicate<Thread> buildFilter(ExtensionContext context) {
@@ -113,21 +134,19 @@ private static Predicate<Thread> buildFilter(ExtensionContext context) {
113134
for (Class<?> cls = context.getRequiredTestClass(); cls != null; cls = cls.getSuperclass()) {
114135
var ann = cls.getAnnotation(DetectThreadLeaks.ExcludeThreads.class);
115136
if (ann != null) {
116-
filterClasses.addAll(Arrays.asList(ann.value()));
137+
for (var c : ann.value()) filterClasses.add(c);
117138
}
118139
}
119140

120-
if (filterClasses.isEmpty()) {
121-
return t -> false;
122-
}
141+
if (filterClasses.isEmpty()) return t -> false;
123142

124143
var predicates =
125144
filterClasses.stream()
126145
.map(
127146
cls -> {
128147
try {
129148
return (Predicate<Thread>) cls.getDeclaredConstructor().newInstance();
130-
} catch (Exception e) {
149+
} catch (ReflectiveOperationException e) {
131150
throw new RuntimeException(
132151
"Cannot instantiate thread filter: " + cls.getName(), e);
133152
}
@@ -138,7 +157,6 @@ private static Predicate<Thread> buildFilter(ExtensionContext context) {
138157
}
139158

140159
private static boolean isConcurrentMode(ExtensionContext context) {
141-
// Check the concurrent flag stored in beforeAll (class-level context = parent of method ctx).
142160
return context
143161
.getParent()
144162
.map(
@@ -149,53 +167,82 @@ private static boolean isConcurrentMode(ExtensionContext context) {
149167
}
150168

151169
private static void checkLeaks(
152-
ExtensionContext.Store store, String description, int lingerMs, Predicate<Thread> filter) {
170+
ExtensionContext.Store store,
171+
String description,
172+
int lingerMs,
173+
Predicate<Thread> filter,
174+
UncaughtExceptionsHandler handler) {
153175
var snapshot = store.get(SNAPSHOT_KEY, HashSet.class);
154-
if (snapshot == null) return;
176+
AssertionError leakError = null;
155177

156-
var leaked = leakedSince(snapshot, filter);
157-
if (leaked.isEmpty()) return;
178+
if (snapshot != null) {
179+
var leaked = leakedSince(snapshot, filter);
158180

159-
// Linger: poll until threads self-terminate or the window expires.
160-
if (lingerMs > 0) {
161-
long deadline = System.nanoTime() + TimeUnit.MILLISECONDS.toNanos(lingerMs);
162-
while (!leaked.isEmpty() && System.nanoTime() < deadline) {
163-
try {
164-
long remainingMs = TimeUnit.NANOSECONDS.toMillis(deadline - System.nanoTime());
165-
Thread.sleep(Math.max(1L, Math.min(100L, remainingMs)));
166-
} catch (InterruptedException e) {
167-
Thread.currentThread().interrupt();
168-
break;
181+
// Linger: poll until threads self-terminate or the window expires.
182+
if (!leaked.isEmpty() && lingerMs > 0) {
183+
long deadline = System.nanoTime() + TimeUnit.MILLISECONDS.toNanos(lingerMs);
184+
while (!leaked.isEmpty() && System.nanoTime() < deadline) {
185+
try {
186+
long remainingMs = TimeUnit.NANOSECONDS.toMillis(deadline - System.nanoTime());
187+
Thread.sleep(Math.max(1L, Math.min(100L, remainingMs)));
188+
} catch (InterruptedException e) {
189+
Thread.currentThread().interrupt();
190+
break;
191+
}
192+
leaked = leakedSince(snapshot, filter);
169193
}
170-
leaked = leakedSince(snapshot, filter);
171194
}
172-
if (leaked.isEmpty()) return;
173-
}
174195

175-
// Interrupt leaked threads for cleanup, then wait briefly for them to terminate.
176-
leaked.keySet().forEach(Thread::interrupt);
177-
long joinDeadline = System.nanoTime() + TimeUnit.MILLISECONDS.toNanos(INTERRUPT_JOIN_MS);
178-
for (Thread t : leaked.keySet()) {
179-
long remaining = TimeUnit.NANOSECONDS.toMillis(joinDeadline - System.nanoTime());
180-
if (remaining <= 0) break;
181-
try {
182-
t.join(remaining);
183-
} catch (InterruptedException e) {
184-
Thread.currentThread().interrupt();
185-
break;
196+
if (!leaked.isEmpty()) {
197+
// Suppress uncaught exception reporting during the interrupt/join phase to avoid
198+
// capturing expected InterruptedException-related exceptions from cleaned-up threads.
199+
if (handler != null) handler.stopReporting();
200+
try {
201+
leaked.keySet().forEach(Thread::interrupt);
202+
long joinDeadline = System.nanoTime() + TimeUnit.MILLISECONDS.toNanos(INTERRUPT_JOIN_MS);
203+
for (Thread t : leaked.keySet()) {
204+
long remaining = TimeUnit.NANOSECONDS.toMillis(joinDeadline - System.nanoTime());
205+
if (remaining <= 0) break;
206+
try {
207+
t.join(remaining);
208+
} catch (InterruptedException e) {
209+
Thread.currentThread().interrupt();
210+
break;
211+
}
212+
}
213+
} finally {
214+
if (handler != null) handler.resumeReporting();
215+
}
216+
217+
var sb = new StringBuilder(leaked.size() + " thread(s) leaked from " + description + ":");
218+
int cnt = 1;
219+
for (var entry : leaked.entrySet()) {
220+
sb.append(String.format("%n %2d) %s", cnt++, Threads.threadName(entry.getKey())));
221+
for (var ste : entry.getValue()) {
222+
sb.append(String.format("%n at %s", ste));
223+
}
224+
}
225+
leakError = new AssertionError(sb.toString());
186226
}
187227
}
188228

189-
// Report failure with stack traces captured before the interrupt.
190-
var sb = new StringBuilder(leaked.size() + " thread(s) leaked from " + description + ":");
191-
int cnt = 1;
192-
for (var entry : leaked.entrySet()) {
193-
sb.append(String.format("%n %2d) %s", cnt++, Threads.threadName(entry.getKey())));
194-
for (var ste : entry.getValue()) {
195-
sb.append(String.format("%n at %s", ste));
196-
}
229+
// Collect uncaught exceptions regardless of whether threads leaked.
230+
List<UncaughtExceptionsHandler.UncaughtException> uncaught =
231+
handler != null ? handler.getAndClear() : List.of();
232+
233+
if (leakError == null && uncaught.isEmpty()) return;
234+
235+
// Combine: leak error first (if any), uncaught exceptions after; all but the first
236+
// are attached as suppressed on the thrown error.
237+
var errors = new ArrayList<AssertionError>();
238+
if (leakError != null) errors.add(leakError);
239+
for (var ue : uncaught) {
240+
errors.add(
241+
new AssertionError("Uncaught exception in thread [" + ue.threadName() + "]", ue.error()));
197242
}
198-
throw new AssertionError(sb.toString());
243+
var first = errors.get(0);
244+
errors.subList(1, errors.size()).forEach(first::addSuppressed);
245+
throw first;
199246
}
200247

201248
private static Map<Thread, StackTraceElement[]> leakedSince(
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package com.carrotsearch.randomizedtesting.jupiter;
2+
3+
import java.util.ArrayList;
4+
import java.util.List;
5+
import java.util.logging.Level;
6+
import java.util.logging.Logger;
7+
8+
/** Collects uncaught exceptions from threads during test execution. */
9+
class UncaughtExceptionsHandler implements Thread.UncaughtExceptionHandler {
10+
private static final Logger LOGGER = Logger.getLogger(UncaughtExceptionsHandler.class.getName());
11+
12+
record UncaughtException(String threadName, Throwable error) {}
13+
14+
private final Thread.UncaughtExceptionHandler previous;
15+
private final List<UncaughtException> exceptions = new ArrayList<>();
16+
private boolean reporting = true;
17+
18+
UncaughtExceptionsHandler(Thread.UncaughtExceptionHandler previous) {
19+
this.previous = previous;
20+
}
21+
22+
@Override
23+
public void uncaughtException(Thread t, Throwable e) {
24+
synchronized (exceptions) {
25+
if (reporting) {
26+
LOGGER.log(Level.SEVERE, "Uncaught exception in thread: " + Threads.threadName(t), e);
27+
exceptions.add(new UncaughtException(Threads.threadName(t), e));
28+
}
29+
}
30+
if (previous != null) previous.uncaughtException(t, e);
31+
}
32+
33+
void stopReporting() {
34+
synchronized (exceptions) {
35+
reporting = false;
36+
}
37+
}
38+
39+
void resumeReporting() {
40+
synchronized (exceptions) {
41+
reporting = true;
42+
}
43+
}
44+
45+
List<UncaughtException> getAndClear() {
46+
synchronized (exceptions) {
47+
var copy = new ArrayList<>(exceptions);
48+
exceptions.clear();
49+
return copy;
50+
}
51+
}
52+
53+
/** Restores the previous default uncaught exception handler. */
54+
void restore() {
55+
Thread.setDefaultUncaughtExceptionHandler(previous);
56+
}
57+
}

randomizedtesting-jupiter/src/test/java/com/carrotsearch/randomizedtesting/jupiter/F005_ThreadLeaks.java

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,72 @@ private static void startNamedThread(String name) {
343343
t.start();
344344
}
345345

346+
@Nested
347+
class TestUncaughtExceptions {
348+
@Test
349+
void uncaughtExceptionFailsTheTest() {
350+
collectExecutionResults(testKitBuilder(UncaughtInTestMethod.class))
351+
.results()
352+
.allEvents()
353+
.finished()
354+
.failed()
355+
.assertEventsMatchExactly(
356+
event(
357+
finishedWithFailure(
358+
instanceOf(AssertionError.class),
359+
new Condition<>(
360+
t ->
361+
t.getCause() instanceof RuntimeException rc
362+
&& "uncaught-test-exception".equals(rc.getMessage()),
363+
"cause is the original RuntimeException"))));
364+
}
365+
366+
@Test
367+
void uncaughtExceptionsWithThreadLeaksAreNotReported() {
368+
collectExecutionResults(testKitBuilder(UncaughtWithLeak.class))
369+
.results()
370+
.allEvents()
371+
.finished()
372+
.failed()
373+
.assertEventsMatchExactly(
374+
event(
375+
finishedWithFailure(
376+
instanceOf(AssertionError.class),
377+
new Condition<>(t -> t.getCause() == null, "cause is empty."))));
378+
}
379+
380+
@DetectThreadLeaks(scope = DetectThreadLeaks.Scope.SUITE)
381+
static class UncaughtInTestMethod extends IgnoreInStandaloneRuns {
382+
@Test
383+
void testMethod() throws InterruptedException {
384+
var t =
385+
new Thread(
386+
() -> {
387+
throw new RuntimeException("uncaught-test-exception");
388+
});
389+
t.start();
390+
t.join();
391+
}
392+
}
393+
394+
@DetectThreadLeaks(scope = DetectThreadLeaks.Scope.TEST)
395+
static class UncaughtWithLeak extends IgnoreInStandaloneRuns {
396+
@Test
397+
void testMethod() {
398+
var t1 =
399+
new Thread(
400+
() -> {
401+
try {
402+
Thread.sleep(TimeUnit.MINUTES.toMillis(1));
403+
} catch (InterruptedException ignored) {
404+
throw new RuntimeException("uncaught-test-exception");
405+
}
406+
});
407+
t1.start();
408+
}
409+
}
410+
}
411+
346412
/** Starts a daemon thread that sleeps long enough to be observable as a leak. */
347413
private static void startSleepingThread() {
348414
var t =

0 commit comments

Comments
 (0)