Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions pkg/epp/handlers/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,24 @@ func (s *StreamingServer) generateHeaders(reqCtx *RequestContext) []*configPb.He
},
},
}

// Add worker ID header if obtained from FrontEnd service
if reqCtx.WorkerInstanceID != "" {
headers = append(headers, &configPb.HeaderValueOption{
Header: &configPb.HeaderValue{
Key: "x-gateway-worker-id",
RawValue: []byte(reqCtx.WorkerInstanceID),
},
})
}

// Add host rewrite header
headers = append(headers, &configPb.HeaderValueOption{
Header: &configPb.HeaderValue{
Key: "host_rewrite_header",
RawValue: []byte("x-gateway-destination-endpoint"),
},
})
if reqCtx.RequestSize > 0 {
// We need to update the content length header if the body is mutated, see Envoy doc:
// https://www.envoyproxy.io/docs/envoy/latest/api-v3/extensions/filters/http/ext_proc/v3/processing_mode.proto
Expand Down
98 changes: 98 additions & 0 deletions pkg/epp/handlers/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@ limitations under the License.
package handlers

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"

Expand Down Expand Up @@ -50,9 +53,18 @@ func NewStreamingServer(destinationEndpointHintMetadataNamespace, destinationEnd
destinationEndpointHintKey: destinationEndpointHintKey,
director: director,
datastore: datastore,
frontEndAddress: "localhost", // Default FrontEnd address (same sidecar)
frontEndPort: "8081", // Default FrontEnd port
httpClient: &http.Client{Timeout: 30 * time.Second},
}
}

// SetFrontEndConfig allows configuration of the FrontEnd service endpoint
func (s *StreamingServer) SetFrontEndConfig(address, port string) {
s.frontEndAddress = address
s.frontEndPort = port
}

type Director interface {
HandleRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
HandleResponse(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
Expand All @@ -74,6 +86,11 @@ type StreamingServer struct {
destinationEndpointHintMetadataNamespace string
datastore Datastore
director Director

// FrontEnd service configuration for worker ID requests
frontEndAddress string
frontEndPort string
httpClient *http.Client
}

// RequestContext stores context information during the life time of an HTTP request.
Expand Down Expand Up @@ -109,6 +126,8 @@ type RequestContext struct {
respHeaderResp *extProcPb.ProcessingResponse
respBodyResp []*extProcPb.ProcessingResponse
respTrailerResp *extProcPb.ProcessingResponse

WorkerInstanceID string // Worker ID from FrontEnd service
}

type Request struct {
Expand Down Expand Up @@ -225,6 +244,12 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
break
}

// Make blocking HTTP request to FrontEnd service to get worker_instance_id
if err = s.fetchWorkerIDFromFrontEnd(ctx, reqCtx); err != nil {
logger.V(logutil.DEFAULT).Error(err, "Failed to fetch worker ID from FrontEnd service")
// TODO (atchernych) if FrontEnd call fails ?
}

// Populate the ExtProc protocol responses for the request body.
requestBodyBytes, err := json.Marshal(reqCtx.Request.Body)
if err != nil {
Expand Down Expand Up @@ -523,3 +548,76 @@ func buildCommonResponses(bodyBytes []byte, byteLimit int, setEos bool) []*extPr

return responses
}

// fetchWorkerIDFromFrontEnd makes a blocking HTTP request to the FrontEnd service
// to obtain the worker_instance_id for the current request
func (s *StreamingServer) fetchWorkerIDFromFrontEnd(ctx context.Context, reqCtx *RequestContext) error {
logger := log.FromContext(ctx)

// Build FrontEnd service URL
frontEndURL := fmt.Sprintf("http://localhost:%s/v1/chat/completions", s.frontEndPort)

// Create request body with nvext annotations
requestBody := map[string]interface{}{
"nvext": map[string]interface{}{
"annotations": []string{"query_instance_id"},
},
}

requestJSON, err := json.Marshal(requestBody)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "Failed to marshal request body")
return fmt.Errorf("failed to marshal request body: %w", err)
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost, frontEndURL, bytes.NewBuffer(requestJSON))
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "Failed to create FrontEnd request")
return fmt.Errorf("failed to create FrontEnd request: %w", err)
}

// Set appropriate headers
req.Header.Set("Content-Type", "application/json")

// Make the blocking HTTP request
logger.V(logutil.VERBOSE).Info("Making blocking HTTP request to FrontEnd",
"url", frontEndURL,
"body", string(requestJSON))
resp, err := s.httpClient.Do(req)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "Failed to call FrontEnd service")
return fmt.Errorf("failed to call FrontEnd service: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
logger.V(logutil.DEFAULT).Error(nil, "FrontEnd service returned non-200 status", "status", resp.StatusCode)
return fmt.Errorf("FrontEnd service returned status %d", resp.StatusCode)
}

// Read the response body
body, err := io.ReadAll(resp.Body)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "Failed to read FrontEnd response body")
return fmt.Errorf("failed to read FrontEnd response body: %w", err)
}

// Parse JSON response to extract worker_instance_id
var responseData map[string]interface{}
if err := json.Unmarshal(body, &responseData); err != nil {
logger.V(logutil.DEFAULT).Error(err, "Failed to parse FrontEnd response JSON")
return fmt.Errorf("failed to parse FrontEnd response JSON: %w", err)
}

// Extract worker_instance_id from response
if workerID, exists := responseData["worker_instance_id"]; exists {
if workerIDStr, ok := workerID.(string); ok && workerIDStr != "" {
reqCtx.WorkerInstanceID = workerIDStr
logger.V(logutil.VERBOSE).Info("Successfully obtained worker ID from FrontEnd", "worker_instance_id", workerIDStr)
return nil
}
}

logger.V(logutil.DEFAULT).Info("FrontEnd response does not contain valid worker_instance_id", "response_body", string(body))
return fmt.Errorf("FrontEnd response does not contain valid worker_instance_id")
}