@@ -58,6 +58,38 @@ impl CudnnSdk {
5858 p. join ( "cudnn.h" ) . is_file ( ) && p. join ( "cudnn_version.h" ) . is_file ( )
5959 }
6060
61+ #[ cfg( target_os = "windows" ) ]
62+ fn is_vxy_dir_name ( name : & str ) -> bool {
63+ name. starts_with ( 'v' )
64+ && name[ 1 ..]
65+ . split ( '.' )
66+ . all ( |part| !part. is_empty ( ) && part. chars ( ) . all ( |c| c. is_ascii_digit ( ) ) )
67+ }
68+
69+ #[ cfg( target_os = "windows" ) ]
70+ fn collect_windows_cudnn_include_paths ( bases : & [ & Path ] ) -> Vec < path:: PathBuf > {
71+ let mut paths = Vec :: new ( ) ;
72+
73+ for base in bases {
74+ if let Ok ( entries) = fs:: read_dir ( base) {
75+ for entry in entries. flatten ( ) {
76+ if let Ok ( file_type) = entry. file_type ( ) {
77+ if file_type. is_dir ( ) {
78+ let name = entry. file_name ( ) ;
79+ if let Some ( name_str) = name. to_str ( ) {
80+ if Self :: is_vxy_dir_name ( name_str) {
81+ paths. push ( base. join ( name_str) . join ( "include" ) ) ;
82+ }
83+ }
84+ }
85+ }
86+ }
87+ }
88+ }
89+
90+ paths
91+ }
92+
6193 fn find_cudnn_include_dir ( ) -> Result < path:: PathBuf , Box < dyn error:: Error > > {
6294 let cudnn_include_dir = env:: var_os ( "CUDNN_INCLUDE_DIR" ) ;
6395
@@ -90,27 +122,7 @@ impl CudnnSdk {
90122 Path :: new ( "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA" ) ,
91123 ] ;
92124
93- for base in bases {
94- if let Ok ( entries) = fs:: read_dir ( base) {
95- for entry in entries. flatten ( ) {
96- if let Ok ( file_type) = entry. file_type ( ) {
97- if file_type. is_dir ( ) {
98- let name = entry. file_name ( ) ;
99- if let Some ( name_str) = name. to_str ( ) {
100- // Match directories like v9.0, v10.2, v13.0, etc.
101- if name_str. starts_with ( 'v' )
102- && name_str[ 1 ..]
103- . split ( '.' )
104- . all ( |part| !part. is_empty ( ) && part. chars ( ) . all ( |c| c. is_ascii_digit ( ) ) )
105- {
106- paths. push ( base. join ( name_str) . join ( "include" ) ) ;
107- }
108- }
109- }
110- }
111- }
112- }
113- }
125+ paths. extend ( Self :: collect_windows_cudnn_include_paths ( & bases) ) ;
114126
115127 paths
116128 } ;
@@ -145,3 +157,66 @@ impl CudnnSdk {
145157 Ok ( [ major, minor, patch] )
146158 }
147159}
160+
161+ #[ cfg( test) ]
162+ mod tests {
163+ use super :: * ;
164+
165+ #[ cfg( target_os = "windows" ) ]
166+ #[ test]
167+ fn is_vxy_dir_name_accepts_valid_versions ( ) {
168+ assert ! ( CudnnSdk :: is_vxy_dir_name( "v9.0" ) ) ;
169+ assert ! ( CudnnSdk :: is_vxy_dir_name( "v10.6" ) ) ;
170+ assert ! ( CudnnSdk :: is_vxy_dir_name( "v12.1" ) ) ;
171+ assert ! ( CudnnSdk :: is_vxy_dir_name( "v13.0" ) ) ;
172+ assert ! ( CudnnSdk :: is_vxy_dir_name( "v9.10.3" ) ) ;
173+ }
174+
175+ #[ cfg( target_os = "windows" ) ]
176+ #[ test]
177+ fn is_vxy_dir_name_rejects_invalid_versions ( ) {
178+ assert ! ( !CudnnSdk :: is_vxy_dir_name( "v" ) ) ;
179+ assert ! ( !CudnnSdk :: is_vxy_dir_name( "v9." ) ) ;
180+ assert ! ( !CudnnSdk :: is_vxy_dir_name( "v.9" ) ) ;
181+ assert ! ( !CudnnSdk :: is_vxy_dir_name( "v9.a" ) ) ;
182+ assert ! ( !CudnnSdk :: is_vxy_dir_name( "9.0" ) ) ;
183+ assert ! ( !CudnnSdk :: is_vxy_dir_name( "vx.y" ) ) ;
184+ assert ! ( !CudnnSdk :: is_vxy_dir_name( "random" ) ) ;
185+ }
186+
187+ #[ cfg( target_os = "windows" ) ]
188+ #[ test]
189+ fn collect_windows_cudnn_include_paths_discovers_multiple_versions ( ) {
190+ use std:: fs:: create_dir_all;
191+
192+ let tmp_dir = env:: temp_dir ( ) . join ( "cudnn_sdk_tests" ) ;
193+ if tmp_dir. exists ( ) {
194+ fs:: remove_dir_all ( & tmp_dir) . ok ( ) ;
195+ }
196+
197+ let cuda_base = tmp_dir. join ( "CUDA" ) ;
198+ let v10_6 = cuda_base. join ( "v10.6" ) ;
199+ let v12_1 = cuda_base. join ( "v12.1" ) ;
200+ let v13_0 = cuda_base. join ( "v13.0" ) ;
201+
202+ for ver in [ & v10_6, & v12_1, & v13_0] {
203+ let include_dir = ver. join ( "include" ) ;
204+ create_dir_all ( & include_dir) . unwrap ( ) ;
205+ fs:: write ( include_dir. join ( "cudnn.h" ) , "// stub" ) . unwrap ( ) ;
206+ fs:: write ( include_dir. join ( "cudnn_version.h" ) , "// stub" ) . unwrap ( ) ;
207+ }
208+
209+ // Also add some junk directories that should be ignored.
210+ create_dir_all ( cuda_base. join ( "vbad" ) ) . unwrap ( ) ;
211+ create_dir_all ( cuda_base. join ( "not-a-version" ) ) . unwrap ( ) ;
212+
213+ let bases: [ & Path ; 1 ] = [ & cuda_base] ;
214+ let mut paths = CudnnSdk :: collect_windows_cudnn_include_paths ( & bases) ;
215+ paths. sort ( ) ;
216+
217+ assert ! ( paths. contains( & v10_6. join( "include" ) ) ) ;
218+ assert ! ( paths. contains( & v12_1. join( "include" ) ) ) ;
219+ assert ! ( paths. contains( & v13_0. join( "include" ) ) ) ;
220+ assert_eq ! ( paths. len( ) , 3 ) ;
221+ }
222+ }
0 commit comments