diff --git a/crates/cudnn/src/error.rs b/crates/cudnn/src/error.rs index 1c60e741..bc04a1b7 100644 --- a/crates/cudnn/src/error.rs +++ b/crates/cudnn/src/error.rs @@ -52,6 +52,18 @@ pub enum CudnnError { RuntimeFpOverflow, #[cfg(not(cudnn9))] VersionMismatch, + /// A version mismatch was detected between cuDNN sub-libraries (cuDNN 9+). + #[cfg(cudnn9)] + SublibraryVersionMismatch, + /// A serialization version mismatch was detected (cuDNN 9+). + #[cfg(cudnn9)] + SerializationVersionMismatch, + /// A deprecated API was called (cuDNN 9+). + #[cfg(cudnn9)] + Deprecated, + /// A required sub-library could not be loaded (cuDNN 9+). + #[cfg(cudnn9)] + SublibraryLoadingFailed, } impl CudnnError { @@ -78,6 +90,14 @@ impl CudnnError { CudnnError::RuntimeFpOverflow => CUDNN_STATUS_RUNTIME_FP_OVERFLOW, #[cfg(not(cudnn9))] CudnnError::VersionMismatch => CUDNN_STATUS_VERSION_MISMATCH, + #[cfg(cudnn9)] + CudnnError::SublibraryVersionMismatch => CUDNN_STATUS_SUBLIBRARY_VERSION_MISMATCH, + #[cfg(cudnn9)] + CudnnError::SerializationVersionMismatch => CUDNN_STATUS_SERIALIZATION_VERSION_MISMATCH, + #[cfg(cudnn9)] + CudnnError::Deprecated => CUDNN_STATUS_DEPRECATED, + #[cfg(cudnn9)] + CudnnError::SublibraryLoadingFailed => CUDNN_STATUS_SUBLIBRARY_LOADING_FAILED, } } } @@ -124,8 +144,205 @@ impl IntoResult for cudnn_sys::cudnnStatus_t { CUDNN_STATUS_RUNTIME_FP_OVERFLOW => CudnnError::RuntimeFpOverflow, #[cfg(not(cudnn9))] CUDNN_STATUS_VERSION_MISMATCH => CudnnError::VersionMismatch, - // TODO(adamcavendish): implement cuDNN 9 error codes. - _ => todo!(), + // cuDNN 9 introduced a hierarchical status code system. Specific sub-codes + // (e.g. CUDNN_STATUS_BAD_PARAM_NULL_POINTER = 2002) are mapped to their + // parent category variant for backwards-compatible error handling. + #[cfg(cudnn9)] + CUDNN_STATUS_SUBLIBRARY_VERSION_MISMATCH => CudnnError::SublibraryVersionMismatch, + #[cfg(cudnn9)] + CUDNN_STATUS_SERIALIZATION_VERSION_MISMATCH => CudnnError::SerializationVersionMismatch, + #[cfg(cudnn9)] + CUDNN_STATUS_DEPRECATED => CudnnError::Deprecated, + #[cfg(cudnn9)] + CUDNN_STATUS_SUBLIBRARY_LOADING_FAILED => CudnnError::SublibraryLoadingFailed, + #[cfg(cudnn9)] + s => { + use cudnn_sys::cudnnStatus_t::*; + // Map cuDNN 9 hierarchical sub-codes to their parent category variant. + // Sub-codes share the same thousands digit as their parent: + // 2xxx -> BAD_PARAM, 3xxx -> NOT_SUPPORTED, + // 4xxx -> INTERNAL_ERROR, 5xxx -> EXECUTION_FAILED + let category = (s as u32) / 1000 * 1000; + match category { + c if c == CUDNN_STATUS_BAD_PARAM as u32 => CudnnError::BadParam, + c if c == CUDNN_STATUS_NOT_SUPPORTED as u32 => CudnnError::NotSupported, + c if c == CUDNN_STATUS_INTERNAL_ERROR as u32 => CudnnError::InternalError, + c if c == CUDNN_STATUS_EXECUTION_FAILED as u32 => CudnnError::ExecutionFailed, + _ => CudnnError::InternalError, + } + } }) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn success_maps_to_ok() { + use cudnn_sys::cudnnStatus_t::*; + assert!(CUDNN_STATUS_SUCCESS.into_result().is_ok()); + } + + #[test] + fn common_status_codes_map() { + use cudnn_sys::cudnnStatus_t::*; + assert_eq!( + CUDNN_STATUS_NOT_INITIALIZED.into_result().unwrap_err(), + CudnnError::NotInitialized + ); + assert_eq!( + CUDNN_STATUS_BAD_PARAM.into_result().unwrap_err(), + CudnnError::BadParam + ); + assert_eq!( + CUDNN_STATUS_INTERNAL_ERROR.into_result().unwrap_err(), + CudnnError::InternalError + ); + assert_eq!( + CUDNN_STATUS_INVALID_VALUE.into_result().unwrap_err(), + CudnnError::InvalidValue + ); + assert_eq!( + CUDNN_STATUS_EXECUTION_FAILED.into_result().unwrap_err(), + CudnnError::ExecutionFailed + ); + assert_eq!( + CUDNN_STATUS_NOT_SUPPORTED.into_result().unwrap_err(), + CudnnError::NotSupported + ); + assert_eq!( + CUDNN_STATUS_LICENSE_ERROR.into_result().unwrap_err(), + CudnnError::LicenseError + ); + assert_eq!( + CUDNN_STATUS_RUNTIME_IN_PROGRESS.into_result().unwrap_err(), + CudnnError::RuntimeInProgress + ); + assert_eq!( + CUDNN_STATUS_RUNTIME_FP_OVERFLOW.into_result().unwrap_err(), + CudnnError::RuntimeFpOverflow + ); + } + + #[cfg(not(cudnn9))] + #[test] + fn cudnn8_only_status_codes_map() { + use cudnn_sys::cudnnStatus_t::*; + assert_eq!( + CUDNN_STATUS_ALLOC_FAILED.into_result().unwrap_err(), + CudnnError::AllocFailed + ); + assert_eq!( + CUDNN_STATUS_ARCH_MISMATCH.into_result().unwrap_err(), + CudnnError::ArchMismatch + ); + assert_eq!( + CUDNN_STATUS_MAPPING_ERROR.into_result().unwrap_err(), + CudnnError::MappingError + ); + assert_eq!( + CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING + .into_result() + .unwrap_err(), + CudnnError::RuntimePrerequisiteMissing + ); + assert_eq!( + CUDNN_STATUS_VERSION_MISMATCH.into_result().unwrap_err(), + CudnnError::VersionMismatch + ); + } + + #[cfg(cudnn9)] + #[test] + fn cudnn9_named_status_codes_map() { + use cudnn_sys::cudnnStatus_t::*; + assert_eq!( + CUDNN_STATUS_SUBLIBRARY_VERSION_MISMATCH + .into_result() + .unwrap_err(), + CudnnError::SublibraryVersionMismatch + ); + assert_eq!( + CUDNN_STATUS_SERIALIZATION_VERSION_MISMATCH + .into_result() + .unwrap_err(), + CudnnError::SerializationVersionMismatch + ); + assert_eq!( + CUDNN_STATUS_DEPRECATED.into_result().unwrap_err(), + CudnnError::Deprecated + ); + assert_eq!( + CUDNN_STATUS_SUBLIBRARY_LOADING_FAILED + .into_result() + .unwrap_err(), + CudnnError::SublibraryLoadingFailed + ); + } + + /// cuDNN 9 hierarchical sub-codes (2xxx/3xxx/…) must map to the parent category, not panic. + #[cfg(cudnn9)] + #[test] + fn cudnn9_hierarchical_subcodes_map_to_parent_category() { + use cudnn_sys::cudnnStatus_t::*; + assert_eq!( + CUDNN_STATUS_BAD_PARAM_NULL_POINTER + .into_result() + .unwrap_err(), + CudnnError::BadParam + ); + assert_eq!( + CUDNN_STATUS_NOT_SUPPORTED_SHAPE.into_result().unwrap_err(), + CudnnError::NotSupported + ); + } + + #[cfg(cudnn9)] + #[test] + fn cudnn9_into_raw_round_trips_for_named_errors() { + let cases = [ + CudnnError::SublibraryVersionMismatch, + CudnnError::SerializationVersionMismatch, + CudnnError::Deprecated, + CudnnError::SublibraryLoadingFailed, + ]; + for err in cases { + assert_eq!(err.into_raw().into_result().unwrap_err(), err); + } + } + + #[test] + fn into_raw_round_trips_for_common_errors() { + let cases = [ + CudnnError::NotInitialized, + CudnnError::BadParam, + CudnnError::InternalError, + CudnnError::InvalidValue, + CudnnError::ExecutionFailed, + CudnnError::NotSupported, + CudnnError::LicenseError, + CudnnError::RuntimeInProgress, + CudnnError::RuntimeFpOverflow, + ]; + for err in cases { + assert_eq!(err.into_raw().into_result().unwrap_err(), err); + } + } + + #[cfg(not(cudnn9))] + #[test] + fn into_raw_round_trips_for_cudnn8_only_errors() { + let cases = [ + CudnnError::AllocFailed, + CudnnError::ArchMismatch, + CudnnError::MappingError, + CudnnError::RuntimePrerequisiteMissing, + CudnnError::VersionMismatch, + ]; + for err in cases { + assert_eq!(err.into_raw().into_result().unwrap_err(), err); + } + } +}