Skip to content

Commit fbe56bb

Browse files
authored
[refactor] Provide common serialization tools for KV backends to speed up tensor serial in nested values (#107)
## Problem Multimodal RL puts nested-dict values into TransferQueue (e.g. `{"pixel_values": Tensor, "image_grid_thw": Tensor, ...}`). The old `MooncakeStoreClient` only zero-copied plain tensors; anything else — including dicts that contain tensors — got pickled through Mooncake's internal bytes pool, which saturated under concurrent multi-MB GETs and forced a `VERL_TQ_MC_GET_RETRIES` retry workaround upstream. ## Refactor Treat every value as one opaque payload. Each value is encoded by the existing zero-copy msgpack encoder, the whole batch is packed into one contiguous CPU buffer, registered once, and shipped through Mooncake's RDMA-backed `batch_upsert_from` / `batch_get_into`. The pickled bytes path and the retry workaround it required are gone. ## Test - [x] 30B-VL + onethinker, 2×8 GPU, RDMA: 3-step and 10-step runs clean — no retry / allocator failure / AssertionError. <img width="2306" height="578" alt="datapath_perf_3way_v2" src="https://github.com/user-attachments/assets/23894842-24dd-43f4-8a21-3087cf642378" /> Signed-off-by: xupinjie <xupinjie321@outlook.com>
1 parent f3a9a4b commit fbe56bb

3 files changed

Lines changed: 623 additions & 203 deletions

File tree

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2+
# Copyright 2025 The TransferQueue Team
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Unit tests for the packed-buffer batch serialization helpers in
17+
``transfer_queue.utils.serial_utils``:
18+
19+
* ``calc_packed_size``
20+
* ``pack_into`` / ``unpack_from``
21+
* ``batch_encode_into``
22+
* ``batch_decode_from``
23+
"""
24+
25+
import numpy as np
26+
import pytest
27+
import torch
28+
29+
from transfer_queue.utils import serial_utils
30+
31+
# ============================================================================
32+
# low-level: calc_packed_size + pack_into + unpack_from (raw bytes layer)
33+
# ============================================================================
34+
35+
36+
def test_calc_packed_size_then_pack_unpack_roundtrip():
37+
items = [b"hello", b"world!", b"x"]
38+
size = serial_utils.calc_packed_size(items)
39+
buf = bytearray(size)
40+
serial_utils.pack_into(buf, items)
41+
recovered = serial_utils.unpack_from(buf)
42+
assert [bytes(mv) for mv in recovered] == items
43+
44+
45+
def test_pack_into_writes_only_within_its_slice():
46+
items = [b"alpha", b"beta", b"gamma"]
47+
sz = serial_utils.calc_packed_size(items)
48+
pad_before, pad_after = 17, 23
49+
big = bytearray(pad_before + sz + pad_after)
50+
serial_utils.pack_into(memoryview(big)[pad_before : pad_before + sz], items)
51+
52+
assert all(b == 0 for b in big[:pad_before])
53+
assert all(b == 0 for b in big[pad_before + sz :])
54+
55+
recovered = serial_utils.unpack_from(memoryview(big)[pad_before : pad_before + sz])
56+
assert [bytes(mv) for mv in recovered] == items
57+
58+
59+
def test_unpack_from_zero_item_buffer():
60+
items: list[bytes] = []
61+
sz = serial_utils.calc_packed_size(items)
62+
buf = bytearray(sz)
63+
serial_utils.pack_into(buf, items)
64+
assert serial_utils.unpack_from(buf) == []
65+
66+
67+
# ============================================================================
68+
# batch_encode_into + batch_decode_from (high-level batch layer)
69+
# ============================================================================
70+
71+
72+
def _mooncake_alloc(sizes: list[int]) -> list[torch.Tensor]:
73+
"""Single big torch.uint8 tensor sliced into N views (mooncake-style)."""
74+
big = torch.empty(sum(sizes), dtype=torch.uint8)
75+
buffers: list[torch.Tensor] = []
76+
offset = 0
77+
for s in sizes:
78+
buffers.append(big[offset : offset + s])
79+
offset += s
80+
return buffers
81+
82+
83+
def _yuanrong_alloc(sizes: list[int]) -> list[bytearray]:
84+
"""N independent bytearrays (yuanrong-style per-key buffer)."""
85+
return [bytearray(s) for s in sizes]
86+
87+
88+
def _decode_from_returned(buffers, alloc_kind):
89+
if alloc_kind == "mooncake":
90+
return serial_utils.batch_decode_from(buffers)
91+
return serial_utils.batch_decode_from([bytes(b) for b in buffers])
92+
93+
94+
def _roundtrip(values, alloc, alloc_kind, *, num_workers: int = 1):
95+
buffers, sizes = serial_utils.batch_encode_into(values, alloc, num_workers=num_workers)
96+
decoded = _decode_from_returned(buffers, alloc_kind)
97+
return decoded, buffers, sizes
98+
99+
100+
# ---- structural: return shapes / alloc contract ----
101+
102+
103+
def test_batch_encode_into_return_shapes():
104+
values = [{"x": 1}, "a string", torch.arange(8, dtype=torch.float32)]
105+
buffers, sizes = serial_utils.batch_encode_into(values, _mooncake_alloc)
106+
107+
assert len(buffers) == len(values)
108+
assert len(sizes) == len(values)
109+
for b, s in zip(buffers, sizes, strict=True):
110+
assert b.nbytes == s
111+
112+
113+
def test_batch_encode_into_allows_padded_buffers():
114+
"""Alloc may return buffers larger than requested sizes; batch_sizes still
115+
reports the actual packed length, and the data round-trips correctly."""
116+
pad = 32
117+
118+
def padded_alloc(sizes):
119+
return [bytearray(s + pad) for s in sizes]
120+
121+
values = [b"alpha", {"k": "v"}, torch.arange(4, dtype=torch.float32)]
122+
buffers, sizes = serial_utils.batch_encode_into(values, padded_alloc)
123+
124+
for b, s in zip(buffers, sizes, strict=True):
125+
assert len(b) == s + pad
126+
127+
# decoding uses only the first `s` bytes; the pad tail is harmless
128+
decoded = serial_utils.batch_decode_from([bytes(b[:s]) for b, s in zip(buffers, sizes, strict=True)])
129+
_assert_equal_payloads(decoded, values)
130+
131+
132+
# ---- semantic: encode → decode roundtrip preserves values ----
133+
134+
135+
_ROUNDTRIP_PARAMS = [
136+
pytest.param([42, 3.14, "hello", b"bytes"], id="primitives"),
137+
pytest.param([{"a": 1, "b": [1, 2, 3]}, {"nested": {"k": "v"}}], id="nested-dicts"),
138+
pytest.param([torch.arange(10, dtype=torch.float32)], id="single-tensor"),
139+
pytest.param(
140+
[
141+
torch.arange(100, dtype=torch.float32),
142+
torch.randn(4, 4, dtype=torch.bfloat16),
143+
torch.zeros(3, 5, dtype=torch.int64),
144+
],
145+
id="mixed-tensors",
146+
),
147+
pytest.param(
148+
[np.arange(50, dtype=np.float64), np.ones((3, 3), dtype=np.int32)],
149+
id="numpy-arrays",
150+
),
151+
pytest.param(
152+
[{"meta": "v1", "arr": torch.arange(5, dtype=torch.float32)}, [1, 2, "three"]],
153+
id="heterogeneous",
154+
),
155+
pytest.param(
156+
[
157+
torch.randn(2, 3, 4, 5, dtype=torch.float32),
158+
torch.randn(2, 3, 4, 5, 6, dtype=torch.bfloat16),
159+
],
160+
id="high-rank-tensors",
161+
),
162+
pytest.param(
163+
[
164+
torch.nested.nested_tensor(
165+
[torch.arange(3, dtype=torch.float32), torch.arange(5, dtype=torch.float32)],
166+
layout=torch.strided,
167+
),
168+
torch.nested.nested_tensor(
169+
[torch.randn(3, dtype=torch.bfloat16), torch.randn(5, dtype=torch.bfloat16)],
170+
layout=torch.strided,
171+
),
172+
torch.nested.nested_tensor(
173+
[torch.arange(4, dtype=torch.float32), torch.arange(7, dtype=torch.float32)],
174+
layout=torch.jagged,
175+
),
176+
torch.nested.nested_tensor(
177+
[torch.randn(4, dtype=torch.bfloat16), torch.randn(7, dtype=torch.bfloat16)],
178+
layout=torch.jagged,
179+
),
180+
],
181+
id="nested-tensors",
182+
),
183+
pytest.param(
184+
[{"only": "one", "tensor": torch.arange(3, dtype=torch.float32)}],
185+
id="single-value",
186+
),
187+
]
188+
189+
190+
@pytest.mark.parametrize("values", _ROUNDTRIP_PARAMS)
191+
def test_batch_encode_decode_roundtrip_mooncake(values):
192+
decoded, *_ = _roundtrip(values, _mooncake_alloc, "mooncake")
193+
_assert_equal_payloads(decoded, values)
194+
195+
196+
@pytest.mark.parametrize("values", _ROUNDTRIP_PARAMS)
197+
def test_batch_encode_decode_roundtrip_yuanrong(values):
198+
decoded, *_ = _roundtrip(values, _yuanrong_alloc, "yuanrong")
199+
_assert_equal_payloads(decoded, values)
200+
201+
202+
def test_batch_encode_decode_empty_list():
203+
calls = []
204+
205+
def alloc(sizes):
206+
calls.append(list(sizes))
207+
return []
208+
209+
buffers, sizes = serial_utils.batch_encode_into([], alloc)
210+
assert buffers == [] and sizes == []
211+
assert calls == [[]]
212+
assert serial_utils.batch_decode_from([]) == []
213+
214+
215+
# ---- num_workers: parallel pack must produce identical bytes vs serial ----
216+
217+
218+
@pytest.mark.parametrize("values", _ROUNDTRIP_PARAMS)
219+
def test_batch_encode_into_parallel_matches_serial(values):
220+
serial_buffers, serial_sizes = serial_utils.batch_encode_into(values, _yuanrong_alloc, num_workers=1)
221+
par_buffers, par_sizes = serial_utils.batch_encode_into(values, _yuanrong_alloc, num_workers=4)
222+
223+
assert serial_sizes == par_sizes
224+
assert [bytes(b) for b in serial_buffers] == [bytes(b) for b in par_buffers]
225+
226+
227+
def test_batch_encode_into_parallel_roundtrip_many_objects():
228+
rng = np.random.default_rng(42)
229+
values = []
230+
for _ in range(64):
231+
n = int(rng.integers(1, 257))
232+
values.append(torch.from_numpy(rng.random(n).astype(np.float32)))
233+
234+
decoded, *_ = _roundtrip(values, _yuanrong_alloc, "yuanrong", num_workers=8)
235+
_assert_equal_payloads(decoded, values)
236+
237+
238+
# ============================================================================
239+
# helpers
240+
# ============================================================================
241+
242+
243+
def _assert_equal_payloads(decoded, original):
244+
assert len(decoded) == len(original)
245+
for got, want in zip(decoded, original, strict=True):
246+
if isinstance(want, torch.Tensor):
247+
assert isinstance(got, torch.Tensor)
248+
assert got.dtype == want.dtype
249+
if want.is_nested:
250+
assert got.is_nested
251+
assert got.layout == want.layout
252+
got_subs = got.unbind()
253+
want_subs = want.unbind()
254+
assert len(got_subs) == len(want_subs)
255+
for g, w in zip(got_subs, want_subs, strict=True):
256+
assert g.shape == w.shape
257+
assert torch.equal(g, w)
258+
else:
259+
assert got.shape == want.shape
260+
assert torch.equal(got, want)
261+
elif isinstance(want, np.ndarray):
262+
assert isinstance(got, np.ndarray)
263+
assert got.dtype == want.dtype
264+
assert got.shape == want.shape
265+
assert np.array_equal(got, want)
266+
elif isinstance(want, dict):
267+
assert isinstance(got, dict)
268+
assert got.keys() == want.keys()
269+
for k in want:
270+
_assert_equal_payloads([got[k]], [want[k]])
271+
elif isinstance(want, list):
272+
assert isinstance(got, list)
273+
_assert_equal_payloads(got, want)
274+
else:
275+
assert got == want

0 commit comments

Comments
 (0)