Skip to content

Commit be8f3e6

Browse files
committed
Add OpenAIRecorder
Signed-off-by: Dorin Geman <dorin.geman@docker.com>
1 parent b8561e1 commit be8f3e6

2 files changed

Lines changed: 240 additions & 0 deletions

File tree

pkg/inference/scheduling/scheduler.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ type Scheduler struct {
4040
router *http.ServeMux
4141
// tracker is the metrics tracker.
4242
tracker *metrics.Tracker
43+
// openAIRecorder is used to record OpenAI API inference requests and responses.
44+
openAIRecorder *metrics.OpenAIRecorder
4345
// lock is used to synchronize access to the scheduler's router.
4446
lock sync.Mutex
4547
}
@@ -64,6 +66,7 @@ func NewScheduler(
6466
loader: newLoader(log, backends, modelManager),
6567
router: http.NewServeMux(),
6668
tracker: tracker,
69+
openAIRecorder: metrics.NewOpenAIRecorder(log.WithField("component", "openai-recorder")),
6770
}
6871

6972
// Register routes.
@@ -115,6 +118,7 @@ func (s *Scheduler) routeHandlers(allowedOrigins []string) map[string]http.Handl
115118
m["POST "+inference.InferencePrefix+"/unload"] = s.Unload
116119
m["POST "+inference.InferencePrefix+"/{backend}/_configure"] = s.Configure
117120
m["POST "+inference.InferencePrefix+"/_configure"] = s.Configure
121+
m["GET "+inference.InferencePrefix+"/requests"] = s.openAIRecorder.GetRecordsByModelHandler()
118122
return m
119123
}
120124

@@ -232,6 +236,14 @@ func (s *Scheduler) handleOpenAIInference(w http.ResponseWriter, r *http.Request
232236
s.tracker.TrackModel(model)
233237
}
234238

