diff --git a/pkg/config/loader/configloader.go b/pkg/config/loader/configloader.go index 25cef3a7..9edf78c5 100644 --- a/pkg/config/loader/configloader.go +++ b/pkg/config/loader/configloader.go @@ -214,9 +214,7 @@ func buildProfiles(rawProfiles []configapi.Profile, handle plugin.Handle) (map[s return nil, fmt.Errorf("the profile %s must have one or both of the Request and Response sections", rawProfile.Name) } - theProfile := requesthandling.Profile{ - ResponsePlugins: make([]requesthandling.ResponseProcessor, len(rawProfile.Plugins.Response)), - } + theProfile := requesthandling.Profile{} for _, pluginRef := range rawProfile.Plugins.Request { rawPlugin := handle.Plugin(pluginRef.PluginRef) @@ -245,17 +243,25 @@ func buildProfiles(rawProfiles []configapi.Profile, handle plugin.Handle) (map[s } } - for idx, pluginRef := range rawProfile.Plugins.Response { + for _, pluginRef := range rawProfile.Plugins.Response { rawPlugin := handle.Plugin(pluginRef.PluginRef) if rawPlugin == nil { return nil, fmt.Errorf("there is no plugin named %s", pluginRef.PluginRef) } - thePlugin, ok := rawPlugin.(requesthandling.ResponseProcessor) - if !ok { - return nil, fmt.Errorf("the plugin named %s is not a ResponseProcessor", pluginRef.PluginRef) + if bodyPlugin, ok := rawPlugin.(requesthandling.ResponseProcessor); ok { + theProfile.ResponsePlugins = append(theProfile.ResponsePlugins, bodyPlugin) + continue } - theProfile.ResponsePlugins[idx] = thePlugin + if chunkPlugin, ok := rawPlugin.(requesthandling.ResponseChunkProcessor); ok { + theProfile.ResponseChunkProcessors = append(theProfile.ResponseChunkProcessors, chunkPlugin) + continue + } + return nil, fmt.Errorf("the plugin named %s is not a ResponseProcessor nor ResponseChunkProcessor", pluginRef.PluginRef) + } + if len(theProfile.ResponsePlugins) > 0 && len(theProfile.ResponseChunkProcessors) > 0 { + return nil, fmt.Errorf("profile %s mixes ResponseProcessor and ResponseChunkProcessor plugins — a profile must use one type exclusively", rawProfile.Name) } + theProfile.NeedsResponseBuffering = len(theProfile.ResponsePlugins) > 0 profiles[rawProfile.Name] = &theProfile } diff --git a/pkg/framework/interface/requesthandling/plugins.go b/pkg/framework/interface/requesthandling/plugins.go index dad4ebb6..da49b86b 100644 --- a/pkg/framework/interface/requesthandling/plugins.go +++ b/pkg/framework/interface/requesthandling/plugins.go @@ -43,13 +43,23 @@ type RequestProcessor interface { ProcessRequest(ctx context.Context, cycleState *plugin.CycleState, request *InferenceRequest) error } +// ResponseProcessor processes the complete buffered response body. +// If any plugin in a profile implements this interface, the framework buffers +// the entire response before calling ProcessResponse on each such plugin. type ResponseProcessor interface { plugin.Plugin - // ProcessResponse runs the ResponseProcessor plugin. - // ResponseProcessor can mutate the headers and/or the body of the response. ProcessResponse(ctx context.Context, cycleState *plugin.CycleState, response *InferenceResponse) error } +// ResponseChunkProcessor processes individual response body chunks as they +// stream through without buffering. The framework converts the raw chunk bytes +// to a string once and passes it to all chunk processors. Plugins receive the +// InferenceResponse to allow header mutation. +type ResponseChunkProcessor interface { + plugin.Plugin + ProcessResponseChunk(ctx context.Context, cycleState *plugin.CycleState, response *InferenceResponse, chunk string, isFinal bool) error +} + type PostProcessor interface { plugin.Plugin diff --git a/pkg/framework/interface/requesthandling/types.go b/pkg/framework/interface/requesthandling/types.go index 78afa58d..43c70fbb 100644 --- a/pkg/framework/interface/requesthandling/types.go +++ b/pkg/framework/interface/requesthandling/types.go @@ -93,6 +93,30 @@ type InferenceRequest struct { type InferenceResponse struct { InferenceMessage + + // CurrentChunk holds the current response body chunk during streaming. + // Set by the framework before calling ResponseChunkProcessor plugins. + // Plugins can read or mutate this field; the framework uses the final + // value when building the ext_proc response. + CurrentChunk string + chunkMutated bool +} + +// SetChunk sets the current chunk content and marks it as mutated. +func (r *InferenceResponse) SetChunk(chunk string) { + r.CurrentChunk = chunk + r.chunkMutated = true +} + +// ChunkMutated reports whether any plugin modified the chunk via SetChunk. +func (r *InferenceResponse) ChunkMutated() bool { + return r.chunkMutated +} + +// ResetChunkState prepares the response for a new chunk processing cycle. +func (r *InferenceResponse) ResetChunkState(chunk string) { + r.CurrentChunk = chunk + r.chunkMutated = false } // NewInferenceRequest returns a new request with initialized Headers, Body, and mutatedHeaders. @@ -114,9 +138,14 @@ type Profile struct { // RequestPlugins are the request processing plugin instances executed by the request handler, // in the same order provided in the configuration file. RequestPlugins []RequestProcessor - // ResponsePlugins are the response processing plugin instances executed by the response handler, - // in the same order provided in the configuration file. + // ResponsePlugins process the complete buffered response body. ResponsePlugins []ResponseProcessor + // ResponseChunkProcessors process individual response chunks without buffering. + ResponseChunkProcessors []ResponseChunkProcessor + // NeedsResponseBuffering is true when any ResponsePlugin is present. + // The framework uses this to decide whether to buffer the full response body + // or stream chunks through ResponseChunkProcessors. + NeedsResponseBuffering bool // ModelSelectorPlugins are the Filter, Scorer (including WeightedScorer), and Picker plugin // instances to be wired into any model-selector plugin present in RequestPlugins. ModelSelectorPlugins []plugin.Plugin diff --git a/pkg/framework/interface/requesthandling/types_test.go b/pkg/framework/interface/requesthandling/types_test.go index f9c4843d..49fcb1f7 100644 --- a/pkg/framework/interface/requesthandling/types_test.go +++ b/pkg/framework/interface/requesthandling/types_test.go @@ -90,6 +90,48 @@ func TestSetBody(t *testing.T) { } } +func TestChunkMutation(t *testing.T) { + resp := NewInferenceResponse() + + if resp.ChunkMutated() { + t.Error("new InferenceResponse should not be marked as chunk-mutated") + } + + resp.ResetChunkState("original chunk") + if resp.CurrentChunk != "original chunk" { + t.Errorf("CurrentChunk = %q; want %q", resp.CurrentChunk, "original chunk") + } + if resp.ChunkMutated() { + t.Error("ResetChunkState should not mark chunk as mutated") + } + + resp.SetChunk("modified chunk") + if resp.CurrentChunk != "modified chunk" { + t.Errorf("CurrentChunk = %q; want %q", resp.CurrentChunk, "modified chunk") + } + if !resp.ChunkMutated() { + t.Error("expected ChunkMutated() to return true after SetChunk") + } +} + +func TestChunkMutation_ResetClearsMutatedFlag(t *testing.T) { + resp := NewInferenceResponse() + resp.ResetChunkState("chunk 1") + resp.SetChunk("mutated chunk 1") + + if !resp.ChunkMutated() { + t.Error("expected ChunkMutated() true after SetChunk") + } + + resp.ResetChunkState("chunk 2") + if resp.ChunkMutated() { + t.Error("ResetChunkState should clear the mutated flag") + } + if resp.CurrentChunk != "chunk 2" { + t.Errorf("CurrentChunk = %q; want %q", resp.CurrentChunk, "chunk 2") + } +} + func TestBodyMutated_FalseByDefault(t *testing.T) { req := NewInferenceRequest() if req.BodyMutated() { diff --git a/pkg/handlers/response.go b/pkg/handlers/response.go index 03210f11..e3830cf2 100644 --- a/pkg/handlers/response.go +++ b/pkg/handlers/response.go @@ -131,6 +131,89 @@ func (s *Server) generateEmptyResponseBodyResponse(responseBodyBytes []byte) []* return responses } +// HandleResponseChunk runs ResponseChunkProcessors on a single response body chunk +// and wraps the result in the ext_proc streaming response format. +func (s *Server) HandleResponseChunk(ctx context.Context, reqCtx *RequestContext, chunkBytes []byte, endOfStream bool) ([]*eppb.ProcessingResponse, error) { + // Bodiless requests (e.g., GET /v1/models) may not have a profile set. + if reqCtx.Profile == nil || len(reqCtx.Profile.ResponseChunkProcessors) == 0 { + return s.buildStreamedChunkResponse(reqCtx, chunkBytes, endOfStream), nil + } + + logger := log.FromContext(ctx).V(logutil.DEFAULT) + + chunk := string(chunkBytes) + reqCtx.Response.ResetChunkState(chunk) + + if err := s.runResponseChunkProcessors(ctx, reqCtx.CycleState, reqCtx.Response, chunk, endOfStream, reqCtx.Profile.ResponseChunkProcessors); err != nil { + logger.Error(err, "Failed to run response chunk processors") + return nil, err + } + + outBytes := chunkBytes + if reqCtx.Response.ChunkMutated() { + outBytes = []byte(reqCtx.Response.CurrentChunk) + } + + return s.buildStreamedChunkResponse(reqCtx, outBytes, endOfStream), nil +} + +// runResponseChunkProcessors executes chunk processors in the order they were registered. +// Each plugin receives response.CurrentChunk so mutations from earlier plugins are visible +// to later ones in the chain. +func (s *Server) runResponseChunkProcessors(ctx context.Context, cycleState *plugin.CycleState, response *requesthandling.InferenceResponse, chunk string, isFinal bool, processors []requesthandling.ResponseChunkProcessor) error { + logger := log.FromContext(ctx).V(logutil.DEFAULT) + verboseLogger := logger.V(logutil.VERBOSE) + + for _, cp := range processors { + if verboseLogger.Enabled() { + verboseLogger.Info("Executing response chunk plugin", "plugin", cp.TypedName()) + } + before := time.Now() + err := cp.ProcessResponseChunk(ctx, cycleState, response, response.CurrentChunk, isFinal) + metrics.RecordPluginProcessingLatency(responsePluginExtensionPoint, cp.TypedName().Type, cp.TypedName().Name, time.Since(before)) + if err != nil { + return err + } + } + return nil +} + +// buildStreamedChunkResponse wraps a chunk in the ext_proc streaming response format. +// On the first call (responseHeadersSent=false), it prepends a HeadersResponse to answer +// the deferred response headers — envoy requires this before it accepts body responses. +func (s *Server) buildStreamedChunkResponse(reqCtx *RequestContext, chunk []byte, endOfStream bool) []*eppb.ProcessingResponse { + responses := []*eppb.ProcessingResponse{ + { + Response: &eppb.ProcessingResponse_ResponseBody{ + ResponseBody: &eppb.BodyResponse{ + Response: &eppb.CommonResponse{ + BodyMutation: &eppb.BodyMutation{ + Mutation: &eppb.BodyMutation_StreamedResponse{ + StreamedResponse: &eppb.StreamedBodyResponse{ + Body: chunk, + EndOfStream: endOfStream, + }, + }, + }, + }, + }, + }, + }, + } + + if !reqCtx.ResponseHeadersSent { + headerResp := &eppb.ProcessingResponse{ + Response: &eppb.ProcessingResponse_ResponseHeaders{ + ResponseHeaders: &eppb.HeadersResponse{}, + }, + } + responses = append([]*eppb.ProcessingResponse{headerResp}, responses...) + reqCtx.ResponseHeadersSent = true + } + + return responses +} + // HandleResponseTrailers handles response trailers. func (s *Server) HandleResponseTrailers(trailers *eppb.HttpTrailers) ([]*eppb.ProcessingResponse, error) { return []*eppb.ProcessingResponse{ diff --git a/pkg/handlers/server.go b/pkg/handlers/server.go index 3cab9c26..b4288726 100644 --- a/pkg/handlers/server.go +++ b/pkg/handlers/server.go @@ -76,6 +76,7 @@ type RequestContext struct { RequestSentTimestamp time.Time ResponseFirstChunkTimestamp time.Time ResponseCompleteTimestamp time.Time + ResponseHeadersSent bool Profile *requesthandling.Profile CycleState *plugin.CycleState Request *requesthandling.InferenceRequest @@ -176,15 +177,25 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { if reqCtx.ResponseFirstChunkTimestamp.IsZero() { reqCtx.ResponseFirstChunkTimestamp = time.Now() } - responseBody = append(responseBody, v.ResponseBody.Body...) - if !v.ResponseBody.EndOfStream { - continue + + if reqCtx.Profile.NeedsResponseBuffering { + responseBody = append(responseBody, v.ResponseBody.Body...) + if !v.ResponseBody.EndOfStream { + // Keep accumulating — don't send responses or record metrics yet. + break + } + responses, err = s.HandleResponseBody(ctx, reqCtx, responseBody) + loggerVerbose.Info("processing response body complete") + } else { + responses, err = s.HandleResponseChunk(ctx, reqCtx, v.ResponseBody.Body, v.ResponseBody.EndOfStream) + loggerVerbose.Info("response chunk processing complete") + } + + if v.ResponseBody.EndOfStream { + reqCtx.ResponseCompleteTimestamp = time.Now() + model, _ := reqCtx.Request.Body["model"].(string) + metrics.RecordRequestTTFT(model, reqCtx.ResponseFirstChunkTimestamp.Sub(reqCtx.RequestReceivedTimestamp)) } - reqCtx.ResponseCompleteTimestamp = time.Now() - model, _ := reqCtx.Request.Body["model"].(string) - metrics.RecordRequestTTFT(model, reqCtx.ResponseFirstChunkTimestamp.Sub(reqCtx.RequestReceivedTimestamp)) - responses, err = s.HandleResponseBody(ctx, reqCtx, responseBody) - loggerVerbose.Info("processing response body complete") case *extProcPb.ProcessingRequest_ResponseTrailers: responses, err = s.HandleResponseTrailers(v.ResponseTrailers) default: diff --git a/pkg/handlers/server_test.go b/pkg/handlers/server_test.go index 332d71d4..75087c58 100644 --- a/pkg/handlers/server_test.go +++ b/pkg/handlers/server_test.go @@ -123,6 +123,7 @@ func TestHandleResponseBody_Streaming(t *testing.T) { wantFullBody := []byte(`{"choices":[{"text":"Hello!"}]}`) profiles := newTestProfiles() + profiles[testProfileName].NeedsResponseBuffering = true ref := newServerForTest(profiles) want, err := ref.HandleResponseBody(ctx, newTestRequestContext(profiles), wantFullBody) if err != nil { @@ -164,6 +165,7 @@ func TestHandleResponseBody_Streaming(t *testing.T) { t.Run(tc.name, func(t *testing.T) { streamCtx, cancel := context.WithCancel(logutil.NewTestLoggerIntoContext(context.Background())) profiles := newTestProfiles() + profiles[testProfileName].NeedsResponseBuffering = true srv := newServerForTest(profiles) testListener, errChan := utils.SetupTestStreamingServer(t, streamCtx, srv) process, conn := utils.GetStreamingServerClient(streamCtx, t)