Skip to content

Commit 0ab427d

Browse files
Gruner-ateroAmit Gruner
andauthored
[SMG] Add /v1/models fallback for model name discovery (#25293)
Co-authored-by: Amit Gruner <agruner@crusoe.ai>
1 parent ba2ffcf commit 0ab427d

4 files changed

Lines changed: 233 additions & 0 deletions

File tree

sgl-model-gateway/src/core/steps/worker/local/discover_metadata.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,45 @@ pub async fn get_model_info(url: &str, api_key: Option<&str>) -> Result<ModelInf
169169
.map_err(|e| format!("Failed to parse response from {}: {}", model_info_url, e))
170170
}
171171

172+
/// Get model name from /v1/models endpoint (OpenAI-compatible fallback).
173+
async fn get_model_name_from_v1_models(url: &str, api_key: Option<&str>) -> Result<String, String> {
174+
let base_url = url.trim_end_matches('/');
175+
let models_url = format!("{}/v1/models", base_url);
176+
177+
let mut req = HTTP_CLIENT.get(&models_url);
178+
if let Some(key) = api_key {
179+
req = req.bearer_auth(key);
180+
}
181+
182+
let response = req
183+
.send()
184+
.await
185+
.map_err(|e| format!("Failed to connect to {}: {}", models_url, e))?;
186+
187+
if !response.status().is_success() {
188+
return Err(format!(
189+
"Server returned status {} from {}",
190+
response.status(),
191+
models_url
192+
));
193+
}
194+
195+
let json: Value = response
196+
.json()
197+
.await
198+
.map_err(|e| format!("Failed to parse response from {}: {}", models_url, e))?;
199+
200+
json["data"]
201+
.as_array()
202+
.and_then(|arr| {
203+
arr.iter()
204+
.find(|entry| entry["object"].as_str() == Some("model"))
205+
})
206+
.and_then(|entry| entry["id"].as_str())
207+
.map(|s| s.to_string())
208+
.ok_or_else(|| format!("No model found in response from {}", models_url))
209+
}
210+
172211
/// Fetch gRPC metadata (returns labels and detected runtime type).
173212
async fn fetch_grpc_metadata(
174213
url: &str,
@@ -283,6 +322,15 @@ impl StepExecutor<LocalWorkerWorkflowData> for DiscoverMetadataStep {
283322
}
284323
}
285324

325+
// If no model name discovered yet, try /v1/models as fallback
326+
if !labels.contains_key("model_path") && !labels.contains_key("served_model_name") {
327+
if let Ok(model_name) =
328+
get_model_name_from_v1_models(&config.url, config.api_key.as_deref()).await
329+
{
330+
labels.insert("served_model_name".to_string(), model_name);
331+
}
332+
}
333+
286334
Ok((labels, None))
287335
}
288336
ConnectionMode::Grpc { .. } => {

sgl-model-gateway/tests/common/mock_worker.rs

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1781,3 +1781,93 @@ impl Default for MockWorkerConfig {
17811781
}
17821782
}
17831783
}
1784+
1785+
/// A minimal OpenAI-compatible mock worker that does not implement /server_info or /model_info.
1786+
/// Used to test fallback model name discovery via /v1/models.
1787+
pub struct OpenAiOnlyMockWorker {
1788+
port: u16,
1789+
model_name: String,
1790+
shutdown_handle: Option<tokio::task::JoinHandle<()>>,
1791+
shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
1792+
}
1793+
1794+
impl OpenAiOnlyMockWorker {
1795+
pub fn new(model_name: impl Into<String>) -> Self {
1796+
Self {
1797+
port: 0,
1798+
model_name: model_name.into(),
1799+
shutdown_handle: None,
1800+
shutdown_tx: None,
1801+
}
1802+
}
1803+
1804+
pub async fn start(&mut self) -> Result<String, Box<dyn std::error::Error>> {
1805+
let listener = std::net::TcpListener::bind("127.0.0.1:0")?;
1806+
self.port = listener.local_addr()?.port();
1807+
drop(listener);
1808+
1809+
let model_name = self.model_name.clone();
1810+
let port = self.port;
1811+
1812+
let app = Router::new()
1813+
.route("/health", get(|| async { Json(json!({ "status": "healthy" })) }))
1814+
.route("/health_generate", get(|| async { Json(json!({ "status": "ok" })) }))
1815+
.route(
1816+
"/v1/models",
1817+
get(move || {
1818+
let model_name = model_name.clone();
1819+
async move {
1820+
let ts = SystemTime::now()
1821+
.duration_since(UNIX_EPOCH)
1822+
.unwrap()
1823+
.as_secs();
1824+
Json(json!({
1825+
"object": "list",
1826+
"data": [{ "id": model_name, "object": "model", "created": ts, "owned_by": "owner" }]
1827+
}))
1828+
}
1829+
}),
1830+
);
1831+
1832+
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
1833+
self.shutdown_tx = Some(shutdown_tx);
1834+
1835+
let handle = tokio::spawn(async move {
1836+
let listener = match tokio::net::TcpListener::bind(("127.0.0.1", port)).await {
1837+
Ok(l) => l,
1838+
Err(e) => {
1839+
eprintln!("Failed to bind to port {}: {}", port, e);
1840+
return;
1841+
}
1842+
};
1843+
let server = axum::serve(listener, app).with_graceful_shutdown(async move {
1844+
let _ = shutdown_rx.await;
1845+
});
1846+
if let Err(e) = server.await {
1847+
eprintln!("Server error: {}", e);
1848+
}
1849+
});
1850+
1851+
self.shutdown_handle = Some(handle);
1852+
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
1853+
1854+
Ok(format!("http://127.0.0.1:{}", self.port))
1855+
}
1856+
1857+
pub async fn stop(&mut self) {
1858+
if let Some(tx) = self.shutdown_tx.take() {
1859+
let _ = tx.send(());
1860+
}
1861+
if let Some(h) = self.shutdown_handle.take() {
1862+
let _ = tokio::time::timeout(tokio::time::Duration::from_secs(5), h).await;
1863+
}
1864+
}
1865+
}
1866+
1867+
impl Drop for OpenAiOnlyMockWorker {
1868+
fn drop(&mut self) {
1869+
if let Some(tx) = self.shutdown_tx.take() {
1870+
let _ = tx.send(());
1871+
}
1872+
}
1873+
}

