diff --git a/pkg/inference/models/manager.go b/pkg/inference/models/manager.go index c87efa831..2f9b84589 100644 --- a/pkg/inference/models/manager.go +++ b/pkg/inference/models/manager.go @@ -256,6 +256,27 @@ func (m *Manager) handleGetModel(w http.ResponseWriter, r *http.Request) { } } +// ResolveModelID resolves a model reference to a model ID. If resolution fails, it returns the original ref. +func (m *Manager) ResolveModelID(modelRef string) string { + // Sanitize modelRef to prevent log forgery + sanitizedModelRef := strings.ReplaceAll(modelRef, "\n", "") + sanitizedModelRef = strings.ReplaceAll(sanitizedModelRef, "\r", "") + + model, err := m.GetModel(sanitizedModelRef) + if err != nil { + m.log.Warnf("Failed to resolve model ref %s to ID: %v", sanitizedModelRef, err) + return sanitizedModelRef + } + + modelID, err := model.ID() + if err != nil { + m.log.Warnf("Failed to get model ID for ref %s: %v", sanitizedModelRef, err) + return sanitizedModelRef + } + + return modelID +} + func getLocalModel(m *Manager, name string) (*Model, error) { if m.distributionClient == nil { return nil, errors.New("model distribution service unavailable") diff --git a/pkg/inference/scheduling/loader.go b/pkg/inference/scheduling/loader.go index 769eef27d..8efee1455 100644 --- a/pkg/inference/scheduling/loader.go +++ b/pkg/inference/scheduling/loader.go @@ -42,12 +42,20 @@ var ( type runnerKey struct { // backend is the backend associated with the runner. backend string - // model is the model associated with the runner. - model string + // modelID is the ID (digest) of the model associated with the runner. + modelID string // mode is the operation mode associated with the runner. mode inference.BackendMode } +// runnerInfo holds information about a runner including its slot and the original model reference used to load it. +type runnerInfo struct { + // slot is the slot index where the runner is stored. + slot int + // modelRef is the original model reference (tag) used to load the runner. + modelRef string +} + // loader manages the loading and unloading of backend runners. It regulates // active backends in a manner that avoids exhausting system resources. Loaders // assume that all of their backends have been installed, so no load requests @@ -80,7 +88,7 @@ type loader struct { // polling. Each signaling channel should be buffered (with size 1). waiters map[chan<- struct{}]bool // runners maps runner keys to their slot index. - runners map[runnerKey]int + runners map[runnerKey]runnerInfo // slots maps slot indices to associated runners. A slot is considered free // if the runner value in it is nil. slots []*runner @@ -151,7 +159,7 @@ func newLoader( guard: make(chan struct{}, 1), availableMemory: totalMemory, waiters: make(map[chan<- struct{}]bool), - runners: make(map[runnerKey]int, nSlots), + runners: make(map[runnerKey]runnerInfo, nSlots), slots: make([]*runner, nSlots), references: make([]uint, nSlots), allocations: make([]uint64, nSlots), @@ -196,24 +204,24 @@ func (l *loader) broadcast() { // lock. It returns the number of remaining runners. func (l *loader) evict(idleOnly bool) int { now := time.Now() - for r, slot := range l.runners { - unused := l.references[slot] == 0 - idle := unused && now.Sub(l.timestamps[slot]) > l.runnerIdleTimeout + for r, runnerInfo := range l.runners { + unused := l.references[runnerInfo.slot] == 0 + idle := unused && now.Sub(l.timestamps[runnerInfo.slot]) > l.runnerIdleTimeout defunct := false select { - case <-l.slots[slot].done: + case <-l.slots[runnerInfo.slot].done: defunct = true default: } if unused && (!idleOnly || idle || defunct) { - l.log.Infof("Evicting %s backend runner with model %s in %s mode", - r.backend, r.model, r.mode, + l.log.Infof("Evicting %s backend runner with model %s (%s) in %s mode", + r.backend, r.modelID, runnerInfo.modelRef, r.mode, ) - l.slots[slot].terminate() - l.slots[slot] = nil - l.availableMemory += l.allocations[slot] - l.allocations[slot] = 0 - l.timestamps[slot] = time.Time{} + l.slots[runnerInfo.slot].terminate() + l.slots[runnerInfo.slot] = nil + l.availableMemory += l.allocations[runnerInfo.slot] + l.allocations[runnerInfo.slot] = 0 + l.timestamps[runnerInfo.slot] = time.Time{} delete(l.runners, r) } } @@ -224,17 +232,17 @@ func (l *loader) evict(idleOnly bool) int { // It returns the number of remaining runners. func (l *loader) evictRunner(backend, model string, mode inference.BackendMode) int { allBackends := backend == "" - for r, slot := range l.runners { - unused := l.references[slot] == 0 - if unused && (allBackends || r.backend == backend) && r.model == model && r.mode == mode { - l.log.Infof("Evicting %s backend runner with model %s in %s mode", - r.backend, r.model, r.mode, + for r, runnerInfo := range l.runners { + unused := l.references[runnerInfo.slot] == 0 + if unused && (allBackends || r.backend == backend) && r.modelID == model && r.mode == mode { + l.log.Infof("Evicting %s backend runner with model %s (%s) in %s mode", + r.backend, r.modelID, runnerInfo.modelRef, r.mode, ) - l.slots[slot].terminate() - l.slots[slot] = nil - l.availableMemory += l.allocations[slot] - l.allocations[slot] = 0 - l.timestamps[slot] = time.Time{} + l.slots[runnerInfo.slot].terminate() + l.slots[runnerInfo.slot] = nil + l.availableMemory += l.allocations[runnerInfo.slot] + l.allocations[runnerInfo.slot] = 0 + l.timestamps[runnerInfo.slot] = time.Time{} delete(l.runners, r) } } @@ -254,11 +262,12 @@ func (l *loader) Unload(ctx context.Context, unload UnloadRequest) int { return l.evict(false) } else { for _, model := range unload.Models { + modelID := l.modelManager.ResolveModelID(model) delete(l.runnerConfigs, runnerKey{unload.Backend, model, inference.BackendModeCompletion}) // Evict both, completion and embedding models. We should consider // accepting a mode parameter in unload requests. - l.evictRunner(unload.Backend, model, inference.BackendModeCompletion) - l.evictRunner(unload.Backend, model, inference.BackendModeEmbedding) + l.evictRunner(unload.Backend, modelID, inference.BackendModeCompletion) + l.evictRunner(unload.Backend, modelID, inference.BackendModeEmbedding) } return len(l.runners) } @@ -282,15 +291,15 @@ func stopAndDrainTimer(timer *time.Timer) { func (l *loader) idleCheckDuration() time.Duration { // Compute the oldest usage time for any idle runner. var oldest time.Time - for _, slot := range l.runners { + for _, runnerInfo := range l.runners { select { - case <-l.slots[slot].done: + case <-l.slots[runnerInfo.slot].done: // Check immediately if a runner is defunct return 0 default: } - if l.references[slot] == 0 { - timestamp := l.timestamps[slot] + if l.references[runnerInfo.slot] == 0 { + timestamp := l.timestamps[runnerInfo.slot] if oldest.IsZero() || timestamp.Before(oldest) { oldest = timestamp } @@ -378,10 +387,10 @@ func (l *loader) run(ctx context.Context) { } } -// load allocates a runner using the specified backend and model. If allocated, +// load allocates a runner using the specified backend and modelID. If allocated, // it should be released by the caller using the release mechanism (once the // runner is no longer needed). -func (l *loader) load(ctx context.Context, backendName, model string, mode inference.BackendMode) (*runner, error) { +func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string, mode inference.BackendMode) (*runner, error) { // Grab the backend. backend, ok := l.backends[backendName] if !ok { @@ -426,20 +435,20 @@ func (l *loader) load(ctx context.Context, backendName, model string, mode infer } // See if we can satisfy the request with an existing runner. - existing, ok := l.runners[runnerKey{backendName, model, mode}] + existing, ok := l.runners[runnerKey{backendName, modelID, mode}] if ok { select { - case <-l.slots[existing].done: - l.log.Warnf("%s runner for %s is defunct. Waiting for it to be evicted.", backendName, model) - if l.references[existing] == 0 { - l.evictRunner(backendName, model, mode) + case <-l.slots[existing.slot].done: + l.log.Warnf("%s runner for %s is defunct. Waiting for it to be evicted.", backendName, existing.modelRef) + if l.references[existing.slot] == 0 { + l.evictRunner(backendName, modelID, mode) } else { goto WaitForChange } default: - l.references[existing] += 1 - l.timestamps[existing] = time.Time{} - return l.slots[existing], nil + l.references[existing.slot] += 1 + l.timestamps[existing.slot] = time.Time{} + return l.slots[existing.slot], nil } } @@ -462,15 +471,15 @@ func (l *loader) load(ctx context.Context, backendName, model string, mode infer // If we've identified a slot, then we're ready to start a runner. if slot >= 0 { var runnerConfig *inference.BackendConfiguration - if rc, ok := l.runnerConfigs[runnerKey{backendName, model, mode}]; ok { + if rc, ok := l.runnerConfigs[runnerKey{backendName, modelID, mode}]; ok { runnerConfig = &rc } // Create the runner. - l.log.Infof("Loading %s backend runner with model %s in %s mode", backendName, model, mode) - runner, err := run(l.log, backend, model, mode, slot, runnerConfig, l.openAIRecorder) + l.log.Infof("Loading %s backend runner with model %s in %s mode", backendName, modelID, mode) + runner, err := run(l.log, backend, modelID, mode, slot, runnerConfig, l.openAIRecorder) if err != nil { l.log.Warnf("Unable to start %s backend runner with model %s in %s mode: %v", - backendName, model, mode, err, + backendName, modelID, mode, err, ) return nil, fmt.Errorf("unable to start runner: %w", err) } @@ -484,14 +493,14 @@ func (l *loader) load(ctx context.Context, backendName, model string, mode infer if err := runner.wait(ctx); err != nil { runner.terminate() l.log.Warnf("Initialization for %s backend runner with model %s in %s mode failed: %v", - backendName, model, mode, err, + backendName, modelID, mode, err, ) return nil, fmt.Errorf("error waiting for runner to be ready: %w", err) } // Perform registration and return the runner. l.availableMemory -= memory - l.runners[runnerKey{backendName, model, mode}] = slot + l.runners[runnerKey{backendName, modelID, mode}] = runnerInfo{slot, modelRef} l.slots[slot] = runner l.references[slot] = 1 l.allocations[slot] = memory @@ -523,17 +532,17 @@ func (l *loader) release(runner *runner) { slot := l.runners[runnerKey{runner.backend.Name(), runner.model, runner.mode}] // Decrement the runner's reference count. - l.references[slot] -= 1 + l.references[slot.slot] -= 1 // If the runner's reference count is now zero, then check if it is still // active, and record now as its idle start time and signal the idle // checker. - if l.references[slot] == 0 { + if l.references[slot.slot] == 0 { select { case <-runner.done: l.evictRunner(runner.backend.Name(), runner.model, runner.mode) default: - l.timestamps[slot] = time.Now() + l.timestamps[slot.slot] = time.Now() select { case l.idleCheck <- struct{}{}: default: @@ -545,22 +554,22 @@ func (l *loader) release(runner *runner) { l.broadcast() } -func (l *loader) setRunnerConfig(ctx context.Context, backendName, model string, mode inference.BackendMode, runnerConfig inference.BackendConfiguration) error { +func (l *loader) setRunnerConfig(ctx context.Context, backendName, modelID string, mode inference.BackendMode, runnerConfig inference.BackendConfiguration) error { l.lock(ctx) defer l.unlock() - runnerId := runnerKey{backendName, model, mode} + runnerId := runnerKey{backendName, modelID, mode} // If the configuration hasn't changed, then just return. if existingConfig, ok := l.runnerConfigs[runnerId]; ok && reflect.DeepEqual(runnerConfig, existingConfig) { - l.log.Infof("Configuration for %s runner for model %s unchanged", backendName, model) + l.log.Infof("Configuration for %s runner for modelID %s unchanged", backendName, modelID) return nil } // If there's an active runner whose configuration we want to override, then // try evicting it (because it may not be in use). if _, ok := l.runners[runnerId]; ok { - l.evictRunner(backendName, model, mode) + l.evictRunner(backendName, modelID, mode) } // If there's still then active runner, then we can't (or at least @@ -569,7 +578,7 @@ func (l *loader) setRunnerConfig(ctx context.Context, backendName, model string, return errRunnerAlreadyActive } - l.log.Infof("Configuring %s runner for %s", backendName, model) + l.log.Infof("Configuring %s runner for %s", backendName, modelID) l.runnerConfigs[runnerId] = runnerConfig return nil } diff --git a/pkg/inference/scheduling/scheduler.go b/pkg/inference/scheduling/scheduler.go index 10b7e6e6f..23477d12c 100644 --- a/pkg/inference/scheduling/scheduler.go +++ b/pkg/inference/scheduling/scheduler.go @@ -56,7 +56,7 @@ func NewScheduler( allowedOrigins []string, tracker *metrics.Tracker, ) *Scheduler { - openAIRecorder := metrics.NewOpenAIRecorder(log.WithField("component", "openai-recorder")) + openAIRecorder := metrics.NewOpenAIRecorder(log.WithField("component", "openai-recorder"), modelManager) // Create the scheduler. s := &Scheduler{ @@ -238,8 +238,10 @@ func (s *Scheduler) handleOpenAIInference(w http.ResponseWriter, r *http.Request s.tracker.TrackModel(model, r.UserAgent()) } + modelID := s.modelManager.ResolveModelID(request.Model) + // Request a runner to execute the request and defer its release. - runner, err := s.loader.load(r.Context(), backend.Name(), request.Model, backendMode) + runner, err := s.loader.load(r.Context(), backend.Name(), modelID, request.Model, backendMode) if err != nil { http.Error(w, fmt.Errorf("unable to load runner: %w", err).Error(), http.StatusInternalServerError) return @@ -295,17 +297,17 @@ func (s *Scheduler) getLoaderStatus(ctx context.Context) []BackendStatus { result := make([]BackendStatus, 0, len(s.loader.runners)) - for key, slot := range s.loader.runners { - if s.loader.slots[slot] != nil { + for key, runnerInfo := range s.loader.runners { + if s.loader.slots[runnerInfo.slot] != nil { status := BackendStatus{ BackendName: key.backend, - ModelName: key.model, + ModelName: runnerInfo.modelRef, Mode: key.mode.String(), LastUsed: time.Time{}, } - if s.loader.references[slot] == 0 { - status.LastUsed = s.loader.timestamps[slot] + if s.loader.references[runnerInfo.slot] == 0 { + status.LastUsed = s.loader.timestamps[runnerInfo.slot] } result = append(result, status) @@ -414,9 +416,9 @@ func (s *Scheduler) Configure(w http.ResponseWriter, r *http.Request) { // Configure is called by compose for each model. s.tracker.TrackModel(model, r.UserAgent()) } - - if err := s.loader.setRunnerConfig(r.Context(), backend.Name(), configureRequest.Model, inference.BackendModeCompletion, runnerConfig); err != nil { - s.log.Warnf("Failed to configure %s runner for %s: %s", backend.Name(), configureRequest.Model, err) + modelID := s.modelManager.ResolveModelID(configureRequest.Model) + if err := s.loader.setRunnerConfig(r.Context(), backend.Name(), modelID, inference.BackendModeCompletion, runnerConfig); err != nil { + s.log.Warnf("Failed to configure %s runner for %s (%s): %s", backend.Name(), configureRequest.Model, modelID, err) if errors.Is(err, errRunnerAlreadyActive) { http.Error(w, err.Error(), http.StatusConflict) } else { @@ -442,14 +444,14 @@ func (s *Scheduler) GetAllActiveRunners() []metrics.ActiveRunner { // Find the runner slot for this backend/model combination key := runnerKey{ backend: backend.BackendName, - model: backend.ModelName, + modelID: backend.ModelName, mode: parseBackendMode(backend.Mode), } - if slot, exists := s.loader.runners[key]; exists { - socket, err := RunnerSocketPath(slot) + if runnerInfo, exists := s.loader.runners[key]; exists { + socket, err := RunnerSocketPath(runnerInfo.slot) if err != nil { - s.log.Warnf("Failed to get socket path for runner %s/%s: %v", backend.BackendName, backend.ModelName, err) + s.log.Warnf("Failed to get socket path for runner %s/%s (%s): %v", backend.BackendName, backend.ModelName, key.modelID, err) continue } @@ -480,13 +482,13 @@ func (s *Scheduler) GetLlamaCppSocket() (string, error) { // Find the runner slot for this backend/model combination key := runnerKey{ backend: backend.BackendName, - model: backend.ModelName, + modelID: backend.ModelName, mode: parseBackendMode(backend.Mode), } - if slot, exists := s.loader.runners[key]; exists { + if runnerInfo, exists := s.loader.runners[key]; exists { // Use the RunnerSocketPath function to get the socket path - return RunnerSocketPath(slot) + return RunnerSocketPath(runnerInfo.slot) } } } diff --git a/pkg/metrics/openai_recorder.go b/pkg/metrics/openai_recorder.go index e1d5fc493..19fcfae70 100644 --- a/pkg/metrics/openai_recorder.go +++ b/pkg/metrics/openai_recorder.go @@ -10,6 +10,7 @@ import ( "time" "github.com/docker/model-runner/pkg/inference" + "github.com/docker/model-runner/pkg/inference/models" "github.com/docker/model-runner/pkg/logging" ) @@ -53,15 +54,17 @@ type ModelData struct { } type OpenAIRecorder struct { - log logging.Logger - records map[string]*ModelData - m sync.RWMutex + log logging.Logger + records map[string]*ModelData // key is model ID + modelManager *models.Manager // for resolving model tags to IDs + m sync.RWMutex } -func NewOpenAIRecorder(log logging.Logger) *OpenAIRecorder { +func NewOpenAIRecorder(log logging.Logger, modelManager *models.Manager) *OpenAIRecorder { return &OpenAIRecorder{ - log: log, - records: make(map[string]*ModelData), + log: log, + modelManager: modelManager, + records: make(map[string]*ModelData), } } @@ -71,28 +74,32 @@ func (r *OpenAIRecorder) SetConfigForModel(model string, config *inference.Backe return } + modelID := r.modelManager.ResolveModelID(model) + r.m.Lock() defer r.m.Unlock() - if r.records[model] == nil { - r.records[model] = &ModelData{ + if r.records[modelID] == nil { + r.records[modelID] = &ModelData{ Records: make([]*RequestResponsePair, 0, 10), Config: inference.BackendConfiguration{}, } } - r.records[model].Config = *config + r.records[modelID].Config = *config } func (r *OpenAIRecorder) RecordRequest(model string, req *http.Request, body []byte) string { + modelID := r.modelManager.ResolveModelID(model) + r.m.Lock() defer r.m.Unlock() - recordID := fmt.Sprintf("%s_%d", model, time.Now().UnixNano()) + recordID := fmt.Sprintf("%s_%d", modelID, time.Now().UnixNano()) record := &RequestResponsePair{ ID: recordID, - Model: model, + Model: modelID, Method: req.Method, URL: req.URL.Path, Request: string(body), @@ -100,17 +107,17 @@ func (r *OpenAIRecorder) RecordRequest(model string, req *http.Request, body []b UserAgent: req.UserAgent(), } - if r.records[model] == nil { - r.records[model] = &ModelData{ + if r.records[modelID] == nil { + r.records[modelID] = &ModelData{ Records: make([]*RequestResponsePair, 0, 10), Config: inference.BackendConfiguration{}, } } - r.records[model].Records = append(r.records[model].Records, record) + r.records[modelID].Records = append(r.records[modelID].Records, record) - if len(r.records[model].Records) > 10 { - r.records[model].Records = r.records[model].Records[1:] + if len(r.records[modelID].Records) > 10 { + r.records[modelID].Records = r.records[modelID].Records[1:] } return recordID @@ -138,10 +145,12 @@ func (r *OpenAIRecorder) RecordResponse(id, model string, rw http.ResponseWriter response = responseBody } + modelID := r.modelManager.ResolveModelID(model) + r.m.Lock() defer r.m.Unlock() - if modelData, exists := r.records[model]; exists { + if modelData, exists := r.records[modelID]; exists { for _, record := range modelData.Records { if record.ID == id { record.Response = response @@ -149,9 +158,9 @@ func (r *OpenAIRecorder) RecordResponse(id, model string, rw http.ResponseWriter return } } - r.log.Errorf("Matching request (id=%s) not found for model %s - %d\n%s", id, model, statusCode, response) + r.log.Errorf("Matching request (id=%s) not found for model %s - %d\n%s", id, modelID, statusCode, response) } else { - r.log.Errorf("Model %s not found in records - %d\n%s", model, statusCode, response) + r.log.Errorf("Model %s not found in records - %d\n%s", modelID, statusCode, response) } } @@ -237,11 +246,12 @@ func (r *OpenAIRecorder) GetRecordsByModelHandler() http.HandlerFunc { return } + modelID := r.modelManager.ResolveModelID(model) if err := json.NewEncoder(w).Encode(map[string]interface{}{ "model": model, "records": records, "count": len(records), - "config": r.records[model].Config, + "config": r.records[modelID].Config, }); err != nil { http.Error(w, fmt.Sprintf("Failed to encode records for model '%s': %v", model, err), http.StatusInternalServerError) @@ -252,10 +262,12 @@ func (r *OpenAIRecorder) GetRecordsByModelHandler() http.HandlerFunc { } func (r *OpenAIRecorder) GetRecordsByModel(model string) []*RequestResponsePair { + modelID := r.modelManager.ResolveModelID(model) + r.m.RLock() defer r.m.RUnlock() - if modelData, exists := r.records[model]; exists { + if modelData, exists := r.records[modelID]; exists { result := make([]*RequestResponsePair, len(modelData.Records)) copy(result, modelData.Records) return result @@ -265,13 +277,15 @@ func (r *OpenAIRecorder) GetRecordsByModel(model string) []*RequestResponsePair } func (r *OpenAIRecorder) RemoveModel(model string) { + modelID := r.modelManager.ResolveModelID(model) + r.m.Lock() defer r.m.Unlock() - if _, exists := r.records[model]; exists { - delete(r.records, model) - r.log.Infof("Removed records for model: %s", model) + if _, exists := r.records[modelID]; exists { + delete(r.records, modelID) + r.log.Infof("Removed records for model: %s", modelID) } else { - r.log.Warnf("No records found for model: %s", model) + r.log.Warnf("No records found for model: %s", modelID) } }