Skip to content

Commit 16a2b7c

Browse files
committed
refactor(mooncake): unify tensor and non-tensor data paths into a single packed-buffer batch transfer
1 parent 94baa19 commit 16a2b7c

2 files changed

Lines changed: 154 additions & 170 deletions

File tree

transfer_queue/storage/clients/mooncake_client.py

Lines changed: 107 additions & 170 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,19 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import pickle
1716
import time
1817
from concurrent.futures import ThreadPoolExecutor, as_completed
1918
from typing import Any
2019

20+
import numpy as np
2121
import torch
22+
from tensordict import TensorDictBase
2223
from torch import Tensor
2324

2425
from transfer_queue.storage.clients.base import StorageClientFactory, StorageKVClient
26+
from transfer_queue.utils import serial_utils
2527
from transfer_queue.utils.logging_utils import get_logger
26-
from transfer_queue.utils.tensor_utils import allocate_empty_tensors, get_nbytes, merge_contiguous_memory
28+
from transfer_queue.utils.tensor_utils import allocate_empty_tensors, get_nbytes
2729

2830
logger = get_logger(__name__)
2931

@@ -40,12 +42,33 @@
4042
RETRY_DELAY_SECONDS = 1.0
4143

4244

45+
def _detach_from_buffer(obj: Any) -> Any:
46+
"""Deep-copy all tensor/array leaves so the result owns its storage."""
47+
# TODO: replace with a keep-alive scheme on the source buffer to skip the copy.
48+
if isinstance(obj, torch.Tensor):
49+
return obj.clone()
50+
if isinstance(obj, np.ndarray):
51+
return obj.copy()
52+
if isinstance(obj, dict):
53+
return {k: _detach_from_buffer(v) for k, v in obj.items()}
54+
if isinstance(obj, list):
55+
return [_detach_from_buffer(v) for v in obj]
56+
if isinstance(obj, tuple):
57+
return tuple(_detach_from_buffer(v) for v in obj)
58+
if isinstance(obj, TensorDictBase):
59+
return obj.apply(lambda t: t.clone())
60+
return obj
61+
62+
4363
@StorageClientFactory.register("MooncakeStoreClient")
4464
class MooncakeStoreClient(StorageKVClient):
4565
"""
4666
Storage client for MooncakeStore.
4767
"""
4868

69+
_logged_first_put: bool = False
70+
_logged_first_get: bool = False
71+
4972
def __init__(self, config: dict[str, Any]):
5073
super().__init__(config)
5174
if not MOONCAKE_STORE_IMPORTED:
@@ -98,55 +121,70 @@ def __init__(self, config: dict[str, Any]):
98121
if ret != 0:
99122
raise RuntimeError(f"Mooncake store setup failed with error code: {ret}")
100123

