Skip to content

Commit b221e2a

Browse files
author
ascend-robot
committed
[feat] Introduce Zero-Copy to use YuanrongStorageClient for transmitting CPU Tensors
Co-authored-by: liwenlin<liwenlin8@huawei.com> # message auto-generated for no-merge-commit merge: !10 merge ds_zero_copy into main [feat] Introduce Zero-Copy to use YuanrongStorageClient for transmitting CPU Tensors Created-by: Lexie-7 Commit-by: liwenlin Merged-by: ascend-robot Description: ### Summary When connecting to the backend of the YuanrongStorageClient, zero-copy is activated to enhance the transmission speed. ### Change 1. Modified the `transfer_queue/storage/clients/yuanrong_client.py` to call the zero-copy interface, and performed operations such as serialization and pack. 2. Add mget and mset UT: `tests/test_yuanrong_storage_client.py` . ### Testing - Test on CPU: `pytest tests/test_yuanrong_storage_client.py ` ### Result When transmitting 512 pieces of data, each 32 MB in size, with a total data volume of 16GB: End-to-end **Get** took **10s** and the bandwidth was **1.6 GB/s**. The time spent calling the **YuanrongStorageClient** interface was **2.27s** with a bandwidth of **7.05 GB/s**. End-to-end **Put** took **3.42s** and the bandwidth was **4.68 GB/s**. The time spent calling the **YuanrongStorageClient** interface was **3.32s** with a bandwidth of **4.83 GB/s**. ### Related Links - Previous issues can be viewed: [[Feat]: Try zero-copy serialize objects that can be converted to memoryview](TransferQueue/TransferQueue#147) - Yuanrong Datasystem PR: [https://atomgit.com/openeuler/yuanrong-datasystem/pull/141](https://atomgit.com/openeuler/yuanrong-datasystem/pull/141) See merge request: Ascend/TransferQueue!10
1 parent bfc0ba6 commit b221e2a

2 files changed

Lines changed: 195 additions & 14 deletions

File tree

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2+
# Copyright 2025 The TransferQueue Team
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import sys
17+
from pathlib import Path
18+
from unittest.mock import MagicMock
19+
20+
import numpy as np
21+
import pytest
22+
import torch
23+
24+
parent_dir = Path(__file__).resolve().parent.parent
25+
sys.path.append(str(parent_dir))
26+
27+
from transfer_queue.storage.clients.yuanrong_client import ( # noqa: E402
28+
YuanrongStorageClient,
29+
)
30+
31+
32+
class MockBuffer:
33+
def __init__(self, size):
34+
self.data = bytearray(size)
35+
36+
def mutable_data(self):
37+
return self.data
38+
39+
40+
class TestYuanrongStorageZCopy:
41+
@pytest.fixture
42+
def mock_kv_client(self, mocker):
43+
mock_client = MagicMock()
44+
mock_client.init.return_value = None
45+
46+
mocker.patch("yr.datasystem.KVClient", return_value=mock_client)
47+
mocker.patch("yr.datasystem.DsTensorClient")
48+
mocker.patch("transfer_queue.storage.clients.yuanrong_client.TORCH_NPU_IMPORTED", False)
49+
50+
return mock_client
51+
52+
@pytest.fixture
53+
def storage_client(self, mock_kv_client):
54+
return YuanrongStorageClient({"host": "127.0.0.1", "port": 31501})
55+
56+
def test_mset_mget_p2p(self, storage_client, mocker):
57+
# Mock serialization/deserialization
58+
def mock_serialization(obj):
59+
if isinstance(obj, torch.Tensor):
60+
return [obj.numpy().tobytes()]
61+
return [str(obj).encode("utf-8")]
62+
63+
def mock_deserialization(items):
64+
data = items[0]
65+
if len(data) == 12:
66+
return torch.from_numpy(np.frombuffer(data, dtype=np.float32).copy())
67+
try:
68+
return data.tobytes().decode("utf-8")
69+
except UnicodeDecodeError:
70+
return data
71+
72+
mocker.patch("transfer_queue.storage.clients.yuanrong_client.serialization", side_effect=mock_serialization)
73+
mocker.patch("transfer_queue.storage.clients.yuanrong_client.deserialization", side_effect=mock_deserialization)
74+
75+
stored_raw_buffers = []
76+
77+
def side_effect_mcreate(keys, sizes):
78+
buffers = [MockBuffer(size) for size in sizes]
79+
for b in buffers:
80+
stored_raw_buffers.append(b.mutable_data())
81+
return 0, buffers
82+
83+
storage_client._cpu_ds_client.mcreate.side_effect = side_effect_mcreate
84+
storage_client._cpu_ds_client.get_buffers.return_value = (0, stored_raw_buffers)
85+
86+
storage_client.mset_zcopy(
87+
["tensor_key", "string_key"], [torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32), "hello yuanrong"]
88+
)
89+
results = storage_client.mget_zcopy(["tensor_key", "string_key"])
90+
91+
assert torch.allclose(results[0], torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32))
92+
assert results[1] == "hello yuanrong"

