Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions pkg/epp/handlers/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,6 @@ func (s *Server) HandleRequestBody(
},
},
}
// Print headers for debugging
for _, header := range headers {
logger.V(logutil.DEBUG).Info("Request body header", "key", header.Header.Key, "value", header.Header.RawValue)
}

targetEndpointValue := &structpb.Struct{
Fields: map[string]*structpb.Value{
Expand Down Expand Up @@ -184,14 +180,42 @@ func (s *Server) HandleRequestBody(
return resp, nil
}

func HandleRequestHeaders(
func (s *Server) HandleRequestHeaders(
ctx context.Context,
reqCtx *RequestContext,
req *extProcPb.ProcessingRequest,
) *extProcPb.ProcessingResponse {
logger := log.FromContext(ctx)
r := req.Request
h := r.(*extProcPb.ProcessingRequest_RequestHeaders)
log.FromContext(ctx).V(logutil.VERBOSE).Info("Handling request headers", "headers", h)
logger.V(logutil.VERBOSE).Info("Handling request headers", "headers", h)

// Call FrontEnd service to get worker instance ID
if err := s.fetchWorkerIDFromFrontEnd(ctx, reqCtx); err != nil {
logger.V(logutil.DEFAULT).Error(err, "Failed to fetch worker ID from FrontEnd service")
// Continue processing even if FrontEnd call fails
}

// Create headers array starting with worker ID if available
var headers []*configPb.HeaderValueOption

// Add worker ID header if obtained from FrontEnd service
if reqCtx.WorkerInstanceID != "" {
headers = append(headers, &configPb.HeaderValueOption{
Header: &configPb.HeaderValue{
Key: "x-gateway-worker-id",
RawValue: []byte(reqCtx.WorkerInstanceID),
},
AppendAction: configPb.HeaderValueOption_OVERWRITE_IF_EXISTS_OR_ADD,
})
headers = append(headers, &configPb.HeaderValueOption{
Header: &configPb.HeaderValue{
Key: "host_rewrite_header",
RawValue: []byte("x-gateway-destination-endpoint"),
},
AppendAction: configPb.HeaderValueOption_OVERWRITE_IF_EXISTS_OR_ADD,
})
}

resp := &extProcPb.ProcessingResponse{
Response: &extProcPb.ProcessingResponse_RequestHeaders{
Expand All @@ -201,6 +225,9 @@ func HandleRequestHeaders(
// based on the new "target-pod" header.
// See https://www.envoyproxy.io/docs/envoy/latest/api-v3/service/ext_proc/v3/external_processor.proto#service-ext-proc-v3-commonresponse.
ClearRouteCache: true,
HeaderMutation: &extProcPb.HeaderMutation{
SetHeaders: headers,
},
},
},
},
Expand Down
106 changes: 105 additions & 1 deletion pkg/epp/handlers/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,14 @@ package handlers

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"

"bytes"

extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
envoyTypePb "github.com/envoyproxy/go-control-plane/envoy/type/v3"
"google.golang.org/grpc/codes"
Expand All @@ -40,6 +45,11 @@ func NewServer(scheduler Scheduler, destinationEndpointHintMetadataNamespace, de
destinationEndpointHintMetadataNamespace: destinationEndpointHintMetadataNamespace,
destinationEndpointHintKey: destinationEndpointHintKey,
datastore: datastore,
frontEndAddress: "localhost",
frontEndPort: "8000",
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
}
}

Expand All @@ -54,6 +64,11 @@ type Server struct {
// back the picked endpoints.
destinationEndpointHintMetadataNamespace string
datastore datastore.Datastore

// FrontEnd service configuration
frontEndAddress string
frontEndPort string
httpClient *http.Client
}

type Scheduler interface {
Expand Down Expand Up @@ -104,7 +119,7 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
switch v := req.Request.(type) {
case *extProcPb.ProcessingRequest_RequestHeaders:
reqCtx.RequestReceivedTimestamp = time.Now()
resp = HandleRequestHeaders(ctx, reqCtx, req)
resp = s.HandleRequestHeaders(ctx, reqCtx, req)
loggerVerbose.Info("Request context after HandleRequestHeaders", "context", reqCtx)
case *extProcPb.ProcessingRequest_RequestBody:
resp, err = s.HandleRequestBody(ctx, reqCtx, req)
Expand Down Expand Up @@ -229,6 +244,7 @@ type RequestContext struct {
ResponseComplete bool
ResponseStatusCode string
RequestRunning bool
WorkerInstanceID string // Worker ID from FrontEnd service

RequestState StreamRequestState
modelServerStreaming bool
Expand All @@ -254,3 +270,91 @@ const (
BodyResponseResponsesComplete StreamRequestState = 6
TrailerResponseResponsesComplete StreamRequestState = 7
)

// fetchWorkerIDFromFrontEnd makes a blocking HTTP request to the FrontEnd service
// to get the worker_instance_id for this request
func (s *Server) fetchWorkerIDFromFrontEnd(ctx context.Context, reqCtx *RequestContext) error {
logger := log.FromContext(ctx)
logger.V(logutil.DEFAULT).Info("Starting fetchWorkerIDFromFrontEnd function")

// Construct FrontEnd URL
frontEndURL := fmt.Sprintf("http://%s:%s/v1/chat/completions", s.frontEndAddress, s.frontEndPort)
logger.V(logutil.DEFAULT).Info("Constructed FrontEnd URL", "url", frontEndURL)

// Create request body with nvext annotations
requestBody := map[string]interface{}{
"nvext": map[string]interface{}{
"annotations": []string{"query_instance_id"},
},
}
logger.V(logutil.DEFAULT).Info("Created request body structure", "requestBody", requestBody)

requestJSON, err := json.Marshal(requestBody)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "Failed to marshal request body")
return fmt.Errorf("failed to marshal request body: %w", err)
}
logger.V(logutil.DEFAULT).Info("Successfully marshaled request body to JSON", "json", string(requestJSON))

req, err := http.NewRequestWithContext(ctx, http.MethodPost, frontEndURL, bytes.NewBuffer(requestJSON))
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "Failed to create FrontEnd request")
return fmt.Errorf("failed to create FrontEnd request: %w", err)
}
logger.V(logutil.DEFAULT).Info("Successfully created HTTP request object")

// Set appropriate headers
req.Header.Set("Content-Type", "application/json")
logger.V(logutil.DEFAULT).Info("Set Content-Type header to application/json")

// Make the blocking HTTP request to fetch Routing
logger.V(logutil.DEFAULT).Info("Making blocking HTTP request to FrontEnd",
"url", frontEndURL,
"body", string(requestJSON))
resp, err := s.httpClient.Do(req)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "Failed to make FrontEnd request")
return fmt.Errorf("failed to make FrontEnd request: %w", err)
}
defer resp.Body.Close()
logger.V(logutil.DEFAULT).Info("Successfully received HTTP response from FrontEnd", "statusCode", resp.StatusCode)

// Read response body
body, err := io.ReadAll(resp.Body)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "Failed to read FrontEnd response body")
return fmt.Errorf("failed to read FrontEnd response body: %w", err)
}
logger.V(logutil.DEFAULT).Info("Successfully read response body", "bodyLength", len(body), "body", string(body))

