1
1
package org .junit .internal .runners .statements ;
2
2
3
+ import java .lang .management .ManagementFactory ;
4
+ import java .lang .management .ThreadMXBean ;
3
5
import java .util .concurrent .Callable ;
4
6
import java .util .concurrent .ExecutionException ;
5
7
import java .util .concurrent .FutureTask ;
6
8
import java .util .concurrent .TimeUnit ;
7
9
import java .util .concurrent .TimeoutException ;
8
10
11
+ import org .junit .internal .runners .ExceptionWithThread ;
9
12
import org .junit .runners .model .Statement ;
10
13
11
14
public class FailOnTimeout extends Statement {
12
15
private final Statement fOriginalStatement ;
13
16
private final TimeUnit fTimeUnit ;
14
17
private final long fTimeout ;
18
+ private ThreadGroup fThreadGroup = null ;
15
19
16
20
public FailOnTimeout (Statement originalStatement , long millis ) {
17
21
this (originalStatement , millis , TimeUnit .MILLISECONDS );
@@ -26,7 +30,8 @@ public FailOnTimeout(Statement originalStatement, long timeout, TimeUnit unit) {
26
30
@ Override
27
31
public void evaluate () throws Throwable {
28
32
FutureTask <Throwable > task = new FutureTask <Throwable >(new CallableStatement ());
29
- Thread thread = new Thread (task , "Time-limited test" );
33
+ fThreadGroup = new ThreadGroup ("FailOnTimeoutGroup" );
34
+ Thread thread = new Thread (fThreadGroup , task , "Time-limited test" );
30
35
thread .setDaemon (true );
31
36
thread .start ();
32
37
Throwable throwable = getResult (task , thread );
@@ -55,17 +60,82 @@ private Throwable getResult(FutureTask<Throwable> task, Thread thread) {
55
60
56
61
private Exception createTimeoutException (Thread thread ) {
57
62
StackTraceElement [] stackTrace = thread .getStackTrace ();
58
- Exception exception = new Exception (String .format (
59
- "test timed out after %d %s" , fTimeout , fTimeUnit .name ().toLowerCase ()));
63
+ final Thread stuckThread = getStuckThread (thread );
64
+ String message = String .format (
65
+ "test timed out after %d %s" , fTimeout , fTimeUnit .name ().toLowerCase ());
66
+ Exception exception = (stuckThread == null )
67
+ ? new Exception (message )
68
+ : new ExceptionWithThread (message , stuckThread ,
69
+ "Appears to be stuck in thread {0}" );
60
70
if (stackTrace != null ) {
61
71
exception .setStackTrace (stackTrace );
62
72
thread .interrupt ();
63
73
}
64
74
return exception ;
65
75
}
66
76
67
- private class CallableStatement implements Callable <Throwable > {
77
+ /**
78
+ * Determines whether the test appears to be stuck in some thread other than
79
+ * the "main thread" (the one created to run the test).
80
+ * @param mainThread The main thread created by {@code evaluate()}
81
+ * @return The thread which appears to be causing the problem, if different from
82
+ * {@code mainThread}, or {@code null} if the main thread appears to be the
83
+ * problem or if the thread cannot be determined. The return value is never equal
84
+ * to {@code mainThread}.
85
+ */
86
+ private Thread getStuckThread (Thread mainThread ) {
87
+ if (fThreadGroup == null ) return null ;
88
+ final int count = fThreadGroup .activeCount (); // this is just an estimate
89
+ int enumSize = Math .max (count * 2 , 100 );
90
+ int enumCount ;
91
+ Thread [] threads ;
92
+ ThreadMXBean mxBean = ManagementFactory .getThreadMXBean ();
93
+ int loopCount = 0 ;
94
+ while (true ) {
95
+ threads = new Thread [enumSize ];
96
+ enumCount = fThreadGroup .enumerate (threads );
97
+ // if there are too many threads to fit into the array, enumerate's result
98
+ // is >= the array's length; therefore we can't trust that it returned all
99
+ // the threads. Try again.
100
+ if (enumCount < enumSize ) break ;
101
+ enumSize += 100 ;
102
+ if (++loopCount >= 5 ) return null ;
103
+ // threads are proliferating too fast for us. Bail before we get into
104
+ // trouble.
105
+ }
106
+
107
+ // Now that we have all the threads in the test's thread group: Assume that
108
+ // any thread we're "stuck" in is RUNNABLE. Look for all RUNNABLE threads.
109
+ // If just one, we return that (unless it equals threadMain). If there's more
110
+ // than one, pick the one that's using the most CPU time, if this feature is
111
+ // supported.
112
+ Thread firstRunnable = null ;
113
+ Thread mostCpu = null ;
114
+ long maxCpuTime = 0 ;
115
+ int runnableCount = 0 ;
116
+ for (int i = 0 ; i < enumCount ; i ++) {
117
+ if (threads [i ].getState () == Thread .State .RUNNABLE ) {
118
+ runnableCount ++;
119
+ if (firstRunnable == null ) firstRunnable = threads [i ];
120
+ if (mxBean .isThreadCpuTimeSupported ()) {
121
+ try {
122
+ long cpuTime = mxBean .getThreadCpuTime (threads [i ].getId ());
123
+ if (mostCpu == null || cpuTime > maxCpuTime ) {
124
+ mostCpu = threads [i ];
125
+ maxCpuTime = cpuTime ;
126
+ }
127
+ } catch (UnsupportedOperationException e ) {
128
+ }
129
+ }
130
+ }
131
+ }
132
+ Thread stuckThread =
133
+ (runnableCount == 1 ) ? firstRunnable :
134
+ ((mostCpu != null ) ? mostCpu : firstRunnable );
135
+ return (stuckThread == mainThread ) ? null : stuckThread ;
136
+ }
68
137
138
+ private class CallableStatement implements Callable <Throwable > {
69
139
public Throwable call () throws Exception {
70
140
try {
71
141
fOriginalStatement .evaluate ();
0 commit comments