99 "sync"
1010 "time"
1111
12+ "github.com/docker/model-runner/pkg/inference"
1213 "github.com/docker/model-runner/pkg/logging"
1314)
1415
@@ -39,17 +40,41 @@ type RequestResponsePair struct {
3940 StatusCode int `json:"status_code"`
4041}
4142
43+ type ModelData struct {
44+ Config inference.BackendConfiguration `json:"config"`
45+ Records []* RequestResponsePair `json:"records"`
46+ }
47+
4248type OpenAIRecorder struct {
4349 log logging.Logger
44- records map [string ][] * RequestResponsePair
50+ records map [string ]* ModelData
4551 m sync.RWMutex
4652}
4753
4854func NewOpenAIRecorder (log logging.Logger ) * OpenAIRecorder {
4955 return & OpenAIRecorder {
5056 log : log ,
51- records : make (map [string ][]* RequestResponsePair ),
57+ records : make (map [string ]* ModelData ),
58+ }
59+ }
60+
61+ func (r * OpenAIRecorder ) SetConfigForModel (model string , config * inference.BackendConfiguration ) {
62+ if config == nil {
63+ r .log .Warnf ("SetConfigForModel called with nil config for model %s" , model )
64+ return
65+ }
66+
67+ r .m .Lock ()
68+ defer r .m .Unlock ()
69+
70+ if r .records [model ] == nil {
71+ r .records [model ] = & ModelData {
72+ Records : make ([]* RequestResponsePair , 0 , 10 ),
73+ Config : inference.BackendConfiguration {},
74+ }
5275 }
76+
77+ r .records [model ].Config = * config
5378}
5479
5580func (r * OpenAIRecorder ) RecordRequest (model string , req * http.Request , body []byte ) string {
@@ -68,13 +93,16 @@ func (r *OpenAIRecorder) RecordRequest(model string, req *http.Request, body []b
6893 }
6994
7095 if r .records [model ] == nil {
71- r .records [model ] = make ([]* RequestResponsePair , 0 , 10 )
96+ r .records [model ] = & ModelData {
97+ Records : make ([]* RequestResponsePair , 0 , 10 ),
98+ Config : inference.BackendConfiguration {},
99+ }
72100 }
73101
74- r .records [model ] = append (r .records [model ], record )
102+ r .records [model ]. Records = append (r .records [model ]. Records , record )
75103
76- if len (r .records [model ]) > 10 {
77- r .records [model ] = r .records [model ][1 :]
104+ if len (r .records [model ]. Records ) > 10 {
105+ r .records [model ]. Records = r .records [model ]. Records [1 :]
78106 }
79107
80108 return recordID
@@ -105,8 +133,8 @@ func (r *OpenAIRecorder) RecordResponse(id, model string, rw http.ResponseWriter
105133 r .m .Lock ()
106134 defer r .m .Unlock ()
107135
108- if modelRecords , exists := r .records [model ]; exists {
109- for _ , record := range modelRecords {
136+ if modelData , exists := r .records [model ]; exists {
137+ for _ , record := range modelData . Records {
110138 if record .ID == id {
111139 record .Response = response
112140 record .StatusCode = statusCode
@@ -205,6 +233,7 @@ func (r *OpenAIRecorder) GetRecordsByModelHandler() http.HandlerFunc {
205233 "model" : model ,
206234 "records" : records ,
207235 "count" : len (records ),
236+ "config" : r .records [model ].Config ,
208237 }); err != nil {
209238 http .Error (w , fmt .Sprintf ("Failed to encode records for model '%s': %v" , model , err ),
210239 http .StatusInternalServerError )
@@ -218,9 +247,9 @@ func (r *OpenAIRecorder) GetRecordsByModel(model string) []*RequestResponsePair
218247 r .m .RLock ()
219248 defer r .m .RUnlock ()
220249
221- if modelRecords , exists := r .records [model ]; exists {
222- result := make ([]* RequestResponsePair , len (modelRecords ))
223- copy (result , modelRecords )
250+ if modelData , exists := r .records [model ]; exists {
251+ result := make ([]* RequestResponsePair , len (modelData . Records ))
252+ copy (result , modelData . Records )
224253 return result
225254 }
226255
0 commit comments