Skip to content

Commit ea9a321

Browse files
committed
OpenAIRecorder: Include BackendConfiguration
Signed-off-by: Dorin Geman <dorin.geman@docker.com>
1 parent 3904f23 commit ea9a321

3 files changed

Lines changed: 44 additions & 13 deletions

File tree

pkg/inference/backend.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ func (m BackendMode) String() string {
3030
}
3131

3232
type BackendConfiguration struct {
33-
ContextSize int64
34-
RawFlags []string
33+
ContextSize int64 `json:"context_size,omitempty"`
34+
RawFlags []string `json:"flags,omitempty"`
3535
}
3636

3737
// Backend is the interface implemented by inference engine backends. Backend

pkg/inference/scheduling/runner.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ func run(
169169
}
170170
}
171171

172+
r.openAIRecorder.SetConfigForModel(model, runnerConfig)
173+
172174
// Start the backend run loop.
173175
go func() {
174176
if err := backend.Run(runCtx, socket, model, mode, runnerConfig); err != nil {

pkg/metrics/openai_recorder.go

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"sync"
1010
"time"
1111

12+
"github.com/docker/model-runner/pkg/inference"
1213
"github.com/docker/model-runner/pkg/logging"
1314
)
1415

@@ -39,17 +40,41 @@ type RequestResponsePair struct {
3940
StatusCode int `json:"status_code"`
4041
}
4142

43+
type ModelData struct {
44+
Config inference.BackendConfiguration `json:"config"`
45+
Records []*RequestResponsePair `json:"records"`
46+
}
47+
4248
type OpenAIRecorder struct {
4349
log logging.Logger
44-
records map[string][]*RequestResponsePair
50+
records map[string]*ModelData
4551
m sync.RWMutex
4652
}
4753

4854
func NewOpenAIRecorder(log logging.Logger) *OpenAIRecorder {
4955
return &OpenAIRecorder{
5056
log: log,
51-
records: make(map[string][]*RequestResponsePair),
57+
records: make(map[string]*ModelData),
58+
}
59+
}
60+
61+
func (r *OpenAIRecorder) SetConfigForModel(model string, config *inference.BackendConfiguration) {
62+
if config == nil {
63+
r.log.Warnf("SetConfigForModel called with nil config for model %s", model)
64+
return
65+
}
66+
67+
r.m.Lock()
68+
defer r.m.Unlock()
69+
70+
if r.records[model] == nil {
71+
r.records[model] = &ModelData{
72+
Records: make([]*RequestResponsePair, 0, 10),
73+
Config: inference.BackendConfiguration{},
74+
}
5275
}
76+
77+
r.records[model].Config = *config
5378
}
5479

5580
func (r *OpenAIRecorder) RecordRequest(model string, req *http.Request, body []byte) string {
@@ -68,13 +93,16 @@ func (r *OpenAIRecorder) RecordRequest(model string, req *http.Request, body []b
6893
}
6994

7095
if r.records[model] == nil {
71-
r.records[model] = make([]*RequestResponsePair, 0, 10)
96+
r.records[model] = &ModelData{
97+
Records: make([]*RequestResponsePair, 0, 10),
98+
Config: inference.BackendConfiguration{},
99+
}
72100
}
73101

74-
r.records[model] = append(r.records[model], record)
102+
r.records[model].Records = append(r.records[model].Records, record)
75103

76-
if len(r.records[model]) > 10 {
77-
r.records[model] = r.records[model][1:]
104+
if len(r.records[model].Records) > 10 {
105+
r.records[model].Records = r.records[model].Records[1:]
78106
}
79107

80108
return recordID
@@ -105,8 +133,8 @@ func (r *OpenAIRecorder) RecordResponse(id, model string, rw http.ResponseWriter
105133
r.m.Lock()
106134
defer r.m.Unlock()
107135

108-
if modelRecords, exists := r.records[model]; exists {
109-
for _, record := range modelRecords {
136+
if modelData, exists := r.records[model]; exists {
137+
for _, record := range modelData.Records {
110138
if record.ID == id {
111139
record.Response = response
112140
record.StatusCode = statusCode
@@ -205,6 +233,7 @@ func (r *OpenAIRecorder) GetRecordsByModelHandler() http.HandlerFunc {
205233
"model": model,
206234
"records": records,
207235
"count": len(records),
236+
"config": r.records[model].Config,
208237
}); err != nil {
209238
http.Error(w, fmt.Sprintf("Failed to encode records for model '%s': %v", model, err),
210239
http.StatusInternalServerError)
@@ -218,9 +247,9 @@ func (r *OpenAIRecorder) GetRecordsByModel(model string) []*RequestResponsePair
218247
r.m.RLock()
219248
defer r.m.RUnlock()
220249

221-
if modelRecords, exists := r.records[model]; exists {
222-
result := make([]*RequestResponsePair, len(modelRecords))
223-
copy(result, modelRecords)
250+
if modelData, exists := r.records[model]; exists {
251+
result := make([]*RequestResponsePair, len(modelData.Records))
252+
copy(result, modelData.Records)
224253
return result
225254
}
226255

0 commit comments

Comments
 (0)