Skip to content

Commit 36d58d2

Browse files
committed
replace pack/unpack and encode/decode with batch_encode_into/batch_decode_from
Signed-off-by: dpj135 <958208521@qq.com>
1 parent fbe56bb commit 36d58d2

2 files changed

Lines changed: 24 additions & 88 deletions

File tree

tests/test_yuanrong_client_zero_copy.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,22 +50,22 @@ def storage_client(self, mock_kv_client):
5050

5151
def test_mset_mget_p2p(self, storage_client, mocker):
5252
# Mock serialization/deserialization
53-
def mock_serialization(obj):
53+
def mock_encode(obj):
5454
if isinstance(obj, torch.Tensor):
5555
return [obj.numpy().tobytes()]
5656
return [str(obj).encode("utf-8")]
5757

58-
def mock_deserialization(items):
59-
data = items[0]
58+
def mock_decode(frames):
59+
data = frames[0]
6060
if len(data) == 12:
6161
return torch.from_numpy(np.frombuffer(data, dtype=np.float32).copy())
6262
try:
6363
return data.tobytes().decode("utf-8")
6464
except UnicodeDecodeError:
6565
return data
6666

67-
mocker.patch("transfer_queue.storage.clients.yuanrong_client._encoder.encode", side_effect=mock_serialization)
68-
mocker.patch("transfer_queue.storage.clients.yuanrong_client._decoder.decode", side_effect=mock_deserialization)
67+
mocker.patch("transfer_queue.utils.serial_utils.encode", side_effect=mock_encode)
68+
mocker.patch("transfer_queue.utils.serial_utils.decode", side_effect=mock_decode)
6969

7070
stored_raw_buffers = []
7171

transfer_queue/storage/clients/yuanrong_client.py

Lines changed: 19 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import struct
1716
from abc import ABC, abstractmethod
1817
from concurrent.futures import ThreadPoolExecutor
1918
from typing import Any, Callable, Optional
@@ -23,7 +22,7 @@
2322

2423
from transfer_queue.storage.clients.base import StorageClientFactory, StorageKVClient
2524
from transfer_queue.utils.logging_utils import get_logger
26-
from transfer_queue.utils.serial_utils import _decoder, _encoder
25+
from transfer_queue.utils.serial_utils import batch_decode_from, batch_encode_into
2726
from transfer_queue.utils.yuanrong_utils import find_reachable_host
2827

2928
logger = get_logger(__name__)
@@ -193,19 +192,11 @@ def _create_empty_npu_tensorlist(self, shapes: list[Any], dtypes: list[Any]) ->
193192
class GeneralKVClientAdapter(StorageStrategy):
194193
"""Adapter for general-purpose KV storage with serialization.
195194
Using yr.datasystem.KVClient to connect datasystem backends.
196-
The serialization method uses '_decoder' and '_encoder' from 'transfer_queue.utils.serial_utils'.
195+
The serialization method uses 'batch_encode_into' and 'batch_decode_from' from 'transfer_queue.utils.serial_utils'.
197196
"""
198197

199198
PUT_KEYS_LIMIT: int = 10_000
200199
GET_CLEAR_KEYS_LIMIT: int = 10_000
201-
202-
# Header: number of entries (uint32, little-endian)
203-
HEADER_FMT = "<I"
204-
HEADER_SIZE = struct.calcsize(HEADER_FMT)
205-
# Entry: (payload_offset: uint32, payload_size: uint32)
206-
ENTRY_FMT = "<II"
207-
ENTRY_SIZE = struct.calcsize(ENTRY_FMT)
208-
209200
DS_MAX_WORKERS: int = 16
210201

211202
def __init__(self, config: dict):
@@ -270,84 +261,23 @@ def clear(self, keys: list[str]) -> None:
270261
batch_keys = keys[i : i + self.GET_CLEAR_KEYS_LIMIT]
271262
self._ds_client.delete(batch_keys)
272263

