Skip to content

Commit 44c683b

Browse files
authored
[fix,serialization] Fix FieldMeta status update and remove unnecessary copy and use recv_multipart(copy=False) by default (Ascend#46)
1. Use `recv_multipart(copy=False)` by default, which returns a writable memory object in `zmq.Frame` 2. Remove redundant memory copy 3. Fix bugs introduced in Ascend#45 when dynamically update the `FieldMeta` status. --- Scripts to validate: ```python3 import zmq import torch import numpy as np import multiprocessing import time # Import your serialization and deserialization interfaces # Assuming the code you just wrote is saved in serial_utils.py in the same directory from transfer_queue.utils.serial_utils import encode, decode def sender(): """Sender process""" context = zmq.Context() socket = context.socket(zmq.PUSH) socket.bind("tcp://127.0.0.1:5557") # 1. Create data to send (containing a numpy array and a torch tensor) np_arr = np.ones((5, 5), dtype=np.float32) pt_tensor = torch.ones((5, 5), dtype=torch.float32) print(f"[Sender] Created NumPy Array, shape: {np_arr.shape}") print(f"[Sender] Created PyTorch Tensor, shape: {pt_tensor.shape}") # 2. Assemble into a complex nested structure for testing payload = { "metadata": {"version": "1.0", "description": "test zero-copy"}, "data_np": np_arr, "data_pt": pt_tensor } # 3. Call your encode interface # encode returns a list[bytes], the first frame is msgpack, and subsequent frames are underlying memory views frames = encode(payload) print(f"[Sender] Serialization complete, generated {len(frames)} frames.") # 4. Send using multipart + zero-copy socket.send_multipart(frames, copy=False) print("[Sender] Data sending complete.") time.sleep(1) # Wait for receiver to process socket.close() context.term() def receiver(): """Receiver process""" context = zmq.Context() socket = context.socket(zmq.PULL) socket.connect("tcp://127.0.0.1:5557") # 1. Receive multiple frames with zero-copy # copy=False will make the returned result a list[zmq.Frame] frames = socket.recv_multipart(copy=False) print(f"\n[Receiver] Received {len(frames)} frames.") print(f"[Receiver] Frame types: {[type(f) for f in frames]}") # 2. Call your decode interface to deserialize # At this point, the passed frames are a list of zmq.Frame payload = decode(frames) recv_np = payload["data_np"] recv_pt = payload["data_pt"] print("\n--- Verify NumPy ---") print(f"[Receiver] NumPy object type: {type(recv_np)}, dtype: {recv_np.dtype}") print(f"[Receiver] Is NumPy memory writeable: {recv_np.flags.writeable}") try: recv_np[0, 0] = 99.0 print(f"[Receiver] ✅ NumPy write successful! recv_np[0, 0] = {recv_np[0, 0]}") except Exception as e: print(f"[Receiver] ❌ NumPy write failed: {e}") print("\n--- Verify PyTorch Tensor ---") print(f"[Receiver] Tensor object type: {type(recv_pt)}, dtype: {recv_pt.dtype}") try: recv_pt[0, 0] = 88.0 print(f"[Receiver] ✅ Tensor write successful! recv_pt[0, 0] = {recv_pt[0, 0].item()}") except Exception as e: print(f"[Receiver] ❌ Tensor write failed: {e}") socket.close() context.term() if __name__ == '__main__': p_recv = multiprocessing.Process(target=receiver) p_send = multiprocessing.Process(target=sender) p_recv.start() time.sleep(0.5) # Ensure the receiver binds first p_send.start() p_send.join() p_recv.join() ``` --------- Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
1 parent cde575c commit 44c683b

10 files changed

Lines changed: 492 additions & 160 deletions

tests/test_async_simple_storage_manager.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,8 @@ async def test_async_storage_manager_mock_operations(mock_async_storage_manager)
128128
manager._put_to_single_storage_unit = AsyncMock()
129129
manager._get_from_single_storage_unit = AsyncMock(
130130
return_value=(
131-
[0, 1],
132131
["test_field"],
133132
{"test_field": [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])]},
134-
b"this is the serialized message object.",
135133
)
136134
)
137135
manager._clear_single_storage_unit = AsyncMock()
@@ -286,7 +284,7 @@ async def fake_get(global_indexes, fields, target_storage_unit=None, **kwargs):
286284
su = target_storage_unit
287285
called_with[su] = list(global_indexes)
288286
tensors = [torch.zeros(2) for _ in global_indexes]
289-
return global_indexes, fields, {"f": tensors}, b""
287+
return fields, {"f": tensors}
290288

291289
manager._get_from_single_storage_unit = fake_get
292290

tests/test_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def _handle_requests(self):
9191
try:
9292
socks = dict(poller.poll(100)) # 100ms timeout
9393
if self.request_socket in socks:
94-
messages = self.request_socket.recv_multipart()
94+
messages = self.request_socket.recv_multipart(copy=False)
9595
identity = messages.pop(0)
9696
serialized_msg = messages
9797
request_msg = ZMQMessage.deserialize(serialized_msg)
@@ -332,7 +332,7 @@ def _handle_data_requests(self):
332332
try:
333333
socks = dict(poller.poll(100)) # 100ms timeout
334334
if self.data_socket in socks:
335-
messages = self.data_socket.recv_multipart()
335+
messages = self.data_socket.recv_multipart(copy=False)
336336
identity = messages.pop(0)
337337
serialized_msg = messages
338338
msg = ZMQMessage.deserialize(serialized_msg)

0 commit comments

Comments
 (0)