From 0a1e7ac937de4ef2aece3d92a55e934da8d628e7 Mon Sep 17 00:00:00 2001 From: xupinjie Date: Mon, 25 May 2026 07:31:55 -0700 Subject: [PATCH 1/5] refactor(mooncake): non-tensor data paths update --- .../storage/clients/mooncake_client.py | 401 +++++++++--------- transfer_queue/utils/serial_utils.py | 110 ++++- 2 files changed, 313 insertions(+), 198 deletions(-) diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index ede1de4e..a0c8412f 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pickle import time from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any @@ -22,6 +21,7 @@ from torch import Tensor from transfer_queue.storage.clients.base import StorageClientFactory, StorageKVClient +from transfer_queue.utils import serial_utils from transfer_queue.utils.logging_utils import get_logger from transfer_queue.utils.tensor_utils import allocate_empty_tensors, get_nbytes, merge_contiguous_memory @@ -98,12 +98,17 @@ def __init__(self, config: dict[str, Any]): if ret != 0: raise RuntimeError(f"Mooncake store setup failed with error code: {ret}") - def put(self, keys: list[str], values: list[Any]) -> None: + def put(self, keys: list[str], values: list[Any]) -> list[dict | None]: """Stores multiple key-value pairs to MooncakeStore. Args: keys (List[str]): List of unique string identifiers. values (List[Any]): List of values to store (tensors, scalars, dicts, etc.). + + Returns: + Per-key metadata aligned with ``keys``. Tensor entries are ``None``; + non-tensor entries carry ``{"packed_size": int}`` so the get-side + can pre-allocate the receive buffer. """ if not isinstance(keys, list) or not isinstance(values, list): @@ -124,22 +129,34 @@ def put(self, keys: list[str], values: list[Any]) -> None: non_tensor_keys.append(key) non_tensor_values.append(value) - futures = [] + tensor_futures = [] + bytes_futures = [] with ThreadPoolExecutor(max_workers=MAX_WORKER_THREADS) as executor: for i in range(0, len(tensor_keys), BATCH_SIZE_LIMIT): batch_keys = tensor_keys[i : i + BATCH_SIZE_LIMIT] batch_tensors = tensor_values[i : i + BATCH_SIZE_LIMIT] - futures.append(executor.submit(self._put_tensors_thread_worker, batch_keys, batch_tensors)) + tensor_futures.append(executor.submit(self._put_tensors_thread_worker, batch_keys, batch_tensors)) for i in range(0, len(non_tensor_keys), BATCH_SIZE_LIMIT): batch_keys = non_tensor_keys[i : i + BATCH_SIZE_LIMIT] batch_values = non_tensor_values[i : i + BATCH_SIZE_LIMIT] - futures.append(executor.submit(self._put_bytes_thread_worker, batch_keys, batch_values)) + bytes_futures.append(executor.submit(self._put_bytes_thread_worker, batch_keys, batch_values)) - for future in as_completed(futures): + for future in tensor_futures: future.result() + packed_sizes = [] + for future in bytes_futures: + packed_sizes.extend(future.result()) + + # bytes results arrive in non-tensor submit order, which matches the order of + # non-tensor values; walk values once to scatter packed_size back to its key slot. + custom_meta: list[dict | None] = [None] * len(keys) + sizes_iter = iter(packed_sizes) + for i, value in enumerate(values): + if not isinstance(value, torch.Tensor): + custom_meta[i] = {"packed_size": next(sizes_iter)} - return None + return custom_meta def _put_tensors_thread_worker(self, batch_keys: list[str], batch_tensors: list[Tensor]) -> None: """Worker thread for putting batch of tensors to MooncakeStore.""" @@ -147,106 +164,42 @@ def _put_tensors_thread_worker(self, batch_keys: list[str], batch_tensors: list[ batch_ptrs, batch_sizes, _contiguous_tensors = self._preprocess_tensors_for_put(batch_tensors) batch_ptr_reduced, batch_sizes_reduced = merge_contiguous_memory(batch_ptrs, batch_sizes) self._register_all_buffers(batch_ptr_reduced, batch_sizes_reduced) - try: - results = self._store.batch_upsert_from(batch_keys, batch_ptrs, batch_sizes, config=self.replica_config) - if len(results) != len(batch_keys): - raise RuntimeError(f"batch_upsert_from returned {len(results)} results, expected {len(batch_keys)}") - - failed_indices = [j for j, r in enumerate(results) if r != 0] - if not failed_indices: - return - - current_failed_keys = [batch_keys[i] for i in failed_indices] - current_failed_codes = [results[i] for i in failed_indices] - current_failed_indices = failed_indices - - logger.error( - f"batch_upsert_from failed for keys {current_failed_keys} with error codes {current_failed_codes}. " - f"Retrying up to {MAX_RETRIES} times..." - ) - - for attempt in range(1, MAX_RETRIES + 1): - retry_ptrs = [batch_ptrs[i] for i in current_failed_indices] - retry_sizes = [batch_sizes[i] for i in current_failed_indices] - - retry_results = self._store.batch_upsert_from( - current_failed_keys, retry_ptrs, retry_sizes, config=self.replica_config - ) - - next_failed_indices = [] - next_failed_keys = [] - next_failed_codes = [] - - for i, ret in enumerate(retry_results): - if ret != 0: - next_failed_indices.append(current_failed_indices[i]) - next_failed_keys.append(current_failed_keys[i]) - next_failed_codes.append(ret) - - if not next_failed_indices: - logger.info("batch_upsert_from succeeded after retransmission.") - break # All retries in this attempt succeeded. - - logger.error( - f"batch_upsert_from retry {attempt}/{MAX_RETRIES} failed for {len(next_failed_keys)} keys " - f"with error codes {next_failed_codes}." - ) - - current_failed_indices = next_failed_indices - current_failed_keys = next_failed_keys - current_failed_codes = next_failed_codes - - if attempt < MAX_RETRIES: - time.sleep(RETRY_DELAY_SECONDS) - else: - raise RuntimeError( - f"batch_upsert_from failed for keys {current_failed_keys} with error codes " - f"{current_failed_codes} after retrying {MAX_RETRIES} times." - ) - + self._batch_upsert_with_retry(batch_keys, batch_ptrs, batch_sizes) finally: self._unregister_all_buffers(batch_ptr_reduced) - def _put_bytes_thread_worker(self, batch_keys: list[str], batch_values: list[Any]): + def _put_bytes_thread_worker(self, batch_keys: list[str], batch_values: list[Any]) -> list[int]: """Worker thread for putting batch of non-tensors to MooncakeStore.""" - serialized_values = [pickle.dumps(v, protocol=pickle.HIGHEST_PROTOCOL) for v in batch_values] + # [TQ-REFACTOR-VERIFY] one-shot marker — REMOVE after verifying new bytes-put data path is hit. + if not type(self).__dict__.get("_put_bytes_verified"): + type(self)._put_bytes_verified = True - # FIXME: When MooncakeStore supports per-key status codes for upsert_batch and get_batch, - # switch the bytes write/read paths from whole-batch retry to per-key selective retry, - # matching the tensor-path behaviour. - ret = self._store.upsert_batch(batch_keys, serialized_values, self.replica_config) - if ret == 0: - return - - logger.error( - f"upsert_batch failed for {len(batch_keys)} keys with error code: {ret}. " - f"Retrying up to {MAX_RETRIES} times..." - ) + # Encode + pack happens in-thread (so CPU serialization parallelizes); + # alloc is our hook to get a torch buffer of the right total size. + # TODO: switch to a pre-registered buffer from MooncakeStore once such + # an API is available, so we can skip the explicit register/unregister. + def alloc(total: int) -> tuple[Tensor, int]: + tensors, ptrs, _, _ = allocate_empty_tensors([torch.uint8], [(total,)]) + return tensors[0], ptrs[0] - for attempt in range(1, MAX_RETRIES + 1): - ret = self._store.upsert_batch(batch_keys, serialized_values, self.replica_config) - if ret == 0: - logger.info("upsert_batch succeeded after retransmission.") - return + big_buf, batch_ptrs, batch_sizes = serial_utils.batch_encode_into(batch_values, alloc) - logger.error( - f"upsert_batch retry {attempt}/{MAX_RETRIES} failed for {len(batch_keys)} keys with error code: {ret}." - ) - if attempt < MAX_RETRIES: - time.sleep(RETRY_DELAY_SECONDS) + self._store.register_buffer(big_buf.data_ptr(), big_buf.nbytes) + try: + self._batch_upsert_with_retry(batch_keys, batch_ptrs, batch_sizes) + finally: + self._store.unregister_buffer(big_buf.data_ptr()) - raise RuntimeError( - f"upsert_batch failed for {len(batch_keys)} keys with error code: {ret} after retrying {MAX_RETRIES} times." - ) + return batch_sizes def get( self, keys: list[str], shapes: list[Any] | None = None, dtypes: list[Any] | None = None, - custom_backend_meta: list[str] | None = None, + custom_backend_meta: list[dict | None] | None = None, ) -> list[Any]: """Get multiple key-value pairs from MooncakeStore. @@ -254,7 +207,8 @@ def get( keys: Keys to fetch. shapes: Expected tensor shapes (use [] for scalars). dtypes: Expected dtypes; use None for non-tensor data. - custom_backend_meta: Optional custom backend metadata. + custom_backend_meta: Per-key dicts; non-tensor entries must carry + ``{"packed_size": int}`` so the receive buffer can be sized. Returns: Retrieved values in the same order as input keys. @@ -274,6 +228,11 @@ def get( else: non_tensor_indices.append(i) + if non_tensor_indices and (custom_backend_meta is None or len(custom_backend_meta) != len(keys)): + raise ValueError( + "custom_backend_meta with per-key packed_size is required when any dtype is None." + ) + results = [None] * len(keys) futures = [] @@ -292,7 +251,12 @@ def get( for i in range(0, len(non_tensor_indices), BATCH_SIZE_LIMIT): batch_indexes = non_tensor_indices[i : i + BATCH_SIZE_LIMIT] batch_keys = [keys[i] for i in batch_indexes] - futures.append(executor.submit(self._get_bytes_thread_worker, batch_keys, batch_indexes)) + batch_packed_sizes = [custom_backend_meta[j]["packed_size"] for j in batch_indexes] + futures.append( + executor.submit( + self._get_bytes_thread_worker, batch_keys, batch_packed_sizes, batch_indexes + ) + ) for future in as_completed(futures): retrieved_values, batch_indexes = future.result() @@ -311,137 +275,180 @@ def _get_tensors_thread_worker( self._register_all_buffers(region_ptrs, region_sizes) try: - ret_codes = self._store.batch_get_into(batch_keys, batch_buffer_ptrs, batch_nbytes) - if len(ret_codes) != len(batch_keys): - raise RuntimeError(f"batch_get_into returned {len(ret_codes)} results, expected {len(batch_keys)}") + self._batch_get_into_with_retry(batch_keys, batch_buffer_ptrs, batch_nbytes) + finally: + self._unregister_all_buffers(region_ptrs) - failed_indices = [i for i, ret in enumerate(ret_codes) if ret < 0] - if not failed_indices: - return batch_buffer_tensors, indexes + return batch_buffer_tensors, indexes - # error handling - current_failed_keys = [batch_keys[i] for i in failed_indices] - current_failed_codes = [ret_codes[i] for i in failed_indices] - current_failed_indices = failed_indices + def _get_bytes_thread_worker( + self, batch_keys: list[str], batch_packed_sizes: list[int], indexes: list[int] + ) -> tuple[list[Any], list[int]]: + # [TQ-REFACTOR-VERIFY] one-shot marker — REMOVE after verifying new bytes-get data path is hit. + if not type(self).__dict__.get("_get_bytes_verified"): + type(self)._get_bytes_verified = True + + # Allocate uint8 receive buffers of packed_size; unpack + decode in-thread. + # Decoded tensors are zero-copy views; their storage transitively keeps + # batch_buffer_tensors alive after this function returns. + batch_shapes = [(sz,) for sz in batch_packed_sizes] + batch_dtypes = [torch.uint8] * len(batch_keys) + batch_nbytes = get_nbytes(batch_dtypes, batch_shapes) + batch_buffer_tensors, batch_buffer_ptrs, region_ptrs, region_sizes = allocate_empty_tensors( + batch_dtypes, batch_shapes + ) - logger.error( - f"batch_get_into failed for keys {current_failed_keys} with error codes {current_failed_codes}. " - f"Retrying up to {MAX_RETRIES} times..." - ) + self._register_all_buffers(region_ptrs, region_sizes) + try: + self._batch_get_into_with_retry(batch_keys, batch_buffer_ptrs, batch_nbytes) + finally: + self._unregister_all_buffers(region_ptrs) + + return serial_utils.batch_decode_from(batch_buffer_tensors), indexes + + def clear(self, keys: list[str], custom_backend_meta: list[Any] | None = None) -> None: + """Deletes multiple keys from MooncakeStore. + + Args: + keys (List[str]): List of keys to remove. + custom_backend_meta (List[Any], optional): ... + """ + ret_codes = self._store.batch_remove(keys, force=True) + for i, ret in enumerate(ret_codes): + if not (ret == 0 or ret == -704): + logger.error(f"remove failed for key `{keys[i]}` with error code: {ret}") - for attempt in range(1, MAX_RETRIES + 1): - # Reuse the originally allocated pointers; no need to allocate/register new buffers. - retry_ptrs = [batch_buffer_ptrs[i] for i in current_failed_indices] - retry_nbytes = [batch_nbytes[i] for i in current_failed_indices] + def close(self): + """Closes MooncakeStore.""" + if self._store: + self._store.close() + self._store = None - retry_codes = self._store.batch_get_into(current_failed_keys, retry_ptrs, retry_nbytes) + def _batch_upsert_with_retry( + self, batch_keys: list[str], batch_ptrs: list[int], batch_sizes: list[int] + ) -> None: + """Run ``batch_upsert_from`` with per-key retry; raise on permanent failure. - next_failed_indices = [] - next_failed_keys = [] - next_failed_codes = [] + Caller owns the memory regions (register/unregister and lifetime of the + backing tensors/buffers). + """ + results = self._store.batch_upsert_from(batch_keys, batch_ptrs, batch_sizes, config=self.replica_config) + if len(results) != len(batch_keys): + raise RuntimeError(f"batch_upsert_from returned {len(results)} results, expected {len(batch_keys)}") - for i, ret in enumerate(retry_codes): - if ret < 0: - next_failed_indices.append(current_failed_indices[i]) - next_failed_keys.append(current_failed_keys[i]) - next_failed_codes.append(ret) + failed_indices = [j for j, r in enumerate(results) if r != 0] + if not failed_indices: + return - if not next_failed_indices: - logger.info("batch_get_into succeeded after retransmission.") - break # All retries in this attempt succeeded. + current_failed_keys = [batch_keys[i] for i in failed_indices] + current_failed_codes = [results[i] for i in failed_indices] + current_failed_indices = failed_indices - logger.error( - f"batch_get_into retry {attempt}/{MAX_RETRIES} failed for {len(next_failed_keys)} keys " - f"with error codes {next_failed_codes}." - ) + logger.error( + f"batch_upsert_from failed for keys {current_failed_keys} with error codes {current_failed_codes}. " + f"Retrying up to {MAX_RETRIES} times..." + ) - # Narrow down to still-failed items for the next retry attempt. - current_failed_indices = next_failed_indices - current_failed_keys = next_failed_keys - current_failed_codes = next_failed_codes + for attempt in range(1, MAX_RETRIES + 1): + retry_ptrs = [batch_ptrs[i] for i in current_failed_indices] + retry_sizes = [batch_sizes[i] for i in current_failed_indices] - if attempt < MAX_RETRIES: - time.sleep(RETRY_DELAY_SECONDS) - else: - # All retries exhausted. - raise RuntimeError( - f"batch_get_into failed for keys {current_failed_keys} with error codes " - f"{current_failed_codes} after retrying {MAX_RETRIES} times." - ) + retry_results = self._store.batch_upsert_from( + current_failed_keys, retry_ptrs, retry_sizes, config=self.replica_config + ) - finally: - self._unregister_all_buffers(region_ptrs) + next_failed_indices = [] + next_failed_keys = [] + next_failed_codes = [] - return batch_buffer_tensors, indexes + for i, ret in enumerate(retry_results): + if ret != 0: + next_failed_indices.append(current_failed_indices[i]) + next_failed_keys.append(current_failed_keys[i]) + next_failed_codes.append(ret) + + if not next_failed_indices: + logger.info("batch_upsert_from succeeded after retransmission.") + return + + logger.error( + f"batch_upsert_from retry {attempt}/{MAX_RETRIES} failed for {len(next_failed_keys)} keys " + f"with error codes {next_failed_codes}." + ) + + current_failed_indices = next_failed_indices + current_failed_keys = next_failed_keys + current_failed_codes = next_failed_codes + + if attempt < MAX_RETRIES: + time.sleep(RETRY_DELAY_SECONDS) - def _get_bytes_thread_worker(self, batch_keys: list[str], indexes: list[int]) -> tuple[list[Any], list[int]]: - raw_results = self._store.get_batch(batch_keys) - if len(raw_results) != len(batch_keys): - raise RuntimeError(f"get_batch returned {len(raw_results)} items, expected {len(batch_keys)}") + raise RuntimeError( + f"batch_upsert_from failed for keys {current_failed_keys} with error codes " + f"{current_failed_codes} after retrying {MAX_RETRIES} times." + ) - # FIXME: Use MooncakeStore provided ret codes to detect transmission failures when supported - # Currently we rely on empty bytes (b'') to detect transmission failures because - # MooncakeStore does not currently return a separate status code per key. - failed_indices = [i for i, result in enumerate(raw_results) if result == b""] - if failed_indices: - current_failed_keys = [batch_keys[i] for i in failed_indices] - current_failed_indices = failed_indices + def _batch_get_into_with_retry( + self, batch_keys: list[str], batch_buffer_ptrs: list[int], batch_nbytes: list[int] + ) -> None: + """Run ``batch_get_into`` with per-key retry; raise on permanent failure. - logger.error(f"get_batch failed for keys {current_failed_keys}. Retrying up to {MAX_RETRIES} times...") + Caller owns the receive buffers (allocate/register/unregister). + """ + ret_codes = self._store.batch_get_into(batch_keys, batch_buffer_ptrs, batch_nbytes) + if len(ret_codes) != len(batch_keys): + raise RuntimeError(f"batch_get_into returned {len(ret_codes)} results, expected {len(batch_keys)}") - for attempt in range(1, MAX_RETRIES + 1): - retry_results = self._store.get_batch(current_failed_keys) + failed_indices = [i for i, ret in enumerate(ret_codes) if ret < 0] + if not failed_indices: + return - next_failed_keys = [] - next_failed_indices = [] + current_failed_keys = [batch_keys[i] for i in failed_indices] + current_failed_codes = [ret_codes[i] for i in failed_indices] + current_failed_indices = failed_indices - for i, result in enumerate(retry_results): - original_idx = current_failed_indices[i] - if result == b"": - next_failed_keys.append(current_failed_keys[i]) - next_failed_indices.append(original_idx) - else: - # Write the successfully retried value back to its original slot immediately. - raw_results[original_idx] = result + logger.error( + f"batch_get_into failed for keys {current_failed_keys} with error codes {current_failed_codes}. " + f"Retrying up to {MAX_RETRIES} times..." + ) - if not next_failed_indices: - logger.info("get_batch succeeded after retransmission.") - break # All retries in this attempt succeeded. + for attempt in range(1, MAX_RETRIES + 1): + # Reuse the originally allocated pointers; no need to allocate/register new buffers. + retry_ptrs = [batch_buffer_ptrs[i] for i in current_failed_indices] + retry_nbytes = [batch_nbytes[i] for i in current_failed_indices] - logger.error(f"get_batch retry {attempt}/{MAX_RETRIES} failed for {len(next_failed_keys)} keys.") + retry_codes = self._store.batch_get_into(current_failed_keys, retry_ptrs, retry_nbytes) - # Narrow down to still-failed items for the next retry attempt. - current_failed_keys = next_failed_keys - current_failed_indices = next_failed_indices + next_failed_indices = [] + next_failed_keys = [] + next_failed_codes = [] - if attempt < MAX_RETRIES: - time.sleep(RETRY_DELAY_SECONDS) - else: - # All retries exhausted. - raise RuntimeError( - f"get_batch failed for keys {current_failed_keys} after retrying {MAX_RETRIES} times." - ) + for i, ret in enumerate(retry_codes): + if ret < 0: + next_failed_indices.append(current_failed_indices[i]) + next_failed_keys.append(current_failed_keys[i]) + next_failed_codes.append(ret) - deserialized_results = [pickle.loads(result) if result != b"" else None for result in raw_results] - return deserialized_results, indexes + if not next_failed_indices: + logger.info("batch_get_into succeeded after retransmission.") + return - def clear(self, keys: list[str], custom_backend_meta: list[Any] | None = None) -> None: - """Deletes multiple keys from MooncakeStore. + logger.error( + f"batch_get_into retry {attempt}/{MAX_RETRIES} failed for {len(next_failed_keys)} keys " + f"with error codes {next_failed_codes}." + ) - Args: - keys (List[str]): List of keys to remove. - custom_backend_meta (List[Any], optional): ... - """ - ret_codes = self._store.batch_remove(keys, force=True) - for i, ret in enumerate(ret_codes): - if not (ret == 0 or ret == -704): - logger.error(f"remove failed for key `{keys[i]}` with error code: {ret}") + current_failed_indices = next_failed_indices + current_failed_keys = next_failed_keys + current_failed_codes = next_failed_codes - def close(self): - """Closes MooncakeStore.""" - if self._store: - self._store.close() - self._store = None + if attempt < MAX_RETRIES: + time.sleep(RETRY_DELAY_SECONDS) + + raise RuntimeError( + f"batch_get_into failed for keys {current_failed_keys} with error codes " + f"{current_failed_codes} after retrying {MAX_RETRIES} times." + ) @staticmethod def _preprocess_tensors_for_put(values: list[Tensor]) -> tuple[list[int], list[int], list[Tensor]]: diff --git a/transfer_queue/utils/serial_utils.py b/transfer_queue/utils/serial_utils.py index 853b3a3e..124bbb23 100644 --- a/transfer_queue/utils/serial_utils.py +++ b/transfer_queue/utils/serial_utils.py @@ -18,8 +18,9 @@ import pickle +import struct import warnings -from collections.abc import Sequence +from collections.abc import Callable, Sequence from contextvars import ContextVar from typing import Any, TypeAlias @@ -387,3 +388,110 @@ def decode(frames: list) -> Any: if len(frames) >= 2 and frames[0] == _PICKLE_FALLBACK_SENTINEL: return pickle.loads(frames[1]) return _decoder.decode(frames) + + +# Packed buffer layout: +# [item_count: uint32 LE] +# [N × (payload_offset: uint32 LE, payload_size: uint32 LE)] +# [payload_0 ... payload_{N-1}] +_PACK_HEADER_FMT = " int: + """Total bytes required to pack ``items`` into one buffer.""" + return _PACK_HEADER_SIZE + len(items) * _PACK_ENTRY_SIZE + sum(memoryview(item).nbytes for item in items) + + +def pack_into(target: bytestr, items: Sequence[bytestr]) -> None: + """Concatenate ``items`` into ``target``, which must be at least ``calc_packed_size(items)`` bytes.""" + target_mv = memoryview(target) + struct.pack_into(_PACK_HEADER_FMT, target_mv, 0, len(items)) + + entry_offset = _PACK_HEADER_SIZE + payload_offset = _PACK_HEADER_SIZE + len(items) * _PACK_ENTRY_SIZE + + target_tensor = torch.frombuffer(target_mv, dtype=torch.uint8) + + for item in items: + item_mv = memoryview(item) + nbytes = item_mv.nbytes + struct.pack_into(_PACK_ENTRY_FMT, target_mv, entry_offset, payload_offset, nbytes) + src_tensor = torch.frombuffer(item_mv, dtype=torch.uint8) + target_tensor[payload_offset : payload_offset + nbytes].copy_(src_tensor) + entry_offset += _PACK_ENTRY_SIZE + payload_offset += nbytes + + +def unpack_from(source: bytestr) -> list[memoryview]: + """Split a packed buffer back into N memoryview slices over ``source``.""" + mv = memoryview(source) + item_count = struct.unpack_from(_PACK_HEADER_FMT, mv, 0)[0] + result: list[memoryview] = [] + for i in range(item_count): + offset, length = struct.unpack_from(_PACK_ENTRY_FMT, mv, _PACK_HEADER_SIZE + i * _PACK_ENTRY_SIZE) + result.append(mv[offset : offset + length]) + return result + + +def batch_encode_into( + values: list[Any], + alloc: Callable[[int], tuple[Any, int]], +) -> tuple[Any, list[int], list[int]]: + """Encode all ``values`` into a single caller-allocated buffer. + + Encoding is done once per value (msgpack zero-copy + transparent pickle + fallback); each encoded frame is then ``pack_into`` its slice of the + buffer returned by ``alloc``. + + Args: + values: Values to encode. + alloc: Backend-supplied factory ``(total_size) -> (buf, base_ptr)``. + ``buf`` must keep the underlying storage alive; ``base_ptr`` is + the address of the first byte. Register/unregister for RDMA-style + backends remain the caller's responsibility. + + Returns: + ``(buf, ptrs, sizes)``: + * ``buf`` is whatever ``alloc`` returned; the caller must keep this + alive while ``ptrs`` are in flight (and for register/unregister). + * ``ptrs[i]`` / ``sizes[i]`` are the absolute pointer and byte length + of value ``i``'s packed slice within ``buf``. + """ + batch_items = [encode(v) for v in values] + batch_sizes = [calc_packed_size(items) for items in batch_items] + buf, base_ptr = alloc(sum(batch_sizes)) + # torch tensors don't directly implement the buffer protocol; reach in + # through numpy. Anything else (bytearray, ndarray) supports memoryview(). + buf_mv = buf.numpy().data if hasattr(buf, "numpy") else memoryview(buf) + + batch_ptrs: list[int] = [] + offset = 0 + for items, size in zip(batch_items, batch_sizes, strict=True): + pack_into(buf_mv[offset : offset + size], items) + batch_ptrs.append(base_ptr + offset) + offset += size + + return buf, batch_ptrs, batch_sizes + + +def batch_decode_from(buffers: Sequence[Any]) -> list[Any]: + """Reverse of ``batch_encode_into``: for each filled buffer, unpack + decode. + + Tensors / ndarrays in the result are zero-copy views over ``buf``. The + natural Python ref chain (``torch.frombuffer`` → ``Py_buffer`` → memoryview + slice → parent memoryview → numpy array → original ``buf``) keeps the + source alive as long as the returned object is reachable. Caller does NOT + need to retain ``buf`` separately. + + Args: + buffers: Per-value receive buffers, in order. Each must support the + buffer protocol (``torch.Tensor`` via ``.numpy().data``; + ``bytearray`` / ``ndarray`` directly). + """ + return [ + decode(unpack_from(buf.numpy().data if hasattr(buf, "numpy") else memoryview(buf))) + for buf in buffers + ] From f2f8dd2ec720a887ef90e3acd99c783decaa0370 Mon Sep 17 00:00:00 2001 From: xupinjie Date: Mon, 25 May 2026 08:53:25 -0700 Subject: [PATCH 2/5] add serial utils batch encoder docoder test --- tests/test_serial_utils_batch_on_cpu.py | 250 ++++++++++++++++++++++++ 1 file changed, 250 insertions(+) create mode 100644 tests/test_serial_utils_batch_on_cpu.py diff --git a/tests/test_serial_utils_batch_on_cpu.py b/tests/test_serial_utils_batch_on_cpu.py new file mode 100644 index 00000000..d5a60af1 --- /dev/null +++ b/tests/test_serial_utils_batch_on_cpu.py @@ -0,0 +1,250 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the packed-buffer batch serialization helpers in +``transfer_queue.utils.serial_utils``: + +* ``calc_packed_size`` +* ``pack_into`` / ``unpack_from`` +* ``batch_encode_into`` +* ``batch_decode_from`` +""" + +import ctypes + +import numpy as np +import pytest +import torch + +from transfer_queue.utils import serial_utils + + +# ============================================================================ +# low-level: calc_packed_size + pack_into + unpack_from (raw bytes layer) +# ============================================================================ + + +def test_calc_packed_size_then_pack_unpack_roundtrip(): + items = [b"hello", b"world!", b"x"] + size = serial_utils.calc_packed_size(items) + buf = bytearray(size) + serial_utils.pack_into(buf, items) + recovered = serial_utils.unpack_from(buf) + assert [bytes(mv) for mv in recovered] == items + + +def test_pack_unpack_roundtrip_random_bytes(): + rng = np.random.default_rng(0) + items = [rng.bytes(int(rng.integers(1, 2048))) for _ in range(16)] + size = serial_utils.calc_packed_size(items) + buf = bytearray(size) + serial_utils.pack_into(buf, items) + + recovered = serial_utils.unpack_from(buf) + assert len(recovered) == len(items) + for r, item in zip(recovered, items, strict=True): + assert bytes(r) == item + + +def test_pack_into_writes_only_within_its_slice(): + """``pack_into`` is called on a slice of a larger buf in ``batch_encode_into``; + it must not write outside the slice.""" + items = [b"alpha", b"beta", b"gamma"] + sz = serial_utils.calc_packed_size(items) + pad_before, pad_after = 17, 23 + big = bytearray(pad_before + sz + pad_after) + serial_utils.pack_into(memoryview(big)[pad_before : pad_before + sz], items) + + assert all(b == 0 for b in big[:pad_before]) + assert all(b == 0 for b in big[pad_before + sz :]) + + recovered = serial_utils.unpack_from(memoryview(big)[pad_before : pad_before + sz]) + assert [bytes(mv) for mv in recovered] == items + + +def test_unpack_from_zero_item_buffer(): + items: list[bytes] = [] + sz = serial_utils.calc_packed_size(items) + buf = bytearray(sz) + serial_utils.pack_into(buf, items) + assert serial_utils.unpack_from(buf) == [] + + +# ============================================================================ +# batch_encode_into + batch_decode_from (high-level batch layer) +# ============================================================================ + + +def _torch_alloc(total: int) -> tuple[torch.Tensor, int]: + """Mooncake-style alloc callback used by tests.""" + buf = torch.empty(total, dtype=torch.uint8) + return buf, buf.data_ptr() + + +def _bytearray_alloc(total: int) -> tuple[bytearray, int]: + """Non-torch alloc callback (exercises ``not hasattr(buf, 'numpy')`` branch).""" + buf = bytearray(total) + base_ptr = ctypes.addressof(ctypes.c_char.from_buffer(buf)) if total else 0 + return buf, base_ptr + + +def _roundtrip(values, alloc): + """Encode ``values`` via ``batch_encode_into``, then slice the big buf back into + per-value buffers and decode via ``batch_decode_from``. Returns decoded list.""" + buf, ptrs, sizes = serial_utils.batch_encode_into(values, alloc) + if isinstance(buf, torch.Tensor): + base = buf.data_ptr() + per_value = [buf[p - base : p - base + s] for p, s in zip(ptrs, sizes, strict=True)] + else: + base = ctypes.addressof(ctypes.c_char.from_buffer(buf)) if len(buf) else 0 + per_value = [bytes(buf[p - base : p - base + s]) for p, s in zip(ptrs, sizes, strict=True)] + return serial_utils.batch_decode_from(per_value), buf, ptrs, sizes + + +# ---- structural: return shapes / alloc contract / ptrs layout ---- + + +def test_batch_encode_into_return_shapes(): + values = [{"x": 1}, "a string", torch.arange(8, dtype=torch.float32)] + buf, ptrs, sizes = serial_utils.batch_encode_into(values, _torch_alloc) + + assert isinstance(buf, torch.Tensor) + assert len(ptrs) == len(values) + assert len(sizes) == len(values) + assert buf.nbytes == sum(sizes) + + +def test_batch_encode_into_alloc_called_with_total_size(): + captured = {} + + def spy_alloc(total: int): + captured["total"] = total + return _torch_alloc(total) + + values = [b"x" * 100, b"y" * 50, {"k": "v"}] + buf, _, sizes = serial_utils.batch_encode_into(values, spy_alloc) + assert captured["total"] == sum(sizes) == buf.nbytes + + +def test_batch_encode_into_ptrs_are_tightly_packed_within_buf(): + values = [torch.arange(i, dtype=torch.float32) for i in (4, 16, 64)] + buf, ptrs, sizes = serial_utils.batch_encode_into(values, _torch_alloc) + + base = buf.data_ptr() + expected_offset = 0 + for p, s in zip(ptrs, sizes, strict=True): + assert p == base + expected_offset + assert p + s <= base + buf.nbytes + expected_offset += s + assert expected_offset == buf.nbytes # no padding between slices + + +# ---- semantic: encode → decode roundtrip preserves values ---- + + +@pytest.mark.parametrize( + "values", + [ + pytest.param([42, 3.14, "hello", b"bytes"], id="primitives"), + pytest.param([{"a": 1, "b": [1, 2, 3]}, {"nested": {"k": "v"}}], id="nested-dicts"), + pytest.param([torch.arange(10, dtype=torch.float32)], id="single-tensor"), + pytest.param( + [ + torch.arange(100, dtype=torch.float32), + torch.randn(4, 4, dtype=torch.bfloat16), + torch.zeros(3, 5, dtype=torch.int64), + ], + id="mixed-tensors", + ), + pytest.param( + [np.arange(50, dtype=np.float64), np.ones((3, 3), dtype=np.int32)], + id="numpy-arrays", + ), + pytest.param( + [{"meta": "v1", "arr": torch.arange(5, dtype=torch.float32)}, [1, 2, "three"]], + id="heterogeneous", + ), + ], +) +def test_batch_encode_decode_roundtrip(values): + decoded, *_ = _roundtrip(values, _torch_alloc) + _assert_equal_payloads(decoded, values) + + +def test_batch_encode_decode_single_value(): + values = [{"only": "one", "tensor": torch.arange(3, dtype=torch.float32)}] + decoded, *_ = _roundtrip(values, _torch_alloc) + _assert_equal_payloads(decoded, values) + + +def test_batch_encode_decode_empty_list(): + """Empty input: alloc invoked with 0, no ptrs, no sizes, decode of [] returns [].""" + calls = [] + + def alloc(total: int): + calls.append(total) + return _torch_alloc(total) + + buf, ptrs, sizes = serial_utils.batch_encode_into([], alloc) + assert ptrs == [] and sizes == [] + assert calls == [0] + assert buf.nbytes == 0 + assert serial_utils.batch_decode_from([]) == [] + + +def test_batch_encode_decode_with_bytearray_alloc(): + """Exercise the non-torch buffer branch end-to-end.""" + values = [b"hello", b"world", {"k": 1}] + decoded, buf, _, sizes = _roundtrip(values, _bytearray_alloc) + assert isinstance(buf, bytearray) + assert len(buf) == sum(sizes) + _assert_equal_payloads(decoded, values) + + +def test_batch_decode_from_accepts_torch_tensor_slices(): + """End-to-end the way the Mooncake bytes get path feeds it: uint8 tensor buffers.""" + values = [torch.arange(20, dtype=torch.float32), {"k": [1, 2, 3]}, "trailing-string"] + decoded, *_ = _roundtrip(values, _torch_alloc) + _assert_equal_payloads(decoded, values) + + +# ============================================================================ +# helpers +# ============================================================================ + + +def _assert_equal_payloads(decoded, original): + assert len(decoded) == len(original) + for got, want in zip(decoded, original, strict=True): + if isinstance(want, torch.Tensor): + assert isinstance(got, torch.Tensor) + assert got.dtype == want.dtype + assert got.shape == want.shape + assert torch.equal(got, want) + elif isinstance(want, np.ndarray): + assert isinstance(got, np.ndarray) + assert got.dtype == want.dtype + assert got.shape == want.shape + assert np.array_equal(got, want) + elif isinstance(want, dict): + assert isinstance(got, dict) + assert got.keys() == want.keys() + for k in want: + _assert_equal_payloads([got[k]], [want[k]]) + elif isinstance(want, list): + assert isinstance(got, list) + _assert_equal_payloads(got, want) + else: + assert got == want From 23d2efaf2bba5a681899e3aa8a4f7b12972c4a66 Mon Sep 17 00:00:00 2001 From: xupinjie Date: Mon, 25 May 2026 23:35:21 -0700 Subject: [PATCH 3/5] Code Style Improvement --- .../storage/clients/mooncake_client.py | 28 +++--------- transfer_queue/utils/serial_utils.py | 45 ++++++++++--------- 2 files changed, 32 insertions(+), 41 deletions(-) diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index a0c8412f..e4ee0f0c 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -150,13 +150,13 @@ def put(self, keys: list[str], values: list[Any]) -> list[dict | None]: # bytes results arrive in non-tensor submit order, which matches the order of # non-tensor values; walk values once to scatter packed_size back to its key slot. - custom_meta: list[dict | None] = [None] * len(keys) + custom_backend_meta: list[dict | None] = [None] * len(keys) sizes_iter = iter(packed_sizes) for i, value in enumerate(values): if not isinstance(value, torch.Tensor): - custom_meta[i] = {"packed_size": next(sizes_iter)} + custom_backend_meta[i] = {"packed_size": next(sizes_iter)} - return custom_meta + return custom_backend_meta def _put_tensors_thread_worker(self, batch_keys: list[str], batch_tensors: list[Tensor]) -> None: """Worker thread for putting batch of tensors to MooncakeStore.""" @@ -172,17 +172,10 @@ def _put_tensors_thread_worker(self, batch_keys: list[str], batch_tensors: list[ def _put_bytes_thread_worker(self, batch_keys: list[str], batch_values: list[Any]) -> list[int]: """Worker thread for putting batch of non-tensors to MooncakeStore.""" - # [TQ-REFACTOR-VERIFY] one-shot marker — REMOVE after verifying new bytes-put data path is hit. - if not type(self).__dict__.get("_put_bytes_verified"): - type(self)._put_bytes_verified = True - - # Encode + pack happens in-thread (so CPU serialization parallelizes); - # alloc is our hook to get a torch buffer of the right total size. - # TODO: switch to a pre-registered buffer from MooncakeStore once such - # an API is available, so we can skip the explicit register/unregister. - def alloc(total: int) -> tuple[Tensor, int]: - tensors, ptrs, _, _ = allocate_empty_tensors([torch.uint8], [(total,)]) - return tensors[0], ptrs[0] + # TODO: switch to a pre-registered buffer from MooncakeStore once such an API is available. + def alloc(buffer_size: int) -> tuple[Tensor, int]: + buffer = torch.empty(buffer_size, dtype=torch.uint8) + return buffer, buffer.data_ptr() big_buf, batch_ptrs, batch_sizes = serial_utils.batch_encode_into(batch_values, alloc) @@ -284,13 +277,6 @@ def _get_tensors_thread_worker( def _get_bytes_thread_worker( self, batch_keys: list[str], batch_packed_sizes: list[int], indexes: list[int] ) -> tuple[list[Any], list[int]]: - # [TQ-REFACTOR-VERIFY] one-shot marker — REMOVE after verifying new bytes-get data path is hit. - if not type(self).__dict__.get("_get_bytes_verified"): - type(self)._get_bytes_verified = True - - # Allocate uint8 receive buffers of packed_size; unpack + decode in-thread. - # Decoded tensors are zero-copy views; their storage transitively keeps - # batch_buffer_tensors alive after this function returns. batch_shapes = [(sz,) for sz in batch_packed_sizes] batch_dtypes = [torch.uint8] * len(batch_keys) batch_nbytes = get_nbytes(batch_dtypes, batch_shapes) diff --git a/transfer_queue/utils/serial_utils.py b/transfer_queue/utils/serial_utils.py index 124bbb23..0095e962 100644 --- a/transfer_queue/utils/serial_utils.py +++ b/transfer_queue/utils/serial_utils.py @@ -405,9 +405,14 @@ def calc_packed_size(items: Sequence[bytestr]) -> int: return _PACK_HEADER_SIZE + len(items) * _PACK_ENTRY_SIZE + sum(memoryview(item).nbytes for item in items) -def pack_into(target: bytestr, items: Sequence[bytestr]) -> None: - """Concatenate ``items`` into ``target``, which must be at least ``calc_packed_size(items)`` bytes.""" - target_mv = memoryview(target) +def pack_into(target_buffer: bytestr, items: Sequence[bytestr]) -> None: + """Concatenate ``items`` into ``target_buffer``, which must be at least ``calc_packed_size(items)`` bytes.""" + target_mv = memoryview(target_buffer) + required = calc_packed_size(items) + if target_mv.nbytes < required: + raise ValueError( + f"pack_into: target buffer has {target_mv.nbytes} bytes, requires {required}" + ) struct.pack_into(_PACK_HEADER_FMT, target_mv, 0, len(items)) entry_offset = _PACK_HEADER_SIZE @@ -425,9 +430,9 @@ def pack_into(target: bytestr, items: Sequence[bytestr]) -> None: payload_offset += nbytes -def unpack_from(source: bytestr) -> list[memoryview]: - """Split a packed buffer back into N memoryview slices over ``source``.""" - mv = memoryview(source) +def unpack_from(source_buffer: bytestr) -> list[memoryview]: + """Split a packed buffer back into N memoryview slices over ``source_buffer``.""" + mv = memoryview(source_buffer) item_count = struct.unpack_from(_PACK_HEADER_FMT, mv, 0)[0] result: list[memoryview] = [] for i in range(item_count): @@ -437,32 +442,32 @@ def unpack_from(source: bytestr) -> list[memoryview]: def batch_encode_into( - values: list[Any], - alloc: Callable[[int], tuple[Any, int]], + objs: list[Any], + alloc_buff_func: Callable[[int], tuple[Any, int]], ) -> tuple[Any, list[int], list[int]]: - """Encode all ``values`` into a single caller-allocated buffer. + """Encode all ``objs`` into a single caller-allocated buffer. - Encoding is done once per value (msgpack zero-copy + transparent pickle + Encoding is done once per object (msgpack zero-copy + transparent pickle fallback); each encoded frame is then ``pack_into`` its slice of the - buffer returned by ``alloc``. + buffer returned by ``alloc_buff_func``. Args: - values: Values to encode. - alloc: Backend-supplied factory ``(total_size) -> (buf, base_ptr)``. + objs: Objects to encode. + alloc_buff_func: Backend-supplied factory ``(total_size) -> (buf, base_ptr)``. ``buf`` must keep the underlying storage alive; ``base_ptr`` is the address of the first byte. Register/unregister for RDMA-style backends remain the caller's responsibility. Returns: ``(buf, ptrs, sizes)``: - * ``buf`` is whatever ``alloc`` returned; the caller must keep this + * ``buf`` is whatever ``alloc_buff_func`` returned; the caller must keep this alive while ``ptrs`` are in flight (and for register/unregister). * ``ptrs[i]`` / ``sizes[i]`` are the absolute pointer and byte length - of value ``i``'s packed slice within ``buf``. + of object ``i``'s packed slice within ``buf``. """ - batch_items = [encode(v) for v in values] + batch_items = [encode(obj) for obj in objs] batch_sizes = [calc_packed_size(items) for items in batch_items] - buf, base_ptr = alloc(sum(batch_sizes)) + buf, base_ptr = alloc_buff_func(sum(batch_sizes)) # torch tensors don't directly implement the buffer protocol; reach in # through numpy. Anything else (bytearray, ndarray) supports memoryview(). buf_mv = buf.numpy().data if hasattr(buf, "numpy") else memoryview(buf) @@ -477,7 +482,7 @@ def batch_encode_into( return buf, batch_ptrs, batch_sizes -def batch_decode_from(buffers: Sequence[Any]) -> list[Any]: +def batch_decode_from(source_buffers: Sequence[Any]) -> list[Any]: """Reverse of ``batch_encode_into``: for each filled buffer, unpack + decode. Tensors / ndarrays in the result are zero-copy views over ``buf``. The @@ -487,11 +492,11 @@ def batch_decode_from(buffers: Sequence[Any]) -> list[Any]: need to retain ``buf`` separately. Args: - buffers: Per-value receive buffers, in order. Each must support the + source_buffers: Per-object receive buffers, in order. Each must support the buffer protocol (``torch.Tensor`` via ``.numpy().data``; ``bytearray`` / ``ndarray`` directly). """ return [ decode(unpack_from(buf.numpy().data if hasattr(buf, "numpy") else memoryview(buf))) - for buf in buffers + for buf in source_buffers ] From dfb9cebed3498c2181df1403ee46b8cfa0ae58d2 Mon Sep 17 00:00:00 2001 From: xupinjie Date: Wed, 27 May 2026 07:30:12 -0700 Subject: [PATCH 4/5] update batch_encode_into --- tests/test_serial_utils_batch_on_cpu.py | 201 ++++++++++-------- .../storage/clients/mooncake_client.py | 23 +- transfer_queue/utils/serial_utils.py | 63 +++--- 3 files changed, 161 insertions(+), 126 deletions(-) diff --git a/tests/test_serial_utils_batch_on_cpu.py b/tests/test_serial_utils_batch_on_cpu.py index d5a60af1..9627fa0a 100644 --- a/tests/test_serial_utils_batch_on_cpu.py +++ b/tests/test_serial_utils_batch_on_cpu.py @@ -22,8 +22,6 @@ * ``batch_decode_from`` """ -import ctypes - import numpy as np import pytest import torch @@ -59,8 +57,6 @@ def test_pack_unpack_roundtrip_random_bytes(): def test_pack_into_writes_only_within_its_slice(): - """``pack_into`` is called on a slice of a larger buf in ``batch_encode_into``; - it must not write outside the slice.""" items = [b"alpha", b"beta", b"gamma"] sz = serial_utils.calc_packed_size(items) pad_before, pad_after = 17, 23 @@ -87,136 +83,159 @@ def test_unpack_from_zero_item_buffer(): # ============================================================================ -def _torch_alloc(total: int) -> tuple[torch.Tensor, int]: - """Mooncake-style alloc callback used by tests.""" - buf = torch.empty(total, dtype=torch.uint8) - return buf, buf.data_ptr() +def _mooncake_alloc(sizes: list[int]) -> list[torch.Tensor]: + """Single big torch.uint8 tensor sliced into N views (mooncake-style).""" + big = torch.empty(sum(sizes), dtype=torch.uint8) + buffers: list[torch.Tensor] = [] + offset = 0 + for s in sizes: + buffers.append(big[offset : offset + s]) + offset += s + return buffers + +def _yuanrong_alloc(sizes: list[int]) -> list[bytearray]: + """N independent bytearrays (yuanrong-style per-key buffer).""" + return [bytearray(s) for s in sizes] -def _bytearray_alloc(total: int) -> tuple[bytearray, int]: - """Non-torch alloc callback (exercises ``not hasattr(buf, 'numpy')`` branch).""" - buf = bytearray(total) - base_ptr = ctypes.addressof(ctypes.c_char.from_buffer(buf)) if total else 0 - return buf, base_ptr +def _decode_from_returned(buffers, alloc_kind): + if alloc_kind == "mooncake": + return serial_utils.batch_decode_from(buffers) + return serial_utils.batch_decode_from([bytes(b) for b in buffers]) -def _roundtrip(values, alloc): - """Encode ``values`` via ``batch_encode_into``, then slice the big buf back into - per-value buffers and decode via ``batch_decode_from``. Returns decoded list.""" - buf, ptrs, sizes = serial_utils.batch_encode_into(values, alloc) - if isinstance(buf, torch.Tensor): - base = buf.data_ptr() - per_value = [buf[p - base : p - base + s] for p, s in zip(ptrs, sizes, strict=True)] - else: - base = ctypes.addressof(ctypes.c_char.from_buffer(buf)) if len(buf) else 0 - per_value = [bytes(buf[p - base : p - base + s]) for p, s in zip(ptrs, sizes, strict=True)] - return serial_utils.batch_decode_from(per_value), buf, ptrs, sizes +def _roundtrip(values, alloc, alloc_kind, *, num_workers: int = 1): + buffers, sizes = serial_utils.batch_encode_into(values, alloc, num_workers=num_workers) + decoded = _decode_from_returned(buffers, alloc_kind) + return decoded, buffers, sizes -# ---- structural: return shapes / alloc contract / ptrs layout ---- + +# ---- structural: return shapes / alloc contract ---- def test_batch_encode_into_return_shapes(): values = [{"x": 1}, "a string", torch.arange(8, dtype=torch.float32)] - buf, ptrs, sizes = serial_utils.batch_encode_into(values, _torch_alloc) + buffers, sizes = serial_utils.batch_encode_into(values, _mooncake_alloc) - assert isinstance(buf, torch.Tensor) - assert len(ptrs) == len(values) + assert len(buffers) == len(values) assert len(sizes) == len(values) - assert buf.nbytes == sum(sizes) + for b, s in zip(buffers, sizes, strict=True): + assert b.nbytes == s -def test_batch_encode_into_alloc_called_with_total_size(): +def test_batch_encode_into_alloc_called_with_per_obj_sizes(): captured = {} - def spy_alloc(total: int): - captured["total"] = total - return _torch_alloc(total) + def spy_alloc(sizes: list[int]): + captured["sizes"] = list(sizes) + return _mooncake_alloc(sizes) values = [b"x" * 100, b"y" * 50, {"k": "v"}] - buf, _, sizes = serial_utils.batch_encode_into(values, spy_alloc) - assert captured["total"] == sum(sizes) == buf.nbytes + _, sizes = serial_utils.batch_encode_into(values, spy_alloc) + assert captured["sizes"] == sizes + +def test_batch_encode_into_allows_padded_buffers(): + """Alloc may return buffers larger than requested sizes; batch_sizes still + reports the actual packed length, and the data round-trips correctly.""" + pad = 32 -def test_batch_encode_into_ptrs_are_tightly_packed_within_buf(): - values = [torch.arange(i, dtype=torch.float32) for i in (4, 16, 64)] - buf, ptrs, sizes = serial_utils.batch_encode_into(values, _torch_alloc) + def padded_alloc(sizes): + return [bytearray(s + pad) for s in sizes] - base = buf.data_ptr() - expected_offset = 0 - for p, s in zip(ptrs, sizes, strict=True): - assert p == base + expected_offset - assert p + s <= base + buf.nbytes - expected_offset += s - assert expected_offset == buf.nbytes # no padding between slices + values = [b"alpha", {"k": "v"}, torch.arange(4, dtype=torch.float32)] + buffers, sizes = serial_utils.batch_encode_into(values, padded_alloc) + + for b, s in zip(buffers, sizes, strict=True): + assert len(b) == s + pad + + # decoding uses only the first `s` bytes; the pad tail is harmless + decoded = serial_utils.batch_decode_from([bytes(b[:s]) for b, s in zip(buffers, sizes, strict=True)]) + _assert_equal_payloads(decoded, values) # ---- semantic: encode → decode roundtrip preserves values ---- -@pytest.mark.parametrize( - "values", - [ - pytest.param([42, 3.14, "hello", b"bytes"], id="primitives"), - pytest.param([{"a": 1, "b": [1, 2, 3]}, {"nested": {"k": "v"}}], id="nested-dicts"), - pytest.param([torch.arange(10, dtype=torch.float32)], id="single-tensor"), - pytest.param( - [ - torch.arange(100, dtype=torch.float32), - torch.randn(4, 4, dtype=torch.bfloat16), - torch.zeros(3, 5, dtype=torch.int64), - ], - id="mixed-tensors", - ), - pytest.param( - [np.arange(50, dtype=np.float64), np.ones((3, 3), dtype=np.int32)], - id="numpy-arrays", - ), - pytest.param( - [{"meta": "v1", "arr": torch.arange(5, dtype=torch.float32)}, [1, 2, "three"]], - id="heterogeneous", - ), - ], -) -def test_batch_encode_decode_roundtrip(values): - decoded, *_ = _roundtrip(values, _torch_alloc) +_ROUNDTRIP_PARAMS = [ + pytest.param([42, 3.14, "hello", b"bytes"], id="primitives"), + pytest.param([{"a": 1, "b": [1, 2, 3]}, {"nested": {"k": "v"}}], id="nested-dicts"), + pytest.param([torch.arange(10, dtype=torch.float32)], id="single-tensor"), + pytest.param( + [ + torch.arange(100, dtype=torch.float32), + torch.randn(4, 4, dtype=torch.bfloat16), + torch.zeros(3, 5, dtype=torch.int64), + ], + id="mixed-tensors", + ), + pytest.param( + [np.arange(50, dtype=np.float64), np.ones((3, 3), dtype=np.int32)], + id="numpy-arrays", + ), + pytest.param( + [{"meta": "v1", "arr": torch.arange(5, dtype=torch.float32)}, [1, 2, "three"]], + id="heterogeneous", + ), +] + + +@pytest.mark.parametrize("values", _ROUNDTRIP_PARAMS) +def test_batch_encode_decode_roundtrip_mooncake(values): + decoded, *_ = _roundtrip(values, _mooncake_alloc, "mooncake") + _assert_equal_payloads(decoded, values) + + +@pytest.mark.parametrize("values", _ROUNDTRIP_PARAMS) +def test_batch_encode_decode_roundtrip_yuanrong(values): + decoded, *_ = _roundtrip(values, _yuanrong_alloc, "yuanrong") _assert_equal_payloads(decoded, values) def test_batch_encode_decode_single_value(): values = [{"only": "one", "tensor": torch.arange(3, dtype=torch.float32)}] - decoded, *_ = _roundtrip(values, _torch_alloc) + decoded, *_ = _roundtrip(values, _mooncake_alloc, "mooncake") _assert_equal_payloads(decoded, values) def test_batch_encode_decode_empty_list(): - """Empty input: alloc invoked with 0, no ptrs, no sizes, decode of [] returns [].""" calls = [] - def alloc(total: int): - calls.append(total) - return _torch_alloc(total) + def alloc(sizes): + calls.append(list(sizes)) + return [] - buf, ptrs, sizes = serial_utils.batch_encode_into([], alloc) - assert ptrs == [] and sizes == [] - assert calls == [0] - assert buf.nbytes == 0 + buffers, sizes = serial_utils.batch_encode_into([], alloc) + assert buffers == [] and sizes == [] + assert calls == [[]] assert serial_utils.batch_decode_from([]) == [] -def test_batch_encode_decode_with_bytearray_alloc(): - """Exercise the non-torch buffer branch end-to-end.""" - values = [b"hello", b"world", {"k": 1}] - decoded, buf, _, sizes = _roundtrip(values, _bytearray_alloc) - assert isinstance(buf, bytearray) - assert len(buf) == sum(sizes) - _assert_equal_payloads(decoded, values) +# ---- num_workers: parallel pack must produce identical bytes vs serial ---- + + +@pytest.mark.parametrize("values", _ROUNDTRIP_PARAMS) +def test_batch_encode_into_parallel_matches_serial(values): + serial_buffers, serial_sizes = serial_utils.batch_encode_into( + values, _yuanrong_alloc, num_workers=1 + ) + par_buffers, par_sizes = serial_utils.batch_encode_into( + values, _yuanrong_alloc, num_workers=4 + ) + + assert serial_sizes == par_sizes + assert [bytes(b) for b in serial_buffers] == [bytes(b) for b in par_buffers] + +def test_batch_encode_into_parallel_roundtrip_many_objects(): + rng = np.random.default_rng(42) + values = [] + for _ in range(64): + n = int(rng.integers(1, 257)) + values.append(torch.from_numpy(rng.random(n).astype(np.float32))) -def test_batch_decode_from_accepts_torch_tensor_slices(): - """End-to-end the way the Mooncake bytes get path feeds it: uint8 tensor buffers.""" - values = [torch.arange(20, dtype=torch.float32), {"k": [1, 2, 3]}, "trailing-string"] - decoded, *_ = _roundtrip(values, _torch_alloc) + decoded, *_ = _roundtrip(values, _yuanrong_alloc, "yuanrong", num_workers=8) _assert_equal_payloads(decoded, values) diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index e4ee0f0c..08a22426 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -173,17 +173,26 @@ def _put_bytes_thread_worker(self, batch_keys: list[str], batch_values: list[Any """Worker thread for putting batch of non-tensors to MooncakeStore.""" # TODO: switch to a pre-registered buffer from MooncakeStore once such an API is available. - def alloc(buffer_size: int) -> tuple[Tensor, int]: - buffer = torch.empty(buffer_size, dtype=torch.uint8) - return buffer, buffer.data_ptr() + region_ptrs: list[int] = [] + region_sizes: list[int] = [] - big_buf, batch_ptrs, batch_sizes = serial_utils.batch_encode_into(batch_values, alloc) + def alloc(sizes: list[int]) -> list[Tensor]: + nonlocal region_ptrs, region_sizes + # Borrow allocate_empty_tensors for its "N views sharing one register-able + # region" layout; uint8 1D shapes are byte scratch, not real tensors. + dtypes = [torch.uint8] * len(sizes) + shapes = [(s,) for s in sizes] + buffers, _, region_ptrs, region_sizes = allocate_empty_tensors(dtypes, shapes) + return buffers - self._store.register_buffer(big_buf.data_ptr(), big_buf.nbytes) + buffers, batch_sizes = serial_utils.batch_encode_into(batch_values, alloc) + batch_ptrs = [b.data_ptr() for b in buffers] + + self._register_all_buffers(region_ptrs, region_sizes) try: self._batch_upsert_with_retry(batch_keys, batch_ptrs, batch_sizes) finally: - self._store.unregister_buffer(big_buf.data_ptr()) + self._unregister_all_buffers(region_ptrs) return batch_sizes @@ -277,6 +286,8 @@ def _get_tensors_thread_worker( def _get_bytes_thread_worker( self, batch_keys: list[str], batch_packed_sizes: list[int], indexes: list[int] ) -> tuple[list[Any], list[int]]: + # Borrow allocate_empty_tensors for its "N views sharing one register-able + # region" layout; uint8 1D shapes are byte scratch, not real tensors. batch_shapes = [(sz,) for sz in batch_packed_sizes] batch_dtypes = [torch.uint8] * len(batch_keys) batch_nbytes = get_nbytes(batch_dtypes, batch_shapes) diff --git a/transfer_queue/utils/serial_utils.py b/transfer_queue/utils/serial_utils.py index 0095e962..0f7a6c58 100644 --- a/transfer_queue/utils/serial_utils.py +++ b/transfer_queue/utils/serial_utils.py @@ -21,6 +21,7 @@ import struct import warnings from collections.abc import Callable, Sequence +from concurrent.futures import ThreadPoolExecutor from contextvars import ContextVar from typing import Any, TypeAlias @@ -443,43 +444,47 @@ def unpack_from(source_buffer: bytestr) -> list[memoryview]: def batch_encode_into( objs: list[Any], - alloc_buff_func: Callable[[int], tuple[Any, int]], -) -> tuple[Any, list[int], list[int]]: - """Encode all ``objs`` into a single caller-allocated buffer. - - Encoding is done once per object (msgpack zero-copy + transparent pickle - fallback); each encoded frame is then ``pack_into`` its slice of the - buffer returned by ``alloc_buff_func``. + alloc_buff_func: Callable[[list[int]], list[Any]], + *, + num_workers: int = 1, +) -> tuple[list[Any], list[int]]: + """Encode all ``objs`` into caller-allocated per-object buffers. Args: objs: Objects to encode. - alloc_buff_func: Backend-supplied factory ``(total_size) -> (buf, base_ptr)``. - ``buf`` must keep the underlying storage alive; ``base_ptr`` is - the address of the first byte. Register/unregister for RDMA-style - backends remain the caller's responsibility. + alloc_buff_func: ``(per_obj_sizes) -> buffers``. ``buffers[i]`` must + hold at least ``sizes[i]`` bytes and be consumable by + ``memoryview()`` (or ``.numpy().data`` for ``torch.Tensor``). + num_workers: ``pack_into`` parallelism. Default 1 (serial); set ``>1`` + when the upper-layer is single-threaded. Returns: - ``(buf, ptrs, sizes)``: - * ``buf`` is whatever ``alloc_buff_func`` returned; the caller must keep this - alive while ``ptrs`` are in flight (and for register/unregister). - * ``ptrs[i]`` / ``sizes[i]`` are the absolute pointer and byte length - of object ``i``'s packed slice within ``buf``. + ``(buffers, batch_sizes)``: ``buffers`` is what ``alloc_buff_func`` + returned with packed bytes written; ``batch_sizes[i]`` is object + ``i``'s packed byte length (≤ ``buffers[i]`` capacity). + + Note: + Lifetime is caller-owned: this function holds no references to + ``buffers`` after return. Whatever the alloc closure used to back + them must remain alive until all downstream consumers finish. """ batch_items = [encode(obj) for obj in objs] batch_sizes = [calc_packed_size(items) for items in batch_items] - buf, base_ptr = alloc_buff_func(sum(batch_sizes)) - # torch tensors don't directly implement the buffer protocol; reach in - # through numpy. Anything else (bytearray, ndarray) supports memoryview(). - buf_mv = buf.numpy().data if hasattr(buf, "numpy") else memoryview(buf) - - batch_ptrs: list[int] = [] - offset = 0 - for items, size in zip(batch_items, batch_sizes, strict=True): - pack_into(buf_mv[offset : offset + size], items) - batch_ptrs.append(base_ptr + offset) - offset += size - - return buf, batch_ptrs, batch_sizes + buffers = alloc_buff_func(batch_sizes) + + def _pack_one(pair: tuple[Any, list[bytestr]]) -> None: + buf, items = pair + mv = buf.numpy().data if hasattr(buf, "numpy") else memoryview(buf) + pack_into(mv, items) + + if num_workers <= 1: + for pair in zip(buffers, batch_items, strict=True): + _pack_one(pair) + else: + with ThreadPoolExecutor(max_workers=num_workers) as executor: + list(executor.map(_pack_one, zip(buffers, batch_items, strict=True))) + + return buffers, batch_sizes def batch_decode_from(source_buffers: Sequence[Any]) -> list[Any]: From 6cd82c57970a26b0b590d937c5af6e32a3fab271 Mon Sep 17 00:00:00 2001 From: xupinjie Date: Thu, 28 May 2026 04:03:57 -0700 Subject: [PATCH 5/5] update docstr and test case --- tests/test_serial_utils_batch_on_cpu.py | 77 +++++++++++-------- .../storage/clients/mooncake_client.py | 22 ++++-- transfer_queue/utils/serial_utils.py | 76 ++++++++++++------ 3 files changed, 109 insertions(+), 66 deletions(-) diff --git a/tests/test_serial_utils_batch_on_cpu.py b/tests/test_serial_utils_batch_on_cpu.py index 9627fa0a..e96dda1f 100644 --- a/tests/test_serial_utils_batch_on_cpu.py +++ b/tests/test_serial_utils_batch_on_cpu.py @@ -43,19 +43,6 @@ def test_calc_packed_size_then_pack_unpack_roundtrip(): assert [bytes(mv) for mv in recovered] == items -def test_pack_unpack_roundtrip_random_bytes(): - rng = np.random.default_rng(0) - items = [rng.bytes(int(rng.integers(1, 2048))) for _ in range(16)] - size = serial_utils.calc_packed_size(items) - buf = bytearray(size) - serial_utils.pack_into(buf, items) - - recovered = serial_utils.unpack_from(buf) - assert len(recovered) == len(items) - for r, item in zip(recovered, items, strict=True): - assert bytes(r) == item - - def test_pack_into_writes_only_within_its_slice(): items = [b"alpha", b"beta", b"gamma"] sz = serial_utils.calc_packed_size(items) @@ -124,18 +111,6 @@ def test_batch_encode_into_return_shapes(): assert b.nbytes == s -def test_batch_encode_into_alloc_called_with_per_obj_sizes(): - captured = {} - - def spy_alloc(sizes: list[int]): - captured["sizes"] = list(sizes) - return _mooncake_alloc(sizes) - - values = [b"x" * 100, b"y" * 50, {"k": "v"}] - _, sizes = serial_utils.batch_encode_into(values, spy_alloc) - assert captured["sizes"] == sizes - - def test_batch_encode_into_allows_padded_buffers(): """Alloc may return buffers larger than requested sizes; batch_sizes still reports the actual packed length, and the data round-trips correctly.""" @@ -178,6 +153,38 @@ def padded_alloc(sizes): [{"meta": "v1", "arr": torch.arange(5, dtype=torch.float32)}, [1, 2, "three"]], id="heterogeneous", ), + pytest.param( + [ + torch.randn(2, 3, 4, 5, dtype=torch.float32), + torch.randn(2, 3, 4, 5, 6, dtype=torch.bfloat16), + ], + id="high-rank-tensors", + ), + pytest.param( + [ + torch.nested.nested_tensor( + [torch.arange(3, dtype=torch.float32), torch.arange(5, dtype=torch.float32)], + layout=torch.strided, + ), + torch.nested.nested_tensor( + [torch.randn(3, dtype=torch.bfloat16), torch.randn(5, dtype=torch.bfloat16)], + layout=torch.strided, + ), + torch.nested.nested_tensor( + [torch.arange(4, dtype=torch.float32), torch.arange(7, dtype=torch.float32)], + layout=torch.jagged, + ), + torch.nested.nested_tensor( + [torch.randn(4, dtype=torch.bfloat16), torch.randn(7, dtype=torch.bfloat16)], + layout=torch.jagged, + ), + ], + id="nested-tensors", + ), + pytest.param( + [{"only": "one", "tensor": torch.arange(3, dtype=torch.float32)}], + id="single-value", + ), ] @@ -193,12 +200,6 @@ def test_batch_encode_decode_roundtrip_yuanrong(values): _assert_equal_payloads(decoded, values) -def test_batch_encode_decode_single_value(): - values = [{"only": "one", "tensor": torch.arange(3, dtype=torch.float32)}] - decoded, *_ = _roundtrip(values, _mooncake_alloc, "mooncake") - _assert_equal_payloads(decoded, values) - - def test_batch_encode_decode_empty_list(): calls = [] @@ -250,8 +251,18 @@ def _assert_equal_payloads(decoded, original): if isinstance(want, torch.Tensor): assert isinstance(got, torch.Tensor) assert got.dtype == want.dtype - assert got.shape == want.shape - assert torch.equal(got, want) + if want.is_nested: + assert got.is_nested + assert got.layout == want.layout + got_subs = got.unbind() + want_subs = want.unbind() + assert len(got_subs) == len(want_subs) + for g, w in zip(got_subs, want_subs, strict=True): + assert g.shape == w.shape + assert torch.equal(g, w) + else: + assert got.shape == want.shape + assert torch.equal(got, want) elif isinstance(want, np.ndarray): assert isinstance(got, np.ndarray) assert got.dtype == want.dtype diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index 08a22426..8d1e7df2 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -150,11 +150,11 @@ def put(self, keys: list[str], values: list[Any]) -> list[dict | None]: # bytes results arrive in non-tensor submit order, which matches the order of # non-tensor values; walk values once to scatter packed_size back to its key slot. - custom_backend_meta: list[dict | None] = [None] * len(keys) sizes_iter = iter(packed_sizes) - for i, value in enumerate(values): - if not isinstance(value, torch.Tensor): - custom_backend_meta[i] = {"packed_size": next(sizes_iter)} + custom_backend_meta: list[dict | None] = [ + {"packed_size": next(sizes_iter)} if not isinstance(value, torch.Tensor) else None + for value in values + ] return custom_backend_meta @@ -178,8 +178,11 @@ def _put_bytes_thread_worker(self, batch_keys: list[str], batch_values: list[Any def alloc(sizes: list[int]) -> list[Tensor]: nonlocal region_ptrs, region_sizes - # Borrow allocate_empty_tensors for its "N views sharing one register-able - # region" layout; uint8 1D shapes are byte scratch, not real tensors. + # `batch_packed_sizes` are byte counts. With torch.uint8 (1 byte/element), + # a 1-D shape of (N,) corresponds to exactly N bytes. We use + # `allocate_empty_tensors` to get N uint8 views over a single contiguous, + # register-able region. These are plain byte buffers, not real tensors; + # consumers apply the actual dtype/shape interpretation when unpacking. dtypes = [torch.uint8] * len(sizes) shapes = [(s,) for s in sizes] buffers, _, region_ptrs, region_sizes = allocate_empty_tensors(dtypes, shapes) @@ -286,8 +289,11 @@ def _get_tensors_thread_worker( def _get_bytes_thread_worker( self, batch_keys: list[str], batch_packed_sizes: list[int], indexes: list[int] ) -> tuple[list[Any], list[int]]: - # Borrow allocate_empty_tensors for its "N views sharing one register-able - # region" layout; uint8 1D shapes are byte scratch, not real tensors. + # `batch_packed_sizes` are byte counts. With torch.uint8 (1 byte/element), + # a 1-D shape of (N,) corresponds to exactly N bytes. We use + # `allocate_empty_tensors` to get N uint8 views over a single contiguous, + # register-able region. These are plain byte buffers, not real tensors; + # consumers apply the actual dtype/shape interpretation when unpacking. batch_shapes = [(sz,) for sz in batch_packed_sizes] batch_dtypes = [torch.uint8] * len(batch_keys) batch_nbytes = get_nbytes(batch_dtypes, batch_shapes) diff --git a/transfer_queue/utils/serial_utils.py b/transfer_queue/utils/serial_utils.py index 0f7a6c58..813f8950 100644 --- a/transfer_queue/utils/serial_utils.py +++ b/transfer_queue/utils/serial_utils.py @@ -447,26 +447,40 @@ def batch_encode_into( alloc_buff_func: Callable[[list[int]], list[Any]], *, num_workers: int = 1, -) -> tuple[list[Any], list[int]]: - """Encode all ``objs`` into caller-allocated per-object buffers. +) -> tuple[list[np.ndarray | memoryview], list[int]]: + """Encode multiple objects in-place into caller-allocated buffers. + + Each object is msgpack-encoded (with zero-copy tensor/ndarray extraction) + and packed into a buffer slot supplied by ``alloc_buff_func``. Buffers are + written in place; the function returns the same buffer list along with + each slot's packed byte length. Args: - objs: Objects to encode. - alloc_buff_func: ``(per_obj_sizes) -> buffers``. ``buffers[i]`` must - hold at least ``sizes[i]`` bytes and be consumable by - ``memoryview()`` (or ``.numpy().data`` for ``torch.Tensor``). - num_workers: ``pack_into`` parallelism. Default 1 (serial); set ``>1`` - when the upper-layer is single-threaded. + objs: Objects to encode, one per output buffer slot. + alloc_buff_func: Callable taking per-object packed sizes and returning + the corresponding buffer list. ``buffers[i]`` must be an + ``np.ndarray`` or ``memoryview`` holding at least ``sizes[i]`` + bytes. + num_workers: Thread count for parallel packing. Default 1 (serial); + set ``>1`` only when the upper layer is single-threaded. Returns: - ``(buffers, batch_sizes)``: ``buffers`` is what ``alloc_buff_func`` - returned with packed bytes written; ``batch_sizes[i]`` is object - ``i``'s packed byte length (≤ ``buffers[i]`` capacity). + tuple[list[np.ndarray | memoryview], list[int]]: The buffers returned by + ``alloc_buff_func`` with packed bytes written, paired with each + object's packed length (``<=`` buffer capacity). Note: - Lifetime is caller-owned: this function holds no references to - ``buffers`` after return. Whatever the alloc closure used to back - them must remain alive until all downstream consumers finish. + Lifetime is caller-owned: this function holds no references to the + buffers after return. Whatever backs the allocation must outlive all + downstream consumers. + + Example: + >>> # Pack two tensors into pre-allocated pinned uint8 tensor buffers + >>> def alloc(sizes): + ... return [torch.empty(s, dtype=torch.uint8, pin_memory=True) for s in sizes] + >>> objs = [torch.tensor([1, 2, 3]), torch.tensor([4.0, 5.0])] + >>> bufs, lengths = batch_encode_into(objs, alloc) + >>> print(f"packed sizes: {lengths}") """ batch_items = [encode(obj) for obj in objs] batch_sizes = [calc_packed_size(items) for items in batch_items] @@ -487,19 +501,31 @@ def _pack_one(pair: tuple[Any, list[bytestr]]) -> None: return buffers, batch_sizes -def batch_decode_from(source_buffers: Sequence[Any]) -> list[Any]: - """Reverse of ``batch_encode_into``: for each filled buffer, unpack + decode. - - Tensors / ndarrays in the result are zero-copy views over ``buf``. The - natural Python ref chain (``torch.frombuffer`` → ``Py_buffer`` → memoryview - slice → parent memoryview → numpy array → original ``buf``) keeps the - source alive as long as the returned object is reachable. Caller does NOT - need to retain ``buf`` separately. +def batch_decode_from(source_buffers: Sequence[np.ndarray | memoryview]) -> list[Any]: + """Reverse of ``batch_encode_into``: unpack and decode each filled buffer. Args: - source_buffers: Per-object receive buffers, in order. Each must support the - buffer protocol (``torch.Tensor`` via ``.numpy().data``; - ``bytearray`` / ``ndarray`` directly). + source_buffers: Per-object receive buffers in order. Each must be an + ``np.ndarray`` or ``memoryview``. + + Returns: + list[Any]: Decoded objects, one per input buffer, in the same order. + + Note: + Tensors and ndarrays in the result are zero-copy views over the + source buffers. The Python reference chain (``torch.frombuffer`` -> + ``Py_buffer`` -> memoryview slice -> parent memoryview -> numpy array + -> original buffer) keeps the source alive as long as the decoded + object is reachable; the caller does NOT need to retain the source + buffer separately. + + Example: + >>> # Round-trip: encode then decode + >>> def alloc(sizes): + ... return [torch.empty(s, dtype=torch.uint8) for s in sizes] + >>> objs = [torch.tensor([1, 2, 3]), torch.tensor([4.0, 5.0])] + >>> bufs, _ = batch_encode_into(objs, alloc) + >>> decoded = batch_decode_from(bufs) """ return [ decode(unpack_from(buf.numpy().data if hasattr(buf, "numpy") else memoryview(buf)))