Skip to content

Commit bb00cd5

Browse files
CharryWuLegNeato
authored andcommitted
test(cudnn-sys): add windows vX.Y path discovery tests
Made-with: Cursor
1 parent 366bad2 commit bb00cd5

1 file changed

Lines changed: 96 additions & 21 deletions

File tree

crates/cudnn-sys/build/cudnn_sdk.rs

Lines changed: 96 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,38 @@ impl CudnnSdk {
5858
p.join("cudnn.h").is_file() && p.join("cudnn_version.h").is_file()
5959
}
6060

61+
#[cfg(target_os = "windows")]
62+
fn is_vxy_dir_name(name: &str) -> bool {
63+
name.starts_with('v')
64+
&& name[1..]
65+
.split('.')
66+
.all(|part| !part.is_empty() && part.chars().all(|c| c.is_ascii_digit()))
67+
}
68+
69+
#[cfg(target_os = "windows")]
70+
fn collect_windows_cudnn_include_paths(bases: &[&Path]) -> Vec<path::PathBuf> {
71+
let mut paths = Vec::new();
72+
73+
for base in bases {
74+
if let Ok(entries) = fs::read_dir(base) {
75+
for entry in entries.flatten() {
76+
if let Ok(file_type) = entry.file_type() {
77+
if file_type.is_dir() {
78+
let name = entry.file_name();
79+
if let Some(name_str) = name.to_str() {
80+
if Self::is_vxy_dir_name(name_str) {
81+
paths.push(base.join(name_str).join("include"));
82+
}
83+
}
84+
}
85+
}
86+
}
87+
}
88+
}
89+
90+
paths
91+
}
92+
6193
fn find_cudnn_include_dir() -> Result<path::PathBuf, Box<dyn error::Error>> {
6294
let cudnn_include_dir = env::var_os("CUDNN_INCLUDE_DIR");
6395

@@ -90,27 +122,7 @@ impl CudnnSdk {
90122
Path::new("C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA"),
91123
];
92124

93-
for base in bases {
94-
if let Ok(entries) = fs::read_dir(base) {
95-
for entry in entries.flatten() {
96-
if let Ok(file_type) = entry.file_type() {
97-
if file_type.is_dir() {
98-
let name = entry.file_name();
99-
if let Some(name_str) = name.to_str() {
100-
// Match directories like v9.0, v10.2, v13.0, etc.
101-
if name_str.starts_with('v')
102-
&& name_str[1..]
103-
.split('.')
104-
.all(|part| !part.is_empty() && part.chars().all(|c| c.is_ascii_digit()))
105-
{
106-
paths.push(base.join(name_str).join("include"));
107-
}
108-
}
109-
}
110-
}
111-
}
112-
}
113-
}
125+
paths.extend(Self::collect_windows_cudnn_include_paths(&bases));
114126

115127
paths
116128
};
@@ -145,3 +157,66 @@ impl CudnnSdk {
145157
Ok([major, minor, patch])
146158
}
147159
}
160+
161+
#[cfg(test)]
162+
mod tests {
163+
use super::*;
164+
165+
#[cfg(target_os = "windows")]
166+
#[test]
167+
fn is_vxy_dir_name_accepts_valid_versions() {
168+
assert!(CudnnSdk::is_vxy_dir_name("v9.0"));
169+
assert!(CudnnSdk::is_vxy_dir_name("v10.6"));
170+
assert!(CudnnSdk::is_vxy_dir_name("v12.1"));
171+
assert!(CudnnSdk::is_vxy_dir_name("v13.0"));
172+
assert!(CudnnSdk::is_vxy_dir_name("v9.10.3"));
173+
}
174+
175+
#[cfg(target_os = "windows")]
176+
#[test]
177+
fn is_vxy_dir_name_rejects_invalid_versions() {
178+
assert!(!CudnnSdk::is_vxy_dir_name("v"));
179+
assert!(!CudnnSdk::is_vxy_dir_name("v9."));
180+
assert!(!CudnnSdk::is_vxy_dir_name("v.9"));
181+
assert!(!CudnnSdk::is_vxy_dir_name("v9.a"));
182+
assert!(!CudnnSdk::is_vxy_dir_name("9.0"));
183+
assert!(!CudnnSdk::is_vxy_dir_name("vx.y"));
184+
assert!(!CudnnSdk::is_vxy_dir_name("random"));
185+
}
186+
187+
#[cfg(target_os = "windows")]
188+
#[test]
189+
fn collect_windows_cudnn_include_paths_discovers_multiple_versions() {
190+
use std::fs::create_dir_all;
191+
192+
let tmp_dir = env::temp_dir().join("cudnn_sdk_tests");
193+
if tmp_dir.exists() {
194+
fs::remove_dir_all(&tmp_dir).ok();
195+
}
196+
197+
let cuda_base = tmp_dir.join("CUDA");
198+
let v10_6 = cuda_base.join("v10.6");
199+
let v12_1 = cuda_base.join("v12.1");
200+
let v13_0 = cuda_base.join("v13.0");
201+
202+
for ver in [&v10_6, &v12_1, &v13_0] {
203+
let include_dir = ver.join("include");
204+
create_dir_all(&include_dir).unwrap();
205+
fs::write(include_dir.join("cudnn.h"), "// stub").unwrap();
206+
fs::write(include_dir.join("cudnn_version.h"), "// stub").unwrap();
207+
}
208+
209+
// Also add some junk directories that should be ignored.
210+
create_dir_all(cuda_base.join("vbad")).unwrap();
211+
create_dir_all(cuda_base.join("not-a-version")).unwrap();
212+
213+
let bases: [&Path; 1] = [&cuda_base];
214+
let mut paths = CudnnSdk::collect_windows_cudnn_include_paths(&bases);
215+
paths.sort();
216+
217+
assert!(paths.contains(&v10_6.join("include")));
218+
assert!(paths.contains(&v12_1.join("include")));
219+
assert!(paths.contains(&v13_0.join("include")));
220+
assert_eq!(paths.len(), 3);
221+
}
222+
}

0 commit comments

Comments
 (0)