Skip to content

Commit 5ef6b93

Browse files
ytl0623ericspod
andauthored
Replace deprecated cuda.cudart with cuda.bindings.runtime. (#8790)
Fixes #8789 ### Description Replaces the deprecated `cuda.cudart` import in `trt_compiler.py` with `cuda.bindings.runtime`, which is the current API introduced in `cuda-bindings` >= 12.6 (pulled in by PyTorch 2.10+). ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: ytl0623 <david89062388@gmail.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent b3fff92 commit 5ef6b93

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

monai/networks/trt_compiler.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,22 @@ def cuassert(cuda_ret):
8585
"""
8686
Error reporting method for CUDA calls.
8787
Args:
88-
cuda_ret: CUDA return code.
88+
cuda_ret: Tuple returned by CUDA runtime calls, where the first element
89+
is a cudaError_t enum and subsequent elements are return values.
90+
91+
Raises:
92+
RuntimeError: If the CUDA call returned an error.
8993
"""
9094
err = cuda_ret[0]
91-
if err != 0:
92-
raise RuntimeError(f"CUDA ERROR: {err}")
95+
if err.value != 0:
96+
err_msg = f"CUDA ERROR: {err.value}"
97+
try:
98+
_, err_name = cudart.cudaGetErrorName(err)
99+
_, err_str = cudart.cudaGetErrorString(err)
100+
err_msg = f"CUDA ERROR {err.value}: {err_name}{err_str}"
101+
except Exception as e:
102+
get_logger("monai.networks.trt_compiler").debug(f"Failed to retrieve CUDA error details: {e}")
103+
raise RuntimeError(err_msg)
93104
if len(cuda_ret) > 1:
94105
return cuda_ret[1]
95106
return None

0 commit comments

Comments
 (0)