diff --git a/samples/rust/audio-transcription-example/src/main.rs b/samples/rust/audio-transcription-example/src/main.rs index 9e308a7d..c326006f 100644 --- a/samples/rust/audio-transcription-example/src/main.rs +++ b/samples/rust/audio-transcription-example/src/main.rs @@ -35,10 +35,10 @@ async fn main() -> Result<(), Box> { if !model.is_cached().await? { println!("Downloading model..."); model - .download(Some(Box::new(|progress: &str| { + .download(Some(|progress: &str| { print!("\r {progress}%"); io::stdout().flush().ok(); - }))) + })) .await?; println!(); } diff --git a/samples/rust/foundry-local-webserver/src/main.rs b/samples/rust/foundry-local-webserver/src/main.rs index 83e66f51..492cbbc1 100644 --- a/samples/rust/foundry-local-webserver/src/main.rs +++ b/samples/rust/foundry-local-webserver/src/main.rs @@ -34,10 +34,10 @@ async fn main() -> Result<(), Box> { if !model.is_cached().await? { print!("Downloading model {model_alias}..."); model - .download(Some(Box::new(move |progress: &str| { + .download(Some(move |progress: &str| { print!("\rDownloading model... {progress}%"); io::stdout().flush().ok(); - }))) + })) .await?; println!(); } diff --git a/samples/rust/native-chat-completions/src/main.rs b/samples/rust/native-chat-completions/src/main.rs index 00461f44..04d09372 100644 --- a/samples/rust/native-chat-completions/src/main.rs +++ b/samples/rust/native-chat-completions/src/main.rs @@ -32,10 +32,10 @@ async fn main() -> Result<(), Box> { if !model.is_cached().await? { println!("Downloading model..."); model - .download(Some(Box::new(|progress: &str| { + .download(Some(|progress: &str| { print!("\r {progress}%"); io::stdout().flush().ok(); - }))) + })) .await?; println!(); } diff --git a/samples/rust/tool-calling-foundry-local/src/main.rs b/samples/rust/tool-calling-foundry-local/src/main.rs index c055af9c..1ccda1e8 100644 --- a/samples/rust/tool-calling-foundry-local/src/main.rs +++ b/samples/rust/tool-calling-foundry-local/src/main.rs @@ -66,10 +66,10 @@ async fn main() -> Result<(), Box> { if !model.is_cached().await? { println!("Downloading model..."); model - .download(Some(Box::new(|progress: &str| { + .download(Some(|progress: &str| { print!("\r {progress}%"); io::stdout().flush().ok(); - }))) + })) .await?; println!(); } diff --git a/samples/rust/tutorial-chat-assistant/src/main.rs b/samples/rust/tutorial-chat-assistant/src/main.rs index 9e373b6d..6b0b587b 100644 --- a/samples/rust/tutorial-chat-assistant/src/main.rs +++ b/samples/rust/tutorial-chat-assistant/src/main.rs @@ -21,10 +21,10 @@ async fn main() -> anyhow::Result<()> { if !model.is_cached().await? { println!("Downloading model..."); model - .download(Some(Box::new(|progress: &str| { + .download(Some(|progress: &str| { print!("\r {progress}"); io::stdout().flush().ok(); - }))) + })) .await?; println!(); } diff --git a/samples/rust/tutorial-document-summarizer/src/main.rs b/samples/rust/tutorial-document-summarizer/src/main.rs index a1f3db7f..9ade2e77 100644 --- a/samples/rust/tutorial-document-summarizer/src/main.rs +++ b/samples/rust/tutorial-document-summarizer/src/main.rs @@ -96,10 +96,10 @@ async fn main() -> anyhow::Result<()> { if !model.is_cached().await? { println!("Downloading model..."); model - .download(Some(Box::new(|progress: &str| { + .download(Some(|progress: &str| { print!("\r {progress}"); io::stdout().flush().ok(); - }))) + })) .await?; println!(); } diff --git a/samples/rust/tutorial-tool-calling/src/main.rs b/samples/rust/tutorial-tool-calling/src/main.rs index ab477692..f4476643 100644 --- a/samples/rust/tutorial-tool-calling/src/main.rs +++ b/samples/rust/tutorial-tool-calling/src/main.rs @@ -199,10 +199,10 @@ async fn main() -> anyhow::Result<()> { if !model.is_cached().await? { println!("Downloading model..."); model - .download(Some(Box::new(|progress: &str| { + .download(Some(|progress: &str| { print!("\r {progress}"); io::stdout().flush().ok(); - }))) + })) .await?; println!(); } diff --git a/samples/rust/tutorial-voice-to-text/src/main.rs b/samples/rust/tutorial-voice-to-text/src/main.rs index c1bedd4c..2295c86a 100644 --- a/samples/rust/tutorial-voice-to-text/src/main.rs +++ b/samples/rust/tutorial-voice-to-text/src/main.rs @@ -28,10 +28,10 @@ async fn main() -> anyhow::Result<()> { if !speech_model.is_cached().await? { println!("Downloading speech model..."); speech_model - .download(Some(Box::new(|progress: &str| { + .download(Some(|progress: &str| { print!("\r {progress}"); io::stdout().flush().ok(); - }))) + })) .await?; println!(); } @@ -60,10 +60,10 @@ async fn main() -> anyhow::Result<()> { if !chat_model.is_cached().await? { println!("Downloading chat model..."); chat_model - .download(Some(Box::new(|progress: &str| { + .download(Some(|progress: &str| { print!("\r {progress}"); io::stdout().flush().ok(); - }))) + })) .await?; println!(); } diff --git a/sdk/rust/examples/chat_completion.rs b/sdk/rust/examples/chat_completion.rs index ea817ab7..3516aa60 100644 --- a/sdk/rust/examples/chat_completion.rs +++ b/sdk/rust/examples/chat_completion.rs @@ -40,9 +40,9 @@ async fn main() -> Result<()> { if !model.is_cached().await? { println!("Downloading model '{}'…", model.alias()); model - .download(Some(Box::new(|progress: &str| { + .download(Some(|progress: &str| { println!(" {progress}"); - }))) + })) .await?; } diff --git a/sdk/rust/examples/interactive_chat.rs b/sdk/rust/examples/interactive_chat.rs index bc8a992d..bd230155 100644 --- a/sdk/rust/examples/interactive_chat.rs +++ b/sdk/rust/examples/interactive_chat.rs @@ -41,7 +41,7 @@ async fn main() -> Result<(), Box> { // Download if needed if !model.is_cached().await? { println!("Downloading '{alias}'…"); - model.download(Some(Box::new(|p: &str| print!("\r {p}%")))).await?; + model.download(Some(|p: &str| print!("\r {p}%"))).await?; println!(); } diff --git a/sdk/rust/examples/tool_calling.rs b/sdk/rust/examples/tool_calling.rs index 38620cb5..fecf6bc5 100644 --- a/sdk/rust/examples/tool_calling.rs +++ b/sdk/rust/examples/tool_calling.rs @@ -67,7 +67,7 @@ async fn main() -> Result<()> { if !model.is_cached().await? { println!("Downloading model '{}'…", model.alias()); - model.download(Some(Box::new(|p: &str| println!(" {p}")))).await?; + model.download(Some(|p: &str| println!(" {p}"))).await?; } println!("Loading model '{}'…", model.alias()); model.load().await?; diff --git a/sdk/rust/src/catalog.rs b/sdk/rust/src/catalog.rs index 0e0a991a..26a737e9 100644 --- a/sdk/rust/src/catalog.rs +++ b/sdk/rust/src/catalog.rs @@ -6,11 +6,10 @@ use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; use crate::detail::core_interop::CoreInterop; -use crate::detail::ModelLoadManager; -use crate::error::{FoundryLocalError, Result}; -use crate::imodel::IModel; use crate::detail::model::Model; use crate::detail::model_variant::ModelVariant; +use crate::detail::ModelLoadManager; +use crate::error::{FoundryLocalError, Result}; use crate::types::ModelInfo; /// How long the catalog cache remains valid before a refresh. @@ -40,7 +39,7 @@ impl CacheInvalidator { /// All mutable catalog data behind a single lock to prevent split-brain reads. struct CatalogState { models_by_alias: HashMap>, - variants_by_id: HashMap>, + variants_by_id: HashMap>, last_refresh: Option, } @@ -125,20 +124,14 @@ impl Catalog { } /// Return all known models keyed by alias. - pub async fn get_models(&self) -> Result>> { + pub async fn get_models(&self) -> Result>> { self.update_models().await?; let s = self.lock_state()?; - Ok(s.models_by_alias - .values() - .map(|m| { - let model: Arc = m.clone(); - model - }) - .collect()) + Ok(s.models_by_alias.values().cloned().collect()) } /// Look up a model by its alias. - pub async fn get_model(&self, alias: &str) -> Result> { + pub async fn get_model(&self, alias: &str) -> Result> { if alias.trim().is_empty() { return Err(FoundryLocalError::Validation { reason: "Model alias must be a non-empty string".into(), @@ -146,26 +139,20 @@ impl Catalog { } self.update_models().await?; let s = self.lock_state()?; - s.models_by_alias - .get(alias) - .map(|m| { - let model: Arc = m.clone(); - model - }) - .ok_or_else(|| { - let available: Vec<&str> = s.models_by_alias.keys().map(|k| k.as_str()).collect(); - FoundryLocalError::ModelOperation { - reason: format!("Unknown model alias '{alias}'. Available: {available:?}"), - } - }) + s.models_by_alias.get(alias).cloned().ok_or_else(|| { + let available: Vec<&str> = s.models_by_alias.keys().map(|k| k.as_str()).collect(); + FoundryLocalError::ModelOperation { + reason: format!("Unknown model alias '{alias}'. Available: {available:?}"), + } + }) } /// Look up a specific model variant by its unique id. /// - /// NOTE: This will return an `IModel` representing a single variant. Use - /// [`get_model`](Catalog::get_model) to obtain an `IModel` with all + /// NOTE: This will return a `Model` representing a single variant. Use + /// [`get_model`](Catalog::get_model) to obtain a `Model` with all /// available variants. - pub async fn get_model_variant(&self, id: &str) -> Result> { + pub async fn get_model_variant(&self, id: &str) -> Result> { if id.trim().is_empty() { return Err(FoundryLocalError::Validation { reason: "Variant id must be a non-empty string".into(), @@ -173,22 +160,16 @@ impl Catalog { } self.update_models().await?; let s = self.lock_state()?; - s.variants_by_id - .get(id) - .map(|v| { - let variant: Arc = v.clone(); - variant - }) - .ok_or_else(|| { - let available: Vec<&str> = s.variants_by_id.keys().map(|k| k.as_str()).collect(); - FoundryLocalError::ModelOperation { - reason: format!("Unknown variant id '{id}'. Available: {available:?}"), - } - }) + s.variants_by_id.get(id).cloned().ok_or_else(|| { + let available: Vec<&str> = s.variants_by_id.keys().map(|k| k.as_str()).collect(); + FoundryLocalError::ModelOperation { + reason: format!("Unknown variant id '{id}'. Available: {available:?}"), + } + }) } /// Return only the model variants that are currently cached on disk. - pub async fn get_cached_models(&self) -> Result>> { + pub async fn get_cached_models(&self) -> Result>> { self.update_models().await?; let raw = self .core @@ -201,37 +182,23 @@ impl Catalog { let s = self.lock_state()?; Ok(cached_ids .iter() - .filter_map(|id| { - s.variants_by_id - .get(id) - .map(|v| { - let variant: Arc = v.clone(); - variant - }) - }) + .filter_map(|id| s.variants_by_id.get(id).cloned()) .collect()) } /// Return model variants that are currently loaded into memory. - pub async fn get_loaded_models(&self) -> Result>> { + pub async fn get_loaded_models(&self) -> Result>> { self.update_models().await?; let loaded_ids = self.model_load_manager.list_loaded().await?; let s = self.lock_state()?; Ok(loaded_ids .iter() - .filter_map(|id| { - s.variants_by_id - .get(id) - .map(|v| { - let variant: Arc = v.clone(); - variant - }) - }) + .filter_map(|id| s.variants_by_id.get(id).cloned()) .collect()) } /// Resolve the latest catalog version for the provided model or variant. - pub async fn get_latest_version(&self, model_or_model_variant: &Arc) -> Result> { + pub async fn get_latest_version(&self, model_or_model_variant: &Model) -> Result> { self.update_models().await?; let s = self.lock_state()?; @@ -257,11 +224,7 @@ impl Catalog { ), })?; - if latest.id() == model_or_model_variant.id() { - Ok(Arc::clone(model_or_model_variant)) - } else { - Ok(latest) - } + Ok(latest) } async fn force_refresh(&self) -> Result<()> { @@ -287,22 +250,22 @@ impl Catalog { }; let mut alias_map_build: HashMap = HashMap::new(); - let mut id_map: HashMap> = HashMap::new(); + let mut id_map: HashMap> = HashMap::new(); for info in infos { let id = info.id.clone(); let alias = info.alias.clone(); - let variant = Arc::new(ModelVariant::new( + let variant = ModelVariant::new( info, Arc::clone(&self.core), Arc::clone(&self.model_load_manager), self.invalidator.clone(), - )); - id_map.insert(id, Arc::clone(&variant)); + ); + id_map.insert(id, Arc::new(Model::from_variant(variant.clone()))); alias_map_build .entry(alias) - .or_insert_with_key(|a| Model::new(a.clone(), Arc::clone(&self.core))) + .or_insert_with_key(|a| Model::from_group(a.clone(), Arc::clone(&self.core))) .add_variant(variant); } diff --git a/sdk/rust/src/detail/mod.rs b/sdk/rust/src/detail/mod.rs index 32f1f983..b153ed5b 100644 --- a/sdk/rust/src/detail/mod.rs +++ b/sdk/rust/src/detail/mod.rs @@ -1,6 +1,6 @@ pub(crate) mod core_interop; pub(crate) mod model; -pub(crate) mod model_variant; mod model_load_manager; +pub(crate) mod model_variant; pub use self::model_load_manager::ModelLoadManager; diff --git a/sdk/rust/src/detail/model.rs b/sdk/rust/src/detail/model.rs index fba868ad..196ebe35 100644 --- a/sdk/rust/src/detail/model.rs +++ b/sdk/rust/src/detail/model.rs @@ -1,196 +1,300 @@ -//! High-level model abstraction that wraps one or more model variants -//! sharing the same alias. +//! Public model type backed by an internal enum. +//! +//! Users interact solely with [`Model`]. The internal representation +//! distinguishes between a single variant and a group of variants sharing +//! the same alias, but callers never need to know which kind they hold. use std::fmt; -use std::future::Future; use std::path::PathBuf; -use std::pin::Pin; use std::sync::atomic::{AtomicUsize, Ordering::Relaxed}; use std::sync::Arc; use super::core_interop::CoreInterop; -use crate::error::{FoundryLocalError, Result}; -use crate::imodel::IModel; use super::model_variant::ModelVariant; +use crate::error::{FoundryLocalError, Result}; use crate::openai::AudioClient; use crate::openai::ChatClient; use crate::types::ModelInfo; -/// A model groups one or more variants that share the same alias. +/// The public model type. /// -/// By default the variant that is already cached locally is selected. You -/// can override the selection with [`IModel::select_variant`]. +/// A `Model` may represent either a group of variants (as returned by +/// [`Catalog::get_model`](crate::Catalog::get_model)) or a single variant (as +/// returned by [`Catalog::get_model_variant`](crate::Catalog::get_model_variant) +/// or [`Model::variants`]). /// -/// Implements [`IModel`] — all operations are forwarded to the currently -/// selected variant. +/// When a `Model` groups multiple variants, operations are forwarded to +/// the currently selected variant. Use [`variants`](Model::variants) to +/// inspect the available variants and [`select_variant`](Model::select_variant) +/// to change the selection. pub struct Model { - alias: String, - core: Arc, - variants: Vec>, - selected_index: AtomicUsize, + inner: ModelKind, +} + +#[allow(clippy::large_enum_variant)] +enum ModelKind { + /// A single model variant (from `get_model_variant` or `variants()`). + ModelVariant(ModelVariant), + /// A group of variants sharing the same alias (from `get_model`). + Model { + alias: String, + core: Arc, + variants: Vec, + selected: AtomicUsize, + }, } impl Clone for Model { fn clone(&self) -> Self { Self { - alias: self.alias.clone(), - core: Arc::clone(&self.core), - variants: self.variants.clone(), - selected_index: AtomicUsize::new(self.selected_index.load(Relaxed)), + inner: match &self.inner { + ModelKind::ModelVariant(v) => ModelKind::ModelVariant(v.clone()), + ModelKind::Model { + alias, + core, + variants, + selected, + } => ModelKind::Model { + alias: alias.clone(), + core: Arc::clone(core), + variants: variants.clone(), + selected: AtomicUsize::new(selected.load(Relaxed)), + }, + }, } } } impl fmt::Debug for Model { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Model") - .field("alias", &self.alias()) - .field("id", &self.id()) - .field("variants_count", &self.variants.len()) - .field("selected_index", &self.selected_index.load(Relaxed)) - .finish() + match &self.inner { + ModelKind::ModelVariant(v) => f + .debug_struct("Model::ModelVariant") + .field("id", &v.id()) + .field("alias", &v.alias()) + .finish(), + ModelKind::Model { + alias, + variants, + selected, + .. + } => f + .debug_struct("Model::Model") + .field("alias", alias) + .field("id", &variants[selected.load(Relaxed)].id()) + .field("variants_count", &variants.len()) + .field("selected_index", &selected.load(Relaxed)) + .finish(), + } } } +// ── Construction (crate-internal) ──────────────────────────────────────────── + impl Model { - pub(crate) fn new(alias: String, core: Arc) -> Self { + /// Create a `Model` wrapping a single variant. + pub(crate) fn from_variant(variant: ModelVariant) -> Self { Self { - alias, - core, - variants: Vec::new(), - selected_index: AtomicUsize::new(0), + inner: ModelKind::ModelVariant(variant), } } - /// Add a variant. If the new variant is cached and the current selection - /// is not, the new variant becomes the selected one. - pub(crate) fn add_variant(&mut self, variant: Arc) { - self.variants.push(variant); - let new_idx = self.variants.len() - 1; - let current = self.selected_index.load(Relaxed); - - // Prefer a cached variant over a non-cached one. - if self.variants[new_idx].info_ref().cached && !self.variants[current].info_ref().cached { - self.selected_index.store(new_idx, Relaxed); + /// Create a `Model` grouping multiple variants under one alias. + pub(crate) fn from_group(alias: String, core: Arc) -> Self { + Self { + inner: ModelKind::Model { + alias, + core, + variants: Vec::new(), + selected: AtomicUsize::new(0), + }, } } - /// Returns a reference to the currently selected variant (crate-internal). - pub(crate) fn selected_variant(&self) -> &ModelVariant { - &self.variants[self.selected_index.load(Relaxed)] + /// Add a variant to a group. Panics if called on a `ModelVariant` kind. + /// + /// If the new variant is cached and the current selection is not, the new + /// variant becomes the selected one. + pub(crate) fn add_variant(&mut self, variant: ModelVariant) { + match &mut self.inner { + ModelKind::Model { + variants, selected, .. + } => { + variants.push(variant); + let new_idx = variants.len() - 1; + let current = selected.load(Relaxed); + if variants[new_idx].info_ref().cached && !variants[current].info_ref().cached { + selected.store(new_idx, Relaxed); + } + } + ModelKind::ModelVariant(_) => { + panic!("add_variant called on a single-variant Model"); + } + } } +} - /// Download the selected variant with a generic progress callback. - /// - /// This is a convenience method that avoids the boxing overhead of the - /// trait method when the concrete type is known. - pub async fn download_generic(&self, progress: Option) -> Result<()> - where - F: FnMut(&str) + Send + 'static, - { - self.selected_variant().download_generic(progress).await +// ── Private helpers ────────────────────────────────────────────────────────── + +impl Model { + fn selected_variant(&self) -> &ModelVariant { + match &self.inner { + ModelKind::ModelVariant(v) => v, + ModelKind::Model { + variants, selected, .. + } => &variants[selected.load(Relaxed)], + } } } -#[allow(clippy::manual_async_fn)] -impl IModel for Model { - fn id(&self) -> &str { +// ── Public API ─────────────────────────────────────────────────────────────── + +impl Model { + /// Unique identifier of the (selected) variant. + pub fn id(&self) -> &str { self.selected_variant().id() } - fn alias(&self) -> &str { - &self.alias + /// Alias shared by all variants of this model. + pub fn alias(&self) -> &str { + match &self.inner { + ModelKind::ModelVariant(v) => v.alias(), + ModelKind::Model { alias, .. } => alias, + } } - fn info(&self) -> &ModelInfo { + /// Full catalog metadata for the (selected) variant. + pub fn info(&self) -> &ModelInfo { self.selected_variant().info() } - fn context_length(&self) -> Option { + /// Maximum context length (in tokens), or `None` if unknown. + pub fn context_length(&self) -> Option { self.selected_variant().info().context_length } - fn input_modalities(&self) -> Option<&str> { + /// Comma-separated input modalities (e.g. `"text,image"`), or `None`. + pub fn input_modalities(&self) -> Option<&str> { self.selected_variant().info().input_modalities.as_deref() } - fn output_modalities(&self) -> Option<&str> { + /// Comma-separated output modalities (e.g. `"text"`), or `None`. + pub fn output_modalities(&self) -> Option<&str> { self.selected_variant().info().output_modalities.as_deref() } - fn capabilities(&self) -> Option<&str> { + /// Capability tags (e.g. `"reasoning"`), or `None`. + pub fn capabilities(&self) -> Option<&str> { self.selected_variant().info().capabilities.as_deref() } - fn supports_tool_calling(&self) -> Option { + /// Whether the model supports tool/function calling, or `None`. + pub fn supports_tool_calling(&self) -> Option { self.selected_variant().info().supports_tool_calling } - fn is_cached(&self) -> Pin> + Send + '_>> { - self.selected_variant().is_cached() + /// Whether the (selected) variant is cached on disk. + pub async fn is_cached(&self) -> Result { + self.selected_variant().is_cached().await } - fn is_loaded(&self) -> Pin> + Send + '_>> { - self.selected_variant().is_loaded() + /// Whether the (selected) variant is loaded into memory. + pub async fn is_loaded(&self) -> Result { + self.selected_variant().is_loaded().await } - fn download( - &self, - progress: Option>, - ) -> Pin> + Send + '_>> { - self.selected_variant().download(progress) + /// Download the (selected) variant. If `progress` is provided it + /// receives human-readable progress strings as they arrive. + pub async fn download(&self, progress: Option) -> Result<()> + where + F: FnMut(&str) + Send + 'static, + { + self.selected_variant().download(progress).await } - fn path(&self) -> Pin> + Send + '_>> { - self.selected_variant().path() + /// Return the local file-system path of the (selected) variant. + pub async fn path(&self) -> Result { + self.selected_variant().path().await } - fn load(&self) -> Pin> + Send + '_>> { - self.selected_variant().load() + /// Load the (selected) variant into memory. + pub async fn load(&self) -> Result<()> { + self.selected_variant().load().await } - fn unload(&self) -> Pin> + Send + '_>> { - self.selected_variant().unload() + /// Unload the (selected) variant from memory. + pub async fn unload(&self) -> Result { + self.selected_variant().unload().await } - fn remove_from_cache(&self) -> Pin> + Send + '_>> { - self.selected_variant().remove_from_cache() + /// Remove the (selected) variant from the local cache. + pub async fn remove_from_cache(&self) -> Result { + self.selected_variant().remove_from_cache().await } - fn create_chat_client(&self) -> ChatClient { - ChatClient::new(self.id(), Arc::clone(&self.core)) + /// Create a [`ChatClient`] bound to the (selected) variant. + pub fn create_chat_client(&self) -> ChatClient { + self.selected_variant().create_chat_client() } - fn create_audio_client(&self) -> AudioClient { - AudioClient::new(self.id(), Arc::clone(&self.core)) + /// Create an [`AudioClient`] bound to the (selected) variant. + pub fn create_audio_client(&self) -> AudioClient { + self.selected_variant().create_audio_client() } - fn variants(&self) -> Vec> { - self.variants - .iter() - .map(|v| { - let variant: Arc = v.clone(); - variant - }) - .collect() + /// Available variants of this model. + /// + /// For a single-variant model (e.g. from + /// [`Catalog::get_model_variant`](crate::Catalog::get_model_variant)), + /// this returns a single-element list containing itself. + pub fn variants(&self) -> Vec> { + match &self.inner { + ModelKind::ModelVariant(v) => { + vec![Arc::new(Model::from_variant(v.clone()))] + } + ModelKind::Model { variants, .. } => variants + .iter() + .map(|v| Arc::new(Model::from_variant(v.clone()))) + .collect(), + } } /// Select a variant by its unique id. - fn select_variant(&self, id: &str) -> Result<()> { - match self.variants.iter().position(|v| v.id() == id) { - Some(pos) => { - self.selected_index.store(pos, Relaxed); - Ok(()) - } - None => { - let available: Vec<&str> = self.variants.iter().map(|v| v.id()).collect(); - Err(FoundryLocalError::ModelOperation { - reason: format!( - "Variant '{id}' not found for model '{}'. Available: {available:?}", - self.alias - ), - }) - } + /// + /// # Errors + /// + /// Returns an error if no variant with the given id exists. + /// For single-variant models this always returns an error — use + /// [`Catalog::get_model`](crate::Catalog::get_model) to obtain a model + /// with all variants available. + pub fn select_variant(&self, id: &str) -> Result<()> { + match &self.inner { + ModelKind::ModelVariant(v) => Err(FoundryLocalError::ModelOperation { + reason: format!( + "select_variant is not supported on a single variant. \ + Call Catalog::get_model(\"{}\") to get a model with all variants available.", + v.alias() + ), + }), + ModelKind::Model { + variants, + selected, + alias, + .. + } => match variants.iter().position(|v| v.id() == id) { + Some(pos) => { + selected.store(pos, Relaxed); + Ok(()) + } + None => { + let available: Vec<&str> = variants.iter().map(|v| v.id()).collect(); + Err(FoundryLocalError::ModelOperation { + reason: format!( + "Variant '{id}' not found for model '{alias}'. Available: {available:?}", + ), + }) + } + }, } } } diff --git a/sdk/rust/src/detail/model_variant.rs b/sdk/rust/src/detail/model_variant.rs index c9c80cd9..636c5d5b 100644 --- a/sdk/rust/src/detail/model_variant.rs +++ b/sdk/rust/src/detail/model_variant.rs @@ -1,21 +1,18 @@ //! A single model variant backed by [`ModelInfo`]. //! //! This type is an implementation detail. Public APIs return -//! [`Arc`](crate::IModel) instead. +//! [`Arc`](crate::Model) instead. use std::fmt; -use std::future::Future; use std::path::PathBuf; -use std::pin::Pin; use std::sync::Arc; use serde_json::json; -use crate::catalog::CacheInvalidator; use super::core_interop::CoreInterop; use super::ModelLoadManager; -use crate::error::{FoundryLocalError, Result}; -use crate::imodel::IModel; +use crate::catalog::CacheInvalidator; +use crate::error::Result; use crate::openai::AudioClient; use crate::openai::ChatClient; use crate::types::ModelInfo; @@ -23,7 +20,7 @@ use crate::types::ModelInfo; /// Represents one specific variant of a model (a particular id within an alias /// group). /// -/// This is an implementation detail — callers should use the [`IModel`] trait. +/// This is an implementation detail — callers should use [`Model`](crate::Model). #[derive(Clone)] pub(crate) struct ModelVariant { info: ModelInfo, @@ -35,8 +32,8 @@ pub(crate) struct ModelVariant { impl fmt::Debug for ModelVariant { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ModelVariant") - .field("id", &self.id()) - .field("alias", &self.alias()) + .field("id", &self.info.id) + .field("alias", &self.info.alias) .finish() } } @@ -56,12 +53,40 @@ impl ModelVariant { } } + pub(crate) fn id(&self) -> &str { + &self.info.id + } + + pub(crate) fn alias(&self) -> &str { + &self.info.alias + } + + pub(crate) fn info(&self) -> &ModelInfo { + &self.info + } + pub(crate) fn info_ref(&self) -> &ModelInfo { &self.info } - /// Download with a generic progress callback (used internally). - pub(crate) async fn download_generic(&self, progress: Option) -> Result<()> + pub(crate) async fn is_cached(&self) -> Result { + let raw = self + .core + .execute_command_async("get_cached_models".into(), None) + .await?; + if raw.trim().is_empty() { + return Ok(false); + } + let cached_ids: Vec = serde_json::from_str(&raw)?; + Ok(cached_ids.iter().any(|id| id == &self.info.id)) + } + + pub(crate) async fn is_loaded(&self) -> Result { + let loaded = self.model_load_manager.list_loaded().await?; + Ok(loaded.iter().any(|id| id == &self.info.id)) + } + + pub(crate) async fn download(&self, progress: Option) -> Result<()> where F: FnMut(&str) + Send + 'static, { @@ -81,136 +106,39 @@ impl ModelVariant { self.cache_invalidator.invalidate(); Ok(()) } -} - -#[allow(clippy::manual_async_fn)] -impl IModel for ModelVariant { - fn id(&self) -> &str { - &self.info.id - } - - fn alias(&self) -> &str { - &self.info.alias - } - - fn info(&self) -> &ModelInfo { - &self.info - } - fn context_length(&self) -> Option { - self.info.context_length - } - - fn input_modalities(&self) -> Option<&str> { - self.info.input_modalities.as_deref() - } - - fn output_modalities(&self) -> Option<&str> { - self.info.output_modalities.as_deref() - } - - fn capabilities(&self) -> Option<&str> { - self.info.capabilities.as_deref() - } - - fn supports_tool_calling(&self) -> Option { - self.info.supports_tool_calling - } - - fn is_cached(&self) -> Pin> + Send + '_>> { - Box::pin(async move { - let raw = self - .core - .execute_command_async("get_cached_models".into(), None) - .await?; - if raw.trim().is_empty() { - return Ok(false); - } - let cached_ids: Vec = serde_json::from_str(&raw)?; - Ok(cached_ids.iter().any(|id| id == &self.info.id)) - }) - } - - fn is_loaded(&self) -> Pin> + Send + '_>> { - Box::pin(async move { - let loaded = self.model_load_manager.list_loaded().await?; - Ok(loaded.iter().any(|id| id == &self.info.id)) - }) - } - - fn download( - &self, - progress: Option>, - ) -> Pin> + Send + '_>> { - Box::pin(async move { - let params = json!({ "Params": { "Model": self.info.id } }); - match progress { - Some(cb) => { - self.core - .execute_command_streaming_async("download_model".into(), Some(params), cb) - .await?; - } - None => { - self.core - .execute_command_async("download_model".into(), Some(params)) - .await?; - } - } - self.cache_invalidator.invalidate(); - Ok(()) - }) - } - - fn path(&self) -> Pin> + Send + '_>> { - Box::pin(async move { - let params = json!({ "Params": { "Model": self.info.id } }); - let path_str = self - .core - .execute_command_async("get_model_path".into(), Some(params)) - .await?; - Ok(PathBuf::from(path_str)) - }) + pub(crate) async fn path(&self) -> Result { + let params = json!({ "Params": { "Model": self.info.id } }); + let path_str = self + .core + .execute_command_async("get_model_path".into(), Some(params)) + .await?; + Ok(PathBuf::from(path_str)) } - fn load(&self) -> Pin> + Send + '_>> { - Box::pin(async move { self.model_load_manager.load(&self.info.id).await }) + pub(crate) async fn load(&self) -> Result<()> { + self.model_load_manager.load(&self.info.id).await } - fn unload(&self) -> Pin> + Send + '_>> { - Box::pin(async move { self.model_load_manager.unload(&self.info.id).await }) + pub(crate) async fn unload(&self) -> Result { + self.model_load_manager.unload(&self.info.id).await } - fn remove_from_cache(&self) -> Pin> + Send + '_>> { - Box::pin(async move { - let params = json!({ "Params": { "Model": self.info.id } }); - let result = self - .core - .execute_command_async("remove_cached_model".into(), Some(params)) - .await?; - self.cache_invalidator.invalidate(); - Ok(result) - }) + pub(crate) async fn remove_from_cache(&self) -> Result { + let params = json!({ "Params": { "Model": self.info.id } }); + let result = self + .core + .execute_command_async("remove_cached_model".into(), Some(params)) + .await?; + self.cache_invalidator.invalidate(); + Ok(result) } - fn create_chat_client(&self) -> ChatClient { + pub(crate) fn create_chat_client(&self) -> ChatClient { ChatClient::new(&self.info.id, Arc::clone(&self.core)) } - fn create_audio_client(&self) -> AudioClient { + pub(crate) fn create_audio_client(&self) -> AudioClient { AudioClient::new(&self.info.id, Arc::clone(&self.core)) } - - fn variants(&self) -> Vec> { - vec![Arc::new(self.clone())] - } - - fn select_variant(&self, _id: &str) -> Result<()> { - Err(FoundryLocalError::ModelOperation { - reason: format!( - "select_variant is not supported on a single variant. \ - Call Catalog::get_model(\"{}\") to get a Model with all variants available.", - self.alias() - ), - }) - } } diff --git a/sdk/rust/src/imodel.rs b/sdk/rust/src/imodel.rs deleted file mode 100644 index 80ff78f5..00000000 --- a/sdk/rust/src/imodel.rs +++ /dev/null @@ -1,100 +0,0 @@ -//! Public trait defining the model interface. -//! -//! All public APIs return [`Arc`] so that callers never need to -//! reference the internal `ModelVariant` type directly. - -use std::fmt::Debug; -use std::future::Future; -use std::path::PathBuf; -use std::pin::Pin; -use std::sync::Arc; - -use crate::error::Result; -use crate::openai::AudioClient; -use crate::openai::ChatClient; -use crate::types::ModelInfo; - -/// Common interface for interacting with a model. -/// -/// An `IModel` may represent either a group of variants (as returned by -/// [`Catalog::get_model`](crate::Catalog::get_model)) or a single variant (as -/// returned by [`Catalog::get_model_variant`](crate::Catalog::get_model_variant) -/// or [`IModel::variants`]). -/// -/// When an `IModel` groups multiple variants, operations are forwarded to -/// the currently selected variant. Use [`variants`](IModel::variants) to -/// inspect the available variants and [`select_variant`](IModel::select_variant) -/// to change the selection. -pub trait IModel: Send + Sync + Debug { - /// Unique identifier of the (selected) variant. - fn id(&self) -> &str; - - /// Alias shared by all variants of this model. - fn alias(&self) -> &str; - - /// Full catalog metadata for the (selected) variant. - fn info(&self) -> &ModelInfo; - - /// Maximum context length (in tokens), or `None` if unknown. - fn context_length(&self) -> Option; - - /// Comma-separated input modalities (e.g. `"text,image"`), or `None`. - fn input_modalities(&self) -> Option<&str>; - - /// Comma-separated output modalities (e.g. `"text"`), or `None`. - fn output_modalities(&self) -> Option<&str>; - - /// Capability tags (e.g. `"reasoning"`), or `None`. - fn capabilities(&self) -> Option<&str>; - - /// Whether the model supports tool/function calling, or `None`. - fn supports_tool_calling(&self) -> Option; - - /// Whether the (selected) variant is cached on disk. - fn is_cached(&self) -> Pin> + Send + '_>>; - - /// Whether the (selected) variant is loaded into memory. - fn is_loaded(&self) -> Pin> + Send + '_>>; - - /// Download the (selected) variant. If `progress` is provided it - /// receives human-readable progress strings as they arrive. - fn download( - &self, - progress: Option>, - ) -> Pin> + Send + '_>>; - - /// Return the local file-system path of the (selected) variant. - fn path(&self) -> Pin> + Send + '_>>; - - /// Load the (selected) variant into memory. - fn load(&self) -> Pin> + Send + '_>>; - - /// Unload the (selected) variant from memory. - fn unload(&self) -> Pin> + Send + '_>>; - - /// Remove the (selected) variant from the local cache. - fn remove_from_cache(&self) -> Pin> + Send + '_>>; - - /// Create a [`ChatClient`] bound to the (selected) variant. - fn create_chat_client(&self) -> ChatClient; - - /// Create an [`AudioClient`] bound to the (selected) variant. - fn create_audio_client(&self) -> AudioClient; - - /// Available variants of this model. - /// - /// For a single-variant model (e.g. from - /// [`Catalog::get_model_variant`](crate::Catalog::get_model_variant)), - /// this returns a single-element list containing itself. - fn variants(&self) -> Vec>; - - /// Select a variant by its unique id. - /// - /// # Errors - /// - /// Returns an error if no variant with the given id exists in this model. - /// For single-variant models this always returns an error — use - /// [`Catalog::get_model`](crate::Catalog::get_model) to obtain a model - /// with all variants available. - fn select_variant(&self, id: &str) -> Result<()>; -} diff --git a/sdk/rust/src/lib.rs b/sdk/rust/src/lib.rs index f405f9c0..872a875c 100644 --- a/sdk/rust/src/lib.rs +++ b/sdk/rust/src/lib.rs @@ -6,7 +6,6 @@ mod catalog; mod configuration; mod error; mod foundry_local_manager; -mod imodel; mod types; pub(crate) mod detail; @@ -14,10 +13,9 @@ pub mod openai; pub use self::catalog::Catalog; pub use self::configuration::{FoundryLocalConfig, LogLevel, Logger}; +pub use self::detail::model::Model; pub use self::error::FoundryLocalError; pub use self::foundry_local_manager::FoundryLocalManager; -pub use self::imodel::IModel; -pub use self::detail::model::Model; pub use self::types::{ ChatResponseFormat, ChatToolChoice, DeviceType, EpDownloadResult, EpInfo, ModelInfo, ModelSettings, Parameter, PromptTemplate, Runtime, diff --git a/sdk/rust/tests/integration/model_test.rs b/sdk/rust/tests/integration/model_test.rs index 7380f064..4e3b371b 100644 --- a/sdk/rust/tests/integration/model_test.rs +++ b/sdk/rust/tests/integration/model_test.rs @@ -1,4 +1,5 @@ use super::common; +use std::sync::Arc; // ── Cached model verification ────────────────────────────────────────────────