Skip to content

Commit a0ebe7e

Browse files
authored
Support inspect command (#29)
* add support for inspect Signed-off-by: kerthcet <kerthcet@gmail.com> * add support for inspect Signed-off-by: kerthcet <kerthcet@gmail.com> * add pull progress bar Signed-off-by: kerthcet <kerthcet@gmail.com> * polish the download progress Signed-off-by: kerthcet <kerthcet@gmail.com> * reorganize the structure Signed-off-by: kerthcet <kerthcet@gmail.com> * optimize the structure Signed-off-by: kerthcet <kerthcet@gmail.com> * fix test Signed-off-by: kerthcet <kerthcet@gmail.com> * fix lint Signed-off-by: kerthcet <kerthcet@gmail.com> --------- Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent bed7e59 commit a0ebe7e

7 files changed

Lines changed: 444 additions & 22 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ PUMA Information:
6868
| `stop` | 🚧 | Stop a running model | `puma stop <model-id>` |
6969
| `rm` || Remove a model | `puma rm InftyAI/tiny-random-gpt2` |
7070
| `info` || Display system-wide information | `puma info` |
71-
| `inspect` | 🚧 | Return detailed information about a model or service | `puma inspect InftyAI/tiny-random-gpt2` |
71+
| `inspect` | | Return detailed information about a model or service | `puma inspect InftyAI/tiny-random-gpt2` |
7272
| `version` || Show PUMA version | `puma version` |
7373
| `help` || Show help information | `puma help` |
7474

src/cli/commands.rs

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ enum Commands {
3333
/// Display system-wide information
3434
INFO,
3535
/// Return detailed information about a model
36-
INSPECT,
36+
INSPECT(InspectArgs),
3737
/// Returns the version of PUMA.
3838
VERSION,
3939
}
@@ -58,6 +58,12 @@ struct RmArgs {
5858
model: String,
5959
}
6060

61+
#[derive(Parser)]
62+
struct InspectArgs {
63+
/// Model name to inspect (e.g., InftyAI/tiny-random-gpt2)
64+
model: String,
65+
}
66+
6167
#[derive(Debug, Clone, Default, clap::ValueEnum)]
6268
pub enum Provider {
6369
#[default]
@@ -70,7 +76,12 @@ pub async fn run(cli: Cli) {
7076
match cli.command {
7177
Commands::PS => {
7278
let mut table = Table::new();
73-
table.set_format(*format::consts::FORMAT_CLEAN);
79+
table.set_format(
80+
format::FormatBuilder::new()
81+
.column_separator(' ')
82+
.padding(0, 1)
83+
.build(),
84+
);
7485
table.add_row(row!["NAME", "PROVIDER", "MODEL", "STATUS", "AGE"]);
7586
table.add_row(row![
7687
"deepseek-r1",
@@ -88,7 +99,12 @@ pub async fn run(cli: Cli) {
8899
let models = registry.load_models().unwrap_or_default();
89100

90101
let mut table = Table::new();
91-
table.set_format(*format::consts::FORMAT_CLEAN);
102+
table.set_format(
103+
format::FormatBuilder::new()
104+
.column_separator(' ')
105+
.padding(0, 1)
106+
.build(),
107+
);
92108
table.add_row(row!["MODEL", "PROVIDER", "REVISION", "SIZE", "MODIFIED"]);
93109

94110
for model in models {
@@ -163,8 +179,51 @@ pub async fn run(cli: Cli) {
163179
info.display();
164180
}
165181

166-
Commands::INSPECT => {
167-
println!("Returning detailed information about model...");
182+
Commands::INSPECT(args) => {
183+
let registry = ModelRegistry::new(None);
184+
185+
match registry.get_model(&args.model) {
186+
Ok(Some(model)) => {
187+
println!("Name: {}", model.name);
188+
println!("Kind: Model");
189+
190+
println!("Spec:");
191+
// Architecture section (only if info is available)
192+
if let Some(arch) = &model.arch {
193+
println!(" Architecture:");
194+
if let Some(model_type) = &arch.model_type {
195+
println!(" Type: {}", model_type);
196+
}
197+
if let Some(classes) = &arch.classes {
198+
println!(" Classes: {}", classes.join(", "));
199+
}
200+
if let Some(parameters) = &arch.parameters {
201+
println!(" Parameters: {}", parameters);
202+
}
203+
if let Some(context_window) = arch.context_window {
204+
println!(" Context Window: {}", context_window);
205+
}
206+
}
207+
// Registry section
208+
println!(" Registry:");
209+
println!(" Provider: {}", model.provider);
210+
println!(" Revision: {}", model.revision);
211+
println!(" Size: {}", format_size_decimal(model.size));
212+
println!(
213+
" Modified: {}",
214+
format_time_ago(&model.modified_at)
215+
);
216+
println!(" Cache Path: {}", model.cache_path);
217+
}
218+
Ok(None) => {
219+
eprintln!("Model not found: {}", args.model);
220+
std::process::exit(1);
221+
}
222+
Err(e) => {
223+
eprintln!("Failed to load registry: {}", e);
224+
std::process::exit(1);
225+
}
226+
}
168227
}
169228

170229
Commands::VERSION => {

src/downloader/huggingface.rs

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@ use colored::Colorize;
22
use log::debug;
33

44
use hf_hub::api::tokio::{ApiBuilder, Progress};
5+
use indicatif::{ProgressBar, ProgressStyle};
56

67
use crate::downloader::downloader::{DownloadError, Downloader};
78
use crate::downloader::progress::{DownloadProgressManager, FileProgress};
8-
use crate::registry::model_registry::{ModelInfo, ModelRegistry};
9+
use crate::registry::model_registry::{ModelArchitecture, ModelInfo, ModelRegistry};
910
use crate::utils::file::{self, format_model_name};
1011

1112
/// Adapter to bridge HuggingFace's Progress trait with our FileProgress
@@ -62,7 +63,15 @@ impl Downloader for HuggingFaceDownloader {
6263
DownloadError::ApiError(format!("Failed to initialize Hugging Face API: {}", e))
6364
})?;
6465

65-
println!("🐆 pulling manifest");
66+
// Create a simple spinner for manifest pulling
67+
let manifest_spinner = ProgressBar::new_spinner();
68+
manifest_spinner.set_style(
69+
ProgressStyle::default_spinner()
70+
.tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏")
71+
.template("pulling manifest {spinner:.white}")
72+
.unwrap(),
73+
);
74+
manifest_spinner.enable_steady_tick(std::time::Duration::from_millis(80));
6675

6776
// Download the entire model repository using snapshot download
6877
let repo = api.model(name.to_string());
@@ -81,6 +90,10 @@ impl Downloader for HuggingFaceDownloader {
8190
}
8291
})?;
8392

93+
// Stop manifest spinner and print clean message
94+
manifest_spinner.finish_and_clear();
95+
println!("pulling manifest");
96+
8497
debug!("Model info for {}: {:?}", name, model_info);
8598

8699
// Calculate the longest filename for proper alignment
@@ -91,6 +104,8 @@ impl Downloader for HuggingFaceDownloader {
91104
.max()
92105
.unwrap_or(30);
93106

107+
// Add extra space for "pulling " prefix
108+
let max_filename_len = max_filename_len + 8;
94109
// Create progress manager
95110
let progress_manager = DownloadProgressManager::new(max_filename_len);
96111

@@ -124,8 +139,9 @@ impl Downloader for HuggingFaceDownloader {
124139
debug!("File {} found in cache, showing as complete", filename);
125140

126141
// Create progress bar for cached file (no speed display)
142+
let display_name = format!("pulling {}", filename);
127143
let mut file_progress =
128-
progress_manager_clone.create_cached_file_progress(&filename);
144+
progress_manager_clone.create_cached_file_progress(&display_name);
129145
let file_size = cached_file_path.metadata().map(|m| m.len()).unwrap_or(0);
130146
file_progress.init(file_size);
131147
file_progress.update(file_size);
@@ -136,7 +152,8 @@ impl Downloader for HuggingFaceDownloader {
136152

137153
// File not in cache, download with progress
138154
debug!("Downloading: {}", filename);
139-
let file_progress = progress_manager_clone.create_file_progress(&filename);
155+
let display_name = format!("pulling {}", filename);
156+
let file_progress = progress_manager_clone.create_file_progress(&display_name);
140157
let progress = HfProgressAdapter {
141158
progress: file_progress,
142159
};
@@ -156,37 +173,65 @@ impl Downloader for HuggingFaceDownloader {
156173
tasks.push(task);
157174
}
158175

176+
// Give tasks a moment to start and create their progress bars
177+
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
178+
179+
// Show spinner at the bottom after all progress bars are created (only if not fully cached)
180+
let spinner = if !model_totally_cached {
181+
Some(progress_manager.create_spinner())
182+
} else {
183+
None
184+
};
185+
159186
// Wait for all downloads to complete
160187
for task in tasks {
161188
task.await
162189
.map_err(|e| DownloadError::ApiError(format!("Task join error: {}", e)))??;
163190
}
164191

192+
// Finish spinner after downloads complete
193+
if let Some(spinner) = &spinner {
194+
spinner.finish_and_clear();
195+
}
196+
165197
let elapsed_time = start_time.elapsed();
166198

167199
// Get accumulated size from downloads
168200
let downloaded_size = progress_manager.total_downloaded_bytes();
169201
let model_cache_path = cache_dir.join(format_model_name(name));
170202

171-
// Register the model
172-
let model_info_record = ModelInfo {
173-
name: name.to_string(),
174-
provider: "huggingface".to_string(),
175-
revision: sha,
176-
size: downloaded_size,
177-
modified_at: chrono::Local::now().to_rfc3339(),
178-
cache_path: model_cache_path.to_string_lossy().to_string(),
179-
};
180-
203+
// Register the model only if not totally cached
181204
if !model_totally_cached {
205+
// Extract architecture info from config.json
206+
let config_path = snapshot_path.join("config.json");
207+
let arch = if config_path.exists() {
208+
std::fs::read_to_string(&config_path)
209+
.ok()
210+
.and_then(|content| serde_json::from_str::<serde_json::Value>(&content).ok())
211+
.and_then(|config| ModelArchitecture::from_config(&config))
212+
} else {
213+
None
214+
};
215+
216+
let model_info_record = ModelInfo {
217+
name: name.to_string(),
218+
provider: "huggingface".to_string(),
219+
revision: sha,
220+
size: downloaded_size,
221+
modified_at: chrono::Local::now().to_rfc3339(),
222+
cache_path: model_cache_path.to_string_lossy().to_string(),
223+
arch,
224+
};
225+
182226
let registry = ModelRegistry::new(None);
183227
registry
184228
.register_model(model_info_record)
185229
.map_err(|e| DownloadError::ApiError(format!("Failed to register model: {}", e)))?;
186230
}
187231

232+
// Print success message
188233
println!(
189-
"\n{} {} {} {} {:.2?}",
234+
"{} {} {} {} {:.2?}",
190235
"✓".green().bold(),
191236
"Successfully downloaded model".bright_white(),
192237
name.cyan().bold(),

src/downloader/progress.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,19 @@ impl DownloadProgressManager {
8585
pub fn total_downloaded_bytes(&self) -> u64 {
8686
self.total_size.load(Ordering::Relaxed)
8787
}
88+
89+
/// Create a spinner progress bar (for post-download operations)
90+
pub fn create_spinner(&self) -> ProgressBar {
91+
let pb = self.multi_progress.add(ProgressBar::new_spinner());
92+
pb.set_style(
93+
ProgressStyle::default_spinner()
94+
.tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏")
95+
.template("{spinner} ")
96+
.unwrap(),
97+
);
98+
pb.enable_steady_tick(std::time::Duration::from_millis(80));
99+
pb
100+
}
88101
}
89102

90103
/// Tracks progress for a single file download

0 commit comments

Comments
 (0)