11use ahash:: AHasher ;
22use dashmap:: DashMap ;
3+ use futures_util:: stream:: StreamExt ;
34use include_dir:: { include_dir, Dir } ;
45use mcp_core:: Tool ;
56use std:: error:: Error ;
@@ -9,7 +10,6 @@ use std::path::Path;
910use std:: sync:: Arc ;
1011use tokenizers:: tokenizer:: Tokenizer ;
1112use tokio:: sync:: OnceCell ;
12- use futures_util:: stream:: StreamExt ;
1313
1414use crate :: message:: Message ;
1515
@@ -151,7 +151,7 @@ impl AsyncTokenCounter {
151151
152152 // Download with retry logic
153153 let response = Self :: download_with_retry ( & client, & file_url, 3 ) . await ?;
154-
154+
155155 // Stream download with progress reporting for large files
156156 let total_size = response. content_length ( ) ;
157157 let mut stream = response. bytes_stream ( ) ;
@@ -164,12 +164,17 @@ impl AsyncTokenCounter {
164164 let chunk = chunk_result?;
165165 file. write_all ( & chunk) . await ?;
166166 downloaded += chunk. len ( ) ;
167-
167+
168168 // Progress reporting for large downloads
169169 if let Some ( total) = total_size {
170- if total > 1024 * 1024 && downloaded % ( 256 * 1024 ) == 0 { // Report every 256KB for files >1MB
171- eprintln ! ( "Downloaded {}/{} bytes ({:.1}%)" ,
172- downloaded, total, ( downloaded as f64 / total as f64 ) * 100.0 ) ;
170+ if total > 1024 * 1024 && downloaded % ( 256 * 1024 ) == 0 {
171+ // Report every 256KB for files >1MB
172+ eprintln ! (
173+ "Downloaded {}/{} bytes ({:.1}%)" ,
174+ downloaded,
175+ total,
176+ ( downloaded as f64 / total as f64 ) * 100.0
177+ ) ;
173178 }
174179 }
175180 }
@@ -183,7 +188,10 @@ impl AsyncTokenCounter {
183188 return Err ( "Downloaded tokenizer file is invalid or corrupted" . into ( ) ) ;
184189 }
185190
186- eprintln ! ( "Successfully downloaded tokenizer: {} ({} bytes)" , repo_id, downloaded) ;
191+ eprintln ! (
192+ "Successfully downloaded tokenizer: {} ({} bytes)" ,
193+ repo_id, downloaded
194+ ) ;
187195 Ok ( ( ) )
188196 }
189197
@@ -194,7 +202,7 @@ impl AsyncTokenCounter {
194202 max_retries : u32 ,
195203 ) -> Result < reqwest:: Response , Box < dyn Error + Send + Sync > > {
196204 let mut delay = std:: time:: Duration :: from_millis ( 200 ) ;
197-
205+
198206 for attempt in 0 ..=max_retries {
199207 match client. get ( url) . send ( ) . await {
200208 Ok ( response) if response. status ( ) . is_success ( ) => {
@@ -203,28 +211,45 @@ impl AsyncTokenCounter {
203211 Ok ( response) if response. status ( ) . is_server_error ( ) => {
204212 // Retry on 5xx errors (server issues)
205213 if attempt < max_retries {
206- eprintln ! ( "Server error {} on attempt {}/{}, retrying in {:?}" ,
207- response. status( ) , attempt + 1 , max_retries + 1 , delay) ;
214+ eprintln ! (
215+ "Server error {} on attempt {}/{}, retrying in {:?}" ,
216+ response. status( ) ,
217+ attempt + 1 ,
218+ max_retries + 1 ,
219+ delay
220+ ) ;
208221 tokio:: time:: sleep ( delay) . await ;
209222 delay = std:: cmp:: min ( delay * 2 , std:: time:: Duration :: from_secs ( 30 ) ) ; // Cap at 30s
210223 continue ;
211224 }
212- return Err ( format ! ( "Server error after {} retries: {}" , max_retries, response. status( ) ) . into ( ) ) ;
225+ return Err ( format ! (
226+ "Server error after {} retries: {}" ,
227+ max_retries,
228+ response. status( )
229+ )
230+ . into ( ) ) ;
213231 }
214232 Ok ( response) => {
215233 // Don't retry on 4xx errors (client errors like 404, 403)
216234 return Err ( format ! ( "Client error: {} - {}" , response. status( ) , url) . into ( ) ) ;
217235 }
218236 Err ( e) if attempt < max_retries => {
219237 // Retry on network errors (timeout, connection refused, DNS, etc.)
220- eprintln ! ( "Network error on attempt {}/{}: {}, retrying in {:?}" ,
221- attempt + 1 , max_retries + 1 , e, delay) ;
238+ eprintln ! (
239+ "Network error on attempt {}/{}: {}, retrying in {:?}" ,
240+ attempt + 1 ,
241+ max_retries + 1 ,
242+ e,
243+ delay
244+ ) ;
222245 tokio:: time:: sleep ( delay) . await ;
223246 delay = std:: cmp:: min ( delay * 2 , std:: time:: Duration :: from_secs ( 30 ) ) ; // Cap at 30s
224247 continue ;
225248 }
226249 Err ( e) => {
227- return Err ( format ! ( "Network error after {} retries: {}" , max_retries, e) . into ( ) ) ;
250+ return Err (
251+ format ! ( "Network error after {} retries: {}" , max_retries, e) . into ( ) ,
252+ ) ;
228253 }
229254 }
230255 }
@@ -237,9 +262,9 @@ impl AsyncTokenCounter {
237262 if let Ok ( json_str) = std:: str:: from_utf8 ( bytes) {
238263 if let Ok ( json_value) = serde_json:: from_str :: < serde_json:: Value > ( json_str) {
239264 // Check for basic tokenizer structure
240- return json_value. get ( "version" ) . is_some ( ) ||
241- json_value. get ( "vocab" ) . is_some ( ) ||
242- json_value. get ( "model" ) . is_some ( ) ;
265+ return json_value. get ( "version" ) . is_some ( )
266+ || json_value. get ( "vocab" ) . is_some ( )
267+ || json_value. get ( "model" ) . is_some ( ) ;
243268 }
244269 }
245270 false
@@ -977,18 +1002,26 @@ mod tests {
9771002 fn test_tokenizer_json_validation ( ) {
9781003 // Test valid tokenizer JSON
9791004 let valid_json = r#"{"version": "1.0", "model": {"type": "BPE"}}"# ;
980- assert ! ( AsyncTokenCounter :: is_valid_tokenizer_json( valid_json. as_bytes( ) ) ) ;
1005+ assert ! ( AsyncTokenCounter :: is_valid_tokenizer_json(
1006+ valid_json. as_bytes( )
1007+ ) ) ;
9811008
9821009 let valid_json2 = r#"{"vocab": {"hello": 1, "world": 2}}"# ;
983- assert ! ( AsyncTokenCounter :: is_valid_tokenizer_json( valid_json2. as_bytes( ) ) ) ;
1010+ assert ! ( AsyncTokenCounter :: is_valid_tokenizer_json(
1011+ valid_json2. as_bytes( )
1012+ ) ) ;
9841013
9851014 // Test invalid JSON
9861015 let invalid_json = r#"{"incomplete": true"# ;
987- assert ! ( !AsyncTokenCounter :: is_valid_tokenizer_json( invalid_json. as_bytes( ) ) ) ;
1016+ assert ! ( !AsyncTokenCounter :: is_valid_tokenizer_json(
1017+ invalid_json. as_bytes( )
1018+ ) ) ;
9881019
9891020 // Test valid JSON but not tokenizer structure
9901021 let wrong_structure = r#"{"random": "data", "not": "tokenizer"}"# ;
991- assert ! ( !AsyncTokenCounter :: is_valid_tokenizer_json( wrong_structure. as_bytes( ) ) ) ;
1022+ assert ! ( !AsyncTokenCounter :: is_valid_tokenizer_json(
1023+ wrong_structure. as_bytes( )
1024+ ) ) ;
9921025
9931026 // Test binary data
9941027 let binary_data = [ 0xFF , 0xFE , 0x00 , 0x01 ] ;
@@ -1003,21 +1036,22 @@ mod tests {
10031036 // This test would require mocking HTTP responses
10041037 // For now, we test the retry logic structure by verifying the function exists
10051038 // In a full test suite, you'd use wiremock or similar to simulate failures
1006-
1039+
10071040 // Test that the function exists and has the right signature
10081041 let client = reqwest:: Client :: new ( ) ;
1009-
1042+
10101043 // Test with a known bad URL to verify error handling
1011- let result = AsyncTokenCounter :: download_with_retry (
1012- & client,
1013- "https://httpbin.org/status/404" ,
1014- 1
1015- ) . await ;
1016-
1044+ let result =
1045+ AsyncTokenCounter :: download_with_retry ( & client, "https://httpbin.org/status/404" , 1 )
1046+ . await ;
1047+
10171048 assert ! ( result. is_err( ) , "Should fail with 404 error" ) ;
1018-
1049+
10191050 let error_msg = result. unwrap_err ( ) . to_string ( ) ;
1020- assert ! ( error_msg. contains( "Client error: 404" ) , "Should contain client error message" ) ;
1051+ assert ! (
1052+ error_msg. contains( "Client error: 404" ) ,
1053+ "Should contain client error message"
1054+ ) ;
10211055 }
10221056
10231057 #[ tokio:: test]
@@ -1027,29 +1061,28 @@ mod tests {
10271061 . timeout ( std:: time:: Duration :: from_millis ( 100 ) ) // Very short timeout
10281062 . build ( )
10291063 . unwrap ( ) ;
1030-
1064+
10311065 // Use httpbin delay endpoint that takes longer than our timeout
10321066 let result = AsyncTokenCounter :: download_with_retry (
10331067 & client,
10341068 "https://httpbin.org/delay/1" , // 1 second delay, but 100ms timeout
1035- 1
1036- ) . await ;
1037-
1069+ 1 ,
1070+ )
1071+ . await ;
1072+
10381073 assert ! ( result. is_err( ) , "Should timeout and fail" ) ;
10391074 }
10401075
10411076 #[ tokio:: test]
10421077 async fn test_successful_download_retry ( ) {
10431078 // Test successful download after simulated retry
10441079 let client = reqwest:: Client :: new ( ) ;
1045-
1080+
10461081 // Use a reliable endpoint that should succeed
1047- let result = AsyncTokenCounter :: download_with_retry (
1048- & client,
1049- "https://httpbin.org/status/200" ,
1050- 2
1051- ) . await ;
1052-
1082+ let result =
1083+ AsyncTokenCounter :: download_with_retry ( & client, "https://httpbin.org/status/200" , 2 )
1084+ . await ;
1085+
10531086 assert ! ( result. is_ok( ) , "Should succeed with 200 status" ) ;
10541087 }
10551088}
0 commit comments