Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions pkg/inference/models/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,24 @@
}
}

// ResolveModelID resolves a model reference to a model ID. If resolution fails, it returns the original ref.
func (m *Manager) ResolveModelID(modelRef string) string {

model, err := m.GetModel(modelRef)
if err != nil {
m.log.Warnf("Failed to resolve model ref %s to ID: %v", modelRef, err)
Comment thread Fixed
return modelRef
}

modelID, err := model.ID()
if err != nil {
m.log.Warnf("Failed to get model ID for ref %s: %v", modelRef, err)
Comment thread Fixed
return modelRef
}

return modelID
}

func getLocalModel(m *Manager, name string) (*Model, error) {
if m.distributionClient == nil {
return nil, errors.New("model distribution service unavailable")
Expand Down
7 changes: 4 additions & 3 deletions pkg/inference/scheduling/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ var (
type runnerKey struct {
// backend is the backend associated with the runner.
backend string
// model is the model associated with the runner.
// model is the modelID associated with the runner.
model string
// mode is the operation mode associated with the runner.
mode inference.BackendMode
Expand Down Expand Up @@ -254,11 +254,12 @@ func (l *loader) Unload(ctx context.Context, unload UnloadRequest) int {
return l.evict(false)
} else {
for _, model := range unload.Models {
modelID := l.modelManager.ResolveModelID(model)
delete(l.runnerConfigs, runnerKey{unload.Backend, model, inference.BackendModeCompletion})
// Evict both, completion and embedding models. We should consider
// accepting a mode parameter in unload requests.
l.evictRunner(unload.Backend, model, inference.BackendModeCompletion)
l.evictRunner(unload.Backend, model, inference.BackendModeEmbedding)
l.evictRunner(unload.Backend, modelID, inference.BackendModeCompletion)
l.evictRunner(unload.Backend, modelID, inference.BackendModeEmbedding)
}
return len(l.runners)
}
Expand Down
11 changes: 7 additions & 4 deletions pkg/inference/scheduling/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
allowedOrigins []string,
tracker *metrics.Tracker,
) *Scheduler {
openAIRecorder := metrics.NewOpenAIRecorder(log.WithField("component", "openai-recorder"))
openAIRecorder := metrics.NewOpenAIRecorder(log.WithField("component", "openai-recorder"), modelManager)

// Create the scheduler.
s := &Scheduler{
Expand Down Expand Up @@ -238,8 +238,10 @@
s.tracker.TrackModel(model)
}

modelID := s.modelManager.ResolveModelID(request.Model)

// Request a runner to execute the request and defer its release.
runner, err := s.loader.load(r.Context(), backend.Name(), request.Model, backendMode)
runner, err := s.loader.load(r.Context(), backend.Name(), modelID, backendMode)
if err != nil {
http.Error(w, fmt.Errorf("unable to load runner: %w", err).Error(), http.StatusInternalServerError)
return
Expand Down Expand Up @@ -410,8 +412,9 @@
runnerConfig.ContextSize = configureRequest.ContextSize
runnerConfig.RuntimeFlags = runtimeFlags

if err := s.loader.setRunnerConfig(r.Context(), backend.Name(), configureRequest.Model, inference.BackendModeCompletion, runnerConfig); err != nil {
s.log.Warnf("Failed to configure %s runner for %s: %s", backend.Name(), configureRequest.Model, err)
modelID := s.modelManager.ResolveModelID(configureRequest.Model)
if err := s.loader.setRunnerConfig(r.Context(), backend.Name(), modelID, inference.BackendModeCompletion, runnerConfig); err != nil {
s.log.Warnf("Failed to configure %s runner for %s: %s", backend.Name(), modelID, err)
Comment thread Fixed
if errors.Is(err, errRunnerAlreadyActive) {
http.Error(w, err.Error(), http.StatusConflict)
} else {
Expand Down
64 changes: 39 additions & 25 deletions pkg/metrics/openai_recorder.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"time"

"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/inference/models"
"github.com/docker/model-runner/pkg/logging"
)

Expand Down Expand Up @@ -53,15 +54,17 @@
}

type OpenAIRecorder struct {
log logging.Logger
records map[string]*ModelData
m sync.RWMutex
log logging.Logger
records map[string]*ModelData // key is model ID
modelManager *models.Manager // for resolving model tags to IDs
m sync.RWMutex
}

func NewOpenAIRecorder(log logging.Logger) *OpenAIRecorder {
func NewOpenAIRecorder(log logging.Logger, modelManager *models.Manager) *OpenAIRecorder {
return &OpenAIRecorder{
log: log,
records: make(map[string]*ModelData),
log: log,
modelManager: modelManager,
records: make(map[string]*ModelData),
}
}

Expand All @@ -71,46 +74,50 @@
return
}

modelID := r.modelManager.ResolveModelID(model)

r.m.Lock()
defer r.m.Unlock()

if r.records[model] == nil {
r.records[model] = &ModelData{
if r.records[modelID] == nil {
r.records[modelID] = &ModelData{
Records: make([]*RequestResponsePair, 0, 10),
Config: inference.BackendConfiguration{},
}
}

r.records[model].Config = *config
r.records[modelID].Config = *config
}

func (r *OpenAIRecorder) RecordRequest(model string, req *http.Request, body []byte) string {
modelID := r.modelManager.ResolveModelID(model)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does using the model ID in the OpenAIRecorder mean that the GUI that displays these requests will also display the SHA256? Or do we resolve it back to the "friendly" name? (Or something like ai/gemma3@sha256:a484b...?)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently the UI does not uses/show tags when using a model or inspecting a model. Behind the scenes uses the model ID but the UI does not show that.
With @doringeman we mentioned that we are going to need a ResolveModelTags(model reference) to also show all the available tags eventually of a local model.


r.m.Lock()
defer r.m.Unlock()

recordID := fmt.Sprintf("%s_%d", model, time.Now().UnixNano())
recordID := fmt.Sprintf("%s_%d", modelID, time.Now().UnixNano())

record := &RequestResponsePair{
ID: recordID,
Model: model,
Model: modelID,
Method: req.Method,
URL: req.URL.Path,
Request: string(body),
Timestamp: time.Now(),
UserAgent: req.UserAgent(),
}

if r.records[model] == nil {
r.records[model] = &ModelData{
if r.records[modelID] == nil {
r.records[modelID] = &ModelData{
Records: make([]*RequestResponsePair, 0, 10),
Config: inference.BackendConfiguration{},
}
}

r.records[model].Records = append(r.records[model].Records, record)
r.records[modelID].Records = append(r.records[modelID].Records, record)

if len(r.records[model].Records) > 10 {
r.records[model].Records = r.records[model].Records[1:]
if len(r.records[modelID].Records) > 10 {
r.records[modelID].Records = r.records[modelID].Records[1:]
}

return recordID
Expand Down Expand Up @@ -138,20 +145,22 @@
response = responseBody
}

modelID := r.modelManager.ResolveModelID(model)

r.m.Lock()
defer r.m.Unlock()

if modelData, exists := r.records[model]; exists {
if modelData, exists := r.records[modelID]; exists {
for _, record := range modelData.Records {
if record.ID == id {
record.Response = response
record.StatusCode = statusCode
return
}
}
r.log.Errorf("Matching request (id=%s) not found for model %s - %d\n%s", id, model, statusCode, response)
r.log.Errorf("Matching request (id=%s) not found for model %s - %d\n%s", id, modelID, statusCode, response)
Comment thread Fixed
Comment thread Fixed
} else {
r.log.Errorf("Model %s not found in records - %d\n%s", model, statusCode, response)
r.log.Errorf("Model %s not found in records - %d\n%s", modelID, statusCode, response)
Comment thread Fixed
}
}

Expand Down Expand Up @@ -237,11 +246,12 @@
return
}

modelID := r.modelManager.ResolveModelID(model)
if err := json.NewEncoder(w).Encode(map[string]interface{}{
"model": model,
"records": records,
"count": len(records),
"config": r.records[model].Config,
"config": r.records[modelID].Config,
}); err != nil {
http.Error(w, fmt.Sprintf("Failed to encode records for model '%s': %v", model, err),
http.StatusInternalServerError)
Expand All @@ -252,10 +262,12 @@
}

func (r *OpenAIRecorder) GetRecordsByModel(model string) []*RequestResponsePair {
modelID := r.modelManager.ResolveModelID(model)

r.m.RLock()
defer r.m.RUnlock()

if modelData, exists := r.records[model]; exists {
if modelData, exists := r.records[modelID]; exists {
result := make([]*RequestResponsePair, len(modelData.Records))
copy(result, modelData.Records)
return result
Expand All @@ -265,13 +277,15 @@
}

func (r *OpenAIRecorder) RemoveModel(model string) {
modelID := r.modelManager.ResolveModelID(model)

r.m.Lock()
defer r.m.Unlock()

if _, exists := r.records[model]; exists {
delete(r.records, model)
r.log.Infof("Removed records for model: %s", model)
if _, exists := r.records[modelID]; exists {
delete(r.records, modelID)
r.log.Infof("Removed records for model: %s", modelID)
Comment thread Fixed
} else {
r.log.Warnf("No records found for model: %s", model)
r.log.Warnf("No records found for model: %s", modelID)
Comment thread Fixed
}
}