Skip to content

Commit 35aaca0

Browse files
authored
Model list endpoint caching (#17)
* Model list endpoint caching * Review fixes * Review fixes
1 parent dffb1b9 commit 35aaca0

8 files changed

Lines changed: 364 additions & 187 deletions

File tree

src/app.rs

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,10 @@ pub struct AppState {
376376
/// Model catalog registry for enriching API responses with model metadata.
377377
/// Loaded from embedded data at startup and optionally synced at runtime.
378378
pub model_catalog: catalog::ModelCatalogRegistry,
379+
/// In-memory cache of model lists fetched from static (config-file) providers.
380+
/// Warmed on startup and refreshed periodically to avoid per-request latency.
381+
pub static_models_cache:
382+
Arc<tokio::sync::RwLock<std::collections::HashMap<String, providers::ModelsResponse>>>,
379383
}
380384

381385
impl AppState {
@@ -1059,7 +1063,7 @@ impl AppState {
10591063
Arc::new(services::ProviderMetricsService::new())
10601064
};
10611065

1062-
Ok(Self {
1066+
let result = Ok(Self {
10631067
http_client,
10641068
config: Arc::new(config),
10651069
db,
@@ -1096,7 +1100,19 @@ impl AppState {
10961100
default_org_id,
10971101
provider_metrics,
10981102
model_catalog,
1099-
})
1103+
static_models_cache: Arc::new(tokio::sync::RwLock::new(
1104+
std::collections::HashMap::new(),
1105+
)),
1106+
});
1107+
1108+
// Warm the static models cache so /v1/models is fast from the first request
1109+
if let Ok(ref state) = result
1110+
&& state.config.features.static_models_cache.enabled()
1111+
{
1112+
state.warm_static_models_cache().await;
1113+
}
1114+
1115+
result
11001116
}
11011117

11021118
/// Ensure a default user exists for anonymous access when auth is disabled.
@@ -1816,6 +1832,49 @@ impl AppState {
18161832
}
18171833
}
18181834
}
1835+
1836+
/// Fetch model lists from all static (config-file) providers in parallel and
1837+
/// store them in `self.static_models_cache`. Failures for individual providers
1838+
/// are logged and skipped so one slow/broken provider cannot block the rest.
1839+
pub async fn warm_static_models_cache(&self) {
1840+
use futures::future::join_all;
1841+
1842+
let futures: Vec<_> = self
1843+
.config
1844+
.providers
1845+
.iter()
1846+
.map(|(name, cfg)| {
1847+
let name = name.to_owned();
1848+
let http = self.http_client.clone();
1849+
let cbs = self.circuit_breakers.clone();
1850+
async move {
1851+
let result = providers::list_models_for_config(cfg, &name, &http, &cbs).await;
1852+
(name, result)
1853+
}
1854+
})
1855+
.collect();
1856+
1857+
let results = join_all(futures).await;
1858+
1859+
let mut cache = self.static_models_cache.write().await;
1860+
cache.retain(|name, _| self.config.providers.get(name).is_some());
1861+
for (name, result) in results {
1862+
match result {
1863+
Ok(response) => {
1864+
cache.insert(name, response);
1865+
}
1866+
Err(e) => {
1867+
tracing::warn!(provider = %name, error = %e, "Failed to fetch models for cache warm");
1868+
}
1869+
}
1870+
}
1871+
let total_models: usize = cache.values().map(|r| r.data.len()).sum();
1872+
tracing::info!(
1873+
providers = cache.len(),
1874+
models = total_models,
1875+
"Static models cache warmed"
1876+
);
1877+
}
18191878
}
18201879

18211880
#[cfg(feature = "server")]

src/cli/server.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,21 @@ pub(crate) async fn run_server(explicit_config_path: Option<&str>, no_browser: b
335335
None
336336
};
337337

338+
// Refresh the static models cache periodically in the background
339+
// (initial warming already happened in AppState::new)
340+
if config.features.static_models_cache.enabled() {
341+
let interval = config.features.static_models_cache.refresh_interval();
342+
let state_ref = state.clone();
343+
tokio::spawn(async move {
344+
let mut ticker = tokio::time::interval(interval);
345+
ticker.tick().await; // skip the immediate first tick (already warmed)
346+
loop {
347+
ticker.tick().await;
348+
state_ref.warm_static_models_cache().await;
349+
}
350+
});
351+
}
352+
338353
let task_tracker = state.task_tracker.clone();
339354
let app = build_app(&config, state);
340355

