diff --git a/pkg/epp/handlers/request.go b/pkg/epp/handlers/request.go index 5789b3cc95..823ca43b51 100644 --- a/pkg/epp/handlers/request.go +++ b/pkg/epp/handlers/request.go @@ -106,6 +106,7 @@ func (s *StreamingServer) generateHeaders(reqCtx *RequestContext) []*configPb.He }, }, } + if reqCtx.RequestSize > 0 { // We need to update the content length header if the body is mutated, see Envoy doc: // https://www.envoyproxy.io/docs/envoy/latest/api-v3/extensions/filters/http/ext_proc/v3/processing_mode.proto diff --git a/pkg/epp/handlers/response.go b/pkg/epp/handlers/response.go index 47c0b9d742..97369b3d32 100644 --- a/pkg/epp/handlers/response.go +++ b/pkg/epp/handlers/response.go @@ -42,6 +42,15 @@ func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *Reques logger.V(logutil.DEFAULT).Error(err, "error marshalling responseBody") return reqCtx, err } + + // Extract worker_instance_id from response body to include in response headers + if workerInstanceID, exists := response["worker_instance_id"]; exists { + if workerIDStr, ok := workerInstanceID.(string); ok && workerIDStr != "" { + reqCtx.WorkerInstanceID = workerIDStr + logger.V(logutil.VERBOSE).Info("Extracted worker instance ID from response", "worker_instance_id", workerIDStr) + } + } + if response["usage"] != nil { usg := response["usage"].(map[string]any) usage := Usage{ @@ -66,12 +75,37 @@ func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *Reques // The function is to handle streaming response if the modelServer is streaming. func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, responseText string) { + logger := log.FromContext(ctx) + if strings.Contains(responseText, streamingEndMsg) { resp := parseRespForUsage(ctx, responseText) reqCtx.Usage = resp.Usage metrics.RecordInputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, resp.Usage.PromptTokens) metrics.RecordOutputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, resp.Usage.CompletionTokens) } + + // Extract worker_instance_id from streaming response text + lines := strings.Split(responseText, "\n") + for _, line := range lines { + if !strings.HasPrefix(line, streamingRespPrefix) { + continue + } + content := strings.TrimPrefix(line, streamingRespPrefix) + if content == "[DONE]" { + continue + } + + var responseData map[string]any + if err := json.Unmarshal([]byte(content), &responseData); err == nil { + if workerInstanceID, exists := responseData["worker_instance_id"]; exists { + if workerIDStr, ok := workerInstanceID.(string); ok && workerIDStr != "" { + reqCtx.WorkerInstanceID = workerIDStr + logger.V(logutil.VERBOSE).Info("Extracted worker instance ID from streaming response", "worker_instance_id", workerIDStr) + break // Found worker_instance_id, no need to continue parsing + } + } + } + } } func (s *StreamingServer) HandleResponseHeaders(ctx context.Context, reqCtx *RequestContext, resp *extProcPb.ProcessingRequest_ResponseHeaders) (*RequestContext, error) { @@ -130,6 +164,16 @@ func (s *StreamingServer) generateResponseHeaders(reqCtx *RequestContext) []*con }, } + // Add worker_instance_id to response headers if available + if reqCtx.WorkerInstanceID != "" { + headers = append(headers, &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: "x-worker-instance-id", + RawValue: []byte(reqCtx.WorkerInstanceID), + }, + }) + } + // include all headers for key, value := range reqCtx.Response.Headers { headers = append(headers, &configPb.HeaderValueOption{ diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index 30596606ec..3a2841fe28 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -109,12 +109,15 @@ type RequestContext struct { respHeaderResp *extProcPb.ProcessingResponse respBodyResp []*extProcPb.ProcessingResponse respTrailerResp *extProcPb.ProcessingResponse + + WorkerInstanceID string } type Request struct { - Headers map[string]string - Body map[string]any - Metadata map[string]any + Headers map[string]string + Body map[string]any + Metadata map[string]any + Annotations []string } type Response struct { Headers map[string]string @@ -143,9 +146,10 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) reqCtx := &RequestContext{ RequestState: RequestReceived, Request: &Request{ - Headers: make(map[string]string), - Body: make(map[string]any), - Metadata: make(map[string]any), + Headers: make(map[string]string), + Body: make(map[string]any), + Metadata: make(map[string]any), + Annotations: []string{}, }, Response: &Response{ Headers: make(map[string]string), @@ -221,16 +225,41 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) reqCtx, err = s.director.HandleRequest(ctx, reqCtx) if err != nil { - logger.V(logutil.DEFAULT).Error(err, "Error handling request") + logger.V(logutil.DEFAULT).Error(err, "Error handling request!") break } + // Add query_instance_id annotation to the request metadata before sending to FrontEnd + if reqCtx.Request != nil { + // Ensure Annotations slice is initialized + if reqCtx.Request.Annotations == nil { + reqCtx.Request.Annotations = []string{} + } + + // Add the annotation (if not already present) + found := false + for _, a := range reqCtx.Request.Annotations { + if a == "query_instance_id" { + found = true + break + } + } + + if !found { + reqCtx.Request.Annotations = append(reqCtx.Request.Annotations, "query_instance_id") + logger.V(logutil.VERBOSE).Info("Added query_instance_id annotation to request") + } + } + // Populate the ExtProc protocol responses for the request body. requestBodyBytes, err := json.Marshal(reqCtx.Request.Body) if err != nil { logger.V(logutil.DEFAULT).Error(err, "Error marshalling request body") break } + + // Log the complete request body being sent to FrontEnd for debugging + logger.V(logutil.VERBOSE).Info("Sending request body to FrontEnd", "request_body", string(requestBodyBytes)) reqCtx.RequestSize = len(requestBodyBytes) reqCtx.reqHeaderResp = s.generateRequestHeaderResponse(reqCtx) reqCtx.reqBodyResp = s.generateRequestBodyResponses(requestBodyBytes)