Skip to content

Commit 82ebe24

Browse files
authored
Merge pull request #2266 from DragonFSKY/fix/ws-compact-tool-output-mismatch
fix(websocket): skip stale state merge after client-side compact
2 parents 2753d9f + 4ca00f7 commit 82ebe24

2 files changed

Lines changed: 372 additions & 58 deletions

File tree

sdk/api/handlers/openai/openai_responses_websocket.go

Lines changed: 155 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,19 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
116116
allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName)
117117
}
118118

119+
allowCompactionReplayBypass := false
120+
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
121+
if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil {
122+
allowCompactionReplayBypass = responsesWebsocketAuthSupportsCompactionReplay(pinnedAuth)
123+
}
124+
} else {
125+
requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
126+
if requestModelName == "" {
127+
requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
128+
}
129+
allowCompactionReplayBypass = h.websocketUpstreamSupportsCompactionReplayForModel(requestModelName)
130+
}
131+
119132
var requestJSON []byte
120133
var updatedLastRequest []byte
121134
var errMsg *interfaces.ErrorMessage
@@ -124,6 +137,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
124137
lastRequest,
125138
lastResponseOutput,
126139
allowIncrementalInputWithPreviousResponseID,
140+
allowCompactionReplayBypass,
127141
)
128142
if errMsg != nil {
129143
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
@@ -222,21 +236,21 @@ func websocketUpgradeHeaders(req *http.Request) http.Header {
222236
}
223237

224238
func normalizeResponsesWebsocketRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte) ([]byte, []byte, *interfaces.ErrorMessage) {
225-
return normalizeResponsesWebsocketRequestWithMode(rawJSON, lastRequest, lastResponseOutput, true)
239+
return normalizeResponsesWebsocketRequestWithMode(rawJSON, lastRequest, lastResponseOutput, true, true)
226240
}
227241

