11use include_dir:: { include_dir, Dir } ;
22use mcp_core:: Tool ;
3- use std:: collections:: HashMap ;
43use std:: error:: Error ;
54use std:: hash:: { Hash , Hasher } ;
65use std:: path:: Path ;
76use std:: sync:: Arc ;
8- use std:: collections:: hash_map:: DefaultHasher ;
97use tokenizers:: tokenizer:: Tokenizer ;
10- use tokio:: sync:: { OnceCell , RwLock } ;
8+ use tokio:: sync:: OnceCell ;
119use dashmap:: DashMap ;
1210use std:: fs;
11+ use ahash:: AHasher ;
1312
1413use crate :: message:: Message ;
1514
@@ -18,7 +17,11 @@ use crate::message::Message;
1817static TOKENIZER_FILES : Dir = include_dir ! ( "$CARGO_MANIFEST_DIR/../../tokenizer_files" ) ;
1918
2019// Global tokenizer cache to avoid repeated downloads and loading
21- static TOKENIZER_CACHE : OnceCell < Arc < RwLock < HashMap < String , Arc < Tokenizer > > > > > = OnceCell :: const_new ( ) ;
20+ static TOKENIZER_CACHE : OnceCell < Arc < DashMap < String , Arc < Tokenizer > > > > = OnceCell :: const_new ( ) ;
21+
22+ // Cache size limits to prevent unbounded growth
23+ const MAX_TOKEN_CACHE_SIZE : usize = 10_000 ;
24+ const MAX_TOKENIZER_CACHE_SIZE : usize = 50 ;
2225
2326/// Async token counter with caching capabilities
2427pub struct AsyncTokenCounter {
@@ -36,18 +39,15 @@ impl AsyncTokenCounter {
3639 pub async fn new ( tokenizer_name : & str ) -> Result < Self , Box < dyn Error + Send + Sync > > {
3740 // Initialize global cache if not already done
3841 let cache = TOKENIZER_CACHE . get_or_init ( || async {
39- Arc :: new ( RwLock :: new ( HashMap :: new ( ) ) )
42+ Arc :: new ( DashMap :: new ( ) )
4043 } ) . await ;
4144
42- // Check cache first
43- {
44- let cache_read = cache. read ( ) . await ;
45- if let Some ( tokenizer) = cache_read. get ( tokenizer_name) {
46- return Ok ( Self {
47- tokenizer : tokenizer. clone ( ) ,
48- token_cache : Arc :: new ( DashMap :: new ( ) ) ,
49- } ) ;
50- }
45+ // Check cache first - DashMap allows concurrent reads
46+ if let Some ( tokenizer) = cache. get ( tokenizer_name) {
47+ return Ok ( Self {
48+ tokenizer : tokenizer. clone ( ) ,
49+ token_cache : Arc :: new ( DashMap :: new ( ) ) ,
50+ } ) ;
5151 }
5252
5353 // Try embedded first
@@ -59,11 +59,15 @@ impl AsyncTokenCounter {
5959 }
6060 } ;
6161
62- // Cache the tokenizer
63- {
64- let mut cache_write = cache. write ( ) . await ;
65- cache_write. insert ( tokenizer_name. to_string ( ) , tokenizer. clone ( ) ) ;
62+ // Cache the tokenizer with size management
63+ if cache. len ( ) >= MAX_TOKENIZER_CACHE_SIZE {
64+ // Simple eviction: remove oldest entry
65+ if let Some ( entry) = cache. iter ( ) . next ( ) {
66+ let old_key = entry. key ( ) . clone ( ) ;
67+ cache. remove ( & old_key) ;
68+ }
6669 }
70+ cache. insert ( tokenizer_name. to_string ( ) , tokenizer. clone ( ) ) ;
6771
6872 Ok ( Self {
6973 tokenizer,
@@ -135,10 +139,10 @@ impl AsyncTokenCounter {
135139 Ok ( ( ) )
136140 }
137141
138- /// Count tokens with caching
142+ /// Count tokens with optimized caching
139143 pub fn count_tokens ( & self , text : & str ) -> usize {
140- // Hash the input text for caching
141- let mut hasher = DefaultHasher :: new ( ) ;
144+ // Use faster AHash for better performance
145+ let mut hasher = AHasher :: default ( ) ;
142146 text. hash ( & mut hasher) ;
143147 let hash = hasher. finish ( ) ;
144148
@@ -147,14 +151,24 @@ impl AsyncTokenCounter {
147151 return * count;
148152 }
149153
150- // Compute and cache result
154+ // Compute and cache result with size management
151155 let encoding = self . tokenizer . encode ( text, false ) . unwrap_or_default ( ) ;
152156 let count = encoding. len ( ) ;
157+
158+ // Manage cache size to prevent unbounded growth
159+ if self . token_cache . len ( ) >= MAX_TOKEN_CACHE_SIZE {
160+ // Simple eviction: remove a random entry
161+ if let Some ( entry) = self . token_cache . iter ( ) . next ( ) {
162+ let old_hash = * entry. key ( ) ;
163+ self . token_cache . remove ( & old_hash) ;
164+ }
165+ }
166+
153167 self . token_cache . insert ( hash, count) ;
154168 count
155169 }
156170
157- /// Count tokens for tools (using cached count_tokens)
171+ /// Count tokens for tools with optimized string handling
158172 pub fn count_tokens_for_tools ( & self , tools : & [ Tool ] ) -> usize {
159173 // Token counts for different function components
160174 let func_init = 7 ; // Tokens for function initialization
@@ -170,6 +184,9 @@ impl AsyncTokenCounter {
170184 func_token_count += func_init;
171185 let name = & tool. name ;
172186 let description = & tool. description . trim_end_matches ( '.' ) ;
187+
188+ // Optimize: count components separately to avoid string allocation
189+ // Note: the separator (:) is likely tokenized with adjacent tokens, so we use original approach for accuracy
173190 let line = format ! ( "{}:{}" , name, description) ;
174191 func_token_count += self . count_tokens ( & line) ;
175192
@@ -184,8 +201,11 @@ impl AsyncTokenCounter {
184201 . as_str ( )
185202 . unwrap_or ( "" )
186203 . trim_end_matches ( '.' ) ;
204+
205+ // Note: separators are tokenized with adjacent tokens, keep original for accuracy
187206 let line = format ! ( "{}:{}:{}" , p_name, p_type, p_desc) ;
188207 func_token_count += self . count_tokens ( & line) ;
208+
189209 if let Some ( enum_values) = value[ "enum" ] . as_array ( ) {
190210 func_token_count =
191211 func_token_count. saturating_add_signed ( enum_init) ;
@@ -227,6 +247,7 @@ impl AsyncTokenCounter {
227247 num_tokens += self . count_tokens ( content_text) ;
228248 } else if let Some ( tool_request) = content. as_tool_request ( ) {
229249 let tool_call = tool_request. tool_call . as_ref ( ) . unwrap ( ) ;
250+ // Note: separators are tokenized with adjacent tokens, keep original for accuracy
230251 let text = format ! (
231252 "{}:{}:{}" ,
232253 tool_request. id, tool_call. name, tool_call. arguments
0 commit comments