Skip to content

Commit 371fa42

Browse files
authored
[FIX]: cuda.core: simplify _check_runtime_error logic (#2003)
* cuda.core: prefer binding names for runtime errors Use the generated runtime error enum as the name source for known CUDA Runtime errors so error messages remain stable when the runtime name table differs from the installed bindings. Made-with: Cursor * cuda.core: simplify runtime error naming path `_check_error()` only routes `runtime.cudaError_t` instances into `_check_runtime_error()`, so consulting `cudaGetErrorName()` and keeping a fallback for unknown values does not improve the normal `cuda.core` path. The Windows hybrid cudart issue is that the runtime name table can lag the generated enum table, so using `error.name` directly is both simpler and a better match for the values the code already has. With the runtime path now relying on enum members, the runtime-side tests no longer need to account for `UNEXPECTED ERROR CODE` in this loop or keep a separate monkeypatch test for avoiding the runtime name lookup. Made-with: Cursor
1 parent 6e5fb12 commit 371fa42

2 files changed

Lines changed: 4 additions & 13 deletions

File tree

cuda_core/cuda/core/_utils/cuda_utils.pyx

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,9 @@ cpdef inline int _check_driver_error(cydriver.CUresult error) except?-1 nogil:
171171
cpdef inline int _check_runtime_error(error) except?-1:
172172
if error == _RUNTIME_SUCCESS:
173173
return 0
174-
name_err, name = runtime.cudaGetErrorName(error)
175-
if name_err != _RUNTIME_SUCCESS:
176-
raise CUDAError(f"UNEXPECTED ERROR CODE: {error}")
177-
name = name.decode()
174+
# `_check_error()` reaches this path only for `runtime.cudaError_t` values.
175+
# Use the enum name directly because Windows hybrid cudart can lag that table.
176+
name = error.name
178177
expl = RUNTIME_CUDA_ERROR_EXPLANATIONS.get(int(error))
179178
if expl is not None:
180179
raise CUDAError(f"{name}: {expl}")

cuda_core/tests/test_cuda_utils.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,22 +48,14 @@ def test_check_driver_error():
4848

4949

5050
def test_check_runtime_error():
51-
num_unexpected = 0
5251
for error in runtime.cudaError_t:
5352
if error == runtime.cudaError_t.cudaSuccess:
5453
assert cuda_utils._check_runtime_error(error) == 0
5554
else:
5655
with pytest.raises(cuda_utils.CUDAError) as e:
5756
cuda_utils._check_runtime_error(error)
5857
msg = str(e)
59-
if "UNEXPECTED ERROR CODE" in msg:
60-
num_unexpected += 1
61-
else:
62-
# Example repr(error): <cudaError_t.cudaErrorUnknown: 999>
63-
enum_name = repr(error).split(".", 1)[1].split(":", 1)[0]
64-
assert enum_name in msg
65-
# Smoke test: We don't want most to be unexpected.
66-
assert num_unexpected < len(driver.CUresult) * 0.5
58+
assert error.name in msg
6759

6860

6961
def test_driver_error_enum_has_non_empty_docstring():

0 commit comments

Comments
 (0)