sgl-model-gateway/tests/routing/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ pub mod power_of_two_test;
1111
pub mod service_discovery_test;
1212
pub mod test_openai_routing;
1313
pub mod test_pd_routing;
14+
pub mod worker_discovery_test;
1415
pub mod worker_management_test;
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
//! Worker metadata discovery integration tests.
2+
3+
use smg::{config::RouterConfig, core::Job};
4+
5+
use crate::common::{
6+
create_test_context,
7+
mock_worker::{HealthStatus, MockWorkerConfig, OpenAiOnlyMockWorker, WorkerType},
8+
AppTestContext,
9+
};
10+
11+
#[cfg(test)]
12+
mod worker_discovery_tests {
13+
use super::*;
14+
15+
/// Normal path: model name is discovered from /server_info.
16+
#[tokio::test]
17+
async fn test_model_name_discovered_via_server_info() {
18+
let ctx = AppTestContext::new(vec![MockWorkerConfig {
19+
port: 0,
20+
worker_type: WorkerType::Regular,
21+
health_status: HealthStatus::Healthy,
22+
response_delay_ms: 0,
23+
fail_rate: 0.0,
24+
}])
25+
.await;
26+
27+
let discovered_models = ctx.app_context.worker_registry.get_models();
28+
assert!(
29+
discovered_models.contains(&"mock-model-path".to_string()),
30+
"Expected 'mock-model-path' discovered via /server_info, got: {:?}",
31+
discovered_models
32+
);
33+
34+
ctx.shutdown().await;
35+
}
36+
37+
/// Fallback path: when /server_info is unavailable, model name is discovered via /v1/models.
38+
#[tokio::test]
39+
async fn test_model_name_discovered_via_v1_models_fallback() {
40+
let mut worker = OpenAiOnlyMockWorker::new("my-model");
41+
let url = worker.start().await.unwrap();
42+
43+
let config = RouterConfig::builder()
44+
.regular_mode(vec![url.clone()])
45+
.random_policy()
46+
.host("127.0.0.1")
47+
.port(0)
48+
.max_payload_size(256 * 1024 * 1024)
49+
.request_timeout_secs(600)
50+
.worker_startup_timeout_secs(5)
51+
.worker_startup_check_interval_secs(1)
52+
.max_concurrent_requests(64)
53+
.queue_timeout_secs(60)
54+
.build_unchecked();
55+
56+
let app_context = create_test_context(config.clone()).await;
57+
58+
let job_queue = app_context
59+
.worker_job_queue
60+
.get()
61+
.expect("JobQueue should be initialized");
62+
job_queue
63+
.submit(Job::InitializeWorkersFromConfig {
64+
router_config: Box::new(config),
65+
})
66+
.await
67+
.expect("Failed to submit worker initialization job");
68+
69+
let start = tokio::time::Instant::now();
70+
loop {
71+
if app_context
72+
.worker_registry
73+
.get_all()
74+
.iter()
75+
.any(|w| w.is_healthy())
76+
{
77+
break;
78+
}
79+
if start.elapsed().as_secs() > 10 {
80+
panic!("Timeout waiting for worker to become healthy");
81+
}
82+
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
83+
}
84+
85+
let discovered_models = app_context.worker_registry.get_models();
86+
assert!(
87+
discovered_models.contains(&"my-model".to_string()),
88+
"Expected 'my-model' discovered via /v1/models fallback, got: {:?}",
89+
discovered_models
90+
);
91+
92+
worker.stop().await;
93+
}
94+
}

0 commit comments

Comments
 (0)