transfer_queue/storage/clients/yuanrong_client.py

Lines changed: 103 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,18 @@
1616
import logging
1717
import os
1818
import pickle
19-
from typing import Any, Optional
19+
import struct
20+
from concurrent.futures import ThreadPoolExecutor
21+
from typing import Any, Optional, TypeAlias
2022

2123
import torch
2224
from torch import Tensor
2325

2426
from transfer_queue.storage.clients.base import TransferQueueStorageKVClient
2527
from transfer_queue.storage.clients.factory import StorageClientFactory
28+
from transfer_queue.utils.serial_utils import _decoder, _encoder
29+
30+
bytestr: TypeAlias = bytes | bytearray | memoryview
2631

2732
logger = logging.getLogger(__name__)
2833
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
@@ -31,11 +36,84 @@
3136
CPU_DS_CLIENT_KEYS_LIMIT: int = 1999
3237
YUANRONG_DATASYSTEM_IMPORTED: bool = True
3338
TORCH_NPU_IMPORTED: bool = True
39+
DS_MAX_WORKERS: int = 16
3440
try:
3541
from yr import datasystem
3642
except ImportError:
3743
YUANRONG_DATASYSTEM_IMPORTED = False
3844

45+
# Header: number of entries (uint32, little-endian)
46+
HEADER_FMT = "<I"
47+
HEADER_SIZE = struct.calcsize(HEADER_FMT)
48+
# Entry: (payload_offset: uint32, payload_size: uint32)
49+
ENTRY_FMT = "<II"
50+
ENTRY_SIZE = struct.calcsize(ENTRY_FMT)
51+
52+
53+
def calc_packed_size(items: list[memoryview]) -> int:
54+
"""
55+
Calculate the total size (in bytes) required to pack a list of memoryview items
56+
into the structured binary format used by pack_into.
57+
58+
Args:
59+
items: List of memoryview objects to be packed.
60+
61+
Returns:
62+
Total buffer size in bytes.
63+
"""
64+
return HEADER_SIZE + len(items) * ENTRY_SIZE + sum(item.nbytes for item in items)
65+
66+
67+
def pack_into(target: memoryview, items: list[memoryview]):
68+
"""
69+
Pack multiple contiguous buffers into a single buffer.
70+
┌───────────────┐
71+
│ item_count │ uint32
72+
├───────────────┤
73+
│ entries │ N * item entries
74+
├───────────────┤
75+
│ payload blob │ N * concatenated buffers
76+
└───────────────┘
77+
78+
Args:
79+
target (memoryview): A writable memoryview returned by StateValueBuffer.MutableData().
80+
It must be large enough to accommodate the total number of bytes of HEADER + ENTRY_TABLE + all items.
81+
This buffer is usually mapped to shared memory or Zero-Copy memory area.
82+
items (List[memoryview]): List of read-only memory views (e.g., from serialized objects). Each item must support
83+
the buffer protocol and be readable as raw bytes.
84+
85+
"""
86+
struct.pack_into(HEADER_FMT, target, 0, len(items))
87+
88+
entry_offset = HEADER_SIZE
89+
payload_offset = HEADER_SIZE + len(items) * ENTRY_SIZE
90+
91+
target_tensor = torch.frombuffer(target, dtype=torch.uint8)
92+
93+
for item in items:
94+
struct.pack_into(ENTRY_FMT, target, entry_offset, payload_offset, item.nbytes)
95+
src_tensor = torch.frombuffer(item, dtype=torch.uint8)
96+
target_tensor[payload_offset : payload_offset + item.nbytes].copy_(src_tensor)
97+
entry_offset += ENTRY_SIZE
98+
payload_offset += item.nbytes
99+
100+
101+
def unpack_from(source: memoryview) -> list[bytestr]:
102+
"""
103+
Unpack multiple contiguous buffers from a single packed buffer.
104+
Args:
105+
source (memoryview): The packed source buffer.
106+
Returns:
107+
list[bytestr]: List of unpacked contiguous buffers.
108+
"""
109+
mv = memoryview(source)
110+
item_count = struct.unpack_from(HEADER_FMT, mv, 0)[0]
111+
offsets = []
112+
for i in range(item_count):
113+
offset, length = struct.unpack_from(ENTRY_FMT, mv, HEADER_SIZE + i * ENTRY_SIZE)
114+
offsets.append((offset, length))
115+
return [mv[offset : offset + length] for offset, length in offsets]
116+
39117