101-
def put(self, keys: list[str], values: list[Any]) -> None:
124+
def put(self, keys: list[str], values: list[Any]) -> list[dict]:
102125
"""Stores multiple key-value pairs to MooncakeStore.
103126
104127
Args:
105128
keys (List[str]): List of unique string identifiers.
106129
values (List[Any]): List of values to store (tensors, scalars, dicts, etc.).
107-
"""
108130
131+
Returns:
132+
List[Dict]: Per-key ``{"packed_size": int}`` metadata, in the same order as ``keys``.
133+
"""
109134
if not isinstance(keys, list) or not isinstance(values, list):
110135
raise ValueError("keys and values must be lists")
111136
if len(keys) != len(values):
112137
raise ValueError("Number of keys must match number of values")
113138

114-
tensor_keys = []
115-
tensor_values = []
116-
non_tensor_keys = []
117-
non_tensor_values = []
118-
119-
for key, value in zip(keys, values, strict=True):
120-
if isinstance(value, torch.Tensor):
121-
tensor_keys.append(key)
122-
tensor_values.append(value)
123-
else:
124-
non_tensor_keys.append(key)
125-
non_tensor_values.append(value)
139+
if not type(self)._logged_first_put:
140+
logger.info("[TQ-MOONCAKE-REFACTOR] put() entered: unified pack-into data path")
141+
type(self)._logged_first_put = True
126142

143+
custom_meta: list[dict] = []
127144
futures = []
128145
with ThreadPoolExecutor(max_workers=MAX_WORKER_THREADS) as executor:
129-
for i in range(0, len(tensor_keys), BATCH_SIZE_LIMIT):
130-
batch_keys = tensor_keys[i : i + BATCH_SIZE_LIMIT]
131-
batch_tensors = tensor_values[i : i + BATCH_SIZE_LIMIT]
132-
futures.append(executor.submit(self._put_tensors_thread_worker, batch_keys, batch_tensors))
133-
134-
for i in range(0, len(non_tensor_keys), BATCH_SIZE_LIMIT):
135-
batch_keys = non_tensor_keys[i : i + BATCH_SIZE_LIMIT]
136-
batch_values = non_tensor_values[i : i + BATCH_SIZE_LIMIT]
137-
futures.append(executor.submit(self._put_bytes_thread_worker, batch_keys, batch_values))
146+
for start in range(0, len(values), BATCH_SIZE_LIMIT):
147+
batch_keys = keys[start : start + BATCH_SIZE_LIMIT]
148+
batch_values = values[start : start + BATCH_SIZE_LIMIT]
149+
150+
# Encode every value (msgpack zero-copy; transparent pickle fallback),
151+
# then pack all encoded frames into ONE contiguous buffer. The worker
152+
# registers this single buffer once and uses per-slice (ptr, size) for
153+
# batch_upsert_from.
154+
batch_items = [serial_utils.encode(v) for v in batch_values]
155+
batch_sizes = [serial_utils.calc_packed_size(items) for items in batch_items]
156+
# TODO: switch to a MooncakeStore-allocated buffer once such an API exists.
157+
big_buf = torch.empty(sum(batch_sizes), dtype=torch.uint8)
158+
big_buf_mv = big_buf.numpy().data
159+
base_ptr = big_buf.data_ptr()
160+
161+
batch_ptrs: list[int] = []
162+
offset = 0
163+
for items, size in zip(batch_items, batch_sizes, strict=True):
164+
serial_utils.pack_into(big_buf_mv[offset : offset + size], items)
165+
batch_ptrs.append(base_ptr + offset)
166+
offset += size
167+
168+
custom_meta.extend({"packed_size": s} for s in batch_sizes)
169+
futures.append(
170+
executor.submit(self._put_batch_worker, batch_keys, batch_ptrs, batch_sizes, big_buf)
171+
)
138172

139173
for future in as_completed(futures):
140174
future.result()
141175

142-
return None
176+
return custom_meta
177+
178+
def _put_batch_worker(
179+
self, batch_keys: list[str], batch_ptrs: list[int], batch_sizes: list[int], big_buf: Tensor
180+
) -> None:
181+
"""Worker thread for putting one packed-buffer batch to MooncakeStore.
143182
144-
def _put_tensors_thread_worker(self, batch_keys: list[str], batch_tensors: list[Tensor]) -> None:
145-
"""Worker thread for putting batch of tensors to MooncakeStore."""
183+
``big_buf`` is passed only to keep the underlying storage alive while
184+
``batch_ptrs`` (per-value slices into it) are in flight.
185+
"""
146186

147-
batch_ptrs, batch_sizes, _contiguous_tensors = self._preprocess_tensors_for_put(batch_tensors)
148-
batch_ptr_reduced, batch_sizes_reduced = merge_contiguous_memory(batch_ptrs, batch_sizes)
149-
self._register_all_buffers(batch_ptr_reduced, batch_sizes_reduced)
187+
self._store.register_buffer(big_buf.data_ptr(), big_buf.nbytes)
150188

151189
try:
152190
results = self._store.batch_upsert_from(batch_keys, batch_ptrs, batch_sizes, config=self.replica_config)
@@ -206,98 +244,65 @@ def _put_tensors_thread_worker(self, batch_keys: list[str], batch_tensors: list[
206244
)
207245

208246
finally:
209-
self._unregister_all_buffers(batch_ptr_reduced)
210-
211-
def _put_bytes_thread_worker(self, batch_keys: list[str], batch_values: list[Any]):
212-
"""Worker thread for putting batch of non-tensors to MooncakeStore."""
213-
214-
serialized_values = [pickle.dumps(v, protocol=pickle.HIGHEST_PROTOCOL) for v in batch_values]
215-
216-
# FIXME: When MooncakeStore supports per-key status codes for upsert_batch and get_batch,
217-
# switch the bytes write/read paths from whole-batch retry to per-key selective retry,
218-
# matching the tensor-path behaviour.
219-
ret = self._store.upsert_batch(batch_keys, serialized_values, self.replica_config)
220-
if ret == 0:
221-
return
247+
self._store.unregister_buffer(big_buf.data_ptr())
222248

223-
logger.error(
224-
f"upsert_batch failed for {len(batch_keys)} keys with error code: {ret}. "
225-
f"Retrying up to {MAX_RETRIES} times..."
226-
)
227-
228-
for attempt in range(1, MAX_RETRIES + 1):
229-
ret = self._store.upsert_batch(batch_keys, serialized_values, self.replica_config)
230-
if ret == 0:
231-
logger.info("upsert_batch succeeded after retransmission.")
232-
return
233-
234-
logger.error(
235-
f"upsert_batch retry {attempt}/{MAX_RETRIES} failed for {len(batch_keys)} keys with error code: {ret}."
236-
)
237-
if attempt < MAX_RETRIES:
238-
time.sleep(RETRY_DELAY_SECONDS)
239-
240-
raise RuntimeError(
241-
f"upsert_batch failed for {len(batch_keys)} keys with error code: {ret} after retrying {MAX_RETRIES} times."
242-
)
243-
244-
def get(
245-
self,
246-
keys: list[str],
247-
shapes: list[Any] | None = None,
248-
dtypes: list[Any] | None = None,
249-
custom_backend_meta: list[str] | None = None,
250-
) -> list[Any]:
249+
def get(self, keys: list[str], **kwargs) -> list[Any]:
251250
"""Get multiple key-value pairs from MooncakeStore.
252251
253252
Args:
254253
keys: Keys to fetch.
255-
shapes: Expected tensor shapes (use [] for scalars).
256-
dtypes: Expected dtypes; use None for non-tensor data.
257-
custom_backend_meta: Optional custom backend metadata.
254+
**kwargs: Must contain ``custom_backend_meta`` — per-key dicts
255+
carrying ``"packed_size": int``.
258256
259257
Returns:
260258
Retrieved values in the same order as input keys.
261259
"""
260+
if not type(self)._logged_first_get:
261+
logger.info("[TQ-MOONCAKE-REFACTOR] get() entered: unified unpack+detach data path")
262+
type(self)._logged_first_get = True
263+
264+
custom_backend_meta = kwargs.get("custom_backend_meta")
265+
if custom_backend_meta is None:
266+
raise ValueError("MooncakeStoreClient.get requires custom_backend_meta with per-key packed_size.")
267+
if len(custom_backend_meta) != len(keys):
268+
raise ValueError(
269+
f"Length of custom_backend_meta ({len(custom_backend_meta)}) must match keys ({len(keys)})"
270+
)
262271

263-
if shapes is None or dtypes is None:
264-
raise ValueError("MooncakeStoreClient needs shapes and dtypes for zero-copy transfer.")
265-
if not (len(keys) == len(shapes) == len(dtypes)):
266-
raise ValueError("Lengths of keys, shapes, dtypes must match")
267-
268-
tensor_indices = []
269-
non_tensor_indices = []
270-
271-
for i, dtype in enumerate(dtypes):
272-
if dtype is not None:
273-
tensor_indices.append(i)
274-
else:
275-
non_tensor_indices.append(i)
272+
try:
273+
packed_sizes = [m["packed_size"] for m in custom_backend_meta]
274+
except (KeyError, TypeError) as e:
275+
raise ValueError("custom_backend_meta entries must be dicts with 'packed_size'") from e
276276

277-
results = [None] * len(keys)
277+
results: list[Any] = [None] * len(keys)
278278

279+
# TODO: when MooncakeStore exposes a pre-registered receive-buffer API
280+
# (symmetric to YuanRong's get_buffers), drop the local alloc + register
281+
# below and hand decoded views straight from MooncakeStore's memory.
279282
futures = []
280283
with ThreadPoolExecutor(max_workers=MAX_WORKER_THREADS) as executor:
281-
for i in range(0, len(tensor_indices), BATCH_SIZE_LIMIT):
282-
batch_indexes = tensor_indices[i : i + BATCH_SIZE_LIMIT]
283-
batch_keys = [keys[i] for i in batch_indexes]
284-
batch_shapes = [shapes[i] for i in batch_indexes]
285-
batch_dtypes = [dtypes[i] for i in batch_indexes]
284+
for start in range(0, len(keys), BATCH_SIZE_LIMIT):
285+
end = min(start + BATCH_SIZE_LIMIT, len(keys))
286+
batch_keys = keys[start:end]
287+
batch_shapes = [(packed_sizes[j],) for j in range(start, end)]
288+
batch_dtypes = [torch.uint8] * (end - start)
289+
batch_indexes = list(range(start, end))
286290
futures.append(
287291
executor.submit(
288292
self._get_tensors_thread_worker, batch_keys, batch_shapes, batch_dtypes, batch_indexes
289293
)
290294
)
291295

292-
for i in range(0, len(non_tensor_indices), BATCH_SIZE_LIMIT):
293-
batch_indexes = non_tensor_indices[i : i + BATCH_SIZE_LIMIT]
294-
batch_keys = [keys[i] for i in batch_indexes]
295-
futures.append(executor.submit(self._get_bytes_thread_worker, batch_keys, batch_indexes))
296-
297296
for future in as_completed(futures):
298-
retrieved_values, batch_indexes = future.result()
299-
for idx, val in zip(batch_indexes, retrieved_values, strict=True):
300-
results[idx] = val
297+
packed_tensors, indexes = future.result()
298+
for idx, packed in zip(indexes, packed_tensors, strict=True):
299+
results[idx] = packed
300+
301+
for idx, packed in enumerate(results):
302+
if packed is None:
303+
continue
304+
items = serial_utils.unpack_from(packed.numpy().data)
305+
results[idx] = _detach_from_buffer(serial_utils.decode(items))
301306

302307
return results
303308

@@ -374,57 +379,6 @@ def _get_tensors_thread_worker(
374379

375380
return batch_buffer_tensors, indexes
376381

377-
def _get_bytes_thread_worker(self, batch_keys: list[str], indexes: list[int]) -> tuple[list[Any], list[int]]:
378-
raw_results = self._store.get_batch(batch_keys)
379-
if len(raw_results) != len(batch_keys):
380-
raise RuntimeError(f"get_batch returned {len(raw_results)} items, expected {len(batch_keys)}")
381-
382-
# FIXME: Use MooncakeStore provided ret codes to detect transmission failures when supported
383-
# Currently we rely on empty bytes (b'') to detect transmission failures because
384-
# MooncakeStore does not currently return a separate status code per key.
385-
failed_indices = [i for i, result in enumerate(raw_results) if result == b""]
386-
if failed_indices:
387-
current_failed_keys = [batch_keys[i] for i in failed_indices]
388-
current_failed_indices = failed_indices
389-
390-
logger.error(f"get_batch failed for keys {current_failed_keys}. Retrying up to {MAX_RETRIES} times...")
391-
392-
for attempt in range(1, MAX_RETRIES + 1):
393-
retry_results = self._store.get_batch(current_failed_keys)
394-
395-
next_failed_keys = []
396-
next_failed_indices = []
397-
398-
for i, result in enumerate(retry_results):
399-
original_idx = current_failed_indices[i]
400-
if result == b"":
401-
next_failed_keys.append(current_failed_keys[i])
402-
next_failed_indices.append(original_idx)
403-
else:
404-
# Write the successfully retried value back to its original slot immediately.
405-
raw_results[original_idx] = result
406-
407-
if not next_failed_indices:
408-
logger.info("get_batch succeeded after retransmission.")
409-
break # All retries in this attempt succeeded.
410-
411-
logger.error(f"get_batch retry {attempt}/{MAX_RETRIES} failed for {len(next_failed_keys)} keys.")
412-
413-
# Narrow down to still-failed items for the next retry attempt.
414-
current_failed_keys = next_failed_keys
415-
current_failed_indices = next_failed_indices
416-
417-
if attempt < MAX_RETRIES:
418-
time.sleep(RETRY_DELAY_SECONDS)
419-
else:
420-
# All retries exhausted.
421-
raise RuntimeError(
422-
f"get_batch failed for keys {current_failed_keys} after retrying {MAX_RETRIES} times."
423-
)
424-
425-
deserialized_results = [pickle.loads(result) if result != b"" else None for result in raw_results]
426-
return deserialized_results, indexes
427-
428382
def clear(self, keys: list[str], custom_backend_meta: list[Any] | None = None) -> None:
429383
"""Deletes multiple keys from MooncakeStore.
430384
@@ -443,23 +397,6 @@ def close(self):
443397
self._store.close()
444398
self._store = None
445399

446-
@staticmethod
447-
def _preprocess_tensors_for_put(values: list[Tensor]) -> tuple[list[int], list[int], list[Tensor]]:
448-
ptr_list: list[int] = []
449-
size_list: list[int] = []
450-
tensor_list: list[Tensor] = [] # hold reference for the contiguous tensor
451-
for t in values:
452-
# TODO: support gpu direct rdma and use different data paths.
453-
# For GPU, it's more reasonable to perform data copy since
454-
# The register overhead is much higher than CPU
455-
if t.device.type == "cuda":
456-
t = t.cpu()
457-
t = t.contiguous()
458-
tensor_list.append(t)
459-
ptr_list.append(t.data_ptr())
460-
size_list.append(t.nbytes)
461-
return ptr_list, size_list, tensor_list
462-
463400
def _register_all_buffers(self, ptrs, sizes):
464401
for ptr, size in zip(ptrs, sizes, strict=True):
465402
self._store.register_buffer(ptr, size)

0 commit comments

Comments
 (0)