Skip to content

Commit 29fd9ba

Browse files
BrianPark314BrianPark314YuhanLiu11
authored
feature/prefix-aware-routing (vllm-project#546)
* feat: add vllmruntime Signed-off-by: BrianPark314 <brianpark314@gmail.com> * feat: add prefix aware picker Signed-off-by: BrianPark314 <brianpark314@gmail.com> * feat: add prefix aware picker Signed-off-by: BrianPark314 <brianpark314@gmail.com> * feat: update prefix aware logic Signed-off-by: BrianPark314 <brianpark314@gmail.com> * chore: revert roundrobin picker Signed-off-by: BrianPark314 <brianpark314@gmail.com> * chore: fix pa picker logic Signed-off-by: BrianPark314 <brianpark314@gmail.com> * chore: add newline Signed-off-by: BrianPark314 <brianpark314@gmail.com> --------- Signed-off-by: BrianPark314 <brianpark314@gmail.com> Co-authored-by: BrianPark314 <brianpark314@gmail.com> Co-authored-by: Yuhan Liu <32589867+YuhanLiu11@users.noreply.github.com>
1 parent 7324b96 commit 29fd9ba

2 files changed

Lines changed: 214 additions & 0 deletions

File tree

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
/*
2+
Copyright 2025 The vLLM Production Stack Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
*/
10+
11+
package picker
12+
13+
import (
14+
"math/rand"
15+
"strings"
16+
"sync"
17+
"time"
18+
19+
"github.com/cespare/xxhash/v2"
20+
21+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"
22+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
23+
)
24+
25+
const chunkSize = 128
26+
27+
var _ plugins.Picker = &PrefixMatchPicker{}
28+
29+
// PrefixMatchPicker selects the engine whose URL was returned by the
30+
// longest-prefix match against previously-seen prompts (same idea as the
31+
// Python `route_request`). Ties are broken at random.
32+
type PrefixMatchPicker struct {
33+
trie *hashTrie
34+
rnd *rand.Rand
35+
}
36+
37+
// NewPrefixMatchPicker returns a ready-to-use picker instance.
38+
func NewPrefixMatchPicker() *PrefixMatchPicker {
39+
return &PrefixMatchPicker{
40+
trie: newHashTrie(),
41+
rnd: rand.New(rand.NewSource(time.Now().UnixNano())),
42+
}
43+
}
44+
45+
func (p *PrefixMatchPicker) Name() string { return "prefixmatch" }
46+
47+
// Pick implements plugins.Picker.
48+
//
49+
// SchedulingContext is assumed to carry the inference request body in
50+
// ctx.RequestBody (map[string]any) with the prompt at key "prompt". Adjust
51+
// the accessor if your integration differs.
52+
func (p *PrefixMatchPicker) Pick(
53+
ctx *types.SchedulingContext,
54+
scoredPods []*types.ScoredPod,
55+
) *types.Result {
56+
if len(scoredPods) == 0 {
57+
return &types.Result{}
58+
}
59+
60+
var prompt string
61+
62+
if msgs, ok := ctx.RequestBody["messages"]; ok {
63+
if arr, ok := msgs.([]any); ok {
64+
var parts []string
65+
for _, m := range arr {
66+
mm, ok := m.(map[string]any)
67+
if !ok {
68+
continue
69+
}
70+
switch c := mm["content"].(type) {
71+
case string:
72+
parts = append(parts, c)
73+
case []any:
74+
for _, part := range c {
75+
mp, ok := part.(map[string]any)
76+
if !ok {
77+
continue
78+
}
79+
if mp["type"] == "text" {
80+
if txt, ok := mp["text"].(string); ok {
81+
parts = append(parts, txt)
82+
}
83+
}
84+
}
85+
}
86+
}
87+
prompt = strings.Join(parts, "\n")
88+
}
89+
}
90+
91+
if prompt == "" {
92+
prompt, _ = ctx.RequestBody["prompt"].(string)
93+
}
94+
95+
// 1. Build the set of available endpoints.
96+
available := make(map[string]struct{}, len(scoredPods))
97+
for _, sp := range scoredPods {
98+
ep := sp.GetPod().EndpointURL // <-- adapt this accessor
99+
available[ep] = struct{}{}
100+
}
101+
102+
// 2. Longest-prefix match within the trie.
103+
matched := p.trie.longestPrefixMatch(prompt, available)
104+
105+
// 3. Fallback: no match --> all endpoints are candidates.
106+
if len(matched) == 0 {
107+
for ep := range available {
108+
matched[ep] = struct{}{}
109+
}
110+
}
111+
112+
// 4. Convert the matched set to a slice and pick randomly.
113+
endpoints := make([]string, 0, len(matched))
114+
for ep := range matched {
115+
endpoints = append(endpoints, ep)
116+
}
117+
selected := endpoints[p.rnd.Intn(len(endpoints))]
118+
119+
// 5. Cache the decision for future prefix look-ups.
120+
p.trie.insert(prompt, selected)
121+
122+
// 6. Return the pod whose URL matches `selected`.
123+
for _, sp := range scoredPods {
124+
if sp.GetPod().EndpointURL == selected { // same accessor as above
125+
return &types.Result{TargetPod: sp}
126+
}
127+
}
128+
// Should never hit; safe fallback.
129+
return &types.Result{TargetPod: scoredPods[0]}
130+
}
131+
132+
/*---------------------------- trie implementation ---------------------------*/
133+
134+
type hashTrie struct {
135+
mu sync.RWMutex
136+
children map[uint64]*hashTrie
137+
endpoints map[string]struct{}
138+
}
139+
140+
func newHashTrie() *hashTrie {
141+
return &hashTrie{children: make(map[uint64]*hashTrie)}
142+
}
143+
144+
func intersection(a, b map[string]struct{}) map[string]struct{} {
145+
res := make(map[string]struct{})
146+
for k := range a {
147+
if _, ok := b[k]; ok {
148+
res[k] = struct{}{}
149+
}
150+
}
151+
return res
152+
}
153+
154+
func chunkAndHash(s string) []uint64 {
155+
hashes := make([]uint64, 0, (len(s)+chunkSize-1)/chunkSize)
156+
for i := 0; i < len(s); i += chunkSize {
157+
end := i + chunkSize
158+
if end > len(s) {
159+
end = len(s)
160+
}
161+
hashes = append(hashes, xxhash.Sum64([]byte(s[i:end])))
162+
}
163+
return hashes
164+
}
165+
166+
func (t *hashTrie) insert(key, endpoint string) {
167+
t.mu.Lock()
168+
defer t.mu.Unlock()
169+
170+
node := t
171+
if node.endpoints == nil {
172+
node.endpoints = make(map[string]struct{})
173+
}
174+
node.endpoints[endpoint] = struct{}{}
175+
176+
for _, h := range chunkAndHash(key) {
177+
child, ok := node.children[h]
178+
if !ok {
179+
child = newHashTrie()
180+
node.children[h] = child
181+
}
182+
node = child
183+
if node.endpoints == nil {
184+
node.endpoints = make(map[string]struct{})
185+
}
186+
node.endpoints[endpoint] = struct{}{}
187+
}
188+
}
189+
190+
func (t *hashTrie) longestPrefixMatch(
191+
key string,
192+
available map[string]struct{},
193+
) map[string]struct{} {
194+
t.mu.RLock()
195+
defer t.mu.RUnlock()
196+
197+
node := t
198+
matched := intersection(node.endpoints, available)
199+
200+
for _, h := range chunkAndHash(key) {
201+
child, ok := node.children[h]
202+
if !ok {
203+
break
204+
}
205+
node = child
206+
cand := intersection(node.endpoints, available)
207+
if len(cand) == 0 {
208+
break
209+
}
210+
matched = cand
211+
}
212+
return matched
213+
}

src/gateway_inference_extension/scheduler.patch

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ index b484cde..c7688a8 100644
1616
scorers: map[plugins.Scorer]int{},
1717
- picker: &picker.RandomPicker{},
1818
+ picker: &picker.RoundRobinPicker{},
19+
+ picker: &picker.PrefixAwarePicker{},
1920
postSchedulePlugins: []plugins.PostSchedule{},
2021
}
2122

0 commit comments

Comments
 (0)