Skip to content

Commit 4ac1c30

Browse files
committed
More fixes to thread leaks.
1 parent 216b896 commit 4ac1c30

4 files changed

Lines changed: 280 additions & 159 deletions

File tree

randomizedtesting-jupiter/src/main/java/com/carrotsearch/randomizedtesting/jupiter/DetectThreadLeaks.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
Scope scope() default Scope.SUITE;
2727

2828
enum Scope {
29+
/** Disable thread leak detection entirely. */
30+
NONE,
2931
/** Check for leaked threads once after all tests in the class complete. */
3032
SUITE,
3133
/** Check for leaked threads after each individual test method. */
@@ -51,11 +53,10 @@ enum Scope {
5153
* Excludes threads matched by any of the given {@link Predicate} classes from leak detection. A
5254
* thread is excluded when at least one predicate returns {@code true} for it.
5355
*
54-
* <p>Annotations are collected hierarchically: the test method, then the class, then each
55-
* superclass, and the filters from all levels are combined. Place on the same class or method as
56-
* {@link DetectThreadLeaks}.
56+
* <p>Annotations are collected hierarchically from the class and its superclasses, and the
57+
* filters from all levels are combined.
5758
*/
58-
@Target({ElementType.TYPE, ElementType.METHOD})
59+
@Target({ElementType.TYPE})
5960
@Retention(RetentionPolicy.RUNTIME)
6061
@Documented
6162
@interface ExcludeThreads {

randomizedtesting-jupiter/src/main/java/com/carrotsearch/randomizedtesting/jupiter/DetectThreadLeaksExtension.java

Lines changed: 52 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ public class DetectThreadLeaksExtension
2323
private static final Logger LOGGER = Logger.getLogger(DetectThreadLeaksExtension.class.getName());
2424
private static final ExtensionContext.Namespace EXTENSION_NAMESPACE =
2525
ExtensionContext.Namespace.create(DetectThreadLeaksExtension.class);
26-
private static final String SNAPSHOT_KEY = "snapshot";
26+
private static final String THREAD_SNAPSHOT_KEY = "snapshot";
2727
private static final String CONCURRENT_KEY = "concurrent";
2828
private static final String UNCAUGHT_EXCEPTION_HANDLER_KEY = "uncaught-exception-handler";
2929

@@ -32,6 +32,10 @@ public class DetectThreadLeaksExtension
3232

3333
@Override
3434
public void beforeAll(ExtensionContext context) {
35+
if (scope(context) == DetectThreadLeaks.Scope.NONE) {
36+
return;
37+
}
38+
3539
if (context.getExecutionMode() != ExecutionMode.SAME_THREAD) {
3640
LOGGER.warning(
3741
"Thread leak detection is disabled: tests in ["
@@ -40,23 +44,36 @@ public void beforeAll(ExtensionContext context) {
4044
context.getStore(EXTENSION_NAMESPACE).put(CONCURRENT_KEY, Boolean.TRUE);
4145
return;
4246
}
43-
if (scope(context) == DetectThreadLeaks.Scope.SUITE) {
44-
var store = context.getStore(EXTENSION_NAMESPACE);
45-
var filter = buildFilter(context);
46-
store.put(UNCAUGHT_EXCEPTION_HANDLER_KEY, installUncaughtExceptionHandler());
47-
store.put(SNAPSHOT_KEY, liveThreads(filter));
47+
48+
var store = context.getStore(EXTENSION_NAMESPACE);
49+
var filter = buildFilter(context);
50+
store.put(UNCAUGHT_EXCEPTION_HANDLER_KEY, installUncaughtExceptionHandler());
51+
store.put(THREAD_SNAPSHOT_KEY, liveThreads(filter));
52+
}
53+
54+
@Override
55+
public void beforeEach(ExtensionContext context) {
56+
if (isConcurrentMode(context) || scope(context) != DetectThreadLeaks.Scope.TEST) {
57+
return;
4858
}
59+
60+
var store = context.getStore(EXTENSION_NAMESPACE);
61+
var filter = buildFilter(context);
62+
store.put(THREAD_SNAPSHOT_KEY, liveThreads(filter));
4963
}
5064

5165
@Override
52-
public void afterAll(ExtensionContext context) {
53-
if (isConcurrentMode(context) || scope(context) != DetectThreadLeaks.Scope.SUITE) return;
66+
public void afterEach(ExtensionContext context) {
67+
if (isConcurrentMode(context) || scope(context) != DetectThreadLeaks.Scope.TEST) {
68+
return;
69+
}
70+
5471
var store = context.getStore(EXTENSION_NAMESPACE);
5572
var handler = store.get(UNCAUGHT_EXCEPTION_HANDLER_KEY, UncaughtExceptionsHandler.class);
5673
try {
5774
checkLeaks(
5875
store,
59-
"suite [" + context.getDisplayName() + "]",
76+
"test [" + context.getDisplayName() + "]",
6077
linger(context),
6178
buildFilter(context),
6279
handler);
@@ -66,23 +83,17 @@ public void afterAll(ExtensionContext context) {
6683
}
6784

6885
@Override
69-
public void beforeEach(ExtensionContext context) {
70-
if (isConcurrentMode(context) || scope(context) != DetectThreadLeaks.Scope.TEST) return;
71-
var store = context.getStore(EXTENSION_NAMESPACE);
72-
var filter = buildFilter(context);
73-
store.put(UNCAUGHT_EXCEPTION_HANDLER_KEY, installUncaughtExceptionHandler());
74-
store.put(SNAPSHOT_KEY, liveThreads(filter));
75-
}
86+
public void afterAll(ExtensionContext context) {
87+
if (isConcurrentMode(context) || scope(context) == DetectThreadLeaks.Scope.NONE) {
88+
return;
89+
}
7690

77-
@Override
78-
public void afterEach(ExtensionContext context) {
79-
if (isConcurrentMode(context) || scope(context) != DetectThreadLeaks.Scope.TEST) return;
8091
var store = context.getStore(EXTENSION_NAMESPACE);
8192
var handler = store.get(UNCAUGHT_EXCEPTION_HANDLER_KEY, UncaughtExceptionsHandler.class);
8293
try {
8394
checkLeaks(
8495
store,
85-
"test [" + context.getDisplayName() + "]",
96+
"suite [" + context.getDisplayName() + "]",
8697
linger(context),
8798
buildFilter(context),
8899
handler);
@@ -119,41 +130,26 @@ private static int linger(ExtensionContext context) {
119130
* any filter matches it.
120131
*/
121132
private static Predicate<Thread> buildFilter(ExtensionContext context) {
122-
var filterClasses = new LinkedHashSet<Class<? extends Predicate<Thread>>>();
123-
124-
context
125-
.getTestMethod()
126-
.ifPresent(
127-
m -> {
128-
var ann = m.getAnnotation(DetectThreadLeaks.ExcludeThreads.class);
129-
if (ann != null) {
130-
for (var c : ann.value()) filterClasses.add(c);
131-
}
132-
});
133+
var filterClasses = new LinkedHashSet<Predicate<Thread>>();
133134

134135
for (Class<?> cls = context.getRequiredTestClass(); cls != null; cls = cls.getSuperclass()) {
135136
var ann = cls.getAnnotation(DetectThreadLeaks.ExcludeThreads.class);
136137
if (ann != null) {
137-
for (var c : ann.value()) filterClasses.add(c);
138+
for (var c : ann.value()) {
139+
try {
140+
filterClasses.add(c.getDeclaredConstructor().newInstance());
141+
} catch (ReflectiveOperationException e) {
142+
throw new RuntimeException("Cannot instantiate thread filter: " + cls.getName(), e);
143+
}
144+
}
138145
}
139146
}
140147

141-
if (filterClasses.isEmpty()) return t -> false;
142-
143-
var predicates =
144-
filterClasses.stream()
145-
.map(
146-
cls -> {
147-
try {
148-
return (Predicate<Thread>) cls.getDeclaredConstructor().newInstance();
149-
} catch (ReflectiveOperationException e) {
150-
throw new RuntimeException(
151-
"Cannot instantiate thread filter: " + cls.getName(), e);
152-
}
153-
})
154-
.toList();
155-
156-
return t -> predicates.stream().anyMatch(p -> p.test(t));
148+
if (filterClasses.isEmpty()) {
149+
return t -> false;
150+
}
151+
152+
return t -> filterClasses.stream().anyMatch(p -> p.test(t));
157153
}
158154

159155
private static boolean isConcurrentMode(ExtensionContext context) {
@@ -172,7 +168,7 @@ private static void checkLeaks(
172168
int lingerMs,
173169
Predicate<Thread> filter,
174170
UncaughtExceptionsHandler handler) {
175-
var snapshot = store.get(SNAPSHOT_KEY, HashSet.class);
171+
var snapshot = store.get(THREAD_SNAPSHOT_KEY, HashSet.class);
176172
AssertionError leakError = null;
177173

178174
if (snapshot != null) {
@@ -196,7 +192,10 @@ private static void checkLeaks(
196192
if (!leaked.isEmpty()) {
197193
// Suppress uncaught exception reporting during the interrupt/join phase to avoid
198194
// capturing expected InterruptedException-related exceptions from cleaned-up threads.
199-
if (handler != null) handler.stopReporting();
195+
if (handler != null) {
196+
handler.stopReporting();
197+
}
198+
200199
try {
201200
leaked.keySet().forEach(Thread::interrupt);
202201
long joinDeadline = System.nanoTime() + TimeUnit.MILLISECONDS.toNanos(INTERRUPT_JOIN_MS);
@@ -211,7 +210,9 @@ private static void checkLeaks(
211210
}
212211
}
213212
} finally {
214-
if (handler != null) handler.resumeReporting();
213+
if (handler != null) {
214+
handler.resumeReporting();
215+
}
215216
}
216217

217218
var sb = new StringBuilder(leaked.size() + " thread(s) leaked from " + description + ":");

0 commit comments

Comments
 (0)