src/config/features.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ pub struct FeaturesConfig {
5757
/// Validates URLs with SSRF protection and enforces size limits.
5858
#[serde(default)]
5959
pub web_fetch: Option<WebFetchConfig>,
60+
61+
/// Static models cache configuration.
62+
/// Caches model lists from config-file providers to avoid per-request latency.
63+
#[serde(default)]
64+
pub static_models_cache: StaticModelsCacheConfig,
6065
}
6166

6267
impl FeaturesConfig {
@@ -2563,6 +2568,51 @@ fn default_catalog_api_url() -> String {
25632568
"https://models.dev/api.json".to_string()
25642569
}
25652570

2571+
/// Configuration for the static models cache.
2572+
///
2573+
/// Model lists from config-file providers are cached in memory and refreshed
2574+
/// periodically so that `/v1/models` does not make upstream HTTP calls on every
2575+
/// request.
2576+
///
2577+
/// ```toml
2578+
/// [features.static_models_cache]
2579+
/// refresh_interval_secs = 300
2580+
/// ```
2581+
#[derive(Debug, Clone, Serialize, Deserialize)]
2582+
#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
2583+
#[serde(deny_unknown_fields)]
2584+
pub struct StaticModelsCacheConfig {
2585+
/// How often to refresh the cached model lists, in seconds.
2586+
/// Set to 0 to disable caching (every request will query providers directly).
2587+
/// Default: 300 (5 minutes).
2588+
#[serde(default = "default_static_models_refresh_interval_secs")]
2589+
pub refresh_interval_secs: u64,
2590+
}
2591+
2592+
impl Default for StaticModelsCacheConfig {
2593+
fn default() -> Self {
2594+
Self {
2595+
refresh_interval_secs: default_static_models_refresh_interval_secs(),
2596+
}
2597+
}
2598+
}
2599+
2600+
impl StaticModelsCacheConfig {
2601+
/// Whether caching is enabled (interval > 0).
2602+
pub fn enabled(&self) -> bool {
2603+
self.refresh_interval_secs > 0
2604+
}
2605+
2606+
/// Refresh interval as a `Duration`.
2607+
pub fn refresh_interval(&self) -> std::time::Duration {
2608+
std::time::Duration::from_secs(self.refresh_interval_secs)
2609+
}
2610+
}
2611+
2612+
fn default_static_models_refresh_interval_secs() -> u64 {
2613+
300 // 5 minutes
2614+
}
2615+
25662616
#[cfg(test)]
25672617
mod tests {
25682618
use super::*;

src/middleware/layers/admin.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2373,6 +2373,9 @@ mod tests {
23732373
crate::services::ProviderMetricsService::with_local_metrics(|| None),
23742374
),
23752375
model_catalog: crate::catalog::ModelCatalogRegistry::new(),
2376+
static_models_cache: std::sync::Arc::new(tokio::sync::RwLock::new(
2377+
std::collections::HashMap::new(),
2378+
)),
23762379
}
23772380
}
23782381

@@ -2674,6 +2677,9 @@ mod tests {
26742677
crate::services::ProviderMetricsService::with_local_metrics(|| None),
26752678
),
26762679
model_catalog: crate::catalog::ModelCatalogRegistry::new(),
2680+
static_models_cache: std::sync::Arc::new(tokio::sync::RwLock::new(
2681+
std::collections::HashMap::new(),
2682+
)),
26772683
}
26782684
}
26792685

src/middleware/layers/api.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2289,6 +2289,9 @@ mod tests {
22892289
crate::services::ProviderMetricsService::with_local_metrics(|| None),
22902290
),
22912291
model_catalog: crate::catalog::ModelCatalogRegistry::new(),
2292+
static_models_cache: std::sync::Arc::new(tokio::sync::RwLock::new(
2293+
std::collections::HashMap::new(),
2294+
)),
22922295
}
22932296
}
22942297

@@ -2340,6 +2343,9 @@ mod tests {
23402343
crate::services::ProviderMetricsService::with_local_metrics(|| None),
23412344
),
23422345
model_catalog: crate::catalog::ModelCatalogRegistry::new(),
2346+
static_models_cache: std::sync::Arc::new(tokio::sync::RwLock::new(
2347+
std::collections::HashMap::new(),
2348+
)),
23432349
}
23442350
}
23452351

0 commit comments

Comments
 (0)