diff --git a/tests/test_serial_utils_batch_on_cpu.py b/tests/test_serial_utils_batch_on_cpu.py new file mode 100644 index 00000000..e96dda1f --- /dev/null +++ b/tests/test_serial_utils_batch_on_cpu.py @@ -0,0 +1,280 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the packed-buffer batch serialization helpers in +``transfer_queue.utils.serial_utils``: + +* ``calc_packed_size`` +* ``pack_into`` / ``unpack_from`` +* ``batch_encode_into`` +* ``batch_decode_from`` +""" + +import numpy as np +import pytest +import torch + +from transfer_queue.utils import serial_utils + + +# ============================================================================ +# low-level: calc_packed_size + pack_into + unpack_from (raw bytes layer) +# ============================================================================ + + +def test_calc_packed_size_then_pack_unpack_roundtrip(): + items = [b"hello", b"world!", b"x"] + size = serial_utils.calc_packed_size(items) + buf = bytearray(size) + serial_utils.pack_into(buf, items) + recovered = serial_utils.unpack_from(buf) + assert [bytes(mv) for mv in recovered] == items + + +def test_pack_into_writes_only_within_its_slice(): + items = [b"alpha", b"beta", b"gamma"] + sz = serial_utils.calc_packed_size(items) + pad_before, pad_after = 17, 23 + big = bytearray(pad_before + sz + pad_after) + serial_utils.pack_into(memoryview(big)[pad_before : pad_before + sz], items) + + assert all(b == 0 for b in big[:pad_before]) + assert all(b == 0 for b in big[pad_before + sz :]) + + recovered = serial_utils.unpack_from(memoryview(big)[pad_before : pad_before + sz]) + assert [bytes(mv) for mv in recovered] == items + + +def test_unpack_from_zero_item_buffer(): + items: list[bytes] = [] + sz = serial_utils.calc_packed_size(items) + buf = bytearray(sz) + serial_utils.pack_into(buf, items) + assert serial_utils.unpack_from(buf) == [] + + +# ============================================================================ +# batch_encode_into + batch_decode_from (high-level batch layer) +# ============================================================================ + + +def _mooncake_alloc(sizes: list[int]) -> list[torch.Tensor]: + """Single big torch.uint8 tensor sliced into N views (mooncake-style).""" + big = torch.empty(sum(sizes), dtype=torch.uint8) + buffers: list[torch.Tensor] = [] + offset = 0 + for s in sizes: + buffers.append(big[offset : offset + s]) + offset += s + return buffers + + +def _yuanrong_alloc(sizes: list[int]) -> list[bytearray]: + """N independent bytearrays (yuanrong-style per-key buffer).""" + return [bytearray(s) for s in sizes] + + +def _decode_from_returned(buffers, alloc_kind): + if alloc_kind == "mooncake": + return serial_utils.batch_decode_from(buffers) + return serial_utils.batch_decode_from([bytes(b) for b in buffers]) + + +def _roundtrip(values, alloc, alloc_kind, *, num_workers: int = 1): + buffers, sizes = serial_utils.batch_encode_into(values, alloc, num_workers=num_workers) + decoded = _decode_from_returned(buffers, alloc_kind) + return decoded, buffers, sizes + + +# ---- structural: return shapes / alloc contract ---- + + +def test_batch_encode_into_return_shapes(): + values = [{"x": 1}, "a string", torch.arange(8, dtype=torch.float32)] + buffers, sizes = serial_utils.batch_encode_into(values, _mooncake_alloc) + + assert len(buffers) == len(values) + assert len(sizes) == len(values) + for b, s in zip(buffers, sizes, strict=True): + assert b.nbytes == s + + +def test_batch_encode_into_allows_padded_buffers(): + """Alloc may return buffers larger than requested sizes; batch_sizes still + reports the actual packed length, and the data round-trips correctly.""" + pad = 32 + + def padded_alloc(sizes): + return [bytearray(s + pad) for s in sizes] + + values = [b"alpha", {"k": "v"}, torch.arange(4, dtype=torch.float32)] + buffers, sizes = serial_utils.batch_encode_into(values, padded_alloc) + + for b, s in zip(buffers, sizes, strict=True): + assert len(b) == s + pad + + # decoding uses only the first `s` bytes; the pad tail is harmless + decoded = serial_utils.batch_decode_from([bytes(b[:s]) for b, s in zip(buffers, sizes, strict=True)]) + _assert_equal_payloads(decoded, values) + + +# ---- semantic: encode → decode roundtrip preserves values ---- + + +_ROUNDTRIP_PARAMS = [ + pytest.param([42, 3.14, "hello", b"bytes"], id="primitives"), + pytest.param([{"a": 1, "b": [1, 2, 3]}, {"nested": {"k": "v"}}], id="nested-dicts"), + pytest.param([torch.arange(10, dtype=torch.float32)], id="single-tensor"), + pytest.param( + [ + torch.arange(100, dtype=torch.float32), + torch.randn(4, 4, dtype=torch.bfloat16), + torch.zeros(3, 5, dtype=torch.int64), + ], + id="mixed-tensors", + ), + pytest.param( + [np.arange(50, dtype=np.float64), np.ones((3, 3), dtype=np.int32)], + id="numpy-arrays", + ), + pytest.param( + [{"meta": "v1", "arr": torch.arange(5, dtype=torch.float32)}, [1, 2, "three"]], + id="heterogeneous", + ), + pytest.param( + [ + torch.randn(2, 3, 4, 5, dtype=torch.float32), + torch.randn(2, 3, 4, 5, 6, dtype=torch.bfloat16), + ], + id="high-rank-tensors", + ), + pytest.param( + [ + torch.nested.nested_tensor( + [torch.arange(3, dtype=torch.float32), torch.arange(5, dtype=torch.float32)], + layout=torch.strided, + ), + torch.nested.nested_tensor( + [torch.randn(3, dtype=torch.bfloat16), torch.randn(5, dtype=torch.bfloat16)], + layout=torch.strided, + ), + torch.nested.nested_tensor( + [torch.arange(4, dtype=torch.float32), torch.arange(7, dtype=torch.float32)], + layout=torch.jagged, + ), + torch.nested.nested_tensor( + [torch.randn(4, dtype=torch.bfloat16), torch.randn(7, dtype=torch.bfloat16)], + layout=torch.jagged, + ), + ], + id="nested-tensors", + ), + pytest.param( + [{"only": "one", "tensor": torch.arange(3, dtype=torch.float32)}], + id="single-value", + ), +] + + +@pytest.mark.parametrize("values", _ROUNDTRIP_PARAMS) +def test_batch_encode_decode_roundtrip_mooncake(values): + decoded, *_ = _roundtrip(values, _mooncake_alloc, "mooncake") + _assert_equal_payloads(decoded, values) + + +@pytest.mark.parametrize("values", _ROUNDTRIP_PARAMS) +def test_batch_encode_decode_roundtrip_yuanrong(values): + decoded, *_ = _roundtrip(values, _yuanrong_alloc, "yuanrong") + _assert_equal_payloads(decoded, values) + + +def test_batch_encode_decode_empty_list(): + calls = [] + + def alloc(sizes): + calls.append(list(sizes)) + return [] + + buffers, sizes = serial_utils.batch_encode_into([], alloc) + assert buffers == [] and sizes == [] + assert calls == [[]] + assert serial_utils.batch_decode_from([]) == [] + + +# ---- num_workers: parallel pack must produce identical bytes vs serial ---- + + +@pytest.mark.parametrize("values", _ROUNDTRIP_PARAMS) +def test_batch_encode_into_parallel_matches_serial(values): + serial_buffers, serial_sizes = serial_utils.batch_encode_into( + values, _yuanrong_alloc, num_workers=1 + ) + par_buffers, par_sizes = serial_utils.batch_encode_into( + values, _yuanrong_alloc, num_workers=4 + ) + + assert serial_sizes == par_sizes + assert [bytes(b) for b in serial_buffers] == [bytes(b) for b in par_buffers] + + +def test_batch_encode_into_parallel_roundtrip_many_objects(): + rng = np.random.default_rng(42) + values = [] + for _ in range(64): + n = int(rng.integers(1, 257)) + values.append(torch.from_numpy(rng.random(n).astype(np.float32))) + + decoded, *_ = _roundtrip(values, _yuanrong_alloc, "yuanrong", num_workers=8) + _assert_equal_payloads(decoded, values) + + +# ============================================================================ +# helpers +# ============================================================================ + + +def _assert_equal_payloads(decoded, original): + assert len(decoded) == len(original) + for got, want in zip(decoded, original, strict=True): + if isinstance(want, torch.Tensor): + assert isinstance(got, torch.Tensor) + assert got.dtype == want.dtype + if want.is_nested: + assert got.is_nested + assert got.layout == want.layout + got_subs = got.unbind() + want_subs = want.unbind() + assert len(got_subs) == len(want_subs) + for g, w in zip(got_subs, want_subs, strict=True): + assert g.shape == w.shape + assert torch.equal(g, w) + else: + assert got.shape == want.shape + assert torch.equal(got, want) + elif isinstance(want, np.ndarray): + assert isinstance(got, np.ndarray) + assert got.dtype == want.dtype + assert got.shape == want.shape + assert np.array_equal(got, want) + elif isinstance(want, dict): + assert isinstance(got, dict) + assert got.keys() == want.keys() + for k in want: + _assert_equal_payloads([got[k]], [want[k]]) + elif isinstance(want, list): + assert isinstance(got, list) + _assert_equal_payloads(got, want) + else: + assert got == want diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index ede1de4e..8d1e7df2 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pickle import time from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any @@ -22,6 +21,7 @@ from torch import Tensor from transfer_queue.storage.clients.base import StorageClientFactory, StorageKVClient +from transfer_queue.utils import serial_utils from transfer_queue.utils.logging_utils import get_logger from transfer_queue.utils.tensor_utils import allocate_empty_tensors, get_nbytes, merge_contiguous_memory @@ -98,12 +98,17 @@ def __init__(self, config: dict[str, Any]): if ret != 0: raise RuntimeError(f"Mooncake store setup failed with error code: {ret}") - def put(self, keys: list[str], values: list[Any]) -> None: + def put(self, keys: list[str], values: list[Any]) -> list[dict | None]: """Stores multiple key-value pairs to MooncakeStore. Args: keys (List[str]): List of unique string identifiers. values (List[Any]): List of values to store (tensors, scalars, dicts, etc.). + + Returns: + Per-key metadata aligned with ``keys``. Tensor entries are ``None``; + non-tensor entries carry ``{"packed_size": int}`` so the get-side + can pre-allocate the receive buffer. """ if not isinstance(keys, list) or not isinstance(values, list): @@ -124,22 +129,34 @@ def put(self, keys: list[str], values: list[Any]) -> None: non_tensor_keys.append(key) non_tensor_values.append(value) - futures = [] + tensor_futures = [] + bytes_futures = [] with ThreadPoolExecutor(max_workers=MAX_WORKER_THREADS) as executor: for i in range(0, len(tensor_keys), BATCH_SIZE_LIMIT): batch_keys = tensor_keys[i : i + BATCH_SIZE_LIMIT] batch_tensors = tensor_values[i : i + BATCH_SIZE_LIMIT] - futures.append(executor.submit(self._put_tensors_thread_worker, batch_keys, batch_tensors)) + tensor_futures.append(executor.submit(self._put_tensors_thread_worker, batch_keys, batch_tensors)) for i in range(0, len(non_tensor_keys), BATCH_SIZE_LIMIT): batch_keys = non_tensor_keys[i : i + BATCH_SIZE_LIMIT] batch_values = non_tensor_values[i : i + BATCH_SIZE_LIMIT] - futures.append(executor.submit(self._put_bytes_thread_worker, batch_keys, batch_values)) + bytes_futures.append(executor.submit(self._put_bytes_thread_worker, batch_keys, batch_values)) - for future in as_completed(futures): + for future in tensor_futures: future.result() + packed_sizes = [] + for future in bytes_futures: + packed_sizes.extend(future.result()) + + # bytes results arrive in non-tensor submit order, which matches the order of + # non-tensor values; walk values once to scatter packed_size back to its key slot. + sizes_iter = iter(packed_sizes) + custom_backend_meta: list[dict | None] = [ + {"packed_size": next(sizes_iter)} if not isinstance(value, torch.Tensor) else None + for value in values + ] - return None + return custom_backend_meta def _put_tensors_thread_worker(self, batch_keys: list[str], batch_tensors: list[Tensor]) -> None: """Worker thread for putting batch of tensors to MooncakeStore.""" @@ -147,106 +164,47 @@ def _put_tensors_thread_worker(self, batch_keys: list[str], batch_tensors: list[ batch_ptrs, batch_sizes, _contiguous_tensors = self._preprocess_tensors_for_put(batch_tensors) batch_ptr_reduced, batch_sizes_reduced = merge_contiguous_memory(batch_ptrs, batch_sizes) self._register_all_buffers(batch_ptr_reduced, batch_sizes_reduced) - try: - results = self._store.batch_upsert_from(batch_keys, batch_ptrs, batch_sizes, config=self.replica_config) - if len(results) != len(batch_keys): - raise RuntimeError(f"batch_upsert_from returned {len(results)} results, expected {len(batch_keys)}") - - failed_indices = [j for j, r in enumerate(results) if r != 0] - if not failed_indices: - return - - current_failed_keys = [batch_keys[i] for i in failed_indices] - current_failed_codes = [results[i] for i in failed_indices] - current_failed_indices = failed_indices - - logger.error( - f"batch_upsert_from failed for keys {current_failed_keys} with error codes {current_failed_codes}. " - f"Retrying up to {MAX_RETRIES} times..." - ) - - for attempt in range(1, MAX_RETRIES + 1): - retry_ptrs = [batch_ptrs[i] for i in current_failed_indices] - retry_sizes = [batch_sizes[i] for i in current_failed_indices] - - retry_results = self._store.batch_upsert_from( - current_failed_keys, retry_ptrs, retry_sizes, config=self.replica_config - ) - - next_failed_indices = [] - next_failed_keys = [] - next_failed_codes = [] - - for i, ret in enumerate(retry_results): - if ret != 0: - next_failed_indices.append(current_failed_indices[i]) - next_failed_keys.append(current_failed_keys[i]) - next_failed_codes.append(ret) - - if not next_failed_indices: - logger.info("batch_upsert_from succeeded after retransmission.") - break # All retries in this attempt succeeded. - - logger.error( - f"batch_upsert_from retry {attempt}/{MAX_RETRIES} failed for {len(next_failed_keys)} keys " - f"with error codes {next_failed_codes}." - ) - - current_failed_indices = next_failed_indices - current_failed_keys = next_failed_keys - current_failed_codes = next_failed_codes - - if attempt < MAX_RETRIES: - time.sleep(RETRY_DELAY_SECONDS) - else: - raise RuntimeError( - f"batch_upsert_from failed for keys {current_failed_keys} with error codes " - f"{current_failed_codes} after retrying {MAX_RETRIES} times." - ) - + self._batch_upsert_with_retry(batch_keys, batch_ptrs, batch_sizes) finally: self._unregister_all_buffers(batch_ptr_reduced) - def _put_bytes_thread_worker(self, batch_keys: list[str], batch_values: list[Any]): + def _put_bytes_thread_worker(self, batch_keys: list[str], batch_values: list[Any]) -> list[int]: """Worker thread for putting batch of non-tensors to MooncakeStore.""" - serialized_values = [pickle.dumps(v, protocol=pickle.HIGHEST_PROTOCOL) for v in batch_values] - - # FIXME: When MooncakeStore supports per-key status codes for upsert_batch and get_batch, - # switch the bytes write/read paths from whole-batch retry to per-key selective retry, - # matching the tensor-path behaviour. - ret = self._store.upsert_batch(batch_keys, serialized_values, self.replica_config) - if ret == 0: - return - - logger.error( - f"upsert_batch failed for {len(batch_keys)} keys with error code: {ret}. " - f"Retrying up to {MAX_RETRIES} times..." - ) - - for attempt in range(1, MAX_RETRIES + 1): - ret = self._store.upsert_batch(batch_keys, serialized_values, self.replica_config) - if ret == 0: - logger.info("upsert_batch succeeded after retransmission.") - return + # TODO: switch to a pre-registered buffer from MooncakeStore once such an API is available. + region_ptrs: list[int] = [] + region_sizes: list[int] = [] + + def alloc(sizes: list[int]) -> list[Tensor]: + nonlocal region_ptrs, region_sizes + # `batch_packed_sizes` are byte counts. With torch.uint8 (1 byte/element), + # a 1-D shape of (N,) corresponds to exactly N bytes. We use + # `allocate_empty_tensors` to get N uint8 views over a single contiguous, + # register-able region. These are plain byte buffers, not real tensors; + # consumers apply the actual dtype/shape interpretation when unpacking. + dtypes = [torch.uint8] * len(sizes) + shapes = [(s,) for s in sizes] + buffers, _, region_ptrs, region_sizes = allocate_empty_tensors(dtypes, shapes) + return buffers + + buffers, batch_sizes = serial_utils.batch_encode_into(batch_values, alloc) + batch_ptrs = [b.data_ptr() for b in buffers] - logger.error( - f"upsert_batch retry {attempt}/{MAX_RETRIES} failed for {len(batch_keys)} keys with error code: {ret}." - ) - if attempt < MAX_RETRIES: - time.sleep(RETRY_DELAY_SECONDS) + self._register_all_buffers(region_ptrs, region_sizes) + try: + self._batch_upsert_with_retry(batch_keys, batch_ptrs, batch_sizes) + finally: + self._unregister_all_buffers(region_ptrs) - raise RuntimeError( - f"upsert_batch failed for {len(batch_keys)} keys with error code: {ret} after retrying {MAX_RETRIES} times." - ) + return batch_sizes def get( self, keys: list[str], shapes: list[Any] | None = None, dtypes: list[Any] | None = None, - custom_backend_meta: list[str] | None = None, + custom_backend_meta: list[dict | None] | None = None, ) -> list[Any]: """Get multiple key-value pairs from MooncakeStore. @@ -254,7 +212,8 @@ def get( keys: Keys to fetch. shapes: Expected tensor shapes (use [] for scalars). dtypes: Expected dtypes; use None for non-tensor data. - custom_backend_meta: Optional custom backend metadata. + custom_backend_meta: Per-key dicts; non-tensor entries must carry + ``{"packed_size": int}`` so the receive buffer can be sized. Returns: Retrieved values in the same order as input keys. @@ -274,6 +233,11 @@ def get( else: non_tensor_indices.append(i) + if non_tensor_indices and (custom_backend_meta is None or len(custom_backend_meta) != len(keys)): + raise ValueError( + "custom_backend_meta with per-key packed_size is required when any dtype is None." + ) + results = [None] * len(keys) futures = [] @@ -292,7 +256,12 @@ def get( for i in range(0, len(non_tensor_indices), BATCH_SIZE_LIMIT): batch_indexes = non_tensor_indices[i : i + BATCH_SIZE_LIMIT] batch_keys = [keys[i] for i in batch_indexes] - futures.append(executor.submit(self._get_bytes_thread_worker, batch_keys, batch_indexes)) + batch_packed_sizes = [custom_backend_meta[j]["packed_size"] for j in batch_indexes] + futures.append( + executor.submit( + self._get_bytes_thread_worker, batch_keys, batch_packed_sizes, batch_indexes + ) + ) for future in as_completed(futures): retrieved_values, batch_indexes = future.result() @@ -311,137 +280,178 @@ def _get_tensors_thread_worker( self._register_all_buffers(region_ptrs, region_sizes) try: - ret_codes = self._store.batch_get_into(batch_keys, batch_buffer_ptrs, batch_nbytes) - if len(ret_codes) != len(batch_keys): - raise RuntimeError(f"batch_get_into returned {len(ret_codes)} results, expected {len(batch_keys)}") + self._batch_get_into_with_retry(batch_keys, batch_buffer_ptrs, batch_nbytes) + finally: + self._unregister_all_buffers(region_ptrs) + + return batch_buffer_tensors, indexes + + def _get_bytes_thread_worker( + self, batch_keys: list[str], batch_packed_sizes: list[int], indexes: list[int] + ) -> tuple[list[Any], list[int]]: + # `batch_packed_sizes` are byte counts. With torch.uint8 (1 byte/element), + # a 1-D shape of (N,) corresponds to exactly N bytes. We use + # `allocate_empty_tensors` to get N uint8 views over a single contiguous, + # register-able region. These are plain byte buffers, not real tensors; + # consumers apply the actual dtype/shape interpretation when unpacking. + batch_shapes = [(sz,) for sz in batch_packed_sizes] + batch_dtypes = [torch.uint8] * len(batch_keys) + batch_nbytes = get_nbytes(batch_dtypes, batch_shapes) + batch_buffer_tensors, batch_buffer_ptrs, region_ptrs, region_sizes = allocate_empty_tensors( + batch_dtypes, batch_shapes + ) + + self._register_all_buffers(region_ptrs, region_sizes) + try: + self._batch_get_into_with_retry(batch_keys, batch_buffer_ptrs, batch_nbytes) + finally: + self._unregister_all_buffers(region_ptrs) - failed_indices = [i for i, ret in enumerate(ret_codes) if ret < 0] - if not failed_indices: - return batch_buffer_tensors, indexes + return serial_utils.batch_decode_from(batch_buffer_tensors), indexes - # error handling - current_failed_keys = [batch_keys[i] for i in failed_indices] - current_failed_codes = [ret_codes[i] for i in failed_indices] - current_failed_indices = failed_indices + def clear(self, keys: list[str], custom_backend_meta: list[Any] | None = None) -> None: + """Deletes multiple keys from MooncakeStore. - logger.error( - f"batch_get_into failed for keys {current_failed_keys} with error codes {current_failed_codes}. " - f"Retrying up to {MAX_RETRIES} times..." - ) + Args: + keys (List[str]): List of keys to remove. + custom_backend_meta (List[Any], optional): ... + """ + ret_codes = self._store.batch_remove(keys, force=True) + for i, ret in enumerate(ret_codes): + if not (ret == 0 or ret == -704): + logger.error(f"remove failed for key `{keys[i]}` with error code: {ret}") - for attempt in range(1, MAX_RETRIES + 1): - # Reuse the originally allocated pointers; no need to allocate/register new buffers. - retry_ptrs = [batch_buffer_ptrs[i] for i in current_failed_indices] - retry_nbytes = [batch_nbytes[i] for i in current_failed_indices] + def close(self): + """Closes MooncakeStore.""" + if self._store: + self._store.close() + self._store = None - retry_codes = self._store.batch_get_into(current_failed_keys, retry_ptrs, retry_nbytes) + def _batch_upsert_with_retry( + self, batch_keys: list[str], batch_ptrs: list[int], batch_sizes: list[int] + ) -> None: + """Run ``batch_upsert_from`` with per-key retry; raise on permanent failure. - next_failed_indices = [] - next_failed_keys = [] - next_failed_codes = [] + Caller owns the memory regions (register/unregister and lifetime of the + backing tensors/buffers). + """ + results = self._store.batch_upsert_from(batch_keys, batch_ptrs, batch_sizes, config=self.replica_config) + if len(results) != len(batch_keys): + raise RuntimeError(f"batch_upsert_from returned {len(results)} results, expected {len(batch_keys)}") - for i, ret in enumerate(retry_codes): - if ret < 0: - next_failed_indices.append(current_failed_indices[i]) - next_failed_keys.append(current_failed_keys[i]) - next_failed_codes.append(ret) + failed_indices = [j for j, r in enumerate(results) if r != 0] + if not failed_indices: + return - if not next_failed_indices: - logger.info("batch_get_into succeeded after retransmission.") - break # All retries in this attempt succeeded. + current_failed_keys = [batch_keys[i] for i in failed_indices] + current_failed_codes = [results[i] for i in failed_indices] + current_failed_indices = failed_indices - logger.error( - f"batch_get_into retry {attempt}/{MAX_RETRIES} failed for {len(next_failed_keys)} keys " - f"with error codes {next_failed_codes}." - ) + logger.error( + f"batch_upsert_from failed for keys {current_failed_keys} with error codes {current_failed_codes}. " + f"Retrying up to {MAX_RETRIES} times..." + ) - # Narrow down to still-failed items for the next retry attempt. - current_failed_indices = next_failed_indices - current_failed_keys = next_failed_keys - current_failed_codes = next_failed_codes + for attempt in range(1, MAX_RETRIES + 1): + retry_ptrs = [batch_ptrs[i] for i in current_failed_indices] + retry_sizes = [batch_sizes[i] for i in current_failed_indices] - if attempt < MAX_RETRIES: - time.sleep(RETRY_DELAY_SECONDS) - else: - # All retries exhausted. - raise RuntimeError( - f"batch_get_into failed for keys {current_failed_keys} with error codes " - f"{current_failed_codes} after retrying {MAX_RETRIES} times." - ) + retry_results = self._store.batch_upsert_from( + current_failed_keys, retry_ptrs, retry_sizes, config=self.replica_config + ) - finally: - self._unregister_all_buffers(region_ptrs) + next_failed_indices = [] + next_failed_keys = [] + next_failed_codes = [] - return batch_buffer_tensors, indexes + for i, ret in enumerate(retry_results): + if ret != 0: + next_failed_indices.append(current_failed_indices[i]) + next_failed_keys.append(current_failed_keys[i]) + next_failed_codes.append(ret) + + if not next_failed_indices: + logger.info("batch_upsert_from succeeded after retransmission.") + return + + logger.error( + f"batch_upsert_from retry {attempt}/{MAX_RETRIES} failed for {len(next_failed_keys)} keys " + f"with error codes {next_failed_codes}." + ) + + current_failed_indices = next_failed_indices + current_failed_keys = next_failed_keys + current_failed_codes = next_failed_codes + + if attempt < MAX_RETRIES: + time.sleep(RETRY_DELAY_SECONDS) - def _get_bytes_thread_worker(self, batch_keys: list[str], indexes: list[int]) -> tuple[list[Any], list[int]]: - raw_results = self._store.get_batch(batch_keys) - if len(raw_results) != len(batch_keys): - raise RuntimeError(f"get_batch returned {len(raw_results)} items, expected {len(batch_keys)}") + raise RuntimeError( + f"batch_upsert_from failed for keys {current_failed_keys} with error codes " + f"{current_failed_codes} after retrying {MAX_RETRIES} times." + ) - # FIXME: Use MooncakeStore provided ret codes to detect transmission failures when supported - # Currently we rely on empty bytes (b'') to detect transmission failures because - # MooncakeStore does not currently return a separate status code per key. - failed_indices = [i for i, result in enumerate(raw_results) if result == b""] - if failed_indices: - current_failed_keys = [batch_keys[i] for i in failed_indices] - current_failed_indices = failed_indices + def _batch_get_into_with_retry( + self, batch_keys: list[str], batch_buffer_ptrs: list[int], batch_nbytes: list[int] + ) -> None: + """Run ``batch_get_into`` with per-key retry; raise on permanent failure. - logger.error(f"get_batch failed for keys {current_failed_keys}. Retrying up to {MAX_RETRIES} times...") + Caller owns the receive buffers (allocate/register/unregister). + """ + ret_codes = self._store.batch_get_into(batch_keys, batch_buffer_ptrs, batch_nbytes) + if len(ret_codes) != len(batch_keys): + raise RuntimeError(f"batch_get_into returned {len(ret_codes)} results, expected {len(batch_keys)}") - for attempt in range(1, MAX_RETRIES + 1): - retry_results = self._store.get_batch(current_failed_keys) + failed_indices = [i for i, ret in enumerate(ret_codes) if ret < 0] + if not failed_indices: + return - next_failed_keys = [] - next_failed_indices = [] + current_failed_keys = [batch_keys[i] for i in failed_indices] + current_failed_codes = [ret_codes[i] for i in failed_indices] + current_failed_indices = failed_indices - for i, result in enumerate(retry_results): - original_idx = current_failed_indices[i] - if result == b"": - next_failed_keys.append(current_failed_keys[i]) - next_failed_indices.append(original_idx) - else: - # Write the successfully retried value back to its original slot immediately. - raw_results[original_idx] = result + logger.error( + f"batch_get_into failed for keys {current_failed_keys} with error codes {current_failed_codes}. " + f"Retrying up to {MAX_RETRIES} times..." + ) - if not next_failed_indices: - logger.info("get_batch succeeded after retransmission.") - break # All retries in this attempt succeeded. + for attempt in range(1, MAX_RETRIES + 1): + # Reuse the originally allocated pointers; no need to allocate/register new buffers. + retry_ptrs = [batch_buffer_ptrs[i] for i in current_failed_indices] + retry_nbytes = [batch_nbytes[i] for i in current_failed_indices] - logger.error(f"get_batch retry {attempt}/{MAX_RETRIES} failed for {len(next_failed_keys)} keys.") + retry_codes = self._store.batch_get_into(current_failed_keys, retry_ptrs, retry_nbytes) - # Narrow down to still-failed items for the next retry attempt. - current_failed_keys = next_failed_keys - current_failed_indices = next_failed_indices + next_failed_indices = [] + next_failed_keys = [] + next_failed_codes = [] - if attempt < MAX_RETRIES: - time.sleep(RETRY_DELAY_SECONDS) - else: - # All retries exhausted. - raise RuntimeError( - f"get_batch failed for keys {current_failed_keys} after retrying {MAX_RETRIES} times." - ) + for i, ret in enumerate(retry_codes): + if ret < 0: + next_failed_indices.append(current_failed_indices[i]) + next_failed_keys.append(current_failed_keys[i]) + next_failed_codes.append(ret) - deserialized_results = [pickle.loads(result) if result != b"" else None for result in raw_results] - return deserialized_results, indexes + if not next_failed_indices: + logger.info("batch_get_into succeeded after retransmission.") + return - def clear(self, keys: list[str], custom_backend_meta: list[Any] | None = None) -> None: - """Deletes multiple keys from MooncakeStore. + logger.error( + f"batch_get_into retry {attempt}/{MAX_RETRIES} failed for {len(next_failed_keys)} keys " + f"with error codes {next_failed_codes}." + ) - Args: - keys (List[str]): List of keys to remove. - custom_backend_meta (List[Any], optional): ... - """ - ret_codes = self._store.batch_remove(keys, force=True) - for i, ret in enumerate(ret_codes): - if not (ret == 0 or ret == -704): - logger.error(f"remove failed for key `{keys[i]}` with error code: {ret}") + current_failed_indices = next_failed_indices + current_failed_keys = next_failed_keys + current_failed_codes = next_failed_codes - def close(self): - """Closes MooncakeStore.""" - if self._store: - self._store.close() - self._store = None + if attempt < MAX_RETRIES: + time.sleep(RETRY_DELAY_SECONDS) + + raise RuntimeError( + f"batch_get_into failed for keys {current_failed_keys} with error codes " + f"{current_failed_codes} after retrying {MAX_RETRIES} times." + ) @staticmethod def _preprocess_tensors_for_put(values: list[Tensor]) -> tuple[list[int], list[int], list[Tensor]]: diff --git a/transfer_queue/utils/serial_utils.py b/transfer_queue/utils/serial_utils.py index 853b3a3e..813f8950 100644 --- a/transfer_queue/utils/serial_utils.py +++ b/transfer_queue/utils/serial_utils.py @@ -18,8 +18,10 @@ import pickle +import struct import warnings -from collections.abc import Sequence +from collections.abc import Callable, Sequence +from concurrent.futures import ThreadPoolExecutor from contextvars import ContextVar from typing import Any, TypeAlias @@ -387,3 +389,145 @@ def decode(frames: list) -> Any: if len(frames) >= 2 and frames[0] == _PICKLE_FALLBACK_SENTINEL: return pickle.loads(frames[1]) return _decoder.decode(frames) + + +# Packed buffer layout: +# [item_count: uint32 LE] +# [N × (payload_offset: uint32 LE, payload_size: uint32 LE)] +# [payload_0 ... payload_{N-1}] +_PACK_HEADER_FMT = " int: + """Total bytes required to pack ``items`` into one buffer.""" + return _PACK_HEADER_SIZE + len(items) * _PACK_ENTRY_SIZE + sum(memoryview(item).nbytes for item in items) + + +def pack_into(target_buffer: bytestr, items: Sequence[bytestr]) -> None: + """Concatenate ``items`` into ``target_buffer``, which must be at least ``calc_packed_size(items)`` bytes.""" + target_mv = memoryview(target_buffer) + required = calc_packed_size(items) + if target_mv.nbytes < required: + raise ValueError( + f"pack_into: target buffer has {target_mv.nbytes} bytes, requires {required}" + ) + struct.pack_into(_PACK_HEADER_FMT, target_mv, 0, len(items)) + + entry_offset = _PACK_HEADER_SIZE + payload_offset = _PACK_HEADER_SIZE + len(items) * _PACK_ENTRY_SIZE + + target_tensor = torch.frombuffer(target_mv, dtype=torch.uint8) + + for item in items: + item_mv = memoryview(item) + nbytes = item_mv.nbytes + struct.pack_into(_PACK_ENTRY_FMT, target_mv, entry_offset, payload_offset, nbytes) + src_tensor = torch.frombuffer(item_mv, dtype=torch.uint8) + target_tensor[payload_offset : payload_offset + nbytes].copy_(src_tensor) + entry_offset += _PACK_ENTRY_SIZE + payload_offset += nbytes + + +def unpack_from(source_buffer: bytestr) -> list[memoryview]: + """Split a packed buffer back into N memoryview slices over ``source_buffer``.""" + mv = memoryview(source_buffer) + item_count = struct.unpack_from(_PACK_HEADER_FMT, mv, 0)[0] + result: list[memoryview] = [] + for i in range(item_count): + offset, length = struct.unpack_from(_PACK_ENTRY_FMT, mv, _PACK_HEADER_SIZE + i * _PACK_ENTRY_SIZE) + result.append(mv[offset : offset + length]) + return result + + +def batch_encode_into( + objs: list[Any], + alloc_buff_func: Callable[[list[int]], list[Any]], + *, + num_workers: int = 1, +) -> tuple[list[np.ndarray | memoryview], list[int]]: + """Encode multiple objects in-place into caller-allocated buffers. + + Each object is msgpack-encoded (with zero-copy tensor/ndarray extraction) + and packed into a buffer slot supplied by ``alloc_buff_func``. Buffers are + written in place; the function returns the same buffer list along with + each slot's packed byte length. + + Args: + objs: Objects to encode, one per output buffer slot. + alloc_buff_func: Callable taking per-object packed sizes and returning + the corresponding buffer list. ``buffers[i]`` must be an + ``np.ndarray`` or ``memoryview`` holding at least ``sizes[i]`` + bytes. + num_workers: Thread count for parallel packing. Default 1 (serial); + set ``>1`` only when the upper layer is single-threaded. + + Returns: + tuple[list[np.ndarray | memoryview], list[int]]: The buffers returned by + ``alloc_buff_func`` with packed bytes written, paired with each + object's packed length (``<=`` buffer capacity). + + Note: + Lifetime is caller-owned: this function holds no references to the + buffers after return. Whatever backs the allocation must outlive all + downstream consumers. + + Example: + >>> # Pack two tensors into pre-allocated pinned uint8 tensor buffers + >>> def alloc(sizes): + ... return [torch.empty(s, dtype=torch.uint8, pin_memory=True) for s in sizes] + >>> objs = [torch.tensor([1, 2, 3]), torch.tensor([4.0, 5.0])] + >>> bufs, lengths = batch_encode_into(objs, alloc) + >>> print(f"packed sizes: {lengths}") + """ + batch_items = [encode(obj) for obj in objs] + batch_sizes = [calc_packed_size(items) for items in batch_items] + buffers = alloc_buff_func(batch_sizes) + + def _pack_one(pair: tuple[Any, list[bytestr]]) -> None: + buf, items = pair + mv = buf.numpy().data if hasattr(buf, "numpy") else memoryview(buf) + pack_into(mv, items) + + if num_workers <= 1: + for pair in zip(buffers, batch_items, strict=True): + _pack_one(pair) + else: + with ThreadPoolExecutor(max_workers=num_workers) as executor: + list(executor.map(_pack_one, zip(buffers, batch_items, strict=True))) + + return buffers, batch_sizes + + +def batch_decode_from(source_buffers: Sequence[np.ndarray | memoryview]) -> list[Any]: + """Reverse of ``batch_encode_into``: unpack and decode each filled buffer. + + Args: + source_buffers: Per-object receive buffers in order. Each must be an + ``np.ndarray`` or ``memoryview``. + + Returns: + list[Any]: Decoded objects, one per input buffer, in the same order. + + Note: + Tensors and ndarrays in the result are zero-copy views over the + source buffers. The Python reference chain (``torch.frombuffer`` -> + ``Py_buffer`` -> memoryview slice -> parent memoryview -> numpy array + -> original buffer) keeps the source alive as long as the decoded + object is reachable; the caller does NOT need to retain the source + buffer separately. + + Example: + >>> # Round-trip: encode then decode + >>> def alloc(sizes): + ... return [torch.empty(s, dtype=torch.uint8) for s in sizes] + >>> objs = [torch.tensor([1, 2, 3]), torch.tensor([4.0, 5.0])] + >>> bufs, _ = batch_encode_into(objs, alloc) + >>> decoded = batch_decode_from(bufs) + """ + return [ + decode(unpack_from(buf.numpy().data if hasattr(buf, "numpy") else memoryview(buf))) + for buf in source_buffers + ]