@@ -19,13 +19,22 @@ package nvcdi
1919import (
2020 "fmt"
2121 "path/filepath"
22+ "slices"
2223
2324 "github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
2425 "github.com/NVIDIA/nvidia-container-toolkit/internal/dxcore"
2526 "github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
2627 "github.com/NVIDIA/nvidia-container-toolkit/pkg/lookup"
2728)
2829
30+ const (
31+ libcudaSo = "libcuda.so.1.1"
32+ )
33+
34+ var dxcoreLibraries = []string {
35+ "libdxcore.so" ,
36+ }
37+
2938var requiredDriverStoreFiles = []string {
3039 "libcuda.so.1.1" , /* Core library for cuda support */
3140 "libcuda_loader.so" , /* Core library for cuda support on WSL */
@@ -56,40 +65,125 @@ func (l *wsllib) newWSLDriverDiscoverer() (discover.Discover, error) {
5665 if len (driverStorePaths ) > 1 {
5766 l .logger .Warningf ("Found multiple driver store paths: %v" , driverStorePaths )
5867 }
59- l .logger .Infof ("Using WSL driver store paths: %v" , driverStorePaths )
6068
61- driverStorePaths = append (driverStorePaths , "/usr/lib/wsl/lib" )
69+ nvDriverStorePath , err := l .getNVIDIADriverStorePath (driverStorePaths )
70+ if err != nil {
71+ return nil , fmt .Errorf ("failed to find NVIDIA driver store path: %w" , err )
72+ }
73+
74+ l .logger .Infof ("Using WSL driver store path: %v" , nvDriverStorePath )
6275
63- driverStoreMounts := discover .NewMounts (
76+ dxcoreMounts := discover .NewMounts (
6477 l .logger ,
6578 lookup .NewFileLocator (
6679 lookup .WithLogger (l .logger ),
6780 lookup .WithSearchPaths (
68- driverStorePaths ... ,
81+ "/usr/lib/wsl/lib" ,
82+ ),
83+ lookup .WithCount (1 ),
84+ ),
85+ l .driver .Root ,
86+ dxcoreLibraries ,
87+ )
88+
89+ requiredDriverStoreMounts := discover .NewMounts (
90+ l .logger ,
91+ lookup .NewFileLocator (
92+ lookup .WithLogger (l .logger ),
93+ lookup .WithSearchPaths (
94+ nvDriverStorePath ,
6995 ),
7096 lookup .WithCount (1 ),
7197 ),
7298 l .driver .Root ,
7399 requiredDriverStoreFiles ,
74100 )
75101
102+ additionalDriverStoreMounts , err := l .getAdditionalMountsFromDriverStore (nvDriverStorePath )
103+ if err != nil {
104+ return nil , fmt .Errorf ("failed to get additional mounts from driver store: %w" , err )
105+ }
106+
76107 symlinkHook := nvidiaSMISimlinkHook {
77108 logger : l .logger ,
78- mountsFrom : driverStoreMounts ,
109+ mountsFrom : requiredDriverStoreMounts ,
79110 hookCreator : l .hookCreator ,
80111 }
81112
82- ldcacheHook , _ := discover .NewLDCacheUpdateHook (l .logger , driverStoreMounts , l .hookCreator )
113+ ldcacheHook , _ := discover .NewLDCacheUpdateHook (l .logger , requiredDriverStoreMounts , l .hookCreator )
83114
84115 d := discover .Merge (
85- driverStoreMounts ,
116+ dxcoreMounts ,
117+ requiredDriverStoreMounts ,
118+ additionalDriverStoreMounts ,
86119 symlinkHook ,
87120 ldcacheHook ,
88121 )
89122
90123 return d , nil
91124}
92125
126+ // getNVIDIADriverStorePath returns the driver store path associated with NVIDIA GPUs
127+ func (l * wsllib ) getNVIDIADriverStorePath (driverStorePaths []string ) (string , error ) {
128+ fileLocator := lookup .NewFileLocator (
129+ lookup .WithLogger (l .logger ),
130+ lookup .WithSearchPaths (
131+ driverStorePaths ... ,
132+ ),
133+ lookup .WithCount (1 ),
134+ )
135+ matches , err := fileLocator .Locate (libcudaSo )
136+ if err != nil {
137+ return "" , fmt .Errorf ("failed to locate %s at WSL driver store paths: %w" , libcudaSo , err )
138+ }
139+ if len (matches ) == 0 {
140+ return "" , fmt .Errorf ("could not locate %s at WSL driver store paths" )
141+ }
142+
143+ return filepath .Dir (matches [0 ]), nil
144+ }
145+
146+ // getAdditionalMountsFromDriverStore discovers additional NVIDIA libraries (.so files) from the
147+ // driver store that are not in the required list of libraries.
148+ func (l * wsllib ) getAdditionalMountsFromDriverStore (driverStore string ) (discover.Discover , error ) {
149+ additionalLibs , err := l .getAdditionalLibsFromDriverStore (driverStore , requiredDriverStoreFiles )
150+ if err != nil {
151+ return nil , fmt .Errorf ("failed to lookup additional libraries in driver store: %w" , err )
152+ }
153+
154+ mounts := discover .NewMounts (
155+ l .logger ,
156+ lookup .NewFileLocator (
157+ lookup .WithLogger (l .logger ),
158+ lookup .WithRoot (l .driver .Root ),
159+ ),
160+ l .driver .Root ,
161+ additionalLibs ,
162+ )
163+
164+ return mounts , nil
165+ }
166+
167+ func (l * wsllib ) getAdditionalLibsFromDriverStore (driverStore string , excludeFiles []string ) ([]string , error ) {
168+ fileLocator := lookup .AsOptional (
169+ lookup .NewFileLocator (
170+ lookup .WithLogger (l .logger ),
171+ lookup .WithSearchPaths (driverStore ),
172+ lookup .WithFilter (func (s string ) error {
173+ if slices .Contains (excludeFiles , s ) {
174+ return fmt .Errorf ("file is excluded" )
175+ }
176+ return nil
177+ }),
178+ ))
179+
180+ found , err := fileLocator .Locate ("*.so*" )
181+ if err != nil {
182+ return nil , fmt .Errorf ("failed to find additional files at driver store: %w" , err )
183+ }
184+ return found , nil
185+ }
186+
93187type nvidiaSMISimlinkHook struct {
94188 discover.None
95189 logger logger.Interface
0 commit comments