From 26e0b975ed893796ed3654c271ab9c441f4492a2 Mon Sep 17 00:00:00 2001 From: Kayo Date: Mon, 12 May 2025 23:20:27 -0700 Subject: [PATCH 1/2] build(cudnn-sys): Add CUDNN_INCLUDE_DIR Enables users to specify a non-standard cuDNN install path. This seems to be needed for the newer editions of the CUDA toolkit, as cuDNN isn't included by default (at least in the Fedora repo's, you have to install from a tarball) --- crates/cudnn-sys/build/cudnn_sdk.rs | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/crates/cudnn-sys/build/cudnn_sdk.rs b/crates/cudnn-sys/build/cudnn_sdk.rs index 2974aa74..59f5436c 100644 --- a/crates/cudnn-sys/build/cudnn_sdk.rs +++ b/crates/cudnn-sys/build/cudnn_sdk.rs @@ -1,3 +1,4 @@ +use std::env; use std::error; use std::fs; use std::path; @@ -64,7 +65,18 @@ impl CudnnSdk { "C:/Program Files/NVIDIA/CUDNN/v9.x/include", "C:/Program Files/NVIDIA/CUDNN/v8.x/include", ]; - CUDNN_DEFAULT_PATHS + + let mut cudnn_paths: Vec = + CUDNN_DEFAULT_PATHS.iter().map(|s| s.to_string()).collect(); + if let Some(override_path) = env::var_os("CUDNN_INCLUDE_DIR") { + cudnn_paths.push( + override_path + .into_string() + .expect("CUDNN_INCLUDE_DIR to be a Unicode string"), + ); + } + + cudnn_paths .iter() .find(|s| Self::is_cudnn_include_path(s)) .map(path::PathBuf::from) From 2cfca20a84f39f2bafe354b2c48016ee05698003 Mon Sep 17 00:00:00 2001 From: Kayo Date: Mon, 19 May 2025 23:00:32 -0700 Subject: [PATCH 2/2] more idiomatic --- crates/cudnn-sys/build/cudnn_sdk.rs | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/crates/cudnn-sys/build/cudnn_sdk.rs b/crates/cudnn-sys/build/cudnn_sdk.rs index 59f5436c..cfa9d995 100644 --- a/crates/cudnn-sys/build/cudnn_sdk.rs +++ b/crates/cudnn-sys/build/cudnn_sdk.rs @@ -2,6 +2,7 @@ use std::env; use std::error; use std::fs; use std::path; +use std::path::Path; /// Represents the cuDNN SDK installation. #[derive(Debug, Clone)] @@ -58,6 +59,8 @@ impl CudnnSdk { } fn find_cudnn_include_dir() -> Result> { + let cudnn_include_dir = env::var_os("CUDNN_INCLUDE_DIR"); + #[cfg(not(target_os = "windows"))] const CUDNN_DEFAULT_PATHS: &[&str] = &["/usr/include", "/usr/local/include"]; #[cfg(target_os = "windows")] @@ -66,20 +69,15 @@ impl CudnnSdk { "C:/Program Files/NVIDIA/CUDNN/v8.x/include", ]; - let mut cudnn_paths: Vec = - CUDNN_DEFAULT_PATHS.iter().map(|s| s.to_string()).collect(); - if let Some(override_path) = env::var_os("CUDNN_INCLUDE_DIR") { - cudnn_paths.push( - override_path - .into_string() - .expect("CUDNN_INCLUDE_DIR to be a Unicode string"), - ); + let mut cudnn_paths: Vec<&Path> = CUDNN_DEFAULT_PATHS.iter().map(Path::new).collect(); + if let Some(override_path) = &cudnn_include_dir { + cudnn_paths.push(Path::new(override_path)); } cudnn_paths .iter() - .find(|s| Self::is_cudnn_include_path(s)) - .map(path::PathBuf::from) + .find(|p| Self::is_cudnn_include_path(p)) + .map(|p| p.to_path_buf()) .ok_or("Cannot find cuDNN include directory.".into()) }