Skip to content

Commit d6d0822

Browse files
committed
fix: import PR #102
1 parent 7fcb49d commit d6d0822

1 file changed

Lines changed: 73 additions & 0 deletions

File tree

internal/proxy/streaming.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212

1313
"github.com/lich0821/ccNexus/internal/config"
1414
"github.com/lich0821/ccNexus/internal/logger"
15+
"github.com/lich0821/ccNexus/internal/tokencount"
1516
"github.com/lich0821/ccNexus/internal/transformer"
1617
)
1718

@@ -85,6 +86,24 @@ func (p *Proxy) handleStreamingResponse(w http.ResponseWriter, resp *http.Respon
8586

8687
if strings.Contains(line, "data: [DONE]") {
8788
streamDone = true
89+
90+
// Token Usage Fallback: Inject message_delta with estimated output_tokens before [DONE]
91+
if outputTokens == 0 && outputText.Len() > 0 {
92+
outputTokens = tokencount.EstimateOutputTokens(outputText.String())
93+
logger.Debug("[%s] Token fallback before [DONE]: estimated output_tokens=%d", endpoint.Name, outputTokens)
94+
95+
// Update stream context for transformer fallback
96+
if streamCtx != nil {
97+
streamCtx.OutputTokens = outputTokens
98+
}
99+
100+
// Inject message_delta event with usage
101+
deltaEvent := fmt.Sprintf("event: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null},\"usage\":{\"output_tokens\":%d}}\n\n", outputTokens)
102+
if _, writeErr := w.Write([]byte(deltaEvent)); writeErr == nil {
103+
flusher.Flush()
104+
}
105+
}
106+
88107
buffer.WriteString(line + "\n")
89108
eventData := buffer.Bytes()
90109
logger.DebugLog("[%s] SSE Event #%d (Original): %s", endpoint.Name, eventCount+1, string(eventData))
@@ -105,6 +124,24 @@ func (p *Proxy) handleStreamingResponse(w http.ResponseWriter, resp *http.Respon
105124
eventData := buffer.Bytes()
106125
logger.DebugLog("[%s] SSE Event #%d (Original): %s", endpoint.Name, eventCount, string(eventData))
107126

127+
// Check if this is a message_stop event (Token Usage Fallback)
128+
isMessageStop := p.isMessageStopEvent(eventData)
129+
if isMessageStop && outputTokens == 0 && outputText.Len() > 0 {
130+
outputTokens = tokencount.EstimateOutputTokens(outputText.String())
131+
logger.Debug("[%s] Token fallback before message_stop: estimated output_tokens=%d", endpoint.Name, outputTokens)
132+
133+
// Update stream context for transformer fallback
134+
if streamCtx != nil {
135+
streamCtx.OutputTokens = outputTokens
136+
}
137+
138+
// Inject message_delta event with usage before message_stop
139+
deltaEvent := fmt.Sprintf("event: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null},\"usage\":{\"output_tokens\":%d}}\n\n", outputTokens)
140+
if _, writeErr := w.Write([]byte(deltaEvent)); writeErr == nil {
141+
flusher.Flush()
142+
}
143+
}
144+
108145
transformedEvent, err := p.transformStreamEvent(eventData, trans, transformerName, streamCtx)
109146
if err != nil {
110147
logger.Error("[%s] Failed to transform SSE event: %v", endpoint.Name, err)
@@ -211,6 +248,7 @@ func (p *Proxy) extractTokensFromEvent(eventData []byte, inputTokens, outputToke
211248
}
212249

213250
// extractTextFromEvent extracts text content from transformed event
251+
// Enhanced to support both delta.text and content_block_delta formats
214252
func (p *Proxy) extractTextFromEvent(transformedEvent []byte, outputText *strings.Builder) {
215253
scanner := bufio.NewScanner(bytes.NewReader(transformedEvent))
216254
for scanner.Scan() {
@@ -225,6 +263,18 @@ func (p *Proxy) extractTextFromEvent(transformedEvent []byte, outputText *string
225263
continue
226264
}
227265

266+
eventType, _ := event["type"].(string)
267+
268+
// Handle content_block_delta format (from some third-party APIs)
269+
if eventType == "content_block_delta" {
270+
if delta, ok := event["delta"].(map[string]interface{}); ok {
271+
if text, ok := delta["text"].(string); ok {
272+
outputText.WriteString(text)
273+
}
274+
}
275+
}
276+
277+
// Handle standard delta.text format
228278
if delta, ok := event["delta"].(map[string]interface{}); ok {
229279
if text, ok := delta["text"].(string); ok {
230280
outputText.WriteString(text)
@@ -233,6 +283,29 @@ func (p *Proxy) extractTextFromEvent(transformedEvent []byte, outputText *string
233283
}
234284
}
235285

286+
// isMessageStopEvent checks if the event is a message_stop event
287+
func (p *Proxy) isMessageStopEvent(eventData []byte) bool {
288+
scanner := bufio.NewScanner(bytes.NewReader(eventData))
289+
for scanner.Scan() {
290+
line := scanner.Text()
291+
if !strings.HasPrefix(line, "data:") {
292+
continue
293+
}
294+
295+
jsonData := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
296+
var event map[string]interface{}
297+
if err := json.Unmarshal([]byte(jsonData), &event); err != nil {
298+
continue
299+
}
300+
301+
eventType, _ := event["type"].(string)
302+
if eventType == "message_stop" {
303+
return true
304+
}
305+
}
306+
return false
307+
}
308+
236309
// decompressGzip decompresses gzip-encoded response body
237310
func decompressGzip(body io.ReadCloser) ([]byte, error) {
238311
gzipReader, err := gzip.NewReader(body)

0 commit comments

Comments
 (0)