1616import logging
1717import os
1818import pickle
19+ from concurrent .futures import ThreadPoolExecutor , as_completed
1920from typing import Any , Optional
2021
2122import torch
3536except ImportError :
3637 MOONCAKE_STORE_IMPORTED = False
3738
38- BATCH_SIZE_LIMIT : int = 500
39+ BATCH_SIZE_LIMIT : int = 200
40+ MAX_WORKER_THREADS = 4
3941
4042
4143@StorageClientFactory .register ("MooncakeStoreClient" )
@@ -81,7 +83,7 @@ def __init__(self, config: dict[str, Any]):
8183 self .metadata_server = self .metadata_server + "/metadata"
8284
8385 self .replica_config = ReplicateConfig ()
84- # FIXME: hard_pin is not supported yet
86+ # FIXME: hard_pin support
8587 # self.replica_config.with_hard_pin = True
8688
8789 self ._store = MooncakeDistributedStore ()
@@ -97,7 +99,7 @@ def __init__(self, config: dict[str, Any]):
9799 if ret != 0 :
98100 raise RuntimeError (f"Mooncake store setup failed with error code: { ret } " )
99101
100- def put (self , keys : list [str ], values : list [Any ]) -> Optional [ list [ Any ]] :
102+ def put (self , keys : list [str ], values : list [Any ]) -> None :
101103 """Stores multiple key-value pairs to MooncakeStore.
102104
103105 Args:
@@ -121,43 +123,51 @@ def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]:
121123 tensor_values .append (value )
122124 else :
123125 non_tensor_keys .append (key )
124- non_tensor_values .append (pickle . dumps ( value ) )
126+ non_tensor_values .append (value )
125127
126- if tensor_keys :
127- self ._batch_put_tensors (tensor_keys , tensor_values )
128+ futures = []
129+ with ThreadPoolExecutor (max_workers = MAX_WORKER_THREADS ) as executor :
130+ for i in range (0 , len (tensor_keys ), BATCH_SIZE_LIMIT ):
131+ batch_keys = tensor_keys [i : i + BATCH_SIZE_LIMIT ]
132+ batch_tensors = tensor_values [i : i + BATCH_SIZE_LIMIT ]
133+ futures .append (executor .submit (self ._put_tensors_thread_worker , batch_keys , batch_tensors ))
128134
129- if non_tensor_keys :
130- self ._batch_put_bytes (non_tensor_keys , non_tensor_values )
135+ for i in range (0 , len (non_tensor_keys ), BATCH_SIZE_LIMIT ):
136+ batch_keys = non_tensor_keys [i : i + BATCH_SIZE_LIMIT ]
137+ batch_values = non_tensor_values [i : i + BATCH_SIZE_LIMIT ]
138+ futures .append (executor .submit (self ._put_bytes_thread_worker , batch_keys , batch_values ))
139+
140+ for future in as_completed (futures ):
141+ future .result ()
131142
132143 return None
133144
134- def _batch_put_tensors (self , keys : list [str ], tensors : list [Tensor ]):
135- for i in range (0 , len (keys ), BATCH_SIZE_LIMIT ):
136- batch_keys = keys [i : i + BATCH_SIZE_LIMIT ]
137- batch_tensors = tensors [i : i + BATCH_SIZE_LIMIT ]
145+ def _put_tensors_thread_worker (self , batch_keys : list [str ], batch_tensors : list [Tensor ]):
146+ """Worker thread for putting batch of tensors to MooncakeStore."""
138147
139- batch_ptrs , batch_sizes = self ._preprocess_tensors_for_put (batch_tensors )
140- batch_ptr_reduced , batch_sizes_reduced = merge_continues_memory (batch_ptrs , batch_sizes )
141- self ._register_all_buffers (batch_ptr_reduced , batch_sizes_reduced )
148+ batch_ptrs , batch_sizes , contiguous_tensors = self ._preprocess_tensors_for_put (batch_tensors )
149+ batch_ptr_reduced , batch_sizes_reduced = merge_continues_memory (batch_ptrs , batch_sizes )
150+ self ._register_all_buffers (batch_ptr_reduced , batch_sizes_reduced )
142151
152+ try :
143153 results = self ._store .batch_upsert_from (batch_keys , batch_ptrs , batch_sizes , config = self .replica_config )
144154 if not all (r == 0 for r in results ):
145155 failed_indices = [j for j , r in enumerate (results ) if r != 0 ]
146156 error_codes = [results [j ] for j in failed_indices ]
147157 raise RuntimeError (
148158 f"batch_put_tensor failed for indices { failed_indices } with error codes: { error_codes } "
149159 )
150-
160+ finally :
151161 self ._unregister_all_buffers (batch_ptr_reduced )
152162
153- def _batch_put_bytes (self , keys : list [str ], values : list [bytes ]):
154- for i in range (0 , len (keys ), BATCH_SIZE_LIMIT ):
155- batch_keys = keys [i : i + BATCH_SIZE_LIMIT ]
156- batch_values = values [i : i + BATCH_SIZE_LIMIT ]
163+ def _put_bytes_thread_worker (self , batch_keys : list [str ], batch_values : list [bytes ]):
164+ """Worker thread for putting batch of non-tensors to MooncakeStore."""
157165
158- ret = self ._store .upsert_batch (batch_keys , batch_values , self .replica_config )
159- if ret != 0 :
160- raise RuntimeError (f"put_batch failed with error code: { ret } " )
166+ batch_values = [pickle .dumps (v , protocol = pickle .HIGHEST_PROTOCOL ) for v in batch_values ]
167+
168+ ret = self ._store .upsert_batch (batch_keys , batch_values , self .replica_config )
169+ if ret != 0 :
170+ raise RuntimeError (f"put_batch failed with error code: { ret } " )
161171
162172 def get (
163173 self ,
@@ -194,71 +204,61 @@ def get(
194204
195205 results = [None ] * len (keys )
196206
197- if tensor_indices :
198- tensor_keys = [keys [i ] for i in tensor_indices ]
199- tensor_shapes = [shapes [i ] for i in tensor_indices ]
200- tensor_dtypes = [dtypes [i ] for i in tensor_indices ]
201- tensor_results = self ._batch_get_tensors (tensor_keys , tensor_shapes , tensor_dtypes )
202- # TODO: optimize these for loops
203- for idx , tensor in zip (tensor_indices , tensor_results , strict = True ):
204- results [idx ] = tensor
205-
206- if non_tensor_indices :
207- non_tensor_keys = [keys [i ] for i in non_tensor_indices ]
208- non_tensor_results = self ._batch_get_bytes (non_tensor_keys )
209- for idx , data in zip (non_tensor_indices , non_tensor_results , strict = True ):
210- results [idx ] = pickle .loads (data )
211-
212- return results
213-
214- def _batch_get_tensors (self , keys : list [str ], shapes : list , dtypes : list ) -> list [Tensor ]:
215- tensors = [None ] * len (keys )
207+ futures = []
208+ with ThreadPoolExecutor (max_workers = MAX_WORKER_THREADS ) as executor :
209+ for i in range (0 , len (tensor_indices ), BATCH_SIZE_LIMIT ):
210+ batch_indexes = tensor_indices [i : i + BATCH_SIZE_LIMIT ]
211+ batch_keys = [keys [i ] for i in batch_indexes ]
212+ batch_shapes = [shapes [i ] for i in batch_indexes ]
213+ batch_dtypes = [dtypes [i ] for i in batch_indexes ]
214+ futures .append (
215+ executor .submit (
216+ self ._get_tensors_thread_worker , batch_keys , batch_shapes , batch_dtypes , batch_indexes
217+ )
218+ )
216219
217- for i in range (0 , len (keys ), BATCH_SIZE_LIMIT ):
218- batch_keys = keys [i : i + BATCH_SIZE_LIMIT ]
219- batch_shapes = shapes [ i : i + BATCH_SIZE_LIMIT ]
220- batch_dtypes = dtypes [ i : i + BATCH_SIZE_LIMIT ]
220+ for i in range (0 , len (non_tensor_indices ), BATCH_SIZE_LIMIT ):
221+ batch_indexes = non_tensor_indices [i : i + BATCH_SIZE_LIMIT ]
222+ batch_keys = [ keys [ i ] for i in batch_indexes ]
223+ futures . append ( executor . submit ( self . _get_bytes_thread_worker , batch_keys , batch_indexes ))
221224
222- batch_nbytes = get_nbytes (batch_dtypes , batch_shapes )
223- batch_buffer_tensors , batch_buffer_ptrs = allocate_empty_tensors (batch_dtypes , batch_shapes )
225+ for future in as_completed (futures ):
226+ retrieved_values , batch_indexes = future .result ()
227+ for idx , val in zip (batch_indexes , retrieved_values , strict = True ):
228+ results [idx ] = val
224229
225- batch_ptrs = batch_buffer_ptrs
230+ return results
226231
227- self ._register_all_buffers (batch_ptrs , batch_nbytes )
228- ret_codes = self ._store .batch_get_into (batch_keys , batch_ptrs , batch_nbytes )
229- self ._unregister_all_buffers (batch_ptrs )
232+ def _get_tensors_thread_worker (
233+ self , batch_keys : list [str ], batch_shapes : list [tuple ], batch_dtypes : list [torch .dtype ], indexes : list [int ]
234+ ) -> tuple [list [Tensor ], list [int ]]:
235+ batch_nbytes = get_nbytes (batch_dtypes , batch_shapes )
236+ batch_buffer_tensors , batch_buffer_ptrs = allocate_empty_tensors (batch_dtypes , batch_shapes )
230237
238+ self ._register_all_buffers (batch_buffer_ptrs , batch_nbytes )
239+ try :
240+ ret_codes = self ._store .batch_get_into (batch_keys , batch_buffer_ptrs , batch_nbytes )
231241 if len (ret_codes ) != len (batch_keys ):
232242 raise RuntimeError (f"batch_get_into returned { len (ret_codes )} results, expected { len (batch_keys )} " )
243+ for i , ret in enumerate (ret_codes ):
244+ if ret < 0 :
245+ raise RuntimeError (f"batch_get_into failed for key `{ batch_keys [i ]} ` with error code: { ret } " )
246+ finally :
247+ self ._unregister_all_buffers (batch_buffer_ptrs )
233248
234- # Check result codes and validate tensors
235- # Note: Positive values indicate success (bytes read), negative values indicate error
236- for j , (tensor , shape , dtype , ret_code ) in enumerate (
237- zip (batch_buffer_tensors , batch_shapes , batch_dtypes , ret_codes , strict = True )
238- ):
239- if ret_code < 0 :
240- raise RuntimeError (f"batch_get_into failed for key '{ batch_keys [j ]} ' with error code: { ret_code } " )
241- if tensor .shape != torch .Size (shape ):
242- raise RuntimeError (
243- f"Shape mismatch for key '{ batch_keys [j ]} ': expected { shape } , got { tensor .shape } "
244- )
245- if tensor .dtype != dtype :
246- raise RuntimeError (
247- f"Dtype mismatch for key '{ batch_keys [j ]} ': expected { dtype } , got { tensor .dtype } "
248- )
249- tensors [i + j ] = tensor
250-
251- return tensors
249+ return batch_buffer_tensors , indexes
252250
253- def _batch_get_bytes (self , keys : list [str ]) -> list [bytes ]:
251+ def _get_bytes_thread_worker (self , batch_keys : list [str ], indexes : list [ int ] ) -> tuple [ list [Any ], list [ int ] ]:
254252 results = []
255- for i in range (0 , len (keys ), BATCH_SIZE_LIMIT ):
256- batch_keys = keys [i : i + BATCH_SIZE_LIMIT ]
257- batch_results = self ._store .get_batch (batch_keys )
258- if len (batch_results ) != len (batch_keys ):
259- raise RuntimeError (f"get_batch returned { len (batch_results )} items, expected { len (batch_keys )} " )
260- results .extend (batch_results )
261- return results
253+
254+ batch_results = self ._store .get_batch (batch_keys )
255+ if len (batch_results ) != len (batch_keys ):
256+ raise RuntimeError (f"get_batch returned { len (batch_results )} items, expected { len (batch_keys )} " )
257+
258+ batch_results = [pickle .loads (result ) for result in batch_results ]
259+ results .extend (batch_results )
260+
261+ return results , indexes
262262
263263 def clear (self , keys : list [str ], custom_backend_meta = None ):
264264 """Deletes multiple keys from MooncakeStore.
@@ -267,10 +267,10 @@ def clear(self, keys: list[str], custom_backend_meta=None):
267267 keys (List[str]): List of keys to remove.
268268 custom_backend_meta (List[Any], optional): ...
269269 """
270- rets = self ._store .batch_remove (keys , force = True )
271- for i , ret in enumerate (rets ):
270+ ret_codes = self ._store .batch_remove (keys , force = True )
271+ for i , ret in enumerate (ret_codes ):
272272 if not (ret == 0 or ret == - 704 ):
273- logger .error (f"remove failed for key ' { keys [i ]} ' with error code: { ret } " )
273+ logger .error (f"remove failed for key ` { keys [i ]} ` with error code: { ret } " )
274274
275275 def close (self ):
276276 """Closes MooncakeStore."""
@@ -279,17 +279,19 @@ def close(self):
279279 self ._store = None
280280
281281 @staticmethod
282- def _preprocess_tensors_for_put (values : list [Tensor ]) -> tuple [list [Any ], list [Any ]]:
282+ def _preprocess_tensors_for_put (values : list [Tensor ]) -> tuple [list [Any ], list [Any ], list [ Tensor ] ]:
283283 ptr_list = []
284284 size_list = []
285+ tensor_list = [] # hold reference for the contiguous tensor
285286 for t in values :
286287 t = t .contiguous ()
288+ tensor_list .append (t )
287289 ptr_list .append (t .data_ptr ())
288290 size_list .append (t .nbytes )
289- return ptr_list , size_list
291+ return ptr_list , size_list , tensor_list
290292
291293 def _register_all_buffers (self , ptrs , sizes ):
292- for ptr , size in zip (ptrs , sizes , strict = False ):
294+ for ptr , size in zip (ptrs , sizes , strict = True ):
293295 self ._store .register_buffer (ptr , size )
294296
295297 def _unregister_all_buffers (self , ptrs ):
0 commit comments