@@ -71,22 +71,52 @@ impl CudnnSdk {
7171 "/usr/local/include/x86_64-linux-gnu" ,
7272 "/usr/local/include/aarch64-linux-gnu" ,
7373 ] ;
74+
75+ #[ cfg( not( target_os = "windows" ) ) ]
76+ let mut cudnn_paths: Vec < path:: PathBuf > =
77+ CUDNN_DEFAULT_PATHS . iter ( ) . map ( Path :: new) . map ( path:: PathBuf :: from) . collect ( ) ;
78+
7479 #[ cfg( target_os = "windows" ) ]
75- const CUDNN_DEFAULT_PATHS : & [ & str ] = & [
76- // Standalone cuDNN installs following NVIDIA's documentation.
77- "C:/Program Files/NVIDIA/CUDNN/v9.x/include" ,
78- "C:/Program Files/NVIDIA/CUDNN/v8.x/include" ,
79- // CUDA Toolkit installs that bundle cuDNN headers.
80- // These are the default Windows install locations for recent CUDA versions.
81- "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.2/include" ,
82- "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.8/include" ,
83- "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v12.6/include" ,
84- "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v13.0/include" ,
85- ] ;
80+ let mut cudnn_paths: Vec < path:: PathBuf > = {
81+ // Legacy standalone cuDNN installs following NVIDIA's documentation.
82+ let mut paths = vec ! [
83+ path:: PathBuf :: from( "C:/Program Files/NVIDIA/CUDNN/v9.x/include" ) ,
84+ path:: PathBuf :: from( "C:/Program Files/NVIDIA/CUDNN/v8.x/include" ) ,
85+ ] ;
86+
87+ // Dynamically discover CUDA and cuDNN installs by matching vX.Y-style directories.
88+ let bases = [
89+ Path :: new ( "C:/Program Files/NVIDIA/CUDNN" ) ,
90+ Path :: new ( "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA" ) ,
91+ ] ;
92+
93+ for base in bases {
94+ if let Ok ( entries) = fs:: read_dir ( base) {
95+ for entry in entries. flatten ( ) {
96+ if let Ok ( file_type) = entry. file_type ( ) {
97+ if file_type. is_dir ( ) {
98+ let name = entry. file_name ( ) ;
99+ if let Some ( name_str) = name. to_str ( ) {
100+ // Match directories like v9.0, v10.2, v13.0, etc.
101+ if name_str. starts_with ( 'v' )
102+ && name_str[ 1 ..]
103+ . split ( '.' )
104+ . all ( |part| !part. is_empty ( ) && part. chars ( ) . all ( |c| c. is_ascii_digit ( ) ) )
105+ {
106+ paths. push ( base. join ( name_str) . join ( "include" ) ) ;
107+ }
108+ }
109+ }
110+ }
111+ }
112+ }
113+ }
114+
115+ paths
116+ } ;
86117
87- let mut cudnn_paths: Vec < & Path > = CUDNN_DEFAULT_PATHS . iter ( ) . map ( Path :: new) . collect ( ) ;
88118 if let Some ( override_path) = & cudnn_include_dir {
89- cudnn_paths. push ( Path :: new ( override_path) ) ;
119+ cudnn_paths. push ( Path :: new ( override_path) . to_path_buf ( ) ) ;
90120 }
91121
92122 cudnn_paths
0 commit comments