40118
@StorageClientFactory.register("YuanrongStorageClient")
41119
class YuanrongStorageClient(TransferQueueStorageKVClient):
@@ -106,6 +184,19 @@ def _create_empty_npu_tensorlist(self, shapes, dtypes):
106184
tensors.append(tensor)
107185
return tensors
108186

187+
def mset_zcopy(self, keys: list[str], objs: list[Any]):
188+
items_list = [[memoryview(b) for b in _encoder.encode(obj)] for obj in objs]
189+
packed_sizes = [calc_packed_size(items) for items in items_list]
190+
status, buffers = self._cpu_ds_client.mcreate(keys, packed_sizes)
191+
tasks = [(target.MutableData(), item) for target, item in zip(buffers, items_list, strict=False)]
192+
with ThreadPoolExecutor(max_workers=DS_MAX_WORKERS) as executor:
193+
list(executor.map(lambda p: pack_into(*p), tasks))
194+
self._cpu_ds_client.mset_buffer(buffers)
195+
196+
def mget_zcopy(self, keys: list[str]) -> list[Any]:
197+
status, buffers = self._cpu_ds_client.get_buffers(keys, timeout_ms=500)
198+
return [_decoder.decode(unpack_from(buffer)) if buffer is not None else None for buffer in buffers]
199+
109200
def _batch_put(self, keys: list[str], values: list[Any]):
110201
"""Stores a batch of key-value pairs to remote storage, splitting by device type.
111202
@@ -125,17 +216,15 @@ def _batch_put(self, keys: list[str], values: list[Any]):
125216
cpu_values = []
126217

127218
for key, value in zip(keys, values, strict=True):
128-
if isinstance(value, Tensor) and value.device.type == "npu":
219+
if isinstance(value, torch.Tensor) and value.device.type == "npu":
129220
if not value.is_contiguous():
130221
raise ValueError(f"NPU Tensor is not contiguous: {value}")
131222
npu_keys.append(key)
132223
npu_values.append(value)
133224

134225
else:
135226
cpu_keys.append(key)
136-
# TODO: Optimize serialization of tensors
137-
# Serializing slice of tensors results in entire tensors being serialized
138-
cpu_values.append(pickle.dumps(value.clone() if isinstance(value, Tensor) else value))
227+
cpu_values.append(pickle.dumps(value))
139228

140229
# put NPU data
141230
for i in range(0, len(npu_keys), NPU_DS_CLIENT_KEYS_LIMIT):
@@ -157,11 +246,10 @@ def _batch_put(self, keys: list[str], values: list[Any]):
157246

158247
else:
159248
# All data goes through CPU path
160-
pickled_values = [pickle.dumps(v.clone() if isinstance(v, Tensor) else v) for v in values]
161249
for i in range(0, len(keys), CPU_DS_CLIENT_KEYS_LIMIT):
162250
batch_keys = keys[i : i + CPU_DS_CLIENT_KEYS_LIMIT]
163-
batch_vals = pickled_values[i : i + CPU_DS_CLIENT_KEYS_LIMIT]
164-
self._cpu_ds_client.mset(batch_keys, batch_vals)
251+
batch_vals = values[i : i + CPU_DS_CLIENT_KEYS_LIMIT]
252+
self.mset_zcopy(batch_keys, batch_vals)
165253

166254
def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]:
167255
"""Stores multiple key-value pairs to remote storage.
@@ -253,16 +341,17 @@ def _batch_get(self, keys: list[str], shapes: list, dtypes: list) -> list[Any]:
253341
results[idx] = pickle.loads(raw_val)
254342

255343
return results
344+
256345
else:
257-
# npu is not available, goes through cpu_ds_client
258346
results = [None] * len(keys)
259-
idx = 0
347+
cpu_indices = list(range(len(keys)))
348+
260349
for i in range(0, len(keys), CPU_DS_CLIENT_KEYS_LIMIT):
261350
batch_keys = keys[i : i + CPU_DS_CLIENT_KEYS_LIMIT]
262-
raw_values = self._cpu_ds_client.get(batch_keys)
263-
for raw_val in raw_values:
264-
results[idx] = pickle.loads(raw_val)
265-
idx += 1
351+
batch_indices = cpu_indices[i : i + CPU_DS_CLIENT_KEYS_LIMIT]
352+
objects = self.mget_zcopy(batch_keys)
353+
for idx, obj in zip(batch_indices, objects, strict=False):
354+
results[idx] = obj
266355
return results
267356

268357
def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> list[Any]:

0 commit comments

Comments
 (0)