Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 109 additions & 8 deletions pkg/nvcdi/driver-wsl.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,28 @@ package nvcdi
import (
"fmt"
"path/filepath"
"slices"

"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/dxcore"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/lookup"
)

const (
libcudaSo = "libcuda.so.1.1"
)

var dxcoreLibraries = []string{
"libdxcore.so", /* Core library for dxcore support */
}

var requiredDriverStoreFiles = []string{
"libcuda.so.1.1", /* Core library for cuda support */
"libcuda_loader.so", /* Core library for cuda support on WSL */
"libnvidia-ptxjitcompiler.so.1", /* Core library for PTX Jit support */
"libnvidia-ml.so.1", /* Core library for nvml */
"libnvidia-ml_loader.so", /* Core library for nvml on WSL */
"libdxcore.so", /* Core library for dxcore support */
"libnvdxgdmal.so.1", /* dxgdmal library for cuda */
"nvcubins.bin", /* Binary containing GPU code for cuda */
"nvidia-smi", /* nvidia-smi binary*/
Expand All @@ -56,40 +64,133 @@ func (l *wsllib) newWSLDriverDiscoverer() (discover.Discover, error) {
if len(driverStorePaths) > 1 {
l.logger.Warningf("Found multiple driver store paths: %v", driverStorePaths)
}
l.logger.Infof("Using WSL driver store paths: %v", driverStorePaths)

driverStorePaths = append(driverStorePaths, "/usr/lib/wsl/lib")
nvDriverStorePath, err := l.getNVIDIADriverStorePath(driverStorePaths)
if err != nil {
return nil, fmt.Errorf("failed to find NVIDIA driver store path: %w", err)
}

l.logger.Infof("Using WSL driver store path: %v", nvDriverStorePath)

driverStoreMounts := discover.NewMounts(
dxcoreMounts := discover.NewMounts(
l.logger,
lookup.NewFileLocator(
lookup.WithLogger(l.logger),
lookup.WithSearchPaths(
driverStorePaths...,
"/usr/lib/wsl/lib",
),
lookup.WithCount(1),
),
l.driver.Root,
dxcoreLibraries,
)

requiredDriverStoreMounts := discover.NewMounts(
l.logger,
lookup.NewFileLocator(
lookup.WithLogger(l.logger),
lookup.WithSearchPaths(
nvDriverStorePath,
),
lookup.WithCount(1),
),
l.driver.Root,
requiredDriverStoreFiles,
)

additionalDriverStoreMounts, err := l.getAdditionalMountsFromDriverStore(nvDriverStorePath)
if err != nil {
return nil, fmt.Errorf("failed to get additional mounts from driver store: %w", err)
}

symlinkHook := nvidiaSMISimlinkHook{
logger: l.logger,
mountsFrom: driverStoreMounts,
mountsFrom: requiredDriverStoreMounts,
hookCreator: l.hookCreator,
}

ldcacheHook, _ := discover.NewLDCacheUpdateHook(l.logger, driverStoreMounts, l.hookCreator)
ldcacheHook, _ := discover.NewLDCacheUpdateHook(l.logger, discover.Merge(requiredDriverStoreMounts, dxcoreMounts), l.hookCreator)

d := discover.Merge(
driverStoreMounts,
dxcoreMounts,
requiredDriverStoreMounts,
additionalDriverStoreMounts,
symlinkHook,
ldcacheHook,
)

return d, nil
}

// getNVIDIADriverStorePath returns the driver store path associated with NVIDIA GPUs
func (l *wsllib) getNVIDIADriverStorePath(driverStorePaths []string) (string, error) {
fileLocator := lookup.NewFileLocator(
lookup.WithLogger(l.logger),
lookup.WithSearchPaths(
driverStorePaths...,
),
lookup.WithCount(1),
)
matches, err := fileLocator.Locate(libcudaSo)
if err != nil {
return "", fmt.Errorf("failed to locate %s at WSL driver store paths: %w", libcudaSo, err)
}
if len(matches) == 0 {
return "", fmt.Errorf("could not locate %s at WSL driver store paths", libcudaSo)
}

return filepath.Dir(matches[0]), nil
}

// getAdditionalMountsFromDriverStore discovers additional NVIDIA libraries (.so files) from the
// driver store that are not in the required list of libraries.
func (l *wsllib) getAdditionalMountsFromDriverStore(driverStore string) (discover.Discover, error) {
additionalLibs, err := l.getAdditionalFilesFromDriverStore(driverStore, requiredDriverStoreFiles)
if err != nil {
return nil, fmt.Errorf("failed to lookup additional files in driver store: %w", err)
}

mounts := discover.NewMounts(
l.logger,
lookup.NewFileLocator(
lookup.WithLogger(l.logger),
lookup.WithRoot(l.driver.Root),
),
l.driver.Root,
additionalLibs,
)

return mounts, nil
}

func (l *wsllib) getAdditionalFilesFromDriverStore(driverStore string, excludeFiles []string) ([]string, error) {
fileLocator := lookup.AsOptional(
lookup.NewFileLocator(
lookup.WithLogger(l.logger),
lookup.WithSearchPaths(driverStore),
lookup.WithFilter(func(s string) error {
if slices.Contains(excludeFiles, filepath.Base(s)) {
return fmt.Errorf("file is excluded")
}
return nil
}),
))

libs, err := fileLocator.Locate("*.so*")
if err != nil {
return nil, fmt.Errorf("failed to find additional '.so' files in driver store: %w", err)
}
bins, err := fileLocator.Locate("*.bin")
if err != nil {
return nil, fmt.Errorf("failed to find additional '.bin' files in driver store: %w", err)
}
dlls, err := fileLocator.Locate("*.dll")
if err != nil {
return nil, fmt.Errorf("failed to find additional '.dll' files in driver store: %w", err)
}
return slices.Concat(libs, bins, dlls), nil
}

type nvidiaSMISimlinkHook struct {
discover.None
logger logger.Interface
Expand Down
Loading