diff --git a/crates/cudnn-sys/build/cudnn_sdk.rs b/crates/cudnn-sys/build/cudnn_sdk.rs index 2974aa74..cfa9d995 100644 --- a/crates/cudnn-sys/build/cudnn_sdk.rs +++ b/crates/cudnn-sys/build/cudnn_sdk.rs @@ -1,6 +1,8 @@ +use std::env; use std::error; use std::fs; use std::path; +use std::path::Path; /// Represents the cuDNN SDK installation. #[derive(Debug, Clone)] @@ -57,6 +59,8 @@ impl CudnnSdk { } fn find_cudnn_include_dir() -> Result> { + let cudnn_include_dir = env::var_os("CUDNN_INCLUDE_DIR"); + #[cfg(not(target_os = "windows"))] const CUDNN_DEFAULT_PATHS: &[&str] = &["/usr/include", "/usr/local/include"]; #[cfg(target_os = "windows")] @@ -64,10 +68,16 @@ impl CudnnSdk { "C:/Program Files/NVIDIA/CUDNN/v9.x/include", "C:/Program Files/NVIDIA/CUDNN/v8.x/include", ]; - CUDNN_DEFAULT_PATHS + + let mut cudnn_paths: Vec<&Path> = CUDNN_DEFAULT_PATHS.iter().map(Path::new).collect(); + if let Some(override_path) = &cudnn_include_dir { + cudnn_paths.push(Path::new(override_path)); + } + + cudnn_paths .iter() - .find(|s| Self::is_cudnn_include_path(s)) - .map(path::PathBuf::from) + .find(|p| Self::is_cudnn_include_path(p)) + .map(|p| p.to_path_buf()) .ok_or("Cannot find cuDNN include directory.".into()) }