// Parse JSON response
var responseData map[string]interface{}
if err := json.Unmarshal(body, &responseData); err != nil {
logger.V(logutil.DEFAULT).Error(err, "Failed to unmarshal FrontEnd response")
return fmt.Errorf("failed to unmarshal FrontEnd response: %w", err)
}
logger.V(logutil.DEFAULT).Info("Successfully unmarshaled JSON response", "responseData", responseData)

// Extract worker_instance_id
if workerID, exists := responseData["worker_instance_id"]; exists {
logger.V(logutil.DEFAULT).Info("Found worker_instance_id key in response", "value", workerID)
if workerIDStr, ok := workerID.(string); ok && workerIDStr != "" {
reqCtx.WorkerInstanceID = workerIDStr
logger.V(logutil.DEFAULT).Info("Extracted worker instance ID from FrontEnd", "worker_instance_id", workerIDStr)
logger.V(logutil.DEFAULT).Info("Successfully completed fetchWorkerIDFromFrontEnd function")
return nil
}
logger.V(logutil.DEFAULT).Info("worker_instance_id exists but is not a valid non-empty string", "type", fmt.Sprintf("%T", workerID), "value", workerID)
} else {
logger.V(logutil.DEFAULT).Info("worker_instance_id key not found in response", "availableKeys", getMapKeys(responseData))
}

