Skip to content

Commit 83b886c

Browse files
committed
fix(scheduling): acquire lock before reading runnerConfigs in load
Signed-off-by: Dorin Geman <dorin.geman@docker.com>
1 parent b57c1e0 commit 83b886c

1 file changed

Lines changed: 11 additions & 10 deletions

File tree

pkg/inference/scheduling/loader.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -415,13 +415,22 @@ func (l *loader) run(ctx context.Context) {
415415
// it should be released by the caller using the release mechanism (once the
416416
// runner is no longer needed).
417417
func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string, mode inference.BackendMode) (*runner, error) {
418-
// Grab the backend.
418+
// Grab the backend. The backends map is immutable after construction,
419+
// so it is safe to read without holding the lock.
419420
backend, ok := l.backends[backendName]
420421
if !ok {
421422
return nil, ErrBackendNotFound
422423
}
423424

424-
// Get runner configuration if available
425+
l.log.Info("Loading backend runner", "backend", backendName, "model", modelID, "mode", mode)
426+
427+
if !l.lock(ctx) {
428+
return nil, context.Canceled
429+
}
430+
defer l.unlock()
431+
432+
// Get runner configuration if available (must be done under lock since
433+
// runnerConfigs can be modified concurrently by setRunnerConfig).
425434
var runnerConfig *inference.BackendConfiguration
426435
draftModelID := ""
427436
if rc, ok := l.runnerConfigs[makeConfigKey(backendName, modelID, mode)]; ok {
@@ -455,14 +464,6 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string
455464
runnerConfig = &defaultConfig
456465
}
457466

458-
l.log.Info("Loading backend runner", "backend", backendName, "model", modelID, "mode", mode)
459-
460-
// Acquire the loader lock and defer its release.
461-
if !l.lock(ctx) {
462-
return nil, context.Canceled
463-
}
464-
defer l.unlock()
465-
466467
// Create a polling channel that we can use to detect state changes and
467468
// ensure that it's deregistered by the time we return.
468469
poll := make(chan struct{}, 1)

0 commit comments

Comments
 (0)