Skip to content

Commit 255df82

Browse files
authored
Optimize the commands (#31)
* Optimize the commands Signed-off-by: kerthcet <kerthcet@gmail.com> * fix lint Signed-off-by: kerthcet <kerthcet@gmail.com> --------- Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent 9c2c93f commit 255df82

3 files changed

Lines changed: 180 additions & 211 deletions

File tree

src/cli/commands.rs

Lines changed: 61 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -185,33 +185,56 @@ pub async fn run(cli: Cli) {
185185
Ok(Some(model)) => {
186186
println!("Name: {}", model.name);
187187
println!("Kind: Model");
188-
println!("Metadata:");
189-
println!(" Created: {}", format_time_ago(&model.created_at));
190-
println!(" Updated: {}", format_time_ago(&model.updated_at));
191-
192188
println!("Spec:");
193-
// Architecture section (only if info is available)
194-
if let Some(arch) = &model.arch {
195-
println!(" Architecture:");
196-
if let Some(model_type) = &arch.model_type {
197-
println!(" Type: {}", model_type);
198-
}
199-
if let Some(classes) = &arch.classes {
200-
println!(" Classes: {}", classes.join(", "));
201-
}
202-
if let Some(parameters) = &arch.parameters {
203-
println!(" Parameters: {}", parameters);
204-
}
205-
if let Some(context_window) = arch.context_window {
206-
println!(" Context Window: {}", context_window);
207-
}
189+
if let Some(spec) = &model.spec {
190+
println!(
191+
" Author: {}",
192+
spec.author.as_deref().unwrap_or("N/A")
193+
);
194+
println!(
195+
" Task: {}",
196+
spec.task.as_deref().unwrap_or("N/A")
197+
);
198+
println!(
199+
" License: {}",
200+
spec.license
201+
.as_ref()
202+
.map(|s| s.to_uppercase())
203+
.unwrap_or_else(|| "N/A".to_string())
204+
);
205+
println!(
206+
" Model Type: {}",
207+
spec.model_type.as_deref().unwrap_or("N/A")
208+
);
209+
println!(
210+
" Parameters: {}",
211+
spec.parameters
212+
.map(crate::utils::format::format_parameters)
213+
.unwrap_or_else(|| "N/A".to_string())
214+
);
215+
println!(
216+
" Context Window: {}",
217+
spec.context_window
218+
.map(|w| crate::utils::format::format_parameters(w as u64))
219+
.unwrap_or_else(|| "N/A".to_string())
220+
);
221+
} else {
222+
println!(" Author: N/A");
223+
println!(" Task: N/A");
224+
println!(" License: N/A");
225+
println!(" Model Type: N/A");
226+
println!(" Parameters: N/A");
227+
println!(" Context Window: N/A");
208228
}
209229
// Registry section
210230
println!(" Registry:");
211231
println!(" Provider: {}", model.provider);
212232
println!(" Revision: {}", model.revision);
213233
println!(" Size: {}", format_size_decimal(model.size));
214234
println!(" Cache Path: {}", model.cache_path);
235+
println!("Status:");
236+
println!(" Created: {}", format_time_ago(&model.created_at));
237+
println!(" Updated: {}", format_time_ago(&model.updated_at));
215238
}
216239
Ok(None) => {
217240
eprintln!("Model not found: {}", args.model);
@@ -233,7 +256,7 @@ pub async fn run(cli: Cli) {
233256
#[cfg(test)]
234257
mod tests {
235258
use super::*;
236-
use crate::registry::model_registry::{ModelArchitecture, ModelInfo};
259+
use crate::registry::model_registry::ModelInfo;
237260
use tempfile::TempDir;
238261

239262
#[test]
@@ -258,7 +281,7 @@ mod tests {
258281
created_at: "2025-01-01T00:00:00Z".to_string(),
259282
updated_at: "2025-01-01T00:00:00Z".to_string(),
260283
cache_path: "/tmp/test".to_string(),
261-
arch: None,
284+
spec: None,
262285
};
263286

264287
registry.register_model(model).unwrap();
@@ -282,11 +305,13 @@ mod tests {
282305
created_at: "2025-01-01T00:00:00Z".to_string(),
283306
updated_at: "2025-01-02T00:00:00Z".to_string(),
284307
cache_path: "/tmp/test/gpt".to_string(),
285-
arch: Some(ModelArchitecture {
308+
spec: Some(crate::registry::model_registry::ModelSpec {
286309
model_type: Some("gpt2".to_string()),
287-
classes: Some(vec!["GPT2LMHeadModel".to_string()]),
310+
parameters: Some(7_000_000_000),
288311
context_window: Some(2048),
289-
parameters: Some("7.00B".to_string()),
312+
author: Some("test-org".to_string()),
313+
task: Some("text-generation".to_string()),
314+
license: Some("mit".to_string()),
290315
}),
291316
};
292317

@@ -300,11 +325,13 @@ mod tests {
300325
assert_eq!(model_info.created_at, "2025-01-01T00:00:00Z");
301326
assert_eq!(model_info.updated_at, "2025-01-02T00:00:00Z");
302327

303-
let arch = model_info.arch.unwrap();
304-
assert_eq!(arch.model_type, Some("gpt2".to_string()));
305-
assert_eq!(arch.classes, Some(vec!["GPT2LMHeadModel".to_string()]));
306-
assert_eq!(arch.context_window, Some(2048));
307-
assert_eq!(arch.parameters, Some("7.00B".to_string()));
328+
let spec = model_info.spec.as_ref().unwrap();
329+
assert_eq!(spec.author, Some("test-org".to_string()));
330+
assert_eq!(spec.task, Some("text-generation".to_string()));
331+
assert_eq!(spec.license, Some("mit".to_string()));
332+
assert_eq!(spec.model_type, Some("gpt2".to_string()));
333+
assert_eq!(spec.context_window, Some(2048));
334+
assert_eq!(spec.parameters, Some(7_000_000_000));
308335
}
309336

310337
#[test]
@@ -320,7 +347,7 @@ mod tests {
320347
created_at: "2025-01-01T00:00:00Z".to_string(),
321348
updated_at: "2025-01-01T00:00:00Z".to_string(),
322349
cache_path: "/tmp/test/simple".to_string(),
323-
arch: None,
350+
spec: None,
324351
};
325352

326353
registry.register_model(model).unwrap();
@@ -330,7 +357,7 @@ mod tests {
330357

331358
let model_info = retrieved.unwrap();
332359
assert_eq!(model_info.name, "test/simple-model");
333-
assert!(model_info.arch.is_none());
360+
assert!(model_info.spec.is_none());
334361
}
335362

336363
#[test]
@@ -346,7 +373,7 @@ mod tests {
346373
created_at: "2025-01-01T00:00:00Z".to_string(),
347374
updated_at: "2025-01-01T00:00:00Z".to_string(),
348375
cache_path: "/tmp/test/remove".to_string(),
349-
arch: None,
376+
spec: None,
350377
};
351378

352379
registry.register_model(model).unwrap();
@@ -400,7 +427,7 @@ mod tests {
400427
created_at: "2025-01-01T00:00:00Z".to_string(),
401428
updated_at: "2025-01-01T00:00:00Z".to_string(),
402429
cache_path: "/tmp/test".to_string(),
403-
arch: None,
430+
spec: None,
404431
};
405432

406433
registry.register_model(model).unwrap();
@@ -414,7 +441,7 @@ mod tests {
414441
created_at: "2025-01-05T00:00:00Z".to_string(),
415442
updated_at: "2025-01-05T00:00:00Z".to_string(),
416443
cache_path: "/tmp/test".to_string(),
417-
arch: None,
444+
spec: None,
418445
};
419446

420447
registry.register_model(updated_model).unwrap();

src/downloader/huggingface.rs

Lines changed: 92 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use indicatif::{ProgressBar, ProgressStyle};
66

77
use crate::downloader::downloader::{DownloadError, Downloader};
88
use crate::downloader::progress::{DownloadProgressManager, FileProgress};
9-
use crate::registry::model_registry::{ModelArchitecture, ModelInfo, ModelRegistry};
9+
use crate::registry::model_registry::{ModelInfo, ModelRegistry, ModelSpec};
1010
use crate::utils::file::{self, format_model_name};
1111

1212
/// Adapter to bridge HuggingFace's Progress trait with our FileProgress
@@ -35,6 +35,60 @@ impl HuggingFaceDownloader {
3535
pub fn new() -> Self {
3636
Self
3737
}
38+
39+
async fn fetch_metadata_from_api(
40+
model_name: &str,
41+
) -> (
42+
Option<String>,
43+
Option<String>,
44+
Option<String>,
45+
Option<String>,
46+
Option<u64>,
47+
Option<u64>,
48+
) {
49+
let url = format!("https://huggingface.co/api/models/{}", model_name);
50+
let client = reqwest::Client::new();
51+
52+
match client.get(&url).send().await {
53+
Ok(response) => {
54+
if let Ok(json) = response.json::<serde_json::Value>().await {
55+
let author = json
56+
.get("author")
57+
.and_then(|v| v.as_str())
58+
.map(|s| s.to_string());
59+
60+
let task = json
61+
.get("pipeline_tag")
62+
.and_then(|v| v.as_str())
63+
.map(|s| s.to_string());
64+
65+
let license = json
66+
.get("cardData")
67+
.and_then(|card| card.get("license"))
68+
.and_then(|v| v.as_str())
69+
.map(|s| s.to_string());
70+
71+
let model_type = json
72+
.get("config")
73+
.and_then(|config| config.get("model_type"))
74+
.and_then(|v| v.as_str())
75+
.map(|s| s.to_string());
76+
77+
let parameters = json
78+
.get("safetensors")
79+
.and_then(|st| st.get("total"))
80+
.and_then(|v| v.as_u64());
81+
82+
let storage = json.get("usedStorage").and_then(|v| v.as_u64());
83+
84+
(author, task, license, model_type, parameters, storage)
85+
} else {
86+
(None, None, None, None, None, None)
87+
}
88+
}
89+
Err(_) => (None, None, None, None, None, None),
90+
}
91+
}
3892
}
3993

4094
impl Default for HuggingFaceDownloader {
@@ -195,34 +249,63 @@ impl Downloader for HuggingFaceDownloader {
195249
}
196250

197251
let elapsed_time = start_time.elapsed();
198-
199-
// Get accumulated size from downloads
200-
let downloaded_size = progress_manager.total_downloaded_bytes();
201252
let model_cache_path = cache_dir.join(format_model_name(name));
202253

203254
// Register the model only if not totally cached
204255
if !model_totally_cached {
205-
// Extract architecture info from config.json
256+
// Fetch metadata from HuggingFace API
257+
let (
258+
author_from_api,
259+
task_from_api,
260+
license_from_api,
261+
model_type_from_api,
262+
parameters_from_api,
263+
storage_from_api,
264+
) = Self::fetch_metadata_from_api(name).await;
265+
266+
// Extract context_window from config.json
206267
let config_path = snapshot_path.join("config.json");
207-
let arch = if config_path.exists() {
268+
let context_window = if config_path.exists() {
208269
std::fs::read_to_string(&config_path)
209270
.ok()
210271
.and_then(|content| serde_json::from_str::<serde_json::Value>(&content).ok())
211-
.and_then(|config| ModelArchitecture::from_config(&config))
272+
.and_then(|config| {
273+
config
274+
.get("text_config")
275+
.and_then(|tc| tc.get("max_position_embeddings"))
276+
.or_else(|| config.get("max_position_embeddings"))
277+
.or_else(|| config.get("n_positions"))
278+
.or_else(|| config.get("n_ctx"))
279+
.and_then(|v| v.as_u64())
280+
.map(|v| v as u32)
281+
})
212282
} else {
213283
None
214284
};
215285

286+
let spec = Some(ModelSpec {
287+
author: author_from_api,
288+
task: task_from_api,
289+
license: license_from_api,
290+
model_type: model_type_from_api,
291+
parameters: parameters_from_api,
292+
context_window,
293+
});
294+
295+
// Use storage from API, fallback to accumulated download size
296+
let model_size =
297+
storage_from_api.unwrap_or_else(|| progress_manager.total_downloaded_bytes());
298+
216299
let now = chrono::Local::now().to_rfc3339();
217300
let model_info_record = ModelInfo {
218301
name: name.to_string(),
219302
provider: "huggingface".to_string(),
220303
revision: sha,
221-
size: downloaded_size,
304+
size: model_size,
222305
created_at: now.clone(),
223306
updated_at: now,
224307
cache_path: model_cache_path.to_string_lossy().to_string(),
225-
arch,
308+
spec,
226309
};
227310

228311
let registry = ModelRegistry::new(None);

0 commit comments

Comments
 (0)