|
| 1 | +/* |
| 2 | +Copyright 2025 The vLLM Production Stack Authors. |
| 3 | +
|
| 4 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +you may not use this file except in compliance with the License. |
| 6 | +You may obtain a copy of the License at |
| 7 | +
|
| 8 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +*/ |
| 10 | + |
| 11 | +package picker |
| 12 | + |
| 13 | +import ( |
| 14 | + "math/rand" |
| 15 | + "strings" |
| 16 | + "sync" |
| 17 | + "time" |
| 18 | + |
| 19 | + "github.com/cespare/xxhash/v2" |
| 20 | + |
| 21 | + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" |
| 22 | + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" |
| 23 | +) |
| 24 | + |
| 25 | +const chunkSize = 128 |
| 26 | + |
| 27 | +var _ plugins.Picker = &PrefixMatchPicker{} |
| 28 | + |
| 29 | +// PrefixMatchPicker selects the engine whose URL was returned by the |
| 30 | +// longest-prefix match against previously-seen prompts (same idea as the |
| 31 | +// Python `route_request`). Ties are broken at random. |
| 32 | +type PrefixMatchPicker struct { |
| 33 | + trie *hashTrie |
| 34 | + rnd *rand.Rand |
| 35 | +} |
| 36 | + |
| 37 | +// NewPrefixMatchPicker returns a ready-to-use picker instance. |
| 38 | +func NewPrefixMatchPicker() *PrefixMatchPicker { |
| 39 | + return &PrefixMatchPicker{ |
| 40 | + trie: newHashTrie(), |
| 41 | + rnd: rand.New(rand.NewSource(time.Now().UnixNano())), |
| 42 | + } |
| 43 | +} |
| 44 | + |
| 45 | +func (p *PrefixMatchPicker) Name() string { return "prefixmatch" } |
| 46 | + |
| 47 | +// Pick implements plugins.Picker. |
| 48 | +// |
| 49 | +// SchedulingContext is assumed to carry the inference request body in |
| 50 | +// ctx.RequestBody (map[string]any) with the prompt at key "prompt". Adjust |
| 51 | +// the accessor if your integration differs. |
| 52 | +func (p *PrefixMatchPicker) Pick( |
| 53 | + ctx *types.SchedulingContext, |
| 54 | + scoredPods []*types.ScoredPod, |
| 55 | +) *types.Result { |
| 56 | + if len(scoredPods) == 0 { |
| 57 | + return &types.Result{} |
| 58 | + } |
| 59 | + |
| 60 | + var prompt string |
| 61 | + |
| 62 | + if msgs, ok := ctx.RequestBody["messages"]; ok { |
| 63 | + if arr, ok := msgs.([]any); ok { |
| 64 | + var parts []string |
| 65 | + for _, m := range arr { |
| 66 | + mm, ok := m.(map[string]any) |
| 67 | + if !ok { |
| 68 | + continue |
| 69 | + } |
| 70 | + switch c := mm["content"].(type) { |
| 71 | + case string: |
| 72 | + parts = append(parts, c) |
| 73 | + case []any: |
| 74 | + for _, part := range c { |
| 75 | + mp, ok := part.(map[string]any) |
| 76 | + if !ok { |
| 77 | + continue |
| 78 | + } |
| 79 | + if mp["type"] == "text" { |
| 80 | + if txt, ok := mp["text"].(string); ok { |
| 81 | + parts = append(parts, txt) |
| 82 | + } |
| 83 | + } |
| 84 | + } |
| 85 | + } |
| 86 | + } |
| 87 | + prompt = strings.Join(parts, "\n") |
| 88 | + } |
| 89 | + } |
| 90 | + |
| 91 | + if prompt == "" { |
| 92 | + prompt, _ = ctx.RequestBody["prompt"].(string) |
| 93 | + } |
| 94 | + |
| 95 | + // 1. Build the set of available endpoints. |
| 96 | + available := make(map[string]struct{}, len(scoredPods)) |
| 97 | + for _, sp := range scoredPods { |
| 98 | + ep := sp.GetPod().EndpointURL // <-- adapt this accessor |
| 99 | + available[ep] = struct{}{} |
| 100 | + } |
| 101 | + |
| 102 | + // 2. Longest-prefix match within the trie. |
| 103 | + matched := p.trie.longestPrefixMatch(prompt, available) |
| 104 | + |
| 105 | + // 3. Fallback: no match --> all endpoints are candidates. |
| 106 | + if len(matched) == 0 { |
| 107 | + for ep := range available { |
| 108 | + matched[ep] = struct{}{} |
| 109 | + } |
| 110 | + } |
| 111 | + |
| 112 | + // 4. Convert the matched set to a slice and pick randomly. |
| 113 | + endpoints := make([]string, 0, len(matched)) |
| 114 | + for ep := range matched { |
| 115 | + endpoints = append(endpoints, ep) |
| 116 | + } |
| 117 | + selected := endpoints[p.rnd.Intn(len(endpoints))] |
| 118 | + |
| 119 | + // 5. Cache the decision for future prefix look-ups. |
| 120 | + p.trie.insert(prompt, selected) |
| 121 | + |
| 122 | + // 6. Return the pod whose URL matches `selected`. |
| 123 | + for _, sp := range scoredPods { |
| 124 | + if sp.GetPod().EndpointURL == selected { // same accessor as above |
| 125 | + return &types.Result{TargetPod: sp} |
| 126 | + } |
| 127 | + } |
| 128 | + // Should never hit; safe fallback. |
| 129 | + return &types.Result{TargetPod: scoredPods[0]} |
| 130 | +} |
| 131 | + |
| 132 | +/*---------------------------- trie implementation ---------------------------*/ |
| 133 | + |
| 134 | +type hashTrie struct { |
| 135 | + mu sync.RWMutex |
| 136 | + children map[uint64]*hashTrie |
| 137 | + endpoints map[string]struct{} |
| 138 | +} |
| 139 | + |
| 140 | +func newHashTrie() *hashTrie { |
| 141 | + return &hashTrie{children: make(map[uint64]*hashTrie)} |
| 142 | +} |
| 143 | + |
| 144 | +func intersection(a, b map[string]struct{}) map[string]struct{} { |
| 145 | + res := make(map[string]struct{}) |
| 146 | + for k := range a { |
| 147 | + if _, ok := b[k]; ok { |
| 148 | + res[k] = struct{}{} |
| 149 | + } |
| 150 | + } |
| 151 | + return res |
| 152 | +} |
| 153 | + |
| 154 | +func chunkAndHash(s string) []uint64 { |
| 155 | + hashes := make([]uint64, 0, (len(s)+chunkSize-1)/chunkSize) |
| 156 | + for i := 0; i < len(s); i += chunkSize { |
| 157 | + end := i + chunkSize |
| 158 | + if end > len(s) { |
| 159 | + end = len(s) |
| 160 | + } |
| 161 | + hashes = append(hashes, xxhash.Sum64([]byte(s[i:end]))) |
| 162 | + } |
| 163 | + return hashes |
| 164 | +} |
| 165 | + |
| 166 | +func (t *hashTrie) insert(key, endpoint string) { |
| 167 | + t.mu.Lock() |
| 168 | + defer t.mu.Unlock() |
| 169 | + |
| 170 | + node := t |
| 171 | + if node.endpoints == nil { |
| 172 | + node.endpoints = make(map[string]struct{}) |
| 173 | + } |
| 174 | + node.endpoints[endpoint] = struct{}{} |
| 175 | + |
| 176 | + for _, h := range chunkAndHash(key) { |
| 177 | + child, ok := node.children[h] |
| 178 | + if !ok { |
| 179 | + child = newHashTrie() |
| 180 | + node.children[h] = child |
| 181 | + } |
| 182 | + node = child |
| 183 | + if node.endpoints == nil { |
| 184 | + node.endpoints = make(map[string]struct{}) |
| 185 | + } |
| 186 | + node.endpoints[endpoint] = struct{}{} |
| 187 | + } |
| 188 | +} |
| 189 | + |
| 190 | +func (t *hashTrie) longestPrefixMatch( |
| 191 | + key string, |
| 192 | + available map[string]struct{}, |
| 193 | +) map[string]struct{} { |
| 194 | + t.mu.RLock() |
| 195 | + defer t.mu.RUnlock() |
| 196 | + |
| 197 | + node := t |
| 198 | + matched := intersection(node.endpoints, available) |
| 199 | + |
| 200 | + for _, h := range chunkAndHash(key) { |
| 201 | + child, ok := node.children[h] |
| 202 | + if !ok { |
| 203 | + break |
| 204 | + } |
| 205 | + node = child |
| 206 | + cand := intersection(node.endpoints, available) |
| 207 | + if len(cand) == 0 { |
| 208 | + break |
| 209 | + } |
| 210 | + matched = cand |
| 211 | + } |
| 212 | + return matched |
| 213 | +} |
0 commit comments