Skip to content

Commit d07d6b1

Browse files
committed
reorganize the structure
Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent 8695655 commit d07d6b1

3 files changed

Lines changed: 153 additions & 80 deletions

File tree

src/downloader/huggingface.rs

Lines changed: 4 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ use crate::downloader::downloader::{DownloadError, Downloader};
88
use crate::downloader::progress::{DownloadProgressManager, FileProgress};
99
use crate::registry::model_registry::{ModelInfo, ModelRegistry, ModelSpec};
1010
use crate::utils::file::{self, format_model_name};
11-
use crate::utils::format::format_parameters;
1211

1312
/// Adapter to bridge HuggingFace's Progress trait with our FileProgress
1413
#[derive(Clone)]
@@ -36,36 +35,6 @@ impl HuggingFaceDownloader {
3635
pub fn new() -> Self {
3736
Self
3837
}
39-
40-
fn estimate_parameters(config: &serde_json::Value) -> Option<String> {
41-
// Try to extract architecture dimensions for parameter estimation
42-
let n_layer = config
43-
.get("n_layer")
44-
.or_else(|| config.get("num_hidden_layers"))
45-
.and_then(|v| v.as_u64())?;
46-
47-
let n_embd = config
48-
.get("n_embd")
49-
.or_else(|| config.get("hidden_size"))
50-
.and_then(|v| v.as_u64())?;
51-
52-
let vocab_size = config.get("vocab_size").and_then(|v| v.as_u64())?;
53-
54-
let n_positions = config
55-
.get("n_positions")
56-
.or_else(|| config.get("max_position_embeddings"))
57-
.and_then(|v| v.as_u64())
58-
.unwrap_or(2048);
59-
60-
// Rough parameter estimation for transformer models
61-
// Each layer: ~12 * n_embd^2 (attention + FFN)
62-
// Embeddings: vocab_size * n_embd + n_positions * n_embd
63-
let layer_params = 12 * n_layer * n_embd * n_embd;
64-
let embedding_params = vocab_size * n_embd + n_positions * n_embd;
65-
let total_params = layer_params + embedding_params;
66-
67-
Some(format_parameters(total_params))
68-
}
6938
}
7039

