Skip to content

Commit 447e5c4

Browse files
committed
fix(transformer): add fallback for token usage in streaming responses of claude
1 parent cbc0711 commit 447e5c4

2 files changed

Lines changed: 108 additions & 2 deletions

File tree

internal/proxy/streaming.go

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@ import (
55
"bytes"
66
"compress/gzip"
77
"encoding/json"
8+
"fmt"
89
"io"
910
"net/http"
1011
"strings"
1112

1213
"github.com/lich0821/ccNexus/internal/config"
1314
"github.com/lich0821/ccNexus/internal/logger"
15+
"github.com/lich0821/ccNexus/internal/tokencount"
1416
"github.com/lich0821/ccNexus/internal/transformer"
1517
"github.com/lich0821/ccNexus/internal/transformer/cc"
1618
"github.com/lich0821/ccNexus/internal/transformer/cx/chat"
@@ -85,6 +87,43 @@ func (p *Proxy) handleStreamingResponse(w http.ResponseWriter, resp *http.Respon
8587
}
8688

8789
if strings.Contains(line, "data: [DONE]") {
90+
// Fallback: update output_tokens if not provided
91+
if outputTokens == 0 && outputText.String() != "" {
92+
outputTokens = tokencount.EstimateOutputTokens(outputText.String())
93+
}
94+
95+
if streamCtx != nil {
96+
streamCtx.OutputTokens = outputTokens
97+
streamCtx.InputTokens = inputTokens
98+
}
99+
100+
// Create message_delta to output
101+
deltaEvent := map[string]interface{}{
102+
"type": "message_delta",
103+
"delta": map[string]interface{}{
104+
"stop_reason": "end_turn",
105+
"stop_sequence": nil,
106+
},
107+
"usage": map[string]interface{}{
108+
"output_tokens": outputTokens,
109+
},
110+
}
111+
deltaData, _ := json.Marshal(deltaEvent)
112+
deltaSSE := fmt.Sprintf("data: %s\n\n", deltaData)
113+
114+
// For cc_claude, send directly to avoid transformer modifying it
115+
if transformerName == "cc_claude" {
116+
w.Write([]byte(deltaSSE))
117+
flusher.Flush()
118+
} else {
119+
// Transform event for other transformers
120+
transformedDelta, _ := p.transformStreamEvent([]byte(deltaSSE), trans, transformerName, streamCtx)
121+
if len(transformedDelta) > 0 {
122+
w.Write(transformedDelta)
123+
flusher.Flush()
124+
}
125+
}
126+
88127
streamDone = true
89128
buffer.WriteString(line + "\n")
90129
eventData := buffer.Bytes()
@@ -115,6 +154,47 @@ func (p *Proxy) handleStreamingResponse(w http.ResponseWriter, resp *http.Respon
115154
p.extractTokensFromEvent(transformedEvent, &inputTokens, &outputTokens)
116155
p.extractTextFromEvent(transformedEvent, &outputText)
117156

157+
// Check if this is message_stop event - send custom message_delta before it
158+
isMessageStop := strings.Contains(string(transformedEvent), `"type":"message_stop"`)
159+
if isMessageStop {
160+
// Estimate output tokens if not provided
161+
if outputTokens == 0 && outputText.String() != "" {
162+
outputTokens = tokencount.EstimateOutputTokens(outputText.String())
163+
}
164+
165+
if streamCtx != nil {
166+
streamCtx.OutputTokens = outputTokens
167+
streamCtx.InputTokens = inputTokens
168+
}
169+
170+
// Create message_delta event with output_tokens
171+
deltaEvent := map[string]interface{}{
172+
"type": "message_delta",
173+
"delta": map[string]interface{}{
174+
"stop_reason": "end_turn",
175+
"stop_sequence": nil,
176+
},
177+
"usage": map[string]interface{}{
178+
"output_tokens": outputTokens,
179+
},
180+
}
181+
deltaData, _ := json.Marshal(deltaEvent)
182+
deltaSSE := fmt.Sprintf("data: %s\n\n", deltaData)
183+
184+
// For cc_claude, send directly
185+
if transformerName == "cc_claude" {
186+
w.Write([]byte(deltaSSE))
187+
flusher.Flush()
188+
} else {
189+
// Transform event for other transformers
190+
transformedDelta, _ := p.transformStreamEvent([]byte(deltaSSE), trans, transformerName, streamCtx)
191+
if len(transformedDelta) > 0 {
192+
w.Write(transformedDelta)
193+
flusher.Flush()
194+
}
195+
}
196+
}
197+
118198
if _, writeErr := w.Write(transformedEvent); writeErr != nil {
119199
// Client disconnected (broken pipe) is normal for cancelled requests
120200
if strings.Contains(writeErr.Error(), "broken pipe") || strings.Contains(writeErr.Error(), "connection reset") {
@@ -223,6 +303,20 @@ func (p *Proxy) extractTextFromEvent(transformedEvent []byte, outputText *string
223303
continue
224304
}
225305

306+
eventType, _ := event["type"].(string)
307+
308+
// Handle content_block_delta (new format)
309+
if eventType == "content_block_delta" {
310+
if delta, ok := event["delta"].(map[string]interface{}); ok {
311+
if deltaType, _ := delta["type"].(string); deltaType == "text_delta" {
312+
if text, ok := delta["text"].(string); ok {
313+
outputText.WriteString(text)
314+
}
315+
}
316+
}
317+
}
318+
319+
// Handle content_block (old format)
226320
if delta, ok := event["delta"].(map[string]interface{}); ok {
227321
if text, ok := delta["text"].(string); ok {
228322
outputText.WriteString(text)

internal/transformer/cc/claude.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,23 @@ func (t *ClaudeTransformer) TransformResponseWithContext(resp []byte, isStreamin
7676
} else if eventType == "message_delta" {
7777
// Fallback: fill input_tokens if 0
7878
if usage, ok := event["usage"].(map[string]interface{}); ok {
79+
modified := false
80+
7981
if input, ok := usage["input_tokens"].(float64); ok && int(input) == 0 && ctx.InputTokens > 0 {
8082
usage["input_tokens"] = ctx.InputTokens
81-
modified, _ := json.Marshal(event)
83+
modified = true
84+
}
85+
86+
// Fallback: fill output_tokens if 0
87+
if output, ok := usage["output_tokens"].(float64); ok && int(output) == 0 && ctx.OutputTokens > 0 {
88+
usage["output_tokens"] = ctx.OutputTokens
89+
modified = true
90+
}
91+
92+
if modified {
93+
modifiedData, _ := json.Marshal(event)
8294
result.WriteString("data: ")
83-
result.Write(modified)
95+
result.Write(modifiedData)
8496
result.WriteString("\n")
8597
continue
8698
}

0 commit comments

Comments
 (0)