|
3 | 3 | //! Downloads use Range headers to resume from `.partial` files if paused. |
4 | 4 | //! Progress events are emitted via channel for UI integration. |
5 | 5 |
|
6 | | -use crate::{hf, meta, net}; |
7 | | -use keyless_core::error::{KeylessError, KeylessResult}; |
8 | | -use std::fs::{self, File}; |
9 | | -use std::io::Write; |
10 | | -use std::sync::mpsc::SyncSender; |
11 | | -use std::sync::{Arc, atomic::AtomicBool, atomic::Ordering}; |
12 | | -use tracing::{error, info}; |
13 | | - |
14 | | -/// Progress events for model download. |
15 | | -#[derive(Debug)] |
16 | | -pub enum DownloadEvent { |
17 | | - /// Download started for the given model. |
18 | | - Started { |
19 | | - /// Model identifier (e.g., "openai/whisper-tiny"). |
20 | | - model: String, |
21 | | - }, |
22 | | - /// Stage update (e.g., "downloading config.json"). |
23 | | - Stage { |
24 | | - /// Human-readable stage description. |
25 | | - text: String, |
26 | | - }, |
27 | | - /// Progress update with bytes downloaded, total size, speed, and ETA. |
28 | | - Progress { |
29 | | - /// Bytes downloaded so far. |
30 | | - bytes: u64, |
31 | | - /// Total bytes if known (Some), else None. |
32 | | - total: Option<u64>, |
33 | | - /// Throughput in megabytes per second. |
34 | | - mbps: f64, |
35 | | - /// Estimated seconds remaining. |
36 | | - eta_s: f64, |
37 | | - }, |
38 | | - /// Download completed (Ok) or failed (Err). |
39 | | - Done(KeylessResult<()>), |
40 | | -} |
41 | | - |
42 | | -/// Ensure the required Whisper model files are present in the local cache. |
43 | | -/// |
44 | | -/// Downloads are implemented with `reqwest` (no `hf-hub`). We fetch small |
45 | | -/// files (`config.json`, `tokenizer.json`) with retry/backoff, then stream |
46 | | -/// `model.safetensors` with Range resume into a `.partial` file and atomically |
47 | | -/// rename on completion. Progress and stage events are emitted for UI integration. |
48 | | -/// |
49 | | -/// An optional token (`HF_TOKEN` or `HUGGINGFACE_HUB_TOKEN`) is used if set. |
50 | | -/// |
51 | | -/// # Arguments |
52 | | -/// * `model_id` - Hugging Face model identifier (e.g., "openai/whisper-base") |
53 | | -/// * `tx` - Channel to send progress events |
54 | | -/// * `cancel` - Atomic flag to check for cancellation requests |
55 | | -pub fn ensure_model_cached( |
56 | | - model_id: String, |
57 | | - tx: SyncSender<DownloadEvent>, |
58 | | - cancel: Arc<AtomicBool>, |
59 | | -) { |
60 | | - // Try-send for non-blocking start event (may drop if channel full; OK for best-effort). |
61 | | - let _ = tx.try_send(DownloadEvent::Started { |
62 | | - model: model_id.clone(), |
63 | | - }); |
64 | | - info!(model = %model_id, "starting model download"); |
65 | | - // Clone channel senders for different scopes (closure captures require owned values). |
66 | | - let tx_clone = tx.clone(); |
67 | | - let tx_for_loop = tx.clone(); |
68 | | - let result: KeylessResult<()> = (|| { |
69 | | - // Blocking reqwest client (with timeouts); this function is not async. |
70 | | - let client = |
71 | | - net::build_blocking_client().map_err(|e| KeylessError::Other(e.to_string()))?; |
72 | | - // Optional auth header (HF token from env); works without token for public models. |
73 | | - let auth = net::auth_header(); |
74 | | - |
75 | | - // Ensure repo cache dir exists (e.g., ~/.cache/keyless/openai--whisper-base/). |
76 | | - let repo_dir = hf::keyless_cache_repo_dir(&model_id); |
77 | | - fs::create_dir_all(&repo_dir).map_err(KeylessError::from)?; |
78 | | - |
79 | | - // Plan line (if size known from cache): show total expected size to user. |
80 | | - if let Some(total) = meta::plan_total_bytes(&model_id) { |
81 | | - let _ = tx_clone.try_send(DownloadEvent::Stage { |
82 | | - text: format!("plan: {}", keyless_core::utils::human_size(total)), |
83 | | - }); |
84 | | - } |
85 | | - |
86 | | - // Small files (with retry/backoff): fetch into memory, then write atomically. |
87 | | - for name in ["config.json", "tokenizer.json"] { |
88 | | - // Check cancellation before each file (allows prompt abort). |
89 | | - if cancel.load(Ordering::Relaxed) { |
90 | | - return Err(KeylessError::Other("cancelled".into())); |
91 | | - } |
92 | | - let dst = repo_dir.join(name); |
93 | | - // Skip if already downloaded (idempotent operation). |
94 | | - if dst.exists() { |
95 | | - continue; |
96 | | - } |
97 | | - let url = net::hf_resolve_url(&model_id, name); |
98 | | - let _ = tx_for_loop.try_send(DownloadEvent::Stage { |
99 | | - text: format!("downloading {}", name), |
100 | | - }); |
101 | | - // Retry with exponential backoff (handles transient network errors). |
102 | | - let (attempts, initial, max_ms) = meta::backoff_config(); |
103 | | - let (bytes, content_len_hdr) = |
104 | | - net::blocking_get_with_backoff(&client, &url, &auth, attempts, initial, max_ms)?; |
105 | | - // Write file atomically (create overwrites if exists; we checked above, but defensive). |
106 | | - let mut f = File::create(&dst).map_err(KeylessError::from)?; |
107 | | - f.write_all(&bytes).map_err(KeylessError::from)?; |
108 | | - // Verify downloaded size matches Content-Length header (catches truncation/corruption). |
109 | | - if let Some(cl) = content_len_hdr |
110 | | - && cl != bytes.len() as u64 |
111 | | - { |
112 | | - return Err(KeylessError::Other(format!( |
113 | | - "{} size mismatch ({} != {})", |
114 | | - name, |
115 | | - bytes.len(), |
116 | | - cl |
117 | | - ))); |
118 | | - } |
119 | | - } |
120 | | - |
121 | | - // Large file streaming with progress: stream chunks to avoid memory issues. |
122 | | - if cancel.load(Ordering::Relaxed) { |
123 | | - return Err(KeylessError::Other("cancelled".into())); |
124 | | - } |
125 | | - let weights = repo_dir.join("model.safetensors"); |
126 | | - if !weights.exists() { |
127 | | - let url = net::hf_resolve_url(&model_id, "model.safetensors"); |
128 | | - let mut req = client.get(&url); |
129 | | - // Add auth header if available (for private models or rate limit increases). |
130 | | - if let Some((h, v)) = &auth { |
131 | | - req = req.header(h, v.clone()); |
132 | | - } |
133 | | - // Resume support: check for existing partial file to resume download. |
134 | | - let tmp = repo_dir.join("model.safetensors.partial"); |
135 | | - let mut downloaded: u64 = 0; |
136 | | - // Read partial file size to determine resume offset. |
137 | | - if tmp.exists() |
138 | | - && let Ok(m) = std::fs::metadata(&tmp) |
139 | | - { |
140 | | - downloaded = m.len(); |
141 | | - } |
142 | | - // Add Range header if resuming (bytes=<offset>- requests from offset to end). |
143 | | - if downloaded > 0 { |
144 | | - req = req.header(reqwest::header::RANGE, format!("bytes={}-", downloaded)); |
145 | | - } |
146 | | - let mut resp = req |
147 | | - .send() |
148 | | - .map_err(|e| KeylessError::Other(format!("GET {}: {}", url, e)))?; |
149 | | - let status = resp.status(); |
150 | | - let mut file: Box<dyn std::io::Write>; |
151 | | - if status == reqwest::StatusCode::PARTIAL_CONTENT { |
152 | | - // Resume OK (206): server honored Range header; append to existing partial file. |
153 | | - resp = resp |
154 | | - .error_for_status() |
155 | | - .map_err(|e| KeylessError::Other(e.to_string()))?; |
156 | | - file = Box::new( |
157 | | - std::fs::OpenOptions::new() |
158 | | - .append(true) |
159 | | - .open(&tmp) |
160 | | - .map_err(KeylessError::from)?, |
161 | | - ); |
162 | | - } else if status == reqwest::StatusCode::OK { |
163 | | - // Range ignored (200): server doesn't support Range; restart from beginning. |
164 | | - // Remove partial file to avoid corruption (truncate would also work). |
165 | | - let _ = std::fs::remove_file(&tmp); |
166 | | - downloaded = 0; |
167 | | - resp = resp |
168 | | - .error_for_status() |
169 | | - .map_err(|e| KeylessError::Other(e.to_string()))?; |
170 | | - // Create new file (truncate overwrites existing; defensive for edge cases). |
171 | | - file = Box::new( |
172 | | - std::fs::OpenOptions::new() |
173 | | - .write(true) |
174 | | - .create(true) |
175 | | - .truncate(true) |
176 | | - .open(&tmp) |
177 | | - .map_err(KeylessError::from)?, |
178 | | - ); |
179 | | - } else { |
180 | | - // Unexpected status (e.g., 416 Range Not Satisfiable); fail fast. |
181 | | - return Err(KeylessError::Other(format!("unexpected status {}", status))); |
182 | | - } |
183 | | - |
184 | | - // Calculate total size: for 206, add downloaded bytes to Content-Length (partial response). |
185 | | - // For 200, Content-Length is full file size (no resume offset). |
186 | | - let total = if status == reqwest::StatusCode::PARTIAL_CONTENT { |
187 | | - resp.content_length() |
188 | | - .unwrap_or(0) |
189 | | - .saturating_add(downloaded) |
190 | | - } else { |
191 | | - resp.content_length().unwrap_or(0) |
192 | | - }; |
193 | | - // Cache total size for future resume attempts (if known). |
194 | | - if total > 0 { |
195 | | - meta::update_saved_size(&model_id, total); |
196 | | - } |
197 | | - // Track time for speed/ETA calculations. |
198 | | - let start = std::time::Instant::now(); |
199 | | - let mut last_emit = std::time::Instant::now(); |
200 | | - // 64KB buffer: balances memory usage vs I/O syscall overhead. |
201 | | - let mut buf = vec![0u8; 64 * 1024]; |
202 | | - loop { |
203 | | - // Check cancellation before reading (allows prompt abort on slow networks). |
204 | | - if cancel.load(Ordering::Relaxed) { |
205 | | - return Err(KeylessError::Other("cancelled".into())); |
206 | | - } |
207 | | - use std::io::Read; |
208 | | - // Read chunk from response stream (may return < buf.len() at end). |
209 | | - let n = resp |
210 | | - .read(&mut buf) |
211 | | - .map_err(|e| KeylessError::Other(format!("read chunk: {}", e)))?; |
212 | | - // EOF: n == 0 indicates stream end (normal completion). |
213 | | - if n == 0 { |
214 | | - break; |
215 | | - } |
216 | | - // Check cancellation after read (allows abort during write). |
217 | | - if cancel.load(Ordering::Relaxed) { |
218 | | - return Err(KeylessError::Other("cancelled".into())); |
219 | | - } |
220 | | - // Write chunk to file (only write n bytes, not full buffer). |
221 | | - file.write_all(&buf[..n]).map_err(KeylessError::from)?; |
222 | | - downloaded += n as u64; |
223 | | - // Throttle progress events to at most every 250ms (prevents channel saturation). |
224 | | - if last_emit.elapsed().as_millis() >= 250 { |
225 | | - last_emit = std::time::Instant::now(); |
226 | | - if total > 0 && downloaded <= total { |
227 | | - // Calculate speed (MB/s) and ETA (seconds remaining). |
228 | | - // max(0.001) prevents division by zero on very fast downloads. |
229 | | - let secs = start.elapsed().as_secs_f64().max(0.001); |
230 | | - let mbps = (downloaded as f64 / 1_000_000.0) / secs; |
231 | | - // saturating_sub prevents underflow if downloaded > total (shouldn't happen). |
232 | | - let left = (total.saturating_sub(downloaded)) as f64; |
233 | | - let bps = (downloaded as f64) / secs; |
234 | | - // ETA = remaining_bytes / bytes_per_second; max(0.0) prevents negative ETA. |
235 | | - let eta = if bps > 0.0 { |
236 | | - (left / bps).max(0.0) |
237 | | - } else { |
238 | | - 0.0 |
239 | | - }; |
240 | | - let _ = tx_clone.try_send(DownloadEvent::Progress { |
241 | | - bytes: downloaded, |
242 | | - total: Some(total), |
243 | | - mbps, |
244 | | - eta_s: eta, |
245 | | - }); |
246 | | - } else { |
247 | | - // Total unknown or downloaded exceeds expected (shouldn't happen). |
248 | | - let _ = tx_clone.try_send(DownloadEvent::Progress { |
249 | | - bytes: downloaded, |
250 | | - total: None, |
251 | | - mbps: 0.0, |
252 | | - eta_s: 0.0, |
253 | | - }); |
254 | | - } |
255 | | - } |
256 | | - } |
257 | | - // Close file handle before rename (ensures all buffers flushed to disk). |
258 | | - drop(file); |
259 | | - // Atomic rename: partial → final (prevents partial files from being used). |
260 | | - fs::rename(&tmp, &weights).map_err(KeylessError::from)?; |
261 | | - // Verify final size if we know total (catches truncation/corruption). |
262 | | - if total > 0 { |
263 | | - let final_len = std::fs::metadata(&weights) |
264 | | - .map_err(KeylessError::from)? |
265 | | - .len(); |
266 | | - if final_len != total { |
267 | | - // Clean up corrupted file (better to have no file than wrong size). |
268 | | - // Controller will surface error to user and allow retry. |
269 | | - let _ = std::fs::remove_file(&weights); |
270 | | - return Err(KeylessError::Other(format!( |
271 | | - "downloaded size mismatch ({} != {}), please retry", |
272 | | - final_len, total |
273 | | - ))); |
274 | | - } |
275 | | - // Emit a final 100% progress to fully fill the gauge (UI completeness). |
276 | | - let _ = tx_clone.try_send(DownloadEvent::Progress { |
277 | | - bytes: total, |
278 | | - total: Some(total), |
279 | | - mbps: 0.0, |
280 | | - eta_s: 0.0, |
281 | | - }); |
282 | | - } |
283 | | - } |
284 | | - Ok(()) |
285 | | - })(); |
286 | | - |
287 | | - // Log result for debugging (goes to session.log; not shown in TUI). |
288 | | - match &result { |
289 | | - Ok(()) => info!(model = %model_id, "model download completed successfully"), |
290 | | - Err(e) => error!(model = %model_id, error = %e, "model download failed"), |
291 | | - } |
292 | | - |
293 | | - // Use blocking send for completion event (must be delivered; wait if channel full). |
294 | | - // try_send would drop the event if channel saturated, hiding errors from user. |
295 | | - let _ = tx.send(DownloadEvent::Done(result)); |
296 | | -} |
| 6 | +mod events; |
| 7 | +mod large_file; |
| 8 | +mod model; |
| 9 | +mod progress; |
| 10 | +mod small_files; |
| 11 | + |
| 12 | +pub use events::DownloadEvent; |
| 13 | +pub use model::ensure_model_cached; |
0 commit comments