diff --git a/pkg/handlers/response.go b/pkg/handlers/response.go index 7d2d7cc..7e35eb4 100644 --- a/pkg/handlers/response.go +++ b/pkg/handlers/response.go @@ -17,8 +17,10 @@ limitations under the License. package handlers import ( + "bytes" "context" "encoding/json" + "errors" "fmt" "strconv" "time" @@ -44,9 +46,9 @@ func (s *Server) HandleResponseHeaders(ctx context.Context, reqCtx *RequestConte if !headers.GetEndOfStream() { log.FromContext(ctx).V(logutil.VERBOSE).Info("captured response headers, deferring response until body arrives...") - return nil } - // EndOfStream means no body is expected, return HeadersResponse immediately + // Always respond to response headers so Envoy proceeds with body chunks. + // In STREAMED/FULL_DUPLEX_STREAMED mode, Envoy blocks until we respond. return []*eppb.ProcessingResponse{ { Response: &eppb.ProcessingResponse_ResponseHeaders{ @@ -64,8 +66,15 @@ func (s *Server) HandleResponseBody(ctx context.Context, reqCtx *RequestContext, } if err := json.Unmarshal(responseBodyBytes, &reqCtx.Response.Body); err != nil { - logger.Error(err, "Failed to parse response body as JSON, skipping response plugins") - return s.generateEmptyResponseBodyResponse(responseBodyBytes), nil + // Try parsing as SSE (Server-Sent Events) — streaming responses from providers + // like Anthropic use SSE format which isn't valid JSON. + if sseBody, sseErr := parseSSEResponseBody(responseBodyBytes); sseErr == nil && sseBody != nil { + reqCtx.Response.Body = sseBody + logger.V(logutil.VERBOSE).Info("parsed SSE response body for response plugins") + } else { + logger.Error(err, "Failed to parse response body as JSON or SSE, skipping response plugins") + return s.generateEmptyResponseBodyResponse(responseBodyBytes), nil + } } if err := s.runResponsePlugins(ctx, reqCtx.CycleState, reqCtx.Response); err != nil { @@ -130,6 +139,71 @@ func (s *Server) HandleResponseTrailers(trailers *eppb.HttpTrailers) ([]*eppb.Pr }, nil } +// parseSSEResponseBody extracts a composite response body from an SSE (Server-Sent Events) +// stream. It scans all "data:" lines for JSON objects and merges usage/model fields into +// a single map that response plugins can process. This enables usage-tracking and metering +// plugins to work with streaming responses from providers like Anthropic and OpenAI. +// parseSSEResponseBody extracts a composite response body from an SSE (Server-Sent Events) +// stream. It parses by SSE event boundaries instead of individual lines because one logical +// event may legally contain multiple consecutive `data:` lines that must be joined before JSON decoding. +func parseSSEResponseBody(body []byte) (map[string]any, error) { + result := map[string]any{} + lines := bytes.Split(body, []byte("\n")) + eventDataLines := make([][]byte, 0) + + // flushEvent keeps the SSE framing logic local to this parser because the bug happens + // exactly at event boundaries: we must join all `data:` lines for one event before parsing. + flushEvent := func() { + if len(eventDataLines) == 0 { + return + } + + data := bytes.Join(eventDataLines, []byte("\n")) + eventDataLines = eventDataLines[:0] + + data = bytes.TrimSpace(data) + if len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) { + return + } + + var event map[string]any + if err := json.Unmarshal(data, &event); err != nil { + return + } + + if model, ok := event["model"].(string); ok && model != "" { + result["model"] = model + } + + // Check for usage at top level (Anthropic) or nested in response (OpenAI Responses API) + usage, _ := event["usage"].(map[string]any) + if usage == nil { + if resp, ok := event["response"].(map[string]any); ok { + usage, _ = resp["usage"].(map[string]any) + if m, ok := resp["model"].(string); ok && m != "" { + result["model"] = m + } + } + } + if usage != nil { + existing, _ := result["usage"].(map[string]any) + if existing == nil { + existing = map[string]any{} + } + for k, v := range usage { + existing[k] = v + } + result["usage"] = existing + } + } + + if len(result) == 0 { + return nil, errors.New("no parseable SSE data events found") + } + + return result, nil +} + // runResponsePlugins executes response plugins in the order they were registered. func (s *Server) runResponsePlugins(ctx context.Context, cycleState *plugin.CycleState, response *requesthandling.InferenceResponse) error { logger := log.FromContext(ctx).V(logutil.DEFAULT) diff --git a/pkg/handlers/server.go b/pkg/handlers/server.go index ffc90c2..75940d1 100644 --- a/pkg/handlers/server.go +++ b/pkg/handlers/server.go @@ -146,6 +146,16 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { } responseBody = append(responseBody, v.ResponseBody.Body...) if !v.ResponseBody.EndOfStream { + // Send an immediate response for this chunk so Envoy continues + // streaming. Without this, Envoy blocks waiting for our response + // and stops forwarding subsequent chunks. + if sendErr := srv.Send(&extProcPb.ProcessingResponse{ + Response: &extProcPb.ProcessingResponse_ResponseBody{ + ResponseBody: &extProcPb.BodyResponse{}, + }, + }); sendErr != nil { + return status.Errorf(codes.Unknown, "failed to send streaming response ack: %v", sendErr) + } continue } reqCtx.ResponseCompleteTimestamp = time.Now() diff --git a/pkg/handlers/server_test.go b/pkg/handlers/server_test.go index 406195a..421b12c 100644 --- a/pkg/handlers/server_test.go +++ b/pkg/handlers/server_test.go @@ -177,6 +177,10 @@ func TestHandleResponseBody_Streaming(t *testing.T) { if err := process.Send(request); err != nil { t.Fatalf("send response headers: %v", err) } + // Discard the immediate header ack (HandleResponseHeaders always responds now). + if _, err := process.Recv(); err != nil { + t.Fatalf("recv header ack: %v", err) + } for _, c := range tc.chunks { request = &extProcPb.ProcessingRequest{ @@ -190,6 +194,12 @@ func TestHandleResponseBody_Streaming(t *testing.T) { if err := process.Send(request); err != nil { t.Fatalf("send response body chunk: %v", err) } + // Discard the immediate ack for non-EoS chunks. + if !c.endOfStream { + if _, err := process.Recv(); err != nil { + t.Fatalf("recv chunk ack: %v", err) + } + } } got := make([]*extProcPb.ProcessingResponse, 0, len(want))