@@ -9,7 +9,7 @@ use crate::otel_http::{self, ClientRequestOptions};
99use crate :: policies:: { LoadBalancingPolicy , PolicyRegistry } ;
1010use crate :: protocols:: spec:: {
1111 ChatCompletionRequest , CompletionRequest , EmbeddingRequest , GenerateRequest , GenerationRequest ,
12- RerankRequest , RerankResponse , RerankResult , ResponsesRequest ,
12+ RerankRequest , RerankResponse , RerankResult , ResponsesRequest , DEFAULT_MODEL_NAME ,
1313} ;
1414use crate :: routers:: header_utils;
1515use crate :: routers:: http:: dp_utils;
@@ -154,6 +154,10 @@ impl Router {
154154 )
155155 } ;
156156 ctx. worker_registry . register ( worker_arc. clone ( ) ) ;
157+ if !models. is_empty ( ) {
158+ ctx. worker_registry
159+ . sync_worker_models ( worker_arc. url ( ) , & models) ;
160+ }
157161
158162 // Notify PolicyRegistry about the new worker
159163 let model_id = worker_arc. model_id ( ) ;
@@ -538,6 +542,39 @@ impl Router {
538542 } )
539543 }
540544
545+ fn normalize_model_id ( model_id : Option < & str > ) -> Option < & str > {
546+ let model = model_id?. trim ( ) ;
547+ ( !model. is_empty ( ) ) . then_some ( model)
548+ }
549+
550+ fn body_model_for_route < ' a > ( route : & str , model_id : Option < & ' a str > ) -> Option < & ' a str > {
551+ let model = Self :: normalize_model_id ( model_id) ?;
552+ if route == "/v1/rerank" && model == DEFAULT_MODEL_NAME {
553+ return None ;
554+ }
555+ Some ( model)
556+ }
557+
558+ fn resolve_body_model_filter < ' a > (
559+ & self ,
560+ route : & str ,
561+ model_id : Option < & ' a str > ,
562+ run_id : Option < & str > ,
563+ ) -> Option < & ' a str > {
564+ let model = Self :: body_model_for_route ( route, model_id) ?;
565+
566+ if run_id. is_some ( ) || !self . worker_registry . get_by_model_fast ( model) . is_empty ( ) {
567+ return Some ( model) ;
568+ }
569+
570+ debug ! (
571+ model_id = %model,
572+ route,
573+ "body model is not indexed; routing without a model filter"
574+ ) ;
575+ None
576+ }
577+
541578 /// Select worker for a specific model considering circuit breaker state
542579 fn select_worker_for_model (
543580 & self ,
@@ -586,11 +623,24 @@ impl Router {
586623 let text = typed_req. extract_text_for_routing ( ) ;
587624 let run_id = run_id. map ( |s| s. to_string ( ) ) ;
588625
626+ // Fall back to the body's `model` field when the caller doesn't pass one, but
627+ // only use it as a routing filter when the registry has already indexed that
628+ // model. This keeps compatibility for generic upstream model validation while
629+ // still preventing known LoRA requests from being sent to workers that have not
630+ // loaded the adapter. Run-scoped requests keep the body model as a hard filter.
631+ let effective_model_id = Self :: normalize_model_id ( model_id) . or_else ( || {
632+ self . resolve_body_model_filter ( route, typed_req. get_model ( ) , run_id. as_deref ( ) )
633+ } ) ;
634+
589635 let response = RetryExecutor :: execute_response_with_retry (
590636 & self . retry_config ,
591637 // operation per attempt
592638 |_: u32 | async {
593- let worker = match self . select_worker_for_model ( model_id, Some ( & text) , headers) {
639+ let worker = match self . select_worker_for_model (
640+ effective_model_id,
641+ Some ( & text) ,
642+ headers,
643+ ) {
594644 Some ( w) => w,
595645 None => {
596646 RouterMetrics :: record_request_error ( route, "no_available_workers" ) ;
@@ -604,7 +654,7 @@ impl Router {
604654
605655 // Optional load tracking for cache-aware policy
606656 // Get the policy for this model to check if it's cache-aware
607- let policy = match model_id {
657+ let policy = match effective_model_id {
608658 Some ( model) => self . policy_registry . get_policy_or_default ( model) ,
609659 None => self . policy_registry . get_default_policy ( ) ,
610660 } ;
@@ -1209,6 +1259,10 @@ impl Router {
12091259
12101260 let worker_arc: Arc < dyn Worker > = Arc :: new ( new_worker) ;
12111261 self . worker_registry . register ( worker_arc. clone ( ) ) ;
1262+ if !models. is_empty ( ) {
1263+ self . worker_registry
1264+ . sync_worker_models ( worker_arc. url ( ) , & models) ;
1265+ }
12121266
12131267 // Notify PolicyRegistry about the new worker
12141268 let model_id = worker_arc. model_id ( ) ;
@@ -1246,6 +1300,10 @@ impl Router {
12461300
12471301 let worker_arc = Arc :: new ( new_worker) ;
12481302 self . worker_registry . register ( worker_arc. clone ( ) ) ;
1303+ if !models. is_empty ( ) {
1304+ self . worker_registry
1305+ . sync_worker_models ( worker_arc. url ( ) , & models) ;
1306+ }
12491307
12501308 // Notify PolicyRegistry about the new worker
12511309 let model_id = worker_arc. model_id ( ) ;
@@ -2273,6 +2331,268 @@ mod tests {
22732331 ) ;
22742332 }
22752333
2334+ /// Verify select_worker_for_model only picks workers whose model_index
2335+ /// includes the requested model. This is the safety property the
2336+ /// `effective_model_id` fallback in route_typed_request depends on: a
2337+ /// request body that asks for a LoRA adapter must never be dispatched
2338+ /// to a pod that hasn't loaded it.
2339+ #[ test]
2340+ fn test_select_worker_filters_to_pods_with_loaded_lora ( ) {
2341+ let router = create_test_regular_router ( ) ;
2342+
2343+ // Both workers initially serve the base model only (this matches the
2344+ // state of a freshly-scaled-up vLLM pod before the orchestrator has
2345+ // pushed the LoRA adapter).
2346+ router
2347+ . worker_registry
2348+ . sync_worker_models ( "http://worker1:8080" , & [ "base-model" . to_string ( ) ] ) ;
2349+ router
2350+ . worker_registry
2351+ . sync_worker_models ( "http://worker2:8080" , & [ "base-model" . to_string ( ) ] ) ;
2352+
2353+ // Simulate orchestrator pushing the LoRA to worker1 only.
2354+ router. worker_registry . sync_worker_models (
2355+ "http://worker1:8080" ,
2356+ & [ "base-model" . to_string ( ) , "rft-run-1" . to_string ( ) ] ,
2357+ ) ;
2358+
2359+ // A request for the LoRA must land on worker1.
2360+ for _ in 0 ..20 {
2361+ let worker = router
2362+ . select_worker_for_model ( Some ( "rft-run-1" ) , Some ( r#"{"prompt":"x"}"# ) , None )
2363+ . expect ( "a worker is available" ) ;
2364+ assert_eq ! (
2365+ worker. url( ) ,
2366+ "http://worker1:8080" ,
2367+ "LoRA request leaked to a worker without the adapter loaded"
2368+ ) ;
2369+ }
2370+
2371+ // The base model should still see both workers.
2372+ let base_workers = router. worker_registry . get_by_model_fast ( "base-model" ) ;
2373+ assert_eq ! ( base_workers. len( ) , 2 ) ;
2374+ }
2375+
2376+ #[ test]
2377+ fn test_body_model_filter_only_uses_indexed_models ( ) {
2378+ let router = create_test_regular_router ( ) ;
2379+
2380+ router. worker_registry . sync_worker_models (
2381+ "http://worker1:8080" ,
2382+ & [ "base-model" . to_string ( ) , "rft-run-1" . to_string ( ) ] ,
2383+ ) ;
2384+
2385+ assert_eq ! (
2386+ router. resolve_body_model_filter(
2387+ "/v1/chat/completions" ,
2388+ Some ( "rft-run-1" ) ,
2389+ None ,
2390+ ) ,
2391+ Some ( "rft-run-1" )
2392+ ) ;
2393+ assert_eq ! (
2394+ router. resolve_body_model_filter(
2395+ "/v1/chat/completions" ,
2396+ Some ( "not-indexed" ) ,
2397+ None ,
2398+ ) ,
2399+ None
2400+ ) ;
2401+ assert_eq ! (
2402+ router. resolve_body_model_filter(
2403+ "/v1/chat/completions" ,
2404+ Some ( "not-indexed" ) ,
2405+ Some ( "run-123" ) ,
2406+ ) ,
2407+ Some ( "not-indexed" )
2408+ ) ;
2409+ assert_eq ! (
2410+ router. resolve_body_model_filter( "/v1/rerank" , Some ( DEFAULT_MODEL_NAME ) , None ) ,
2411+ None
2412+ ) ;
2413+ }
2414+
2415+ async fn start_counting_chat_worker (
2416+ request_count : Arc < std:: sync:: atomic:: AtomicUsize > ,
2417+ ) -> ( String , tokio:: task:: JoinHandle < ( ) > ) {
2418+ use axum:: { routing:: post, Json , Router as AxumRouter } ;
2419+ use tokio:: net:: TcpListener ;
2420+
2421+ let listener = TcpListener :: bind ( "127.0.0.1:0" ) . await . unwrap ( ) ;
2422+ let addr = listener. local_addr ( ) . unwrap ( ) ;
2423+ let app = AxumRouter :: new ( ) . route (
2424+ "/v1/chat/completions" ,
2425+ post ( move |Json ( body) : Json < serde_json:: Value > | {
2426+ let request_count = request_count. clone ( ) ;
2427+ async move {
2428+ request_count. fetch_add ( 1 , std:: sync:: atomic:: Ordering :: SeqCst ) ;
2429+ let model = body
2430+ . get ( "model" )
2431+ . and_then ( |v| v. as_str ( ) )
2432+ . unwrap_or ( "unknown" )
2433+ . to_string ( ) ;
2434+ Json ( serde_json:: json!( {
2435+ "id" : "chatcmpl-test" ,
2436+ "object" : "chat.completion" ,
2437+ "created" : 0 ,
2438+ "model" : model,
2439+ "choices" : [ {
2440+ "index" : 0 ,
2441+ "message" : {
2442+ "role" : "assistant" ,
2443+ "content" : "ok"
2444+ } ,
2445+ "finish_reason" : "stop"
2446+ } ] ,
2447+ "usage" : {
2448+ "prompt_tokens" : 1 ,
2449+ "completion_tokens" : 1 ,
2450+ "total_tokens" : 2
2451+ }
2452+ } ) )
2453+ }
2454+ } ) ,
2455+ ) ;
2456+ let handle = tokio:: spawn ( async move {
2457+ axum:: serve ( listener, app) . await . unwrap ( ) ;
2458+ } ) ;
2459+ tokio:: time:: sleep ( std:: time:: Duration :: from_millis ( 10 ) ) . await ;
2460+ ( format ! ( "http://{}" , addr) , handle)
2461+ }
2462+
2463+ async fn start_model_listing_worker (
2464+ models : Vec < & ' static str > ,
2465+ ) -> ( String , tokio:: task:: JoinHandle < ( ) > ) {
2466+ use axum:: { routing:: get, Json , Router as AxumRouter } ;
2467+ use tokio:: net:: TcpListener ;
2468+
2469+ let models: Vec < String > = models. into_iter ( ) . map ( String :: from) . collect ( ) ;
2470+ let listener = TcpListener :: bind ( "127.0.0.1:0" ) . await . unwrap ( ) ;
2471+ let addr = listener. local_addr ( ) . unwrap ( ) ;
2472+ let app = AxumRouter :: new ( )
2473+ . route ( "/health" , get ( || async { StatusCode :: OK } ) )
2474+ . route (
2475+ "/v1/models" ,
2476+ get ( move || {
2477+ let models = models. clone ( ) ;
2478+ async move {
2479+ let data: Vec < serde_json:: Value > = models
2480+ . iter ( )
2481+ . map ( |id| serde_json:: json!( { "id" : id, "object" : "model" } ) )
2482+ . collect ( ) ;
2483+ Json ( serde_json:: json!( { "object" : "list" , "data" : data} ) )
2484+ }
2485+ } ) ,
2486+ ) ;
2487+ let handle = tokio:: spawn ( async move {
2488+ axum:: serve ( listener, app) . await . unwrap ( ) ;
2489+ } ) ;
2490+ tokio:: time:: sleep ( std:: time:: Duration :: from_millis ( 10 ) ) . await ;
2491+ ( format ! ( "http://{}" , addr) , handle)
2492+ }
2493+
2494+ #[ tokio:: test]
2495+ async fn test_route_chat_with_body_model_filters_to_loaded_lora_worker ( ) {
2496+ let base_only_count = Arc :: new ( std:: sync:: atomic:: AtomicUsize :: new ( 0 ) ) ;
2497+ let lora_count = Arc :: new ( std:: sync:: atomic:: AtomicUsize :: new ( 0 ) ) ;
2498+ let ( base_only_url, _base_only_handle) =
2499+ start_counting_chat_worker ( base_only_count. clone ( ) ) . await ;
2500+ let ( lora_url, _lora_handle) = start_counting_chat_worker ( lora_count. clone ( ) ) . await ;
2501+
2502+ let worker_registry = Arc :: new ( WorkerRegistry :: new ( ) ) ;
2503+ worker_registry. register ( Arc :: new ( BasicWorker :: new (
2504+ base_only_url. clone ( ) ,
2505+ WorkerType :: Regular ,
2506+ ) ) ) ;
2507+ worker_registry. register ( Arc :: new ( BasicWorker :: new (
2508+ lora_url. clone ( ) ,
2509+ WorkerType :: Regular ,
2510+ ) ) ) ;
2511+
2512+ worker_registry. sync_worker_models ( & base_only_url, & [ "base-model" . to_string ( ) ] ) ;
2513+ worker_registry. sync_worker_models (
2514+ & lora_url,
2515+ & [ "base-model" . to_string ( ) , "rft-run-1" . to_string ( ) ] ,
2516+ ) ;
2517+
2518+ let policy_registry = Arc :: new ( PolicyRegistry :: new (
2519+ crate :: config:: types:: PolicyConfig :: RoundRobin ,
2520+ ) ) ;
2521+ let ( _, rx) = tokio:: sync:: watch:: channel ( HashMap :: new ( ) ) ;
2522+ let router = Router {
2523+ worker_registry,
2524+ policy_registry,
2525+ worker_startup_timeout_secs : 5 ,
2526+ worker_startup_check_interval_secs : 1 ,
2527+ intra_node_data_parallel_size : 1 ,
2528+ api_key : None ,
2529+ client : Client :: new ( ) ,
2530+ retry_config : RetryConfig :: default ( ) ,
2531+ circuit_breaker_config : CircuitBreakerConfig :: default ( ) ,
2532+ _worker_loads : Arc :: new ( rx) ,
2533+ _load_monitor_handle : None ,
2534+ } ;
2535+
2536+ let request: ChatCompletionRequest = serde_json:: from_value ( serde_json:: json!( {
2537+ "model" : "rft-run-1" ,
2538+ "messages" : [ { "role" : "user" , "content" : "hello" } ]
2539+ } ) )
2540+ . unwrap ( ) ;
2541+
2542+ for _ in 0 ..6 {
2543+ let response = router. route_chat ( None , & request, None , None ) . await ;
2544+ assert_eq ! ( response. status( ) , StatusCode :: OK ) ;
2545+ }
2546+
2547+ assert_eq ! (
2548+ base_only_count. load( std:: sync:: atomic:: Ordering :: SeqCst ) ,
2549+ 0 ,
2550+ "body model was ignored and the request reached a worker without the LoRA"
2551+ ) ;
2552+ assert_eq ! (
2553+ lora_count. load( std:: sync:: atomic:: Ordering :: SeqCst ) ,
2554+ 6 ,
2555+ "all requests should be routed to the worker indexed for the requested LoRA"
2556+ ) ;
2557+ }
2558+
2559+ #[ tokio:: test]
2560+ async fn test_router_new_indexes_all_discovered_models_immediately ( ) {
2561+ let ( url, _handle) =
2562+ start_model_listing_worker ( vec ! [ "base-model" , "rft-run-1" ] ) . await ;
2563+ let config = crate :: config:: types:: RouterConfig {
2564+ mode : crate :: config:: types:: RoutingMode :: Regular {
2565+ worker_urls : vec ! [ url. clone( ) ] ,
2566+ } ,
2567+ policy : crate :: config:: types:: PolicyConfig :: RoundRobin ,
2568+ worker_startup_timeout_secs : 2 ,
2569+ worker_startup_check_interval_secs : 1 ,
2570+ ..Default :: default ( )
2571+ } ;
2572+ let ctx = Arc :: new (
2573+ crate :: server:: AppContext :: new (
2574+ config. clone ( ) ,
2575+ Client :: new ( ) ,
2576+ config. max_concurrent_requests ,
2577+ config. rate_limit_tokens_per_second ,
2578+ config. api_key_validation_urls . clone ( ) ,
2579+ )
2580+ . unwrap ( ) ,
2581+ ) ;
2582+
2583+ let router = Router :: new ( vec ! [ url. clone( ) ] , & ctx) . await . unwrap ( ) ;
2584+
2585+ let base_workers = router. worker_registry . get_by_model_fast ( "base-model" ) ;
2586+ let lora_workers = router. worker_registry . get_by_model_fast ( "rft-run-1" ) ;
2587+ assert_eq ! ( base_workers. len( ) , 1 ) ;
2588+ assert_eq ! ( lora_workers. len( ) , 1 ) ;
2589+
2590+ let worker = router
2591+ . select_worker_for_model ( Some ( "rft-run-1" ) , Some ( r#"{"prompt":"x"}"# ) , None )
2592+ . expect ( "LoRA model should be routable immediately after registration" ) ;
2593+ assert_eq ! ( worker. url( ) , url) ;
2594+ }
2595+
22762596 #[ test]
22772597 fn test_inline_header_conversion_matches_headers_to_request_headers ( ) {
22782598 // Verify that the inline header conversion pattern used in pd_router and
0 commit comments