Skip to content

Commit 0d13597

Browse files
committed
refactor(models): modularize download module into submodules
Reorganize download.rs (299 lines) into a modern Rust module structure: - download.rs: re-exports public API (events, model) - download/model.rs: main ensure_model_cached orchestration - download/events.rs: DownloadEvent enum - download/small_files.rs: small file download logic - download/large_file.rs: large file streaming with resume support - download/progress.rs: progress reporting utilities Maintains same public API, improves code organization and maintainability. Updated README.md to reflect new module structure.
1 parent 68f3364 commit 0d13597

7 files changed

Lines changed: 405 additions & 292 deletions

File tree

keyless-models/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ HTTP/HF helpers and model metadata/cache utilities for Keyless.
1818
- `KEYLESS_CACHE_DIR`: override cache root (tests/dev)
1919

2020
## Module structure
21-
- `download.rs`: Model download with progress reporting and cancellation (`DownloadEvent`, `ensure_model_cached`)
21+
- `download`: Model download with progress reporting and cancellation (`DownloadEvent`, `ensure_model_cached`)
2222
- Status-aware resume (206 append, 200 restart/truncate) with periodic progress (≤4 Hz)
23+
- Organized into submodules: `events`, `model`, `small_files`, `large_file`, `progress`
2324
- `hf.rs`: HF cache path helpers (`keyless_cache_repo_dir`, `get_local_model_size`, `delete_partial_file`)
2425
- `net.rs`: HTTP utilities (auth headers, URL resolution, retry/backoff)
2526
- `meta.rs`: Sizes TTL cache management

keyless-models/src/download.rs

