Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
280 changes: 280 additions & 0 deletions tests/test_serial_utils_batch_on_cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2025 The TransferQueue Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for the packed-buffer batch serialization helpers in
``transfer_queue.utils.serial_utils``:

* ``calc_packed_size``
* ``pack_into`` / ``unpack_from``
* ``batch_encode_into``
* ``batch_decode_from``
"""

import numpy as np
import pytest
import torch

from transfer_queue.utils import serial_utils


# ============================================================================
# low-level: calc_packed_size + pack_into + unpack_from (raw bytes layer)
# ============================================================================


def test_calc_packed_size_then_pack_unpack_roundtrip():
items = [b"hello", b"world!", b"x"]
size = serial_utils.calc_packed_size(items)
buf = bytearray(size)
serial_utils.pack_into(buf, items)
recovered = serial_utils.unpack_from(buf)
assert [bytes(mv) for mv in recovered] == items


def test_pack_into_writes_only_within_its_slice():
items = [b"alpha", b"beta", b"gamma"]
sz = serial_utils.calc_packed_size(items)
pad_before, pad_after = 17, 23
big = bytearray(pad_before + sz + pad_after)
serial_utils.pack_into(memoryview(big)[pad_before : pad_before + sz], items)

assert all(b == 0 for b in big[:pad_before])
assert all(b == 0 for b in big[pad_before + sz :])

recovered = serial_utils.unpack_from(memoryview(big)[pad_before : pad_before + sz])
assert [bytes(mv) for mv in recovered] == items


def test_unpack_from_zero_item_buffer():
items: list[bytes] = []
sz = serial_utils.calc_packed_size(items)
buf = bytearray(sz)
serial_utils.pack_into(buf, items)
assert serial_utils.unpack_from(buf) == []


# ============================================================================
# batch_encode_into + batch_decode_from (high-level batch layer)
# ============================================================================


def _mooncake_alloc(sizes: list[int]) -> list[torch.Tensor]:
"""Single big torch.uint8 tensor sliced into N views (mooncake-style)."""
big = torch.empty(sum(sizes), dtype=torch.uint8)
buffers: list[torch.Tensor] = []
offset = 0
for s in sizes:
buffers.append(big[offset : offset + s])
offset += s
return buffers


def _yuanrong_alloc(sizes: list[int]) -> list[bytearray]:
"""N independent bytearrays (yuanrong-style per-key buffer)."""
return [bytearray(s) for s in sizes]


def _decode_from_returned(buffers, alloc_kind):
if alloc_kind == "mooncake":
return serial_utils.batch_decode_from(buffers)
return serial_utils.batch_decode_from([bytes(b) for b in buffers])


def _roundtrip(values, alloc, alloc_kind, *, num_workers: int = 1):
buffers, sizes = serial_utils.batch_encode_into(values, alloc, num_workers=num_workers)
decoded = _decode_from_returned(buffers, alloc_kind)
return decoded, buffers, sizes


# ---- structural: return shapes / alloc contract ----


def test_batch_encode_into_return_shapes():
values = [{"x": 1}, "a string", torch.arange(8, dtype=torch.float32)]
buffers, sizes = serial_utils.batch_encode_into(values, _mooncake_alloc)

assert len(buffers) == len(values)
assert len(sizes) == len(values)
for b, s in zip(buffers, sizes, strict=True):
assert b.nbytes == s


def test_batch_encode_into_allows_padded_buffers():
"""Alloc may return buffers larger than requested sizes; batch_sizes still
reports the actual packed length, and the data round-trips correctly."""
pad = 32

def padded_alloc(sizes):
return [bytearray(s + pad) for s in sizes]

values = [b"alpha", {"k": "v"}, torch.arange(4, dtype=torch.float32)]
buffers, sizes = serial_utils.batch_encode_into(values, padded_alloc)

for b, s in zip(buffers, sizes, strict=True):
assert len(b) == s + pad

# decoding uses only the first `s` bytes; the pad tail is harmless
decoded = serial_utils.batch_decode_from([bytes(b[:s]) for b, s in zip(buffers, sizes, strict=True)])
_assert_equal_payloads(decoded, values)


# ---- semantic: encode → decode roundtrip preserves values ----


_ROUNDTRIP_PARAMS = [
pytest.param([42, 3.14, "hello", b"bytes"], id="primitives"),
pytest.param([{"a": 1, "b": [1, 2, 3]}, {"nested": {"k": "v"}}], id="nested-dicts"),
pytest.param([torch.arange(10, dtype=torch.float32)], id="single-tensor"),
pytest.param(
[
torch.arange(100, dtype=torch.float32),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add test coverage for

  1. Nested tensor (jagged/strided)
  2. 4D+ dim tensor

torch.randn(4, 4, dtype=torch.bfloat16),
torch.zeros(3, 5, dtype=torch.int64),
],
id="mixed-tensors",
),
pytest.param(
[np.arange(50, dtype=np.float64), np.ones((3, 3), dtype=np.int32)],
id="numpy-arrays",
),
pytest.param(
[{"meta": "v1", "arr": torch.arange(5, dtype=torch.float32)}, [1, 2, "three"]],
id="heterogeneous",
),
pytest.param(
[
torch.randn(2, 3, 4, 5, dtype=torch.float32),
torch.randn(2, 3, 4, 5, 6, dtype=torch.bfloat16),
],
id="high-rank-tensors",
),
pytest.param(
[
torch.nested.nested_tensor(
[torch.arange(3, dtype=torch.float32), torch.arange(5, dtype=torch.float32)],
layout=torch.strided,
),
torch.nested.nested_tensor(
[torch.randn(3, dtype=torch.bfloat16), torch.randn(5, dtype=torch.bfloat16)],
layout=torch.strided,
),
torch.nested.nested_tensor(
[torch.arange(4, dtype=torch.float32), torch.arange(7, dtype=torch.float32)],
layout=torch.jagged,
),
torch.nested.nested_tensor(
[torch.randn(4, dtype=torch.bfloat16), torch.randn(7, dtype=torch.bfloat16)],
layout=torch.jagged,
),
],
id="nested-tensors",
),
pytest.param(
[{"only": "one", "tensor": torch.arange(3, dtype=torch.float32)}],
id="single-value",
),
]


@pytest.mark.parametrize("values", _ROUNDTRIP_PARAMS)
def test_batch_encode_decode_roundtrip_mooncake(values):
decoded, *_ = _roundtrip(values, _mooncake_alloc, "mooncake")
_assert_equal_payloads(decoded, values)


@pytest.mark.parametrize("values", _ROUNDTRIP_PARAMS)
def test_batch_encode_decode_roundtrip_yuanrong(values):
decoded, *_ = _roundtrip(values, _yuanrong_alloc, "yuanrong")
_assert_equal_payloads(decoded, values)


def test_batch_encode_decode_empty_list():
calls = []

def alloc(sizes):
calls.append(list(sizes))
return []

buffers, sizes = serial_utils.batch_encode_into([], alloc)
assert buffers == [] and sizes == []
assert calls == [[]]
assert serial_utils.batch_decode_from([]) == []


# ---- num_workers: parallel pack must produce identical bytes vs serial ----


@pytest.mark.parametrize("values", _ROUNDTRIP_PARAMS)
def test_batch_encode_into_parallel_matches_serial(values):
serial_buffers, serial_sizes = serial_utils.batch_encode_into(
values, _yuanrong_alloc, num_workers=1
)
par_buffers, par_sizes = serial_utils.batch_encode_into(
values, _yuanrong_alloc, num_workers=4
)

assert serial_sizes == par_sizes
assert [bytes(b) for b in serial_buffers] == [bytes(b) for b in par_buffers]


def test_batch_encode_into_parallel_roundtrip_many_objects():
rng = np.random.default_rng(42)
values = []
for _ in range(64):
n = int(rng.integers(1, 257))
values.append(torch.from_numpy(rng.random(n).astype(np.float32)))

decoded, *_ = _roundtrip(values, _yuanrong_alloc, "yuanrong", num_workers=8)
_assert_equal_payloads(decoded, values)
Comment on lines +219 to +240
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can merge the two tests by setting @pytest.mark.parametrize("values", _ROUNDTRIP_PARAMS) to the roundtrip test



# ============================================================================
# helpers
# ============================================================================


def _assert_equal_payloads(decoded, original):
assert len(decoded) == len(original)
for got, want in zip(decoded, original, strict=True):
if isinstance(want, torch.Tensor):
assert isinstance(got, torch.Tensor)
assert got.dtype == want.dtype
if want.is_nested:
assert got.is_nested
assert got.layout == want.layout
got_subs = got.unbind()
want_subs = want.unbind()
assert len(got_subs) == len(want_subs)
for g, w in zip(got_subs, want_subs, strict=True):
assert g.shape == w.shape
assert torch.equal(g, w)
else:
assert got.shape == want.shape
assert torch.equal(got, want)
elif isinstance(want, np.ndarray):
assert isinstance(got, np.ndarray)
assert got.dtype == want.dtype
assert got.shape == want.shape
assert np.array_equal(got, want)
elif isinstance(want, dict):
assert isinstance(got, dict)
assert got.keys() == want.keys()
for k in want:
_assert_equal_payloads([got[k]], [want[k]])
elif isinstance(want, list):
assert isinstance(got, list)
_assert_equal_payloads(got, want)
else:
assert got == want
Loading