Skip to content

Commit c97c774

Browse files
authored
ENG-3780: use model id in body for routing (#40)
* use model id in body for routing * Fix body-model routing fallback for LoRA-aware workers Use request body models as worker filters only when the router registry has indexed that model, while keeping run-scoped requests strict. Treat rerank's "default" sentinel as unspecified so omitted-model rerank requests continue to route normally. Also sync all advertised worker models during startup and add_worker so workers serving multiple models or LoRA adapters are routable immediately, before the next health refresh. Adds regression coverage for indexed LoRA routing, rerank default handling, and multi-model worker registration.
1 parent 19cb83a commit c97c774

1 file changed

Lines changed: 323 additions & 3 deletions

File tree

src/routers/http/router.rs

Lines changed: 323 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use crate::otel_http::{self, ClientRequestOptions};
99
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
1010
use crate::protocols::spec::{
1111
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, GenerationRequest,
12-
RerankRequest, RerankResponse, RerankResult, ResponsesRequest,
12+
RerankRequest, RerankResponse, RerankResult, ResponsesRequest, DEFAULT_MODEL_NAME,
1313
};
1414
use crate::routers::header_utils;
1515
use crate::routers::http::dp_utils;
@@ -154,6 +154,10 @@ impl Router {
154154
)
155155
};
156156
ctx.worker_registry.register(worker_arc.clone());
157+
if !models.is_empty() {
158+
ctx.worker_registry
159+
.sync_worker_models(worker_arc.url(), &models);
160+
}
157161

158162
// Notify PolicyRegistry about the new worker
159163
let model_id = worker_arc.model_id();
@@ -538,6 +542,39 @@ impl Router {
538542
})
539543
}
540544

