Skip to content

Commit 29b3596

Browse files
authored
Merge pull request #1666 from elezar/ensure-that-edit-factory-is-propagated
fix: Reuse instantiated editsFactory in CDI
2 parents 08bf583 + 79477c9 commit 29b3596

6 files changed

Lines changed: 41 additions & 24 deletions

File tree

internal/modifier/cdi.go

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ package modifier
1818

1919
import (
2020
"fmt"
21-
"slices"
2221
"strings"
2322

2423
"tags.cncf.io/container-device-interface/pkg/parser"
@@ -176,11 +175,6 @@ func filterAutomaticDevices(devices []string) []string {
176175
func (f *Factory) newAutomaticCDISpecModifier(devices []string) (oci.SpecModifier, error) {
177176
f.logger.Debugf("Generating in-memory CDI specs for devices %v", devices)
178177

179-
nvcdiFeatureFlags := slices.Clone(f.cfg.NVIDIAContainerRuntimeConfig.Modes.JitCDI.NVCDIFeatureFlags)
180-
if f.cfg.Features.NoAdditionalGIDsForDeviceNodes.IsEnabled() {
181-
nvcdiFeatureFlags = append(nvcdiFeatureFlags, nvcdi.FeatureNoAdditionalGIDsForDeviceNodes)
182-
}
183-
184178
csvFiles, err := csv.GetFileList(f.cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.MountSpecPath)
185179
if err != nil {
186180
f.logger.Warningf("Failed to get the list of CSV files: %v", err)
@@ -198,10 +192,11 @@ func (f *Factory) newAutomaticCDISpecModifier(devices []string) (oci.SpecModifie
198192
nvcdi.WithNVIDIACDIHookPath(f.cfg.NVIDIACTKConfig.Path),
199193
nvcdi.WithDriverRoot(f.driver.Root),
200194
nvcdi.WithDevRoot(f.driver.DevRoot),
195+
nvcdi.WithEditsFactory(f.editsFactory),
201196
nvcdi.WithVendor(automaticDeviceVendor),
202197
nvcdi.WithClass(cdiModeIdentifiers.deviceClassByMode[mode]),
203198
nvcdi.WithMode(mode),
204-
nvcdi.WithFeatureFlags(nvcdiFeatureFlags...),
199+
nvcdi.WithFeatureFlags(f.cfg.NVIDIAContainerRuntimeConfig.Modes.JitCDI.NVCDIFeatureFlags...),
205200
nvcdi.WithCSVCompatContainerRoot(f.cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.CompatContainerRoot),
206201
nvcdi.WithCSVFiles(csvFiles),
207202
)

internal/modifier/csv_test.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ func TestNewCSVModifier(t *testing.T) {
124124
f := createFactory(
125125
WithLogger(logger),
126126
WithDriver(driver),
127-
WithLogger(logger),
128127
WithConfig(&tc.cfg),
129128
WithImage(&image),
130129
)

internal/modifier/discover_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
testlog "github.com/sirupsen/logrus/hooks/test"
2525
"github.com/stretchr/testify/require"
2626

27+
"github.com/NVIDIA/nvidia-container-toolkit/api/config/v1"
2728
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
2829
)
2930

@@ -132,6 +133,7 @@ func TestDiscoverModifier(t *testing.T) {
132133

133134
factory := createFactory(
134135
WithLogger(logger),
136+
WithConfig(&config.Config{}),
135137
)
136138
for _, tc := range testCases {
137139
t.Run(tc.description, func(t *testing.T) {

internal/modifier/factory.go

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,20 @@ import (
3232
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
3333
)
3434

35-
type Factory struct {
35+
// factoryOptions define the set of options that must be set when constructing
36+
// a modifier factory.
37+
type factoryOptions struct {
3638
logger logger.Interface
3739
cfg *config.Config
3840
driver *root.Driver
3941
hookCreator discover.HookCreator
4042
image *image.CUDA
4143
runtimeMode info.RuntimeMode
44+
}
4245

46+
type Factory struct {
47+
factoryOptions
48+
// An editsFactory is created at construction.
4349
editsFactory edits.Factory
4450
}
4551

@@ -60,12 +66,14 @@ func New(opts ...Option) (oci.SpecModifier, error) {
6066
func createFactory(opts ...Option) *Factory {
6167
f := &Factory{}
6268
for _, opt := range opts {
63-
opt(f)
64-
}
65-
if f.editsFactory == nil {
66-
f.editsFactory = edits.NewFactory(edits.WithLogger(f.logger))
69+
opt(&f.factoryOptions)
6770
}
6871

72+
f.editsFactory = edits.NewFactory(
73+
edits.WithLogger(f.logger),
74+
edits.WithNoAdditionalGIDsForDeviceNodes(f.cfg.Features.NoAdditionalGIDsForDeviceNodes.IsEnabled()),
75+
)
76+
6977
return f
7078
}
7179

@@ -125,39 +133,39 @@ func (f *Factory) create() (oci.SpecModifier, error) {
125133
return modifiers, nil
126134
}
127135

128-
type Option func(*Factory)
136+
type Option func(*factoryOptions)
129137

130138
func WithConfig(cfg *config.Config) Option {
131-
return func(f *Factory) {
139+
return func(f *factoryOptions) {
132140
f.cfg = cfg
133141
}
134142
}
135143

136144
func WithDriver(driver *root.Driver) Option {
137-
return func(f *Factory) {
145+
return func(f *factoryOptions) {
138146
f.driver = driver
139147
}
140148
}
141149
func WithHookCreator(hookCreator discover.HookCreator) Option {
142-
return func(f *Factory) {
150+
return func(f *factoryOptions) {
143151
f.hookCreator = hookCreator
144152
}
145153
}
146154

147155
func WithImage(image *image.CUDA) Option {
148-
return func(f *Factory) {
156+
return func(f *factoryOptions) {
149157
f.image = image
150158
}
151159
}
152160

153161
func WithLogger(logger logger.Interface) Option {
154-
return func(f *Factory) {
162+
return func(f *factoryOptions) {
155163
f.logger = logger
156164
}
157165
}
158166

159167
func WithRuntimeMode(runtimeMode info.RuntimeMode) Option {
160-
return func(f *Factory) {
168+
return func(f *factoryOptions) {
161169
f.runtimeMode = runtimeMode
162170
}
163171
}

pkg/nvcdi/lib.go

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,7 @@ func New(opts ...Option) (Interface, error) {
7171
discover.WithLdconfigPath(o.ldconfigPath),
7272
discover.WithDisabledHooks(o.disabledHooks...),
7373
),
74-
editsFactory: edits.NewFactory(
75-
edits.WithLogger(o.logger),
76-
edits.WithNoAdditionalGIDsForDeviceNodes(o.featureFlags[FeatureNoAdditionalGIDsForDeviceNodes]),
77-
),
74+
editsFactory: o.editsFactory,
7875
}
7976

8077
var factory deviceSpecGeneratorFactory

pkg/nvcdi/options.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"github.com/NVIDIA/go-nvml/pkg/nvml"
2323

2424
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
25+
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
2526
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
2627
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvsandboxutils"
2728
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv"
@@ -52,6 +53,8 @@ type options struct {
5253

5354
disabledHooks []discover.HookName
5455
enabledHooks []discover.HookName
56+
57+
editsFactory edits.Factory
5558
}
5659

5760
type platformlibs struct {
@@ -116,6 +119,13 @@ func populateOptions(opts ...Option) *options {
116119
o.disabledHooks = append(o.disabledHooks, HookEnableCudaCompat, DisableDeviceNodeModificationHook)
117120
}
118121

122+
if o.editsFactory == nil {
123+
o.editsFactory = edits.NewFactory(
124+
edits.WithLogger(o.logger),
125+
edits.WithNoAdditionalGIDsForDeviceNodes(o.featureFlags[FeatureNoAdditionalGIDsForDeviceNodes]),
126+
)
127+
}
128+
119129
return o
120130
}
121131

@@ -191,6 +201,12 @@ func WithDevRoot(root string) Option {
191201
}
192202
}
193203

204+
func WithEditsFactory(editsFactory edits.Factory) Option {
205+
return func(l *options) {
206+
l.editsFactory = editsFactory
207+
}
208+
}
209+
194210
// WithLogger sets the logger for the library
195211
func WithLogger(logger logger.Interface) Option {
196212
return func(l *options) {

0 commit comments

Comments
 (0)