Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions samples/rust/audio-transcription-example/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
if !model.is_cached().await? {
println!("Downloading model...");
model
.download(Some(Box::new(|progress: &str| {
.download(Some(|progress: &str| {
print!("\r {progress}%");
io::stdout().flush().ok();
})))
}))
.await?;
println!();
}
Expand Down
4 changes: 2 additions & 2 deletions samples/rust/foundry-local-webserver/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
if !model.is_cached().await? {
print!("Downloading model {model_alias}...");
model
.download(Some(Box::new(move |progress: &str| {
.download(Some(move |progress: &str| {
print!("\rDownloading model... {progress}%");
io::stdout().flush().ok();
})))
}))
.await?;
println!();
}
Expand Down
4 changes: 2 additions & 2 deletions samples/rust/native-chat-completions/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
if !model.is_cached().await? {
println!("Downloading model...");
model
.download(Some(Box::new(|progress: &str| {
.download(Some(|progress: &str| {
print!("\r {progress}%");
io::stdout().flush().ok();
})))
}))
.await?;
println!();
}
Expand Down
4 changes: 2 additions & 2 deletions samples/rust/tool-calling-foundry-local/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
if !model.is_cached().await? {
println!("Downloading model...");
model
.download(Some(Box::new(|progress: &str| {
.download(Some(|progress: &str| {
print!("\r {progress}%");
io::stdout().flush().ok();
})))
}))
.await?;
println!();
}
Expand Down
4 changes: 2 additions & 2 deletions samples/rust/tutorial-chat-assistant/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ async fn main() -> anyhow::Result<()> {
if !model.is_cached().await? {
println!("Downloading model...");
model
.download(Some(Box::new(|progress: &str| {
.download(Some(|progress: &str| {
print!("\r {progress}");
io::stdout().flush().ok();
})))
}))
.await?;
println!();
}
Expand Down
4 changes: 2 additions & 2 deletions samples/rust/tutorial-document-summarizer/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@ async fn main() -> anyhow::Result<()> {
if !model.is_cached().await? {
println!("Downloading model...");
model
.download(Some(Box::new(|progress: &str| {
.download(Some(|progress: &str| {
print!("\r {progress}");
io::stdout().flush().ok();
})))
}))
.await?;
println!();
}
Expand Down
4 changes: 2 additions & 2 deletions samples/rust/tutorial-tool-calling/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,10 @@ async fn main() -> anyhow::Result<()> {
if !model.is_cached().await? {
println!("Downloading model...");
model
.download(Some(Box::new(|progress: &str| {
.download(Some(|progress: &str| {
print!("\r {progress}");
io::stdout().flush().ok();
})))
}))
.await?;
println!();
}
Expand Down
8 changes: 4 additions & 4 deletions samples/rust/tutorial-voice-to-text/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ async fn main() -> anyhow::Result<()> {
if !speech_model.is_cached().await? {
println!("Downloading speech model...");
speech_model
.download(Some(Box::new(|progress: &str| {
.download(Some(|progress: &str| {
print!("\r {progress}");
io::stdout().flush().ok();
})))
}))
.await?;
println!();
}
Expand Down Expand Up @@ -60,10 +60,10 @@ async fn main() -> anyhow::Result<()> {
if !chat_model.is_cached().await? {
println!("Downloading chat model...");
chat_model
.download(Some(Box::new(|progress: &str| {
.download(Some(|progress: &str| {
print!("\r {progress}");
io::stdout().flush().ok();
})))
}))
.await?;
println!();
}
Expand Down
4 changes: 2 additions & 2 deletions sdk/rust/examples/chat_completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ async fn main() -> Result<()> {
if !model.is_cached().await? {
println!("Downloading model '{}'…", model.alias());
model
.download(Some(Box::new(|progress: &str| {
.download(Some(|progress: &str| {
println!(" {progress}");
})))
}))
.await?;
}

Expand Down
2 changes: 1 addition & 1 deletion sdk/rust/examples/interactive_chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Download if needed
if !model.is_cached().await? {
println!("Downloading '{alias}'…");
model.download(Some(Box::new(|p: &str| print!("\r {p}%")))).await?;
model.download(Some(|p: &str| print!("\r {p}%"))).await?;
println!();
}

Expand Down
2 changes: 1 addition & 1 deletion sdk/rust/examples/tool_calling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async fn main() -> Result<()> {

if !model.is_cached().await? {
println!("Downloading model '{}'…", model.alias());
model.download(Some(Box::new(|p: &str| println!(" {p}")))).await?;
model.download(Some(|p: &str| println!(" {p}"))).await?;
}
println!("Loading model '{}'…", model.alias());
model.load().await?;
Expand Down
101 changes: 32 additions & 69 deletions sdk/rust/src/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@ use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};

use crate::detail::core_interop::CoreInterop;
use crate::detail::ModelLoadManager;
use crate::error::{FoundryLocalError, Result};
use crate::imodel::IModel;
use crate::detail::model::Model;
use crate::detail::model_variant::ModelVariant;
use crate::detail::ModelLoadManager;
use crate::error::{FoundryLocalError, Result};
use crate::types::ModelInfo;

/// How long the catalog cache remains valid before a refresh.
Expand Down Expand Up @@ -40,7 +39,7 @@ impl CacheInvalidator {
/// All mutable catalog data behind a single lock to prevent split-brain reads.
struct CatalogState {
models_by_alias: HashMap<String, Arc<Model>>,
variants_by_id: HashMap<String, Arc<ModelVariant>>,
variants_by_id: HashMap<String, Arc<Model>>,
last_refresh: Option<Instant>,
}

Expand Down Expand Up @@ -125,70 +124,52 @@ impl Catalog {
}

/// Return all known models keyed by alias.
pub async fn get_models(&self) -> Result<Vec<Arc<dyn IModel>>> {
pub async fn get_models(&self) -> Result<Vec<Arc<Model>>> {
self.update_models().await?;
let s = self.lock_state()?;
Ok(s.models_by_alias
.values()
.map(|m| {
let model: Arc<dyn IModel> = m.clone();
model
})
.collect())
Ok(s.models_by_alias.values().cloned().collect())
}

/// Look up a model by its alias.
pub async fn get_model(&self, alias: &str) -> Result<Arc<dyn IModel>> {
pub async fn get_model(&self, alias: &str) -> Result<Arc<Model>> {
if alias.trim().is_empty() {
return Err(FoundryLocalError::Validation {
reason: "Model alias must be a non-empty string".into(),
});
}
self.update_models().await?;
let s = self.lock_state()?;
s.models_by_alias
.get(alias)
.map(|m| {
let model: Arc<dyn IModel> = m.clone();
model
})
.ok_or_else(|| {
let available: Vec<&str> = s.models_by_alias.keys().map(|k| k.as_str()).collect();
FoundryLocalError::ModelOperation {
reason: format!("Unknown model alias '{alias}'. Available: {available:?}"),
}
})
s.models_by_alias.get(alias).cloned().ok_or_else(|| {
let available: Vec<&str> = s.models_by_alias.keys().map(|k| k.as_str()).collect();
FoundryLocalError::ModelOperation {
reason: format!("Unknown model alias '{alias}'. Available: {available:?}"),
}
})
}

/// Look up a specific model variant by its unique id.
///
/// NOTE: This will return an `IModel` representing a single variant. Use
/// [`get_model`](Catalog::get_model) to obtain an `IModel` with all
/// NOTE: This will return a `Model` representing a single variant. Use
/// [`get_model`](Catalog::get_model) to obtain a `Model` with all
/// available variants.
pub async fn get_model_variant(&self, id: &str) -> Result<Arc<dyn IModel>> {
pub async fn get_model_variant(&self, id: &str) -> Result<Arc<Model>> {
if id.trim().is_empty() {
return Err(FoundryLocalError::Validation {
reason: "Variant id must be a non-empty string".into(),
});
}
self.update_models().await?;
let s = self.lock_state()?;
s.variants_by_id
.get(id)
.map(|v| {
let variant: Arc<dyn IModel> = v.clone();
variant
})
.ok_or_else(|| {
let available: Vec<&str> = s.variants_by_id.keys().map(|k| k.as_str()).collect();
FoundryLocalError::ModelOperation {
reason: format!("Unknown variant id '{id}'. Available: {available:?}"),
}
})
s.variants_by_id.get(id).cloned().ok_or_else(|| {
let available: Vec<&str> = s.variants_by_id.keys().map(|k| k.as_str()).collect();
FoundryLocalError::ModelOperation {
reason: format!("Unknown variant id '{id}'. Available: {available:?}"),
}
})
}

/// Return only the model variants that are currently cached on disk.
pub async fn get_cached_models(&self) -> Result<Vec<Arc<dyn IModel>>> {
pub async fn get_cached_models(&self) -> Result<Vec<Arc<Model>>> {
self.update_models().await?;
let raw = self
.core
Expand All @@ -201,37 +182,23 @@ impl Catalog {
let s = self.lock_state()?;
Ok(cached_ids
.iter()
.filter_map(|id| {
s.variants_by_id
.get(id)
.map(|v| {
let variant: Arc<dyn IModel> = v.clone();
variant
})
})
.filter_map(|id| s.variants_by_id.get(id).cloned())
.collect())
}

/// Return model variants that are currently loaded into memory.
pub async fn get_loaded_models(&self) -> Result<Vec<Arc<dyn IModel>>> {
pub async fn get_loaded_models(&self) -> Result<Vec<Arc<Model>>> {
self.update_models().await?;
let loaded_ids = self.model_load_manager.list_loaded().await?;
let s = self.lock_state()?;
Ok(loaded_ids
.iter()
.filter_map(|id| {
s.variants_by_id
.get(id)
.map(|v| {
let variant: Arc<dyn IModel> = v.clone();
variant
})
})
.filter_map(|id| s.variants_by_id.get(id).cloned())
.collect())
}

/// Resolve the latest catalog version for the provided model or variant.
pub async fn get_latest_version(&self, model_or_model_variant: &Arc<dyn IModel>) -> Result<Arc<dyn IModel>> {
pub async fn get_latest_version(&self, model_or_model_variant: &Model) -> Result<Arc<Model>> {
self.update_models().await?;
Comment thread
nenad1002 marked this conversation as resolved.
let s = self.lock_state()?;

Expand All @@ -257,11 +224,7 @@ impl Catalog {
),
})?;

if latest.id() == model_or_model_variant.id() {
Ok(Arc::clone(model_or_model_variant))
} else {
Ok(latest)
}
Ok(latest)
}

async fn force_refresh(&self) -> Result<()> {
Expand All @@ -287,22 +250,22 @@ impl Catalog {
};

let mut alias_map_build: HashMap<String, Model> = HashMap::new();
let mut id_map: HashMap<String, Arc<ModelVariant>> = HashMap::new();
let mut id_map: HashMap<String, Arc<Model>> = HashMap::new();

for info in infos {
let id = info.id.clone();
let alias = info.alias.clone();
let variant = Arc::new(ModelVariant::new(
let variant = ModelVariant::new(
info,
Arc::clone(&self.core),
Arc::clone(&self.model_load_manager),
self.invalidator.clone(),
));
id_map.insert(id, Arc::clone(&variant));
);
id_map.insert(id, Arc::new(Model::from_variant(variant.clone())));

Comment thread
nenad1002 marked this conversation as resolved.
alias_map_build
.entry(alias)
.or_insert_with_key(|a| Model::new(a.clone(), Arc::clone(&self.core)))
.or_insert_with_key(|a| Model::from_group(a.clone(), Arc::clone(&self.core)))
.add_variant(variant);
}

Expand Down
2 changes: 1 addition & 1 deletion sdk/rust/src/detail/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
pub(crate) mod core_interop;
pub(crate) mod model;
pub(crate) mod model_variant;
mod model_load_manager;
pub(crate) mod model_variant;

pub use self::model_load_manager::ModelLoadManager;
Loading
Loading