11package com .carrotsearch .randomizedtesting .jupiter ;
22
3- import java .util .Arrays ;
3+ import java .util .ArrayList ;
44import java .util .HashSet ;
55import java .util .LinkedHashSet ;
6+ import java .util .List ;
67import java .util .Map ;
78import java .util .concurrent .TimeUnit ;
89import java .util .function .Predicate ;
@@ -24,6 +25,7 @@ public class DetectThreadLeaksExtension
2425 ExtensionContext .Namespace .create (DetectThreadLeaksExtension .class );
2526 private static final String SNAPSHOT_KEY = "snapshot" ;
2627 private static final String CONCURRENT_KEY = "concurrent" ;
28+ private static final String UNCAUGHT_EXCEPTION_HANDLER_KEY = "uncaught-exception-handler" ;
2729
2830 /** Total time budget (ms) to join interrupted threads before giving up. */
2931 private static final long INTERRUPT_JOIN_MS = 2_000L ;
@@ -39,48 +41,67 @@ public void beforeAll(ExtensionContext context) {
3941 return ;
4042 }
4143 if (scope (context ) == DetectThreadLeaks .Scope .SUITE ) {
42- context .getStore (EXTENSION_NAMESPACE ).put (SNAPSHOT_KEY , liveThreads (buildFilter (context )));
44+ var store = context .getStore (EXTENSION_NAMESPACE );
45+ var filter = buildFilter (context );
46+ store .put (UNCAUGHT_EXCEPTION_HANDLER_KEY , installUncaughtExceptionHandler ());
47+ store .put (SNAPSHOT_KEY , liveThreads (filter ));
4348 }
4449 }
4550
4651 @ Override
4752 public void afterAll (ExtensionContext context ) {
48- if (isConcurrentMode (context ) || scope (context ) != DetectThreadLeaks .Scope .SUITE ) {
49- return ;
53+ if (isConcurrentMode (context ) || scope (context ) != DetectThreadLeaks .Scope .SUITE ) return ;
54+ var store = context .getStore (EXTENSION_NAMESPACE );
55+ var handler = store .get (UNCAUGHT_EXCEPTION_HANDLER_KEY , UncaughtExceptionsHandler .class );
56+ try {
57+ checkLeaks (
58+ store ,
59+ "suite [" + context .getDisplayName () + "]" ,
60+ linger (context ),
61+ buildFilter (context ),
62+ handler );
63+ } finally {
64+ if (handler != null ) handler .restore ();
5065 }
51- checkLeaks (
52- context .getStore (EXTENSION_NAMESPACE ),
53- "suite [" + context .getDisplayName () + "]" ,
54- linger (context ),
55- buildFilter (context ));
5666 }
5767
5868 @ Override
5969 public void beforeEach (ExtensionContext context ) {
60- if (isConcurrentMode (context ) || scope (context ) != DetectThreadLeaks .Scope .TEST ) {
61- return ;
62- }
63- context .getStore (EXTENSION_NAMESPACE ).put (SNAPSHOT_KEY , liveThreads (buildFilter (context )));
70+ if (isConcurrentMode (context ) || scope (context ) != DetectThreadLeaks .Scope .TEST ) return ;
71+ var store = context .getStore (EXTENSION_NAMESPACE );
72+ var filter = buildFilter (context );
73+ store .put (UNCAUGHT_EXCEPTION_HANDLER_KEY , installUncaughtExceptionHandler ());
74+ store .put (SNAPSHOT_KEY , liveThreads (filter ));
6475 }
6576
6677 @ Override
6778 public void afterEach (ExtensionContext context ) {
68- if (isConcurrentMode (context ) || scope (context ) != DetectThreadLeaks .Scope .TEST ) {
69- return ;
79+ if (isConcurrentMode (context ) || scope (context ) != DetectThreadLeaks .Scope .TEST ) return ;
80+ var store = context .getStore (EXTENSION_NAMESPACE );
81+ var handler = store .get (UNCAUGHT_EXCEPTION_HANDLER_KEY , UncaughtExceptionsHandler .class );
82+ try {
83+ checkLeaks (
84+ store ,
85+ "test [" + context .getDisplayName () + "]" ,
86+ linger (context ),
87+ buildFilter (context ),
88+ handler );
89+ } finally {
90+ if (handler != null ) handler .restore ();
7091 }
71- checkLeaks (
72- context .getStore (EXTENSION_NAMESPACE ),
73- "test [" + context .getDisplayName () + "]" ,
74- linger (context ),
75- buildFilter (context ));
92+ }
93+
94+ private static UncaughtExceptionsHandler installUncaughtExceptionHandler () {
95+ var handler = new UncaughtExceptionsHandler (Thread .getDefaultUncaughtExceptionHandler ());
96+ Thread .setDefaultUncaughtExceptionHandler (handler );
97+ return handler ;
7698 }
7799
78100 private static DetectThreadLeaks .Scope scope (ExtensionContext context ) {
79101 return context .getRequiredTestClass ().getAnnotation (DetectThreadLeaks .class ).scope ();
80102 }
81103
82104 private static int linger (ExtensionContext context ) {
83- // Method-level annotation takes precedence over class-level.
84105 var methodAnn =
85106 context
86107 .getTestMethod ()
@@ -94,7 +115,7 @@ private static int linger(ExtensionContext context) {
94115
95116 /**
96117 * Collects {@link DetectThreadLeaks.ExcludeThreads} filter classes from the entire hierarchy
97- * (method → class → superclasses) and returns a combined predicate that excludes a thread when
118+ * (method to class to superclasses) and returns a combined predicate that excludes a thread when
98119 * any filter matches it.
99120 */
100121 private static Predicate <Thread > buildFilter (ExtensionContext context ) {
@@ -113,21 +134,19 @@ private static Predicate<Thread> buildFilter(ExtensionContext context) {
113134 for (Class <?> cls = context .getRequiredTestClass (); cls != null ; cls = cls .getSuperclass ()) {
114135 var ann = cls .getAnnotation (DetectThreadLeaks .ExcludeThreads .class );
115136 if (ann != null ) {
116- filterClasses . addAll ( Arrays . asList ( ann .value ()));
137+ for ( var c : ann .value ()) filterClasses . add ( c );
117138 }
118139 }
119140
120- if (filterClasses .isEmpty ()) {
121- return t -> false ;
122- }
141+ if (filterClasses .isEmpty ()) return t -> false ;
123142
124143 var predicates =
125144 filterClasses .stream ()
126145 .map (
127146 cls -> {
128147 try {
129148 return (Predicate <Thread >) cls .getDeclaredConstructor ().newInstance ();
130- } catch (Exception e ) {
149+ } catch (ReflectiveOperationException e ) {
131150 throw new RuntimeException (
132151 "Cannot instantiate thread filter: " + cls .getName (), e );
133152 }
@@ -138,7 +157,6 @@ private static Predicate<Thread> buildFilter(ExtensionContext context) {
138157 }
139158
140159 private static boolean isConcurrentMode (ExtensionContext context ) {
141- // Check the concurrent flag stored in beforeAll (class-level context = parent of method ctx).
142160 return context
143161 .getParent ()
144162 .map (
@@ -149,53 +167,82 @@ private static boolean isConcurrentMode(ExtensionContext context) {
149167 }
150168
151169 private static void checkLeaks (
152- ExtensionContext .Store store , String description , int lingerMs , Predicate <Thread > filter ) {
170+ ExtensionContext .Store store ,
171+ String description ,
172+ int lingerMs ,
173+ Predicate <Thread > filter ,
174+ UncaughtExceptionsHandler handler ) {
153175 var snapshot = store .get (SNAPSHOT_KEY , HashSet .class );
154- if ( snapshot == null ) return ;
176+ AssertionError leakError = null ;
155177
156- var leaked = leakedSince (snapshot , filter );
157- if ( leaked . isEmpty ()) return ;
178+ if (snapshot != null ) {
179+ var leaked = leakedSince ( snapshot , filter ) ;
158180
159- // Linger: poll until threads self-terminate or the window expires.
160- if (lingerMs > 0 ) {
161- long deadline = System .nanoTime () + TimeUnit .MILLISECONDS .toNanos (lingerMs );
162- while (!leaked .isEmpty () && System .nanoTime () < deadline ) {
163- try {
164- long remainingMs = TimeUnit .NANOSECONDS .toMillis (deadline - System .nanoTime ());
165- Thread .sleep (Math .max (1L , Math .min (100L , remainingMs )));
166- } catch (InterruptedException e ) {
167- Thread .currentThread ().interrupt ();
168- break ;
181+ // Linger: poll until threads self-terminate or the window expires.
182+ if (!leaked .isEmpty () && lingerMs > 0 ) {
183+ long deadline = System .nanoTime () + TimeUnit .MILLISECONDS .toNanos (lingerMs );
184+ while (!leaked .isEmpty () && System .nanoTime () < deadline ) {
185+ try {
186+ long remainingMs = TimeUnit .NANOSECONDS .toMillis (deadline - System .nanoTime ());
187+ Thread .sleep (Math .max (1L , Math .min (100L , remainingMs )));
188+ } catch (InterruptedException e ) {
189+ Thread .currentThread ().interrupt ();
190+ break ;
191+ }
192+ leaked = leakedSince (snapshot , filter );
169193 }
170- leaked = leakedSince (snapshot , filter );
171194 }
172- if (leaked .isEmpty ()) return ;
173- }
174195
175- // Interrupt leaked threads for cleanup, then wait briefly for them to terminate.
176- leaked .keySet ().forEach (Thread ::interrupt );
177- long joinDeadline = System .nanoTime () + TimeUnit .MILLISECONDS .toNanos (INTERRUPT_JOIN_MS );
178- for (Thread t : leaked .keySet ()) {
179- long remaining = TimeUnit .NANOSECONDS .toMillis (joinDeadline - System .nanoTime ());
180- if (remaining <= 0 ) break ;
181- try {
182- t .join (remaining );
183- } catch (InterruptedException e ) {
184- Thread .currentThread ().interrupt ();
185- break ;
196+ if (!leaked .isEmpty ()) {
197+ // Suppress uncaught exception reporting during the interrupt/join phase to avoid
198+ // capturing expected InterruptedException-related exceptions from cleaned-up threads.
199+ if (handler != null ) handler .stopReporting ();
200+ try {
201+ leaked .keySet ().forEach (Thread ::interrupt );
202+ long joinDeadline = System .nanoTime () + TimeUnit .MILLISECONDS .toNanos (INTERRUPT_JOIN_MS );
203+ for (Thread t : leaked .keySet ()) {
204+ long remaining = TimeUnit .NANOSECONDS .toMillis (joinDeadline - System .nanoTime ());
205+ if (remaining <= 0 ) break ;
206+ try {
207+ t .join (remaining );
208+ } catch (InterruptedException e ) {
209+ Thread .currentThread ().interrupt ();
210+ break ;
211+ }
212+ }
213+ } finally {
214+ if (handler != null ) handler .resumeReporting ();
215+ }
216+
217+ var sb = new StringBuilder (leaked .size () + " thread(s) leaked from " + description + ":" );
218+ int cnt = 1 ;
219+ for (var entry : leaked .entrySet ()) {
220+ sb .append (String .format ("%n %2d) %s" , cnt ++, Threads .threadName (entry .getKey ())));
221+ for (var ste : entry .getValue ()) {
222+ sb .append (String .format ("%n at %s" , ste ));
223+ }
224+ }
225+ leakError = new AssertionError (sb .toString ());
186226 }
187227 }
188228
189- // Report failure with stack traces captured before the interrupt.
190- var sb = new StringBuilder (leaked .size () + " thread(s) leaked from " + description + ":" );
191- int cnt = 1 ;
192- for (var entry : leaked .entrySet ()) {
193- sb .append (String .format ("%n %2d) %s" , cnt ++, Threads .threadName (entry .getKey ())));
194- for (var ste : entry .getValue ()) {
195- sb .append (String .format ("%n at %s" , ste ));
196- }
229+ // Collect uncaught exceptions regardless of whether threads leaked.
230+ List <UncaughtExceptionsHandler .UncaughtException > uncaught =
231+ handler != null ? handler .getAndClear () : List .of ();
232+
233+ if (leakError == null && uncaught .isEmpty ()) return ;
234+
235+ // Combine: leak error first (if any), uncaught exceptions after; all but the first
236+ // are attached as suppressed on the thrown error.
237+ var errors = new ArrayList <AssertionError >();
238+ if (leakError != null ) errors .add (leakError );
239+ for (var ue : uncaught ) {
240+ errors .add (
241+ new AssertionError ("Uncaught exception in thread [" + ue .threadName () + "]" , ue .error ()));
197242 }
198- throw new AssertionError (sb .toString ());
243+ var first = errors .get (0 );
244+ errors .subList (1 , errors .size ()).forEach (first ::addSuppressed );
245+ throw first ;
199246 }
200247
201248 private static Map <Thread , StackTraceElement []> leakedSince (
0 commit comments