diff --git a/cmd/epp/main.go b/cmd/epp/main.go index b5e06177bc..8592735d2c 100644 --- a/cmd/epp/main.go +++ b/cmd/epp/main.go @@ -22,6 +22,11 @@ import ( ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/gateway-api-inference-extension/cmd/epp/runner" + eppplugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + + // Dynamo plugins + dynprereq "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid" + dynscorer "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/dynamo_kv_scorer" ) func main() { @@ -30,6 +35,9 @@ func main() { // For adding out-of-tree plugins to the plugins registry, use the following: // plugins.Register(my-out-of-tree-plugin-name, my-out-of-tree-plugin-factory-function) + eppplugins.Register("dynamo-inject-workerid", dynprereq.InjectWorkerIDPreRequestFactory) + eppplugins.Register("kv-aware-scorer", dynscorer.KVAwareScorerFactory) + if err := runner.NewRunner().Run(ctrl.SetupSignalHandler()); err != nil { os.Exit(1) } diff --git a/pkg/bbr/handlers/request.go b/pkg/bbr/handlers/request.go index 32fffc0217..1aa1b85268 100644 --- a/pkg/bbr/handlers/request.go +++ b/pkg/bbr/handlers/request.go @@ -18,8 +18,10 @@ package handlers import ( "context" + "encoding/base64" "encoding/json" "fmt" + "strings" basepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" eppb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" @@ -31,11 +33,49 @@ import ( const modelHeader = "X-Gateway-Model-Name" +// Dynamo-related +const ( + workerIDHeader = "x-worker-instance-id" + injectHintHeader = "x-epp-inject-nvext-worker-instance-id" + tokenDataHeader = "x-epp-inject-nvext-token-data" +) + // HandleRequestBody handles request bodies. func (s *Server) HandleRequestBody(ctx context.Context, data map[string]any) ([]*eppb.ProcessingResponse, error) { logger := log.FromContext(ctx) var ret []*eppb.ProcessingResponse + // If we captured a worker id hint in the headers phase, inject it into body JSON: + // nvext.backend_instance_id = + if wid := strings.TrimSpace(s.workerIDHint); wid != "" { + // ensure nvext is a map[string]any + if nv, ok := data["nvext"]; !ok || nv == nil { + data["nvext"] = map[string]any{"backend_instance_id": wid} + } else if m, ok := nv.(map[string]any); ok { + m["backend_instance_id"] = wid + } else { + // if nvext was some other type, replace with a clean map + data["nvext"] = map[string]any{"backend_instance_id": wid} + } + } + + // If we captured token_data in headers, decode and inject as nvext.token_data + if td := strings.TrimSpace(s.tokenDataHint); td != "" { + // header value is base64(JSON array) + if raw, err := base64.StdEncoding.DecodeString(td); err == nil { + var arr []int64 + if err := json.Unmarshal(raw, &arr); err == nil && len(arr) > 0 { + // ensure nvext map exists + nv, ok := data["nvext"].(map[string]any) + if !ok || nv == nil { + nv = map[string]any{} + data["nvext"] = nv + } + nv["token_data"] = arr + } + } + } + requestBodyBytes, err := json.Marshal(data) if err != nil { return nil, err @@ -46,6 +86,7 @@ func (s *Server) HandleRequestBody(ctx context.Context, data map[string]any) ([] metrics.RecordModelNotInBodyCounter() logger.V(logutil.DEFAULT).Info("Request body does not contain model parameter") if s.streaming { + // still stream the possibly mutated body ret = append(ret, &eppb.ProcessingResponse{ Response: &eppb.ProcessingResponse_RequestHeaders{ RequestHeaders: &eppb.HeadersResponse{}, @@ -53,14 +94,24 @@ func (s *Server) HandleRequestBody(ctx context.Context, data map[string]any) ([] }) ret = addStreamedBodyResponse(ret, requestBodyBytes) return ret, nil - } else { - ret = append(ret, &eppb.ProcessingResponse{ + } + + // non-streaming: return a body response with the (possibly) mutated body + return []*eppb.ProcessingResponse{ + { Response: &eppb.ProcessingResponse_RequestBody{ - RequestBody: &eppb.BodyResponse{}, + RequestBody: &eppb.BodyResponse{ + Response: &eppb.CommonResponse{ + BodyMutation: &eppb.BodyMutation{ + Mutation: &eppb.BodyMutation_Body{ + Body: requestBodyBytes, + }, + }, + }, + }, }, - }) - } - return ret, nil + }, + }, nil } modelStr, ok := modelVal.(string) @@ -73,6 +124,7 @@ func (s *Server) HandleRequestBody(ctx context.Context, data map[string]any) ([] metrics.RecordSuccessCounter() if s.streaming { + // set the model header, then stream the (possibly) mutated body ret = append(ret, &eppb.ProcessingResponse{ Response: &eppb.ProcessingResponse_RequestHeaders{ RequestHeaders: &eppb.HeadersResponse{ @@ -86,16 +138,42 @@ func (s *Server) HandleRequestBody(ctx context.Context, data map[string]any) ([] RawValue: []byte(modelStr), }, }, + // also keep the worker id header if we have one + func() *basepb.HeaderValueOption { + if strings.TrimSpace(s.workerIDHint) == "" { + return nil + } + return &basepb.HeaderValueOption{ + Header: &basepb.HeaderValue{ + Key: workerIDHeader, + RawValue: []byte(s.workerIDHint), + }, + } + }(), }, }, }, }, }, }) + + // prune nil entries if worker id not present + hm := ret[len(ret)-1].GetRequestHeaders().GetResponse().GetHeaderMutation() + if hm != nil && hm.SetHeaders != nil { + out := hm.SetHeaders[:0] + for _, h := range hm.SetHeaders { + if h != nil { + out = append(out, h) + } + } + hm.SetHeaders = out + } + ret = addStreamedBodyResponse(ret, requestBodyBytes) return ret, nil } + // Non-streaming: set model header and replace the body with our mutated JSON return []*eppb.ProcessingResponse{ { Response: &eppb.ProcessingResponse_RequestBody{ @@ -111,6 +189,22 @@ func (s *Server) HandleRequestBody(ctx context.Context, data map[string]any) ([] RawValue: []byte(modelStr), }, }, + func() *basepb.HeaderValueOption { + if strings.TrimSpace(s.workerIDHint) == "" { + return nil + } + return &basepb.HeaderValueOption{ + Header: &basepb.HeaderValue{ + Key: workerIDHeader, + RawValue: []byte(s.workerIDHint), + }, + } + }(), + }, + }, + BodyMutation: &eppb.BodyMutation{ + Mutation: &eppb.BodyMutation_Body{ + Body: requestBodyBytes, }, }, }, @@ -141,6 +235,32 @@ func addStreamedBodyResponse(responses []*eppb.ProcessingResponse, requestBodyBy // HandleRequestHeaders handles request headers. func (s *Server) HandleRequestHeaders(headers *eppb.HttpHeaders) ([]*eppb.ProcessingResponse, error) { + // reset per-request + s.workerIDHint = "" + s.tokenDataHint = "" + + if m := headers.GetHeaders(); m != nil { + for _, h := range m.GetHeaders() { + k := strings.ToLower(h.GetKey()) + + switch k { + case injectHintHeader, workerIDHeader: + if rv := h.GetRawValue(); len(rv) > 0 { + s.workerIDHint = strings.TrimSpace(string(rv)) + } else { + s.workerIDHint = strings.TrimSpace(h.GetValue()) + } + case tokenDataHeader: + if rv := h.GetRawValue(); len(rv) > 0 { + s.tokenDataHint = strings.TrimSpace(string(rv)) + } else { + s.tokenDataHint = strings.TrimSpace(h.GetValue()) + } + } + } + } + + // No header mutations needed here; body phase will do the JSON injection. return []*eppb.ProcessingResponse{ { Response: &eppb.ProcessingResponse_RequestHeaders{ diff --git a/pkg/bbr/handlers/server.go b/pkg/bbr/handlers/server.go index a5803806bc..eb2893fdc6 100644 --- a/pkg/bbr/handlers/server.go +++ b/pkg/bbr/handlers/server.go @@ -38,7 +38,9 @@ func NewServer(streaming bool) *Server { // Server implements the Envoy external processing server. // https://www.envoyproxy.io/docs/envoy/latest/api-v3/service/ext_proc/v3/external_processor.proto type Server struct { - streaming bool + streaming bool + workerIDHint string + tokenDataHint string } func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { diff --git a/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid/plugin.go b/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid/plugin.go new file mode 100644 index 0000000000..b6708fa4d4 --- /dev/null +++ b/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid/plugin.go @@ -0,0 +1,69 @@ +package dynamo_inject_workerid + +import ( + "context" + "encoding/json" + "strings" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + rc "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" + schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +const ( + typeString = "dynamo-inject-workerid" + pluginName = "dynamo-inject-workerid" + WorkerIDHeader = "x-worker-instance-id" + injectHintHeader = "x-epp-inject-nvext-worker-instance-id" + TokenDataHeader = "x-epp-inject-nvext-token-data" +) + +var _ plugins.Plugin = (*InjectWorkerIDPreRequest)(nil) +var _ rc.PreRequest = (*InjectWorkerIDPreRequest)(nil) + +type InjectWorkerIDPreRequest struct { + typedName plugins.TypedName +} + +func NewInjectWorkerIDPreRequest() *InjectWorkerIDPreRequest { + return &InjectWorkerIDPreRequest{ + typedName: plugins.TypedName{Type: typeString, Name: pluginName}, + } +} + +func (p *InjectWorkerIDPreRequest) WithName(name string) *InjectWorkerIDPreRequest { + p.typedName.Name = name + return p +} + +func InjectWorkerIDPreRequestFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { + return NewInjectWorkerIDPreRequest().WithName(name), nil +} + +func (p *InjectWorkerIDPreRequest) TypedName() plugins.TypedName { return p.typedName } + +func (p *InjectWorkerIDPreRequest) PreRequest( + _ context.Context, + req *schedtypes.LLMRequest, + _ *schedtypes.SchedulingResult, + _ int, +) { + if req == nil { + return + } + if req.Headers == nil { + req.Headers = map[string]string{} + } + wid := strings.TrimSpace(req.Headers[WorkerIDHeader]) + if wid == "" { + return + } + req.Headers[WorkerIDHeader] = wid + req.Headers[injectHintHeader] = wid + + // Pass through token-data header if scorer set it + if td := strings.TrimSpace(req.Headers[TokenDataHeader]); td != "" { + req.Headers[TokenDataHeader] = td + } + +} diff --git a/pkg/epp/scheduling/plugins/dynamo_kv_scorer/epp-config-dynamo.yaml b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/epp-config-dynamo.yaml new file mode 100644 index 0000000000..2d92be03b7 --- /dev/null +++ b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/epp-config-dynamo.yaml @@ -0,0 +1,24 @@ +# This is an example for configuring the EPP to use the dynamo token-aware kv router for scoring the pods +apiVersion: inference.networking.x-k8s.io/v1alpha1 +kind: EndpointPickerConfig +plugins: + # Required: tells EPP which profile to use (even if you only have one) + - type: single-profile-handler + + # Picker: chooses the final endpoint after scoring + - name: picker + type: max-score-picker + - name: dyn-pre + type: dynamo-inject-workerid + parameters: {} + - name: dyn-kv + type: kv-aware-scorer + parameters: + frontendURL: http://127.0.0.1:8000/v1/chat/completions + timeoutMS: 10000 +schedulingProfiles: + - name: default + plugins: + - pluginRef: dyn-kv + weight: 1 + - pluginRef: picker diff --git a/pkg/epp/scheduling/plugins/dynamo_kv_scorer/plugin.go b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/plugin.go new file mode 100644 index 0000000000..50eb5f6907 --- /dev/null +++ b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/plugin.go @@ -0,0 +1,431 @@ +package dynamo_kv_scorer + +import ( + "bufio" + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + log "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" + schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +const ( + PluginName = "dynamo-kv-scorer" + KVAwareScorerType = "kv-aware-scorer" + StateKeyWorkerInstanceID = schedtypes.StateKey("dynamo/worker-instance-id") + WorkerIDHeader = "x-worker-instance-id" + TokenDataHeader = "x-epp-inject-nvext-token-data" +) + +type params struct { + FrontendURL string `json:"frontendURL"` + TimeoutMS int `json:"timeoutMS"` +} + +// tiny wrapper so we can store a string in CycleState +type stateString string + +func (s stateString) Clone() schedtypes.StateData { return s } + +type KVAwareScorer struct { + typedName plugins.TypedName + feURL string + feTimeout time.Duration +} + +// compile-time assertions +var _ plugins.Plugin = (*KVAwareScorer)(nil) +var _ framework.Scorer = (*KVAwareScorer)(nil) + +func NewKVAwareScorer() *KVAwareScorer { + return &KVAwareScorer{ + typedName: plugins.TypedName{Type: KVAwareScorerType, Name: PluginName}, + feURL: "http://127.0.0.1:8000/v1/chat/completions", + feTimeout: 10 * time.Second, + } +} + +func (k *KVAwareScorer) WithName(name string) *KVAwareScorer { k.typedName.Name = name; return k } +func (k *KVAwareScorer) WithFrontend(url string, timeout time.Duration) *KVAwareScorer { + if url != "" { + k.feURL = url + } + if timeout > 0 { + k.feTimeout = timeout + } + return k +} + +func KVAwareScorerFactory(name string, raw json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { + p := params{} + _ = json.Unmarshal(raw, &p) + timeout := time.Duration(p.TimeoutMS) * time.Millisecond + if timeout <= 0 { + timeout = 10 * time.Second + } + return NewKVAwareScorer().WithName(name).WithFrontend(p.FrontendURL, timeout), nil +} + +func (k *KVAwareScorer) TypedName() plugins.TypedName { return k.typedName } + +func (k *KVAwareScorer) Score( + ctx context.Context, + cycle *schedtypes.CycleState, + req *schedtypes.LLMRequest, + pods []schedtypes.Pod, +) map[schedtypes.Pod]float64 { + logger := log.FromContext(ctx) + + workerID, tokenData, err := k.callFrontEndForWorker(ctx, req) + if err != nil { + logger.V(logutil.DEFAULT).Error(err, "FrontEnd call failed; proceeding without worker id") + } else if workerID != "" { + cycle.Write(StateKeyWorkerInstanceID, stateString(workerID)) + if req.Headers == nil { + req.Headers = map[string]string{} + } + req.Headers[WorkerIDHeader] = workerID + if len(tokenData) > 0 { + if req.Headers == nil { + req.Headers = map[string]string{} + } + req.Headers[TokenDataHeader] = encodeTokenData(tokenData) + } + } + + // neutral/uniform scores – only your scorer runs in the profile, so this “wins” + out := make(map[schedtypes.Pod]float64, len(pods)) + for _, p := range pods { + out[p] = 1.0 + } + return out +} + +// Call the Dynamo FrontEnd and extract worker_instance_id via SSE. +func (k *KVAwareScorer) callFrontEndForWorker( + ctx context.Context, + req *schedtypes.LLMRequest, +) (string, []int64, error) { + logger := log.FromContext(ctx) + + feBody := buildFrontEndBodyFromLLMRequest(req) + payload, err := json.Marshal(feBody) + if err != nil { + logger.V(logutil.DEFAULT).Error(err, "Dynamo FrontEnd marshal failed") + return "", nil, fmt.Errorf("marshal FrontEnd body: %w", err) + } + + reqCtx, cancel := context.WithTimeout(ctx, k.feTimeout) + defer cancel() + + httpReq, err := http.NewRequestWithContext(reqCtx, http.MethodPost, k.feURL, bytes.NewReader(payload)) + if err != nil { + logger.V(logutil.DEFAULT).Error(err, "Dynamo FrontEnd request build failed") + return "", nil, fmt.Errorf("build FrontEnd request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "text/event-stream") + + client := &http.Client{Timeout: 0} + resp, err := client.Do(httpReq) + if err != nil { + logger.V(logutil.DEFAULT).Error(err, "Dynamo FrontEnd POST failed") + return "", nil, fmt.Errorf("FrontEnd POST failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + errBody, _ := io.ReadAll(resp.Body) + logger.V(logutil.DEFAULT).Error(nil, "Dynamo FrontEnd non-2xx response", + "status_code", resp.StatusCode, "response_body", string(errBody)) + return "", nil, fmt.Errorf("Dynamo FrontEnd error: %d body=%s", resp.StatusCode, string(errBody)) + } + + ct := strings.ToLower(resp.Header.Get("Content-Type")) + if !strings.Contains(ct, "text/event-stream") { + logger.V(logutil.DEFAULT).Error(nil, "Unexpected non-SSE response") + return "", nil, fmt.Errorf("unexpected non-SSE response (Content-Type=%q)", resp.Header.Get("Content-Type")) + } + + // Parse SSE: expect `event: worker_instance_id`, a quoted id in a comment or data, and `data: [DONE]` + reader := bufio.NewReader(resp.Body) + workerID, tokenData, perr := parseSelectionFromSSE(ctx, reader) + if perr != nil { + return "", nil, perr + } + return workerID, tokenData, nil +} + +// Build the exact body we send to the FrontEnd, only from LLMRequest (no header merging). +func buildFrontEndBodyFromLLMRequest(req *schedtypes.LLMRequest) map[string]any { + feBody := make(map[string]any, 8) + + // We call /v1/chat/completions so must provide messages + userText := "" + if req != nil && strings.TrimSpace(req.Prompt) != "" { + userText = req.Prompt + } + feBody["messages"] = []map[string]any{ + {"role": "user", "content": userText}, + } + + if req != nil && strings.TrimSpace(req.TargetModel) != "" { + feBody["model"] = req.TargetModel + } + + // Force SSE so we can parse worker_instance_id + feBody["stream"] = true + + feBody["max_tokens"] = 1 + feBody["temperature"] = 0.0 + + // Ask the Dynamo to include worker id + feBody["nvext"] = map[string]any{ + "annotations": []string{"query_instance_id"}, + } + + return feBody +} + +// This function scans an SSE stream for a worker_instance_id and token_data. +// Expected pattern: +// +// event: worker_instance_id +// : "8303679623149182543" +// data: [DONE] + +// or with tokens: +// event: worker_instance_id\n: \"8228244551594056720\"\n\n +// event: token_data\n: \"[151644,872,198,151644,872,198,14990,151645,198,151645,198,151644,77091,198]\ +// "\n\ndata: [DONE]\n\n" +// Also supports JSON in data lines with either top-level worker_instance_id +// or annotations.worker_instance_id. +func parseSelectionFromSSE(ctx context.Context, reader *bufio.Reader) (string, []int64, error) { + logger := log.FromContext(ctx) + + var ( + eventName string + dataBuf strings.Builder // accumulates "data:" lines for one event + commentBuf strings.Builder // accumulates ":" comment lines + gotWID string + gotTD []int64 + ) + + // collect the exact SSE bytes for debugging + var rawBuf strings.Builder + + flushEvent := func() (bool, error) { + data := strings.TrimSpace(dataBuf.String()) + comment := strings.TrimSpace(commentBuf.String()) + dataBuf.Reset() + commentBuf.Reset() + + // [DONE] ends the stream + if data == "[DONE]" || comment == "[DONE]" { + logger.V(logutil.DEFAULT).Info("SSE stream DONE") + logger.V(logutil.DEFAULT).Info("SSE raw stream", "raw", rawBuf.String()) + if gotWID != "" && len(gotTD) == 0 { + logger.V(logutil.DEFAULT).Info("SSE DONE: worker_instance_id present, token_data missing") + } + return true, nil + } + + // Prefer the named event + if eventName == "worker_instance_id" { + candidate := data + if candidate == "" { + candidate = comment + } + if candidate != "" { + // Try JSON string + var s string + if json.Unmarshal([]byte(candidate), &s) == nil && s != "" { + logger.V(logutil.VERBOSE).Info("worker_instance_id extracted from named event", "worker_instance_id", s) + gotWID = s + return false, nil + } + // Fallback: strip quotes + clean := strings.Trim(candidate, "\"") + if clean != "" && clean != "[DONE]" { + logger.V(logutil.DEFAULT).Info("worker_instance_id extracted (raw) from named event", "worker_instance_id", clean) + gotWID = clean + return false, nil + } + } + } + + if eventName == "token_data" { + candidate := data + if candidate == "" { + candidate = comment + } + if candidate != "" { + if arr := toInt64SliceJSON(candidate); len(arr) > 0 { + gotTD = arr + logger.V(logutil.DEFAULT).Info("token_data extracted from named event", "count", len(arr)) + return false, nil + } + } + } + // Generic JSON in data: + if data != "" { + var msg map[string]any + if json.Unmarshal([]byte(data), &msg) == nil { + if wid, ok := msg["worker_instance_id"].(string); ok && wid != "" { + logger.V(logutil.DEFAULT).Info("worker_instance_id found in SSE payload root", "worker_instance_id", wid) + gotWID = wid + } + if ann, ok := msg["annotations"].(map[string]any); ok { + if wid, ok := ann["worker_instance_id"].(string); ok && wid != "" { + logger.V(logutil.DEFAULT).Info("worker_instance_id found in SSE annotations", "worker_instance_id", wid) + gotWID = wid + } + } + if td, ok := msg["token_data"]; ok { + if arr := toInt64Slice(td); len(arr) > 0 { + gotTD = arr + logger.V(logutil.DEFAULT).Info("token_data found in SSE payload root", "count", len(arr)) + } + } else if nv, ok := msg["nvext"].(map[string]any); ok { + if td, ok := nv["token_data"]; ok { + if arr := toInt64Slice(td); len(arr) > 0 { + gotTD = arr + logger.V(logutil.DEFAULT).Info("token_data found in SSE nvext", "count", len(arr)) + } + } + } + } + } + return false, nil + } + + for { + line, err := reader.ReadString('\n') + // capture the raw stream as-is for debugging + rawBuf.WriteString(line) + if err != nil { + if err == io.EOF { + _, _ = flushEvent() + logger.V(logutil.DEFAULT).Info("SSE raw stream (EOF)", "raw", rawBuf.String()) + if gotWID != "" && len(gotTD) == 0 { + logger.V(logutil.DEFAULT).Info("EOF: worker_instance_id present, token_data missing") + } + if gotWID != "" || len(gotTD) > 0 { + return gotWID, gotTD, nil + } + logger.V(logutil.DEFAULT).Error(nil, "EOF before selection fields present") + return "", nil, fmt.Errorf("selection not found in SSE stream (EOF)") + } + logger.V(logutil.DEFAULT).Error(err, "SSE read error") + return "", nil, fmt.Errorf("sse read error: %w", err) + } + + l := strings.TrimRight(line, "\r\n") + if l == "" { + // End of current event. + if done, _ := flushEvent(); done { + if gotWID != "" && len(gotTD) == 0 { + logger.V(logutil.DEFAULT).Info("SSE DONE: worker_instance_id present, token_data missing") + } + return gotWID, gotTD, nil + } + eventName = "" // reset for next event + continue + } + + // Comment line + if strings.HasPrefix(l, ":") { + commentLine := strings.TrimSpace(l[1:]) + if commentBuf.Len() > 0 { + commentBuf.WriteByte('\n') + } + commentBuf.WriteString(commentLine) + continue + } + + // "field: value" + if idx := strings.IndexByte(l, ':'); idx != -1 { + field := l[:idx] + val := strings.TrimSpace(l[idx+1:]) + switch field { + case "event": + eventName = val + case "data": + if dataBuf.Len() > 0 { + dataBuf.WriteByte('\n') + } + dataBuf.WriteString(val) + default: + // ignore id, retry, etc. + } + } + } +} + +// encodeTokenData turns []int64 into base64(JSON array) for a safe header value. +func encodeTokenData(tokens []int64) string { + b, _ := json.Marshal(tokens) + return base64.StdEncoding.EncodeToString(b) +} + +// Accepts interface{} from a parsed JSON map +func toInt64Slice(v any) []int64 { + xs, ok := v.([]any) + if !ok { + return nil + } + out := make([]int64, 0, len(xs)) + for _, it := range xs { + switch n := it.(type) { + case float64: + out = append(out, int64(n)) + case int64: + out = append(out, n) + case json.Number: + if i, err := n.Int64(); err == nil { + out = append(out, i) + } + } + } + return out +} + +// Accepts raw JSON (string) for events like: +// event: worker_instance_id\n: \"8228244551594056720\"\n\n +// event: token_data\n: \"[151644,872,198,151644,872,198,14990,151645,198,151645,198,151644,77091,198]\ +// "\n\ndata: [DONE]\n\n" +// replaces the old toInt64SliceJSON +func toInt64SliceJSON(s string) []int64 { + // case 1: direct JSON array + var arr []int64 + if err := json.Unmarshal([]byte(s), &arr); err == nil && len(arr) > 0 { + return arr + } + // case 2: s is a JSON string that itself contains a JSON array + var inner string + if err := json.Unmarshal([]byte(s), &inner); err == nil && inner != "" { + var arr2 []int64 + if err := json.Unmarshal([]byte(inner), &arr2); err == nil && len(arr2) > 0 { + return arr2 + } + } + // case 3: strip quotes and try once more + unquoted := strings.Trim(s, "\"") + if unquoted != s { + var arr3 []int64 + if err := json.Unmarshal([]byte(unquoted), &arr3); err == nil && len(arr3) > 0 { + return arr3 + } + } + return nil +}