@@ -2,6 +2,7 @@ use std::env;
22use std:: error;
33use std:: fs;
44use std:: path;
5+ use std:: path:: Path ;
56
67/// Represents the cuDNN SDK installation.
78#[ derive( Debug , Clone ) ]
@@ -58,6 +59,8 @@ impl CudnnSdk {
5859 }
5960
6061 fn find_cudnn_include_dir ( ) -> Result < path:: PathBuf , Box < dyn error:: Error > > {
62+ let cudnn_include_dir = env:: var_os ( "CUDNN_INCLUDE_DIR" ) ;
63+
6164 #[ cfg( not( target_os = "windows" ) ) ]
6265 const CUDNN_DEFAULT_PATHS : & [ & str ] = & [ "/usr/include" , "/usr/local/include" ] ;
6366 #[ cfg( target_os = "windows" ) ]
@@ -66,20 +69,15 @@ impl CudnnSdk {
6669 "C:/Program Files/NVIDIA/CUDNN/v8.x/include" ,
6770 ] ;
6871
69- let mut cudnn_paths: Vec < String > =
70- CUDNN_DEFAULT_PATHS . iter ( ) . map ( |s| s. to_string ( ) ) . collect ( ) ;
71- if let Some ( override_path) = env:: var_os ( "CUDNN_INCLUDE_DIR" ) {
72- cudnn_paths. push (
73- override_path
74- . into_string ( )
75- . expect ( "CUDNN_INCLUDE_DIR to be a Unicode string" ) ,
76- ) ;
72+ let mut cudnn_paths: Vec < & Path > = CUDNN_DEFAULT_PATHS . iter ( ) . map ( Path :: new) . collect ( ) ;
73+ if let Some ( override_path) = & cudnn_include_dir {
74+ cudnn_paths. push ( Path :: new ( override_path) ) ;
7775 }
7876
7977 cudnn_paths
8078 . iter ( )
81- . find ( |s | Self :: is_cudnn_include_path ( s ) )
82- . map ( path :: PathBuf :: from )
79+ . find ( |p | Self :: is_cudnn_include_path ( p ) )
80+ . map ( |p| p . to_path_buf ( ) )
8381 . ok_or ( "Cannot find cuDNN include directory." . into ( ) )
8482 }
8583
0 commit comments