Skip to content

Commit ba5710e

Browse files
authored
[perf] Add zmq.proxy to accelerate request processing for SimpleStorageUnit (Ascend#37)
## Background Previously, `SimpleStorageUnit` relied on a single-threaded event loop for request processing. This design could lead to bottlenecks and increased latency when multiple requests arrived simultaneously, as operations like ZMQ message deserialization and memory I/O would block the main socket loop from receiving new requests. ## Key Changes 1. Refactored `SimpleStorageUnit` to utilize a native `zmq.proxy`. This acts as a highly efficient, C-level load balancer between a frontend `ROUTER` socket (handling external client connections) and an internal backend `DEALER` socket (inproc://). 2. ~~Introduced a worker thread pool where each worker binds its own independent `DEALER` socket to process `PUT/GET/CLEAR` requests concurrently. This preserves ZMQ's "share-nothing" concurrency philosophy.~~ 3. ~~Added a `threading.Lock()` to `StorageUnitData` to prevent race condition introduced by multi-threads~~ 4. ~~Added `num_worker_threads` as an explicit input parameter for `SimpleStorageUnit` (configurable via TQ system config items).~~ > During performance test, we surprisingly find out that the refactored multi-thread code achieves better performance with `num_worker_threads=1`. The introduction of the native C-level `zmq.proxy` offloads the high-frequency I/O from the main Python thread. Therefore, we retire the multi-thread version and only preserve the `zmq.proxy` optimization. ## Architechture ### Old Version <img width="1067" height="1760" alt="mermaid-diagram-2026-02-26-192209" src="https://github.com/user-attachments/assets/3a61673b-9e91-4cc9-9930-b20e6cd06217" /> ### New Version <img width="1374" height="3104" alt="mermaid-diagram-2026-02-26-220631" src="https://github.com/user-attachments/assets/824386e0-5b57-4a7c-a15c-ac3c6258d9ad" /> ## Performance Gain We provide a simple benchmark script for this PR: ```python3 import argparse import multiprocessing import time import ray import torch import zmq import tensordict # Ensure this runs in the repository root directory, otherwise sys.path.append might be needed from transfer_queue.storage.simple_backend import SimpleStorageUnit from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType class StorageClient: """Independent test client that interacts directly with the frontend ROUTER of SimpleStorageUnit""" def __init__(self, address): self.context = zmq.Context() self.socket = self.context.socket(zmq.DEALER) self.socket.setsockopt(zmq.RCVTIMEO, 20000) # Timeout set to 20s to prevent timeouts under heavy concurrency self.socket.connect(address) def send_put(self, client_id, local_indexes, field_data): msg = ZMQMessage.create( request_type=ZMQRequestType.PUT_DATA, sender_id=f"bench_client_{client_id}", body={"local_indexes": local_indexes, "data": field_data}, ) self.socket.send_multipart(msg.serialize()) return ZMQMessage.deserialize(self.socket.recv_multipart()) def close(self): self.socket.close() self.context.term() def client_worker(worker_id, address, num_requests, batch_size): """Worker process task: Continuously bombard the Storage Unit with PUT requests""" client = StorageClient(address) start_time = time.time() # Construct Dummy Tensor data to simulate actual memory and serialization overhead # As noted in the PR description, serialization and memory I/O are the bottlenecks blocking the main loop field_data = { "dummy_tensor": [torch.randn(256, 256) for _ in range(batch_size)] } for i in range(num_requests): local_indexes = list(range(i * batch_size, (i + 1) * batch_size)) client.send_put(worker_id, local_indexes, field_data) elapsed = time.time() - start_time client.close() print(f"[Worker {worker_id}] Completed {num_requests} write requests, took {elapsed:.3f} seconds " f"(QPS: {num_requests / elapsed:.2f} req/s)") def main(num_clients, storage_threads, requests_per_client): # Initialize Ray and global settings ray.init(ignore_reinit_error=True) tensordict.set_list_to_stack(True).set() try: print(f"🚀 Launching SimpleStorageUnit, internal worker threads (num_worker_threads): {storage_threads} ...") # Launch the backend Actor. PR 37 exposes the num_worker_threads parameter storage_actor = SimpleStorageUnit.options( max_concurrency=50, num_cpus=2 ).remote( storage_unit_size=1000000, num_worker_threads=storage_threads # comment this line for old version comparison ) zmq_info = ray.get(storage_actor.get_zmq_server_info.remote()) put_get_address = zmq_info.to_addr("put_get_socket") print(f"✅ Storage unit ready, ZMQ Address: {put_get_address}") # Wait for zmq.proxy and all worker threads to bind to the inproc port time.sleep(2) print(f"🔥 Spawning {num_clients} independent concurrent write processes...") processes = [] batch_size = 256 start_time = time.time() # 1. Create and start multiple processes for i in range(num_clients): p = multiprocessing.Process( target=client_worker, args=(i, put_get_address, requests_per_client, batch_size) ) p.start() processes.append(p) # 2. Wait for all concurrent processes to complete for p in processes: p.join() total_time = time.time() - start_time total_requests = num_clients * requests_per_client print("\n" + "=" * 50) print(f" 📊 Benchmark Results") print("=" * 50) print(f" SimpleStorageUnit internal threads : {storage_threads}") print(f" External concurrent clients : {num_clients}") print(f" Total processed requests (Batches) : {total_requests} (Batch Size: {batch_size})") print(f" Total benchmark duration : {total_time:.3f} seconds") print(f" 🚀 Overall Throughput : {total_requests / total_time:.2f} req/s") print("=" * 50 + "\n") finally: # Resource cleanup if 'storage_actor' in locals(): ray.kill(storage_actor) ray.shutdown() if __name__ == "__main__": parser = argparse.ArgumentParser(description="PR Ascend#37 Performance Benchmark") parser.add_argument("--clients", type=int, default=8, help="Number of concurrent client processes") parser.add_argument("--threads", type=int, default=4, help="Number of processing threads in SimpleStorageUnit") parser.add_argument("--requests", type=int, default=300, help="Number of requests sent per client") args = parser.parse_args() main(args.clients, args.threads, args.requests) ``` ### Small Scale Test (`batch_size=20`, `clients=4`) On a mac mini with M2 chip with 24GB memory: #### Old Version ```bash python benchmark.py --clients 4 ``` <img width="680" height="343" alt="image" src="https://github.com/user-attachments/assets/0e5fedc4-a185-4d34-94d0-8cde007d1a74" /> #### New Version ```bash python benchmark.py --clients 4 --threads 1 ``` <img width="663" height="342" alt="image" src="https://github.com/user-attachments/assets/c325bc27-0ad7-485a-9717-9255662b3733" /> ```bash python benchmark.py --clients 4 --threads 2 ``` <img width="663" height="343" alt="image" src="https://github.com/user-attachments/assets/66e64858-08ac-4358-b8f9-8b0f56506ffa" /> ### Middle Scale Test (`batch_size=256`, `clients=4`) On a mac mini with M2 chip with 24GB memory: #### Old Version ```bash python benchmark.py --clients 4 ``` <img width="683" height="327" alt="image" src="https://github.com/user-attachments/assets/47b4b8a7-d81a-4572-9235-14c3c68059f7" /> #### New Version ```bash python benchmark.py --clients 4 --threads 1 ``` <img width="731" height="343" alt="image" src="https://github.com/user-attachments/assets/ae22115e-9433-4a80-a4d3-238beba9fec1" /> ```bash python benchmark.py --clients 4 --threads 2 ``` <img width="716" height="341" alt="image" src="https://github.com/user-attachments/assets/ba9ff4c6-9d0c-45cd-83c5-881be2b5c118" /> ### Large Scale Test (`batch_size=256`, `clients=50`) On a Ubuntu server with Intel(R) Xeon(R) Platinum 8358P CPU @ 2.60GHz x 128 cores: Note: 1. The benchmark script has also been modified to consider `get` performance 2. We export the following env vars: ```bash export OMP_NUM_THREADS=1 export MKL_NUM_THREADS=1 export OPENBLAS_NUM_THREADS=1 export VECLIB_MAXIMUM_THREADS=1 export NUMEXPR_NUM_THREADS=1 export TORCH_NUM_THREADS=1 export TQ_ZERO_COPY_SERIALIZATION=True ``` #### Old Version ```bash python benchmark.py --clients 50 ``` <img width="555" height="196" alt="image" src="https://github.com/user-attachments/assets/f47397d6-1819-4230-bb46-3073d36a1633" /> #### New Version ```bash python benchmark.py --clients 50 --threads 1 ``` <img width="551" height="195" alt="image" src="https://github.com/user-attachments/assets/0a9dcee1-326e-43eb-901a-5e1f9b1a75f1" /> ```bash python benchmark.py --clients 50 --threads 2 ``` <img width="556" height="195" alt="image" src="https://github.com/user-attachments/assets/b8a6daaf-5644-4607-b8f9-a4d00c7a8b34" /> ```bash python benchmark.py --clients 50 --threads 4 ``` <img width="526" height="190" alt="image" src="https://github.com/user-attachments/assets/1e0a5e3f-c13b-4ba0-ac95-4ee54f30be79" /> --------- Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
1 parent bd31c02 commit ba5710e

4 files changed

Lines changed: 142 additions & 32 deletions

File tree

scripts/performance_test.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,11 @@
3030
parent_dir = Path(__file__).resolve().parent.parent.parent
3131
sys.path.append(str(parent_dir))
3232

33-
34-
from transfer_queue import ( # noqa: E402
35-
SimpleStorageUnit,
36-
TransferQueueClient,
37-
TransferQueueController,
38-
process_zmq_server_info,
39-
)
33+
from transfer_queue.client import TransferQueueClient # noqa: E402
34+
from transfer_queue.controller import TransferQueueController # noqa: E402
35+
from transfer_queue.storage.simple_backend import SimpleStorageUnit # noqa: E402
4036
from transfer_queue.utils.common import get_placement_group # noqa: E402
37+
from transfer_queue.utils.zmq_utils import process_zmq_server_info # noqa: E402
4138

4239
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
4340
logger = logging.getLogger(__name__)

transfer_queue/controller.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1523,7 +1523,7 @@ def kv_retrieve_keys(
15231523
)
15241524
data_fields = []
15251525
for fname, col_idx in partition.field_name_mapping.items():
1526-
if col_mask[col_idx]:
1526+
if col_idx < len(col_mask) and col_mask[col_idx]:
15271527
data_fields.append(fname)
15281528

15291529
metadata = self.generate_batch_meta(partition_id, verified_global_indexes, data_fields, mode="force_fetch")

transfer_queue/interface.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig:
8282
placement_group_bundle_index=storage_unit_rank,
8383
name=f"TransferQueueStorageUnit#{storage_unit_rank}",
8484
lifetime="detached",
85-
).remote(storage_unit_size=math.ceil(total_storage_size / num_data_storage_units))
85+
).remote(
86+
storage_unit_size=math.ceil(total_storage_size / num_data_storage_units),
87+
)
8688
_TRANSFER_QUEUE_STORAGE[f"TransferQueueStorageUnit#{storage_unit_rank}"] = storage_node
8789
logger.info(f"TransferQueueStorageUnit#{storage_unit_rank} has been created.")
8890

