Skip to content

Commit 0a90770

Browse files
committed
Merge remote-tracking branch 'origin/claude/fix-blackwell-monai-tests-Apjwl' into 8587-test-erros-on-pytorch-release-2508-on-series-50
2 parents 80124e6 + 1b5ac46 commit 0a90770

File tree

2 files changed

+50
-1
lines changed

2 files changed

+50
-1
lines changed

runtests.sh

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ doMypyFormat=false
5353
doCleanup=false
5454
doDistTests=false
5555
doPrecommit=false
56+
testTimeout=0
5657

5758
NUM_PARALLEL=1
5859

@@ -109,6 +110,8 @@ function print_usage {
109110
echo " -v, --version : show MONAI and system version information and exit"
110111
echo " -p, --path : specify the path used for formatting, default is the current dir if unspecified"
111112
echo " --formatfix : format code using \"isort\" and \"black\" for user specified directories"
113+
echo " --timeout [secs] : per-test timeout in seconds; tests exceeding this are marked as errors and skipped"
114+
echo " (default: 180s when flag is given without a value; 0 = disabled)"
112115
echo ""
113116
echo "${separator}For bug reports and feature requests, please file an issue at:"
114117
echo " https://github.com/Project-MONAI/MONAI/issues/new/choose"
@@ -344,6 +347,15 @@ do
344347
testdir=$2
345348
shift
346349
;;
350+
--timeout)
351+
# Accept an optional numeric value; default to 180s if none given.
352+
if [[ -n "$2" ]] && [[ "$2" =~ ^[0-9]+$ ]]; then
353+
testTimeout=$2
354+
shift
355+
else
356+
testTimeout=180
357+
fi
358+
;;
347359
*)
348360
print_error_msg "Incorrect commandline provided, invalid key: $key"
349361
print_usage
@@ -695,7 +707,11 @@ if [ $doUnitTests = true ]
695707
then
696708
echo "${separator}${blue}unittests${noColor}"
697709
torch_validate
698-
${cmdPrefix}${cmd} ./tests/runner.py -p "^(?!test_integration|test_perceptual_loss|test_auto3dseg_ensemble).*(?<!_dist)$" # excluding integration/dist/perceptual_loss tests
710+
timeoutArg=""
711+
if [ "$testTimeout" -gt 0 ] 2>/dev/null; then
712+
timeoutArg="--timeout $testTimeout"
713+
fi
714+
${cmdPrefix}${cmd} ./tests/runner.py -p "^(?!test_integration|test_perceptual_loss|test_auto3dseg_ensemble).*(?<!_dist)$" $timeoutArg # excluding integration/dist/perceptual_loss tests
699715
fi
700716

701717
# distributed test only

tests/runner.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import inspect
1616
import os
1717
import re
18+
import signal
1819
import sys
1920
import time
2021
import unittest
@@ -24,10 +25,23 @@
2425

2526
results: dict = {}
2627

28+
_SIGALRM_AVAILABLE = hasattr(signal, "SIGALRM")
29+
30+
31+
class _TestTimeoutError(Exception):
32+
"""Raised when a single test exceeds the per-test timeout."""
33+
34+
35+
def _alarm_handler(signum, frame):
36+
raise _TestTimeoutError("Test timed out")
37+
2738

2839
class TimeLoggingTestResult(unittest.TextTestResult):
2940
"""Overload the default results so that we can store the results."""
3041

42+
# Set by the caller before running; 0 means no timeout.
43+
timeout: int = 0
44+
3145
def __init__(self, *args, **kwargs):
3246
super().__init__(*args, **kwargs)
3347
self.timed_tests = {}
@@ -37,10 +51,15 @@ def startTest(self, test): # noqa: N802
3751
self.start_time = time.time()
3852
name = self.getDescription(test)
3953
self.stream.write(f"Starting test: {name}...\n")
54+
if _SIGALRM_AVAILABLE and self.timeout > 0:
55+
signal.signal(signal.SIGALRM, _alarm_handler)
56+
signal.alarm(self.timeout)
4057
super().startTest(test)
4158

4259
def stopTest(self, test): # noqa: N802
4360
"""On test end, get time, print, store and do normal behaviour."""
61+
if _SIGALRM_AVAILABLE and self.timeout > 0:
62+
signal.alarm(0) # cancel any pending alarm
4463
elapsed = time.time() - self.start_time
4564
name = self.getDescription(test)
4665
self.stream.write(f"Finished test: {name} ({elapsed:.03}s)\n")
@@ -99,6 +118,13 @@ def parse_args():
99118
parser.add_argument(
100119
"-f", "--failfast", action="store_true", dest="failfast", default=False, help="Stop testing on first failure"
101120
)
121+
parser.add_argument(
122+
"--timeout",
123+
dest="timeout",
124+
default=0,
125+
type=int,
126+
help="Per-test timeout in seconds; 0 disables (default: %(default)d). Requires SIGALRM (Linux/macOS only).",
127+
)
102128
args = parser.parse_args()
103129
print(f"Running tests in folder: '{args.path}'")
104130
if args.pattern:
@@ -145,6 +171,13 @@ def get_default_pattern(loader):
145171
discovery_time = pc.total_time
146172
print(f"time to discover tests: {discovery_time}s, total cases: {tests.countTestCases()}.")
147173

174+
if args.timeout > 0:
175+
if _SIGALRM_AVAILABLE:
176+
TimeLoggingTestResult.timeout = args.timeout
177+
print(f"Per-test timeout enabled: {args.timeout}s")
178+
else:
179+
print("Warning: --timeout ignored; SIGALRM is not available on this platform.")
180+
148181
test_runner = unittest.runner.TextTestRunner(
149182
resultclass=TimeLoggingTestResult, verbosity=args.verbosity, failfast=args.failfast
150183
)

0 commit comments

Comments
 (0)