228-
func normalizeResponsesWebsocketRequestWithMode(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) {
242+
func normalizeResponsesWebsocketRequestWithMode(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool, allowCompactionReplayBypass bool) ([]byte, []byte, *interfaces.ErrorMessage) {
229243
requestType := strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String())
230244
switch requestType {
231245
case wsRequestTypeCreate:
232246
// log.Infof("responses websocket: response.create request")
233247
if len(lastRequest) == 0 {
234248
return normalizeResponseCreateRequest(rawJSON)
235249
}
236-
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID)
250+
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID, allowCompactionReplayBypass)
237251
case wsRequestTypeAppend:
238252
// log.Infof("responses websocket: response.append request")
239-
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID)
253+
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID, allowCompactionReplayBypass)
240254
default:
241255
return nil, lastRequest, &interfaces.ErrorMessage{
242256
StatusCode: http.StatusBadRequest,
@@ -265,7 +279,7 @@ func normalizeResponseCreateRequest(rawJSON []byte) ([]byte, []byte, *interfaces
265279
return normalized, bytes.Clone(normalized), nil
266280
}
267281

268-
func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) {
282+
func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool, allowCompactionReplayBypass bool) ([]byte, []byte, *interfaces.ErrorMessage) {
269283
if len(lastRequest) == 0 {
270284
return nil, lastRequest, &interfaces.ErrorMessage{
271285
StatusCode: http.StatusBadRequest,
@@ -315,20 +329,37 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last
315329
}
316330
}
317331

318-
existingInput := gjson.GetBytes(lastRequest, "input")
319-
mergedInput, errMerge := mergeJSONArrayRaw(existingInput.Raw, normalizeJSONArrayRaw(lastResponseOutput))
320-
if errMerge != nil {
321-
return nil, lastRequest, &interfaces.ErrorMessage{
322-
StatusCode: http.StatusBadRequest,
323-
Error: fmt.Errorf("invalid previous response output: %w", errMerge),
332+
// When the client sends a compact replay for a downstream that can consume it
333+
// directly, the input already carries the canonical history. In that case,
334+
// skip merging with stale lastRequest/lastResponseOutput to avoid breaking
335+
// function_call / function_call_output pairings.
336+
// See: https://github.com/router-for-me/CLIProxyAPI/issues/2207
337+
var mergedInput string
338+
if allowCompactionReplayBypass && inputContainsFullTranscript(nextInput) {
339+
log.Infof("responses websocket: full transcript detected, skipping stale merge (input items=%d)", len(nextInput.Array()))
340+
mergedInput = nextInput.Raw
341+
} else {
342+
appendInputRaw := nextInput.Raw
343+
if inputContainsFullTranscript(nextInput) {
344+
appendInputRaw = inputWithoutCompactionItems(nextInput)
324345
}
325-
}
326346

327-
mergedInput, errMerge = mergeJSONArrayRaw(mergedInput, nextInput.Raw)
328-
if errMerge != nil {
329-
return nil, lastRequest, &interfaces.ErrorMessage{
330-
StatusCode: http.StatusBadRequest,
331-
Error: fmt.Errorf("invalid request input: %w", errMerge),
347+
existingInput := gjson.GetBytes(lastRequest, "input")
348+
var errMerge error
349+
mergedInput, errMerge = mergeJSONArrayRaw(existingInput.Raw, normalizeJSONArrayRaw(lastResponseOutput))
350+
if errMerge != nil {
351+
return nil, lastRequest, &interfaces.ErrorMessage{
352+
StatusCode: http.StatusBadRequest,
353+
Error: fmt.Errorf("invalid previous response output: %w", errMerge),
354+
}
355+
}
356+
357+
mergedInput, errMerge = mergeJSONArrayRaw(mergedInput, appendInputRaw)
358+
if errMerge != nil {
359+
return nil, lastRequest, &interfaces.ErrorMessage{
360+
StatusCode: http.StatusBadRequest,
361+
Error: fmt.Errorf("invalid request input: %w", errMerge),
362+
}
332363
}
333364
}
334365
dedupedInput, errDedupeFunctionCalls := dedupeFunctionCallsByCallID(mergedInput)
@@ -480,72 +511,104 @@ func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, met
480511
}
481512

482513
func (h *OpenAIResponsesAPIHandler) websocketUpstreamSupportsIncrementalInputForModel(modelName string) bool {
483-
if h == nil || h.AuthManager == nil {
514+
auths, _ := h.responsesWebsocketAvailableAuthsForModel(modelName)
515+
for _, auth := range auths {
516+
if websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) {
517+
return true
518+
}
519+
}
520+
return false
521+
}
522+
523+
func (h *OpenAIResponsesAPIHandler) websocketUpstreamSupportsCompactionReplayForModel(modelName string) bool {
524+
auths, _ := h.responsesWebsocketAvailableAuthsForModel(modelName)
525+
if len(auths) == 0 {
484526
return false
485527
}
528+
for _, auth := range auths {
529+
if !responsesWebsocketAuthSupportsCompactionReplay(auth) {
530+
return false
531+
}
532+
}
533+
return true
534+
}
535+
536+
func (h *OpenAIResponsesAPIHandler) responsesWebsocketAvailableAuthsForModel(modelName string) ([]*coreauth.Auth, string) {
537+
if h == nil || h.AuthManager == nil {
538+
return nil, ""
539+
}
540+
resolvedModelName := responsesWebsocketResolvedModelName(modelName)
541+
providerSet, modelKey := responsesWebsocketProviderSetForModel(resolvedModelName)
542+
if len(providerSet) == 0 {
543+
return nil, modelKey
544+
}
486545

487-
resolvedModelName := modelName
546+
registryRef := registry.GetGlobalRegistry()
547+
now := time.Now()
548+
auths := h.AuthManager.List()
549+
available := make([]*coreauth.Auth, 0, len(auths))
550+
for _, auth := range auths {
551+
if !responsesWebsocketAuthMatchesModel(auth, providerSet, modelKey, registryRef, now) {
552+
continue
553+
}
554+
available = append(available, auth)
555+
}
556+
return available, modelKey
557+
}
558+
559+
func responsesWebsocketResolvedModelName(modelName string) string {
488560
initialSuffix := thinking.ParseSuffix(modelName)
489561
if initialSuffix.ModelName == "auto" {
490562
resolvedBase := util.ResolveAutoModel(initialSuffix.ModelName)
491563
if initialSuffix.HasSuffix {
492-
resolvedModelName = fmt.Sprintf("%s(%s)", resolvedBase, initialSuffix.RawSuffix)
493-
} else {
494-
resolvedModelName = resolvedBase
564+
return fmt.Sprintf("%s(%s)", resolvedBase, initialSuffix.RawSuffix)
495565
}
496-
} else {
497-
resolvedModelName = util.ResolveAutoModel(modelName)
566+
return resolvedBase
498567
}
568+
return util.ResolveAutoModel(modelName)
569+
}
499570

571+
func responsesWebsocketProviderSetForModel(resolvedModelName string) (map[string]struct{}, string) {
500572
parsed := thinking.ParseSuffix(resolvedModelName)
501573
baseModel := strings.TrimSpace(parsed.ModelName)
502574
providers := util.GetProviderName(baseModel)
503575
if len(providers) == 0 && baseModel != resolvedModelName {
504576
providers = util.GetProviderName(resolvedModelName)
505577
}
506-
if len(providers) == 0 {
507-
return false
508-
}
509-
510578
providerSet := make(map[string]struct{}, len(providers))
511-
for i := 0; i < len(providers); i++ {
512-
providerKey := strings.TrimSpace(strings.ToLower(providers[i]))
579+
for _, provider := range providers {
580+
providerKey := strings.TrimSpace(strings.ToLower(provider))
513581
if providerKey == "" {
514582
continue
515583
}
516584
providerSet[providerKey] = struct{}{}
517585
}
518-
if len(providerSet) == 0 {
519-
return false
520-
}
521-
522586
modelKey := baseModel
523587
if modelKey == "" {
524588
modelKey = strings.TrimSpace(resolvedModelName)
525589
}
526-
registryRef := registry.GetGlobalRegistry()
527-
now := time.Now()
528-
auths := h.AuthManager.List()
529-
for i := 0; i < len(auths); i++ {
530-
auth := auths[i]
531-
if auth == nil {
532-
continue
533-
}
534-
providerKey := strings.TrimSpace(strings.ToLower(auth.Provider))
535-
if _, ok := providerSet[providerKey]; !ok {
536-
continue
537-
}
538-
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(auth.ID, modelKey) {
539-
continue
540-
}
541-
if !responsesWebsocketAuthAvailableForModel(auth, modelKey, now) {
542-
continue
543-
}
544-
if websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) {
545-
return true
546-
}
590+
return providerSet, modelKey
591+
}
592+
593+
func responsesWebsocketAuthMatchesModel(auth *coreauth.Auth, providerSet map[string]struct{}, modelKey string, registryRef *registry.ModelRegistry, now time.Time) bool {
594+
if auth == nil {
595+
return false
547596
}
548-
return false
597+
providerKey := strings.TrimSpace(strings.ToLower(auth.Provider))
598+
if _, ok := providerSet[providerKey]; !ok {
599+
return false
600+
}
601+
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(auth.ID, modelKey) {
602+
return false
603+
}
604+
return responsesWebsocketAuthAvailableForModel(auth, modelKey, now)
605+
}
606+
607+
func responsesWebsocketAuthSupportsCompactionReplay(auth *coreauth.Auth) bool {
608+
if auth == nil {
609+
return false
610+
}
611+
return strings.EqualFold(strings.TrimSpace(auth.Provider), "codex")
549612
}
550613

551614
func responsesWebsocketAuthAvailableForModel(auth *coreauth.Auth, modelName string, now time.Time) bool {
@@ -691,6 +754,42 @@ func mergeJSONArrayRaw(existingRaw, appendRaw string) (string, error) {
691754
return string(out), nil
692755
}
693756

757+
// inputContainsFullTranscript returns true when the input array carries compact
758+
// replay markers that indicate the client already sent the full conversation
759+
// transcript. Merging that input with stale lastRequest/lastResponseOutput
760+
// would duplicate or break function_call/function_call_output pairings, so the
761+
// caller should use the input as-is.
762+
//
763+
// Assistant messages alone are not enough to classify the payload as a replay:
764+
// incremental websocket requests may legitimately append assistant items.
765+
func inputContainsFullTranscript(input gjson.Result) bool {
766+
if !input.IsArray() {
767+
return false
768+
}
769+
for _, item := range input.Array() {
770+
t := item.Get("type").String()
771+
if t == "compaction" || t == "compaction_summary" {
772+
return true
773+
}
774+
}
775+
return false
776+
}
777+
778+
func inputWithoutCompactionItems(input gjson.Result) string {
779+
if !input.IsArray() {
780+
return normalizeJSONArrayRaw([]byte(input.Raw))
781+
}
782+
filtered := make([]string, 0, len(input.Array()))
783+
for _, item := range input.Array() {
784+
t := item.Get("type").String()
785+
if t == "compaction" || t == "compaction_summary" {
786+
continue
787+
}
788+
filtered = append(filtered, item.Raw)
789+
}
790+
return "[" + strings.Join(filtered, ",") + "]"
791+
}
792+
694793
func normalizeJSONArrayRaw(raw []byte) string {
695794
trimmed := strings.TrimSpace(string(raw))
696795
if trimmed == "" {

0 commit comments

Comments
 (0)