Skip to content

Commit 2cfca20

Browse files
committed
more idiomatic
1 parent 26e0b97 commit 2cfca20

1 file changed

Lines changed: 8 additions & 10 deletions

File tree

crates/cudnn-sys/build/cudnn_sdk.rs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::env;
22
use std::error;
33
use std::fs;
44
use std::path;
5+
use std::path::Path;
56

67
/// Represents the cuDNN SDK installation.
78
#[derive(Debug, Clone)]
@@ -58,6 +59,8 @@ impl CudnnSdk {
5859
}
5960

6061
fn find_cudnn_include_dir() -> Result<path::PathBuf, Box<dyn error::Error>> {
62+
let cudnn_include_dir = env::var_os("CUDNN_INCLUDE_DIR");
63+
6164
#[cfg(not(target_os = "windows"))]
6265
const CUDNN_DEFAULT_PATHS: &[&str] = &["/usr/include", "/usr/local/include"];
6366
#[cfg(target_os = "windows")]
@@ -66,20 +69,15 @@ impl CudnnSdk {
6669
"C:/Program Files/NVIDIA/CUDNN/v8.x/include",
6770
];
6871

69-
let mut cudnn_paths: Vec<String> =
70-
CUDNN_DEFAULT_PATHS.iter().map(|s| s.to_string()).collect();
71-
if let Some(override_path) = env::var_os("CUDNN_INCLUDE_DIR") {
72-
cudnn_paths.push(
73-
override_path
74-
.into_string()
75-
.expect("CUDNN_INCLUDE_DIR to be a Unicode string"),
76-
);
72+
let mut cudnn_paths: Vec<&Path> = CUDNN_DEFAULT_PATHS.iter().map(Path::new).collect();
73+
if let Some(override_path) = &cudnn_include_dir {
74+
cudnn_paths.push(Path::new(override_path));
7775
}
7876

7977
cudnn_paths
8078
.iter()
81-
.find(|s| Self::is_cudnn_include_path(s))
82-
.map(path::PathBuf::from)
79+
.find(|p| Self::is_cudnn_include_path(p))
80+
.map(|p| p.to_path_buf())
8381
.ok_or("Cannot find cuDNN include directory.".into())
8482
}
8583

0 commit comments

Comments
 (0)