Skip to content

Commit bc8cbe9

Browse files
fix(proxy): tighten streaming usage tracking and tool-call continuity (port upstream 4694c54)
1 parent 0ed96ae commit bc8cbe9

8 files changed

Lines changed: 352 additions & 23 deletions

File tree

proxy/handler.go

Lines changed: 70 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,16 @@ func (h *Handler) Responses(c *gin.Context) {
660660
}
661661
}()
662662

663+
// 上游 ctx 生命周期:每次 attempt 开始前用新的 drainable ctx 替换,
664+
// defer 兜底确保函数退出时上游被释放。
665+
// 目的:客户端断连后仍给上游 upstreamDrainTimeout 时间捞 response.completed 的 usage。
666+
var lastUpstreamCancel context.CancelFunc
667+
defer func() {
668+
if lastUpstreamCancel != nil {
669+
lastUpstreamCancel()
670+
}
671+
}()
672+
663673
for attempt := 0; attempt <= maxRetries; attempt++ {
664674
account, stickyProxyURL := h.nextAccountForSessionWithPreference(sessionID, excludeAccounts, preferPlan)
665675
if account == nil {
@@ -698,7 +708,15 @@ func (h *Handler) Responses(c *gin.Context) {
698708
// 透传下游请求头用于指纹学习
699709
downstreamHeaders := c.Request.Header.Clone()
700710

701-
resp, reqErr := ExecuteRequest(c.Request.Context(), account, codexBody, sessionID, proxyURL, apiKey, deviceCfg, downstreamHeaders, useWebsocket)
711+
// 上游使用与客户端解耦的 context:客户端中途断开时仍能继续读完
712+
// response.completed 拿到 usage(流式计费的关键)。
713+
// 重试前先 cancel 上一轮的上游 ctx。
714+
if lastUpstreamCancel != nil {
715+
lastUpstreamCancel()
716+
}
717+
upstreamCtx, upstreamCancel := newDrainableUpstreamContext(c.Request.Context(), upstreamDrainTimeout)
718+
lastUpstreamCancel = upstreamCancel
719+
resp, reqErr := ExecuteRequest(upstreamCtx, account, codexBody, sessionID, proxyURL, apiKey, deviceCfg, downstreamHeaders, useWebsocket)
702720
durationMs := int(time.Since(start).Milliseconds())
703721

704722
if reqErr != nil {
@@ -779,6 +797,9 @@ func (h *Handler) Responses(c *gin.Context) {
779797
var lastFailedErrMsg string // 上游 response.failed 的 error.message(debug 用,不论是否 capacity)
780798

781799
if isStream {
800+
// clientGone:客户端写失败后置位,后续事件不再写客户端,
801+
// 但继续读上游直到 response.completed/failed,以拿到准确 usage。
802+
clientGone := false
782803
// 流式透传 + TTFT 跟踪(headers 已在 SetupKeepalive 里设置)
783804
readErr = ReadSSEStream(resp.Body, func(data []byte) bool {
784805
parsed := gjson.ParseBytes(data)
@@ -794,8 +815,8 @@ func (h *Handler) Responses(c *gin.Context) {
794815
}
795816
}
796817

797-
// TTFT: 记录第一个 output_text.delta 事件的时间
798-
if !ttftRecorded && eventType == "response.output_text.delta" {
818+
// TTFT: 黑名单策略 —— 排除控制/终止事件,其余均视为首字(覆盖纯工具调用/图像/推理流)
819+
if !ttftRecorded && isFirstTokenEvent(eventType) {
799820
firstTokenMs = int(time.Since(start).Milliseconds())
800821
ttftRecorded = true
801822
}
@@ -821,11 +842,15 @@ func (h *Handler) Responses(c *gin.Context) {
821842

822843
// 画图场景下将 SSE 事件里的 response.model 改为 gpt-5.4
823844
dataToWrite := rewriteResponseModelIfDrawing(data, virtualHit, "response.model")
824-
if err := sseW.WriteEvent(dataToWrite); err != nil {
825-
writeErr = err
826-
return false
845+
if !clientGone {
846+
if err := sseW.WriteEvent(dataToWrite); err != nil {
847+
writeErr = err
848+
clientGone = true
849+
} else {
850+
wroteAnyBody = true
851+
}
827852
}
828-
wroteAnyBody = true
853+
// 客户端断开后仍继续读上游直到 terminal 事件,确保拿到 usage
829854
return eventType != "response.completed" && eventType != "response.failed"
830855
})
831856
} else {
@@ -838,7 +863,7 @@ func (h *Handler) Responses(c *gin.Context) {
838863
readErr = ReadSSEStream(resp.Body, func(data []byte) bool {
839864
parsed := gjson.ParseBytes(data)
840865
eventType := parsed.Get("type").String()
841-
if !ttftRecorded && eventType == "response.output_text.delta" {
866+
if !ttftRecorded && isFirstTokenEvent(eventType) {
842867
firstTokenMs = int(time.Since(start).Milliseconds())
843868
ttftRecorded = true
844869
}
@@ -1303,6 +1328,16 @@ func (h *Handler) ChatCompletions(c *gin.Context) {
13031328
}
13041329
}()
13051330

1331+
// 上游 ctx 生命周期:每次 attempt 开始前用新的 drainable ctx 替换,
1332+
// defer 兜底确保函数退出时上游被释放。
1333+
// 目的:客户端断连后仍给上游 upstreamDrainTimeout 时间捞 response.completed 的 usage。
1334+
var lastUpstreamCancel context.CancelFunc
1335+
defer func() {
1336+
if lastUpstreamCancel != nil {
1337+
lastUpstreamCancel()
1338+
}
1339+
}()
1340+
13061341
for attempt := 0; attempt <= maxRetries; attempt++ {
13071342
account, stickyProxyURL := h.nextAccountForSessionWithPreference(sessionID, excludeAccounts, preferPlan)
13081343
if account == nil {
@@ -1341,7 +1376,15 @@ func (h *Handler) ChatCompletions(c *gin.Context) {
13411376
// 透传下游请求头用于指纹学习
13421377
downstreamHeaders := c.Request.Header.Clone()
13431378

1344-
resp, reqErr := ExecuteRequest(c.Request.Context(), account, codexBody, sessionID, proxyURL, apiKey, deviceCfg, downstreamHeaders, useWebsocket)
1379+
// 上游使用与客户端解耦的 context:客户端中途断开时仍能继续读完
1380+
// response.completed 拿到 usage(流式计费的关键)。
1381+
// 重试前先 cancel 上一轮的上游 ctx。
1382+
if lastUpstreamCancel != nil {
1383+
lastUpstreamCancel()
1384+
}
1385+
upstreamCtx, upstreamCancel := newDrainableUpstreamContext(c.Request.Context(), upstreamDrainTimeout)
1386+
lastUpstreamCancel = upstreamCancel
1387+
resp, reqErr := ExecuteRequest(upstreamCtx, account, codexBody, sessionID, proxyURL, apiKey, deviceCfg, downstreamHeaders, useWebsocket)
13451388
durationMs := int(time.Since(start).Milliseconds())
13461389

13471390
if reqErr != nil {
@@ -1426,6 +1469,9 @@ func (h *Handler) ChatCompletions(c *gin.Context) {
14261469
if isStream {
14271470
streamTranslator := NewStreamTranslator(chunkID, responseModel, created)
14281471

1472+
// clientGone:客户端写失败后置位,后续事件不再写客户端,
1473+
// 但继续读上游直到 response.completed/failed,以拿到准确 usage。
1474+
clientGone := false
14291475
readErr = ReadSSEStream(resp.Body, func(data []byte) bool {
14301476
parsed := gjson.ParseBytes(data)
14311477
eventType := parsed.Get("type").String()
@@ -1444,7 +1490,7 @@ func (h *Handler) ChatCompletions(c *gin.Context) {
14441490

14451491
chunk, done := streamTranslator.Translate(data)
14461492

1447-
if !ttftRecorded && strings.Contains(eventType, ".delta") {
1493+
if !ttftRecorded && isFirstTokenEvent(eventType) {
14481494
firstTokenMs = int(time.Since(start).Milliseconds())
14491495
ttftRecorded = true
14501496
}
@@ -1463,19 +1509,27 @@ func (h *Handler) ChatCompletions(c *gin.Context) {
14631509
gotTerminal = true
14641510
}
14651511

1466-
if chunk != nil {
1512+
if !clientGone && chunk != nil {
14671513
if err := sseW.WriteEvent(chunk); err != nil {
14681514
writeErr = err
1469-
return false
1515+
clientGone = true
1516+
} else {
1517+
wroteAnyBody = true
14701518
}
1471-
wroteAnyBody = true
14721519
}
1473-
if done {
1520+
if !clientGone && done {
14741521
if err := sseW.WriteRaw("data: [DONE]\n\n"); err != nil {
14751522
writeErr = err
1523+
clientGone = true
1524+
} else {
1525+
wroteAnyBody = true
1526+
}
1527+
if !clientGone {
14761528
return false
14771529
}
1478-
wroteAnyBody = true
1530+
}
1531+
// 客户端断开后,要等到 terminal 事件才退出,确保拿到 usage。
1532+
if gotTerminal {
14791533
return false
14801534
}
14811535
return true
@@ -1487,7 +1541,7 @@ func (h *Handler) ChatCompletions(c *gin.Context) {
14871541
readErr = ReadSSEStream(resp.Body, func(data []byte) bool {
14881542
parsed := gjson.ParseBytes(data)
14891543
eventType := parsed.Get("type").String()
1490-
if !ttftRecorded && strings.Contains(eventType, ".delta") {
1544+
if !ttftRecorded && isFirstTokenEvent(eventType) {
14911545
firstTokenMs = int(time.Since(start).Milliseconds())
14921546
ttftRecorded = true
14931547
}

proxy/handler_anthropic.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -272,10 +272,8 @@ func (h *Handler) Messages(c *gin.Context) {
272272
}
273273
}
274274

275-
// TTFT 跟踪
276-
if !ttftRecorded && (eventType == "response.output_text.delta" ||
277-
eventType == "response.reasoning_summary_text.delta" ||
278-
eventType == "response.reasoning_text.delta") {
275+
// TTFT 跟踪:黑名单策略统一语义,覆盖纯工具调用/图像/推理流
276+
if !ttftRecorded && isFirstTokenEvent(eventType) {
279277
firstTokenMs = int(time.Since(start).Milliseconds())
280278
ttftRecorded = true
281279
}
@@ -331,7 +329,7 @@ func (h *Handler) Messages(c *gin.Context) {
331329
parsed := gjson.ParseBytes(data)
332330
eventType := parsed.Get("type").String()
333331

334-
if !ttftRecorded && strings.Contains(eventType, ".delta") {
332+
if !ttftRecorded && isFirstTokenEvent(eventType) {
335333
firstTokenMs = int(time.Since(start).Milliseconds())
336334
ttftRecorded = true
337335
}

proxy/response_cache.go

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,30 @@ func expandPreviousResponse(codexBody []byte) ([]byte, string) {
121121
return codexBody, ""
122122
}
123123

124+
currentInput := gjson.GetBytes(codexBody, "input")
125+
126+
// 客户端已经自带 function_call 等续链项时,跳过注入。
127+
// 缓存里只会存 function_call 项(见 cacheCompletedResponse),
128+
// 再注入会让同一 call_id 出现两次,上游会以 "duplicate call_id" 等 400 拒绝。
129+
// 仍返回 prevID,让 cacheCompletedResponse 能把这一轮响应链入缓存。
130+
if currentInput.IsArray() && inputHasToolCallContext(currentInput) {
131+
log.Printf("input 已自带工具续链项,跳过 previous_response_id=%s 的历史注入", prevID)
132+
return codexBody, prevID
133+
}
134+
124135
cached := getResponseCache(prevID)
125136
if cached == nil {
126-
// 缓存未命中(首次请求 / 过期 / 其他实例),无法展开,按原样继续
137+
// 缓存未命中(首次请求 / 过期 / 其他实例),无法展开,按原样继续。
138+
// 若 input 仅含 function_call_output 又拿不到对应的 function_call,
139+
// 上游通常会返回 "No tool call found for function call output" 400,
140+
// 这里打日志便于诊断(不阻断,让上游错误透传给客户端)。
141+
if currentInput.IsArray() && inputHasFunctionCallOutput(currentInput) {
142+
log.Printf("缓存未命中且 input 含 function_call_output,previous_response_id=%s,上游可能返回 400", prevID)
143+
}
127144
return codexBody, prevID
128145
}
129146

130147
// 构建新 input: 缓存的历史 items + 当前 input items
131-
currentInput := gjson.GetBytes(codexBody, "input")
132148
var merged []json.RawMessage
133149
merged = append(merged, cached...)
134150
if currentInput.IsArray() {
@@ -149,6 +165,36 @@ func expandPreviousResponse(codexBody []byte) ([]byte, string) {
149165
return codexBody, prevID
150166
}
151167

168+
// inputHasToolCallContext 判断 input 数组里是否已包含 function_call 续链项。
169+
// 这类项一旦同时出现在缓存里会造成 call_id 冲突(上游 400 duplicate call_id)。
170+
// 与本 fork 的 cacheCompletedResponse 对称:缓存仅存 function_call,此处也只查 function_call。
171+
func inputHasToolCallContext(input gjson.Result) bool {
172+
found := false
173+
input.ForEach(func(_, v gjson.Result) bool {
174+
if v.Get("type").String() == "function_call" {
175+
found = true
176+
return false
177+
}
178+
return true
179+
})
180+
return found
181+
}
182+
183+
// inputHasFunctionCallOutput 判断 input 数组里是否含 *_output 项(缺少配对的 function_call 时上游会 400)。
184+
func inputHasFunctionCallOutput(input gjson.Result) bool {
185+
found := false
186+
input.ForEach(func(_, v gjson.Result) bool {
187+
switch v.Get("type").String() {
188+
case "function_call_output", "tool_call_output", "local_shell_call_output",
189+
"tool_search_call_output", "custom_tool_call_output", "mcp_tool_call_output":
190+
found = true
191+
return false
192+
}
193+
return true
194+
})
195+
return found
196+
}
197+
152198
// cacheCompletedResponse 从 response.completed 事件中提取 response.id 和 response.output,
153199
// 与当前请求的 expanded input 合并后存入缓存。
154200
// 仅在响应包含 function_call 时才缓存,避免为普通对话浪费内存。

proxy/response_cache_test.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package proxy
2+
3+
import (
4+
"testing"
5+
6+
"github.com/tidwall/gjson"
7+
)
8+
9+
// resetResponseCacheForTest 清空响应上下文缓存,仅供测试使用。
10+
func resetResponseCacheForTest() {
11+
respCache.mu.Lock()
12+
defer respCache.mu.Unlock()
13+
for k := range respCache.store {
14+
delete(respCache.store, k)
15+
}
16+
}
17+
18+
func TestExpandPreviousResponseSkipsInjectionWhenInputHasFunctionCall(t *testing.T) {
19+
resetResponseCacheForTest()
20+
21+
cacheCompletedResponse(
22+
[]byte(`[{"type":"message","role":"user","content":"call tool"}]`),
23+
[]byte(`{"type":"response.completed","response":{"id":"resp_dup","output":[{"type":"function_call","call_id":"call_abc","name":"get_weather","arguments":"{}"}]}}`),
24+
)
25+
26+
// 客户端续链时同时自带 function_call 和 function_call_output,再注入缓存里的 function_call 会让 call_abc 重复。
27+
body := []byte(`{"model":"gpt-5.4","previous_response_id":"resp_dup","input":[` +
28+
`{"type":"function_call","call_id":"call_abc","name":"get_weather","arguments":"{}"},` +
29+
`{"type":"function_call_output","call_id":"call_abc","output":"sunny"}` +
30+
`]}`)
31+
got, prevID := expandPreviousResponse(body)
32+
33+
if prevID != "resp_dup" {
34+
t.Fatalf("prevID = %q, want resp_dup", prevID)
35+
}
36+
input := gjson.GetBytes(got, "input").Array()
37+
if len(input) != 2 {
38+
t.Fatalf("input count = %d, want 2 (no injection); body=%s", len(input), got)
39+
}
40+
if typ := input[0].Get("type").String(); typ != "function_call" {
41+
t.Fatalf("input[0].type = %q, want function_call", typ)
42+
}
43+
if callID := input[0].Get("call_id").String(); callID != "call_abc" {
44+
t.Fatalf("input[0].call_id = %q, want call_abc", callID)
45+
}
46+
}
47+
48+
func TestExpandPreviousResponseLeavesBodyUntouchedOnCacheMiss(t *testing.T) {
49+
resetResponseCacheForTest()
50+
51+
body := []byte(`{"model":"gpt-5.4","previous_response_id":"resp_missing","input":[` +
52+
`{"type":"function_call_output","call_id":"call_missing","output":"x"}` +
53+
`]}`)
54+
got, prevID := expandPreviousResponse(body)
55+
56+
if prevID != "resp_missing" {
57+
t.Fatalf("prevID = %q, want resp_missing (returned for downstream cache linkage)", prevID)
58+
}
59+
if string(got) != string(body) {
60+
t.Fatalf("body mutated on cache miss; got=%s want=%s", got, body)
61+
}
62+
}

proxy/ttft.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package proxy
2+
3+
// isFirstTokenEvent 判断 codex SSE 事件是否代表“首个有内容产出”,用于 TTFT 计时。
4+
//
5+
// TTFT 的语义是“上游开始向客户端产出内容的时间点”,不应限定为文本。
6+
// 现实中很多请求并不会出 text.delta:
7+
// - 纯工具调用:仅 function_call_arguments.delta / output_item.added
8+
// - 图像生成:image_generation_call.partial_image
9+
// - reasoning-only / 推理型模型:先输出 reasoning_text.delta 才到 text
10+
// - 流首字之前断开:永远等不到 text.delta
11+
//
12+
// 因此采用“黑名单”策略:排除控制事件(created / in_progress)和
13+
// 流终止事件(completed / failed),其余任何事件都视为首字。
14+
// 与 sub2api 的“任何非空、非 [DONE]、非 usage-only 行都算首字”语义一致。
15+
func isFirstTokenEvent(eventType string) bool {
16+
switch eventType {
17+
case "":
18+
return false
19+
case "response.created",
20+
"response.in_progress",
21+
"response.completed",
22+
"response.failed":
23+
return false
24+
}
25+
return true
26+
}

0 commit comments

Comments
 (0)