logger.V(logutil.DEFAULT).Info("worker_instance_id not found or empty in FrontEnd response")
return fmt.Errorf("worker_instance_id not found in FrontEnd response")
}

// SetFrontEndConfig allows external configuration of the FrontEnd service
func (s *Server) SetFrontEndConfig(address, port string) {
s.frontEndAddress = address
s.frontEndPort = port
}
133 changes: 133 additions & 0 deletions pkg/epp/handlers/streamingserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ limitations under the License.
package handlers

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"math/rand"
"net/http"
"strconv"
"strings"
"time"
Expand All @@ -48,6 +50,12 @@ func NewStreamingServer(scheduler Scheduler, destinationEndpointHintMetadataName
destinationEndpointHintMetadataNamespace: destinationEndpointHintMetadataNamespace,
destinationEndpointHintKey: destinationEndpointHintKey,
datastore: datastore,

// Default FrontEnd configuration (same sidecar)
frontEndPort: "8000",
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
}
}

Expand All @@ -60,6 +68,10 @@ type StreamingServer struct {
// back the picked endpoints.
destinationEndpointHintMetadataNamespace string
datastore datastore.Datastore

// FrontEnd service configuration
frontEndPort string
httpClient *http.Client
}

func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
Expand Down Expand Up @@ -380,6 +392,12 @@ func (s *StreamingServer) HandleRequestBody(
"model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "endpoint", targetPod, "endpoint metrics",
fmt.Sprintf("%+v", target))

// Call FrontEnd service to get worker instance ID
if err = s.fetchWorkerIDFromFrontEnd(ctx, reqCtx); err != nil {
logger.V(logutil.DEFAULT).Error(err, "Failed to fetch worker ID from FrontEnd service")
// Continue processing even if FrontEnd call fails
}

reqCtx.Model = llmReq.Model
reqCtx.ResolvedTargetModel = llmReq.ResolvedTargetModel
reqCtx.RequestSize = len(requestBodyBytes)
Expand Down Expand Up @@ -495,6 +513,102 @@ func (s *StreamingServer) HandleRequestHeaders(ctx context.Context, reqCtx *Requ
return nil
}

// SetFrontEndConfig allows configuring the FrontEnd service endpoint
func (s *StreamingServer) SetFrontEndConfig(port string) {
s.frontEndPort = port
}

// fetchWorkerIDFromFrontEnd makes a blocking HTTP request to the FrontEnd service
// to get the worker_instance_id for this request
func (s *StreamingServer) fetchWorkerIDFromFrontEnd(ctx context.Context, reqCtx *RequestContext) error {
logger := log.FromContext(ctx)
logger.V(logutil.DEFAULT).Info("Starting fetchWorkerIDFromFrontEnd function")

// Construct FrontEnd URL
frontEndURL := fmt.Sprintf("http://localhost:%s/v1/chat/completions", s.frontEndPort)
logger.V(logutil.DEFAULT).Info("Constructed FrontEnd URL", "url", frontEndURL)

// Create request body with nvext annotations
requestBody := map[string]interface{}{
"nvext": map[string]interface{}{
"annotations": []string{"query_instance_id"},
},
}
logger.V(logutil.DEFAULT).Info("Created request body structure", "requestBody", requestBody)

requestJSON, err := json.Marshal(requestBody)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "Failed to marshal request body")
return fmt.Errorf("failed to marshal request body: %w", err)
}
logger.V(logutil.DEFAULT).Info("Successfully marshaled request body to JSON", "json", string(requestJSON))

req, err := http.NewRequestWithContext(ctx, http.MethodPost, frontEndURL, bytes.NewBuffer(requestJSON))
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "Failed to create FrontEnd request")
return fmt.Errorf("failed to create FrontEnd request: %w", err)
}
logger.V(logutil.DEFAULT).Info("Successfully created HTTP request object")

