Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkg/epp/handlers/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ func (s *StreamingServer) generateHeaders(reqCtx *RequestContext) []*configPb.He
},
},
}

if reqCtx.RequestSize > 0 {
// We need to update the content length header if the body is mutated, see Envoy doc:
// https://www.envoyproxy.io/docs/envoy/latest/api-v3/extensions/filters/http/ext_proc/v3/processing_mode.proto
Expand Down
44 changes: 44 additions & 0 deletions pkg/epp/handlers/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *Reques
logger.V(logutil.DEFAULT).Error(err, "error marshalling responseBody")
return reqCtx, err
}

// Extract worker_instance_id from response body to include in response headers
if workerInstanceID, exists := response["worker_instance_id"]; exists {
if workerIDStr, ok := workerInstanceID.(string); ok && workerIDStr != "" {
reqCtx.WorkerInstanceID = workerIDStr
logger.V(logutil.VERBOSE).Info("Extracted worker instance ID from response", "worker_instance_id", workerIDStr)
}
}

if response["usage"] != nil {
usg := response["usage"].(map[string]any)
usage := Usage{
Expand All @@ -66,12 +75,37 @@ func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *Reques

// The function is to handle streaming response if the modelServer is streaming.
func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, responseText string) {
logger := log.FromContext(ctx)

if strings.Contains(responseText, streamingEndMsg) {
resp := parseRespForUsage(ctx, responseText)
reqCtx.Usage = resp.Usage
metrics.RecordInputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, resp.Usage.PromptTokens)
metrics.RecordOutputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, resp.Usage.CompletionTokens)
}

// Extract worker_instance_id from streaming response text
lines := strings.Split(responseText, "\n")
for _, line := range lines {
if !strings.HasPrefix(line, streamingRespPrefix) {
continue
}
content := strings.TrimPrefix(line, streamingRespPrefix)
if content == "[DONE]" {
continue
}

var responseData map[string]any
if err := json.Unmarshal([]byte(content), &responseData); err == nil {
if workerInstanceID, exists := responseData["worker_instance_id"]; exists {
if workerIDStr, ok := workerInstanceID.(string); ok && workerIDStr != "" {
reqCtx.WorkerInstanceID = workerIDStr
logger.V(logutil.VERBOSE).Info("Extracted worker instance ID from streaming response", "worker_instance_id", workerIDStr)
break // Found worker_instance_id, no need to continue parsing
}
}
}
}
}

func (s *StreamingServer) HandleResponseHeaders(ctx context.Context, reqCtx *RequestContext, resp *extProcPb.ProcessingRequest_ResponseHeaders) (*RequestContext, error) {
Expand Down Expand Up @@ -130,6 +164,16 @@ func (s *StreamingServer) generateResponseHeaders(reqCtx *RequestContext) []*con
},
}

// Add worker_instance_id to response headers if available
if reqCtx.WorkerInstanceID != "" {
headers = append(headers, &configPb.HeaderValueOption{
Header: &configPb.HeaderValue{
Key: "x-worker-instance-id",
RawValue: []byte(reqCtx.WorkerInstanceID),
},
})
}

// include all headers
for key, value := range reqCtx.Response.Headers {
headers = append(headers, &configPb.HeaderValueOption{
Expand Down
43 changes: 36 additions & 7 deletions pkg/epp/handlers/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,15 @@ type RequestContext struct {
respHeaderResp *extProcPb.ProcessingResponse
respBodyResp []*extProcPb.ProcessingResponse
respTrailerResp *extProcPb.ProcessingResponse

WorkerInstanceID string
}

type Request struct {
Headers map[string]string
Body map[string]any
Metadata map[string]any
Headers map[string]string
Body map[string]any
Metadata map[string]any
Annotations []string
}
type Response struct {
Headers map[string]string
Expand Down Expand Up @@ -143,9 +146,10 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
reqCtx := &RequestContext{
RequestState: RequestReceived,
Request: &Request{
Headers: make(map[string]string),
Body: make(map[string]any),
Metadata: make(map[string]any),
Headers: make(map[string]string),
Body: make(map[string]any),
Metadata: make(map[string]any),
Annotations: []string{},
},
Response: &Response{
Headers: make(map[string]string),
Expand Down Expand Up @@ -221,16 +225,41 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)

reqCtx, err = s.director.HandleRequest(ctx, reqCtx)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "Error handling request")
logger.V(logutil.DEFAULT).Error(err, "Error handling request!")
break
}

// Add query_instance_id annotation to the request metadata before sending to FrontEnd
if reqCtx.Request != nil {
// Ensure Annotations slice is initialized
if reqCtx.Request.Annotations == nil {
reqCtx.Request.Annotations = []string{}
}

// Add the annotation (if not already present)
found := false
for _, a := range reqCtx.Request.Annotations {
if a == "query_instance_id" {
found = true
break
}
}

if !found {
reqCtx.Request.Annotations = append(reqCtx.Request.Annotations, "query_instance_id")
logger.V(logutil.VERBOSE).Info("Added query_instance_id annotation to request")
}
}

// Populate the ExtProc protocol responses for the request body.
requestBodyBytes, err := json.Marshal(reqCtx.Request.Body)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "Error marshalling request body")
break
}

// Log the complete request body being sent to FrontEnd for debugging
logger.V(logutil.VERBOSE).Info("Sending request body to FrontEnd", "request_body", string(requestBodyBytes))
reqCtx.RequestSize = len(requestBodyBytes)
reqCtx.reqHeaderResp = s.generateRequestHeaderResponse(reqCtx)
reqCtx.reqBodyResp = s.generateRequestBodyResponses(requestBodyBytes)
Expand Down