Skip to content

Commit c6efcdb

Browse files
authored
Support HF downloading models (#16)
* Add HF downloader support Signed-off-by: kerthcet <kerthcet@gmail.com> * add bars Signed-off-by: kerthcet <kerthcet@gmail.com> * fix color Signed-off-by: kerthcet <kerthcet@gmail.com> * fix color Signed-off-by: kerthcet <kerthcet@gmail.com> * add download successfully message Signed-off-by: kerthcet <kerthcet@gmail.com> * change the color Signed-off-by: kerthcet <kerthcet@gmail.com> * change the rending shape Signed-off-by: kerthcet <kerthcet@gmail.com> --------- Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent 333b891 commit c6efcdb

9 files changed

Lines changed: 781 additions & 29 deletions

File tree

Cargo.lock

Lines changed: 473 additions & 10 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,5 @@ env_logger = "0.11.6"
1717
log = "0.4.26"
1818
indicatif = "0.17.11"
1919
dirs = "6.0.0"
20+
hf-hub = { version = "0.5.0", features = ["tokio"] }
21+
colored = "2.1"

Makefile

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
build:
2-
cargo build && cp target/debug/puma ./puma
2+
cargo build && cp target/debug/puma ./puma
3+
4+
test:
5+
cargo test

README.md

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,49 @@
11
# PUMA
22

3-
**PUMA** aims to be a lightweight, high-performance inference engine for local AI. Play for fun.
3+
**PUMA** aims to be a lightweight, high-performance inference engine for local AI. *Play for fun.*
44

5-
## How to Run
5+
## Features
66

7-
### Build
7+
- 🚀 **Model Management** - Download and manage AI models from multiple providers
8+
9+
## Quick Start
10+
11+
### Installation
12+
13+
```bash
14+
make build
15+
```
816

9-
Run `make build` to build the **puma** binary.
17+
## Commands
18+
19+
| Command | Description |
20+
|---------|-------------|
21+
| `pull` | Download a model from a provider |
22+
| `ls` | List local models |
23+
| `ps` | List running models |
24+
| `run` | Create and run a model |
25+
| `stop` | Stop a running model |
26+
| `rm` | Remove a model |
27+
| `info` | Display system-wide information |
28+
| `inspect` | Return detailed information about a model |
29+
| `version` | Show PUMA version |
30+
| `help` | Show help information |
31+
32+
## Development
33+
34+
### Build
1035

11-
### Run
36+
```bash
37+
make compile
38+
```
1239

13-
Run `./puma help` to see all available commands.
40+
### Test
1441

15-
For example, you can run `./puma version` to see the binary version.
42+
```bash
43+
make test
44+
```
1645

17-
## Supported Backends
46+
### Supported Providers
1847

19-
Use [llama.cpp](https://github.com/ggerganov/llama.cpp) as the default backend for quick prototyping, will implement our own backend in the future.
48+
-**Hugging Face** - Full support with custom cache directories
49+
- 🚧 **ModelScope** - Coming soon

src/cli/commands.rs

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
use clap::{Parser, Subcommand};
22
use prettytable::{format, row, Table};
3+
use std::path::PathBuf;
4+
5+
use crate::downloader::downloader::Downloader;
6+
use crate::downloader::huggingface::HuggingFaceDownloader;
37

48
#[derive(Parser)]
59
#[command(name = "PUMA")]
@@ -33,10 +37,18 @@ enum Commands {
3337

3438
#[derive(Parser)]
3539
struct PullArgs {
36-
#[arg(long, value_name = "model name")]
40+
#[arg(short = 'm', long, value_name = "model name")]
3741
model: String,
38-
#[arg(long, value_name = "model provider", value_enum)]
42+
#[arg(
43+
short = 'p',
44+
long,
45+
value_name = "model provider",
46+
value_enum,
47+
default_value = "huggingface"
48+
)]
3949
provider: Provider,
50+
#[arg(long, value_name = "cache directory")]
51+
cache_dir: Option<PathBuf>,
4052
}
4153

4254
#[derive(Debug, Clone, clap::ValueEnum)]
@@ -85,7 +97,15 @@ pub async fn run(cli: Cli) {
8597

8698
Commands::PULL(args) => match args.provider {
8799
Provider::Huggingface => {
88-
println!("Downloading model from Huggingface...");
100+
let downloader = HuggingFaceDownloader::new();
101+
let cache_dir = args.cache_dir.unwrap_or_else(|| PathBuf::new());
102+
match downloader.download_model(&args.model, &cache_dir).await {
103+
Ok(_) => {}
104+
Err(e) => {
105+
eprintln!("Error downloading model: {}", e);
106+
std::process::exit(1);
107+
}
108+
}
89109
}
90110
Provider::Modelscope => {
91111
println!("Downloading model from Modelscope...");

src/downloader/downloader.rs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,29 @@
11
use core::fmt;
2+
use std::path::PathBuf;
23

34
#[derive(Debug)]
45
pub enum DownloadError {
5-
RequestError(String),
6-
ParseError(String),
6+
NetworkError(String),
7+
AuthError(String),
8+
ModelNotFound(String),
9+
IoError(String),
10+
ApiError(String),
711
}
812

913
impl std::error::Error for DownloadError {}
1014

1115
impl fmt::Display for DownloadError {
1216
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1317
match self {
14-
DownloadError::RequestError(e) => write!(f, "RequestError: {}", e),
15-
DownloadError::ParseError(e) => write!(f, "ParseError: {}", e),
18+
DownloadError::NetworkError(e) => write!(f, "Network error: {}", e),
19+
DownloadError::AuthError(e) => write!(f, "Authentication error: {}", e),
20+
DownloadError::ModelNotFound(e) => write!(f, "Model not found: {}", e),
21+
DownloadError::IoError(e) => write!(f, "IO error: {}", e),
22+
DownloadError::ApiError(e) => write!(f, "API error: {}", e),
1623
}
1724
}
1825
}
26+
27+
pub trait Downloader {
28+
async fn download_model(&self, name: &str, cache_dir: &PathBuf) -> Result<(), DownloadError>;
29+
}

src/downloader/huggingface.rs

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
use colored::Colorize;
2+
use log::{debug, info};
3+
use std::path::PathBuf;
4+
use std::sync::Arc;
5+
6+
use hf_hub::api::tokio::{ApiBuilder, Progress};
7+
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
8+
9+
use crate::downloader::downloader::{DownloadError, Downloader};
10+
11+
#[derive(Clone)]
12+
struct FileProgressBar {
13+
pb: ProgressBar,
14+
}
15+
16+
impl Progress for FileProgressBar {
17+
async fn init(&mut self, size: usize, _filename: &str) {
18+
self.pb.set_length(size as u64);
19+
self.pb.reset();
20+
self.pb.tick(); // Force render with correct size
21+
}
22+
23+
async fn update(&mut self, size: usize) {
24+
self.pb.inc(size as u64);
25+
}
26+
27+
async fn finish(&mut self) {}
28+
}
29+
30+
pub struct HuggingFaceDownloader;
31+
32+
impl HuggingFaceDownloader {
33+
pub fn new() -> Self {
34+
Self
35+
}
36+
}
37+
38+
impl Default for HuggingFaceDownloader {
39+
fn default() -> Self {
40+
Self::new()
41+
}
42+
}
43+
44+
impl Downloader for HuggingFaceDownloader {
45+
async fn download_model(&self, name: &str, cache_dir: &PathBuf) -> Result<(), DownloadError> {
46+
let start_time = std::time::Instant::now();
47+
48+
info!("Downloading model {} from Hugging Face...", name);
49+
50+
// Build API without default progress bars (we have our own implementation)
51+
let api = if cache_dir.as_os_str().is_empty() {
52+
// Use default HF cache
53+
ApiBuilder::new().build().map_err(|e| {
54+
DownloadError::ApiError(format!("Failed to initialize Hugging Face API: {}", e))
55+
})?
56+
} else {
57+
// Use custom cache directory
58+
ApiBuilder::new()
59+
.with_cache_dir(cache_dir.clone())
60+
.build()
61+
.map_err(|e| {
62+
DownloadError::ApiError(format!(
63+
"Failed to initialize Hugging Face API with custom cache: {}",
64+
e
65+
))
66+
})?
67+
};
68+
69+
// Download the entire model repository using snapshot download
70+
let repo = api.model(name.to_string());
71+
72+
// Get model info to list all files
73+
let model_info = repo.info().await.map_err(|e| {
74+
let err_str = e.to_string();
75+
if err_str.contains("404") || err_str.contains("not found") {
76+
DownloadError::ModelNotFound(format!("Model '{}' not found", name))
77+
} else if err_str.contains("401") || err_str.contains("403") {
78+
DownloadError::AuthError(format!("Authentication failed: {}", e))
79+
} else if err_str.contains("network") || err_str.contains("connection") {
80+
DownloadError::NetworkError(format!("Network error: {}", e))
81+
} else {
82+
DownloadError::ApiError(format!("Failed to fetch model info: {}", e))
83+
}
84+
})?;
85+
86+
debug!("Model info for {}: {:?}", name, model_info);
87+
88+
// Create multi-progress for parallel downloads
89+
let multi_progress = Arc::new(MultiProgress::new());
90+
91+
// Progress bar style with block characters (chart-like, not #)
92+
let style = ProgressStyle::default_bar()
93+
.template("{msg:<30} [{elapsed_precise}] {bar:60.white} {bytes}/{total_bytes}")
94+
.unwrap()
95+
.progress_chars("▇▆▅▄▃▂▁ ");
96+
97+
// Download all files in parallel
98+
let mut tasks = Vec::new();
99+
100+
for sibling in model_info.siblings {
101+
let api_clone = api.clone();
102+
let model_name = name.to_string();
103+
let filename = sibling.rfilename.clone();
104+
105+
let pb = multi_progress.add(ProgressBar::hidden());
106+
pb.set_style(style.clone());
107+
pb.set_message(filename.clone());
108+
109+
let task = tokio::spawn(async move {
110+
debug!("Downloading: {}", filename);
111+
112+
let repo = api_clone.model(model_name);
113+
let progress = FileProgressBar { pb: pb.clone() };
114+
115+
let result = repo.download_with_progress(&filename, progress).await;
116+
117+
match &result {
118+
Ok(_) => {
119+
pb.finish();
120+
}
121+
Err(_) => {
122+
pb.abandon();
123+
}
124+
}
125+
126+
result.map_err(|e| {
127+
DownloadError::NetworkError(format!("Failed to download {}: {}", filename, e))
128+
})
129+
});
130+
131+
tasks.push(task);
132+
}
133+
134+
// Wait for all downloads to complete
135+
for task in tasks {
136+
task.await
137+
.map_err(|e| DownloadError::ApiError(format!("Task join error: {}", e)))??;
138+
}
139+
140+
let elapsed_time = start_time.elapsed();
141+
142+
println!(
143+
"\n{} {} {} {} {:.2?}",
144+
"✓".green().bold(),
145+
"Successfully downloaded model".bright_white(),
146+
name.cyan().bold(),
147+
"in".bright_white(),
148+
elapsed_time
149+
);
150+
151+
Ok(())
152+
}
153+
}
154+
155+
#[cfg(test)]
156+
mod tests {
157+
use super::*;
158+
159+
#[tokio::test]
160+
async fn test_download_model_invalid() {
161+
let downloader = HuggingFaceDownloader::new();
162+
let result = downloader
163+
.download_model("invalid-model-that-does-not-exist-12345", &PathBuf::new())
164+
.await;
165+
assert!(result.is_err());
166+
}
167+
168+
#[tokio::test]
169+
async fn test_download_real_tiny_model() {
170+
let downloader = HuggingFaceDownloader::new();
171+
// Use HF's official tiny test model (only a few KB)
172+
let result = downloader
173+
.download_model("InftyAI/tiny-random-gpt2", &PathBuf::new())
174+
.await;
175+
assert!(
176+
result.is_ok(),
177+
"Failed to download tiny model: {:?}",
178+
result
179+
);
180+
181+
// Cleanup: remove the downloaded files from the default HF cache (~/.cache/huggingface/hub)
182+
if let Some(home_dir) = dirs::home_dir() {
183+
let cache_dir = home_dir
184+
.join(".cache")
185+
.join("huggingface")
186+
.join("hub")
187+
.join("models--InftyAI--tiny-random-gpt2");
188+
189+
if cache_dir.exists() {
190+
let _ = std::fs::remove_dir_all(&cache_dir);
191+
}
192+
}
193+
}
194+
195+
#[tokio::test]
196+
async fn test_download_with_custom_cache() {
197+
use std::env;
198+
use std::fs;
199+
200+
let downloader = HuggingFaceDownloader::new();
201+
let temp_dir = env::temp_dir().join("puma_test_cache");
202+
203+
print!("Using temporary cache directory: {:?}\n", temp_dir);
204+
205+
// Create the directory first
206+
fs::create_dir_all(&temp_dir).unwrap();
207+
208+
let result = downloader
209+
.download_model("InftyAI/tiny-random-gpt2", &temp_dir)
210+
.await;
211+
212+
assert!(
213+
result.is_ok(),
214+
"Failed to download with custom cache: {:?}",
215+
result
216+
);
217+
218+
// Cleanup
219+
let _ = std::fs::remove_dir_all(&temp_dir);
220+
}
221+
}

src/downloader/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
mod downloader;
1+
pub mod downloader;
2+
pub mod huggingface;

src/main.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ use crate::cli::commands::{run, Cli};
1010
use crate::util::file;
1111

1212
fn main() {
13-
env_logger::init();
13+
// Initialize logger.
14+
env_logger::Builder::from_env(env_logger::Env::default()).init();
1415

1516
// Create the root folder if it doesn't exist.
1617
file::create_folder_if_not_exists(&file::root_home()).unwrap();

0 commit comments

Comments
 (0)