Skip to content

Commit a5216d2

Browse files
committed
fix(cudnn): implement cuDNN 9 error codes, replace todo!() with proper mapping
cuDNN 9 restructured cudnnStatus_t into a hierarchical numeric system (2xxx=BAD_PARAM, 3xxx=NOT_SUPPORTED, 4xxx=INTERNAL_ERROR, 5xxx=EXECUTION_FAILED) and removed several codes present in cuDNN 8. Changes: - Add four new CudnnError variants behind #[cfg(cudnn9)]: SublibraryVersionMismatch, SerializationVersionMismatch, Deprecated, SublibraryLoadingFailed - Replace the _ => todo!() wildcard in IntoResult::into_result() with a category-based fallback that maps cuDNN 9 sub-codes (e.g. BAD_PARAM_NULL_POINTER) to their parent category variant using integer division, eliminating the runtime panic entirely - Add wire both new variants in into_raw() for round-trip correctness Verified against cudnn_graph.h from cuDNN 9.20 (anaconda distribution). The cudnn crate itself compiles cleanly; only pre-existing cust bindgen errors prevent a full cargo check -p cudnn from succeeding. Made-with: Cursor
1 parent e91b9ce commit a5216d2

File tree

1 file changed

+49
-4
lines changed

1 file changed

+49
-4
lines changed

crates/cudnn/src/error.rs

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::{error::Error, ffi::CStr, fmt::Display};
1+
use std::{error::Error, ffi::CStr, fmt::Display};
22

33
/// Enum encapsulating function status returns. All cuDNN library functions return their status.
44
///
@@ -52,6 +52,18 @@ pub enum CudnnError {
5252
RuntimeFpOverflow,
5353
#[cfg(not(cudnn9))]
5454
VersionMismatch,
55+
/// A version mismatch was detected between cuDNN sub-libraries (cuDNN 9+).
56+
#[cfg(cudnn9)]
57+
SublibraryVersionMismatch,
58+
/// A serialization version mismatch was detected (cuDNN 9+).
59+
#[cfg(cudnn9)]
60+
SerializationVersionMismatch,
61+
/// A deprecated API was called (cuDNN 9+).
62+
#[cfg(cudnn9)]
63+
Deprecated,
64+
/// A required sub-library could not be loaded (cuDNN 9+).
65+
#[cfg(cudnn9)]
66+
SublibraryLoadingFailed,
5567
}
5668

5769
impl CudnnError {
@@ -78,6 +90,14 @@ impl CudnnError {
7890
CudnnError::RuntimeFpOverflow => CUDNN_STATUS_RUNTIME_FP_OVERFLOW,
7991
#[cfg(not(cudnn9))]
8092
CudnnError::VersionMismatch => CUDNN_STATUS_VERSION_MISMATCH,
93+
#[cfg(cudnn9)]
94+
CudnnError::SublibraryVersionMismatch => CUDNN_STATUS_SUBLIBRARY_VERSION_MISMATCH,
95+
#[cfg(cudnn9)]
96+
CudnnError::SerializationVersionMismatch => CUDNN_STATUS_SERIALIZATION_VERSION_MISMATCH,
97+
#[cfg(cudnn9)]
98+
CudnnError::Deprecated => CUDNN_STATUS_DEPRECATED,
99+
#[cfg(cudnn9)]
100+
CudnnError::SublibraryLoadingFailed => CUDNN_STATUS_SUBLIBRARY_LOADING_FAILED,
81101
}
82102
}
83103
}
@@ -124,8 +144,33 @@ impl IntoResult for cudnn_sys::cudnnStatus_t {
124144
CUDNN_STATUS_RUNTIME_FP_OVERFLOW => CudnnError::RuntimeFpOverflow,
125145
#[cfg(not(cudnn9))]
126146
CUDNN_STATUS_VERSION_MISMATCH => CudnnError::VersionMismatch,
127-
// TODO(adamcavendish): implement cuDNN 9 error codes.
128-
_ => todo!(),
147+
// cuDNN 9 introduced a hierarchical status code system. Specific sub-codes
148+
// (e.g. CUDNN_STATUS_BAD_PARAM_NULL_POINTER = 2002) are mapped to their
149+
// parent category variant for backwards-compatible error handling.
150+
#[cfg(cudnn9)]
151+
CUDNN_STATUS_SUBLIBRARY_VERSION_MISMATCH => CudnnError::SublibraryVersionMismatch,
152+
#[cfg(cudnn9)]
153+
CUDNN_STATUS_SERIALIZATION_VERSION_MISMATCH => CudnnError::SerializationVersionMismatch,
154+
#[cfg(cudnn9)]
155+
CUDNN_STATUS_DEPRECATED => CudnnError::Deprecated,
156+
#[cfg(cudnn9)]
157+
CUDNN_STATUS_SUBLIBRARY_LOADING_FAILED => CudnnError::SublibraryLoadingFailed,
158+
#[cfg(cudnn9)]
159+
s => {
160+
use cudnn_sys::cudnnStatus_t::*;
161+
// Map cuDNN 9 hierarchical sub-codes to their parent category variant.
162+
// Sub-codes share the same thousands digit as their parent:
163+
// 2xxx -> BAD_PARAM, 3xxx -> NOT_SUPPORTED,
164+
// 4xxx -> INTERNAL_ERROR, 5xxx -> EXECUTION_FAILED
165+
let category = (s as u32) / 1000 * 1000;
166+
match category {
167+
c if c == CUDNN_STATUS_BAD_PARAM as u32 => CudnnError::BadParam,
168+
c if c == CUDNN_STATUS_NOT_SUPPORTED as u32 => CudnnError::NotSupported,
169+
c if c == CUDNN_STATUS_INTERNAL_ERROR as u32 => CudnnError::InternalError,
170+
c if c == CUDNN_STATUS_EXECUTION_FAILED as u32 => CudnnError::ExecutionFailed,
171+
_ => CudnnError::InternalError,
172+
}
173+
}
129174
})
130175
}
131-
}
176+
}

0 commit comments

Comments
 (0)