Skip to content

Commit 9f5bdfa

Browse files
authored
Merge pull request router-for-me#2531 from jamestut/openai-vertex-token-usage-fix
Fix missing `response.completed.usage` for late-usage OpenAI-compatible streams
2 parents 9eabdd0 + 65e9e89 commit 9f5bdfa

3 files changed

Lines changed: 279 additions & 140 deletions

File tree

internal/runtime/executor/openai_compat_executor.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,14 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
298298
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
299299
reporter.PublishFailure(ctx)
300300
out <- cliproxyexecutor.StreamChunk{Err: errScan}
301+
} else {
302+
// In case the upstream close the stream without a terminal [DONE] marker.
303+
// Feed a synthetic done marker through the translator so pending
304+
// response.completed events are still emitted exactly once.
305+
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("data: [DONE]"), &param)
306+
for i := range chunks {
307+
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
308+
}
301309
}
302310
// Ensure we record the request if no usage chunk was ever seen
303311
reporter.EnsurePublished(ctx)

internal/translator/openai/openai/responses/openai_openai-responses_response.go

Lines changed: 153 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@ type oaiToResponsesStateReasoning struct {
2020
OutputIndex int
2121
}
2222
type oaiToResponsesState struct {
23-
Seq int
24-
ResponseID string
25-
Created int64
26-
Started bool
27-
ReasoningID string
28-
ReasoningIndex int
23+
Seq int
24+
ResponseID string
25+
Created int64
26+
Started bool
27+
CompletionPending bool
28+
CompletedEmitted bool
29+
ReasoningID string
30+
ReasoningIndex int
2931
// aggregation buffers for response.output
3032
// Per-output message text buffers by index
3133
MsgTextBuf map[int]*strings.Builder
@@ -60,6 +62,141 @@ func emitRespEvent(event string, payload []byte) []byte {
6062
return translatorcommon.SSEEventData(event, payload)
6163
}
6264

65+
func buildResponsesCompletedEvent(st *oaiToResponsesState, requestRawJSON []byte, nextSeq func() int) []byte {
66+
completed := []byte(`{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`)
67+
completed, _ = sjson.SetBytes(completed, "sequence_number", nextSeq())
68+
completed, _ = sjson.SetBytes(completed, "response.id", st.ResponseID)
69+
completed, _ = sjson.SetBytes(completed, "response.created_at", st.Created)
70+
// Inject original request fields into response as per docs/response.completed.json
71+
if requestRawJSON != nil {
72+
req := gjson.ParseBytes(requestRawJSON)
73+
if v := req.Get("instructions"); v.Exists() {
74+
completed, _ = sjson.SetBytes(completed, "response.instructions", v.String())
75+
}
76+
if v := req.Get("max_output_tokens"); v.Exists() {
77+
completed, _ = sjson.SetBytes(completed, "response.max_output_tokens", v.Int())
78+
}
79+
if v := req.Get("max_tool_calls"); v.Exists() {
80+
completed, _ = sjson.SetBytes(completed, "response.max_tool_calls", v.Int())
81+
}
82+
if v := req.Get("model"); v.Exists() {
83+
completed, _ = sjson.SetBytes(completed, "response.model", v.String())
84+
}
85+
if v := req.Get("parallel_tool_calls"); v.Exists() {
86+
completed, _ = sjson.SetBytes(completed, "response.parallel_tool_calls", v.Bool())
87+
}
88+
if v := req.Get("previous_response_id"); v.Exists() {
89+
completed, _ = sjson.SetBytes(completed, "response.previous_response_id", v.String())
90+
}
91+
if v := req.Get("prompt_cache_key"); v.Exists() {
92+
completed, _ = sjson.SetBytes(completed, "response.prompt_cache_key", v.String())
93+
}
94+
if v := req.Get("reasoning"); v.Exists() {
95+
completed, _ = sjson.SetBytes(completed, "response.reasoning", v.Value())
96+
}
97+
if v := req.Get("safety_identifier"); v.Exists() {
98+
completed, _ = sjson.SetBytes(completed, "response.safety_identifier", v.String())
99+
}
100+
if v := req.Get("service_tier"); v.Exists() {
101+
completed, _ = sjson.SetBytes(completed, "response.service_tier", v.String())
102+
}
103+
if v := req.Get("store"); v.Exists() {
104+
completed, _ = sjson.SetBytes(completed, "response.store", v.Bool())
105+
}
106+
if v := req.Get("temperature"); v.Exists() {
107+
completed, _ = sjson.SetBytes(completed, "response.temperature", v.Float())
108+
}
109+
if v := req.Get("text"); v.Exists() {
110+
completed, _ = sjson.SetBytes(completed, "response.text", v.Value())
111+
}
112+
if v := req.Get("tool_choice"); v.Exists() {
113+
completed, _ = sjson.SetBytes(completed, "response.tool_choice", v.Value())
114+
}
115+
if v := req.Get("tools"); v.Exists() {
116+
completed, _ = sjson.SetBytes(completed, "response.tools", v.Value())
117+
}
118+
if v := req.Get("top_logprobs"); v.Exists() {
119+
completed, _ = sjson.SetBytes(completed, "response.top_logprobs", v.Int())
120+
}
121+
if v := req.Get("top_p"); v.Exists() {
122+
completed, _ = sjson.SetBytes(completed, "response.top_p", v.Float())
123+
}
124+
if v := req.Get("truncation"); v.Exists() {
125+
completed, _ = sjson.SetBytes(completed, "response.truncation", v.String())
126+
}
127+
if v := req.Get("user"); v.Exists() {
128+
completed, _ = sjson.SetBytes(completed, "response.user", v.Value())
129+
}
130+
if v := req.Get("metadata"); v.Exists() {
131+
completed, _ = sjson.SetBytes(completed, "response.metadata", v.Value())
132+
}
133+
}
134+
135+
outputsWrapper := []byte(`{"arr":[]}`)
136+
type completedOutputItem struct {
137+
index int
138+
raw []byte
139+
}
140+
outputItems := make([]completedOutputItem, 0, len(st.Reasonings)+len(st.MsgItemAdded)+len(st.FuncArgsBuf))
141+
if len(st.Reasonings) > 0 {
142+
for _, r := range st.Reasonings {
143+
item := []byte(`{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`)
144+
item, _ = sjson.SetBytes(item, "id", r.ReasoningID)
145+
item, _ = sjson.SetBytes(item, "summary.0.text", r.ReasoningData)
146+
outputItems = append(outputItems, completedOutputItem{index: r.OutputIndex, raw: item})
147+
}
148+
}
149+
if len(st.MsgItemAdded) > 0 {
150+
for i := range st.MsgItemAdded {
151+
txt := ""
152+
if b := st.MsgTextBuf[i]; b != nil {
153+
txt = b.String()
154+
}
155+
item := []byte(`{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`)
156+
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
157+
item, _ = sjson.SetBytes(item, "content.0.text", txt)
158+
outputItems = append(outputItems, completedOutputItem{index: st.MsgOutputIx[i], raw: item})
159+
}
160+
}
161+
if len(st.FuncArgsBuf) > 0 {
162+
for key := range st.FuncArgsBuf {
163+
args := ""
164+
if b := st.FuncArgsBuf[key]; b != nil {
165+
args = b.String()
166+
}
167+
callID := st.FuncCallIDs[key]
168+
name := st.FuncNames[key]
169+
item := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`)
170+
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("fc_%s", callID))
171+
item, _ = sjson.SetBytes(item, "arguments", args)
172+
item, _ = sjson.SetBytes(item, "call_id", callID)
173+
item, _ = sjson.SetBytes(item, "name", name)
174+
outputItems = append(outputItems, completedOutputItem{index: st.FuncOutputIx[key], raw: item})
175+
}
176+
}
177+
sort.Slice(outputItems, func(i, j int) bool { return outputItems[i].index < outputItems[j].index })
178+
for _, item := range outputItems {
179+
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item.raw)
180+
}
181+
if gjson.GetBytes(outputsWrapper, "arr.#").Int() > 0 {
182+
completed, _ = sjson.SetRawBytes(completed, "response.output", []byte(gjson.GetBytes(outputsWrapper, "arr").Raw))
183+
}
184+
if st.UsageSeen {
185+
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens", st.PromptTokens)
186+
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens_details.cached_tokens", st.CachedTokens)
187+
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens", st.CompletionTokens)
188+
if st.ReasoningTokens > 0 {
189+
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens_details.reasoning_tokens", st.ReasoningTokens)
190+
}
191+
total := st.TotalTokens
192+
if total == 0 {
193+
total = st.PromptTokens + st.CompletionTokens
194+
}
195+
completed, _ = sjson.SetBytes(completed, "response.usage.total_tokens", total)
196+
}
197+
return emitRespEvent("response.completed", completed)
198+
}
199+
63200
// ConvertOpenAIChatCompletionsResponseToOpenAIResponses converts OpenAI Chat Completions streaming chunks
64201
// to OpenAI Responses SSE events (response.*).
65202
func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
@@ -90,6 +227,10 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
90227
return [][]byte{}
91228
}
92229
if bytes.Equal(rawJSON, []byte("[DONE]")) {
230+
if st.CompletionPending && !st.CompletedEmitted {
231+
st.CompletedEmitted = true
232+
return [][]byte{buildResponsesCompletedEvent(st, requestRawJSON, func() int { st.Seq++; return st.Seq })}
233+
}
93234
return [][]byte{}
94235
}
95236

@@ -165,6 +306,8 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
165306
st.TotalTokens = 0
166307
st.ReasoningTokens = 0
167308
st.UsageSeen = false
309+
st.CompletionPending = false
310+
st.CompletedEmitted = false
168311
// response.created
169312
created := []byte(`{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`)
170313
created, _ = sjson.SetBytes(created, "sequence_number", nextSeq())
@@ -374,8 +517,9 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
374517
}
375518
}
376519

377-
// finish_reason triggers finalization, including text done/content done/item done,
378-
// reasoning done/part.done, function args done/item done, and completed
520+
// finish_reason triggers item-level finalization. response.completed is
521+
// deferred until the terminal [DONE] marker so late usage-only chunks can
522+
// still populate response.usage.
379523
if fr := choice.Get("finish_reason"); fr.Exists() && fr.String() != "" {
380524
// Emit message done events for all indices that started a message
381525
if len(st.MsgItemAdded) > 0 {
@@ -464,138 +608,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
464608
st.FuncArgsDone[key] = true
465609
}
466610
}
467-
completed := []byte(`{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`)
468-
completed, _ = sjson.SetBytes(completed, "sequence_number", nextSeq())
469-
completed, _ = sjson.SetBytes(completed, "response.id", st.ResponseID)
470-
completed, _ = sjson.SetBytes(completed, "response.created_at", st.Created)
471-
// Inject original request fields into response as per docs/response.completed.json
472-
if requestRawJSON != nil {
473-
req := gjson.ParseBytes(requestRawJSON)
474-
if v := req.Get("instructions"); v.Exists() {
475-
completed, _ = sjson.SetBytes(completed, "response.instructions", v.String())
476-
}
477-
if v := req.Get("max_output_tokens"); v.Exists() {
478-
completed, _ = sjson.SetBytes(completed, "response.max_output_tokens", v.Int())
479-
}
480-
if v := req.Get("max_tool_calls"); v.Exists() {
481-
completed, _ = sjson.SetBytes(completed, "response.max_tool_calls", v.Int())
482-
}
483-
if v := req.Get("model"); v.Exists() {
484-
completed, _ = sjson.SetBytes(completed, "response.model", v.String())
485-
}
486-
if v := req.Get("parallel_tool_calls"); v.Exists() {
487-
completed, _ = sjson.SetBytes(completed, "response.parallel_tool_calls", v.Bool())
488-
}
489-
if v := req.Get("previous_response_id"); v.Exists() {
490-
completed, _ = sjson.SetBytes(completed, "response.previous_response_id", v.String())
491-
}
492-
if v := req.Get("prompt_cache_key"); v.Exists() {
493-
completed, _ = sjson.SetBytes(completed, "response.prompt_cache_key", v.String())
494-
}
495-
if v := req.Get("reasoning"); v.Exists() {
496-
completed, _ = sjson.SetBytes(completed, "response.reasoning", v.Value())
497-
}
498-
if v := req.Get("safety_identifier"); v.Exists() {
499-
completed, _ = sjson.SetBytes(completed, "response.safety_identifier", v.String())
500-
}
501-
if v := req.Get("service_tier"); v.Exists() {
502-
completed, _ = sjson.SetBytes(completed, "response.service_tier", v.String())
503-
}
504-
if v := req.Get("store"); v.Exists() {
505-
completed, _ = sjson.SetBytes(completed, "response.store", v.Bool())
506-
}
507-
if v := req.Get("temperature"); v.Exists() {
508-
completed, _ = sjson.SetBytes(completed, "response.temperature", v.Float())
509-
}
510-
if v := req.Get("text"); v.Exists() {
511-
completed, _ = sjson.SetBytes(completed, "response.text", v.Value())
512-
}
513-
if v := req.Get("tool_choice"); v.Exists() {
514-
completed, _ = sjson.SetBytes(completed, "response.tool_choice", v.Value())
515-
}
516-
if v := req.Get("tools"); v.Exists() {
517-
completed, _ = sjson.SetBytes(completed, "response.tools", v.Value())
518-
}
519-
if v := req.Get("top_logprobs"); v.Exists() {
520-
completed, _ = sjson.SetBytes(completed, "response.top_logprobs", v.Int())
521-
}
522-
if v := req.Get("top_p"); v.Exists() {
523-
completed, _ = sjson.SetBytes(completed, "response.top_p", v.Float())
524-
}
525-
if v := req.Get("truncation"); v.Exists() {
526-
completed, _ = sjson.SetBytes(completed, "response.truncation", v.String())
527-
}
528-
if v := req.Get("user"); v.Exists() {
529-
completed, _ = sjson.SetBytes(completed, "response.user", v.Value())
530-
}
531-
if v := req.Get("metadata"); v.Exists() {
532-
completed, _ = sjson.SetBytes(completed, "response.metadata", v.Value())
533-
}
534-
}
535-
// Build response.output using aggregated buffers
536-
outputsWrapper := []byte(`{"arr":[]}`)
537-
type completedOutputItem struct {
538-
index int
539-
raw []byte
540-
}
541-
outputItems := make([]completedOutputItem, 0, len(st.Reasonings)+len(st.MsgItemAdded)+len(st.FuncArgsBuf))
542-
if len(st.Reasonings) > 0 {
543-
for _, r := range st.Reasonings {
544-
item := []byte(`{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`)
545-
item, _ = sjson.SetBytes(item, "id", r.ReasoningID)
546-
item, _ = sjson.SetBytes(item, "summary.0.text", r.ReasoningData)
547-
outputItems = append(outputItems, completedOutputItem{index: r.OutputIndex, raw: item})
548-
}
549-
}
550-
if len(st.MsgItemAdded) > 0 {
551-
for i := range st.MsgItemAdded {
552-
txt := ""
553-
if b := st.MsgTextBuf[i]; b != nil {
554-
txt = b.String()
555-
}
556-
item := []byte(`{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`)
557-
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
558-
item, _ = sjson.SetBytes(item, "content.0.text", txt)
559-
outputItems = append(outputItems, completedOutputItem{index: st.MsgOutputIx[i], raw: item})
560-
}
561-
}
562-
if len(st.FuncArgsBuf) > 0 {
563-
for key := range st.FuncArgsBuf {
564-
args := ""
565-
if b := st.FuncArgsBuf[key]; b != nil {
566-
args = b.String()
567-
}
568-
callID := st.FuncCallIDs[key]
569-
name := st.FuncNames[key]
570-
item := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`)
571-
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("fc_%s", callID))
572-
item, _ = sjson.SetBytes(item, "arguments", args)
573-
item, _ = sjson.SetBytes(item, "call_id", callID)
574-
item, _ = sjson.SetBytes(item, "name", name)
575-
outputItems = append(outputItems, completedOutputItem{index: st.FuncOutputIx[key], raw: item})
576-
}
577-
}
578-
sort.Slice(outputItems, func(i, j int) bool { return outputItems[i].index < outputItems[j].index })
579-
for _, item := range outputItems {
580-
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item.raw)
581-
}
582-
if gjson.GetBytes(outputsWrapper, "arr.#").Int() > 0 {
583-
completed, _ = sjson.SetRawBytes(completed, "response.output", []byte(gjson.GetBytes(outputsWrapper, "arr").Raw))
584-
}
585-
if st.UsageSeen {
586-
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens", st.PromptTokens)
587-
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens_details.cached_tokens", st.CachedTokens)
588-
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens", st.CompletionTokens)
589-
if st.ReasoningTokens > 0 {
590-
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens_details.reasoning_tokens", st.ReasoningTokens)
591-
}
592-
total := st.TotalTokens
593-
if total == 0 {
594-
total = st.PromptTokens + st.CompletionTokens
595-
}
596-
completed, _ = sjson.SetBytes(completed, "response.usage.total_tokens", total)
597-
}
598-
out = append(out, emitRespEvent("response.completed", completed))
611+
st.CompletionPending = true
599612
}
600613

601614
return true

0 commit comments

Comments
 (0)