// Set appropriate headers
req.Header.Set("Content-Type", "application/json")
logger.V(logutil.DEFAULT).Info("Set Content-Type header to application/json")

// Make the blocking HTTP request to fetch Routing
logger.V(logutil.DEFAULT).Info("Making blocking HTTP request to FrontEnd",
"url", frontEndURL,
"body", string(requestJSON))
resp, err := s.httpClient.Do(req)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "Failed to make FrontEnd request")
return fmt.Errorf("failed to make FrontEnd request: %w", err)
}
defer resp.Body.Close()
logger.V(logutil.DEFAULT).Info("Successfully received HTTP response from FrontEnd", "statusCode", resp.StatusCode)

// Read response body
body, err := io.ReadAll(resp.Body)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "Failed to read FrontEnd response body")
return fmt.Errorf("failed to read FrontEnd response body: %w", err)
}
logger.V(logutil.DEFAULT).Info("Successfully read response body", "bodyLength", len(body), "body", string(body))

// Parse JSON response
var responseData map[string]interface{}
if err := json.Unmarshal(body, &responseData); err != nil {
logger.V(logutil.DEFAULT).Error(err, "Failed to unmarshal FrontEnd response")
return fmt.Errorf("failed to unmarshal FrontEnd response: %w", err)
}
logger.V(logutil.DEFAULT).Info("Successfully unmarshaled JSON response", "responseData", responseData)

// Extract worker_instance_id
if workerID, exists := responseData["worker_instance_id"]; exists {
logger.V(logutil.DEFAULT).Info("Found worker_instance_id key in response", "value", workerID)
if workerIDStr, ok := workerID.(string); ok && workerIDStr != "" {
reqCtx.WorkerInstanceID = workerIDStr
logger.V(logutil.DEFAULT).Info("Extracted worker instance ID from FrontEnd", "worker_instance_id", workerIDStr)
logger.V(logutil.DEFAULT).Info("Successfully completed fetchWorkerIDFromFrontEnd function")
return nil
}
logger.V(logutil.DEFAULT).Info("worker_instance_id exists but is not a valid non-empty string", "type", fmt.Sprintf("%T", workerID), "value", workerID)
} else {
logger.V(logutil.DEFAULT).Info("worker_instance_id key not found in response", "availableKeys", getMapKeys(responseData))
}

logger.V(logutil.DEFAULT).Info("worker_instance_id not found or empty in FrontEnd response")
return fmt.Errorf("worker_instance_id not found in FrontEnd response")
}

// getMapKeys returns a slice of keys from a map[string]interface{}
func getMapKeys(m map[string]interface{}) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
return keys
}

func (s *StreamingServer) populateRequestHeaderResponse(reqCtx *RequestContext, endpoint string, requestBodyLength int) {
headers := []*configPb.HeaderValueOption{
{
Expand All @@ -504,6 +618,25 @@ func (s *StreamingServer) populateRequestHeaderResponse(reqCtx *RequestContext,
},
},
}

// Add worker ID header if obtained from FrontEnd service
if reqCtx.WorkerInstanceID != "" {
headers = append(headers, &configPb.HeaderValueOption{
Header: &configPb.HeaderValue{
Key: "x-gateway-worker-id",
RawValue: []byte(reqCtx.WorkerInstanceID),
},
AppendAction: configPb.HeaderValueOption_OVERWRITE_IF_EXISTS_OR_ADD,
})
headers = append(headers, &configPb.HeaderValueOption{
Header: &configPb.HeaderValue{
Key: "host_rewrite_header",
RawValue: []byte("x-gateway-destination-endpoint"),
},
AppendAction: configPb.HeaderValueOption_OVERWRITE_IF_EXISTS_OR_ADD,
})
}

if requestBodyLength > 0 {
// We need to update the content length header if the body is mutated, see Envoy doc:
// https://www.envoyproxy.io/docs/envoy/latest/api-v3/extensions/filters/http/ext_proc/v3/processing_mode.proto
Expand Down