transfer_queue/storage/simple_backend.py

Lines changed: 134 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616
import dataclasses
1717
import logging
1818
import os
19+
import time
20+
import weakref
1921
from dataclasses import dataclass
2022
from operator import itemgetter
21-
from threading import Thread
22-
from typing import Any
23+
from threading import Event, Thread
24+
from typing import Any, Optional
2325
from uuid import uuid4
2426

2527
import ray
@@ -173,16 +175,41 @@ def __init__(self, storage_unit_size: int):
173175

174176
self.storage_data = StorageUnitData(self.storage_unit_size)
175177

178+
# Internal communication address for proxy and workers
179+
self._inproc_addr = f"inproc://simple_storage_workers_{self.storage_unit_id}"
180+
181+
# Shutdown event for graceful termination
182+
self._shutdown_event = Event()
183+
184+
# Placeholder for zmq_context, proxy_thread and worker_threads
185+
self.zmq_context: Optional[zmq.Context] = None
186+
self.put_get_socket: Optional[zmq.Socket] = None
187+
self.proxy_thread: Optional[Thread] = None
188+
self.worker_thread: Optional[Thread] = None
189+
176190
self._init_zmq_socket()
177191
self._start_process_put_get()
178192

193+
# Register finalizer for graceful cleanup when garbage collected
194+
self._finalizer = weakref.finalize(
195+
self,
196+
self._shutdown_resources,
197+
self._shutdown_event,
198+
self.worker_thread,
199+
self.proxy_thread,
200+
self.zmq_context,
201+
self.put_get_socket,
202+
)
203+
179204
def _init_zmq_socket(self) -> None:
180205
"""
181206
Initialize ZMQ socket connections between storage unit and controller/clients:
182-
- put_get_socket:
183-
Handle put/get requests from clients.
207+
- put_get_socket (ROUTER): Handle put/get requests from clients.
208+
- worker_socket (DEALER): Backend socket for worker communication.
184209
"""
185210
self.zmq_context = zmq.Context()
211+
212+
# Frontend: ROUTER for receiving client requests
186213
self.put_get_socket = create_zmq_socket(self.zmq_context, zmq.ROUTER)
187214
self._node_ip = get_node_ip_address()
188215

