Skip to content

Commit 8507e2a

Browse files
authored
Add thread leak detection and unhandled exception capture functionality (#7)
* Initial implementation of @DetectThreadLeaks annotation. * More progress. * Thread leaks and thread filters/ exclusions. * Add documentation. * Add uncaught exception handler. * Add migration notes. * More fixes to thread leaks. * More minor tweaks and changes. * Add a comment about extensions. * Adding javadocs.
1 parent 4ef7b03 commit 8507e2a

12 files changed

Lines changed: 1162 additions & 54 deletions

File tree

etc/junit4-missing-features.txt

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,5 @@
11
[ai generated overview of junit4 features]
22

3-
8. Timeouts
4-
- Standard JUnit @Test(timeout=N) is honoured
5-
- @Timeout(millis=N) annotation provides an explicit alternative
6-
- Termination sequence: Thread.interrupt() → Thread.stop() → zombie
7-
detection; all attempts are logged with stack traces
8-
9-
9. Thread-leak detection
10-
- Threads that escape a test's ThreadGroup boundary are killed and cause
11-
a test failure
12-
- Encourages explicit Thread.join() before a test method returns
13-
14-
10. Lingering threads and advanced thread-leak control
15-
- @ThreadLeakLingering(linger=N) waits up to N ms for stray threads to
16-
finish naturally (useful for Executor pools or other uncontrolled threads)
17-
- Additional annotations for fine-grained policy:
18-
@ThreadLeakScope – suite vs. test scope
19-
@ThreadLeakAction – warn vs. fail
20-
@ThreadLeakZombies – ignore vs. fail on zombie threads
21-
223
11. Nightly / scaled tests
234
- @Nightly marks a test that only runs when nightly mode is active
245
(-Dtests.nightly=true)
@@ -38,4 +19,12 @@
3819

3920
- predictably shuffled test execution order
4021
- blowing up test reps using tests.iters
41-
-
22+
23+
[to check/ add tests of]
24+
25+
- is the seed stack trace frame injected for leaked threads + randomized testing ext?
26+
- can we enforce the order of extensions (randomized testing > leaked threads)
27+
- how are jupiter timeouts working together with leaked threads ext.?
28+
- maybe bring back thread leak zombies annotation (if we can't cleanly terminate leaked threads, ignore all remaining tests).
29+
- maybe move some of the implementation details to a non-exposed package?
30+
- regenerate the javadocs with public API only.
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package com.carrotsearch.randomizedtesting.jupiter;
2+
3+
import java.lang.annotation.Documented;
4+
import java.lang.annotation.ElementType;
5+
import java.lang.annotation.Inherited;
6+
import java.lang.annotation.Retention;
7+
import java.lang.annotation.RetentionPolicy;
8+
import java.lang.annotation.Target;
9+
import java.util.function.Predicate;
10+
import org.junit.jupiter.api.extension.ExtendWith;
11+
12+
/**
13+
* Detects threads started within the annotated test class that are still alive after the configured
14+
* scope ends.
15+
*
16+
* <p>Only functional in sequential (same-thread) execution mode. Emits a warning and skips
17+
* detection if tests run concurrently.
18+
*/
19+
@Target({ElementType.TYPE})
20+
@Retention(RetentionPolicy.RUNTIME)
21+
@Documented
22+
@ExtendWith(DetectThreadLeaksExtension.class)
23+
@Inherited
24+
public @interface DetectThreadLeaks {
25+
/** Scope at which thread leak detection is performed. */
26+
Scope scope() default Scope.SUITE;
27+
28+
enum Scope {
29+
/** Disable thread leak detection entirely. */
30+
NONE,
31+
/** Check for leaked threads once after all tests in the class complete. */
32+
SUITE,
33+
/** Check for leaked threads after each individual test method. */
34+
TEST
35+
}
36+
37+
/**
38+
* Milliseconds to wait for leaked threads to self-terminate before declaring a failure. If all
39+
* leaked threads terminate within this window, the test passes. Default is 0 (no lingering).
40+
*
41+
* <p>Place this annotation on the same class or method as {@link DetectThreadLeaks}. A
42+
* method-level annotation takes precedence over a class-level one.
43+
*/
44+
@Target({ElementType.TYPE, ElementType.METHOD})
45+
@Retention(RetentionPolicy.RUNTIME)
46+
@Documented
47+
@Inherited
48+
@interface LingerTime {
49+
int millis();
50+
}
51+
52+
/**
53+
* Excludes threads matched by any of the given {@link Predicate} classes from leak detection. A
54+
* thread is excluded when at least one predicate returns {@code true} for it.
55+
*
56+
* <p>Annotations are collected hierarchically from the class and its superclasses, and the
57+
* filters from all levels are combined.
58+
*
59+
* @see SystemThreadFilter
60+
*/
61+
@Target({ElementType.TYPE})
62+
@Retention(RetentionPolicy.RUNTIME)
63+
@Documented
64+
@interface ExcludeThreads {
65+
Class<? extends Predicate<Thread>>[] value() default {SystemThreadFilter.class};
66+
}
67+
}
Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
package com.carrotsearch.randomizedtesting.jupiter;
2+
3+
import java.time.Duration;
4+
import java.util.ArrayList;
5+
import java.util.HashSet;
6+
import java.util.LinkedHashSet;
7+
import java.util.List;
8+
import java.util.Map;
9+
import java.util.concurrent.TimeUnit;
10+
import java.util.function.Predicate;
11+
import java.util.logging.Logger;
12+
import java.util.stream.Collectors;
13+
import org.junit.jupiter.api.extension.AfterAllCallback;
14+
import org.junit.jupiter.api.extension.AfterEachCallback;
15+
import org.junit.jupiter.api.extension.BeforeAllCallback;
16+
import org.junit.jupiter.api.extension.BeforeEachCallback;
17+
import org.junit.jupiter.api.extension.ExtensionContext;
18+
import org.junit.jupiter.api.parallel.ExecutionMode;
19+
20+
/** JUnit Jupiter extension implementing {@link DetectThreadLeaks}. */
21+
public class DetectThreadLeaksExtension
22+
implements BeforeAllCallback, AfterAllCallback, BeforeEachCallback, AfterEachCallback {
23+
24+
private static final Logger LOGGER = Logger.getLogger(DetectThreadLeaksExtension.class.getName());
25+
private static final ExtensionContext.Namespace EXTENSION_NAMESPACE =
26+
ExtensionContext.Namespace.create(DetectThreadLeaksExtension.class);
27+
private static final String THREAD_SNAPSHOT_KEY = "snapshot";
28+
private static final String CONCURRENT_KEY = "concurrent";
29+
private static final String UNCAUGHT_EXCEPTION_HANDLER_KEY = "uncaught-exception-handler";
30+
31+
/** Total time budget to join interrupted threads before giving up. */
32+
private static final Duration INTERRUPT_JOIN_MS = Duration.ofSeconds(3);
33+
34+
@Override
35+
public void beforeAll(ExtensionContext context) {
36+
if (scope(context) == DetectThreadLeaks.Scope.NONE) {
37+
return;
38+
}
39+
40+
if (context.getExecutionMode() != ExecutionMode.SAME_THREAD) {
41+
LOGGER.warning(
42+
"Thread leak detection is disabled: tests in ["
43+
+ context.getDisplayName()
44+
+ "] run in concurrent execution mode.");
45+
context.getStore(EXTENSION_NAMESPACE).put(CONCURRENT_KEY, Boolean.TRUE);
46+
return;
47+
}
48+
49+
var store = context.getStore(EXTENSION_NAMESPACE);
50+
var filter = buildFilter(context);
51+
store.put(UNCAUGHT_EXCEPTION_HANDLER_KEY, installUncaughtExceptionHandler());
52+
store.put(THREAD_SNAPSHOT_KEY, liveThreads(filter));
53+
}
54+
55+
@Override
56+
public void beforeEach(ExtensionContext context) {
57+
if (isConcurrentMode(context) || scope(context) != DetectThreadLeaks.Scope.TEST) {
58+
return;
59+
}
60+
61+
var store = context.getStore(EXTENSION_NAMESPACE);
62+
var filter = buildFilter(context);
63+
store.put(THREAD_SNAPSHOT_KEY, liveThreads(filter));
64+
}
65+
66+
@Override
67+
public void afterEach(ExtensionContext context) {
68+
if (isConcurrentMode(context) || scope(context) != DetectThreadLeaks.Scope.TEST) {
69+
return;
70+
}
71+
72+
var store = context.getStore(EXTENSION_NAMESPACE);
73+
var handler = store.get(UNCAUGHT_EXCEPTION_HANDLER_KEY, UncaughtExceptionsHandler.class);
74+
try {
75+
checkLeaks(
76+
store,
77+
"test [" + context.getDisplayName() + "]",
78+
linger(context),
79+
buildFilter(context),
80+
handler);
81+
} finally {
82+
if (handler != null) handler.restore();
83+
}
84+
}
85+
86+
@Override
87+
public void afterAll(ExtensionContext context) {
88+
if (isConcurrentMode(context) || scope(context) == DetectThreadLeaks.Scope.NONE) {
89+
return;
90+
}
91+
92+
var store = context.getStore(EXTENSION_NAMESPACE);
93+
var handler = store.get(UNCAUGHT_EXCEPTION_HANDLER_KEY, UncaughtExceptionsHandler.class);
94+
try {
95+
checkLeaks(
96+
store,
97+
"suite [" + context.getDisplayName() + "]",
98+
linger(context),
99+
buildFilter(context),
100+
handler);
101+
} finally {
102+
if (handler != null) handler.restore();
103+
}
104+
}
105+
106+
private static UncaughtExceptionsHandler installUncaughtExceptionHandler() {
107+
var handler = new UncaughtExceptionsHandler(Thread.getDefaultUncaughtExceptionHandler());
108+
Thread.setDefaultUncaughtExceptionHandler(handler);
109+
return handler;
110+
}
111+
112+
private static DetectThreadLeaks.Scope scope(ExtensionContext context) {
113+
return context.getRequiredTestClass().getAnnotation(DetectThreadLeaks.class).scope();
114+
}
115+
116+
private static int linger(ExtensionContext context) {
117+
var methodAnn =
118+
context
119+
.getTestMethod()
120+
.map(m -> m.getAnnotation(DetectThreadLeaks.LingerTime.class))
121+
.orElse(null);
122+
if (methodAnn != null) return methodAnn.millis();
123+
124+
var classAnn = context.getRequiredTestClass().getAnnotation(DetectThreadLeaks.LingerTime.class);
125+
return classAnn == null ? 0 : classAnn.millis();
126+
}
127+
128+
@DetectThreadLeaks.ExcludeThreads()
129+
private static class AnnotationDefaultsSource {}
130+
131+
/**
132+
* Collects {@link DetectThreadLeaks.ExcludeThreads} filter classes from the entire hierarchy
133+
* (method to class to superclasses) and returns a combined predicate that excludes a thread when
134+
* any filter matches it.
135+
*/
136+
private static Predicate<Thread> buildFilter(ExtensionContext context) {
137+
List<DetectThreadLeaks.ExcludeThreads> excludeThreads = new ArrayList<>();
138+
139+
for (Class<?> cls = context.getRequiredTestClass(); cls != null; cls = cls.getSuperclass()) {
140+
var ann = cls.getAnnotation(DetectThreadLeaks.ExcludeThreads.class);
141+
if (ann != null) {
142+
excludeThreads.add(ann);
143+
}
144+
}
145+
146+
if (excludeThreads.isEmpty()) {
147+
excludeThreads.add(
148+
AnnotationDefaultsSource.class.getAnnotation(DetectThreadLeaks.ExcludeThreads.class));
149+
}
150+
151+
var filterClasses = new LinkedHashSet<Predicate<Thread>>();
152+
for (var ann : excludeThreads) {
153+
for (var cls : ann.value()) {
154+
try {
155+
filterClasses.add(cls.getDeclaredConstructor().newInstance());
156+
} catch (ReflectiveOperationException e) {
157+
throw new RuntimeException("Cannot instantiate thread filter: " + cls.getName(), e);
158+
}
159+
}
160+
}
161+
162+
if (filterClasses.isEmpty()) {
163+
return t -> false;
164+
} else {
165+
return t -> filterClasses.stream().anyMatch(p -> p.test(t));
166+
}
167+
}
168+
169+
private static boolean isConcurrentMode(ExtensionContext context) {
170+
return context
171+
.getParent()
172+
.map(
173+
p ->
174+
Boolean.TRUE.equals(
175+
p.getStore(EXTENSION_NAMESPACE).get(CONCURRENT_KEY, Boolean.class)))
176+
.orElse(false);
177+
}
178+
179+
private static void checkLeaks(
180+
ExtensionContext.Store store,
181+
String description,
182+
int lingerMs,
183+
Predicate<Thread> filter,
184+
UncaughtExceptionsHandler handler) {
185+
var snapshot = store.get(THREAD_SNAPSHOT_KEY, HashSet.class);
186+
AssertionError leakError = null;
187+
188+
if (snapshot != null) {
189+
var leaked = leakedSince(snapshot, filter);
190+
191+
// Linger: poll until threads self-terminate or the window expires.
192+
if (!leaked.isEmpty() && lingerMs > 0) {
193+
long deadline = System.nanoTime() + TimeUnit.MILLISECONDS.toNanos(lingerMs);
194+
while (!leaked.isEmpty() && System.nanoTime() < deadline) {
195+
try {
196+
long remainingMs = TimeUnit.NANOSECONDS.toMillis(deadline - System.nanoTime());
197+
Thread.sleep(Math.max(1L, Math.min(100L, remainingMs)));
198+
} catch (InterruptedException e) {
199+
Thread.currentThread().interrupt();
200+
break;
201+
}
202+
leaked = leakedSince(snapshot, filter);
203+
}
204+
}
205+
206+
if (!leaked.isEmpty()) {
207+
// Suppress uncaught exception reporting during the interrupt/join phase to avoid
208+
// capturing expected InterruptedException-related exceptions from cleaned-up threads.
209+
if (handler != null) {
210+
handler.stopReporting();
211+
}
212+
213+
try {
214+
// Send an interrupt to all threads.
215+
leaked.keySet().forEach(Thread::interrupt);
216+
217+
// Wait for all those threads.
218+
long joinDeadline = System.nanoTime() + INTERRUPT_JOIN_MS.toNanos();
219+
for (Thread t : leaked.keySet()) {
220+
long remaining = TimeUnit.NANOSECONDS.toMillis(joinDeadline - System.nanoTime());
221+
if (remaining <= 0) {
222+
break;
223+
}
224+
225+
try {
226+
t.join(remaining);
227+
} catch (InterruptedException e) {
228+
Thread.currentThread().interrupt();
229+
break;
230+
}
231+
}
232+
} finally {
233+
if (handler != null) {
234+
handler.resumeReporting();
235+
}
236+
}
237+
238+
var sb = new StringBuilder(leaked.size() + " thread(s) leaked from " + description + ":");
239+
int cnt = 1;
240+
for (var entry : leaked.entrySet()) {
241+
sb.append(String.format("%n %2d) %s", cnt++, Threads.threadName(entry.getKey())));
242+
for (var ste : entry.getValue()) {
243+
sb.append(String.format("%n at %s", ste));
244+
}
245+
}
246+
leakError = new AssertionError(sb.toString());
247+
}
248+
}
249+
250+
// Collect uncaught exceptions regardless of whether threads leaked.
251+
List<UncaughtExceptionsHandler.UncaughtException> uncaught =
252+
handler != null ? handler.getAndClear() : List.of();
253+
254+
if (leakError == null && uncaught.isEmpty()) return;
255+
256+
// Combine: leak error first (if any), uncaught exceptions after; all but the first
257+
// are attached as suppressed on the thrown error.
258+
var errors = new ArrayList<AssertionError>();
259+
if (leakError != null) errors.add(leakError);
260+
for (var ue : uncaught) {
261+
errors.add(
262+
new AssertionError("Uncaught exception in thread [" + ue.threadName() + "]", ue.error()));
263+
}
264+
var first = errors.get(0);
265+
errors.subList(1, errors.size()).forEach(first::addSuppressed);
266+
throw first;
267+
}
268+
269+
private static Map<Thread, StackTraceElement[]> leakedSince(
270+
HashSet<?> snapshot, Predicate<Thread> filter) {
271+
var current = liveThreadsWithStacks(filter);
272+
current.keySet().removeAll(snapshot);
273+
return current;
274+
}
275+
276+
private static HashSet<Thread> liveThreads(Predicate<Thread> filter) {
277+
return new HashSet<>(liveThreadsWithStacks(filter).keySet());
278+
}
279+
280+
private static Map<Thread, StackTraceElement[]> liveThreadsWithStacks(Predicate<Thread> filter) {
281+
return Thread.getAllStackTraces().entrySet().stream()
282+
.filter(e -> e.getKey().isAlive())
283+
.filter(e -> !filter.test(e.getKey()))
284+
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
285+
}
286+
}

0 commit comments

Comments
 (0)