Skip to content

Commit 79477c9

Browse files
committed
fix: Reuse instantiated editsFactory in CDI
This change ensures that the editsFactory is instantiated once and passed to the nvcdi constructor. This makes it unnecessary to reprocess optional arguments and configs. Signed-off-by: Evan Lezar <elezar@nvidia.com>
1 parent 722b541 commit 79477c9

6 files changed

Lines changed: 41 additions & 24 deletions

File tree

internal/modifier/cdi.go

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ package modifier
1818

1919
import (
2020
"fmt"
21-
"slices"
2221
"strings"
2322

2423
"tags.cncf.io/container-device-interface/pkg/parser"
@@ -176,11 +175,6 @@ func filterAutomaticDevices(devices []string) []string {
176175
func (f *Factory) newAutomaticCDISpecModifier(devices []string) (oci.SpecModifier, error) {
177176
f.logger.Debugf("Generating in-memory CDI specs for devices %v", devices)
178177

179-
nvcdiFeatureFlags := slices.Clone(f.cfg.NVIDIAContainerRuntimeConfig.Modes.JitCDI.NVCDIFeatureFlags)
180-
if f.cfg.Features.NoAdditionalGIDsForDeviceNodes.IsEnabled() {
181-
nvcdiFeatureFlags = append(nvcdiFeatureFlags, nvcdi.FeatureNoAdditionalGIDsForDeviceNodes)
182-
}
183-
184178
csvFiles, err := csv.GetFileList(f.cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.MountSpecPath)
185179
if err != nil {
186180
f.logger.Warningf("Failed to get the list of CSV files: %v", err)
@@ -198,10 +192,11 @@ func (f *Factory) newAutomaticCDISpecModifier(devices []string) (oci.SpecModifie
198192
nvcdi.WithNVIDIACDIHookPath(f.cfg.NVIDIACTKConfig.Path),
199193
nvcdi.WithDriverRoot(f.driver.Root),
200194
nvcdi.WithDevRoot(f.driver.DevRoot),
195+
nvcdi.WithEditsFactory(f.editsFactory),
201196
nvcdi.WithVendor(automaticDeviceVendor),
202197
nvcdi.WithClass(cdiModeIdentifiers.deviceClassByMode[mode]),
203198
nvcdi.WithMode(mode),
204-
nvcdi.WithFeatureFlags(nvcdiFeatureFlags...),
199+
nvcdi.WithFeatureFlags(f.cfg.NVIDIAContainerRuntimeConfig.Modes.JitCDI.NVCDIFeatureFlags...),
205200
nvcdi.WithCSVCompatContainerRoot(f.cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.CompatContainerRoot),
206201
nvcdi.WithCSVFiles(csvFiles),
207202
)

internal/modifier/csv_test.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ func TestNewCSVModifier(t *testing.T) {
124124
f := createFactory(
125125
WithLogger(logger),
126126
WithDriver(driver),
127-
WithLogger(logger),
128127
WithConfig(&tc.cfg),
129128
WithImage(&image),
130129
)

internal/modifier/discover_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
testlog "github.com/sirupsen/logrus/hooks/test"
2525
"github.com/stretchr/testify/require"
2626

27+
"github.com/NVIDIA/nvidia-container-toolkit/api/config/v1"
2728
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
2829
)
2930

@@ -132,6 +133,7 @@ func TestDiscoverModifier(t *testing.T) {
132133

133134
factory := createFactory(
134135
WithLogger(logger),
136+
WithConfig(&config.Config{}),
135137
)
136138
for _, tc := range testCases {
137139
t.Run(tc.description, func(t *testing.T) {

internal/modifier/factory.go

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,20 @@ import (
3232
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
3333
)
3434

35-
type Factory struct {
35+
// factoryOptions define the set of options that must be set when constructing
36+
// a modifier factory.
37+
type factoryOptions struct {
3638
logger logger.Interface
3739
cfg *config.Config
3840
driver *root.Driver
3941
hookCreator discover.HookCreator
4042
image *image.CUDA
4143
runtimeMode info.RuntimeMode
44+
}
4245

46+
type Factory struct {
47+
factoryOptions
48+
// An editsFactory is created at construction.
4349
editsFactory edits.Factory
4450
}
4551

@@ -60,12 +66,14 @@ func New(opts ...Option) (oci.SpecModifier, error) {
6066
func createFactory(opts ...Option) *Factory {
6167
f := &Factory{}
6268
for _, opt := range opts {
63-
opt(f)
64-
}
65-
if f.editsFactory == nil {
66-
f.editsFactory = edits.NewFactory(edits.WithLogger(f.logger))
69+
opt(&f.factoryOptions)
6770
}
6871

72+
f.editsFactory = edits.NewFactory(
73+
edits.WithLogger(f.logger),
74+
edits.WithNoAdditionalGIDsForDeviceNodes(f.cfg.Features.NoAdditionalGIDsForDeviceNodes.IsEnabled()),
75+
)
76+
6977
return f
7078
}
7179

@@ -125,39 +133,39 @@ func (f *Factory) create() (oci.SpecModifier, error) {
125133
return modifiers, nil
126134
}
127135

128-
type Option func(*Factory)
136+
type Option func(*factoryOptions)
129137

130138
func WithConfig(cfg *config.Config) Option {
131-
return func(f *Factory) {
139+
return func(f *factoryOptions) {
132140
f.cfg = cfg
133141
}
134142
}
135143

136144
func WithDriver(driver *root.Driver) Option {
137-
return func(f *Factory) {
145+
return func(f *factoryOptions) {
138146
f.driver = driver
139147
}
140148
}
141149
func WithHookCreator(hookCreator discover.HookCreator) Option {
142-
return func(f *Factory) {
150+
return func(f *factoryOptions) {
143151
f.hookCreator = hookCreator
144152
}
145153
}
146154

147155
func WithImage(image *image.CUDA) Option {
148-
return func(f *Factory) {
156+
return func(f *factoryOptions) {
149157
f.image = image
150158
}
151159
}
152160

153161
func WithLogger(logger logger.Interface) Option {
154-
return func(f *Factory) {
162+
return func(f *factoryOptions) {
155163
f.logger = logger
156164
}
157165
}
158166

159167
func WithRuntimeMode(runtimeMode info.RuntimeMode) Option {
160-
return func(f *Factory) {
168+
return func(f *factoryOptions) {
161169
f.runtimeMode = runtimeMode
162170
}
163171
}

pkg/nvcdi/lib.go

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,7 @@ func New(opts ...Option) (Interface, error) {
7171
discover.WithLdconfigPath(o.ldconfigPath),
7272
discover.WithDisabledHooks(o.disabledHooks...),
7373
),
74-
editsFactory: edits.NewFactory(
75-
edits.WithLogger(o.logger),
76-
edits.WithNoAdditionalGIDsForDeviceNodes(o.featureFlags[FeatureNoAdditionalGIDsForDeviceNodes]),
77-
),
74+
editsFactory: o.editsFactory,
7875
}
7976

8077
var factory deviceSpecGeneratorFactory

pkg/nvcdi/options.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"github.com/NVIDIA/go-nvml/pkg/nvml"
2323

2424
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
25+
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
2526
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
2627
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvsandboxutils"
2728
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv"
@@ -52,6 +53,8 @@ type options struct {
5253

5354
disabledHooks []discover.HookName
5455
enabledHooks []discover.HookName
56+
57+
editsFactory edits.Factory
5558
}
5659

5760
type platformlibs struct {
@@ -116,6 +119,13 @@ func populateOptions(opts ...Option) *options {
116119
o.disabledHooks = append(o.disabledHooks, HookEnableCudaCompat, DisableDeviceNodeModificationHook)
117120
}
118121

122+
if o.editsFactory == nil {
123+
o.editsFactory = edits.NewFactory(
124+
edits.WithLogger(o.logger),
125+
edits.WithNoAdditionalGIDsForDeviceNodes(o.featureFlags[FeatureNoAdditionalGIDsForDeviceNodes]),
126+
)
127+
}
128+
119129
return o
120130
}
121131

@@ -191,6 +201,12 @@ func WithDevRoot(root string) Option {
191201
}
192202
}
193203

204+
func WithEditsFactory(editsFactory edits.Factory) Option {
205+
return func(l *options) {
206+
l.editsFactory = editsFactory
207+
}
208+
}
209+
194210
// WithLogger sets the logger for the library
195211
func WithLogger(logger logger.Interface) Option {
196212
return func(l *options) {

0 commit comments

Comments
 (0)