@@ -2,10 +2,11 @@ use colored::Colorize;
22use log:: debug;
33
44use hf_hub:: api:: tokio:: { ApiBuilder , Progress } ;
5+ use indicatif:: { ProgressBar , ProgressStyle } ;
56
67use crate :: downloader:: downloader:: { DownloadError , Downloader } ;
78use crate :: downloader:: progress:: { DownloadProgressManager , FileProgress } ;
8- use crate :: registry:: model_registry:: { ModelInfo , ModelRegistry } ;
9+ use crate :: registry:: model_registry:: { ModelArchitecture , ModelInfo , ModelRegistry } ;
910use crate :: utils:: file:: { self , format_model_name} ;
1011
1112/// Adapter to bridge HuggingFace's Progress trait with our FileProgress
@@ -62,7 +63,15 @@ impl Downloader for HuggingFaceDownloader {
6263 DownloadError :: ApiError ( format ! ( "Failed to initialize Hugging Face API: {}" , e) )
6364 } ) ?;
6465
65- println ! ( "🐆 pulling manifest" ) ;
66+ // Create a simple spinner for manifest pulling
67+ let manifest_spinner = ProgressBar :: new_spinner ( ) ;
68+ manifest_spinner. set_style (
69+ ProgressStyle :: default_spinner ( )
70+ . tick_chars ( "⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏" )
71+ . template ( "pulling manifest {spinner:.white}" )
72+ . unwrap ( ) ,
73+ ) ;
74+ manifest_spinner. enable_steady_tick ( std:: time:: Duration :: from_millis ( 80 ) ) ;
6675
6776 // Download the entire model repository using snapshot download
6877 let repo = api. model ( name. to_string ( ) ) ;
@@ -81,6 +90,10 @@ impl Downloader for HuggingFaceDownloader {
8190 }
8291 } ) ?;
8392
93+ // Stop manifest spinner and print clean message
94+ manifest_spinner. finish_and_clear ( ) ;
95+ println ! ( "pulling manifest" ) ;
96+
8497 debug ! ( "Model info for {}: {:?}" , name, model_info) ;
8598
8699 // Calculate the longest filename for proper alignment
@@ -91,6 +104,8 @@ impl Downloader for HuggingFaceDownloader {
91104 . max ( )
92105 . unwrap_or ( 30 ) ;
93106
107+ // Add extra space for "pulling " prefix
108+ let max_filename_len = max_filename_len + 8 ;
94109 // Create progress manager
95110 let progress_manager = DownloadProgressManager :: new ( max_filename_len) ;
96111
@@ -124,8 +139,9 @@ impl Downloader for HuggingFaceDownloader {
124139 debug ! ( "File {} found in cache, showing as complete" , filename) ;
125140
126141 // Create progress bar for cached file (no speed display)
142+ let display_name = format ! ( "pulling {}" , filename) ;
127143 let mut file_progress =
128- progress_manager_clone. create_cached_file_progress ( & filename ) ;
144+ progress_manager_clone. create_cached_file_progress ( & display_name ) ;
129145 let file_size = cached_file_path. metadata ( ) . map ( |m| m. len ( ) ) . unwrap_or ( 0 ) ;
130146 file_progress. init ( file_size) ;
131147 file_progress. update ( file_size) ;
@@ -136,7 +152,8 @@ impl Downloader for HuggingFaceDownloader {
136152
137153 // File not in cache, download with progress
138154 debug ! ( "Downloading: {}" , filename) ;
139- let file_progress = progress_manager_clone. create_file_progress ( & filename) ;
155+ let display_name = format ! ( "pulling {}" , filename) ;
156+ let file_progress = progress_manager_clone. create_file_progress ( & display_name) ;
140157 let progress = HfProgressAdapter {
141158 progress : file_progress,
142159 } ;
@@ -156,37 +173,65 @@ impl Downloader for HuggingFaceDownloader {
156173 tasks. push ( task) ;
157174 }
158175
176+ // Give tasks a moment to start and create their progress bars
177+ tokio:: time:: sleep ( tokio:: time:: Duration :: from_millis ( 100 ) ) . await ;
178+
179+ // Show spinner at the bottom after all progress bars are created (only if not fully cached)
180+ let spinner = if !model_totally_cached {
181+ Some ( progress_manager. create_spinner ( ) )
182+ } else {
183+ None
184+ } ;
185+
159186 // Wait for all downloads to complete
160187 for task in tasks {
161188 task. await
162189 . map_err ( |e| DownloadError :: ApiError ( format ! ( "Task join error: {}" , e) ) ) ??;
163190 }
164191
192+ // Finish spinner after downloads complete
193+ if let Some ( spinner) = & spinner {
194+ spinner. finish_and_clear ( ) ;
195+ }
196+
165197 let elapsed_time = start_time. elapsed ( ) ;
166198
167199 // Get accumulated size from downloads
168200 let downloaded_size = progress_manager. total_downloaded_bytes ( ) ;
169201 let model_cache_path = cache_dir. join ( format_model_name ( name) ) ;
170202
171- // Register the model
172- let model_info_record = ModelInfo {
173- name : name. to_string ( ) ,
174- provider : "huggingface" . to_string ( ) ,
175- revision : sha,
176- size : downloaded_size,
177- modified_at : chrono:: Local :: now ( ) . to_rfc3339 ( ) ,
178- cache_path : model_cache_path. to_string_lossy ( ) . to_string ( ) ,
179- } ;
180-
203+ // Register the model only if not totally cached
181204 if !model_totally_cached {
205+ // Extract architecture info from config.json
206+ let config_path = snapshot_path. join ( "config.json" ) ;
207+ let arch = if config_path. exists ( ) {
208+ std:: fs:: read_to_string ( & config_path)
209+ . ok ( )
210+ . and_then ( |content| serde_json:: from_str :: < serde_json:: Value > ( & content) . ok ( ) )
211+ . and_then ( |config| ModelArchitecture :: from_config ( & config) )
212+ } else {
213+ None
214+ } ;
215+
216+ let model_info_record = ModelInfo {
217+ name : name. to_string ( ) ,
218+ provider : "huggingface" . to_string ( ) ,
219+ revision : sha,
220+ size : downloaded_size,
221+ modified_at : chrono:: Local :: now ( ) . to_rfc3339 ( ) ,
222+ cache_path : model_cache_path. to_string_lossy ( ) . to_string ( ) ,
223+ arch,
224+ } ;
225+
182226 let registry = ModelRegistry :: new ( None ) ;
183227 registry
184228 . register_model ( model_info_record)
185229 . map_err ( |e| DownloadError :: ApiError ( format ! ( "Failed to register model: {}" , e) ) ) ?;
186230 }
187231
232+ // Print success message
188233 println ! (
189- "\n {} {} {} {} {:.2?}" ,
234+ "{} {} {} {} {:.2?}" ,
190235 "✓" . green( ) . bold( ) ,
191236 "Successfully downloaded model" . bright_white( ) ,
192237 name. cyan( ) . bold( ) ,
0 commit comments