Skip to content

Commit 1b8ef80

Browse files
authored
Merge branch 'rocketride-org:develop' into feat/RR-1045-add-minimax-m3-model
2 parents 90581fd + da2a7d2 commit 1b8ef80

1 file changed

Lines changed: 31 additions & 8 deletions

File tree

packages/ai/src/ai/common/models/audio/whisper.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,29 @@ def _check_gpu_compatible(cls) -> bool:
8585
import subprocess
8686
import sys
8787

88-
# StorageView has no (shape, dtype, device) Python constructor — must use
89-
# from_array() on a CUDA-resident tensor to exercise the CUDA path.
88+
# Probe script checks two things:
89+
# 1. Version guard: ctranslate2 4.7.x + CUDA 12.8 causes a
90+
# tcache_thread_shutdown() SIGABRT during GPU transcription on H200
91+
# (heap corruption in cuBLAS 12.8.4). Exit non-zero to force CPU.
92+
# Upper bound at 4.8 so the guard lifts automatically once
93+
# ctranslate2 ships a fix (expected in 4.8+).
94+
# 2. StorageView sanity: verify a CUDA StorageView can be created via
95+
# the documented from_array() API (no direct (shape,dtype,device)
96+
# constructor exists in the Python bindings).
9097
probe_script = (
91-
'import ctranslate2, torch; '
92-
'v = ctranslate2.get_supported_compute_types("cuda"); '
93-
'assert v, "no cuda types"; '
94-
't = torch.zeros(1, dtype=torch.float32, device="cuda"); '
95-
'sv = ctranslate2.StorageView.from_array(t); '
96-
'print("ok")'
98+
'import sys, ctranslate2, torch\n'
99+
'v = ctranslate2.get_supported_compute_types("cuda")\n'
100+
'assert v, "no cuda types"\n'
101+
'try:\n'
102+
' ct2 = tuple(int(x) for x in ctranslate2.__version__.split(".")[:2])\n'
103+
'except (ValueError, AttributeError):\n'
104+
' ct2 = (999, 999)\n'
105+
'cuda = torch.version.cuda or ""\n'
106+
'if (4, 7) <= ct2 < (4, 8) and cuda.startswith("12.8"):\n'
107+
' sys.exit(1)\n'
108+
't = torch.zeros(1, dtype=torch.float32, device="cuda")\n'
109+
'sv = ctranslate2.StorageView.from_array(t)\n'
110+
'print("ok")\n'
97111
)
98112
result = None
99113
try:
@@ -200,6 +214,15 @@ def load(
200214
device = 'cuda'
201215
else:
202216
device = 'cpu'
217+
elif device != 'cpu' and not WhisperLoader._check_gpu_compatible():
218+
# Explicit cuda / cuda:N requested but probe failed — fall back to CPU
219+
# so the same SIGABRT protection applies regardless of how the caller
220+
# specified the device.
221+
logger.warning(
222+
'ctranslate2 CUDA probe failed for explicit device=%r — Whisper will use CPU instead.',
223+
device,
224+
)
225+
device = 'cpu'
203226

204227
if device == 'cpu':
205228
gpu_index = -1

0 commit comments

Comments
 (0)