Skip to content

Commit 961c5df

Browse files
authored
[client, storage] refactor: unify dynamic ZMQ socket decorator between simple_backend_manager and client (#66)
as title. --------- Signed-off-by: ji-huazhong <hzji210@gmail.com>
1 parent aec9192 commit 961c5df

3 files changed

Lines changed: 113 additions & 154 deletions

File tree

transfer_queue/client.py

Lines changed: 20 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717
import logging
1818
import os
1919
import threading
20-
from functools import wraps
2120
from typing import Any, Callable, Optional
22-
from uuid import uuid4
2321

2422
import torch
2523
import zmq
@@ -38,8 +36,7 @@
3836
ZMQMessage,
3937
ZMQRequestType,
4038
ZMQServerInfo,
41-
create_zmq_socket,
42-
format_zmq_address,
39+
with_zmq_socket,
4340
)
4441

4542
logger = logging.getLogger(__name__)
@@ -53,6 +50,13 @@
5350

5451
TQ_NUM_THREADS = int(os.environ.get("TQ_NUM_THREADS", 8))
5552

53+
# Pre-bound decorator for controller socket operations.
54+
with_controller_socket = with_zmq_socket(
55+
"request_handle_socket",
56+
get_identity=lambda self: self.client_id,
57+
get_peer=lambda self, target: self._controller,
58+
)
59+
5660

5761
class AsyncTransferQueueClient:
5862
"""Asynchronous client for interacting with TransferQueue controller and storage systems.
@@ -99,63 +103,8 @@ def initialize_storage_manager(
99103
manager_type, controller_info=self._controller, config=config
100104
)
101105

102-
# TODO (TQStorage): Provide a general dynamic socket function for both Client & Storage @huazhong.
103-
@staticmethod
104-
def dynamic_socket(socket_name: str):
105-
"""Decorator to auto-manage ZMQ sockets for Controller/Storage servers.
106-
107-
Handles socket lifecycle: create -> connect -> inject -> close.
108-
109-
Args:
110-
socket_name: Port name from server config to use for ZMQ connection (e.g., "data_req_port")
111-
112-
Decorated Function Requirements:
113-
1. Must be an async class method (needs `self`)
114-
2. `self` must have:
115-
- `_controller`: Server registry
116-
- `client_id`: Unique client ID for socket identity
117-
3. Receives ZMQ socket via `socket` keyword argument (injected by decorator)
118-
"""
119-
120-
def decorator(func: Callable):
121-
@wraps(func)
122-
async def wrapper(self, *args, **kwargs):
123-
server_info = self._controller
124-
if not server_info:
125-
raise RuntimeError("No controller registered")
126-
127-
context = zmq.asyncio.Context()
128-
address = format_zmq_address(server_info.ip, server_info.ports.get(socket_name))
129-
identity = f"{self.client_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode()
130-
sock = create_zmq_socket(context, zmq.DEALER, identity=identity, ip=server_info.ip)
131-
132-
try:
133-
sock.connect(address)
134-
logger.debug(
135-
f"[{self.client_id}]: Connected to Controller {server_info.id} at {address} "
136-
f"with identity {identity.decode()}"
137-
)
138-
139-
kwargs["socket"] = sock
140-
return await func(self, *args, **kwargs)
141-
except Exception as e:
142-
logger.error(f"[{self.client_id}]: Error in socket operation with Controller {server_info.id}: {e}")
143-
raise
144-
finally:
145-
try:
146-
if not sock.closed:
147-
sock.close(linger=-1)
148-
except Exception as e:
149-
logger.warning(f"[{self.client_id}]: Error closing socket to Controller {server_info.id}: {e}")
150-
151-
context.term()
152-
153-
return wrapper
154-
155-
return decorator
156-
157106
# ==================== Basic API ====================
158-
@dynamic_socket(socket_name="request_handle_socket")
107+
@with_controller_socket
159108
async def async_get_meta(
160109
self,
161110
data_fields: list[str],
@@ -245,7 +194,7 @@ async def async_get_meta(
245194
f"{response_msg.body.get('message', 'Unknown error')}"
246195
)
247196

248-
@dynamic_socket(socket_name="request_handle_socket")
197+
@with_controller_socket
249198
async def async_set_custom_meta(
250199
self,
251200
metadata: BatchMeta,
@@ -545,7 +494,7 @@ async def async_clear_samples(self, metadata: BatchMeta):
545494
except Exception as e:
546495
raise RuntimeError(f"Error in clear_samples operation: {str(e)}") from e
547496

548-
@dynamic_socket(socket_name="request_handle_socket")
497+
@with_controller_socket
549498
async def _clear_meta_in_controller(self, metadata: BatchMeta, socket=None):
550499
"""Clear metadata in the controller.
551500
@@ -571,7 +520,7 @@ async def _clear_meta_in_controller(self, metadata: BatchMeta, socket=None):
571520
if response_msg.request_type != ZMQRequestType.CLEAR_META_RESPONSE:
572521
raise RuntimeError("Failed to clear samples metadata in controller.")
573522

