diff --git a/Makefile b/Makefile index dee7e99e0e..d3f9ec74ac 100644 --- a/Makefile +++ b/Makefile @@ -170,6 +170,48 @@ verify-all: ##@ Build +##@ Dynamo EPP with FFI + +# Build the Dynamo EPP image with CGO static library support +.PHONY: dynamo-image-local-build +dynamo-image-local-build: ## Build the Dynamo EPP image using Docker Buildx for local development. + BUILDER=$(shell $(DOCKER_BUILDX_CMD) create --use) + $(MAKE) dynamo-image-build PUSH=$(PUSH) + $(MAKE) dynamo-image-build LOAD=$(LOAD) + $(DOCKER_BUILDX_CMD) rm $$BUILDER + +.PHONY: dynamo-image-local-push +dynamo-image-local-push: PUSH=--push ## Build the Dynamo EPP image for local development and push it to $IMAGE_REPO. +dynamo-image-local-push: dynamo-image-local-build + +.PHONY: dynamo-image-local-load +dynamo-image-local-load: LOAD=--load ## Build the Dynamo EPP image for local development and load it in the local Docker registry. +dynamo-image-local-load: dynamo-image-local-build + +.PHONY: dynamo-image-build +dynamo-image-build: ## Build the Dynamo EPP image using Docker Buildx with CGO support. + $(IMAGE_BUILD_CMD) -f Dockerfile.dynamo -t $(IMAGE_TAG) \ + --platform=$(PLATFORMS) \ + --build-arg BASE_IMAGE=ubuntu:24.04 \ + --build-arg BUILDER_IMAGE=$(BUILDER_IMAGE) \ + --build-arg COMMIT_SHA=${GIT_COMMIT_SHA} \ + --build-arg BUILD_REF=${BUILD_REF} \ + $(PUSH) \ + $(LOAD) \ + $(IMAGE_BUILD_EXTRA_OPTS) ./ + +.PHONY: dynamo-image-push +dynamo-image-push: PUSH=--push ## Build the Dynamo EPP image and push it to $IMAGE_REPO. +dynamo-image-push: dynamo-image-build + +.PHONY: dynamo-image-load +dynamo-image-load: LOAD=--load ## Build the Dynamo EPP image and load it in the local Docker registry. +dynamo-image-load: dynamo-image-build + +.PHONY: dynamo-image-kind +dynamo-image-kind: dynamo-image-build ## Build the Dynamo EPP image and load it to kind cluster $KIND_CLUSTER ("kind" by default). + kind load docker-image $(IMAGE_TAG) --name $(KIND_CLUSTER) + # Build the container image .PHONY: image-local-build image-local-build: ## Build the EPP image using Docker Buildx for local development. diff --git a/cmd/epp/main.go b/cmd/epp/main.go index b5e06177bc..8592735d2c 100644 --- a/cmd/epp/main.go +++ b/cmd/epp/main.go @@ -22,6 +22,11 @@ import ( ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/gateway-api-inference-extension/cmd/epp/runner" + eppplugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + + // Dynamo plugins + dynprereq "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid" + dynscorer "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/dynamo_kv_scorer" ) func main() { @@ -30,6 +35,9 @@ func main() { // For adding out-of-tree plugins to the plugins registry, use the following: // plugins.Register(my-out-of-tree-plugin-name, my-out-of-tree-plugin-factory-function) + eppplugins.Register("dynamo-inject-workerid", dynprereq.InjectWorkerIDPreRequestFactory) + eppplugins.Register("kv-aware-scorer", dynscorer.KVAwareScorerFactory) + if err := runner.NewRunner().Run(ctrl.SetupSignalHandler()); err != nil { os.Exit(1) } diff --git a/pkg/bbr/handlers/request.go b/pkg/bbr/handlers/request.go index 32fffc0217..1aa1b85268 100644 --- a/pkg/bbr/handlers/request.go +++ b/pkg/bbr/handlers/request.go @@ -18,8 +18,10 @@ package handlers import ( "context" + "encoding/base64" "encoding/json" "fmt" + "strings" basepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" eppb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" @@ -31,11 +33,49 @@ import ( const modelHeader = "X-Gateway-Model-Name" +// Dynamo-related +const ( + workerIDHeader = "x-worker-instance-id" + injectHintHeader = "x-epp-inject-nvext-worker-instance-id" + tokenDataHeader = "x-epp-inject-nvext-token-data" +) + // HandleRequestBody handles request bodies. func (s *Server) HandleRequestBody(ctx context.Context, data map[string]any) ([]*eppb.ProcessingResponse, error) { logger := log.FromContext(ctx) var ret []*eppb.ProcessingResponse + // If we captured a worker id hint in the headers phase, inject it into body JSON: + // nvext.backend_instance_id = + if wid := strings.TrimSpace(s.workerIDHint); wid != "" { + // ensure nvext is a map[string]any + if nv, ok := data["nvext"]; !ok || nv == nil { + data["nvext"] = map[string]any{"backend_instance_id": wid} + } else if m, ok := nv.(map[string]any); ok { + m["backend_instance_id"] = wid + } else { + // if nvext was some other type, replace with a clean map + data["nvext"] = map[string]any{"backend_instance_id": wid} + } + } + + // If we captured token_data in headers, decode and inject as nvext.token_data + if td := strings.TrimSpace(s.tokenDataHint); td != "" { + // header value is base64(JSON array) + if raw, err := base64.StdEncoding.DecodeString(td); err == nil { + var arr []int64 + if err := json.Unmarshal(raw, &arr); err == nil && len(arr) > 0 { + // ensure nvext map exists + nv, ok := data["nvext"].(map[string]any) + if !ok || nv == nil { + nv = map[string]any{} + data["nvext"] = nv + } + nv["token_data"] = arr + } + } + } + requestBodyBytes, err := json.Marshal(data) if err != nil { return nil, err @@ -46,6 +86,7 @@ func (s *Server) HandleRequestBody(ctx context.Context, data map[string]any) ([] metrics.RecordModelNotInBodyCounter() logger.V(logutil.DEFAULT).Info("Request body does not contain model parameter") if s.streaming { + // still stream the possibly mutated body ret = append(ret, &eppb.ProcessingResponse{ Response: &eppb.ProcessingResponse_RequestHeaders{ RequestHeaders: &eppb.HeadersResponse{}, @@ -53,14 +94,24 @@ func (s *Server) HandleRequestBody(ctx context.Context, data map[string]any) ([] }) ret = addStreamedBodyResponse(ret, requestBodyBytes) return ret, nil - } else { - ret = append(ret, &eppb.ProcessingResponse{ + } + + // non-streaming: return a body response with the (possibly) mutated body + return []*eppb.ProcessingResponse{ + { Response: &eppb.ProcessingResponse_RequestBody{ - RequestBody: &eppb.BodyResponse{}, + RequestBody: &eppb.BodyResponse{ + Response: &eppb.CommonResponse{ + BodyMutation: &eppb.BodyMutation{ + Mutation: &eppb.BodyMutation_Body{ + Body: requestBodyBytes, + }, + }, + }, + }, }, - }) - } - return ret, nil + }, + }, nil } modelStr, ok := modelVal.(string) @@ -73,6 +124,7 @@ func (s *Server) HandleRequestBody(ctx context.Context, data map[string]any) ([] metrics.RecordSuccessCounter() if s.streaming { + // set the model header, then stream the (possibly) mutated body ret = append(ret, &eppb.ProcessingResponse{ Response: &eppb.ProcessingResponse_RequestHeaders{ RequestHeaders: &eppb.HeadersResponse{ @@ -86,16 +138,42 @@ func (s *Server) HandleRequestBody(ctx context.Context, data map[string]any) ([] RawValue: []byte(modelStr), }, }, + // also keep the worker id header if we have one + func() *basepb.HeaderValueOption { + if strings.TrimSpace(s.workerIDHint) == "" { + return nil + } + return &basepb.HeaderValueOption{ + Header: &basepb.HeaderValue{ + Key: workerIDHeader, + RawValue: []byte(s.workerIDHint), + }, + } + }(), }, }, }, }, }, }) + + // prune nil entries if worker id not present + hm := ret[len(ret)-1].GetRequestHeaders().GetResponse().GetHeaderMutation() + if hm != nil && hm.SetHeaders != nil { + out := hm.SetHeaders[:0] + for _, h := range hm.SetHeaders { + if h != nil { + out = append(out, h) + } + } + hm.SetHeaders = out + } + ret = addStreamedBodyResponse(ret, requestBodyBytes) return ret, nil } + // Non-streaming: set model header and replace the body with our mutated JSON return []*eppb.ProcessingResponse{ { Response: &eppb.ProcessingResponse_RequestBody{ @@ -111,6 +189,22 @@ func (s *Server) HandleRequestBody(ctx context.Context, data map[string]any) ([] RawValue: []byte(modelStr), }, }, + func() *basepb.HeaderValueOption { + if strings.TrimSpace(s.workerIDHint) == "" { + return nil + } + return &basepb.HeaderValueOption{ + Header: &basepb.HeaderValue{ + Key: workerIDHeader, + RawValue: []byte(s.workerIDHint), + }, + } + }(), + }, + }, + BodyMutation: &eppb.BodyMutation{ + Mutation: &eppb.BodyMutation_Body{ + Body: requestBodyBytes, }, }, }, @@ -141,6 +235,32 @@ func addStreamedBodyResponse(responses []*eppb.ProcessingResponse, requestBodyBy // HandleRequestHeaders handles request headers. func (s *Server) HandleRequestHeaders(headers *eppb.HttpHeaders) ([]*eppb.ProcessingResponse, error) { + // reset per-request + s.workerIDHint = "" + s.tokenDataHint = "" + + if m := headers.GetHeaders(); m != nil { + for _, h := range m.GetHeaders() { + k := strings.ToLower(h.GetKey()) + + switch k { + case injectHintHeader, workerIDHeader: + if rv := h.GetRawValue(); len(rv) > 0 { + s.workerIDHint = strings.TrimSpace(string(rv)) + } else { + s.workerIDHint = strings.TrimSpace(h.GetValue()) + } + case tokenDataHeader: + if rv := h.GetRawValue(); len(rv) > 0 { + s.tokenDataHint = strings.TrimSpace(string(rv)) + } else { + s.tokenDataHint = strings.TrimSpace(h.GetValue()) + } + } + } + } + + // No header mutations needed here; body phase will do the JSON injection. return []*eppb.ProcessingResponse{ { Response: &eppb.ProcessingResponse_RequestHeaders{ diff --git a/pkg/bbr/handlers/server.go b/pkg/bbr/handlers/server.go index a5803806bc..eb2893fdc6 100644 --- a/pkg/bbr/handlers/server.go +++ b/pkg/bbr/handlers/server.go @@ -38,7 +38,9 @@ func NewServer(streaming bool) *Server { // Server implements the Envoy external processing server. // https://www.envoyproxy.io/docs/envoy/latest/api-v3/service/ext_proc/v3/external_processor.proto type Server struct { - streaming bool + streaming bool + workerIDHint string + tokenDataHint string } func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { diff --git a/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid/plugin.go b/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid/plugin.go new file mode 100644 index 0000000000..b6708fa4d4 --- /dev/null +++ b/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid/plugin.go @@ -0,0 +1,69 @@ +package dynamo_inject_workerid + +import ( + "context" + "encoding/json" + "strings" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + rc "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" + schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +const ( + typeString = "dynamo-inject-workerid" + pluginName = "dynamo-inject-workerid" + WorkerIDHeader = "x-worker-instance-id" + injectHintHeader = "x-epp-inject-nvext-worker-instance-id" + TokenDataHeader = "x-epp-inject-nvext-token-data" +) + +var _ plugins.Plugin = (*InjectWorkerIDPreRequest)(nil) +var _ rc.PreRequest = (*InjectWorkerIDPreRequest)(nil) + +type InjectWorkerIDPreRequest struct { + typedName plugins.TypedName +} + +func NewInjectWorkerIDPreRequest() *InjectWorkerIDPreRequest { + return &InjectWorkerIDPreRequest{ + typedName: plugins.TypedName{Type: typeString, Name: pluginName}, + } +} + +func (p *InjectWorkerIDPreRequest) WithName(name string) *InjectWorkerIDPreRequest { + p.typedName.Name = name + return p +} + +func InjectWorkerIDPreRequestFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { + return NewInjectWorkerIDPreRequest().WithName(name), nil +} + +func (p *InjectWorkerIDPreRequest) TypedName() plugins.TypedName { return p.typedName } + +func (p *InjectWorkerIDPreRequest) PreRequest( + _ context.Context, + req *schedtypes.LLMRequest, + _ *schedtypes.SchedulingResult, + _ int, +) { + if req == nil { + return + } + if req.Headers == nil { + req.Headers = map[string]string{} + } + wid := strings.TrimSpace(req.Headers[WorkerIDHeader]) + if wid == "" { + return + } + req.Headers[WorkerIDHeader] = wid + req.Headers[injectHintHeader] = wid + + // Pass through token-data header if scorer set it + if td := strings.TrimSpace(req.Headers[TokenDataHeader]); td != "" { + req.Headers[TokenDataHeader] = td + } + +} diff --git a/pkg/epp/scheduling/plugins/dynamo_kv_scorer/epp-config-dynamo.yaml b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/epp-config-dynamo.yaml new file mode 100644 index 0000000000..b689c00171 --- /dev/null +++ b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/epp-config-dynamo.yaml @@ -0,0 +1,21 @@ +# This is an example for configuring the EPP to use the dynamo token-aware kv router for scoring the pods +apiVersion: inference.networking.x-k8s.io/v1alpha1 +kind: EndpointPickerConfig +plugins: + # Required: tells EPP which profile to use (even if you only have one) + - type: single-profile-handler + + # Picker: chooses the final endpoint after scoring + - name: picker + type: max-score-picker + - name: dyn-pre + type: dynamo-inject-workerid + parameters: {} + - name: dyn-kv + type: kv-aware-scorer +schedulingProfiles: + - name: default + plugins: + - pluginRef: dyn-kv + weight: 1 + - pluginRef: picker diff --git a/pkg/epp/scheduling/plugins/dynamo_kv_scorer/plugin.go b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/plugin.go new file mode 100644 index 0000000000..83a4ace264 --- /dev/null +++ b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/plugin.go @@ -0,0 +1,428 @@ +package dynamo_kv_scorer + +/* +#cgo CPPFLAGS: -I${SRCDIR}/include +#cgo CXXFLAGS: -std=c++17 +#cgo LDFLAGS: ${SRCDIR}/lib/libdynamo_llm_capi.a -lstdc++ -ldl -lpthread -lm + +#include +#include +#include // for free +#include + +// enum underlying type is uint32_t; matches cbindgen output +typedef uint32_t dynamo_llm_result_t; +enum { DYNAMO_OK = 0, DYNAMO_ERR = 1 }; + +// opaque handle forward-decl +struct WorkerSelectionPipeline; +typedef struct WorkerSelectionPipeline WorkerSelectionPipeline; + +// Prototypes (C-compatible) +dynamo_llm_result_t dynamo_llm_init(const char *namespace_c_str, + const char *component_c_str, + int64_t worker_id, + uint32_t kv_block_size); + +dynamo_llm_result_t dynamo_llm_shutdown(void); +dynamo_llm_result_t dynamo_llm_load_publisher_create(void); + +dynamo_llm_result_t dynamo_kv_event_publish_stored(uint64_t event_id, + const uint32_t *token_ids, + const uintptr_t *num_block_tokens, + const uint64_t *block_ids, + size_t num_blocks, + const uint64_t *parent_hash, + uint64_t lora_id); + +dynamo_llm_result_t dynamo_kv_event_publish_removed(uint64_t event_id, + const uint64_t *block_ids, + size_t num_blocks); + +dynamo_llm_result_t dynamo_create_worker_selection_pipeline(const char *namespace_c_str, + const char *component_c_str, + const char *model_name_c_str, + bool use_kv_routing, + double busy_threshold, + double overlap_score_weight, + double router_temperature, + bool use_kv_events, + bool router_replica_sync, + WorkerSelectionPipeline **pipeline_out); + +dynamo_llm_result_t dynamo_destroy_worker_selection_pipeline(WorkerSelectionPipeline *pipeline); + +dynamo_llm_result_t dynamo_query_worker_selection_and_annotate(WorkerSelectionPipeline *pipeline, + const char *request_json_c_str, + int64_t *worker_instance_id_out, + uint32_t **token_ids_out, + size_t *token_count_out, + char **annotated_request_json_out); + +dynamo_llm_result_t dynamo_free_worker_selection_result(uint32_t *token_ids, + size_t token_count, + char *annotated_request_json); +*/ +import "C" + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "os" + "strings" + "sync" + "unsafe" + + log "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" + schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +const ( + PluginName = "dynamo-kv-scorer" + KVAwareScorerType = "kv-aware-scorer" + StateKeyWorkerInstanceID = schedtypes.StateKey("dynamo/worker-instance-id") + WorkerIDHeader = "x-worker-instance-id" + TokenDataHeader = "x-epp-inject-nvext-token-data" +) + +// --------------------------- config / env --------------------------- + +var warmupOnce sync.Once +var warmupErr error + +type stateString string +type params struct { +} + +func (s stateString) Clone() schedtypes.StateData { return s } + +type KVAwareScorer struct { + typedName plugins.TypedName +} + +var _ plugins.Plugin = (*KVAwareScorer)(nil) +var _ framework.Scorer = (*KVAwareScorer)(nil) + +func NewKVAwareScorer() *KVAwareScorer { + return &KVAwareScorer{ + typedName: plugins.TypedName{Type: KVAwareScorerType, Name: PluginName}, + } +} + +func (k *KVAwareScorer) WithName(name string) *KVAwareScorer { k.typedName.Name = name; return k } + +func KVAwareScorerFactory(name string, raw json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { + p := params{} + _ = json.Unmarshal(raw, &p) + + s := NewKVAwareScorer().WithName(name) + + // one-time FFI init (runtime + persistent pipeline) + warmupOnce.Do(func() { + defer func() { + if r := recover(); r != nil { + warmupErr = fmt.Errorf("Dynamo configuration error: %v", r) + } + }() + warmupErr = initFFI() + }) + if warmupErr != nil { + return nil, fmt.Errorf("Dynamo FFI init for the Router failed: %w", warmupErr) + } + + return s, nil +} + +func (k *KVAwareScorer) TypedName() plugins.TypedName { return k.typedName } + +// --------------------------- FFI integration --------------------------- + +var ( + ffiOnce sync.Once + ffiErr error + + ffiNamespace string + ffiComponent string + ffiModel string + ffiOverlapScoreWeight float64 + ffiRouterTemperature float64 + ffiKvBlockSize uint32 + ffiWorkerID int64 + + runtimeInitialized bool + + // Boxed pipeline handle (owned on the Rust side, opaque here) + pipeline *C.struct_WorkerSelectionPipeline + pipelineMutex sync.RWMutex +) + +func loadDynamoConfig() { + ffiNamespace = getEnvOrDefault("DYNAMO_NAMESPACE", "vllm-agg") + ffiComponent = getEnvOrDefault("DYNAMO_COMPONENT", "backend") + ffiModel = getEnvOrDefault("DYNAMO_MODEL", "Qwen/Qwen3-0.6B") + ffiWorkerID = getEnvInt64OrDefault("DYNAMO_WORKER_ID", 1) + + ffiOverlapScoreWeight = getEnvFloatOrDefault("DYNAMO_OVERLAP_SCORE_WEIGHT", -1.0) + ffiRouterTemperature = getEnvFloatOrDefault("DYNAMO_ROUTER_TEMPERATURE", -1.0) + + kvBlockSizeStr := os.Getenv("DYNAMO_KV_BLOCK_SIZE") + if kvBlockSizeStr == "" { + panic("DYNAMO_KV_BLOCK_SIZE is required and must match the model card's kv_cache_block_size") + } + var tmp int64 + if n, err := fmt.Sscanf(kvBlockSizeStr, "%d", &tmp); err != nil || n != 1 { + panic(fmt.Sprintf("DYNAMO_KV_BLOCK_SIZE='%s' is not a valid integer", kvBlockSizeStr)) + } + ffiKvBlockSize = uint32(tmp) + if ffiKvBlockSize < 16 || ffiKvBlockSize > 8192 { + panic(fmt.Sprintf("DYNAMO_KV_BLOCK_SIZE=%d outside [16,8192]", ffiKvBlockSize)) + } + if (ffiKvBlockSize & (ffiKvBlockSize - 1)) != 0 { + panic(fmt.Sprintf("DYNAMO_KV_BLOCK_SIZE=%d must be a power of 2", ffiKvBlockSize)) + } + fmt.Printf("Dynamo KV Scorer: Loaded DYNAMO_KV_BLOCK_SIZE=%d\n", ffiKvBlockSize) +} + +func getEnvOrDefault(key, def string) string { + if v := os.Getenv(key); v != "" { + return v + } + return def +} +func getEnvInt64OrDefault(key string, def int64) int64 { + if v := os.Getenv(key); v != "" { + var p int64 + if n, err := fmt.Sscanf(v, "%d", &p); err == nil && n == 1 { + return p + } + } + return def +} +func getEnvFloatOrDefault(key string, def float64) float64 { + if v := os.Getenv(key); v != "" { + var p float64 + if n, err := fmt.Sscanf(v, "%f", &p); err == nil && n == 1 { + return p + } + } + return def +} +func getEnvBoolOrDefault(key string, def bool) bool { + if v := os.Getenv(key); v != "" { + switch strings.ToLower(v) { + case "true", "1", "yes", "on": + return true + case "false", "0", "no", "off": + return false + } + } + return def +} + +// initFFI: initialize runtime and create a persistent boxed pipeline. +func initFFI() error { + ffiOnce.Do(func() { + loadDynamoConfig() + + ns := C.CString(ffiNamespace) + cm := C.CString(ffiComponent) + model := C.CString(ffiModel) + defer C.free(unsafe.Pointer(ns)) + defer C.free(unsafe.Pointer(cm)) + defer C.free(unsafe.Pointer(model)) + + // Init Dynamo runtime + if rc := C.dynamo_llm_init(ns, cm, C.int64_t(ffiWorkerID), C.uint32_t(ffiKvBlockSize)); rc != C.DYNAMO_OK { + ffiErr = fmt.Errorf("dynamo_llm_init failed") + return + } + runtimeInitialized = true + + // Create persistent pipeline + pipelineMutex.Lock() + defer pipelineMutex.Unlock() + + rc := C.dynamo_create_worker_selection_pipeline( + ns, + cm, + model, + C.bool(getEnvBoolOrDefault("DYNAMO_USE_KV_ROUTING", true)), + C.double(getEnvFloatOrDefault("DYNAMO_BUSY_THRESHOLD", -1.0)), + C.double(ffiOverlapScoreWeight), + C.double(ffiRouterTemperature), + C.bool(getEnvBoolOrDefault("DYNAMO_USE_KV_EVENTS", true)), + C.bool(getEnvBoolOrDefault("DYNAMO_ROUTER_REPLICA_SYNC", true)), + &pipeline, + ) + if rc != C.DYNAMO_OK { + ffiErr = fmt.Errorf("dynamo_create_worker_selection_pipeline failed") + return + } + }) + return ffiErr +} + +// --------------------------- scoring --------------------------- + +func encodeTokenData(tokens []int64) string { + b, _ := json.Marshal(tokens) + return base64.StdEncoding.EncodeToString(b) +} + +func (k *KVAwareScorer) Score( + ctx context.Context, + cycle *schedtypes.CycleState, + req *schedtypes.LLMRequest, + pods []schedtypes.Pod, +) map[schedtypes.Pod]float64 { + logger := log.FromContext(ctx) + + workerID, tokenData, err := k.callDynamoRouter(ctx, req) + if err != nil { + logger.V(logutil.DEFAULT).Error(err, "Dynamo call failed; proceeding without worker id") + } else if workerID != "" { + logger.V(logutil.DEFAULT).Info( + "Dynamo router selected worker", + "workerID", workerID, + "tokenDataCount", len(tokenData), + "tokenData", tokenData, + ) + cycle.Write(StateKeyWorkerInstanceID, stateString(workerID)) + if req.Headers == nil { + req.Headers = map[string]string{} + } + req.Headers[WorkerIDHeader] = workerID + if len(tokenData) > 0 { + if req.Headers == nil { + req.Headers = map[string]string{} + } + req.Headers[TokenDataHeader] = encodeTokenData(tokenData) + } + } + + out := make(map[schedtypes.Pod]float64, len(pods)) + for _, p := range pods { + out[p] = 1.0 + } + return out +} + +// --------------------------- router call (persistent only) --------------------------- + +func (k *KVAwareScorer) callDynamoRouter( + ctx context.Context, + req *schedtypes.LLMRequest, +) (string, []int64, error) { + logger := log.FromContext(ctx) + + if err := initFFI(); err != nil { + logger.V(logutil.DEFAULT).Error(err, "FFI init failed") + return "", nil, err + } + if !runtimeInitialized { + return "", nil, fmt.Errorf("dynamo runtime not initialized") + } + + pipelineMutex.RLock() + currentPipeline := pipeline + pipelineMutex.RUnlock() + + if currentPipeline == nil { + return "", nil, fmt.Errorf("dynamo worker selection pipeline not created") + } + + // Build OpenAI-compatible JSON request + requestBody := buildOpenAIRequest(req) + requestJSON, err := json.Marshal(requestBody) + if err != nil { + logger.V(logutil.DEFAULT).Error(err, "Failed to marshal OpenAI request") + return "", nil, fmt.Errorf("marshal OpenAI request: %w", err) + } + cRequestJSON := C.CString(string(requestJSON)) + defer C.free(unsafe.Pointer(cRequestJSON)) + + // Output variables + var cWorkerID C.int64_t + var cTokens *C.uint32_t + var cTokenCount C.size_t + var cAnnotatedJSON *C.char + + // Call the worker selection pipeline + rc := C.dynamo_query_worker_selection_and_annotate( + currentPipeline, + cRequestJSON, + &cWorkerID, + &cTokens, + &cTokenCount, + &cAnnotatedJSON, + ) + if rc != C.DYNAMO_OK { + return "", nil, fmt.Errorf("dynamo_query_worker_selection_and_annotate failed") + } + + // Copy tokens into Go memory and free C memory + count := int(uintptr(cTokenCount)) + var tokens64 []int64 + if count > 0 && cTokens != nil { + src := unsafe.Slice((*uint32)(unsafe.Pointer(cTokens)), count) + tokens64 = make([]int64, count) + for i := 0; i < count; i++ { + tokens64[i] = int64(src[i]) + } + } + C.dynamo_free_worker_selection_result(cTokens, cTokenCount, cAnnotatedJSON) + + workerID := fmt.Sprintf("%d", int64(cWorkerID)) + logger.V(logutil.DEFAULT).Info("Worker selection completed", + "workerID", workerID, "tokenCount", count) + + return workerID, tokens64, nil +} + +func buildOpenAIRequest(req *schedtypes.LLMRequest) map[string]any { + requestBody := make(map[string]any) + userText := "default prompt" + if req != nil && strings.TrimSpace(req.Prompt) != "" { + userText = req.Prompt + } + requestBody["messages"] = []map[string]any{{"role": "user", "content": userText}} + if req != nil && strings.TrimSpace(req.TargetModel) != "" { + requestBody["model"] = req.TargetModel + } else { + requestBody["model"] = ffiModel + } + requestBody["max_tokens"] = 1 + requestBody["temperature"] = 0.0 + requestBody["stream"] = true + requestBody["nvext"] = map[string]any{ + "annotations": []string{"query_instance_id"}, + } + return requestBody +} + +// --------------------------- shutdown --------------------------- + +func cleanupDynamo() error { + pipelineMutex.Lock() + defer pipelineMutex.Unlock() + + if pipeline != nil { + if rc := C.dynamo_destroy_worker_selection_pipeline(pipeline); rc != C.DYNAMO_OK { + fmt.Printf("Warning: dynamo_destroy_worker_selection_pipeline failed\n") + } + pipeline = nil + } + + if runtimeInitialized { + if rc := C.dynamo_llm_shutdown(); rc != C.DYNAMO_OK { + return fmt.Errorf("dynamo_llm_shutdown failed") + } + runtimeInitialized = false + } + return nil +}