@@ -195,6 +222,10 @@ def _init_zmq_socket(self) -> None:
195222
logger.warning(f"[{self.storage_unit_id}]: Try to bind ZMQ sockets failed, retrying...")
196223
continue
197224

225+
# Backend: DEALER for worker communication (connected via zmq.proxy)
226+
self.worker_socket = create_zmq_socket(self.zmq_context, zmq.DEALER)
227+
self.worker_socket.bind(self._inproc_addr)
228+
198229
self.zmq_server_info = ZMQServerInfo(
199230
role=TransferQueueRole.STORAGE,
200231
id=str(self.storage_unit_id),
@@ -203,33 +234,78 @@ def _init_zmq_socket(self) -> None:
203234
)
204235

205236
def _start_process_put_get(self) -> None:
206-
"""Create a daemon thread and start put/get process."""
207-
self.process_put_get_thread = Thread(
208-
target=self._process_put_get, name=f"StorageUnitProcessPutGetThread-{self.storage_unit_id}", daemon=True
237+
"""Start worker threads and ZMQ proxy for handling requests."""
238+
239+
# Start worker thread
240+
self.worker_thread = Thread(
241+
target=self._worker_routine,
242+
name=f"StorageUnitWorkerThread-{self.storage_unit_id}",
243+
daemon=True,
244+
)
245+
self.worker_thread.start()
246+
247+
time.sleep(0.5) # make sure worker thread is ready before zmq.proxy forwarding messages
248+
249+
# Start proxy thread (ROUTER <-> DEALER)
250+
self.proxy_thread = Thread(
251+
target=self._proxy_routine,
252+
name=f"StorageUnitProxyThread-{self.storage_unit_id}",
253+
daemon=True,
209254
)
210-
self.process_put_get_thread.start()
255+
self.proxy_thread.start()
256+
257+
def _proxy_routine(self) -> None:
258+
"""ZMQ proxy for message forwarding between frontend ROUTER and backend DEALER."""
259+
logger.info(f"[{self.storage_unit_id}]: start ZMQ proxy...")
260+
try:
261+
zmq.proxy(self.put_get_socket, self.worker_socket)
262+
except zmq.ContextTerminated:
263+
logger.info(f"[{self.storage_unit_id}]: ZMQ Proxy stopped gracefully (Context Terminated)")
264+
except Exception as e:
265+
if self._shutdown_event.is_set():
266+
logger.info(f"[{self.storage_unit_id}]: ZMQ Proxy shutting down...")
267+
else:
268+
logger.error(f"[{self.storage_unit_id}]: ZMQ Proxy unexpected error: {e}")
269+
270+
def _worker_routine(self) -> None:
271+
"""Worker thread for processing requests."""
272+
# Each worker must have its own socket
273+
worker_socket = create_zmq_socket(self.zmq_context, zmq.DEALER)
274+
worker_socket.connect(self._inproc_addr)
211275

212-
def _process_put_get(self) -> None:
213-
"""Process put_get_socket request."""
214276
poller = zmq.Poller()
215-
poller.register(self.put_get_socket, zmq.POLLIN)
277+
poller.register(worker_socket, zmq.POLLIN)
216278

217-
logger.info(f"[{self.storage_unit_id}]: start processing put/get requests...")
279+
logger.info(f"[{self.storage_unit_id}]: worker thread started...")
280+
perf_monitor = IntervalPerfMonitor(caller_name=f"{self.storage_unit_id}")
281+
282+
while not self._shutdown_event.is_set():
283+
try:
284+
socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT * 1000))
285+
except zmq.error.ContextTerminated:
286+
# ZMQ context was terminated, exit gracefully
287+
logger.info(f"[{self.storage_unit_id}]: worker stopped gracefully (Context Terminated)")
288+
break
289+
except Exception as e:
290+
logger.warning(f"[{self.storage_unit_id}]: worker poll error: {e}")
291+
continue
218292