574-
@dynamic_socket(socket_name="request_handle_socket")
523+
@with_controller_socket
575524
async def _get_partition_meta(self, partition_id: str, socket=None) -> BatchMeta:
576525
"""Get metadata required for the whole partition from controller.
577526
@@ -601,7 +550,7 @@ async def _get_partition_meta(self, partition_id: str, socket=None) -> BatchMeta
601550

602551
return response_msg.body["metadata"]
603552

604-
@dynamic_socket(socket_name="request_handle_socket")
553+
@with_controller_socket
605554
async def _clear_partition_in_controller(self, partition_id, socket=None):
606555
"""Clear the whole partition in the controller.
607556
@@ -628,7 +577,7 @@ async def _clear_partition_in_controller(self, partition_id, socket=None):
628577
raise RuntimeError(f"Failed to clear partition {partition_id} in controller.")
629578

630579
# ==================== Status Query API ====================
631-
@dynamic_socket(socket_name="request_handle_socket")
580+
@with_controller_socket
632581
async def async_get_consumption_status(
633582
self,
634583
task_name: str,
@@ -691,7 +640,7 @@ async def async_get_consumption_status(
691640
except Exception as e:
692641
raise RuntimeError(f"[{self.client_id}]: Error in get_consumption_status: {str(e)}") from e
693642

694-
@dynamic_socket(socket_name="request_handle_socket")
643+
@with_controller_socket
695644
async def async_get_production_status(
696645
self,
697646
data_fields: list[str],
@@ -823,7 +772,7 @@ async def async_check_production_status(
823772
return False
824773
return torch.all(production_status == 1).item()
825774

826-
@dynamic_socket(socket_name="request_handle_socket")
775+
@with_controller_socket
827776
async def async_reset_consumption(
828777
self,
829778
partition_id: str,
@@ -885,7 +834,7 @@ async def async_reset_consumption(
885834
except Exception as e:
886835
raise RuntimeError(f"[{self.client_id}]: Error in reset_consumption: {str(e)}") from e
887836

888-
@dynamic_socket(socket_name="request_handle_socket")
837+
@with_controller_socket
889838
async def async_get_partition_list(
890839
self,
891840
socket: Optional[zmq.asyncio.Socket] = None,
@@ -931,7 +880,7 @@ async def async_get_partition_list(
931880
raise RuntimeError(f"[{self.client_id}]: Error in get_partition_list: {str(e)}") from e
932881

933882
# ==================== KV Interface API ====================
934-
@dynamic_socket(socket_name="request_handle_socket")
883+
@with_controller_socket
935884
async def async_kv_retrieve_meta(
936885
self,
937886
keys: list[str] | str,
@@ -997,7 +946,7 @@ async def async_kv_retrieve_meta(
997946
except Exception as e:
998947
raise RuntimeError(f"[{self.client_id}]: Error in kv_retrieve_keys: {str(e)}") from e
999948

1000-
@dynamic_socket(socket_name="request_handle_socket")
949+
@with_controller_socket
1001950
async def async_kv_retrieve_keys(
1002951
self,
1003952
global_indexes: list[int] | int,
@@ -1060,7 +1009,7 @@ async def async_kv_retrieve_keys(
10601009
except Exception as e:
10611010
raise RuntimeError(f"[{self.client_id}]: Error in kv_retrieve_indexes: {str(e)}") from e
10621011

1063-
@dynamic_socket(socket_name="request_handle_socket")
1012+
@with_controller_socket
10641013
async def async_kv_list(
10651014
self,
10661015
partition_id: Optional[str] = None,

transfer_queue/storage/managers/simple_backend_manager.py

Lines changed: 16 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,8 @@
1919
import warnings
2020
from collections import defaultdict
2121
from collections.abc import Mapping
22-
from functools import wraps
2322
from operator import itemgetter
24-
from typing import Any, Callable, NamedTuple, Optional
25-
from uuid import uuid4
23+
from typing import Any, Callable, NamedTuple
2624

2725
import torch
2826
import zmq
@@ -36,8 +34,7 @@
3634
ZMQMessage,
3735
ZMQRequestType,
3836
ZMQServerInfo,
39-
create_zmq_socket,
40-
format_zmq_address,
37+
with_zmq_socket,
4138
)
4239

4340
logger = logging.getLogger(__name__)
@@ -51,6 +48,15 @@
5148

5249
TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT", 200)) # seconds
5350

51+
# Pre-bound decorator for storage-unit socket operations.
52+
with_storage_unit_socket = with_zmq_socket(
53+
"put_get_socket",
54+
get_identity=lambda self: self.storage_manager_id,
55+
get_peer=lambda self, target: self.storage_unit_infos[target],
56+
resolve_target=lambda args, kwargs: kwargs.get("target_storage_unit"),
57+
timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT,
58+
)
59+
5460

5561
class RoutingGroup(NamedTuple):
5662
"""Routing result for a single storage unit."""
@@ -114,78 +120,6 @@ def _register_servers(self, server_infos: "ZMQServerInfo | dict[Any, ZMQServerIn
114120

115121
return server_infos_transform
116122

117-
# TODO (TQStorage): Provide a general dynamic socket function for both Client & Storage @huazhong.
118-
@staticmethod
119-
def dynamic_storage_manager_socket(socket_name: str, timeout: int):
120-
"""Decorator to auto-manage ZMQ sockets for Controller/Storage servers (create -> connect -> inject -> close).
121-
122-
Args:
123-
socket_name (str): Port name (from server config) to use for ZMQ connection (e.g., "data_req_port").
124-
timeout (float): Timeout in seconds for ZMQ connection (in seconds).
125-
126-
Decorated Function Rules:
127-
1. Must be an async class method (needs `self`).
128-
2. `self` requires:
129-
- `storage_unit_infos: storage unit infos (ZMQServerInfo | dict[Any, ZMQServerInfo]).
130-
3. Specify target server via:
131-
- `target_storage_unit` arg.
132-
4. Receives ZMQ socket via `socket` keyword arg (injected by decorator).
133-
"""
134-
135-
def decorator(func: Callable):
136-
@wraps(func)
137-
async def wrapper(self, *args, **kwargs):
138-
server_key = kwargs.get("target_storage_unit")
139-
if server_key is None:
140-
for arg in args:
141-
if isinstance(arg, str) and arg in self.storage_unit_infos.keys():
142-
server_key = arg
143-
break
144-
145-
server_info = self.storage_unit_infos.get(server_key)
146-
147-
if not server_info:
148-
raise RuntimeError(f"Server {server_key} not found in registered servers")
149-
150-
context = zmq.asyncio.Context()
151-
address = format_zmq_address(server_info.ip, server_info.ports.get(socket_name))
152-
identity = f"{self.storage_manager_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode()
153-
sock = create_zmq_socket(context, zmq.DEALER, server_info.ip, identity)
154-
155-
try:
156-
sock.connect(address)
157-
# Timeouts to avoid indefinite await on recv/send
158-
sock.setsockopt(zmq.RCVTIMEO, timeout * 1000)
159-
sock.setsockopt(zmq.SNDTIMEO, timeout * 1000)
160-
logger.debug(
161-
f"[{self.storage_manager_id}]: Connected to StorageUnit {server_info.id} at {address} "
162-
f"with identity {identity.decode()}"
163-
)
164-
165-
kwargs["socket"] = sock
166-
return await func(self, *args, **kwargs)
167-
except Exception as e:
168-
logger.error(
169-
f"[{self.storage_manager_id}]: Error in socket operation with "
170-
f"StorageUnit {server_info.id} at {address}: "
171-
f"{type(e).__name__}: {e}"
172-
)
173-
raise
174-
finally:
175-
try:
176-
if not sock.closed:
177-
sock.close(linger=-1)
178-
except Exception as e:
179-
logger.warning(
180-
f"[{self.storage_manager_id}]: Error closing socket to StorageUnit {server_info.id}: {e}"
181-
)
182-
183-
context.term()
184-
185-
return wrapper
186-
187-
return decorator
188-
189123
def _group_by_hash(self, global_indexes: list[int]) -> dict[str, RoutingGroup]:
190124
"""Group samples by global_idx % num_su, return {storage_id: RoutingGroup}.
191125
@@ -286,7 +220,7 @@ def _select_by_positions(field_data, positions: list[int]):
286220
return field_data[positions]
287221

288222
async def put_data(
289-
self, data: TensorDict, metadata: BatchMeta, data_parser: Optional[Callable[[Any], Any]] = None
223+
self, data: TensorDict, metadata: BatchMeta, data_parser: Callable[[Any], Any] | None = None
290224
) -> None:
291225
"""
292226
Send data to remote StorageUnit based on metadata.
@@ -347,13 +281,13 @@ async def put_data(
347281
field_schema,
348282
)
349283

350-
@dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT)
284+
@with_storage_unit_socket
351285
async def _put_to_single_storage_unit(
352286
self,
353287
global_indexes: list[int],
354288
storage_data: dict[str, Any],
355289
target_storage_unit: str,
356-
data_parser: Optional[Callable[[Any], Any]] = None,
290+
data_parser: Callable[[Any], Any] | None = None,
357291
socket: zmq.Socket = None,
358292
):
359293
"""
@@ -483,7 +417,7 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict:
483417

484418
return TensorDict(tensor_data, batch_size=len(metadata))
485419

486-
@dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT)
420+
@with_storage_unit_socket
487421
async def _get_from_single_storage_unit(
488422
self,
489423
global_indexes: list[int],
@@ -555,7 +489,7 @@ async def clear_data(self, metadata: BatchMeta) -> None:
555489
if isinstance(result, Exception):
556490
logger.error(f"[{self.storage_manager_id}]: Error in clear operation task {i}: {result}")
557491

558-
@dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT)
492+
@with_storage_unit_socket
559493
async def _clear_single_storage_unit(self, global_indexes, target_storage_unit=None, socket=None):
560494
try:
561495
request_msg = ZMQMessage.create(

0 commit comments

Comments
 (0)