1717import logging
1818import os
1919import threading
20- from functools import wraps
2120from typing import Any , Callable , Optional
22- from uuid import uuid4
2321
2422import torch
2523import zmq
3836 ZMQMessage ,
3937 ZMQRequestType ,
4038 ZMQServerInfo ,
41- create_zmq_socket ,
42- format_zmq_address ,
39+ with_zmq_socket ,
4340)
4441
4542logger = logging .getLogger (__name__ )
5350
5451TQ_NUM_THREADS = int (os .environ .get ("TQ_NUM_THREADS" , 8 ))
5552
53+ # Pre-bound decorator for controller socket operations.
54+ with_controller_socket = with_zmq_socket (
55+ "request_handle_socket" ,
56+ get_identity = lambda self : self .client_id ,
57+ get_peer = lambda self , target : self ._controller ,
58+ )
59+
5660
5761class AsyncTransferQueueClient :
5862 """Asynchronous client for interacting with TransferQueue controller and storage systems.
@@ -99,63 +103,8 @@ def initialize_storage_manager(
99103 manager_type , controller_info = self ._controller , config = config
100104 )
101105
102- # TODO (TQStorage): Provide a general dynamic socket function for both Client & Storage @huazhong.
103- @staticmethod
104- def dynamic_socket (socket_name : str ):
105- """Decorator to auto-manage ZMQ sockets for Controller/Storage servers.
106-
107- Handles socket lifecycle: create -> connect -> inject -> close.
108-
109- Args:
110- socket_name: Port name from server config to use for ZMQ connection (e.g., "data_req_port")
111-
112- Decorated Function Requirements:
113- 1. Must be an async class method (needs `self`)
114- 2. `self` must have:
115- - `_controller`: Server registry
116- - `client_id`: Unique client ID for socket identity
117- 3. Receives ZMQ socket via `socket` keyword argument (injected by decorator)
118- """
119-
120- def decorator (func : Callable ):
121- @wraps (func )
122- async def wrapper (self , * args , ** kwargs ):
123- server_info = self ._controller
124- if not server_info :
125- raise RuntimeError ("No controller registered" )
126-
127- context = zmq .asyncio .Context ()
128- address = format_zmq_address (server_info .ip , server_info .ports .get (socket_name ))
129- identity = f"{ self .client_id } _to_{ server_info .id } _{ uuid4 ().hex [:8 ]} " .encode ()
130- sock = create_zmq_socket (context , zmq .DEALER , identity = identity , ip = server_info .ip )
131-
132- try :
133- sock .connect (address )
134- logger .debug (
135- f"[{ self .client_id } ]: Connected to Controller { server_info .id } at { address } "
136- f"with identity { identity .decode ()} "
137- )
138-
139- kwargs ["socket" ] = sock
140- return await func (self , * args , ** kwargs )
141- except Exception as e :
142- logger .error (f"[{ self .client_id } ]: Error in socket operation with Controller { server_info .id } : { e } " )
143- raise
144- finally :
145- try :
146- if not sock .closed :
147- sock .close (linger = - 1 )
148- except Exception as e :
149- logger .warning (f"[{ self .client_id } ]: Error closing socket to Controller { server_info .id } : { e } " )
150-
151- context .term ()
152-
153- return wrapper
154-
155- return decorator
156-
157106 # ==================== Basic API ====================
158- @dynamic_socket ( socket_name = "request_handle_socket" )
107+ @with_controller_socket
159108 async def async_get_meta (
160109 self ,
161110 data_fields : list [str ],
@@ -245,7 +194,7 @@ async def async_get_meta(
245194 f"{ response_msg .body .get ('message' , 'Unknown error' )} "
246195 )
247196
248- @dynamic_socket ( socket_name = "request_handle_socket" )
197+ @with_controller_socket
249198 async def async_set_custom_meta (
250199 self ,
251200 metadata : BatchMeta ,
@@ -545,7 +494,7 @@ async def async_clear_samples(self, metadata: BatchMeta):
545494 except Exception as e :
546495 raise RuntimeError (f"Error in clear_samples operation: { str (e )} " ) from e
547496
548- @dynamic_socket ( socket_name = "request_handle_socket" )
497+ @with_controller_socket
549498 async def _clear_meta_in_controller (self , metadata : BatchMeta , socket = None ):
550499 """Clear metadata in the controller.
551500
@@ -571,7 +520,7 @@ async def _clear_meta_in_controller(self, metadata: BatchMeta, socket=None):
571520 if response_msg .request_type != ZMQRequestType .CLEAR_META_RESPONSE :
572521 raise RuntimeError ("Failed to clear samples metadata in controller." )
573522
574- @dynamic_socket ( socket_name = "request_handle_socket" )
523+ @with_controller_socket
575524 async def _get_partition_meta (self , partition_id : str , socket = None ) -> BatchMeta :
576525 """Get metadata required for the whole partition from controller.
577526
@@ -601,7 +550,7 @@ async def _get_partition_meta(self, partition_id: str, socket=None) -> BatchMeta
601550
602551 return response_msg .body ["metadata" ]
603552
604- @dynamic_socket ( socket_name = "request_handle_socket" )
553+ @with_controller_socket
605554 async def _clear_partition_in_controller (self , partition_id , socket = None ):
606555 """Clear the whole partition in the controller.
607556
@@ -628,7 +577,7 @@ async def _clear_partition_in_controller(self, partition_id, socket=None):
628577 raise RuntimeError (f"Failed to clear partition { partition_id } in controller." )
629578
630579 # ==================== Status Query API ====================
631- @dynamic_socket ( socket_name = "request_handle_socket" )
580+ @with_controller_socket
632581 async def async_get_consumption_status (
633582 self ,
634583 task_name : str ,
@@ -691,7 +640,7 @@ async def async_get_consumption_status(
691640 except Exception as e :
692641 raise RuntimeError (f"[{ self .client_id } ]: Error in get_consumption_status: { str (e )} " ) from e
693642
694- @dynamic_socket ( socket_name = "request_handle_socket" )
643+ @with_controller_socket
695644 async def async_get_production_status (
696645 self ,
697646 data_fields : list [str ],
@@ -823,7 +772,7 @@ async def async_check_production_status(
823772 return False
824773 return torch .all (production_status == 1 ).item ()
825774
826- @dynamic_socket ( socket_name = "request_handle_socket" )
775+ @with_controller_socket
827776 async def async_reset_consumption (
828777 self ,
829778 partition_id : str ,
@@ -885,7 +834,7 @@ async def async_reset_consumption(
885834 except Exception as e :
886835 raise RuntimeError (f"[{ self .client_id } ]: Error in reset_consumption: { str (e )} " ) from e
887836
888- @dynamic_socket ( socket_name = "request_handle_socket" )
837+ @with_controller_socket
889838 async def async_get_partition_list (
890839 self ,
891840 socket : Optional [zmq .asyncio .Socket ] = None ,
@@ -931,7 +880,7 @@ async def async_get_partition_list(
931880 raise RuntimeError (f"[{ self .client_id } ]: Error in get_partition_list: { str (e )} " ) from e
932881
933882 # ==================== KV Interface API ====================
934- @dynamic_socket ( socket_name = "request_handle_socket" )
883+ @with_controller_socket
935884 async def async_kv_retrieve_meta (
936885 self ,
937886 keys : list [str ] | str ,
@@ -997,7 +946,7 @@ async def async_kv_retrieve_meta(
997946 except Exception as e :
998947 raise RuntimeError (f"[{ self .client_id } ]: Error in kv_retrieve_keys: { str (e )} " ) from e
999948
1000- @dynamic_socket ( socket_name = "request_handle_socket" )
949+ @with_controller_socket
1001950 async def async_kv_retrieve_keys (
1002951 self ,
1003952 global_indexes : list [int ] | int ,
@@ -1060,7 +1009,7 @@ async def async_kv_retrieve_keys(
10601009 except Exception as e :
10611010 raise RuntimeError (f"[{ self .client_id } ]: Error in kv_retrieve_indexes: { str (e )} " ) from e
10621011
1063- @dynamic_socket ( socket_name = "request_handle_socket" )
1012+ @with_controller_socket
10641013 async def async_kv_list (
10651014 self ,
10661015 partition_id : Optional [str ] = None ,
0 commit comments