11package com .carrotsearch .randomizedtesting .jupiter ;
22
33import java .util .HashSet ;
4+ import java .util .Map ;
5+ import java .util .concurrent .TimeUnit ;
46import java .util .logging .Logger ;
57import java .util .stream .Collectors ;
68import org .junit .jupiter .api .extension .AfterAllCallback ;
@@ -20,6 +22,9 @@ public class DetectThreadLeaksExtension
2022 private static final String SNAPSHOT_KEY = "snapshot" ;
2123 private static final String CONCURRENT_KEY = "concurrent" ;
2224
25+ /** Total time budget (ms) to join interrupted threads before giving up. */
26+ private static final long INTERRUPT_JOIN_MS = 2_000L ;
27+
2328 @ Override
2429 public void beforeAll (ExtensionContext context ) {
2530 if (context .getExecutionMode () != ExecutionMode .SAME_THREAD ) {
@@ -40,7 +45,10 @@ public void afterAll(ExtensionContext context) {
4045 if (isConcurrentMode (context ) || scope (context ) != DetectThreadLeaks .Scope .SUITE ) {
4146 return ;
4247 }
43- checkLeaks (context .getStore (EXTENSION_NAMESPACE ), "suite [" + context .getDisplayName () + "]" );
48+ checkLeaks (
49+ context .getStore (EXTENSION_NAMESPACE ),
50+ "suite [" + context .getDisplayName () + "]" ,
51+ linger (context ));
4452 }
4553
4654 @ Override
@@ -56,13 +64,29 @@ public void afterEach(ExtensionContext context) {
5664 if (isConcurrentMode (context ) || scope (context ) != DetectThreadLeaks .Scope .TEST ) {
5765 return ;
5866 }
59- checkLeaks (context .getStore (EXTENSION_NAMESPACE ), "test [" + context .getDisplayName () + "]" );
67+ checkLeaks (
68+ context .getStore (EXTENSION_NAMESPACE ),
69+ "test [" + context .getDisplayName () + "]" ,
70+ linger (context ));
6071 }
6172
6273 private static DetectThreadLeaks .Scope scope (ExtensionContext context ) {
6374 return context .getRequiredTestClass ().getAnnotation (DetectThreadLeaks .class ).scope ();
6475 }
6576
77+ private static int linger (ExtensionContext context ) {
78+ // Method-level annotation takes precedence over class-level.
79+ var methodAnn =
80+ context
81+ .getTestMethod ()
82+ .map (m -> m .getAnnotation (DetectThreadLeaks .LingerTime .class ))
83+ .orElse (null );
84+ if (methodAnn != null ) return methodAnn .millis ();
85+
86+ var classAnn = context .getRequiredTestClass ().getAnnotation (DetectThreadLeaks .LingerTime .class );
87+ return classAnn == null ? 0 : classAnn .millis ();
88+ }
89+
6690 private static boolean isConcurrentMode (ExtensionContext context ) {
6791 // Check the concurrent flag stored in beforeAll (class-level context = parent of method ctx).
6892 return context
@@ -74,26 +98,70 @@ private static boolean isConcurrentMode(ExtensionContext context) {
7498 .orElse (false );
7599 }
76100
77- private static void checkLeaks (ExtensionContext .Store store , String description ) {
101+ private static void checkLeaks (ExtensionContext .Store store , String description , int lingerMs ) {
78102 var snapshot = store .get (SNAPSHOT_KEY , HashSet .class );
79103 if (snapshot == null ) return ;
80104
81- var leaked = liveThreads ();
82- leaked .removeAll (snapshot );
83- leaked .removeIf (t -> !t .isAlive ());
105+ var leaked = leakedSince (snapshot );
106+ if (leaked .isEmpty ()) return ;
107+
108+ // Linger: poll until threads self-terminate or the window expires.
109+ if (lingerMs > 0 ) {
110+ long deadline = System .nanoTime () + TimeUnit .MILLISECONDS .toNanos (lingerMs );
111+ while (!leaked .isEmpty () && System .nanoTime () < deadline ) {
112+ try {
113+ long remainingMs = TimeUnit .NANOSECONDS .toMillis (deadline - System .nanoTime ());
114+ Thread .sleep (Math .max (1L , Math .min (100L , remainingMs )));
115+ } catch (InterruptedException e ) {
116+ Thread .currentThread ().interrupt ();
117+ break ;
118+ }
119+ leaked = leakedSince (snapshot );
120+ }
121+ if (leaked .isEmpty ()) return ;
122+ }
123+
124+ // Interrupt leaked threads for cleanup, then wait briefly for them to terminate.
125+ leaked .keySet ().forEach (Thread ::interrupt );
126+ long joinDeadline = System .nanoTime () + TimeUnit .MILLISECONDS .toNanos (INTERRUPT_JOIN_MS );
127+ for (Thread t : leaked .keySet ()) {
128+ long remaining = TimeUnit .NANOSECONDS .toMillis (joinDeadline - System .nanoTime ());
129+ if (remaining <= 0 ) break ;
130+ try {
131+ t .join (remaining );
132+ } catch (InterruptedException e ) {
133+ Thread .currentThread ().interrupt ();
134+ break ;
135+ }
136+ }
84137
85- if (!leaked .isEmpty ()) {
86- var sb = new StringBuilder (leaked .size () + " thread(s) leaked from " + description + ":" );
87- leaked .forEach (t -> sb .append ("\n " ).append (Threads .threadName (t )));
88- throw new AssertionError (sb .toString ());
138+ // Report failure with stack traces captured before the interrupt.
139+ var sb = new StringBuilder (leaked .size () + " thread(s) leaked from " + description + ":" );
140+ int cnt = 1 ;
141+ for (var entry : leaked .entrySet ()) {
142+ sb .append (String .format ("%n %2d) %s" , cnt ++, Threads .threadName (entry .getKey ())));
143+ for (var ste : entry .getValue ()) {
144+ sb .append (String .format ("%n at %s" , ste ));
145+ }
89146 }
147+ throw new AssertionError (sb .toString ());
148+ }
149+
150+ private static Map <Thread , StackTraceElement []> leakedSince (HashSet <?> snapshot ) {
151+ var current = liveThreadsWithStacks ();
152+ current .keySet ().removeAll (snapshot );
153+ return current ;
90154 }
91155
92156 private static HashSet <Thread > liveThreads () {
93- return Thread .getAllStackTraces ().keySet ().stream ()
94- .filter (Thread ::isAlive )
95- .filter (t -> !isKnownSystemThread (t ))
96- .collect (Collectors .toCollection (HashSet ::new ));
157+ return new HashSet <>(liveThreadsWithStacks ().keySet ());
158+ }
159+
160+ private static Map <Thread , StackTraceElement []> liveThreadsWithStacks () {
161+ return Thread .getAllStackTraces ().entrySet ().stream ()
162+ .filter (e -> e .getKey ().isAlive ())
163+ .filter (e -> !isKnownSystemThread (e .getKey ()))
164+ .collect (Collectors .toMap (Map .Entry ::getKey , Map .Entry ::getValue ));
97165 }
98166
99167 private static boolean isKnownSystemThread (Thread t ) {
0 commit comments