273-
@classmethod
274-
def calc_packed_size(cls, items: list[memoryview]) -> int:
275-
"""
276-
Calculate the total size (in bytes) required to pack a list of memoryview items
277-
into the structured binary format used by pack_into.
278-
279-
Args:
280-
items: List of memoryview objects to be packed.
281-
282-
Returns:
283-
Total buffer size in bytes.
284-
"""
285-
return cls.HEADER_SIZE + len(items) * cls.ENTRY_SIZE + sum(item.nbytes for item in items)
286-
287-
@classmethod
288-
def pack_into(cls, target: memoryview, items: list[memoryview]):
289-
"""
290-
Pack multiple contiguous buffers into a single buffer.
291-
┌───────────────┐
292-
│ item_count │ uint32
293-
├───────────────┤
294-
│ entries │ N * item entries
295-
├───────────────┤
296-
│ payload blob │ N * concatenated buffers
297-
└───────────────┘
298-
299-
Args:
300-
target (memoryview): A writable memoryview returned by StateValueBuffer.MutableData().
301-
It must be large enough to accommodate the total number of bytes of HEADER + ENTRY_TABLE + all items.
302-
This buffer is usually mapped to shared memory or Zero-Copy memory area.
303-
items (List[memoryview]): List of read-only memory views (e.g., from serialized objects).
304-
Each item must support the buffer protocol and be readable as raw bytes.
305-
306-
"""
307-
struct.pack_into(cls.HEADER_FMT, target, 0, len(items))
308-
309-
entry_offset = cls.HEADER_SIZE
310-
payload_offset = cls.HEADER_SIZE + len(items) * cls.ENTRY_SIZE
311-
312-
target_tensor = torch.frombuffer(target, dtype=torch.uint8)
313-
314-
for item in items:
315-
struct.pack_into(cls.ENTRY_FMT, target, entry_offset, payload_offset, item.nbytes)
316-
src_tensor = torch.frombuffer(item, dtype=torch.uint8)
317-
target_tensor[payload_offset : payload_offset + item.nbytes].copy_(src_tensor)
318-
entry_offset += cls.ENTRY_SIZE
319-
payload_offset += item.nbytes
320-
321-
@classmethod
322-
def unpack_from(cls, source: memoryview) -> list[memoryview]:
323-
"""
324-
Unpack multiple contiguous buffers from a single packed buffer.
325-
Args:
326-
source (memoryview): The packed source buffer.
327-
Returns:
328-
list[memoryview]: List of unpacked contiguous buffers.
329-
"""
330-
mv = memoryview(source)
331-
item_count = struct.unpack_from(cls.HEADER_FMT, mv, 0)[0]
332-
offsets = []
333-
for i in range(item_count):
334-
offset, length = struct.unpack_from(cls.ENTRY_FMT, mv, cls.HEADER_SIZE + i * cls.ENTRY_SIZE)
335-
offsets.append((offset, length))
336-
return [mv[offset : offset + length] for offset, length in offsets]
337-
338264
def mset_zero_copy(self, keys: list[str], objs: list[Any]):
339265
"""Store multiple objects in zero-copy mode using parallel serialization and buffer packing.
340266
341267
Args:
342268
keys (list[str]): List of string keys under which the objects will be stored.
343269
objs (list[Any]): List of Python objects to store (e.g., tensors, strings).
344270
"""
345-
items_list = [[memoryview(b) for b in _encoder.encode(obj)] for obj in objs]
346-
packed_sizes = [self.calc_packed_size(items) for items in items_list]
347-
buffers = self._ds_client.mcreate(keys, packed_sizes)
348-
tasks = [(target.MutableData(), item) for target, item in zip(buffers, items_list, strict=True)]
349-
with ThreadPoolExecutor(max_workers=self.DS_MAX_WORKERS) as executor:
350-
list(executor.map(lambda p: self.pack_into(*p), tasks))
271+
buffers: list = []
272+
273+
def alloc(sizes):
274+
# DataSystem buffers must be converted via MutableData() to obtain
275+
# a memoryview-compatible data structure for zero-copy packing.
276+
mcreate_bufs = self._ds_client.mcreate(keys, sizes)
277+
buffers.extend(mcreate_bufs)
278+
return [buf.MutableData() for buf in mcreate_bufs]
279+
280+
batch_encode_into(objs, alloc, num_workers=self.DS_MAX_WORKERS)
351281
self._ds_client.mset_buffer(buffers)
352282

353283
def mget_zero_copy(self, keys: list[str]) -> list[Any]:
@@ -360,7 +290,13 @@ def mget_zero_copy(self, keys: list[str]) -> list[Any]:
360290
list[Any]: List of deserialized objects corresponding to the input keys.
361291
"""
362292
buffers = self._ds_client.get_buffers(keys)
363-
return [_decoder.decode(self.unpack_from(buffer)) if buffer is not None else None for buffer in buffers]
293+
valid_indexes = [i for i, buf in enumerate(buffers) if buf is not None]
294+
valid_bufs = [buffers[i] for i in valid_indexes]
295+
decoded_objs = batch_decode_from(valid_bufs)
296+
results = [None] * len(keys)
297+
for idx, obj in zip(valid_indexes, decoded_objs, strict=True):
298+
results[idx] = obj
299+
return results
364300

365301

366302
@StorageClientFactory.register("YuanrongStorageClient")

0 commit comments

Comments
 (0)