Skip to content

Commit 51638cd

Browse files
authored
fix: TITO route silently rewritten to /v1/chat/completions (#17)
* fix: remove explicit TITO route so transparent proxy preserves path * fix: preserve /v1/chat/completions/tokens path when routing to backend * test: verify TITO route forwards to correct backend path * fix: override route_chat_tokens in PD routers to forward correct path * refactor: extract route_chat_with_path helper in PDRouter to avoid duplication * fix: use unique port for TITO test to avoid conflicts * refactor: simplify mock TITO handler to delegate to chat_completions_handler * fix: avoid double capture in mock TITO handler * fix: restore comments stripped during refactor
1 parent 4051968 commit 51638cd

8 files changed

Lines changed: 304 additions & 118 deletions

File tree

src/routers/http/pd_router.rs

Lines changed: 52 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1847,6 +1847,46 @@ impl PDRouter {
18471847
);
18481848
Ok(bytes::Bytes::from(merged_str))
18491849
}
1850+
1851+
/// Internal helper for routing chat requests with a configurable backend path.
1852+
async fn route_chat_with_path(
1853+
&self,
1854+
headers: Option<&HeaderMap>,
1855+
body: &ChatCompletionRequest,
1856+
model_id: Option<&str>,
1857+
run_id: Option<&str>,
1858+
route: &'static str,
1859+
) -> Response {
1860+
let is_stream = body.stream;
1861+
let return_logprob = body.logprobs;
1862+
1863+
let request_text = if self.policies_need_request_text() {
1864+
body.messages.first().and_then(|msg| match msg {
1865+
ChatMessage::User { content, .. } => match content {
1866+
UserMessageContent::Text(text) => Some(text.clone()),
1867+
UserMessageContent::Parts(_) => None,
1868+
},
1869+
ChatMessage::System { content, .. } => Some(content.clone()),
1870+
_ => None,
1871+
})
1872+
} else {
1873+
None
1874+
};
1875+
1876+
let batch_size = Self::get_chat_batch_size(body);
1877+
1878+
let context = PDRequestContext {
1879+
route,
1880+
batch_size,
1881+
is_stream,
1882+
return_logprob,
1883+
request_text,
1884+
model_id,
1885+
run_id,
1886+
};
1887+
1888+
self.execute_dual_dispatch(headers, body, context).await
1889+
}
18501890
}
18511891

18521892
// Helper functions
@@ -2114,40 +2154,19 @@ impl RouterTrait for PDRouter {
21142154
model_id: Option<&str>,
21152155
run_id: Option<&str>,
21162156
) -> Response {
2117-
// Extract parameters
2118-
let is_stream = body.stream;
2119-
let return_logprob = body.logprobs;
2120-
2121-
// Extract text for cache-aware routing
2122-
let request_text = if self.policies_need_request_text() {
2123-
body.messages.first().and_then(|msg| match msg {
2124-
ChatMessage::User { content, .. } => match content {
2125-
UserMessageContent::Text(text) => Some(text.clone()),
2126-
UserMessageContent::Parts(_) => None,
2127-
},
2128-
ChatMessage::System { content, .. } => Some(content.clone()),
2129-
_ => None,
2130-
})
2131-
} else {
2132-
None
2133-
};
2134-
2135-
// Calculate batch size
2136-
let batch_size = Self::get_chat_batch_size(body);
2137-
2138-
// Create context
2139-
let context = PDRequestContext {
2140-
route: "/v1/chat/completions",
2141-
batch_size,
2142-
is_stream,
2143-
return_logprob,
2144-
request_text,
2145-
model_id,
2146-
run_id,
2147-
};
2157+
self.route_chat_with_path(headers, body, model_id, run_id, "/v1/chat/completions")
2158+
.await
2159+
}
21482160

2149-
// Execute with retry and bootstrap injection
2150-
self.execute_dual_dispatch(headers, body, context).await
2161+
async fn route_chat_tokens(
2162+
&self,
2163+
headers: Option<&HeaderMap>,
2164+
body: &ChatCompletionRequest,
2165+
model_id: Option<&str>,
2166+
run_id: Option<&str>,
2167+
) -> Response {
2168+
self.route_chat_with_path(headers, body, model_id, run_id, "/v1/chat/completions/tokens")
2169+
.await
21512170
}
21522171

