Skip to content

Commit 366bad2

Browse files
CharryWuLegNeato
authored andcommitted
feat(cudnn-sys): discover Windows cuDNN paths via vX.Y directories
Made-with: Cursor
1 parent 029c225 commit 366bad2

1 file changed

Lines changed: 43 additions & 13 deletions

File tree

crates/cudnn-sys/build/cudnn_sdk.rs

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,22 +71,52 @@ impl CudnnSdk {
7171
"/usr/local/include/x86_64-linux-gnu",
7272
"/usr/local/include/aarch64-linux-gnu",
7373
];
74+
75+
#[cfg(not(target_os = "windows"))]
76+
let mut cudnn_paths: Vec<path::PathBuf> =
77+
CUDNN_DEFAULT_PATHS.iter().map(Path::new).map(path::PathBuf::from).collect();
78+
7479
#[cfg(target_os = "windows")]
75-
const CUDNN_DEFAULT_PATHS: &[&str] = &[
76-
// Standalone cuDNN installs following NVIDIA's documentation.
77-
"C:/Program Files/NVIDIA/CUDNN/v9.x/include",
78-
"C:/Program Files/NVIDIA/CUDNN/v8.x/include",
79-
// CUDA Toolkit installs that bundle cuDNN headers.
80-
// These are the default Windows install locations for recent CUDA versions.
81-
"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.2/include",
82-
"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.8/include",
83-
"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v12.6/include",
84-
"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v13.0/include",
85-
];
80+
let mut cudnn_paths: Vec<path::PathBuf> = {
81+
// Legacy standalone cuDNN installs following NVIDIA's documentation.
82+
let mut paths = vec![
83+
path::PathBuf::from("C:/Program Files/NVIDIA/CUDNN/v9.x/include"),
84+
path::PathBuf::from("C:/Program Files/NVIDIA/CUDNN/v8.x/include"),
85+
];
86+
87+
// Dynamically discover CUDA and cuDNN installs by matching vX.Y-style directories.
88+
let bases = [
89+
Path::new("C:/Program Files/NVIDIA/CUDNN"),
90+
Path::new("C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA"),
91+
];
92+
93+
for base in bases {
94+
if let Ok(entries) = fs::read_dir(base) {
95+
for entry in entries.flatten() {
96+
if let Ok(file_type) = entry.file_type() {
97+
if file_type.is_dir() {
98+
let name = entry.file_name();
99+
if let Some(name_str) = name.to_str() {
100+
// Match directories like v9.0, v10.2, v13.0, etc.
101+
if name_str.starts_with('v')
102+
&& name_str[1..]
103+
.split('.')
104+
.all(|part| !part.is_empty() && part.chars().all(|c| c.is_ascii_digit()))
105+
{
106+
paths.push(base.join(name_str).join("include"));
107+
}
108+
}
109+
}
110+
}
111+
}
112+
}
113+
}
114+
115+
paths
116+
};
86117

87-
let mut cudnn_paths: Vec<&Path> = CUDNN_DEFAULT_PATHS.iter().map(Path::new).collect();
88118
if let Some(override_path) = &cudnn_include_dir {
89-
cudnn_paths.push(Path::new(override_path));
119+
cudnn_paths.push(Path::new(override_path).to_path_buf());
90120
}
91121

92122
cudnn_paths

0 commit comments

Comments
 (0)