1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- import pickle
1716import time
1817from concurrent .futures import ThreadPoolExecutor , as_completed
1918from typing import Any
2019
20+ import numpy as np
2121import torch
22+ from tensordict import TensorDictBase
2223from torch import Tensor
2324
2425from transfer_queue .storage .clients .base import StorageClientFactory , StorageKVClient
26+ from transfer_queue .utils import serial_utils
2527from transfer_queue .utils .logging_utils import get_logger
26- from transfer_queue .utils .tensor_utils import allocate_empty_tensors , get_nbytes , merge_contiguous_memory
28+ from transfer_queue .utils .tensor_utils import allocate_empty_tensors , get_nbytes
2729
2830logger = get_logger (__name__ )
2931
4042RETRY_DELAY_SECONDS = 1.0
4143
4244
45+ def _detach_from_buffer (obj : Any ) -> Any :
46+ """Deep-copy all tensor/array leaves so the result owns its storage."""
47+ # TODO: replace with a keep-alive scheme on the source buffer to skip the copy.
48+ if isinstance (obj , torch .Tensor ):
49+ return obj .clone ()
50+ if isinstance (obj , np .ndarray ):
51+ return obj .copy ()
52+ if isinstance (obj , dict ):
53+ return {k : _detach_from_buffer (v ) for k , v in obj .items ()}
54+ if isinstance (obj , list ):
55+ return [_detach_from_buffer (v ) for v in obj ]
56+ if isinstance (obj , tuple ):
57+ return tuple (_detach_from_buffer (v ) for v in obj )
58+ if isinstance (obj , TensorDictBase ):
59+ return obj .apply (lambda t : t .clone ())
60+ return obj
61+
62+
4363@StorageClientFactory .register ("MooncakeStoreClient" )
4464class MooncakeStoreClient (StorageKVClient ):
4565 """
4666 Storage client for MooncakeStore.
4767 """
4868
69+ _logged_first_put : bool = False
70+ _logged_first_get : bool = False
71+
4972 def __init__ (self , config : dict [str , Any ]):
5073 super ().__init__ (config )
5174 if not MOONCAKE_STORE_IMPORTED :
@@ -98,55 +121,70 @@ def __init__(self, config: dict[str, Any]):
98121 if ret != 0 :
99122 raise RuntimeError (f"Mooncake store setup failed with error code: { ret } " )
100123
101- def put (self , keys : list [str ], values : list [Any ]) -> None :
124+ def put (self , keys : list [str ], values : list [Any ]) -> list [ dict ] :
102125 """Stores multiple key-value pairs to MooncakeStore.
103126
104127 Args:
105128 keys (List[str]): List of unique string identifiers.
106129 values (List[Any]): List of values to store (tensors, scalars, dicts, etc.).
107- """
108130
131+ Returns:
132+ List[Dict]: Per-key ``{"packed_size": int}`` metadata, in the same order as ``keys``.
133+ """
109134 if not isinstance (keys , list ) or not isinstance (values , list ):
110135 raise ValueError ("keys and values must be lists" )
111136 if len (keys ) != len (values ):
112137 raise ValueError ("Number of keys must match number of values" )
113138
114- tensor_keys = []
115- tensor_values = []
116- non_tensor_keys = []
117- non_tensor_values = []
118-
119- for key , value in zip (keys , values , strict = True ):
120- if isinstance (value , torch .Tensor ):
121- tensor_keys .append (key )
122- tensor_values .append (value )
123- else :
124- non_tensor_keys .append (key )
125- non_tensor_values .append (value )
139+ if not type (self )._logged_first_put :
140+ logger .info ("[TQ-MOONCAKE-REFACTOR] put() entered: unified pack-into data path" )
141+ type(self )._logged_first_put = True
126142
143+ custom_meta : list [dict ] = []
127144 futures = []
128145 with ThreadPoolExecutor (max_workers = MAX_WORKER_THREADS ) as executor :
129- for i in range (0 , len (tensor_keys ), BATCH_SIZE_LIMIT ):
130- batch_keys = tensor_keys [i : i + BATCH_SIZE_LIMIT ]
131- batch_tensors = tensor_values [i : i + BATCH_SIZE_LIMIT ]
132- futures .append (executor .submit (self ._put_tensors_thread_worker , batch_keys , batch_tensors ))
133-
134- for i in range (0 , len (non_tensor_keys ), BATCH_SIZE_LIMIT ):
135- batch_keys = non_tensor_keys [i : i + BATCH_SIZE_LIMIT ]
136- batch_values = non_tensor_values [i : i + BATCH_SIZE_LIMIT ]
137- futures .append (executor .submit (self ._put_bytes_thread_worker , batch_keys , batch_values ))
146+ for start in range (0 , len (values ), BATCH_SIZE_LIMIT ):
147+ batch_keys = keys [start : start + BATCH_SIZE_LIMIT ]
148+ batch_values = values [start : start + BATCH_SIZE_LIMIT ]
149+
150+ # Encode every value (msgpack zero-copy; transparent pickle fallback),
151+ # then pack all encoded frames into ONE contiguous buffer. The worker
152+ # registers this single buffer once and uses per-slice (ptr, size) for
153+ # batch_upsert_from.
154+ batch_items = [serial_utils .encode (v ) for v in batch_values ]
155+ batch_sizes = [serial_utils .calc_packed_size (items ) for items in batch_items ]
156+ # TODO: switch to a MooncakeStore-allocated buffer once such an API exists.
157+ big_buf = torch .empty (sum (batch_sizes ), dtype = torch .uint8 )
158+ big_buf_mv = big_buf .numpy ().data
159+ base_ptr = big_buf .data_ptr ()
160+
161+ batch_ptrs : list [int ] = []
162+ offset = 0
163+ for items , size in zip (batch_items , batch_sizes , strict = True ):
164+ serial_utils .pack_into (big_buf_mv [offset : offset + size ], items )
165+ batch_ptrs .append (base_ptr + offset )
166+ offset += size
167+
168+ custom_meta .extend ({"packed_size" : s } for s in batch_sizes )
169+ futures .append (
170+ executor .submit (self ._put_batch_worker , batch_keys , batch_ptrs , batch_sizes , big_buf )
171+ )
138172
139173 for future in as_completed (futures ):
140174 future .result ()
141175
142- return None
176+ return custom_meta
177+
178+ def _put_batch_worker (
179+ self , batch_keys : list [str ], batch_ptrs : list [int ], batch_sizes : list [int ], big_buf : Tensor
180+ ) -> None :
181+ """Worker thread for putting one packed-buffer batch to MooncakeStore.
143182
144- def _put_tensors_thread_worker (self , batch_keys : list [str ], batch_tensors : list [Tensor ]) -> None :
145- """Worker thread for putting batch of tensors to MooncakeStore."""
183+ ``big_buf`` is passed only to keep the underlying storage alive while
184+ ``batch_ptrs`` (per-value slices into it) are in flight.
185+ """
146186
147- batch_ptrs , batch_sizes , _contiguous_tensors = self ._preprocess_tensors_for_put (batch_tensors )
148- batch_ptr_reduced , batch_sizes_reduced = merge_contiguous_memory (batch_ptrs , batch_sizes )
149- self ._register_all_buffers (batch_ptr_reduced , batch_sizes_reduced )
187+ self ._store .register_buffer (big_buf .data_ptr (), big_buf .nbytes )
150188
151189 try :
152190 results = self ._store .batch_upsert_from (batch_keys , batch_ptrs , batch_sizes , config = self .replica_config )
@@ -206,98 +244,65 @@ def _put_tensors_thread_worker(self, batch_keys: list[str], batch_tensors: list[
206244 )
207245
208246 finally :
209- self ._unregister_all_buffers (batch_ptr_reduced )
210-
211- def _put_bytes_thread_worker (self , batch_keys : list [str ], batch_values : list [Any ]):
212- """Worker thread for putting batch of non-tensors to MooncakeStore."""
213-
214- serialized_values = [pickle .dumps (v , protocol = pickle .HIGHEST_PROTOCOL ) for v in batch_values ]
215-
216- # FIXME: When MooncakeStore supports per-key status codes for upsert_batch and get_batch,
217- # switch the bytes write/read paths from whole-batch retry to per-key selective retry,
218- # matching the tensor-path behaviour.
219- ret = self ._store .upsert_batch (batch_keys , serialized_values , self .replica_config )
220- if ret == 0 :
221- return
247+ self ._store .unregister_buffer (big_buf .data_ptr ())
222248
223- logger .error (
224- f"upsert_batch failed for { len (batch_keys )} keys with error code: { ret } . "
225- f"Retrying up to { MAX_RETRIES } times..."
226- )
227-
228- for attempt in range (1 , MAX_RETRIES + 1 ):
229- ret = self ._store .upsert_batch (batch_keys , serialized_values , self .replica_config )
230- if ret == 0 :
231- logger .info ("upsert_batch succeeded after retransmission." )
232- return
233-
234- logger .error (
235- f"upsert_batch retry { attempt } /{ MAX_RETRIES } failed for { len (batch_keys )} keys with error code: { ret } ."
236- )
237- if attempt < MAX_RETRIES :
238- time .sleep (RETRY_DELAY_SECONDS )
239-
240- raise RuntimeError (
241- f"upsert_batch failed for { len (batch_keys )} keys with error code: { ret } after retrying { MAX_RETRIES } times."
242- )
243-
244- def get (
245- self ,
246- keys : list [str ],
247- shapes : list [Any ] | None = None ,
248- dtypes : list [Any ] | None = None ,
249- custom_backend_meta : list [str ] | None = None ,
250- ) -> list [Any ]:
249+ def get (self , keys : list [str ], ** kwargs ) -> list [Any ]:
251250 """Get multiple key-value pairs from MooncakeStore.
252251
253252 Args:
254253 keys: Keys to fetch.
255- shapes: Expected tensor shapes (use [] for scalars).
256- dtypes: Expected dtypes; use None for non-tensor data.
257- custom_backend_meta: Optional custom backend metadata.
254+ **kwargs: Must contain ``custom_backend_meta`` — per-key dicts
255+ carrying ``"packed_size": int``.
258256
259257 Returns:
260258 Retrieved values in the same order as input keys.
261259 """
260+ if not type (self )._logged_first_get :
261+ logger .info ("[TQ-MOONCAKE-REFACTOR] get() entered: unified unpack+detach data path" )
262+ type(self )._logged_first_get = True
263+
264+ custom_backend_meta = kwargs .get ("custom_backend_meta" )
265+ if custom_backend_meta is None :
266+ raise ValueError ("MooncakeStoreClient.get requires custom_backend_meta with per-key packed_size." )
267+ if len (custom_backend_meta ) != len (keys ):
268+ raise ValueError (
269+ f"Length of custom_backend_meta ({ len (custom_backend_meta )} ) must match keys ({ len (keys )} )"
270+ )
262271
263- if shapes is None or dtypes is None :
264- raise ValueError ("MooncakeStoreClient needs shapes and dtypes for zero-copy transfer." )
265- if not (len (keys ) == len (shapes ) == len (dtypes )):
266- raise ValueError ("Lengths of keys, shapes, dtypes must match" )
267-
268- tensor_indices = []
269- non_tensor_indices = []
270-
271- for i , dtype in enumerate (dtypes ):
272- if dtype is not None :
273- tensor_indices .append (i )
274- else :
275- non_tensor_indices .append (i )
272+ try :
273+ packed_sizes = [m ["packed_size" ] for m in custom_backend_meta ]
274+ except (KeyError , TypeError ) as e :
275+ raise ValueError ("custom_backend_meta entries must be dicts with 'packed_size'" ) from e
276276
277- results = [None ] * len (keys )
277+ results : list [ Any ] = [None ] * len (keys )
278278
279+ # TODO: when MooncakeStore exposes a pre-registered receive-buffer API
280+ # (symmetric to YuanRong's get_buffers), drop the local alloc + register
281+ # below and hand decoded views straight from MooncakeStore's memory.
279282 futures = []
280283 with ThreadPoolExecutor (max_workers = MAX_WORKER_THREADS ) as executor :
281- for i in range (0 , len (tensor_indices ), BATCH_SIZE_LIMIT ):
282- batch_indexes = tensor_indices [i : i + BATCH_SIZE_LIMIT ]
283- batch_keys = [keys [i ] for i in batch_indexes ]
284- batch_shapes = [shapes [i ] for i in batch_indexes ]
285- batch_dtypes = [dtypes [i ] for i in batch_indexes ]
284+ for start in range (0 , len (keys ), BATCH_SIZE_LIMIT ):
285+ end = min (start + BATCH_SIZE_LIMIT , len (keys ))
286+ batch_keys = keys [start :end ]
287+ batch_shapes = [(packed_sizes [j ],) for j in range (start , end )]
288+ batch_dtypes = [torch .uint8 ] * (end - start )
289+ batch_indexes = list (range (start , end ))
286290 futures .append (
287291 executor .submit (
288292 self ._get_tensors_thread_worker , batch_keys , batch_shapes , batch_dtypes , batch_indexes
289293 )
290294 )
291295
292- for i in range (0 , len (non_tensor_indices ), BATCH_SIZE_LIMIT ):
293- batch_indexes = non_tensor_indices [i : i + BATCH_SIZE_LIMIT ]
294- batch_keys = [keys [i ] for i in batch_indexes ]
295- futures .append (executor .submit (self ._get_bytes_thread_worker , batch_keys , batch_indexes ))
296-
297296 for future in as_completed (futures ):
298- retrieved_values , batch_indexes = future .result ()
299- for idx , val in zip (batch_indexes , retrieved_values , strict = True ):
300- results [idx ] = val
297+ packed_tensors , indexes = future .result ()
298+ for idx , packed in zip (indexes , packed_tensors , strict = True ):
299+ results [idx ] = packed
300+
301+ for idx , packed in enumerate (results ):
302+ if packed is None :
303+ continue
304+ items = serial_utils .unpack_from (packed .numpy ().data )
305+ results [idx ] = _detach_from_buffer (serial_utils .decode (items ))
301306
302307 return results
303308
@@ -374,57 +379,6 @@ def _get_tensors_thread_worker(
374379
375380 return batch_buffer_tensors , indexes
376381
377- def _get_bytes_thread_worker (self , batch_keys : list [str ], indexes : list [int ]) -> tuple [list [Any ], list [int ]]:
378- raw_results = self ._store .get_batch (batch_keys )
379- if len (raw_results ) != len (batch_keys ):
380- raise RuntimeError (f"get_batch returned { len (raw_results )} items, expected { len (batch_keys )} " )
381-
382- # FIXME: Use MooncakeStore provided ret codes to detect transmission failures when supported
383- # Currently we rely on empty bytes (b'') to detect transmission failures because
384- # MooncakeStore does not currently return a separate status code per key.
385- failed_indices = [i for i , result in enumerate (raw_results ) if result == b"" ]
386- if failed_indices :
387- current_failed_keys = [batch_keys [i ] for i in failed_indices ]
388- current_failed_indices = failed_indices
389-
390- logger .error (f"get_batch failed for keys { current_failed_keys } . Retrying up to { MAX_RETRIES } times..." )
391-
392- for attempt in range (1 , MAX_RETRIES + 1 ):
393- retry_results = self ._store .get_batch (current_failed_keys )
394-
395- next_failed_keys = []
396- next_failed_indices = []
397-
398- for i , result in enumerate (retry_results ):
399- original_idx = current_failed_indices [i ]
400- if result == b"" :
401- next_failed_keys .append (current_failed_keys [i ])
402- next_failed_indices .append (original_idx )
403- else :
404- # Write the successfully retried value back to its original slot immediately.
405- raw_results [original_idx ] = result
406-
407- if not next_failed_indices :
408- logger .info ("get_batch succeeded after retransmission." )
409- break # All retries in this attempt succeeded.
410-
411- logger .error (f"get_batch retry { attempt } /{ MAX_RETRIES } failed for { len (next_failed_keys )} keys." )
412-
413- # Narrow down to still-failed items for the next retry attempt.
414- current_failed_keys = next_failed_keys
415- current_failed_indices = next_failed_indices
416-
417- if attempt < MAX_RETRIES :
418- time .sleep (RETRY_DELAY_SECONDS )
419- else :
420- # All retries exhausted.
421- raise RuntimeError (
422- f"get_batch failed for keys { current_failed_keys } after retrying { MAX_RETRIES } times."
423- )
424-
425- deserialized_results = [pickle .loads (result ) if result != b"" else None for result in raw_results ]
426- return deserialized_results , indexes
427-
428382 def clear (self , keys : list [str ], custom_backend_meta : list [Any ] | None = None ) -> None :
429383 """Deletes multiple keys from MooncakeStore.
430384
@@ -443,23 +397,6 @@ def close(self):
443397 self ._store .close ()
444398 self ._store = None
445399
446- @staticmethod
447- def _preprocess_tensors_for_put (values : list [Tensor ]) -> tuple [list [int ], list [int ], list [Tensor ]]:
448- ptr_list : list [int ] = []
449- size_list : list [int ] = []
450- tensor_list : list [Tensor ] = [] # hold reference for the contiguous tensor
451- for t in values :
452- # TODO: support gpu direct rdma and use different data paths.
453- # For GPU, it's more reasonable to perform data copy since
454- # The register overhead is much higher than CPU
455- if t .device .type == "cuda" :
456- t = t .cpu ()
457- t = t .contiguous ()
458- tensor_list .append (t )
459- ptr_list .append (t .data_ptr ())
460- size_list .append (t .nbytes )
461- return ptr_list , size_list , tensor_list
462-
463400 def _register_all_buffers (self , ptrs , sizes ):
464401 for ptr , size in zip (ptrs , sizes , strict = True ):
465402 self ._store .register_buffer (ptr , size )
0 commit comments