@@ -9,6 +9,7 @@ use std::path::Path;
99use std:: sync:: Arc ;
1010use tokenizers:: tokenizer:: Tokenizer ;
1111use tokio:: sync:: OnceCell ;
12+ use futures_util:: stream:: StreamExt ;
1213
1314use crate :: message:: Message ;
1415
@@ -117,7 +118,7 @@ impl AsyncTokenCounter {
117118 Ok ( tokenizer)
118119 }
119120
120- /// Proper async download without blocking
121+ /// Robust async download with retry logic and network failure handling
121122 async fn download_tokenizer_async (
122123 repo_id : & str ,
123124 download_dir : & std:: path:: Path ,
@@ -130,20 +131,120 @@ impl AsyncTokenCounter {
130131 ) ;
131132 let file_path = download_dir. join ( "tokenizer.json" ) ;
132133
133- // Use async HTTP client - no runtime blocking!
134- let client = reqwest:: Client :: new ( ) ;
135- let response = client. get ( & file_url) . send ( ) . await ?;
134+ // Check if partial/corrupted file exists and remove it
135+ if file_path. exists ( ) {
136+ if let Ok ( existing_bytes) = tokio:: fs:: read ( & file_path) . await {
137+ if Self :: is_valid_tokenizer_json ( & existing_bytes) {
138+ return Ok ( ( ) ) ; // File is complete and valid
139+ }
140+ }
141+ // Remove corrupted/incomplete file
142+ let _ = tokio:: fs:: remove_file ( & file_path) . await ;
143+ }
136144
137- if !response. status ( ) . is_success ( ) {
138- return Err ( format ! ( "HTTP {}: Failed to download tokenizer" , response. status( ) ) . into ( ) ) ;
145+ // Create enhanced HTTP client with timeouts
146+ let client = reqwest:: Client :: builder ( )
147+ . timeout ( std:: time:: Duration :: from_secs ( 60 ) )
148+ . connect_timeout ( std:: time:: Duration :: from_secs ( 15 ) )
149+ . user_agent ( "goose-tokenizer/1.0" )
150+ . build ( ) ?;
151+
152+ // Download with retry logic
153+ let response = Self :: download_with_retry ( & client, & file_url, 3 ) . await ?;
154+
155+ // Stream download with progress reporting for large files
156+ let total_size = response. content_length ( ) ;
157+ let mut stream = response. bytes_stream ( ) ;
158+ let mut file = tokio:: fs:: File :: create ( & file_path) . await ?;
159+ let mut downloaded = 0 ;
160+
161+ use tokio:: io:: AsyncWriteExt ;
162+
163+ while let Some ( chunk_result) = stream. next ( ) . await {
164+ let chunk = chunk_result?;
165+ file. write_all ( & chunk) . await ?;
166+ downloaded += chunk. len ( ) ;
167+
168+ // Progress reporting for large downloads
169+ if let Some ( total) = total_size {
170+ if total > 1024 * 1024 && downloaded % ( 256 * 1024 ) == 0 { // Report every 256KB for files >1MB
171+ eprintln ! ( "Downloaded {}/{} bytes ({:.1}%)" ,
172+ downloaded, total, ( downloaded as f64 / total as f64 ) * 100.0 ) ;
173+ }
174+ }
139175 }
140176
141- let bytes = response. bytes ( ) . await ?;
142- tokio:: fs:: write ( & file_path, bytes) . await ?;
177+ file. flush ( ) . await ?;
178+
179+ // Validate downloaded file
180+ let final_bytes = tokio:: fs:: read ( & file_path) . await ?;
181+ if !Self :: is_valid_tokenizer_json ( & final_bytes) {
182+ tokio:: fs:: remove_file ( & file_path) . await ?;
183+ return Err ( "Downloaded tokenizer file is invalid or corrupted" . into ( ) ) ;
184+ }
143185
186+ eprintln ! ( "Successfully downloaded tokenizer: {} ({} bytes)" , repo_id, downloaded) ;
144187 Ok ( ( ) )
145188 }
146189
190+ /// Download with exponential backoff retry logic
191+ async fn download_with_retry (
192+ client : & reqwest:: Client ,
193+ url : & str ,
194+ max_retries : u32 ,
195+ ) -> Result < reqwest:: Response , Box < dyn Error + Send + Sync > > {
196+ let mut delay = std:: time:: Duration :: from_millis ( 200 ) ;
197+
198+ for attempt in 0 ..=max_retries {
199+ match client. get ( url) . send ( ) . await {
200+ Ok ( response) if response. status ( ) . is_success ( ) => {
201+ return Ok ( response) ;
202+ }
203+ Ok ( response) if response. status ( ) . is_server_error ( ) => {
204+ // Retry on 5xx errors (server issues)
205+ if attempt < max_retries {
206+ eprintln ! ( "Server error {} on attempt {}/{}, retrying in {:?}" ,
207+ response. status( ) , attempt + 1 , max_retries + 1 , delay) ;
208+ tokio:: time:: sleep ( delay) . await ;
209+ delay = std:: cmp:: min ( delay * 2 , std:: time:: Duration :: from_secs ( 30 ) ) ; // Cap at 30s
210+ continue ;
211+ }
212+ return Err ( format ! ( "Server error after {} retries: {}" , max_retries, response. status( ) ) . into ( ) ) ;
213+ }
214+ Ok ( response) => {
215+ // Don't retry on 4xx errors (client errors like 404, 403)
216+ return Err ( format ! ( "Client error: {} - {}" , response. status( ) , url) . into ( ) ) ;
217+ }
218+ Err ( e) if attempt < max_retries => {
219+ // Retry on network errors (timeout, connection refused, DNS, etc.)
220+ eprintln ! ( "Network error on attempt {}/{}: {}, retrying in {:?}" ,
221+ attempt + 1 , max_retries + 1 , e, delay) ;
222+ tokio:: time:: sleep ( delay) . await ;
223+ delay = std:: cmp:: min ( delay * 2 , std:: time:: Duration :: from_secs ( 30 ) ) ; // Cap at 30s
224+ continue ;
225+ }
226+ Err ( e) => {
227+ return Err ( format ! ( "Network error after {} retries: {}" , max_retries, e) . into ( ) ) ;
228+ }
229+ }
230+ }
231+ unreachable ! ( )
232+ }
233+
234+ /// Validate that the downloaded file is a valid tokenizer JSON
235+ fn is_valid_tokenizer_json ( bytes : & [ u8 ] ) -> bool {
236+ // Basic validation: check if it's valid JSON and has tokenizer structure
237+ if let Ok ( json_str) = std:: str:: from_utf8 ( bytes) {
238+ if let Ok ( json_value) = serde_json:: from_str :: < serde_json:: Value > ( json_str) {
239+ // Check for basic tokenizer structure
240+ return json_value. get ( "version" ) . is_some ( ) ||
241+ json_value. get ( "vocab" ) . is_some ( ) ||
242+ json_value. get ( "model" ) . is_some ( ) ;
243+ }
244+ }
245+ false
246+ }
247+
147248 /// Count tokens with optimized caching
148249 pub fn count_tokens ( & self , text : & str ) -> usize {
149250 // Use faster AHash for better performance
@@ -871,4 +972,84 @@ mod tests {
871972 assert ! ( counter. cache_size( ) > 0 ) ;
872973 assert ! ( counter. cache_size( ) <= MAX_TOKEN_CACHE_SIZE ) ;
873974 }
975+
976+ #[ test]
977+ fn test_tokenizer_json_validation ( ) {
978+ // Test valid tokenizer JSON
979+ let valid_json = r#"{"version": "1.0", "model": {"type": "BPE"}}"# ;
980+ assert ! ( AsyncTokenCounter :: is_valid_tokenizer_json( valid_json. as_bytes( ) ) ) ;
981+
982+ let valid_json2 = r#"{"vocab": {"hello": 1, "world": 2}}"# ;
983+ assert ! ( AsyncTokenCounter :: is_valid_tokenizer_json( valid_json2. as_bytes( ) ) ) ;
984+
985+ // Test invalid JSON
986+ let invalid_json = r#"{"incomplete": true"# ;
987+ assert ! ( !AsyncTokenCounter :: is_valid_tokenizer_json( invalid_json. as_bytes( ) ) ) ;
988+
989+ // Test valid JSON but not tokenizer structure
990+ let wrong_structure = r#"{"random": "data", "not": "tokenizer"}"# ;
991+ assert ! ( !AsyncTokenCounter :: is_valid_tokenizer_json( wrong_structure. as_bytes( ) ) ) ;
992+
993+ // Test binary data
994+ let binary_data = [ 0xFF , 0xFE , 0x00 , 0x01 ] ;
995+ assert ! ( !AsyncTokenCounter :: is_valid_tokenizer_json( & binary_data) ) ;
996+
997+ // Test empty data
998+ assert ! ( !AsyncTokenCounter :: is_valid_tokenizer_json( & [ ] ) ) ;
999+ }
1000+
1001+ #[ tokio:: test]
1002+ async fn test_download_with_retry_logic ( ) {
1003+ // This test would require mocking HTTP responses
1004+ // For now, we test the retry logic structure by verifying the function exists
1005+ // In a full test suite, you'd use wiremock or similar to simulate failures
1006+
1007+ // Test that the function exists and has the right signature
1008+ let client = reqwest:: Client :: new ( ) ;
1009+
1010+ // Test with a known bad URL to verify error handling
1011+ let result = AsyncTokenCounter :: download_with_retry (
1012+ & client,
1013+ "https://httpbin.org/status/404" ,
1014+ 1
1015+ ) . await ;
1016+
1017+ assert ! ( result. is_err( ) , "Should fail with 404 error" ) ;
1018+
1019+ let error_msg = result. unwrap_err ( ) . to_string ( ) ;
1020+ assert ! ( error_msg. contains( "Client error: 404" ) , "Should contain client error message" ) ;
1021+ }
1022+
1023+ #[ tokio:: test]
1024+ async fn test_network_resilience_with_timeout ( ) {
1025+ // Test timeout handling with a slow endpoint
1026+ let client = reqwest:: Client :: builder ( )
1027+ . timeout ( std:: time:: Duration :: from_millis ( 100 ) ) // Very short timeout
1028+ . build ( )
1029+ . unwrap ( ) ;
1030+
1031+ // Use httpbin delay endpoint that takes longer than our timeout
1032+ let result = AsyncTokenCounter :: download_with_retry (
1033+ & client,
1034+ "https://httpbin.org/delay/1" , // 1 second delay, but 100ms timeout
1035+ 1
1036+ ) . await ;
1037+
1038+ assert ! ( result. is_err( ) , "Should timeout and fail" ) ;
1039+ }
1040+
1041+ #[ tokio:: test]
1042+ async fn test_successful_download_retry ( ) {
1043+ // Test successful download after simulated retry
1044+ let client = reqwest:: Client :: new ( ) ;
1045+
1046+ // Use a reliable endpoint that should succeed
1047+ let result = AsyncTokenCounter :: download_with_retry (
1048+ & client,
1049+ "https://httpbin.org/status/200" ,
1050+ 2
1051+ ) . await ;
1052+
1053+ assert ! ( result. is_ok( ) , "Should succeed with 200 status" ) ;
1054+ }
8741055}
0 commit comments