diff --git a/pkg/epp/handlers/request.go b/pkg/epp/handlers/request.go index d7678fadf6..591ddf2627 100644 --- a/pkg/epp/handlers/request.go +++ b/pkg/epp/handlers/request.go @@ -133,10 +133,6 @@ func (s *Server) HandleRequestBody( }, }, } - // Print headers for debugging - for _, header := range headers { - logger.V(logutil.DEBUG).Info("Request body header", "key", header.Header.Key, "value", header.Header.RawValue) - } targetEndpointValue := &structpb.Struct{ Fields: map[string]*structpb.Value{ @@ -184,14 +180,42 @@ func (s *Server) HandleRequestBody( return resp, nil } -func HandleRequestHeaders( +func (s *Server) HandleRequestHeaders( ctx context.Context, reqCtx *RequestContext, req *extProcPb.ProcessingRequest, ) *extProcPb.ProcessingResponse { + logger := log.FromContext(ctx) r := req.Request h := r.(*extProcPb.ProcessingRequest_RequestHeaders) - log.FromContext(ctx).V(logutil.VERBOSE).Info("Handling request headers", "headers", h) + logger.V(logutil.VERBOSE).Info("Handling request headers", "headers", h) + + // Call FrontEnd service to get worker instance ID + if err := s.fetchWorkerIDFromFrontEnd(ctx, reqCtx); err != nil { + logger.V(logutil.DEFAULT).Error(err, "Failed to fetch worker ID from FrontEnd service") + // Continue processing even if FrontEnd call fails + } + + // Create headers array starting with worker ID if available + var headers []*configPb.HeaderValueOption + + // Add worker ID header if obtained from FrontEnd service + if reqCtx.WorkerInstanceID != "" { + headers = append(headers, &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: "x-gateway-worker-id", + RawValue: []byte(reqCtx.WorkerInstanceID), + }, + AppendAction: configPb.HeaderValueOption_OVERWRITE_IF_EXISTS_OR_ADD, + }) + headers = append(headers, &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: "host_rewrite_header", + RawValue: []byte("x-gateway-destination-endpoint"), + }, + AppendAction: configPb.HeaderValueOption_OVERWRITE_IF_EXISTS_OR_ADD, + }) + } resp := &extProcPb.ProcessingResponse{ Response: &extProcPb.ProcessingResponse_RequestHeaders{ @@ -201,6 +225,9 @@ func HandleRequestHeaders( // based on the new "target-pod" header. // See https://www.envoyproxy.io/docs/envoy/latest/api-v3/service/ext_proc/v3/external_processor.proto#service-ext-proc-v3-commonresponse. ClearRouteCache: true, + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: headers, + }, }, }, }, diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index a92f091c55..5036c12ede 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -18,9 +18,14 @@ package handlers import ( "context" + "encoding/json" + "fmt" "io" + "net/http" "time" + "bytes" + extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" envoyTypePb "github.com/envoyproxy/go-control-plane/envoy/type/v3" "google.golang.org/grpc/codes" @@ -40,6 +45,11 @@ func NewServer(scheduler Scheduler, destinationEndpointHintMetadataNamespace, de destinationEndpointHintMetadataNamespace: destinationEndpointHintMetadataNamespace, destinationEndpointHintKey: destinationEndpointHintKey, datastore: datastore, + frontEndAddress: "localhost", + frontEndPort: "8000", + httpClient: &http.Client{ + Timeout: 10 * time.Second, + }, } } @@ -54,6 +64,11 @@ type Server struct { // back the picked endpoints. destinationEndpointHintMetadataNamespace string datastore datastore.Datastore + + // FrontEnd service configuration + frontEndAddress string + frontEndPort string + httpClient *http.Client } type Scheduler interface { @@ -104,7 +119,7 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { switch v := req.Request.(type) { case *extProcPb.ProcessingRequest_RequestHeaders: reqCtx.RequestReceivedTimestamp = time.Now() - resp = HandleRequestHeaders(ctx, reqCtx, req) + resp = s.HandleRequestHeaders(ctx, reqCtx, req) loggerVerbose.Info("Request context after HandleRequestHeaders", "context", reqCtx) case *extProcPb.ProcessingRequest_RequestBody: resp, err = s.HandleRequestBody(ctx, reqCtx, req) @@ -229,6 +244,7 @@ type RequestContext struct { ResponseComplete bool ResponseStatusCode string RequestRunning bool + WorkerInstanceID string // Worker ID from FrontEnd service RequestState StreamRequestState modelServerStreaming bool @@ -254,3 +270,91 @@ const ( BodyResponseResponsesComplete StreamRequestState = 6 TrailerResponseResponsesComplete StreamRequestState = 7 ) + +// fetchWorkerIDFromFrontEnd makes a blocking HTTP request to the FrontEnd service +// to get the worker_instance_id for this request +func (s *Server) fetchWorkerIDFromFrontEnd(ctx context.Context, reqCtx *RequestContext) error { + logger := log.FromContext(ctx) + logger.V(logutil.DEFAULT).Info("Starting fetchWorkerIDFromFrontEnd function") + + // Construct FrontEnd URL + frontEndURL := fmt.Sprintf("http://%s:%s/v1/chat/completions", s.frontEndAddress, s.frontEndPort) + logger.V(logutil.DEFAULT).Info("Constructed FrontEnd URL", "url", frontEndURL) + + // Create request body with nvext annotations + requestBody := map[string]interface{}{ + "nvext": map[string]interface{}{ + "annotations": []string{"query_instance_id"}, + }, + } + logger.V(logutil.DEFAULT).Info("Created request body structure", "requestBody", requestBody) + + requestJSON, err := json.Marshal(requestBody) + if err != nil { + logger.V(logutil.DEFAULT).Error(err, "Failed to marshal request body") + return fmt.Errorf("failed to marshal request body: %w", err) + } + logger.V(logutil.DEFAULT).Info("Successfully marshaled request body to JSON", "json", string(requestJSON)) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, frontEndURL, bytes.NewBuffer(requestJSON)) + if err != nil { + logger.V(logutil.DEFAULT).Error(err, "Failed to create FrontEnd request") + return fmt.Errorf("failed to create FrontEnd request: %w", err) + } + logger.V(logutil.DEFAULT).Info("Successfully created HTTP request object") + + // Set appropriate headers + req.Header.Set("Content-Type", "application/json") + logger.V(logutil.DEFAULT).Info("Set Content-Type header to application/json") + + // Make the blocking HTTP request to fetch Routing + logger.V(logutil.DEFAULT).Info("Making blocking HTTP request to FrontEnd", + "url", frontEndURL, + "body", string(requestJSON)) + resp, err := s.httpClient.Do(req) + if err != nil { + logger.V(logutil.DEFAULT).Error(err, "Failed to make FrontEnd request") + return fmt.Errorf("failed to make FrontEnd request: %w", err) + } + defer resp.Body.Close() + logger.V(logutil.DEFAULT).Info("Successfully received HTTP response from FrontEnd", "statusCode", resp.StatusCode) + + // Read response body + body, err := io.ReadAll(resp.Body) + if err != nil { + logger.V(logutil.DEFAULT).Error(err, "Failed to read FrontEnd response body") + return fmt.Errorf("failed to read FrontEnd response body: %w", err) + } + logger.V(logutil.DEFAULT).Info("Successfully read response body", "bodyLength", len(body), "body", string(body)) + + // Parse JSON response + var responseData map[string]interface{} + if err := json.Unmarshal(body, &responseData); err != nil { + logger.V(logutil.DEFAULT).Error(err, "Failed to unmarshal FrontEnd response") + return fmt.Errorf("failed to unmarshal FrontEnd response: %w", err) + } + logger.V(logutil.DEFAULT).Info("Successfully unmarshaled JSON response", "responseData", responseData) + + // Extract worker_instance_id + if workerID, exists := responseData["worker_instance_id"]; exists { + logger.V(logutil.DEFAULT).Info("Found worker_instance_id key in response", "value", workerID) + if workerIDStr, ok := workerID.(string); ok && workerIDStr != "" { + reqCtx.WorkerInstanceID = workerIDStr + logger.V(logutil.DEFAULT).Info("Extracted worker instance ID from FrontEnd", "worker_instance_id", workerIDStr) + logger.V(logutil.DEFAULT).Info("Successfully completed fetchWorkerIDFromFrontEnd function") + return nil + } + logger.V(logutil.DEFAULT).Info("worker_instance_id exists but is not a valid non-empty string", "type", fmt.Sprintf("%T", workerID), "value", workerID) + } else { + logger.V(logutil.DEFAULT).Info("worker_instance_id key not found in response", "availableKeys", getMapKeys(responseData)) + } + + logger.V(logutil.DEFAULT).Info("worker_instance_id not found or empty in FrontEnd response") + return fmt.Errorf("worker_instance_id not found in FrontEnd response") +} + +// SetFrontEndConfig allows external configuration of the FrontEnd service +func (s *Server) SetFrontEndConfig(address, port string) { + s.frontEndAddress = address + s.frontEndPort = port +} diff --git a/pkg/epp/handlers/streamingserver.go b/pkg/epp/handlers/streamingserver.go index 874dd734f4..7cd8d4aabb 100644 --- a/pkg/epp/handlers/streamingserver.go +++ b/pkg/epp/handlers/streamingserver.go @@ -17,11 +17,13 @@ limitations under the License. package handlers import ( + "bytes" "context" "encoding/json" "fmt" "io" "math/rand" + "net/http" "strconv" "strings" "time" @@ -48,6 +50,12 @@ func NewStreamingServer(scheduler Scheduler, destinationEndpointHintMetadataName destinationEndpointHintMetadataNamespace: destinationEndpointHintMetadataNamespace, destinationEndpointHintKey: destinationEndpointHintKey, datastore: datastore, + + // Default FrontEnd configuration (same sidecar) + frontEndPort: "8000", + httpClient: &http.Client{ + Timeout: 10 * time.Second, + }, } } @@ -60,6 +68,10 @@ type StreamingServer struct { // back the picked endpoints. destinationEndpointHintMetadataNamespace string datastore datastore.Datastore + + // FrontEnd service configuration + frontEndPort string + httpClient *http.Client } func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { @@ -380,6 +392,12 @@ func (s *StreamingServer) HandleRequestBody( "model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "endpoint", targetPod, "endpoint metrics", fmt.Sprintf("%+v", target)) + // Call FrontEnd service to get worker instance ID + if err = s.fetchWorkerIDFromFrontEnd(ctx, reqCtx); err != nil { + logger.V(logutil.DEFAULT).Error(err, "Failed to fetch worker ID from FrontEnd service") + // Continue processing even if FrontEnd call fails + } + reqCtx.Model = llmReq.Model reqCtx.ResolvedTargetModel = llmReq.ResolvedTargetModel reqCtx.RequestSize = len(requestBodyBytes) @@ -495,6 +513,102 @@ func (s *StreamingServer) HandleRequestHeaders(ctx context.Context, reqCtx *Requ return nil } +// SetFrontEndConfig allows configuring the FrontEnd service endpoint +func (s *StreamingServer) SetFrontEndConfig(port string) { + s.frontEndPort = port +} + +// fetchWorkerIDFromFrontEnd makes a blocking HTTP request to the FrontEnd service +// to get the worker_instance_id for this request +func (s *StreamingServer) fetchWorkerIDFromFrontEnd(ctx context.Context, reqCtx *RequestContext) error { + logger := log.FromContext(ctx) + logger.V(logutil.DEFAULT).Info("Starting fetchWorkerIDFromFrontEnd function") + + // Construct FrontEnd URL + frontEndURL := fmt.Sprintf("http://localhost:%s/v1/chat/completions", s.frontEndPort) + logger.V(logutil.DEFAULT).Info("Constructed FrontEnd URL", "url", frontEndURL) + + // Create request body with nvext annotations + requestBody := map[string]interface{}{ + "nvext": map[string]interface{}{ + "annotations": []string{"query_instance_id"}, + }, + } + logger.V(logutil.DEFAULT).Info("Created request body structure", "requestBody", requestBody) + + requestJSON, err := json.Marshal(requestBody) + if err != nil { + logger.V(logutil.DEFAULT).Error(err, "Failed to marshal request body") + return fmt.Errorf("failed to marshal request body: %w", err) + } + logger.V(logutil.DEFAULT).Info("Successfully marshaled request body to JSON", "json", string(requestJSON)) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, frontEndURL, bytes.NewBuffer(requestJSON)) + if err != nil { + logger.V(logutil.DEFAULT).Error(err, "Failed to create FrontEnd request") + return fmt.Errorf("failed to create FrontEnd request: %w", err) + } + logger.V(logutil.DEFAULT).Info("Successfully created HTTP request object") + + // Set appropriate headers + req.Header.Set("Content-Type", "application/json") + logger.V(logutil.DEFAULT).Info("Set Content-Type header to application/json") + + // Make the blocking HTTP request to fetch Routing + logger.V(logutil.DEFAULT).Info("Making blocking HTTP request to FrontEnd", + "url", frontEndURL, + "body", string(requestJSON)) + resp, err := s.httpClient.Do(req) + if err != nil { + logger.V(logutil.DEFAULT).Error(err, "Failed to make FrontEnd request") + return fmt.Errorf("failed to make FrontEnd request: %w", err) + } + defer resp.Body.Close() + logger.V(logutil.DEFAULT).Info("Successfully received HTTP response from FrontEnd", "statusCode", resp.StatusCode) + + // Read response body + body, err := io.ReadAll(resp.Body) + if err != nil { + logger.V(logutil.DEFAULT).Error(err, "Failed to read FrontEnd response body") + return fmt.Errorf("failed to read FrontEnd response body: %w", err) + } + logger.V(logutil.DEFAULT).Info("Successfully read response body", "bodyLength", len(body), "body", string(body)) + + // Parse JSON response + var responseData map[string]interface{} + if err := json.Unmarshal(body, &responseData); err != nil { + logger.V(logutil.DEFAULT).Error(err, "Failed to unmarshal FrontEnd response") + return fmt.Errorf("failed to unmarshal FrontEnd response: %w", err) + } + logger.V(logutil.DEFAULT).Info("Successfully unmarshaled JSON response", "responseData", responseData) + + // Extract worker_instance_id + if workerID, exists := responseData["worker_instance_id"]; exists { + logger.V(logutil.DEFAULT).Info("Found worker_instance_id key in response", "value", workerID) + if workerIDStr, ok := workerID.(string); ok && workerIDStr != "" { + reqCtx.WorkerInstanceID = workerIDStr + logger.V(logutil.DEFAULT).Info("Extracted worker instance ID from FrontEnd", "worker_instance_id", workerIDStr) + logger.V(logutil.DEFAULT).Info("Successfully completed fetchWorkerIDFromFrontEnd function") + return nil + } + logger.V(logutil.DEFAULT).Info("worker_instance_id exists but is not a valid non-empty string", "type", fmt.Sprintf("%T", workerID), "value", workerID) + } else { + logger.V(logutil.DEFAULT).Info("worker_instance_id key not found in response", "availableKeys", getMapKeys(responseData)) + } + + logger.V(logutil.DEFAULT).Info("worker_instance_id not found or empty in FrontEnd response") + return fmt.Errorf("worker_instance_id not found in FrontEnd response") +} + +// getMapKeys returns a slice of keys from a map[string]interface{} +func getMapKeys(m map[string]interface{}) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys +} + func (s *StreamingServer) populateRequestHeaderResponse(reqCtx *RequestContext, endpoint string, requestBodyLength int) { headers := []*configPb.HeaderValueOption{ { @@ -504,6 +618,25 @@ func (s *StreamingServer) populateRequestHeaderResponse(reqCtx *RequestContext, }, }, } + + // Add worker ID header if obtained from FrontEnd service + if reqCtx.WorkerInstanceID != "" { + headers = append(headers, &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: "x-gateway-worker-id", + RawValue: []byte(reqCtx.WorkerInstanceID), + }, + AppendAction: configPb.HeaderValueOption_OVERWRITE_IF_EXISTS_OR_ADD, + }) + headers = append(headers, &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: "host_rewrite_header", + RawValue: []byte("x-gateway-destination-endpoint"), + }, + AppendAction: configPb.HeaderValueOption_OVERWRITE_IF_EXISTS_OR_ADD, + }) + } + if requestBodyLength > 0 { // We need to update the content length header if the body is mutated, see Envoy doc: // https://www.envoyproxy.io/docs/envoy/latest/api-v3/extensions/filters/http/ext_proc/v3/processing_mode.proto