Skip to content

Commit b928421

Browse files
0oshowero0ascend-robot
authored andcommitted
[perf, serialization] Remove memory copy during zero-copy deserialization
Co-authored-by: 0oshowero0<o0shower0o@outlook.com> # message auto-generated for no-merge-commit merge: !13 merge main into main [perf, serialization] Remove memory copy during zero-copy deserialization Created-by: hanzhenyu8 Commit-by: 0oshowero0 Merged-by: ascend-robot Description: ## Background In the previous design, we manually performed a memory copy `bytearray(buffer)` to ensure the decoded tensor was **writable** for users. However, as described in [my previous PR](TransferQueue/TransferQueue#121 (comment)), actually we do not need to copy the buffer in `serial_utils.py`. There are post-processing steps (in `StorageManager`) that pack the received tensors into a `TensorDict` before handing them over to users, which handles the mutability requirement implicitly. ```python3 def _decode_tensor(self, arr: Any) -> torch.Tensor: ... arr = torch.frombuffer(bytearray(buffer), dtype=torch.uint8) ... ``` ## Solution In this PR, we remove the `bytearray` copy and ignore the resulting PyTorch warning regarding non-writable buffers. Additionally, `AsyncSimpleStorageManager` has been optimized to remove a redundant memory copy caused by `torch.nested.as_nested_tensor()`: ```python3 torch.stack(torch.nested.as_nested_tensor(v).unbind()) # this as_nested_tensor() call is redundant, and it will lead to another memory copy if v and all(isinstance(item, torch.Tensor) for item in v) and all(item.shape == v[0].shape for item in v) ``` ### Verification Code The following script validates that the `TensorDict` returned to the user is writable even without the explicit copy. ```python3 import zmq import time import torch from tensordict import TensorDict from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType from tensordict.tensorclass import NonTensorData import random import multiprocessing def create_complex_test_case(): batch = TensorDict( { "nested_tensor": torch.nested.as_nested_tensor( [torch.randn(4, 3), torch.randn(2, 4)], layout=torch.strided ), "jagged_tensor": torch.nested.as_nested_tensor( [torch.randn(4, 5), torch.randn(4, 54)], layout=torch.jagged ), "normal_tensor": torch.randn(2, 10, 3), "numpy_array": torch.randn(2, 2).numpy(), }, batch_size=2, ) return batch # -------------------------- Server(ROUTER Socket) -------------------------- def router_server(): context = zmq.Context() router_socket = context.socket(zmq.ROUTER) router_socket.bind("tcp://127.0.0.1:5555") print("ROUTER Server is ready, binding:tcp://127.0.0.1:5555") print("\n=== start communication(send_multipart/recv_multipart)===") messages = router_socket.recv_multipart() id = messages.pop(0) response_msg = ZMQMessage.deserialize(messages) print(response_msg) # Try to do in-place modification to see if it's allowed td = response_msg.body['data'] print( td['nested_tensor'] ) td['nested_tensor'] += 9999999 print( td['nested_tensor'] ) print( td['jagged_tensor'] ) td['jagged_tensor'] += 9999999 print( td['jagged_tensor'] ) print( td['normal_tensor'] ) td['normal_tensor'] += 9999999 print( td['normal_tensor'] ) print( td['numpy_array'] ) td['numpy_array'] += 9999999 print( td['numpy_array'] ) # it's safe to do in-place modification even we set # arr = torch.frombuffer(buffer, dtype=torch.uint8) router_socket.send_multipart([ id, b"ack", ]) time.sleep(1) router_socket.close() context.term() # -------------------------- Client(DEALER Socket) -------------------------- def dealer_client(): context = zmq.Context() dealer_socket = context.socket(zmq.DEALER) # set client identity dealer_socket.setsockopt_string(zmq.IDENTITY, "client_001") dealer_socket.connect("tcp://127.0.0.1:5555") print("DEALER Client is ready, connecting:tcp://127.0.0.1:5555") time.sleep(0.5) test_data = create_complex_test_case() request_msg = ZMQMessage.create( request_type=ZMQRequestType.PUT_DATA, sender_id='123', receiver_id='456', body={"data":test_data}, ) dealer_socket.send_multipart(request_msg.serialize(),copy=False) response_frames = dealer_socket.recv_multipart() response_frame1 = response_frames[0] print(f"DEALER Receive → Frame: {response_frame1}") dealer_socket.close() context.term() # -------------------------- Start all processes -------------------------- if __name__ == "__main__": # Start server process server_process = multiprocessing.Process(target=router_server) server_process.start() time.sleep(0.5) # Start client process client_process = multiprocessing.Process(target=dealer_client) client_process.start() server_process.join() client_process.join() print("Test Finish!") ``` See merge request: Ascend/TransferQueue!13
1 parent d292bec commit b928421

