Skip to content

Commit a726485

Browse files
committed
[refactor] Provide common serialization tools for KV backends to speed up tensor serial in nested values
Signed-off-by: xupinjie <xupinjie321@outlook.com>
1 parent 94baa19 commit a726485

3 files changed

Lines changed: 634 additions & 200 deletions

File tree

Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2+
# Copyright 2025 The TransferQueue Team
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Unit tests for the packed-buffer batch serialization helpers in
17+
``transfer_queue.utils.serial_utils``:
18+
19+
* ``calc_packed_size``
20+
* ``pack_into`` / ``unpack_from``
21+
* ``batch_encode_into``
22+
* ``batch_decode_from``
23+
"""
24+
25+
import numpy as np
26+
import pytest
27+
import torch
28+
29+
from transfer_queue.utils import serial_utils
30+
31+
32+
# ============================================================================
33+
# low-level: calc_packed_size + pack_into + unpack_from (raw bytes layer)
34+
# ============================================================================
35+
36+
37+
def test_calc_packed_size_then_pack_unpack_roundtrip():
38+
items = [b"hello", b"world!", b"x"]
39+
size = serial_utils.calc_packed_size(items)
40+
buf = bytearray(size)
41+
serial_utils.pack_into(buf, items)
42+
recovered = serial_utils.unpack_from(buf)
43+
assert [bytes(mv) for mv in recovered] == items
44+
45+
46+
def test_pack_into_writes_only_within_its_slice():
47+
items = [b"alpha", b"beta", b"gamma"]
48+
sz = serial_utils.calc_packed_size(items)
49+
pad_before, pad_after = 17, 23
50+
big = bytearray(pad_before + sz + pad_after)
51+
serial_utils.pack_into(memoryview(big)[pad_before : pad_before + sz], items)
52+
53+
assert all(b == 0 for b in big[:pad_before])
54+
assert all(b == 0 for b in big[pad_before + sz :])
55+
56+
recovered = serial_utils.unpack_from(memoryview(big)[pad_before : pad_before + sz])
57+
assert [bytes(mv) for mv in recovered] == items
58+
59+
60+
def test_unpack_from_zero_item_buffer():
61+
items: list[bytes] = []
62+
sz = serial_utils.calc_packed_size(items)
63+
buf = bytearray(sz)
64+
serial_utils.pack_into(buf, items)
65+
assert serial_utils.unpack_from(buf) == []
66+
67+
68+
# ============================================================================
69+
# batch_encode_into + batch_decode_from (high-level batch layer)
70+
# ============================================================================
71+
72+
73+
def _mooncake_alloc(sizes: list[int]) -> list[torch.Tensor]:
74+
"""Single big torch.uint8 tensor sliced into N views (mooncake-style)."""
75+
big = torch.empty(sum(sizes), dtype=torch.uint8)
76+
buffers: list[torch.Tensor] = []
77+
offset = 0
78+
for s in sizes:
79+
buffers.append(big[offset : offset + s])
80+
offset += s
81+
return buffers
82+
83+
84+
def _yuanrong_alloc(sizes: list[int]) -> list[bytearray]:
85+
"""N independent bytearrays (yuanrong-style per-key buffer)."""
86+
return [bytearray(s) for s in sizes]
87+
88+
89+
def _decode_from_returned(buffers, alloc_kind):
90+
if alloc_kind == "mooncake":
91+
return serial_utils.batch_decode_from(buffers)
92+
return serial_utils.batch_decode_from([bytes(b) for b in buffers])
93+
94+
95+
def _roundtrip(values, alloc, alloc_kind, *, num_workers: int = 1):
96+
buffers, sizes = serial_utils.batch_encode_into(values, alloc, num_workers=num_workers)
97+
decoded = _decode_from_returned(buffers, alloc_kind)
98+
return decoded, buffers, sizes
99+
100+
101+
# ---- structural: return shapes / alloc contract ----
102+
103+
104+
def test_batch_encode_into_return_shapes():
105+
values = [{"x": 1}, "a string", torch.arange(8, dtype=torch.float32)]
106+
buffers, sizes = serial_utils.batch_encode_into(values, _mooncake_alloc)
107+
108+
assert len(buffers) == len(values)
109+
assert len(sizes) == len(values)
110+
for b, s in zip(buffers, sizes, strict=True):
111+
assert b.nbytes == s
112+
113+
114+
def test_batch_encode_into_allows_padded_buffers():
115+
"""Alloc may return buffers larger than requested sizes; batch_sizes still
116+
reports the actual packed length, and the data round-trips correctly."""
117+
pad = 32
118+
119+
def padded_alloc(sizes):
120+
return [bytearray(s + pad) for s in sizes]
121+
122+
values = [b"alpha", {"k": "v"}, torch.arange(4, dtype=torch.float32)]
123+
buffers, sizes = serial_utils.batch_encode_into(values, padded_alloc)
124+
125+
for b, s in zip(buffers, sizes, strict=True):
126+
assert len(b) == s + pad
127+
128+
# decoding uses only the first `s` bytes; the pad tail is harmless
129+
decoded = serial_utils.batch_decode_from([bytes(b[:s]) for b, s in zip(buffers, sizes, strict=True)])
130+
_assert_equal_payloads(decoded, values)
131+
132+
133+
# ---- semantic: encode → decode roundtrip preserves values ----
134+
135+
136+
_ROUNDTRIP_PARAMS = [
137+
pytest.param([42, 3.14, "hello", b"bytes"], id="primitives"),
138+
pytest.param([{"a": 1, "b": [1, 2, 3]}, {"nested": {"k": "v"}}], id="nested-dicts"),
139+
pytest.param([torch.arange(10, dtype=torch.float32)], id="single-tensor"),
140+
pytest.param(
141+
[
142+
torch.arange(100, dtype=torch.float32),
143+
torch.randn(4, 4, dtype=torch.bfloat16),
144+
torch.zeros(3, 5, dtype=torch.int64),
145+
],
146+
id="mixed-tensors",
147+
),
148+
pytest.param(
149+
[np.arange(50, dtype=np.float64), np.ones((3, 3), dtype=np.int32)],
150+
id="numpy-arrays",
151+
),
152+
pytest.param(
153+
[{"meta": "v1", "arr": torch.arange(5, dtype=torch.float32)}, [1, 2, "three"]],
154+
id="heterogeneous",
155+
),
156+
pytest.param(
157+
[
158+
torch.randn(2, 3, 4, 5, dtype=torch.float32),
159+
torch.randn(2, 3, 4, 5, 6, dtype=torch.bfloat16),
160+
],
161+
id="high-rank-tensors",
162+
),
163+
pytest.param(
164+
[
165+
torch.nested.nested_tensor(
166+
[torch.arange(3, dtype=torch.float32), torch.arange(5, dtype=torch.float32)],
167+
layout=torch.strided,
168+
),
169+
torch.nested.nested_tensor(
170+
[torch.randn(3, dtype=torch.bfloat16), torch.randn(5, dtype=torch.bfloat16)],
171+
layout=torch.strided,
172+
),
173+
torch.nested.nested_tensor(
174+
[torch.arange(4, dtype=torch.float32), torch.arange(7, dtype=torch.float32)],
175+
layout=torch.jagged,
176+
),
177+
torch.nested.nested_tensor(
178+
[torch.randn(4, dtype=torch.bfloat16), torch.randn(7, dtype=torch.bfloat16)],
179+
layout=torch.jagged,
180+
),
181+
],
182+
id="nested-tensors",
183+
),
184+
pytest.param(
185+
[{"only": "one", "tensor": torch.arange(3, dtype=torch.float32)}],
186+
id="single-value",
187+
),
188+
]
189+
190+
191+
@pytest.mark.parametrize("values", _ROUNDTRIP_PARAMS)
192+
def test_batch_encode_decode_roundtrip_mooncake(values):
193+
decoded, *_ = _roundtrip(values, _mooncake_alloc, "mooncake")
194+
_assert_equal_payloads(decoded, values)
195+
196+
197+
@pytest.mark.parametrize("values", _ROUNDTRIP_PARAMS)
198+
def test_batch_encode_decode_roundtrip_yuanrong(values):
199+
decoded, *_ = _roundtrip(values, _yuanrong_alloc, "yuanrong")
200+
_assert_equal_payloads(decoded, values)
201+
202+
203+
def test_batch_encode_decode_empty_list():
204+
calls = []
205+
206+
def alloc(sizes):
207+
calls.append(list(sizes))
208+
return []
209+
210+
buffers, sizes = serial_utils.batch_encode_into([], alloc)
211+
assert buffers == [] and sizes == []
212+
assert calls == [[]]
213+
assert serial_utils.batch_decode_from([]) == []
214+
215+
216+
# ---- num_workers: parallel pack must produce identical bytes vs serial ----
217+
218+
219+
@pytest.mark.parametrize("values", _ROUNDTRIP_PARAMS)
220+
def test_batch_encode_into_parallel_matches_serial(values):
221+
serial_buffers, serial_sizes = serial_utils.batch_encode_into(
222+
values, _yuanrong_alloc, num_workers=1
223+
)
224+
par_buffers, par_sizes = serial_utils.batch_encode_into(
225+
values, _yuanrong_alloc, num_workers=4
226+
)
227+
228+
assert serial_sizes == par_sizes
229+
assert [bytes(b) for b in serial_buffers] == [bytes(b) for b in par_buffers]
230+
231+
232+
def test_batch_encode_into_parallel_roundtrip_many_objects():
233+
rng = np.random.default_rng(42)
234+
values = []
235+
for _ in range(64):
236+
n = int(rng.integers(1, 257))
237+
values.append(torch.from_numpy(rng.random(n).astype(np.float32)))
238+
239+
decoded, *_ = _roundtrip(values, _yuanrong_alloc, "yuanrong", num_workers=8)
240+
_assert_equal_payloads(decoded, values)
241+
242+
243+
# ============================================================================
244+
# helpers
245+
# ============================================================================
246+
247+
248+
def _assert_equal_payloads(decoded, original):
249+
assert len(decoded) == len(original)
250+
for got, want in zip(decoded, original, strict=True):
251+
if isinstance(want, torch.Tensor):
252+
assert isinstance(got, torch.Tensor)
253+
assert got.dtype == want.dtype
254+
if want.is_nested:
255+
assert got.is_nested
256+
assert got.layout == want.layout
257+
got_subs = got.unbind()
258+
want_subs = want.unbind()
259+
assert len(got_subs) == len(want_subs)
260+
for g, w in zip(got_subs, want_subs, strict=True):
261+
assert g.shape == w.shape
262+
assert torch.equal(g, w)
263+
else:
264+
assert got.shape == want.shape
265+
assert torch.equal(got, want)
266+
elif isinstance(want, np.ndarray):
267+
assert isinstance(got, np.ndarray)
268+
assert got.dtype == want.dtype
269+
assert got.shape == want.shape
270+
assert np.array_equal(got, want)
271+
elif isinstance(want, dict):
272+
assert isinstance(got, dict)
273+
assert got.keys() == want.keys()
274+
for k in want:
275+
_assert_equal_payloads([got[k]], [want[k]])
276+
elif isinstance(want, list):
277+
assert isinstance(got, list)
278+
_assert_equal_payloads(got, want)
279+
else:
280+
assert got == want

0 commit comments

Comments
 (0)