219-
perf_monitor = IntervalPerfMonitor(caller_name=self.storage_unit_id)
293+
if self._shutdown_event.is_set():
294+
break
220295

221-
while True:
222-
socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT * 1000))
296+
if worker_socket in socks:
297+
# Messages received from proxy: [identity, serialized_msg_frame1, ...]
298+
messages = worker_socket.recv_multipart()
299+
identity = messages[0]
300+
serialized_msg = messages[1:]
223301

224-
if self.put_get_socket in socks:
225-
messages = self.put_get_socket.recv_multipart()
226-
identity = messages.pop(0)
227-
serialized_msg = messages
228302
request_msg = ZMQMessage.deserialize(serialized_msg)
229303
operation = request_msg.request_type
304+
230305
try:
231-
logger.debug(f"[{self.storage_unit_id}]: receive operation: {operation}, message: {request_msg}")
306+
logger.debug(f"[{self.storage_unit_id}]: worker received operation: {operation}")
232307

308+
# Process request
233309
if operation == ZMQRequestType.PUT_DATA:
234310
with perf_monitor.measure(op_type="PUT_DATA"):
235311
response_msg = self._handle_put(request_msg)
@@ -253,12 +329,17 @@ def _process_put_get(self) -> None:
253329
request_type=ZMQRequestType.PUT_GET_ERROR,
254330
sender_id=self.storage_unit_id,
255331
body={
256-
"message": f"Storage unit id #{self.storage_unit_id} occur error in processing "
257-
f"put/get/clear request, detail error message: {str(e)}."
332+
"message": f"{self.storage_unit_id}, worker encountered error "
333+
f"during operation {operation}: {str(e)}."
258334
},
259335
)
260336

