Skip to content

Commit 6dbd10f

Browse files
authored
Reuse the model cache to avoid duplicate download (#19)
* polish the format of the ls command Signed-off-by: kerthcet <kerthcet@gmail.com> * Have a progress manager Signed-off-by: kerthcet <kerthcet@gmail.com> * Reuse caches Signed-off-by: kerthcet <kerthcet@gmail.com> * rename util to utils Signed-off-by: kerthcet <kerthcet@gmail.com> * polish the layout of the download progress Signed-off-by: kerthcet <kerthcet@gmail.com> * revert change Signed-off-by: kerthcet <kerthcet@gmail.com> * add make format Signed-off-by: kerthcet <kerthcet@gmail.com> --------- Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent cbaea83 commit 6dbd10f

14 files changed

Lines changed: 338 additions & 132 deletions

File tree

Cargo.lock

Lines changed: 3 additions & 35 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ serde = { version = "1.0", features = ["derive"] }
1515
serde_derive = "1.0"
1616
env_logger = "0.11.6"
1717
log = "0.4.26"
18-
indicatif = "0.17.11"
18+
indicatif = "0.18"
1919
dirs = "6.0.0"
2020
hf-hub = { version = "0.5.0", features = ["tokio"] }
2121
colored = "2.1"

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,7 @@ test:
77
lint:
88
cargo fmt --all -- --check
99
cargo clippy --all-targets --all-features -- -D warnings
10+
11+
format:
12+
cargo fmt --all
13+
cargo clippy --fix --allow-dirty

src/cli/commands.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::downloader::downloader::Downloader;
55
use crate::downloader::huggingface::HuggingFaceDownloader;
66
use crate::registry::model_registry::ModelRegistry;
77
use crate::system::system_info::SystemInfo;
8-
use crate::util::format::{format_size, format_time_ago};
8+
use crate::utils::format::{format_size_decimal, format_time_ago};
99

1010
#[derive(Parser)]
1111
#[command(name = "PUMA")]
@@ -92,7 +92,7 @@ pub async fn run(cli: Cli) {
9292
table.add_row(row!["MODEL", "PROVIDER", "REVISION", "SIZE", "CREATED"]);
9393

9494
for model in models {
95-
let size_str = format_size(model.size);
95+
let size_str = format_size_decimal(model.size);
9696

9797
let revision_short = if model.revision.len() > 8 {
9898
&model.revision[..8]
@@ -118,7 +118,7 @@ pub async fn run(cli: Cli) {
118118
Provider::Huggingface => {
119119
let downloader = HuggingFaceDownloader::new();
120120
if let Err(e) = downloader.download_model(&args.model).await {
121-
eprintln!("Error downloading model: {}", e);
121+
eprintln!("Error downloading model: {}", e);
122122
std::process::exit(1);
123123
}
124124
}

src/downloader/huggingface.rs

Lines changed: 56 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,31 @@
11
use colored::Colorize;
2-
use log::{debug, info};
3-
use std::sync::atomic::{AtomicU64, Ordering};
4-
use std::sync::Arc;
2+
use log::debug;
53

64
use hf_hub::api::tokio::{ApiBuilder, Progress};
7-
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
85

96
use crate::downloader::downloader::{DownloadError, Downloader};
7+
use crate::downloader::progress::{DownloadProgressManager, FileProgress};
108
use crate::registry::model_registry::{ModelInfo, ModelRegistry};
11-
use crate::util::file;
9+
use crate::utils::file::{self, format_model_name};
1210

11+
/// Adapter to bridge HuggingFace's Progress trait with our FileProgress
1312
#[derive(Clone)]
14-
struct FileProgressBar {
15-
pb: ProgressBar,
16-
total_size: Arc<AtomicU64>,
13+
struct HfProgressAdapter {
14+
progress: FileProgress,
1715
}
1816

19-
impl Progress for FileProgressBar {
17+
impl Progress for HfProgressAdapter {
2018
async fn init(&mut self, size: usize, _filename: &str) {
21-
self.pb.set_length(size as u64);
22-
self.pb.reset();
23-
self.pb.tick(); // Force render with correct size
24-
self.total_size.fetch_add(size as u64, Ordering::Relaxed);
19+
self.progress.init(size as u64);
2520
}
2621

2722
async fn update(&mut self, size: usize) {
28-
self.pb.inc(size as u64);
23+
self.progress.update(size as u64);
2924
}
3025

31-
async fn finish(&mut self) {}
26+
async fn finish(&mut self) {
27+
self.progress.finish();
28+
}
3229
}
3330

3431
pub struct HuggingFaceDownloader;
@@ -49,7 +46,7 @@ impl Downloader for HuggingFaceDownloader {
4946
async fn download_model(&self, name: &str) -> Result<(), DownloadError> {
5047
let start_time = std::time::Instant::now();
5148

52-
info!("Downloading model {} from Hugging Face...", name);
49+
debug!("Downloading model {} from Hugging Face...", name);
5350

5451
// Use unified PUMA cache directory
5552
let cache_dir = file::huggingface_cache_dir();
@@ -65,6 +62,8 @@ impl Downloader for HuggingFaceDownloader {
6562
DownloadError::ApiError(format!("Failed to initialize Hugging Face API: {}", e))
6663
})?;
6764

65+
println!("🐆 pulling manifest");
66+
6867
// Download the entire model repository using snapshot download
6968
let repo = api.model(name.to_string());
7069

@@ -84,9 +83,6 @@ impl Downloader for HuggingFaceDownloader {
8483

8584
debug!("Model info for {}: {:?}", name, model_info);
8685

87-
// Create multi-progress for parallel downloads
88-
let multi_progress = Arc::new(MultiProgress::new());
89-
9086
// Calculate the longest filename for proper alignment
9187
let max_filename_len = model_info
9288
.siblings
@@ -95,54 +91,59 @@ impl Downloader for HuggingFaceDownloader {
9591
.max()
9692
.unwrap_or(30);
9793

98-
// Progress bar style with block characters (chart-like, not #)
99-
let template = format!(
100-
"{{msg:<{width}}} [{{elapsed_precise}}] {{bar:60.white}} {{bytes}}/{{total_bytes}}",
101-
width = max_filename_len
102-
);
103-
let style = ProgressStyle::default_bar()
104-
.template(&template)
105-
.unwrap()
106-
.progress_chars("▇▆▅▄▃▂▁ ");
94+
// Create progress manager
95+
let progress_manager = DownloadProgressManager::new(max_filename_len);
10796

108-
// Download all files in parallel
109-
let mut tasks = Vec::new();
97+
// Calculate cache paths
98+
let model_cache_path = cache_dir.join(format_model_name(name));
11099
let sha = model_info.sha.clone();
111-
let total_size = Arc::new(AtomicU64::new(0));
100+
let snapshot_path = model_cache_path.join("snapshots").join(&sha);
101+
102+
// Process all files in manifest order (cached files show as instantly complete)
103+
let mut tasks = Vec::new();
112104

113105
for sibling in model_info.siblings {
114106
let api_clone = api.clone();
115107
let model_name = name.to_string();
116108
let filename = sibling.rfilename.clone();
117-
let total_size_clone = Arc::clone(&total_size);
118-
119-
let pb = multi_progress.add(ProgressBar::hidden());
120-
pb.set_style(style.clone());
121-
pb.set_message(filename.clone());
109+
let progress_manager_clone = progress_manager.clone();
110+
let snapshot_path_clone = snapshot_path.clone();
122111

123112
let task = tokio::spawn(async move {
124-
debug!("Downloading: {}", filename);
125-
126113
let repo = api_clone.model(model_name);
127-
let progress = FileProgressBar {
128-
pb: pb.clone(),
129-
total_size: total_size_clone,
130-
};
131114

132-
let result = repo.download_with_progress(&filename, progress).await;
115+
// Check if file exists in cache
116+
let cached_file_path = snapshot_path_clone.join(&filename);
117+
if cached_file_path.exists() {
118+
debug!("File {} found in cache, showing as complete", filename);
119+
120+
// Create progress bar and mark as instantly complete
121+
let mut file_progress = progress_manager_clone.create_file_progress(&filename);
122+
let file_size = cached_file_path.metadata().map(|m| m.len()).unwrap_or(0);
123+
file_progress.init(file_size);
124+
file_progress.update(file_size);
125+
file_progress.finish();
133126

134-
match &result {
135-
Ok(_) => {
136-
pb.finish();
137-
}
138-
Err(_) => {
139-
pb.abandon();
140-
}
127+
return Ok(());
141128
}
142129

143-
result.map_err(|e| {
144-
DownloadError::NetworkError(format!("Failed to download {}: {}", filename, e))
145-
})
130+
// File not in cache, download with progress
131+
debug!("Downloading: {}", filename);
132+
let file_progress = progress_manager_clone.create_file_progress(&filename);
133+
let progress = HfProgressAdapter {
134+
progress: file_progress,
135+
};
136+
137+
repo.download_with_progress(&filename, progress)
138+
.await
139+
.map_err(|e| {
140+
DownloadError::NetworkError(format!(
141+
"Failed to download {}: {}",
142+
filename, e
143+
))
144+
})?;
145+
146+
Ok(())
146147
});
147148

148149
tasks.push(task);
@@ -157,8 +158,8 @@ impl Downloader for HuggingFaceDownloader {
157158
let elapsed_time = start_time.elapsed();
158159

159160
// Get accumulated size from downloads
160-
let downloaded_size = total_size.load(Ordering::Relaxed);
161-
let model_cache_path = cache_dir.join(format!("models--{}", name.replace("/", "--")));
161+
let downloaded_size = progress_manager.total_downloaded_bytes();
162+
let model_cache_path = cache_dir.join(format_model_name(name));
162163

163164
// Register the model
164165
let model_info_record = ModelInfo {

src/downloader/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
#[allow(clippy::module_inception)]
22
pub mod downloader;
33
pub mod huggingface;
4+
pub mod progress;

0 commit comments

Comments
 (0)