21532172
async fn route_completion(

src/routers/http/router.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1535,6 +1535,17 @@ impl RouterTrait for Router {
15351535
.await
15361536
}
15371537

1538+
async fn route_chat_tokens(
1539+
&self,
1540+
headers: Option<&HeaderMap>,
1541+
body: &ChatCompletionRequest,
1542+
model_id: Option<&str>,
1543+
run_id: Option<&str>,
1544+
) -> Response {
1545+
self.route_typed_request(headers, body, "/v1/chat/completions/tokens", model_id, run_id)
1546+
.await
1547+
}
1548+
15381549
async fn route_completion(
15391550
&self,
15401551
headers: Option<&HeaderMap>,

src/routers/http/vllm_pd_router.rs

Lines changed: 86 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,108 +1270,48 @@ impl VllmPDRouter {
12701270
pub fn worker_registry(&self) -> &crate::core::WorkerRegistry {
12711271
&self.pd_router.worker_registry
12721272
}
1273-
}
1274-
1275-
// Delegate most RouterTrait methods to the underlying PDRouter,
1276-
// but override specific ones for vLLM behavior
1277-
#[async_trait]
1278-
impl RouterTrait for VllmPDRouter {
1279-
fn as_any(&self) -> &dyn std::any::Any {
1280-
self
1281-
}
1282-
1283-
async fn health(&self, req: Request<Body>) -> Response {
1284-
self.pd_router.health(req).await
1285-
}
1286-
1287-
async fn health_generate(&self, req: Request<Body>) -> Response {
1288-
self.pd_router.health_generate(req).await
1289-
}
1290-
1291-
async fn get_server_info(&self, req: Request<Body>) -> Response {
1292-
self.pd_router.get_server_info(req).await
1293-
}
1294-
1295-
async fn get_models(&self, req: Request<Body>) -> Response {
1296-
self.pd_router.get_models(req).await
1297-
}
1298-
1299-
async fn get_model_info(&self, req: Request<Body>) -> Response {
1300-
self.pd_router.get_model_info(req).await
1301-
}
13021273

1303-
async fn route_generate(
1304-
&self,
1305-
headers: Option<&HeaderMap>,
1306-
body: &crate::protocols::spec::GenerateRequest,
1307-
model_id: Option<&str>,
1308-
run_id: Option<&str>,
1309-
) -> Response {
1310-
self.pd_router
1311-
.route_generate(headers, body, model_id, run_id)
1312-
.await
1313-
}
1314-
1315-
// Override OpenAI-compatible routes for vLLM two-stage processing
1316-
async fn route_chat(
1274+
/// Internal helper for routing chat requests with a configurable backend path.
1275+
async fn route_chat_with_path(
13171276
&self,
13181277
headers: Option<&HeaderMap>,
13191278
body: &crate::protocols::spec::ChatCompletionRequest,
1320-
_model_id: Option<&str>,
13211279
run_id: Option<&str>,
1280+
route: &str,
13221281
) -> Response {
13231282
info!(
13241283
"vLLM route_chat called, use_discovery={}",
13251284
self.use_discovery
13261285
);
13271286

1287+
let request_json = match serde_json::to_value(body) {
1288+
Ok(json) => {
1289+
debug!(
1290+
"Serialized chat request: {}",
1291+
serde_json::to_string_pretty(&json).unwrap_or_default()
1292+
);
1293+
json
1294+
}
1295+
Err(e) => {
1296+
return (
1297+
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
1298+
format!("Serialization error: {}", e),
1299+
)
1300+
.into_response()
1301+
}
1302+
};
1303+
13281304
if self.use_discovery {
13291305
// Discovery mode - use vLLM-specific two-stage processing
13301306
info!("Using service discovery mode, processing vLLM two-stage request");
13311307

1332-
// Convert to generic request and use vLLM processing
1333-
let request_json = match serde_json::to_value(body) {
1334-
Ok(json) => {
1335-
debug!(
1336-
"Serialized chat request: {}",
1337-
serde_json::to_string_pretty(&json).unwrap_or_default()
1338-
);
1339-
json
1340-
}
1341-
Err(e) => {
1342-
return (
1343-
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
1344-
format!("Serialization error: {}", e),
1345-
)
1346-
.into_response()
1347-
}
1348-
};
1349-
13501308
// Process vLLM two-stage request with service discovery
1351-
self.process_vllm_request(request_json, "/v1/chat/completions", headers, run_id)
1309+
self.process_vllm_request(request_json, route, headers, run_id)
13521310
.await
13531311
} else {
13541312
// Direct URL mode - implement routing logic here (not delegating to PDRouter)
13551313
info!("Using direct URL mode with VllmPDRouter's own routing logic");
13561314

1357-
// Convert request to JSON
1358-
let request_json = match serde_json::to_value(body) {
1359-
Ok(json) => {
1360-
debug!(
1361-
"Serialized chat request: {}",
1362-
serde_json::to_string_pretty(&json).unwrap_or_default()
1363-
);
1364-
json
1365-
}
1366-
Err(e) => {
1367-
return (
1368-
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
1369-
format!("Serialization error: {}", e),
1370-
)
1371-
.into_response()
1372-
}
1373-
};
1374-
13751315
// Get prefill and decode workers from worker_registry
13761316
let prefill_workers = self.pd_router.worker_registry.get_prefill_workers();
13771317
let decode_workers = self.pd_router.worker_registry.get_decode_workers();
@@ -1463,7 +1403,7 @@ impl RouterTrait for VllmPDRouter {
14631403
request_json,
14641404
prefill_worker.clone(),
14651405
decode_worker.clone(),
1466-
"/v1/chat/completions",
1406+
route,
14671407
headers,
14681408
run_id,
14691409
)
@@ -1485,6 +1425,70 @@ impl RouterTrait for VllmPDRouter {
14851425
resp
14861426
}
14871427
}
1428+
}
1429+
1430+
// Delegate most RouterTrait methods to the underlying PDRouter,
1431+
// but override specific ones for vLLM behavior
1432+
#[async_trait]
1433+
impl RouterTrait for VllmPDRouter {
1434+
fn as_any(&self) -> &dyn std::any::Any {
1435+
self
1436+
}
1437+
1438+
async fn health(&self, req: Request<Body>) -> Response {
1439+
self.pd_router.health(req).await
1440+
}
1441+
1442+
async fn health_generate(&self, req: Request<Body>) -> Response {
1443+
self.pd_router.health_generate(req).await
1444+
}
1445+
1446+
async fn get_server_info(&self, req: Request<Body>) -> Response {
1447+
self.pd_router.get_server_info(req).await
1448+
}
1449+
1450+
async fn get_models(&self, req: Request<Body>) -> Response {
1451+
self.pd_router.get_models(req).await
1452+
}
1453+
1454+
async fn get_model_info(&self, req: Request<Body>) -> Response {
1455+
self.pd_router.get_model_info(req).await
1456+
}
1457+
1458+
async fn route_generate(
1459+
&self,
1460+
headers: Option<&HeaderMap>,
1461+
body: &crate::protocols::spec::GenerateRequest,
1462+
model_id: Option<&str>,
1463+
run_id: Option<&str>,
1464+
) -> Response {
1465+
self.pd_router
1466+
.route_generate(headers, body, model_id, run_id)
1467+
.await
1468+
}
1469+
1470+
// Override OpenAI-compatible routes for vLLM two-stage processing
1471+
async fn route_chat(
1472+
&self,
1473+
headers: Option<&HeaderMap>,
1474+
body: &crate::protocols::spec::ChatCompletionRequest,
1475+
_model_id: Option<&str>,
1476+
run_id: Option<&str>,
1477+
) -> Response {
1478+
self.route_chat_with_path(headers, body, run_id, "/v1/chat/completions")
1479+
.await
1480+
}
1481+
1482+
async fn route_chat_tokens(
1483+
&self,
1484+
headers: Option<&HeaderMap>,
1485+
body: &crate::protocols::spec::ChatCompletionRequest,
1486+
_model_id: Option<&str>,
1487+
run_id: Option<&str>,
1488+
) -> Response {
1489+
self.route_chat_with_path(headers, body, run_id, "/v1/chat/completions/tokens")
1490+
.await
1491+
}
14881492

14891493
async fn route_completion(
14901494
&self,

src/routers/mod.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,18 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
8282
run_id: Option<&str>,
8383
) -> Response;
8484

85+
/// Route a chat completion tokens (TITO) request.
86+
/// Defaults to route_chat; override to forward to /v1/chat/completions/tokens.
87+
async fn route_chat_tokens(
88+
&self,
89+
headers: Option<&HeaderMap>,
90+
body: &ChatCompletionRequest,
91+
model_id: Option<&str>,
92+
run_id: Option<&str>,
93+
) -> Response {
94+
self.route_chat(headers, body, model_id, run_id).await
95+
}
96+
8597
/// Route a completion request
8698
async fn route_completion(
8799
&self,

src/routers/router_manager.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,29 @@ impl RouterTrait for RouterManager {
612612
}
613613
}
614614

615+
async fn route_chat_tokens(
616+
&self,
617+
headers: Option<&HeaderMap>,
618+
body: &ChatCompletionRequest,
619+
_model_id: Option<&str>,
620+
_run_id: Option<&str>,
621+
) -> Response {
622+
let model_id = body.model.as_deref();
623+
let router = self.select_router_for_request(headers, model_id);
624+
625+
if let Some(router) = router {
626+
router
627+
.route_chat_tokens(headers, body, model_id, _run_id)
628+
.await
629+
} else {
630+
let msg = match model_id {
631+
Some(m) => format!("Model '{}' not found or no router available", m),
632+
None => "No routers registered to handle this request".to_string(),
633+
};
634+
(StatusCode::NOT_FOUND, msg).into_response()
635+
}
636+
}
637+
615638
/// Route a completion request
616639
async fn route_completion(
617640
&self,

0 commit comments

Comments
 (0)