261-
self.put_get_socket.send_multipart([identity, *response_msg.serialize()], copy=False)
337+
# Send response back with identity for routing
338+
worker_socket.send_multipart([identity] + response_msg.serialize(), copy=False)
339+
340+
logger.info(f"[{self.storage_unit_id}]: worker stopped.")
341+
poller.unregister(worker_socket)
342+
worker_socket.close(linger=0)
262343

263344
def _handle_put(self, data_parts: ZMQMessage) -> ZMQMessage:
264345
"""
@@ -365,6 +446,36 @@ def _handle_clear(self, data_parts: ZMQMessage) -> ZMQMessage:
365446
)
366447
return response_msg
367448

449+
@staticmethod
450+
def _shutdown_resources(
451+
shutdown_event: Event,
452+
worker_thread: Optional[Thread],
453+
proxy_thread: Optional[Thread],
454+
zmq_context: Optional[zmq.Context],
455+
put_get_socket: Optional[zmq.Socket],
456+
) -> None:
457+
"""Clean up resources on garbage collection."""
458+
logger.info("Shutting down SimpleStorageUnit resources...")
459+
460+
# Signal all threads to stop
461+
shutdown_event.set()
462+
463+
# Terminate put_get_socket
464+
if put_get_socket:
465+
put_get_socket.close(linger=0)
466+
467+
# Terminate ZMQ context to unblock proxy and workers
468+
if zmq_context:
469+
zmq_context.term()
470+
471+
# Wait for threads to finish (with timeout)
472+
if worker_thread and worker_thread.is_alive():
473+
worker_thread.join(timeout=5)
474+
if proxy_thread and proxy_thread.is_alive():
475+
proxy_thread.join(timeout=5)
476+
477+
logger.info("SimpleStorageUnit resources shutdown complete.")
478+
368479
def get_zmq_server_info(self) -> ZMQServerInfo:
369480
"""Get the ZMQ server information for this storage unit.
370481

0 commit comments

Comments
 (0)