Commit b928421
[perf, serialization] Remove memory copy during zero-copy deserialization
Co-authored-by: 0oshowero0<o0shower0o@outlook.com>
# message auto-generated for no-merge-commit merge:
!13 merge main into main
[perf, serialization] Remove memory copy during zero-copy deserialization
Created-by: hanzhenyu8
Commit-by: 0oshowero0
Merged-by: ascend-robot
Description: ## Background
In the previous design, we manually performed a memory copy `bytearray(buffer)` to ensure the decoded tensor was **writable** for users. However, as described in [my previous PR](TransferQueue/TransferQueue#121 (comment)), actually we do not need to copy the buffer in `serial_utils.py`. There are post-processing steps (in `StorageManager`) that pack the received tensors into a `TensorDict` before handing them over to users, which handles the mutability requirement implicitly.
```python3
def _decode_tensor(self, arr: Any) -> torch.Tensor:
...
arr = torch.frombuffer(bytearray(buffer), dtype=torch.uint8)
...
```
## Solution
In this PR, we remove the `bytearray` copy and ignore the resulting PyTorch warning regarding non-writable buffers.
Additionally, `AsyncSimpleStorageManager` has been optimized to remove a redundant memory copy caused by `torch.nested.as_nested_tensor()`:
```python3
torch.stack(torch.nested.as_nested_tensor(v).unbind()) # this as_nested_tensor() call is redundant, and it will lead to another memory copy
if v
and all(isinstance(item, torch.Tensor) for item in v)
and all(item.shape == v[0].shape for item in v)
```
### Verification Code
The following script validates that the `TensorDict` returned to the user is writable even without the explicit copy.
```python3
import zmq
import time
import torch
from tensordict import TensorDict
from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType
from tensordict.tensorclass import NonTensorData
import random
import multiprocessing
def create_complex_test_case():
batch = TensorDict(
{
"nested_tensor": torch.nested.as_nested_tensor(
[torch.randn(4, 3), torch.randn(2, 4)], layout=torch.strided
),
"jagged_tensor": torch.nested.as_nested_tensor(
[torch.randn(4, 5), torch.randn(4, 54)], layout=torch.jagged
),
"normal_tensor": torch.randn(2, 10, 3),
"numpy_array": torch.randn(2, 2).numpy(),
},
batch_size=2,
)
return batch
# -------------------------- Server(ROUTER Socket) --------------------------
def router_server():
context = zmq.Context()
router_socket = context.socket(zmq.ROUTER)
router_socket.bind("tcp://127.0.0.1:5555")
print("ROUTER Server is ready, binding:tcp://127.0.0.1:5555")
print("\n=== start communication(send_multipart/recv_multipart)===")
messages = router_socket.recv_multipart()
id = messages.pop(0)
response_msg = ZMQMessage.deserialize(messages)
print(response_msg)
# Try to do in-place modification to see if it's allowed
td = response_msg.body['data']
print( td['nested_tensor'] )
td['nested_tensor'] += 9999999
print( td['nested_tensor'] )
print( td['jagged_tensor'] )
td['jagged_tensor'] += 9999999
print( td['jagged_tensor'] )
print( td['normal_tensor'] )
td['normal_tensor'] += 9999999
print( td['normal_tensor'] )
print( td['numpy_array'] )
td['numpy_array'] += 9999999
print( td['numpy_array'] )
# it's safe to do in-place modification even we set
# arr = torch.frombuffer(buffer, dtype=torch.uint8)
router_socket.send_multipart([
id,
b"ack",
])
time.sleep(1)
router_socket.close()
context.term()
# -------------------------- Client(DEALER Socket) --------------------------
def dealer_client():
context = zmq.Context()
dealer_socket = context.socket(zmq.DEALER)
# set client identity
dealer_socket.setsockopt_string(zmq.IDENTITY, "client_001")
dealer_socket.connect("tcp://127.0.0.1:5555")
print("DEALER Client is ready, connecting:tcp://127.0.0.1:5555")
time.sleep(0.5)
test_data = create_complex_test_case()
request_msg = ZMQMessage.create(
request_type=ZMQRequestType.PUT_DATA,
sender_id='123',
receiver_id='456',
body={"data":test_data},
)
dealer_socket.send_multipart(request_msg.serialize(),copy=False)
response_frames = dealer_socket.recv_multipart()
response_frame1 = response_frames[0]
print(f"DEALER Receive → Frame: {response_frame1}")
dealer_socket.close()
context.term()
# -------------------------- Start all processes --------------------------
if __name__ == "__main__":
# Start server process
server_process = multiprocessing.Process(target=router_server)
server_process.start()
time.sleep(0.5)
# Start client process
client_process = multiprocessing.Process(target=dealer_client)
client_process.start()
server_process.join()
client_process.join()
print("Test Finish!")
```
See merge request: Ascend/TransferQueue!131 parent d292bec commit b928421
3 files changed
Lines changed: 22 additions & 6 deletions
File tree
- tests
- transfer_queue
- storage/managers
- utils
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
142 | 142 | | |
143 | 143 | | |
144 | 144 | | |
145 | | - | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
146 | 151 | | |
147 | 152 | | |
148 | 153 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
285 | 285 | | |
286 | 286 | | |
287 | 287 | | |
288 | | - | |
| 288 | + | |
289 | 289 | | |
290 | 290 | | |
291 | 291 | | |
| |||
303 | 303 | | |
304 | 304 | | |
305 | 305 | | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
306 | 309 | | |
307 | 310 | | |
308 | | - | |
| 311 | + | |
309 | 312 | | |
310 | 313 | | |
311 | 314 | | |
| |||
341 | 344 | | |
342 | 345 | | |
343 | 346 | | |
| 347 | + | |
| 348 | + | |
344 | 349 | | |
345 | | - | |
| 350 | + | |
346 | 351 | | |
347 | 352 | | |
348 | 353 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
20 | 20 | | |
21 | 21 | | |
22 | 22 | | |
| 23 | + | |
23 | 24 | | |
24 | 25 | | |
25 | 26 | | |
| |||
52 | 53 | | |
53 | 54 | | |
54 | 55 | | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
55 | 60 | | |
56 | 61 | | |
57 | 62 | | |
| |||
154 | 159 | | |
155 | 160 | | |
156 | 161 | | |
157 | | - | |
158 | | - | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
159 | 165 | | |
160 | 166 | | |
161 | 167 | | |
| |||
0 commit comments