7140
impl Default for HuggingFaceDownloader {
@@ -236,54 +205,10 @@ impl Downloader for HuggingFaceDownloader {
236205
// Extract architecture info from config.json
237206
let config_path = snapshot_path.join("config.json");
238207
let spec = if config_path.exists() {
239-
match std::fs::read_to_string(&config_path) {
240-
Ok(config_content) => {
241-
match serde_json::from_str::<serde_json::Value>(&config_content) {
242-
Ok(config) => {
243-
let model_type = config
244-
.get("model_type")
245-
.and_then(|v| v.as_str())
246-
.map(|s| s.to_string());
247-
248-
let architectures = config
249-
.get("architectures")
250-
.and_then(|v| v.as_array())
251-
.map(|arr| {
252-
arr.iter()
253-
.filter_map(|v| v.as_str().map(|s| s.to_string()))
254-
.collect::<Vec<String>>()
255-
})
256-
.filter(|v| !v.is_empty());
257-
258-
let context_window = config
259-
.get("max_position_embeddings")
260-
.or_else(|| config.get("n_ctx"))
261-
.or_else(|| config.get("n_positions"))
262-
.and_then(|v| v.as_u64())
263-
.map(|v| v as u32);
264-
265-
let parameters = Self::estimate_parameters(&config);
266-
267-
if model_type.is_some()
268-
|| architectures.is_some()
269-
|| context_window.is_some()
270-
|| parameters.is_some()
271-
{
272-
Some(ModelSpec {
273-
model_type,
274-
architectures,
275-
context_window,
276-
parameters,
277-
})
278-
} else {
279-
None
280-
}
281-
}
282-
Err(_) => None,
283-
}
284-
}
285-
Err(_) => None,
286-
}
208+
std::fs::read_to_string(&config_path)
209+
.ok()
210+
.and_then(|content| serde_json::from_str::<serde_json::Value>(&content).ok())
211+
.and_then(|config| ModelSpec::from_config(&config))
287212
} else {
288213
None
289214
};

src/registry/model_registry.rs

Lines changed: 147 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ use std::fs;
44
use std::path::PathBuf;
55

66
use crate::utils::file;
7+
use crate::utils::format::format_parameters;
78

8-
#[derive(Debug, Serialize, Deserialize, Clone)]
9+
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
910
pub struct ModelSpec {
1011
#[serde(skip_serializing_if = "Option::is_none")]
1112
pub model_type: Option<String>,
@@ -17,6 +18,80 @@ pub struct ModelSpec {
1718
pub parameters: Option<String>,
1819
}
1920

21+
impl ModelSpec {
22+
/// Extract model spec from config.json
23+
pub fn from_config(config: &serde_json::Value) -> Option<Self> {
24+
let model_type = config
25+
.get("model_type")
26+
.and_then(|v| v.as_str())
27+
.map(|s| s.to_string());
28+
29+
let architectures = config
30+
.get("architectures")
31+
.and_then(|v| v.as_array())
32+
.map(|arr| {
33+
arr.iter()
34+
.filter_map(|v| v.as_str().map(|s| s.to_string()))
35+
.collect::<Vec<String>>()
36+
})
37+
.filter(|v| !v.is_empty());
38+
39+
let context_window = config
40+
.get("n_positions")
41+
.or_else(|| config.get("max_position_embeddings"))
42+
.or_else(|| config.get("n_ctx"))
43+
.and_then(|v| v.as_u64())
44+
.map(|v| v as u32);
45+
46+
let parameters = Self::estimate_parameters(config);
47+
48+
if model_type.is_some()
49+
|| architectures.is_some()
50+
|| context_window.is_some()
51+
|| parameters.is_some()
52+
{
53+
Some(ModelSpec {
54+
model_type,
55+
architectures,
56+
context_window,
57+
parameters,
58+
})
59+
} else {
60+
None
61+
}
62+
}
63+
64+
/// Estimate model parameters from config
65+
fn estimate_parameters(config: &serde_json::Value) -> Option<String> {
66+
let n_layer = config
67+
.get("n_layer")
68+
.or_else(|| config.get("num_hidden_layers"))
69+
.and_then(|v| v.as_u64())?;
70+
71+
let n_embd = config
72+
.get("n_embd")
73+
.or_else(|| config.get("hidden_size"))
74+
.and_then(|v| v.as_u64())?;
75+
76+
let vocab_size = config.get("vocab_size").and_then(|v| v.as_u64())?;
77+
78+
let n_positions = config
79+
.get("n_positions")
80+
.or_else(|| config.get("max_position_embeddings"))
81+
.and_then(|v| v.as_u64())
82+
.unwrap_or(2048);
83+
84+
// Rough parameter estimation for transformer models
85+
// Each layer: ~12 * n_embd^2 (attention + FFN)
86+
// Embeddings: vocab_size * n_embd + n_positions * n_embd
87+
let layer_params = 12 * n_layer * n_embd * n_embd;
88+
let embedding_params = vocab_size * n_embd + n_positions * n_embd;
89+
let total_params = layer_params + embedding_params;
90+
91+
Some(format_parameters(total_params))
92+
}
93+
}
94+
2095
#[derive(Debug, Serialize, Deserialize, Clone)]
2196
pub struct ModelInfo {
2297
pub name: String,
@@ -349,4 +424,75 @@ mod tests {
349424
assert_eq!(model_info.name, "test/legacy-model");
350425
assert!(model_info.spec.is_none());
351426
}
427+
428+
#[test]
429+
fn test_model_spec_from_config_gpt2() {
430+
use serde_json::json;
431+
432+
let config = json!({
433+
"model_type": "gpt2",
434+
"architectures": ["GPT2LMHeadModel"],
435+
"n_layer": 5,
436+
"n_embd": 32,
437+
"vocab_size": 1000,
438+
"n_positions": 512
439+
});
440+
441+
let spec = ModelSpec::from_config(&config);
442+
assert!(spec.is_some());
443+
444+
let spec = spec.unwrap();
445+
assert_eq!(spec.model_type, Some("gpt2".to_string()));
446+
assert_eq!(spec.architectures, Some(vec!["GPT2LMHeadModel".to_string()]));
447+
assert_eq!(spec.context_window, Some(512));
448+
assert_eq!(spec.parameters, Some("109.82K".to_string()));
449+
}
450+
451+
#[test]
452+
fn test_model_spec_from_config_bert_style() {
453+
use serde_json::json;
454+
455+
let config = json!({
456+
"model_type": "bert",
457+
"num_hidden_layers": 12,
458+
"hidden_size": 768,
459+
"vocab_size": 30000,
460+
"max_position_embeddings": 512
461+
});
462+
463+
let spec = ModelSpec::from_config(&config);
464+
assert!(spec.is_some());
465+
466+
let spec = spec.unwrap();
467+
assert_eq!(spec.model_type, Some("bert".to_string()));
468+
assert_eq!(spec.context_window, Some(512));
469+
assert!(spec.parameters.unwrap().contains("M"));
470+
}
471+
472+
#[test]
473+
fn test_model_spec_from_config_partial() {
474+
use serde_json::json;
475+
476+
let config = json!({
477+
"model_type": "llama",
478+
"n_ctx": 4096
479+
});
480+
481+
let spec = ModelSpec::from_config(&config);
482+
assert!(spec.is_some());
483+
484+
let spec = spec.unwrap();
485+
assert_eq!(spec.model_type, Some("llama".to_string()));
486+
assert_eq!(spec.context_window, Some(4096));
487+
assert_eq!(spec.parameters, None);
488+
}
489+
490+
#[test]
491+
fn test_model_spec_from_config_empty() {
492+
use serde_json::json;
493+
494+
let config = json!({});
495+
let spec = ModelSpec::from_config(&config);
496+
assert_eq!(spec, None);
497+
}
352498
}

src/utils/format.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ pub fn format_parameters(count: u64) -> String {
5151
}
5252
}
5353

54+
5455
/// Format RFC3339 timestamp to human-readable relative time (e.g., "2 hours ago")
5556
pub fn format_time_ago(timestamp: &str) -> String {
5657
// Try to parse as RFC3339
@@ -350,4 +351,5 @@ mod tests {
350351
assert_eq!(format_parameters(999_999_999), "1000.00M");
351352
assert_eq!(format_parameters(1_000_000_000), "1.00B");
352353
}
354+
353355
}

0 commit comments

Comments
 (0)