239+
// Record the request in the OpenAI recorder.
240+
recordID := s.openAIRecorder.RecordRequest(request.Model, r, body)
241+
w = s.openAIRecorder.NewResponseRecorder(w)
242+
defer func() {
243+
// Record the response in the OpenAI recorder.
244+
s.openAIRecorder.RecordResponse(recordID, request.Model, w)
245+
}()
246+
235247
// Request a runner to execute the request and defer its release.
236248
runner, err := s.loader.load(r.Context(), backend.Name(), request.Model, backendMode)
237249
if err != nil {

pkg/metrics/openai_recorder.go

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
package metrics
2+
3+
import (
4+
"bytes"
5+
"encoding/json"
6+
"fmt"
7+
"net/http"
8+
"strings"
9+
"sync"
10+
"time"
11+
12+
"github.com/docker/model-runner/pkg/logging"
13+
)
14+
15+
type responseRecorder struct {
16+
http.ResponseWriter
17+
body *bytes.Buffer
18+
statusCode int
19+
}
20+
21+
func (rr *responseRecorder) Write(b []byte) (int, error) {
22+
rr.body.Write(b)
23+
return rr.ResponseWriter.Write(b)
24+
}
25+
26+
func (rr *responseRecorder) WriteHeader(statusCode int) {
27+
rr.statusCode = statusCode
28+
rr.ResponseWriter.WriteHeader(statusCode)
29+
}
30+
31+
type RequestResponsePair struct {
32+
ID string `json:"id"`
33+
Model string `json:"model"`
34+
Method string `json:"method"`
35+
URL string `json:"url"`
36+
Request string `json:"request"`
37+
Response string `json:"response"`
38+
Timestamp time.Time `json:"timestamp"`
39+
StatusCode int `json:"status_code"`
40+
}
41+
42+
type OpenAIRecorder struct {
43+
log logging.Logger
44+
records map[string][]*RequestResponsePair
45+
m sync.RWMutex
46+
}
47+
48+
func NewOpenAIRecorder(log logging.Logger) *OpenAIRecorder {
49+
return &OpenAIRecorder{
50+
log: log,
51+
records: make(map[string][]*RequestResponsePair),
52+
}
53+
}
54+
55+
func (r *OpenAIRecorder) RecordRequest(model string, req *http.Request, body []byte) string {
56+
r.m.Lock()
57+
defer r.m.Unlock()
58+
59+
recordID := fmt.Sprintf("%s_%d", model, time.Now().UnixNano())
60+
61+
record := &RequestResponsePair{
62+
ID: recordID,
63+
Model: model,
64+
Method: req.Method,
65+
URL: req.URL.Path,
66+
Request: string(body),
67+
Timestamp: time.Now(),
68+
}
69+
70+
if r.records[model] == nil {
71+
r.records[model] = make([]*RequestResponsePair, 0, 10)
72+
}
73+
74+
r.records[model] = append(r.records[model], record)
75+
76+
if len(r.records[model]) > 10 {
77+
r.records[model] = r.records[model][1:]
78+
}
79+
80+
return recordID
81+
}
82+
83+
func (r *OpenAIRecorder) NewResponseRecorder(w http.ResponseWriter) http.ResponseWriter {
84+
rc := &responseRecorder{
85+
ResponseWriter: w,
86+
body: &bytes.Buffer{},
87+
statusCode: http.StatusOK,
88+
}
89+
return rc
90+
}
91+
92+
func (r *OpenAIRecorder) RecordResponse(id, model string, rw http.ResponseWriter) {
93+
rr := rw.(*responseRecorder)
94+
95+
responseBody := rr.body.String()
96+
statusCode := rr.statusCode
97+
98+
var response string
99+
if strings.Contains(responseBody, "data: ") {
100+
response = r.convertStreamingResponse(responseBody)
101+
} else {
102+
response = responseBody
103+
}
104+
105+
r.m.Lock()
106+
defer r.m.Unlock()
107+
108+
if modelRecords, exists := r.records[model]; exists {
109+
for _, record := range modelRecords {
110+
if record.ID == id {
111+
record.Response = response
112+
record.StatusCode = statusCode
113+
return
114+
}
115+
}
116+
r.log.Errorf("Matching request (id=%s) not found for model %s - %d\n%s", id, model, statusCode, response)
117+
} else {
118+
r.log.Errorf("Model %s not found in records - %d\n%s", model, statusCode, response)
119+
}
120+
}
121+
122+
func (r *OpenAIRecorder) convertStreamingResponse(streamingBody string) string {
123+
lines := strings.Split(streamingBody, "\n")
124+
var contentBuilder strings.Builder
125+
var lastChunk map[string]interface{}
126+
127+
for _, line := range lines {
128+
if strings.HasPrefix(line, "data: ") {
129+
data := strings.TrimPrefix(line, "data: ")
130+
if data == "[DONE]" {
131+
break
132+
}
133+
134+
var chunk map[string]interface{}
135+
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
136+
continue
137+
}
138+
139+
lastChunk = chunk
140+
141+
if choices, ok := chunk["choices"].([]interface{}); ok && len(choices) > 0 {
142+
if choice, ok := choices[0].(map[string]interface{}); ok {
143+
if delta, ok := choice["delta"].(map[string]interface{}); ok {
144+
if content, ok := delta["content"].(string); ok {
145+
contentBuilder.WriteString(content)
146+
}
147+
}
148+
}
149+
}
150+
}
151+
}
152+
153+
if lastChunk == nil {
154+
return streamingBody
155+
}
156+
157+
finalResponse := make(map[string]interface{})
158+
159+
for key, value := range lastChunk {
160+
finalResponse[key] = value
161+
}
162+
163+
if choices, ok := finalResponse["choices"].([]interface{}); ok && len(choices) > 0 {
164+
if choice, ok := choices[0].(map[string]interface{}); ok {
165+
choice["message"] = map[string]interface{}{
166+
"role": "assistant",
167+
"content": contentBuilder.String(),
168+
}
169+
delete(choice, "delta")
170+
171+
if _, ok := choice["finish_reason"]; !ok {
172+
choice["finish_reason"] = "stop"
173+
}
174+
}
175+
}
176+
177+
finalResponse["object"] = "chat.completion"
178+
179+
jsonResult, err := json.Marshal(finalResponse)
180+
if err != nil {
181+
return streamingBody
182+
}
183+
184+
return string(jsonResult)
185+
}
186+
187+
func (r *OpenAIRecorder) GetRecordsByModelHandler() http.HandlerFunc {
188+
return func(w http.ResponseWriter, req *http.Request) {
189+
w.Header().Set("Content-Type", "application/json")
190+
191+
model := req.URL.Query().Get("model")
192+
193+
if model == "" {
194+
http.Error(w, "A 'model' query parameter is required", http.StatusBadRequest)
195+
} else {
196+
// Retrieve records for the specified model.
197+
records := r.GetRecordsByModel(model)
198+
if records == nil {
199+
// No records found for the specified model.
200+
http.Error(w, fmt.Sprintf("No records found for model '%s'", model), http.StatusNotFound)
201+
return
202+
}
203+
204+
if err := json.NewEncoder(w).Encode(map[string]interface{}{
205+
"model": model,
206+
"records": records,
207+
"count": len(records),
208+
}); err != nil {
209+
http.Error(w, fmt.Sprintf("Failed to encode records for model '%s': %v", model, err),
210+
http.StatusInternalServerError)
211+
return
212+
}
213+
}
214+
}
215+
}
216+
217+
func (r *OpenAIRecorder) GetRecordsByModel(model string) []*RequestResponsePair {
218+
r.m.RLock()
219+
defer r.m.RUnlock()
220+
221+
if modelRecords, exists := r.records[model]; exists {
222+
result := make([]*RequestResponsePair, len(modelRecords))
223+
copy(result, modelRecords)
224+
return result
225+
}
226+
227+
return nil
228+
}

0 commit comments

Comments
 (0)