Skip to content

Commit 9a67441

Browse files
committed
fix(transformer): add fallback for token usage in streaming responses of claude
1 parent cbc0711 commit 9a67441

2 files changed

Lines changed: 108 additions & 7 deletions

File tree

internal/proxy/streaming.go

Lines changed: 94 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@ import (
55
"bytes"
66
"compress/gzip"
77
"encoding/json"
8+
"fmt"
89
"io"
910
"net/http"
1011
"strings"
1112

1213
"github.com/lich0821/ccNexus/internal/config"
1314
"github.com/lich0821/ccNexus/internal/logger"
15+
"github.com/lich0821/ccNexus/internal/tokencount"
1416
"github.com/lich0821/ccNexus/internal/transformer"
1517
"github.com/lich0821/ccNexus/internal/transformer/cc"
1618
"github.com/lich0821/ccNexus/internal/transformer/cx/chat"
@@ -85,14 +87,49 @@ func (p *Proxy) handleStreamingResponse(w http.ResponseWriter, resp *http.Respon
8587
}
8688

8789
if strings.Contains(line, "data: [DONE]") {
90+
// Fallback: update output_tokens if not provided
91+
if outputTokens == 0 && outputText.String() != "" {
92+
outputTokens = tokencount.EstimateOutputTokens(outputText.String())
93+
}
94+
95+
if streamCtx != nil {
96+
streamCtx.OutputTokens = outputTokens
97+
streamCtx.InputTokens = inputTokens
98+
}
99+
100+
// Create message_delta to output
101+
deltaEvent := map[string]interface{}{
102+
"type": "message_delta",
103+
"delta": map[string]interface{}{
104+
"stop_reason": "end_turn",
105+
"stop_sequence": nil,
106+
},
107+
"usage": map[string]interface{}{
108+
"output_tokens": outputTokens,
109+
},
110+
}
111+
deltaData, _ := json.Marshal(deltaEvent)
112+
deltaSSE := fmt.Sprintf("data: %s\n\n", deltaData)
113+
114+
// For cc_claude, send directly to avoid transformer modifying it
115+
if transformerName == "cc_claude" {
116+
w.Write([]byte(deltaSSE))
117+
flusher.Flush()
118+
} else {
119+
// Transform event for other transformers
120+
transformedDelta, _ := p.transformStreamEvent([]byte(deltaSSE), trans, transformerName, streamCtx)
121+
if len(transformedDelta) > 0 {
122+
w.Write(transformedDelta)
123+
flusher.Flush()
124+
}
125+
}
126+
88127
streamDone = true
89128
buffer.WriteString(line + "\n")
90129
eventData := buffer.Bytes()
91-
logger.DebugLog("[%s] SSE Event #%d (Original): %s", endpoint.Name, eventCount+1, string(eventData))
92130

93131
transformedEvent, err := p.transformStreamEvent(eventData, trans, transformerName, streamCtx)
94132
if err == nil && len(transformedEvent) > 0 {
95-
logger.DebugLog("[%s] SSE Event #%d (Transformed): %s", endpoint.Name, eventCount+1, string(transformedEvent))
96133
w.Write(transformedEvent)
97134
flusher.Flush()
98135
}
@@ -104,17 +141,55 @@ func (p *Proxy) handleStreamingResponse(w http.ResponseWriter, resp *http.Respon
104141
if line == "" {
105142
eventCount++
106143
eventData := buffer.Bytes()
107-
logger.DebugLog("[%s] SSE Event #%d (Original): %s", endpoint.Name, eventCount, string(eventData))
108144

109145
transformedEvent, err := p.transformStreamEvent(eventData, trans, transformerName, streamCtx)
110146
if err != nil {
111147
logger.Error("[%s] Failed to transform SSE event: %v", endpoint.Name, err)
112148
} else if len(transformedEvent) > 0 {
113-
logger.DebugLog("[%s] SSE Event #%d (Transformed): %s", endpoint.Name, eventCount, string(transformedEvent))
114-
115149
p.extractTokensFromEvent(transformedEvent, &inputTokens, &outputTokens)
116150
p.extractTextFromEvent(transformedEvent, &outputText)
117151

152+
// Check if this is message_stop event - send custom message_delta before it
153+
isMessageStop := strings.Contains(string(transformedEvent), `"type":"message_stop"`)
154+
if isMessageStop {
155+
// Estimate output tokens if not provided
156+
if outputTokens == 0 && outputText.String() != "" {
157+
outputTokens = tokencount.EstimateOutputTokens(outputText.String())
158+
}
159+
160+
if streamCtx != nil {
161+
streamCtx.OutputTokens = outputTokens
162+
streamCtx.InputTokens = inputTokens
163+
}
164+
165+
// Create message_delta event with output_tokens
166+
deltaEvent := map[string]interface{}{
167+
"type": "message_delta",
168+
"delta": map[string]interface{}{
169+
"stop_reason": "end_turn",
170+
"stop_sequence": nil,
171+
},
172+
"usage": map[string]interface{}{
173+
"output_tokens": outputTokens,
174+
},
175+
}
176+
deltaData, _ := json.Marshal(deltaEvent)
177+
deltaSSE := fmt.Sprintf("data: %s\n\n", deltaData)
178+
179+
// For cc_claude, send directly
180+
if transformerName == "cc_claude" {
181+
w.Write([]byte(deltaSSE))
182+
flusher.Flush()
183+
} else {
184+
// Transform event for other transformers
185+
transformedDelta, _ := p.transformStreamEvent([]byte(deltaSSE), trans, transformerName, streamCtx)
186+
if len(transformedDelta) > 0 {
187+
w.Write(transformedDelta)
188+
flusher.Flush()
189+
}
190+
}
191+
}
192+
118193
if _, writeErr := w.Write(transformedEvent); writeErr != nil {
119194
// Client disconnected (broken pipe) is normal for cancelled requests
120195
if strings.Contains(writeErr.Error(), "broken pipe") || strings.Contains(writeErr.Error(), "connection reset") {
@@ -223,6 +298,20 @@ func (p *Proxy) extractTextFromEvent(transformedEvent []byte, outputText *string
223298
continue
224299
}
225300

301+
eventType, _ := event["type"].(string)
302+
303+
// Handle content_block_delta (new format)
304+
if eventType == "content_block_delta" {
305+
if delta, ok := event["delta"].(map[string]interface{}); ok {
306+
if deltaType, _ := delta["type"].(string); deltaType == "text_delta" {
307+
if text, ok := delta["text"].(string); ok {
308+
outputText.WriteString(text)
309+
}
310+
}
311+
}
312+
}
313+
314+
// Handle content_block (old format)
226315
if delta, ok := event["delta"].(map[string]interface{}); ok {
227316
if text, ok := delta["text"].(string); ok {
228317
outputText.WriteString(text)

internal/transformer/cc/claude.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,23 @@ func (t *ClaudeTransformer) TransformResponseWithContext(resp []byte, isStreamin
7676
} else if eventType == "message_delta" {
7777
// Fallback: fill input_tokens if 0
7878
if usage, ok := event["usage"].(map[string]interface{}); ok {
79+
modified := false
80+
7981
if input, ok := usage["input_tokens"].(float64); ok && int(input) == 0 && ctx.InputTokens > 0 {
8082
usage["input_tokens"] = ctx.InputTokens
81-
modified, _ := json.Marshal(event)
83+
modified = true
84+
}
85+
86+
// Fallback: fill output_tokens if 0
87+
if output, ok := usage["output_tokens"].(float64); ok && int(output) == 0 && ctx.OutputTokens > 0 {
88+
usage["output_tokens"] = ctx.OutputTokens
89+
modified = true
90+
}
91+
92+
if modified {
93+
modifiedData, _ := json.Marshal(event)
8294
result.WriteString("data: ")
83-
result.Write(modified)
95+
result.Write(modifiedData)
8496
result.WriteString("\n")
8597
continue
8698
}

0 commit comments

Comments
 (0)