diff --git a/internal/proxy/streaming.go b/internal/proxy/streaming.go index 7b6ea343..51b4062b 100644 --- a/internal/proxy/streaming.go +++ b/internal/proxy/streaming.go @@ -5,12 +5,14 @@ import ( "bytes" "compress/gzip" "encoding/json" + "fmt" "io" "net/http" "strings" "github.com/lich0821/ccNexus/internal/config" "github.com/lich0821/ccNexus/internal/logger" + "github.com/lich0821/ccNexus/internal/tokencount" "github.com/lich0821/ccNexus/internal/transformer" "github.com/lich0821/ccNexus/internal/transformer/cc" "github.com/lich0821/ccNexus/internal/transformer/cx/chat" @@ -85,6 +87,43 @@ func (p *Proxy) handleStreamingResponse(w http.ResponseWriter, resp *http.Respon } if strings.Contains(line, "data: [DONE]") { + // Fallback: update output_tokens if not provided + if outputTokens == 0 && outputText.String() != "" { + outputTokens = tokencount.EstimateOutputTokens(outputText.String()) + } + + if streamCtx != nil { + streamCtx.OutputTokens = outputTokens + streamCtx.InputTokens = inputTokens + } + + // Create message_delta to output + deltaEvent := map[string]interface{}{ + "type": "message_delta", + "delta": map[string]interface{}{ + "stop_reason": "end_turn", + "stop_sequence": nil, + }, + "usage": map[string]interface{}{ + "output_tokens": outputTokens, + }, + } + deltaData, _ := json.Marshal(deltaEvent) + deltaSSE := fmt.Sprintf("data: %s\n\n", deltaData) + + // For cc_claude, send directly to avoid transformer modifying it + if transformerName == "cc_claude" { + w.Write([]byte(deltaSSE)) + flusher.Flush() + } else { + // Transform event for other transformers + transformedDelta, _ := p.transformStreamEvent([]byte(deltaSSE), trans, transformerName, streamCtx) + if len(transformedDelta) > 0 { + w.Write(transformedDelta) + flusher.Flush() + } + } + streamDone = true buffer.WriteString(line + "\n") eventData := buffer.Bytes() @@ -115,6 +154,47 @@ func (p *Proxy) handleStreamingResponse(w http.ResponseWriter, resp *http.Respon p.extractTokensFromEvent(transformedEvent, &inputTokens, &outputTokens) p.extractTextFromEvent(transformedEvent, &outputText) + // Check if this is message_stop event - send custom message_delta before it + isMessageStop := strings.Contains(string(transformedEvent), `"type":"message_stop"`) + if isMessageStop { + // Estimate output tokens if not provided + if outputTokens == 0 && outputText.String() != "" { + outputTokens = tokencount.EstimateOutputTokens(outputText.String()) + } + + if streamCtx != nil { + streamCtx.OutputTokens = outputTokens + streamCtx.InputTokens = inputTokens + } + + // Create message_delta event with output_tokens + deltaEvent := map[string]interface{}{ + "type": "message_delta", + "delta": map[string]interface{}{ + "stop_reason": "end_turn", + "stop_sequence": nil, + }, + "usage": map[string]interface{}{ + "output_tokens": outputTokens, + }, + } + deltaData, _ := json.Marshal(deltaEvent) + deltaSSE := fmt.Sprintf("data: %s\n\n", deltaData) + + // For cc_claude, send directly + if transformerName == "cc_claude" { + w.Write([]byte(deltaSSE)) + flusher.Flush() + } else { + // Transform event for other transformers + transformedDelta, _ := p.transformStreamEvent([]byte(deltaSSE), trans, transformerName, streamCtx) + if len(transformedDelta) > 0 { + w.Write(transformedDelta) + flusher.Flush() + } + } + } + if _, writeErr := w.Write(transformedEvent); writeErr != nil { // Client disconnected (broken pipe) is normal for cancelled requests if strings.Contains(writeErr.Error(), "broken pipe") || strings.Contains(writeErr.Error(), "connection reset") { @@ -223,6 +303,20 @@ func (p *Proxy) extractTextFromEvent(transformedEvent []byte, outputText *string continue } + eventType, _ := event["type"].(string) + + // Handle content_block_delta (new format) + if eventType == "content_block_delta" { + if delta, ok := event["delta"].(map[string]interface{}); ok { + if deltaType, _ := delta["type"].(string); deltaType == "text_delta" { + if text, ok := delta["text"].(string); ok { + outputText.WriteString(text) + } + } + } + } + + // Handle content_block (old format) if delta, ok := event["delta"].(map[string]interface{}); ok { if text, ok := delta["text"].(string); ok { outputText.WriteString(text) diff --git a/internal/transformer/cc/claude.go b/internal/transformer/cc/claude.go index 73088184..4136c1cd 100644 --- a/internal/transformer/cc/claude.go +++ b/internal/transformer/cc/claude.go @@ -76,11 +76,23 @@ func (t *ClaudeTransformer) TransformResponseWithContext(resp []byte, isStreamin } else if eventType == "message_delta" { // Fallback: fill input_tokens if 0 if usage, ok := event["usage"].(map[string]interface{}); ok { + modified := false + if input, ok := usage["input_tokens"].(float64); ok && int(input) == 0 && ctx.InputTokens > 0 { usage["input_tokens"] = ctx.InputTokens - modified, _ := json.Marshal(event) + modified = true + } + + // Fallback: fill output_tokens if 0 + if output, ok := usage["output_tokens"].(float64); ok && int(output) == 0 && ctx.OutputTokens > 0 { + usage["output_tokens"] = ctx.OutputTokens + modified = true + } + + if modified { + modifiedData, _ := json.Marshal(event) result.WriteString("data: ") - result.Write(modified) + result.Write(modifiedData) result.WriteString("\n") continue }