545+
fn normalize_model_id(model_id: Option<&str>) -> Option<&str> {
546+
let model = model_id?.trim();
547+
(!model.is_empty()).then_some(model)
548+
}
549+
550+
fn body_model_for_route<'a>(route: &str, model_id: Option<&'a str>) -> Option<&'a str> {
551+
let model = Self::normalize_model_id(model_id)?;
552+
if route == "/v1/rerank" && model == DEFAULT_MODEL_NAME {
553+
return None;
554+
}
555+
Some(model)
556+
}
557+
558+
fn resolve_body_model_filter<'a>(
559+
&self,
560+
route: &str,
561+
model_id: Option<&'a str>,
562+
run_id: Option<&str>,
563+
) -> Option<&'a str> {
564+
let model = Self::body_model_for_route(route, model_id)?;
565+
566+
if run_id.is_some() || !self.worker_registry.get_by_model_fast(model).is_empty() {
567+
return Some(model);
568+
}
569+
570+
debug!(
571+
model_id = %model,
572+
route,
573+
"body model is not indexed; routing without a model filter"
574+
);
575+
None
576+
}
577+
541578
/// Select worker for a specific model considering circuit breaker state
542579
fn select_worker_for_model(
543580
&self,
@@ -586,11 +623,24 @@ impl Router {
586623
let text = typed_req.extract_text_for_routing();
587624
let run_id = run_id.map(|s| s.to_string());
588625

626+
// Fall back to the body's `model` field when the caller doesn't pass one, but
627+
// only use it as a routing filter when the registry has already indexed that
628+
// model. This keeps compatibility for generic upstream model validation while
629+
// still preventing known LoRA requests from being sent to workers that have not
630+
// loaded the adapter. Run-scoped requests keep the body model as a hard filter.
631+
let effective_model_id = Self::normalize_model_id(model_id).or_else(|| {
632+
self.resolve_body_model_filter(route, typed_req.get_model(), run_id.as_deref())
633+
});
634+
589635
let response = RetryExecutor::execute_response_with_retry(
590636
&self.retry_config,
591637
// operation per attempt
592638
|_: u32| async {
593-
let worker = match self.select_worker_for_model(model_id, Some(&text), headers) {
639+
let worker = match self.select_worker_for_model(
640+
effective_model_id,
641+
Some(&text),
642+
headers,
643+
) {
594644
Some(w) => w,
595645
None => {
596646
RouterMetrics::record_request_error(route, "no_available_workers");
@@ -604,7 +654,7 @@ impl Router {
604654

605655
// Optional load tracking for cache-aware policy
606656
// Get the policy for this model to check if it's cache-aware
607-
let policy = match model_id {
657+
let policy = match effective_model_id {
608658
Some(model) => self.policy_registry.get_policy_or_default(model),
609659
None => self.policy_registry.get_default_policy(),
610660
};
@@ -1209,6 +1259,10 @@ impl Router {
12091259

12101260
let worker_arc: Arc<dyn Worker> = Arc::new(new_worker);
12111261
self.worker_registry.register(worker_arc.clone());
1262+
if !models.is_empty() {
1263+
self.worker_registry
1264+
.sync_worker_models(worker_arc.url(), &models);
1265+
}
12121266

12131267
// Notify PolicyRegistry about the new worker
12141268
let model_id = worker_arc.model_id();
@@ -1246,6 +1300,10 @@ impl Router {
12461300

12471301
let worker_arc = Arc::new(new_worker);
12481302
self.worker_registry.register(worker_arc.clone());
1303+
if !models.is_empty() {
1304+
self.worker_registry
1305+
.sync_worker_models(worker_arc.url(), &models);
1306+
}
12491307

12501308
// Notify PolicyRegistry about the new worker
12511309
let model_id = worker_arc.model_id();
@@ -2273,6 +2331,268 @@ mod tests {
22732331
);
22742332
}
22752333

2334+
/// Verify select_worker_for_model only picks workers whose model_index
2335+
/// includes the requested model. This is the safety property the
2336+
/// `effective_model_id` fallback in route_typed_request depends on: a
2337+
/// request body that asks for a LoRA adapter must never be dispatched
2338+
/// to a pod that hasn't loaded it.
2339+
#[test]
2340+
fn test_select_worker_filters_to_pods_with_loaded_lora() {
2341+
let router = create_test_regular_router();
2342+
2343+
// Both workers initially serve the base model only (this matches the
2344+
// state of a freshly-scaled-up vLLM pod before the orchestrator has
2345+
// pushed the LoRA adapter).
2346+
router
2347+
.worker_registry
2348+
.sync_worker_models("http://worker1:8080", &["base-model".to_string()]);
2349+
router
2350+
.worker_registry
2351+
.sync_worker_models("http://worker2:8080", &["base-model".to_string()]);
2352+
2353+
// Simulate orchestrator pushing the LoRA to worker1 only.
2354+
router.worker_registry.sync_worker_models(
2355+
"http://worker1:8080",
2356+
&["base-model".to_string(), "rft-run-1".to_string()],
2357+
);
2358+
2359+
// A request for the LoRA must land on worker1.
2360+
for _ in 0..20 {
2361+
let worker = router
2362+
.select_worker_for_model(Some("rft-run-1"), Some(r#"{"prompt":"x"}"#), None)
2363+
.expect("a worker is available");
2364+
assert_eq!(
2365+
worker.url(),
2366+
"http://worker1:8080",
2367+
"LoRA request leaked to a worker without the adapter loaded"
2368+
);
2369+
}
2370+
2371+
// The base model should still see both workers.
2372+
let base_workers = router.worker_registry.get_by_model_fast("base-model");
2373+
assert_eq!(base_workers.len(), 2);
2374+
}
2375+
2376+
#[test]
2377+
fn test_body_model_filter_only_uses_indexed_models() {
2378+
let router = create_test_regular_router();
2379+
2380+
router.worker_registry.sync_worker_models(
2381+
"http://worker1:8080",
2382+
&["base-model".to_string(), "rft-run-1".to_string()],
2383+
);
2384+
2385+
assert_eq!(
2386+
router.resolve_body_model_filter(
2387+
"/v1/chat/completions",
2388+
Some("rft-run-1"),
2389+
None,
2390+
),
2391+
Some("rft-run-1")
2392+
);
2393+
assert_eq!(
2394+
router.resolve_body_model_filter(
2395+
"/v1/chat/completions",
2396+
Some("not-indexed"),
2397+
None,
2398+
),
2399+
None
2400+
);
2401+
assert_eq!(
2402+
router.resolve_body_model_filter(
2403+
"/v1/chat/completions",
2404+
Some("not-indexed"),
2405+
Some("run-123"),
2406+
),
2407+
Some("not-indexed")
2408+
);
2409+
assert_eq!(
2410+
router.resolve_body_model_filter("/v1/rerank", Some(DEFAULT_MODEL_NAME), None),
2411+
None
2412+
);
2413+
}
2414+
2415+
async fn start_counting_chat_worker(
2416+
request_count: Arc<std::sync::atomic::AtomicUsize>,
2417+
) -> (String, tokio::task::JoinHandle<()>) {
2418+
use axum::{routing::post, Json, Router as AxumRouter};
2419+
use tokio::net::TcpListener;
2420+
2421+
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2422+
let addr = listener.local_addr().unwrap();
2423+
let app = AxumRouter::new().route(
2424+
"/v1/chat/completions",
2425+
post(move |Json(body): Json<serde_json::Value>| {
2426+
let request_count = request_count.clone();
2427+
async move {
2428+
request_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
2429+
let model = body
2430+
.get("model")
2431+
.and_then(|v| v.as_str())
2432+
.unwrap_or("unknown")
2433+
.to_string();
2434+
Json(serde_json::json!({
2435+
"id": "chatcmpl-test",
2436+
"object": "chat.completion",
2437+
"created": 0,
2438+
"model": model,
2439+
"choices": [{
2440+
"index": 0,
2441+
"message": {
2442+
"role": "assistant",
2443+
"content": "ok"
2444+
},
2445+
"finish_reason": "stop"
2446+
}],
2447+
"usage": {
2448+
"prompt_tokens": 1,
2449+
"completion_tokens": 1,
2450+
"total_tokens": 2
2451+
}
2452+
}))
2453+
}
2454+
}),
2455+
);
2456+
let handle = tokio::spawn(async move {
2457+
axum::serve(listener, app).await.unwrap();
2458+
});
2459+
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
2460+
(format!("http://{}", addr), handle)
2461+
}
2462+
2463+
async fn start_model_listing_worker(
2464+
models: Vec<&'static str>,
2465+
) -> (String, tokio::task::JoinHandle<()>) {
2466+
use axum::{routing::get, Json, Router as AxumRouter};
2467+
use tokio::net::TcpListener;
2468+
2469+
let models: Vec<String> = models.into_iter().map(String::from).collect();
2470+
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2471+
let addr = listener.local_addr().unwrap();
2472+
let app = AxumRouter::new()
2473+
.route("/health", get(|| async { StatusCode::OK }))
2474+
.route(
2475+
"/v1/models",
2476+
get(move || {
2477+
let models = models.clone();
2478+
async move {
2479+
let data: Vec<serde_json::Value> = models
2480+
.iter()
2481+
.map(|id| serde_json::json!({"id": id, "object": "model"}))
2482+
.collect();
2483+
Json(serde_json::json!({"object": "list", "data": data}))
2484+
}
2485+
}),
2486+
);
2487+
let handle = tokio::spawn(async move {
2488+
axum::serve(listener, app).await.unwrap();
2489+
});
2490+
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
2491+
(format!("http://{}", addr), handle)
2492+
}
2493+
2494+
#[tokio::test]
2495+
async fn test_route_chat_with_body_model_filters_to_loaded_lora_worker() {
2496+
let base_only_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
2497+
let lora_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
2498+
let (base_only_url, _base_only_handle) =
2499+
start_counting_chat_worker(base_only_count.clone()).await;
2500+
let (lora_url, _lora_handle) = start_counting_chat_worker(lora_count.clone()).await;
2501+
2502+
let worker_registry = Arc::new(WorkerRegistry::new());
2503+
worker_registry.register(Arc::new(BasicWorker::new(
2504+
base_only_url.clone(),
2505+
WorkerType::Regular,
2506+
)));
2507+
worker_registry.register(Arc::new(BasicWorker::new(
2508+
lora_url.clone(),
2509+
WorkerType::Regular,
2510+
)));
2511+
2512+
worker_registry.sync_worker_models(&base_only_url, &["base-model".to_string()]);
2513+
worker_registry.sync_worker_models(
2514+
&lora_url,
2515+
&["base-model".to_string(), "rft-run-1".to_string()],
2516+
);
2517+
2518+
let policy_registry = Arc::new(PolicyRegistry::new(
2519+
crate::config::types::PolicyConfig::RoundRobin,
2520+
));
2521+
let (_, rx) = tokio::sync::watch::channel(HashMap::new());
2522+
let router = Router {
2523+
worker_registry,
2524+
policy_registry,
2525+
worker_startup_timeout_secs: 5,
2526+
worker_startup_check_interval_secs: 1,
2527+
intra_node_data_parallel_size: 1,
2528+
api_key: None,
2529+
client: Client::new(),
2530+
retry_config: RetryConfig::default(),
2531+
circuit_breaker_config: CircuitBreakerConfig::default(),
2532+
_worker_loads: Arc::new(rx),
2533+
_load_monitor_handle: None,
2534+
};
2535+
2536+
let request: ChatCompletionRequest = serde_json::from_value(serde_json::json!({
2537+
"model": "rft-run-1",
2538+
"messages": [{"role": "user", "content": "hello"}]
2539+
}))
2540+
.unwrap();
2541+
2542+
for _ in 0..6 {
2543+
let response = router.route_chat(None, &request, None, None).await;
2544+
assert_eq!(response.status(), StatusCode::OK);
2545+
}
2546+
2547+
assert_eq!(
2548+
base_only_count.load(std::sync::atomic::Ordering::SeqCst),
2549+
0,
2550+
"body model was ignored and the request reached a worker without the LoRA"
2551+
);
2552+
assert_eq!(
2553+
lora_count.load(std::sync::atomic::Ordering::SeqCst),
2554+
6,
2555+
"all requests should be routed to the worker indexed for the requested LoRA"
2556+
);
2557+
}
2558+
2559+
#[tokio::test]
2560+
async fn test_router_new_indexes_all_discovered_models_immediately() {
2561+
let (url, _handle) =
2562+
start_model_listing_worker(vec!["base-model", "rft-run-1"]).await;
2563+
let config = crate::config::types::RouterConfig {
2564+
mode: crate::config::types::RoutingMode::Regular {
2565+
worker_urls: vec![url.clone()],
2566+
},
2567+
policy: crate::config::types::PolicyConfig::RoundRobin,
2568+
worker_startup_timeout_secs: 2,
2569+
worker_startup_check_interval_secs: 1,
2570+
..Default::default()
2571+
};
2572+
let ctx = Arc::new(
2573+
crate::server::AppContext::new(
2574+
config.clone(),
2575+
Client::new(),
2576+
config.max_concurrent_requests,
2577+
config.rate_limit_tokens_per_second,
2578+
config.api_key_validation_urls.clone(),
2579+
)
2580+
.unwrap(),
2581+
);
2582+
2583+
let router = Router::new(vec![url.clone()], &ctx).await.unwrap();
2584+
2585+
let base_workers = router.worker_registry.get_by_model_fast("base-model");
2586+
let lora_workers = router.worker_registry.get_by_model_fast("rft-run-1");
2587+
assert_eq!(base_workers.len(), 1);
2588+
assert_eq!(lora_workers.len(), 1);
2589+
2590+
let worker = router
2591+
.select_worker_for_model(Some("rft-run-1"), Some(r#"{"prompt":"x"}"#), None)
2592+
.expect("LoRA model should be routable immediately after registration");
2593+
assert_eq!(worker.url(), url);
2594+
}
2595+
22762596
#[test]
22772597
fn test_inline_header_conversion_matches_headers_to_request_headers() {
22782598
// Verify that the inline header conversion pattern used in pd_router and

0 commit comments

Comments
 (0)