Skip to content

Commit 405317e

Browse files
committed
Add support for requirements checks to CDI
Signed-off-by: Arjun <agadiyar@nvidia.com>
1 parent 3cfea27 commit 405317e

3 files changed

Lines changed: 202 additions & 36 deletions

File tree

internal/modifier/cdi.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ func (f *Factory) newCDIModifier(isJitCDI bool) (oci.SpecModifier, error) {
5656
return nil, nil
5757
}
5858

59+
if err := checkRequirements(f.logger, f.image, f.driver); err != nil {
60+
return nil, fmt.Errorf("requirements not met: %w", err)
61+
}
62+
5963
automaticDevices := filterAutomaticDevices(devices)
6064
if len(automaticDevices) != len(devices) && len(automaticDevices) > 0 {
6165
return nil, fmt.Errorf("requesting a CDI device with vendor 'runtime.nvidia.com' is not supported when requesting other CDI devices")

internal/modifier/csv.go

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,7 @@ import (
2020
"fmt"
2121

2222
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
23-
"github.com/NVIDIA/nvidia-container-toolkit/internal/cuda"
24-
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
2523
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
26-
"github.com/NVIDIA/nvidia-container-toolkit/internal/requirements"
2724
)
2825

2926
// newCSVModifier creates a modifier that applies modications to an OCI spec if required by the runtime wrapper.
@@ -36,45 +33,13 @@ func (f *Factory) newCSVModifier() (oci.SpecModifier, error) {
3633
}
3734
f.logger.Infof("Constructing modifier from config: %+v", *f.cfg)
3835

39-
if err := checkRequirements(f.logger, f.image); err != nil {
36+
if err := checkRequirements(f.logger, f.image, f.driver); err != nil {
4037
return nil, fmt.Errorf("requirements not met: %v", err)
4138
}
4239

4340
return f.newAutomaticCDISpecModifier(devices)
4441
}
4542

46-
func checkRequirements(logger logger.Interface, image *image.CUDA) error {
47-
if image == nil || image.HasDisableRequire() {
48-
// TODO: We could print the real value here instead
49-
logger.Debugf("NVIDIA_DISABLE_REQUIRE=%v; skipping requirement checks", true)
50-
return nil
51-
}
52-
53-
imageRequirements, err := image.GetRequirements()
54-
if err != nil {
55-
// TODO: Should we treat this as a failure, or just issue a warning?
56-
return fmt.Errorf("failed to get image requirements: %v", err)
57-
}
58-
59-
r := requirements.New(logger, imageRequirements)
60-
61-
cudaVersion, err := cuda.Version()
62-
if err != nil {
63-
logger.Warningf("Failed to get CUDA version: %v", err)
64-
} else {
65-
r.AddVersionProperty(requirements.CUDA, cudaVersion)
66-
}
67-
68-
compteCapability, err := cuda.ComputeCapability(0)
69-
if err != nil {
70-
logger.Warningf("Failed to get CUDA Compute Capability: %v", err)
71-
} else {
72-
r.AddVersionProperty(requirements.ARCH, compteCapability)
73-
}
74-
75-
return r.Assert()
76-
}
77-
7843
type csvDevices image.CUDA
7944

8045
func (d csvDevices) DeviceRequests() []string {
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
/**
2+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
**/
16+
17+
package modifier
18+
19+
import (
20+
"fmt"
21+
"strconv"
22+
"strings"
23+
24+
"github.com/NVIDIA/go-nvml/pkg/nvml"
25+
"golang.org/x/mod/semver"
26+
27+
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
28+
"github.com/NVIDIA/nvidia-container-toolkit/internal/cuda"
29+
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
30+
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
31+
"github.com/NVIDIA/nvidia-container-toolkit/internal/requirements"
32+
)
33+
34+
// checkRequirements evaluates NVIDIA_REQUIRE_* constraints using the host
35+
// CUDA driver API version from libcuda, the NVIDIA display driver version from
36+
// the driver root (libcuda / libnvidia-ml soname), the compute capability of
37+
// CUDA device 0, and (when requirements reference brand) the GPU product brand
38+
// from NVML. It is used for CSV and CDI / JIT-CDI modes.
39+
func checkRequirements(logger logger.Interface, image *image.CUDA, driver *root.Driver) error {
40+
if image == nil || image.HasDisableRequire() {
41+
logger.Debugf("NVIDIA_DISABLE_REQUIRE=%v; skipping requirement checks", true)
42+
return nil
43+
}
44+
45+
imageRequirements, err := image.GetRequirements()
46+
if err != nil {
47+
return fmt.Errorf("failed to get image requirements: %v", err)
48+
}
49+
50+
r := requirements.New(logger, imageRequirements)
51+
52+
cudaVersion, err := cuda.Version()
53+
if err != nil {
54+
logger.Warningf("Failed to get CUDA version: %v", err)
55+
} else {
56+
r.AddVersionProperty(requirements.CUDA, cudaVersion)
57+
}
58+
59+
compteCapability, err := cuda.ComputeCapability(0)
60+
if err != nil {
61+
logger.Warningf("Failed to get CUDA Compute Capability: %v", err)
62+
} else {
63+
r.AddVersionProperty(requirements.ARCH, compteCapability)
64+
}
65+
66+
driverVersion, err := driver.Version()
67+
if err != nil {
68+
logger.Warningf("Failed to get NVIDIA driver version: %v", err)
69+
} else {
70+
normalized, normErr := normalizeDriverVersionForSemver(driverVersion)
71+
if normErr != nil {
72+
logger.Warningf("NVIDIA driver version %q is not semver-normalizable: %v", driverVersion, normErr)
73+
} else {
74+
r.AddVersionProperty(requirements.DRIVER, normalized)
75+
}
76+
}
77+
78+
brand, err := getBrandFromNVML(driver)
79+
if err != nil {
80+
logger.Warningf("Failed to get GPU brand from NVML: %v", err)
81+
} else {
82+
r.AddStringProperty(requirements.BRAND, brand)
83+
}
84+
85+
return r.Assert()
86+
}
87+
88+
// normalizeDriverVersionForSemver converts a driver version taken from a
89+
// libcuda / libnvidia-ml soname suffix into a form accepted by
90+
// golang.org/x/mod/semver (no leading zeros in numeric segments)
91+
func normalizeDriverVersionForSemver(raw string) (string, error) {
92+
raw = strings.TrimSpace(raw)
93+
if raw == "" {
94+
return "", fmt.Errorf("empty driver version")
95+
}
96+
parts := strings.Split(raw, ".")
97+
out := make([]string, 0, len(parts))
98+
for _, p := range parts {
99+
if p == "" {
100+
return "", fmt.Errorf("empty version segment in %q", raw)
101+
}
102+
if strings.TrimLeft(p, "0123456789") != "" {
103+
return "", fmt.Errorf("non-numeric version segment %q in %q", p, raw)
104+
}
105+
n, err := strconv.ParseUint(p, 10, 64)
106+
if err != nil {
107+
return "", fmt.Errorf("invalid version segment %q in %q: %w", p, raw, err)
108+
}
109+
out = append(out, strconv.FormatUint(n, 10))
110+
}
111+
normalized := strings.Join(out, ".")
112+
if !semver.IsValid("v" + normalized) {
113+
return "", fmt.Errorf("normalized driver version %q is not valid semver", normalized)
114+
}
115+
return normalized, nil
116+
}
117+
118+
// getBrandFromNVML returns a lowercase brand token for the first visible GPU
119+
// (index 0), using NVML. When driver is non-nil, NVML is loaded from the
120+
// versioned libnvidia-ml under the driver root when possible.
121+
func getBrandFromNVML(driver *root.Driver) (string, error) {
122+
var lib nvml.Interface
123+
var opts []nvml.LibraryOption
124+
v, err := driver.Version()
125+
if err == nil && v != "" && v != "*.*" {
126+
paths, err := driver.Libraries().Locate("libnvidia-ml.so." + v)
127+
if err == nil && len(paths) > 0 {
128+
opts = append(opts, nvml.WithLibraryPath(paths[0]))
129+
}
130+
}
131+
132+
lib = nvml.New(opts...)
133+
if ret := lib.Init(); ret != nvml.SUCCESS {
134+
return "", fmt.Errorf("nvml.Init: %s", lib.ErrorString(ret))
135+
}
136+
defer func() {
137+
_ = lib.Shutdown()
138+
}()
139+
140+
device, ret := lib.DeviceGetHandleByIndex(0)
141+
if ret != nvml.SUCCESS {
142+
return "", fmt.Errorf("nvml.DeviceGetHandleByIndex(0): %s", lib.ErrorString(ret))
143+
}
144+
145+
brandType, ret := lib.DeviceGetBrand(device)
146+
if ret != nvml.SUCCESS {
147+
return "", fmt.Errorf("nvml.DeviceGetBrand: %s", lib.ErrorString(ret))
148+
}
149+
brand, ok := brandTypeToRequirementString(brandType)
150+
if !ok {
151+
return "", fmt.Errorf("unknown NVML brand type %v", brandType)
152+
}
153+
return brand, nil
154+
}
155+
156+
// brandTypeToRequirementString maps NVML brand enums to lowercase tokens
157+
// consistent with typical NVIDIA_REQUIRE_* image constraints.
158+
func brandTypeToRequirementString(b nvml.BrandType) (string, bool) {
159+
switch b {
160+
case nvml.BRAND_UNKNOWN:
161+
return "", false
162+
case nvml.BRAND_QUADRO:
163+
return "quadro", true
164+
case nvml.BRAND_TESLA:
165+
return "tesla", true
166+
case nvml.BRAND_NVS:
167+
return "nvs", true
168+
case nvml.BRAND_GRID:
169+
return "grid", true
170+
case nvml.BRAND_GEFORCE:
171+
return "geforce", true
172+
case nvml.BRAND_TITAN:
173+
return "titan", true
174+
case nvml.BRAND_NVIDIA_VAPPS:
175+
return "nvidiavapps", true
176+
case nvml.BRAND_NVIDIA_VPC:
177+
return "nvidiavpc", true
178+
case nvml.BRAND_NVIDIA_VCS:
179+
return "nvidiavcs", true
180+
case nvml.BRAND_NVIDIA_VWS:
181+
return "nvidiavws", true
182+
case nvml.BRAND_NVIDIA_CLOUD_GAMING:
183+
return "nvidiacloudgaming", true
184+
case nvml.BRAND_QUADRO_RTX:
185+
return "quadrortx", true
186+
case nvml.BRAND_NVIDIA_RTX:
187+
return "nvidiartx", true
188+
case nvml.BRAND_NVIDIA:
189+
return "nvidia", true
190+
case nvml.BRAND_GEFORCE_RTX:
191+
return "geforcertx", true
192+
case nvml.BRAND_TITAN_RTX:
193+
return "titanrtx", true
194+
default:
195+
return "", false
196+
}
197+
}

0 commit comments

Comments
 (0)