Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ name: inftyai/tiny-random-gpt2
kind: Model
spec:
author: inftyai
provider: huggingface
task: text-generation
license: MIT
model_series: gpt2
Expand All @@ -236,8 +237,7 @@ spec:
total: 7.00B
parameters:
f32: 7.00B
artifact:
provider: huggingface
cache:
revision: abc123de
size: 1.24 GB
cache_path: ~/.puma/cache/...
Expand Down
6 changes: 3 additions & 3 deletions src/api/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use tower::util::ServiceExt; // for `oneshot` and `ready`

use super::routes::create_router;
use crate::backend::mock::MockEngine;
use crate::registry::model_registry::{ArtifactInfo, ModelInfo, ModelMetadata, ModelRegistry};
use crate::registry::model_registry::{CacheInfo, ModelInfo, ModelMetadata, ModelRegistry};

/// Helper to create test app with a pre-registered test model
/// Returns the router and the temp directory (which must be kept alive)
Expand All @@ -27,15 +27,15 @@ fn create_test_app() -> (axum::Router, TempDir) {
let test_model = ModelInfo {
uuid: "test-uuid".to_string(),
name: "test-model".to_string(),
provider: "test".to_string(),
author: Some("test-author".to_string()),
task: Some("text-generation".to_string()),
model_series: Some("test-series".to_string()),
provider: "test".to_string(),
license: Some("MIT".to_string()),
created_at: chrono::Utc::now().to_rfc3339(),
updated_at: chrono::Utc::now().to_rfc3339(),
metadata: ModelMetadata {
artifact: ArtifactInfo {
cache: CacheInfo {
revision: "test-rev".to_string(),
size: 1000,
path: "/tmp/test-model".to_string(),
Expand Down
22 changes: 12 additions & 10 deletions src/cli/commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ struct InspectArgs {
#[derive(Debug, Clone, Default, clap::ValueEnum)]
pub enum Provider {
#[default]
#[value(alias = "hf")]
Huggingface,
#[value(alias = "ms")]
Modelscope,
}

Expand Down Expand Up @@ -141,12 +143,12 @@ pub async fn run(cli: Cli) {
"MODEL", "TASK", "PROVIDER", "REVISION", "SIZE", "CREATED"
]);
for model in models {
let size_str = format_size_decimal(model.metadata.artifact.size);
let size_str = format_size_decimal(model.metadata.cache.size);

let revision_short = if model.metadata.artifact.revision.len() > 8 {
&model.metadata.artifact.revision[..8]
let revision_short = if model.metadata.cache.revision.len() > 8 {
&model.metadata.cache.revision[..8]
} else {
&model.metadata.artifact.revision
&model.metadata.cache.revision
};

let created_str = format_time_ago(&model.created_at);
Expand Down Expand Up @@ -230,7 +232,7 @@ pub async fn run(cli: Cli) {
#[cfg(test)]
mod tests {
use super::*;
use crate::registry::model_registry::{ArtifactInfo, ModelInfo, ModelMetadata};
use crate::registry::model_registry::{CacheInfo, ModelInfo, ModelMetadata};
use tempfile::TempDir;

// Helper to create a test model
Expand All @@ -245,15 +247,15 @@ mod tests {
ModelInfo {
uuid: revision.to_string(),
name: name.to_string(),
provider: "huggingface".to_string(),
author: Some("test-author".to_string()),
task: Some("text-generation".to_string()),
model_series: Some("gpt2".to_string()),
provider: "huggingface".to_string(),
license: Some("mit".to_string()),
created_at: "2025-01-01T00:00:00Z".to_string(),
updated_at: "2025-01-01T00:00:00Z".to_string(),
metadata: ModelMetadata {
artifact: ArtifactInfo {
cache: CacheInfo {
revision: revision.to_string(),
size: 1000,
path: "/tmp/test".to_string(),
Expand Down Expand Up @@ -375,7 +377,7 @@ mod tests {

// Update the model
let mut updated_model = create_test_model("test/updated-model", "v2");
updated_model.metadata.artifact.size = 2000;
updated_model.metadata.cache.size = 2000;
updated_model.created_at = "2025-01-05T00:00:00Z".to_string();
updated_model.updated_at = "2025-01-05T00:00:00Z".to_string();

Expand All @@ -387,7 +389,7 @@ mod tests {
// updated_at should be new
assert_eq!(result.updated_at, "2025-01-05T00:00:00Z");
// Other fields should be updated
assert_eq!(result.metadata.artifact.revision, "v2");
assert_eq!(result.metadata.artifact.size, 2000);
assert_eq!(result.metadata.cache.revision, "v2");
assert_eq!(result.metadata.cache.size, 2000);
}
}
26 changes: 13 additions & 13 deletions src/cli/inspect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ pub fn display(model: &ModelInfo) {
" author: {}",
model.author.as_deref().unwrap_or("N/A")
);
println!(" provider: {}", model.provider);
println!(
" model_series: {}",
model.model_series.as_deref().unwrap_or("N/A")
);
println!(
" task: {}",
model.task.as_deref().unwrap_or("N/A")
Expand All @@ -31,10 +36,6 @@ pub fn display(model: &ModelInfo) {
.map(|s| s.to_uppercase())
.unwrap_or_else(|| "N/A".to_string())
);
println!(
" model_series: {}",
model.model_series.as_deref().unwrap_or("N/A")
);
println!(
" context_window: {}",
model
Expand Down Expand Up @@ -65,15 +66,14 @@ pub fn display(model: &ModelInfo) {
println!(" safetensors: N/A");
}

// Artifact section
println!(" artifact:");
println!(" provider: {}", model.provider);
println!(" revision: {}", model.metadata.artifact.revision);
// Cache section
println!(" cache:");
println!(" revision: {}", model.metadata.cache.revision);
println!(
" size: {}",
format_size_decimal(model.metadata.artifact.size)
format_size_decimal(model.metadata.cache.size)
);
println!(" cache_path: {}", model.metadata.artifact.path);
println!(" path: {}", model.metadata.cache.path);
println!("status:");
println!(" created: {}", format_time_ago(&model.created_at));
println!(" updated: {}", format_time_ago(&model.updated_at));
Expand All @@ -82,7 +82,7 @@ pub fn display(model: &ModelInfo) {
#[cfg(test)]
mod tests {
use super::*;
use crate::registry::model_registry::{ArtifactInfo, ModelInfo, ModelMetadata};
use crate::registry::model_registry::{CacheInfo, ModelInfo, ModelMetadata};
use tempfile::TempDir;

fn create_test_model(name: &str, uuid: &str) -> ModelInfo {
Expand All @@ -94,15 +94,15 @@ mod tests {
ModelInfo {
uuid: uuid.to_string(),
name: name.to_string(),
provider: "huggingface".to_string(),
author: Some("test-author".to_string()),
task: Some("text-generation".to_string()),
model_series: Some("gpt2".to_string()),
provider: "huggingface".to_string(),
license: Some("mit".to_string()),
created_at: "2025-01-01T00:00:00Z".to_string(),
updated_at: "2025-01-01T00:00:00Z".to_string(),
metadata: ModelMetadata {
artifact: ArtifactInfo {
cache: CacheInfo {
revision: uuid.to_string(),
size: 1000,
path: "/tmp/test".to_string(),
Expand Down
6 changes: 3 additions & 3 deletions src/cli/ls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ pub fn execute(
#[cfg(test)]
mod tests {
use super::*;
use crate::registry::model_registry::{ArtifactInfo, ModelInfo, ModelMetadata};
use crate::registry::model_registry::{CacheInfo, ModelInfo, ModelMetadata};
use tempfile::TempDir;

fn create_test_model(name: &str, uuid: &str, author: &str) -> ModelInfo {
Expand All @@ -64,15 +64,15 @@ mod tests {
ModelInfo {
uuid: uuid.to_string(),
name: name.to_string(),
provider: "huggingface".to_string(),
author: Some(author.to_string()),
task: Some("text-generation".to_string()),
model_series: Some("gpt2".to_string()),
provider: "huggingface".to_string(),
license: Some("mit".to_string()),
created_at: "2025-01-01T00:00:00Z".to_string(),
updated_at: "2025-01-01T00:00:00Z".to_string(),
metadata: ModelMetadata {
artifact: ArtifactInfo {
cache: CacheInfo {
revision: uuid.to_string(),
size: 1000,
path: "/tmp/test".to_string(),
Expand Down
8 changes: 4 additions & 4 deletions src/cli/rm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub fn execute(registry: &ModelRegistry, model_name: &str) -> Result<(), String>
#[cfg(test)]
mod tests {
use super::*;
use crate::registry::model_registry::{ArtifactInfo, ModelInfo, ModelMetadata};
use crate::registry::model_registry::{CacheInfo, ModelInfo, ModelMetadata};
use tempfile::TempDir;

fn create_test_model(name: &str, uuid: &str) -> ModelInfo {
Expand All @@ -26,15 +26,15 @@ mod tests {
ModelInfo {
uuid: uuid.to_string(),
name: name.to_string(),
provider: "huggingface".to_string(),
author: Some("test-author".to_string()),
task: Some("text-generation".to_string()),
model_series: Some("gpt2".to_string()),
provider: "huggingface".to_string(),
license: Some("mit".to_string()),
created_at: "2025-01-01T00:00:00Z".to_string(),
updated_at: "2025-01-01T00:00:00Z".to_string(),
metadata: ModelMetadata {
artifact: ArtifactInfo {
cache: CacheInfo {
revision: uuid.to_string(),
size: 1000,
path: "/tmp/test".to_string(),
Expand All @@ -55,7 +55,7 @@ mod tests {
std::fs::write(cache_dir.join("model.safetensors"), "fake data").unwrap();

let mut model = create_test_model("test/remove-model", "abc123");
model.metadata.artifact.path = cache_dir.to_string_lossy().to_string();
model.metadata.cache.path = cache_dir.to_string_lossy().to_string();

registry.register_model(model).unwrap();
assert!(registry.get_model("test/remove-model").unwrap().is_some());
Expand Down
8 changes: 4 additions & 4 deletions src/downloader/huggingface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use indicatif::{ProgressBar, ProgressStyle};

use crate::downloader::downloader::{DownloadError, Downloader};
use crate::downloader::progress::{DownloadProgressManager, FileProgress};
use crate::registry::model_registry::{ArtifactInfo, ModelInfo, ModelMetadata, ModelRegistry};
use crate::registry::model_registry::{CacheInfo, ModelInfo, ModelMetadata, ModelRegistry};
use crate::utils::file::{self, format_model_name};

/// Adapter to bridge HuggingFace's Progress trait with our FileProgress
Expand Down Expand Up @@ -284,14 +284,14 @@ impl Downloader for HuggingFaceDownloader {
let model_size =
storage_from_api.unwrap_or_else(|| progress_manager.total_downloaded_bytes());

let artifact = ArtifactInfo {
let cache = CacheInfo {
revision: sha.clone(),
size: model_size,
path: model_cache_path.to_string_lossy().to_string(),
};

let metadata = ModelMetadata {
artifact,
cache,
context_window,
safetensors: safetensors_from_api,
};
Expand All @@ -300,10 +300,10 @@ impl Downloader for HuggingFaceDownloader {
let model_info_record = ModelInfo {
uuid: sha, // Use revision SHA as UUID for now
name: name.to_string(),
provider: "huggingface".to_string(),
author: author_from_api,
task: task_from_api,
model_series: model_series_from_api,
provider: "huggingface".to_string(),
license: license_from_api,
created_at: now.clone(),
updated_at: now,
Expand Down
5 changes: 3 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ fn main() {
// Setup tracing subscriber for tower-http TraceLayer
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "info,hf_hub=warn,tower_http=info".into()),
tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| {
"info,hf_hub=warn,tower_http=info,rusqlite_migration=warn".into()
}),
)
.init();

Expand Down
Loading
Loading