Lines changed: 8 additions & 291 deletions
Original file line numberDiff line numberDiff line change
@@ -3,294 +3,11 @@
33
//! Downloads use Range headers to resume from `.partial` files if paused.
44
//! Progress events are emitted via channel for UI integration.
55
6-
use crate::{hf, meta, net};
7-
use keyless_core::error::{KeylessError, KeylessResult};
8-
use std::fs::{self, File};
9-
use std::io::Write;
10-
use std::sync::mpsc::SyncSender;
11-
use std::sync::{Arc, atomic::AtomicBool, atomic::Ordering};
12-
use tracing::{error, info};
13-
14-
/// Progress events for model download.
15-
#[derive(Debug)]
16-
pub enum DownloadEvent {
17-
/// Download started for the given model.
18-
Started {
19-
/// Model identifier (e.g., "openai/whisper-tiny").
20-
model: String,
21-
},
22-
/// Stage update (e.g., "downloading config.json").
23-
Stage {
24-
/// Human-readable stage description.
25-
text: String,
26-
},
27-
/// Progress update with bytes downloaded, total size, speed, and ETA.
28-
Progress {
29-
/// Bytes downloaded so far.
30-
bytes: u64,
31-
/// Total bytes if known (Some), else None.
32-
total: Option<u64>,
33-
/// Throughput in megabytes per second.
34-
mbps: f64,
35-
/// Estimated seconds remaining.
36-
eta_s: f64,
37-
},
38-
/// Download completed (Ok) or failed (Err).
39-
Done(KeylessResult<()>),
40-
}
41-
42-
/// Ensure the required Whisper model files are present in the local cache.
43-
///
44-
/// Downloads are implemented with `reqwest` (no `hf-hub`). We fetch small
45-
/// files (`config.json`, `tokenizer.json`) with retry/backoff, then stream
46-
/// `model.safetensors` with Range resume into a `.partial` file and atomically
47-
/// rename on completion. Progress and stage events are emitted for UI integration.
48-
///
49-
/// An optional token (`HF_TOKEN` or `HUGGINGFACE_HUB_TOKEN`) is used if set.
50-
///
51-
/// # Arguments
52-
/// * `model_id` - Hugging Face model identifier (e.g., "openai/whisper-base")
53-
/// * `tx` - Channel to send progress events
54-
/// * `cancel` - Atomic flag to check for cancellation requests
55-
pub fn ensure_model_cached(
56-
model_id: String,
57-
tx: SyncSender<DownloadEvent>,
58-
cancel: Arc<AtomicBool>,
59-
) {
60-
// Try-send for non-blocking start event (may drop if channel full; OK for best-effort).
61-
let _ = tx.try_send(DownloadEvent::Started {
62-
model: model_id.clone(),
63-
});
64-
info!(model = %model_id, "starting model download");
65-
// Clone channel senders for different scopes (closure captures require owned values).
66-
let tx_clone = tx.clone();
67-
let tx_for_loop = tx.clone();
68-
let result: KeylessResult<()> = (|| {
69-
// Blocking reqwest client (with timeouts); this function is not async.
70-
let client =
71-
net::build_blocking_client().map_err(|e| KeylessError::Other(e.to_string()))?;
72-
// Optional auth header (HF token from env); works without token for public models.
73-
let auth = net::auth_header();
74-
75-
// Ensure repo cache dir exists (e.g., ~/.cache/keyless/openai--whisper-base/).
76-
let repo_dir = hf::keyless_cache_repo_dir(&model_id);
77-
fs::create_dir_all(&repo_dir).map_err(KeylessError::from)?;
78-
79-
// Plan line (if size known from cache): show total expected size to user.
80-
if let Some(total) = meta::plan_total_bytes(&model_id) {
81-
let _ = tx_clone.try_send(DownloadEvent::Stage {
82-
text: format!("plan: {}", keyless_core::utils::human_size(total)),
83-
});
84-
}
85-
86-
// Small files (with retry/backoff): fetch into memory, then write atomically.
87-
for name in ["config.json", "tokenizer.json"] {
88-
// Check cancellation before each file (allows prompt abort).
89-
if cancel.load(Ordering::Relaxed) {
90-
return Err(KeylessError::Other("cancelled".into()));
91-
}
92-
let dst = repo_dir.join(name);
93-
// Skip if already downloaded (idempotent operation).
94-
if dst.exists() {
95-
continue;
96-
}
97-
let url = net::hf_resolve_url(&model_id, name);
98-
let _ = tx_for_loop.try_send(DownloadEvent::Stage {
99-
text: format!("downloading {}", name),
100-
});
101-
// Retry with exponential backoff (handles transient network errors).
102-
let (attempts, initial, max_ms) = meta::backoff_config();
103-
let (bytes, content_len_hdr) =
104-
net::blocking_get_with_backoff(&client, &url, &auth, attempts, initial, max_ms)?;
105-
// Write file atomically (create overwrites if exists; we checked above, but defensive).
106-
let mut f = File::create(&dst).map_err(KeylessError::from)?;
107-
f.write_all(&bytes).map_err(KeylessError::from)?;
108-
// Verify downloaded size matches Content-Length header (catches truncation/corruption).
109-
if let Some(cl) = content_len_hdr
110-
&& cl != bytes.len() as u64
111-
{
112-
return Err(KeylessError::Other(format!(
113-
"{} size mismatch ({} != {})",
114-
name,
115-
bytes.len(),
116-
cl
117-
)));
118-
}
119-
}
120-
121-
// Large file streaming with progress: stream chunks to avoid memory issues.
122-
if cancel.load(Ordering::Relaxed) {
123-
return Err(KeylessError::Other("cancelled".into()));
124-
}
125-
let weights = repo_dir.join("model.safetensors");
126-
if !weights.exists() {
127-
let url = net::hf_resolve_url(&model_id, "model.safetensors");
128-
let mut req = client.get(&url);
129-
// Add auth header if available (for private models or rate limit increases).
130-
if let Some((h, v)) = &auth {
131-
req = req.header(h, v.clone());
132-
}
133-
// Resume support: check for existing partial file to resume download.
134-
let tmp = repo_dir.join("model.safetensors.partial");
135-
let mut downloaded: u64 = 0;
136-
// Read partial file size to determine resume offset.
137-
if tmp.exists()
138-
&& let Ok(m) = std::fs::metadata(&tmp)
139-
{
140-
downloaded = m.len();
141-
}
142-
// Add Range header if resuming (bytes=<offset>- requests from offset to end).
143-
if downloaded > 0 {
144-
req = req.header(reqwest::header::RANGE, format!("bytes={}-", downloaded));
145-
}
146-
let mut resp = req
147-
.send()
148-
.map_err(|e| KeylessError::Other(format!("GET {}: {}", url, e)))?;
149-
let status = resp.status();
150-
let mut file: Box<dyn std::io::Write>;
151-
if status == reqwest::StatusCode::PARTIAL_CONTENT {
152-
// Resume OK (206): server honored Range header; append to existing partial file.
153-
resp = resp
154-
.error_for_status()
155-
.map_err(|e| KeylessError::Other(e.to_string()))?;
156-
file = Box::new(
157-
std::fs::OpenOptions::new()
158-
.append(true)
159-
.open(&tmp)
160-
.map_err(KeylessError::from)?,
161-
);
162-
} else if status == reqwest::StatusCode::OK {
163-
// Range ignored (200): server doesn't support Range; restart from beginning.
164-
// Remove partial file to avoid corruption (truncate would also work).
165-
let _ = std::fs::remove_file(&tmp);
166-
downloaded = 0;
167-
resp = resp
168-
.error_for_status()
169-
.map_err(|e| KeylessError::Other(e.to_string()))?;
170-
// Create new file (truncate overwrites existing; defensive for edge cases).
171-
file = Box::new(
172-
std::fs::OpenOptions::new()
173-
.write(true)
174-
.create(true)
175-
.truncate(true)
176-
.open(&tmp)
177-
.map_err(KeylessError::from)?,
178-
);
179-
} else {
180-
// Unexpected status (e.g., 416 Range Not Satisfiable); fail fast.
181-
return Err(KeylessError::Other(format!("unexpected status {}", status)));
182-
}
183-
184-
// Calculate total size: for 206, add downloaded bytes to Content-Length (partial response).
185-
// For 200, Content-Length is full file size (no resume offset).
186-
let total = if status == reqwest::StatusCode::PARTIAL_CONTENT {
187-
resp.content_length()
188-
.unwrap_or(0)
189-
.saturating_add(downloaded)
190-
} else {
191-
resp.content_length().unwrap_or(0)
192-
};
193-
// Cache total size for future resume attempts (if known).
194-
if total > 0 {
195-
meta::update_saved_size(&model_id, total);
196-
}
197-
// Track time for speed/ETA calculations.
198-
let start = std::time::Instant::now();
199-
let mut last_emit = std::time::Instant::now();
200-
// 64KB buffer: balances memory usage vs I/O syscall overhead.
201-
let mut buf = vec![0u8; 64 * 1024];
202-
loop {
203-
// Check cancellation before reading (allows prompt abort on slow networks).
204-
if cancel.load(Ordering::Relaxed) {
205-
return Err(KeylessError::Other("cancelled".into()));
206-
}
207-
use std::io::Read;
208-
// Read chunk from response stream (may return < buf.len() at end).
209-
let n = resp
210-
.read(&mut buf)
211-
.map_err(|e| KeylessError::Other(format!("read chunk: {}", e)))?;
212-
// EOF: n == 0 indicates stream end (normal completion).
213-
if n == 0 {
214-
break;
215-
}
216-
// Check cancellation after read (allows abort during write).
217-
if cancel.load(Ordering::Relaxed) {
218-
return Err(KeylessError::Other("cancelled".into()));
219-
}
220-
// Write chunk to file (only write n bytes, not full buffer).
221-
file.write_all(&buf[..n]).map_err(KeylessError::from)?;
222-
downloaded += n as u64;
223-
// Throttle progress events to at most every 250ms (prevents channel saturation).
224-
if last_emit.elapsed().as_millis() >= 250 {
225-
last_emit = std::time::Instant::now();
226-
if total > 0 && downloaded <= total {
227-
// Calculate speed (MB/s) and ETA (seconds remaining).
228-
// max(0.001) prevents division by zero on very fast downloads.
229-
let secs = start.elapsed().as_secs_f64().max(0.001);
230-
let mbps = (downloaded as f64 / 1_000_000.0) / secs;
231-
// saturating_sub prevents underflow if downloaded > total (shouldn't happen).
232-
let left = (total.saturating_sub(downloaded)) as f64;
233-
let bps = (downloaded as f64) / secs;
234-
// ETA = remaining_bytes / bytes_per_second; max(0.0) prevents negative ETA.
235-
let eta = if bps > 0.0 {
236-
(left / bps).max(0.0)
237-
} else {
238-
0.0
239-
};
240-
let _ = tx_clone.try_send(DownloadEvent::Progress {
241-
bytes: downloaded,
242-
total: Some(total),
243-
mbps,
244-
eta_s: eta,
245-
});
246-
} else {
247-
// Total unknown or downloaded exceeds expected (shouldn't happen).
248-
let _ = tx_clone.try_send(DownloadEvent::Progress {
249-
bytes: downloaded,
250-
total: None,
251-
mbps: 0.0,
252-
eta_s: 0.0,
253-
});
254-
}
255-
}
256-
}
257-
// Close file handle before rename (ensures all buffers flushed to disk).
258-
drop(file);
259-
// Atomic rename: partial → final (prevents partial files from being used).
260-
fs::rename(&tmp, &weights).map_err(KeylessError::from)?;
261-
// Verify final size if we know total (catches truncation/corruption).
262-
if total > 0 {
263-
let final_len = std::fs::metadata(&weights)
264-
.map_err(KeylessError::from)?
265-
.len();
266-
if final_len != total {
267-
// Clean up corrupted file (better to have no file than wrong size).
268-
// Controller will surface error to user and allow retry.
269-
let _ = std::fs::remove_file(&weights);
270-
return Err(KeylessError::Other(format!(
271-
"downloaded size mismatch ({} != {}), please retry",
272-
final_len, total
273-
)));
274-
}
275-
// Emit a final 100% progress to fully fill the gauge (UI completeness).
276-
let _ = tx_clone.try_send(DownloadEvent::Progress {
277-
bytes: total,
278-
total: Some(total),
279-
mbps: 0.0,
280-
eta_s: 0.0,
281-
});
282-
}
283-
}
284-
Ok(())
285-
})();
286-
287-
// Log result for debugging (goes to session.log; not shown in TUI).
288-
match &result {
289-
Ok(()) => info!(model = %model_id, "model download completed successfully"),
290-
Err(e) => error!(model = %model_id, error = %e, "model download failed"),
291-
}
292-
293-
// Use blocking send for completion event (must be delivered; wait if channel full).
294-
// try_send would drop the event if channel saturated, hiding errors from user.
295-
let _ = tx.send(DownloadEvent::Done(result));
296-
}
6+
mod events;
7+
mod large_file;
8+
mod model;
9+
mod progress;
10+
mod small_files;
11+
12+
pub use events::DownloadEvent;
13+
pub use model::ensure_model_cached;
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//! Progress events for model download.
2+
3+
use keyless_core::error::KeylessResult;
4+
5+
/// Progress events for model download.
6+
#[derive(Debug)]
7+
pub enum DownloadEvent {
8+
/// Download started for the given model.
9+
Started {
10+
/// Model identifier (e.g., "openai/whisper-tiny").
11+
model: String,
12+
},
13+
/// Stage update (e.g., "downloading config.json").
14+
Stage {
15+
/// Human-readable stage description.
16+
text: String,
17+
},
18+
/// Progress update with bytes downloaded, total size, speed, and ETA.
19+
Progress {
20+
/// Bytes downloaded so far.
21+
bytes: u64,
22+
/// Total bytes if known (Some), else None.
23+
total: Option<u64>,
24+
/// Throughput in megabytes per second.
25+
mbps: f64,
26+
/// Estimated seconds remaining.
27+
eta_s: f64,
28+
},
29+
/// Download completed (Ok) or failed (Err).
30+
Done(KeylessResult<()>),
31+
}

0 commit comments

Comments
 (0)