3 files changed

Lines changed: 22 additions & 6 deletions

File tree

tests/test_async_simple_storage_manager.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,12 @@ async def test_async_storage_manager_mock_operations(mock_async_storage_manager)
142142

143143
manager._put_to_single_storage_unit = AsyncMock()
144144
manager._get_from_single_storage_unit = AsyncMock(
145-
return_value=([0, 1], ["test_field"], {"test_field": [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])]})
145+
return_value=(
146+
[0, 1],
147+
["test_field"],
148+
{"test_field": [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])]},
149+
b"this is the serialized message object.",
150+
)
146151
)
147152
manager._clear_single_storage_unit = AsyncMock()
148153
manager.notify_data_update = AsyncMock()

transfer_queue/storage/managers/simple_backend_manager.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict:
285285

286286
# post-process data segments to generate a batch of data
287287
merged_data: dict[int, dict[str, torch.Tensor]] = {}
288-
for global_indexes, fields, data_from_single_storage_unit in results:
288+
for global_indexes, fields, data_from_single_storage_unit, messages in results:
289289
field_getter = itemgetter(*fields)
290290
field_values = field_getter(data_from_single_storage_unit)
291291

@@ -303,9 +303,12 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict:
303303
for field in metadata.field_names:
304304
ordered_data[field] = [merged_data[global_idx][field] for global_idx in metadata.global_indexes]
305305

306+
# In the final packing stage we intentionally perform a memory copy through torch.stack and as_nested_tensor.
307+
# This detaches the received tensors from the original zero‑copy buffers,
308+
# gives them their own lifetime, and ensures the resulting tensors are writable.
306309
tensor_data = {
307310
field: (
308-
torch.stack(torch.nested.as_nested_tensor(v).unbind())
311+
torch.stack(v)
309312
if v
310313
and all(isinstance(item, torch.Tensor) for item in v)
311314
and all(item.shape == v[0].shape for item in v)
@@ -341,8 +344,10 @@ async def _get_from_single_storage_unit(
341344

342345
if response_msg.request_type == ZMQRequestType.GET_DATA_RESPONSE:
343346
# Return data and index information from this storage unit
347+
# We need to return messages to get_data() since the zero-copy deserialization directly points to the
348+
# memory of messages object.
344349
storage_unit_data = response_msg.body["data"]
345-
return global_indexes, fields, storage_unit_data
350+
return global_indexes, fields, storage_unit_data, messages
346351
else:
347352
raise RuntimeError(
348353
f"Failed to get data from storage unit {target_storage_unit}: "

transfer_queue/utils/serial_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import logging
2121
import os
2222
import pickle
23+
import warnings
2324
from collections.abc import Sequence
2425
from inspect import isclass
2526
from types import FunctionType
@@ -52,6 +53,10 @@
5253
logger = logging.getLogger(__name__)
5354
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
5455

56+
# Ignore warnings about non-writable buffers from torch.frombuffer. Upper codes will ensure
57+
# the tensors are writable to users.
58+
warnings.filterwarnings(action="ignore", message=r"The given buffer is not writable*", category=UserWarning)
59+
5560

5661
class MsgpackEncoder:
5762
"""Encoder with custom torch tensor and numpy array serialization.
@@ -154,8 +159,9 @@ def _decode_tensor(self, arr: Any) -> torch.Tensor:
154159
if not buffer: # torch.frombuffer doesn't like empty buffers
155160
assert 0 in shape
156161
return torch.empty(shape, dtype=torch_dtype)
157-
# Create uint8 array and convert read-only buffer into writable bytearray
158-
arr = torch.frombuffer(bytearray(buffer), dtype=torch.uint8)
162+
# Create uint8 array. Upper codes should make sure the tensor is cloned so it has their own lifetime and
163+
# become writable to users.
164+
arr = torch.frombuffer(buffer, dtype=torch.uint8)
159165
# Convert back to proper shape & type
160166
return arr.view(torch_dtype).view(shape)
161167

0 commit comments

Comments
 (0)