1414# limitations under the License.
1515
1616import asyncio
17- import logging
1817import os
1918import threading
20- from functools import wraps
2119from typing import Any , Callable , Optional
22- from uuid import uuid4
2320
2421import torch
2522import zmq
3431 TransferQueueStorageManagerFactory ,
3532)
3633from transfer_queue .utils .common import limit_pytorch_auto_parallel_threads
34+ from transfer_queue .utils .logging_utils import get_logger
3735from transfer_queue .utils .zmq_utils import (
3836 ZMQMessage ,
3937 ZMQRequestType ,
4038 ZMQServerInfo ,
41- create_zmq_socket ,
42- format_zmq_address ,
39+ with_zmq_socket ,
4340)
4441
45- logger = logging .getLogger (__name__ )
46- logger .setLevel (os .getenv ("TQ_LOGGING_LEVEL" , logging .WARNING ))
47-
48- # Ensure logger has a handler
49- if not logger .hasHandlers ():
50- handler = logging .StreamHandler ()
51- handler .setFormatter (logging .Formatter ("%(asctime)s - %(levelname)s - %(name)s - %(message)s" ))
52- logger .addHandler (handler )
42+ logger = get_logger (__name__ )
5343
5444TQ_NUM_THREADS = int (os .environ .get ("TQ_NUM_THREADS" , 8 ))
5545
46+ # Pre-bound decorator for controller socket operations.
47+ with_controller_socket = with_zmq_socket (
48+ "request_handle_socket" ,
49+ get_identity = lambda self : self .client_id ,
50+ get_peer = lambda self , target : self ._controller ,
51+ )
52+
5653
5754class AsyncTransferQueueClient :
5855 """Asynchronous client for interacting with TransferQueue controller and storage systems.
@@ -99,63 +96,8 @@ def initialize_storage_manager(
9996 manager_type , controller_info = self ._controller , config = config
10097 )
10198
102- # TODO (TQStorage): Provide a general dynamic socket function for both Client & Storage @huazhong.
103- @staticmethod
104- def dynamic_socket (socket_name : str ):
105- """Decorator to auto-manage ZMQ sockets for Controller/Storage servers.
106-
107- Handles socket lifecycle: create -> connect -> inject -> close.
108-
109- Args:
110- socket_name: Port name from server config to use for ZMQ connection (e.g., "data_req_port")
111-
112- Decorated Function Requirements:
113- 1. Must be an async class method (needs `self`)
114- 2. `self` must have:
115- - `_controller`: Server registry
116- - `client_id`: Unique client ID for socket identity
117- 3. Receives ZMQ socket via `socket` keyword argument (injected by decorator)
118- """
119-
120- def decorator (func : Callable ):
121- @wraps (func )
122- async def wrapper (self , * args , ** kwargs ):
123- server_info = self ._controller
124- if not server_info :
125- raise RuntimeError ("No controller registered" )
126-
127- context = zmq .asyncio .Context ()
128- address = format_zmq_address (server_info .ip , server_info .ports .get (socket_name ))
129- identity = f"{ self .client_id } _to_{ server_info .id } _{ uuid4 ().hex [:8 ]} " .encode ()
130- sock = create_zmq_socket (context , zmq .DEALER , identity = identity , ip = server_info .ip )
131-
132- try :
133- sock .connect (address )
134- logger .debug (
135- f"[{ self .client_id } ]: Connected to Controller { server_info .id } at { address } "
136- f"with identity { identity .decode ()} "
137- )
138-
139- kwargs ["socket" ] = sock
140- return await func (self , * args , ** kwargs )
141- except Exception as e :
142- logger .error (f"[{ self .client_id } ]: Error in socket operation with Controller { server_info .id } : { e } " )
143- raise
144- finally :
145- try :
146- if not sock .closed :
147- sock .close (linger = - 1 )
148- except Exception as e :
149- logger .warning (f"[{ self .client_id } ]: Error closing socket to Controller { server_info .id } : { e } " )
150-
151- context .term ()
152-
153- return wrapper
154-
155- return decorator
156-
15799 # ==================== Basic API ====================
158- @dynamic_socket ( socket_name = "request_handle_socket" )
100+ @with_controller_socket
159101 async def async_get_meta (
160102 self ,
161103 data_fields : list [str ],
@@ -245,7 +187,7 @@ async def async_get_meta(
245187 f"{ response_msg .body .get ('message' , 'Unknown error' )} "
246188 )
247189
248- @dynamic_socket ( socket_name = "request_handle_socket" )
190+ @with_controller_socket
249191 async def async_set_custom_meta (
250192 self ,
251193 metadata : BatchMeta ,
@@ -545,7 +487,7 @@ async def async_clear_samples(self, metadata: BatchMeta):
545487 except Exception as e :
546488 raise RuntimeError (f"Error in clear_samples operation: { str (e )} " ) from e
547489
548- @dynamic_socket ( socket_name = "request_handle_socket" )
490+ @with_controller_socket
549491 async def _clear_meta_in_controller (self , metadata : BatchMeta , socket = None ):
550492 """Clear metadata in the controller.
551493
@@ -571,7 +513,7 @@ async def _clear_meta_in_controller(self, metadata: BatchMeta, socket=None):
571513 if response_msg .request_type != ZMQRequestType .CLEAR_META_RESPONSE :
572514 raise RuntimeError ("Failed to clear samples metadata in controller." )
573515
574- @dynamic_socket ( socket_name = "request_handle_socket" )
516+ @with_controller_socket
575517 async def _get_partition_meta (self , partition_id : str , socket = None ) -> BatchMeta :
576518 """Get metadata required for the whole partition from controller.
577519
@@ -601,7 +543,7 @@ async def _get_partition_meta(self, partition_id: str, socket=None) -> BatchMeta
601543
602544 return response_msg .body ["metadata" ]
603545
604- @dynamic_socket ( socket_name = "request_handle_socket" )
546+ @with_controller_socket
605547 async def _clear_partition_in_controller (self , partition_id , socket = None ):
606548 """Clear the whole partition in the controller.
607549
@@ -628,7 +570,7 @@ async def _clear_partition_in_controller(self, partition_id, socket=None):
628570 raise RuntimeError (f"Failed to clear partition { partition_id } in controller." )
629571
630572 # ==================== Status Query API ====================
631- @dynamic_socket ( socket_name = "request_handle_socket" )
573+ @with_controller_socket
632574 async def async_get_consumption_status (
633575 self ,
634576 task_name : str ,
@@ -691,7 +633,7 @@ async def async_get_consumption_status(
691633 except Exception as e :
692634 raise RuntimeError (f"[{ self .client_id } ]: Error in get_consumption_status: { str (e )} " ) from e
693635
694- @dynamic_socket ( socket_name = "request_handle_socket" )
636+ @with_controller_socket
695637 async def async_get_production_status (
696638 self ,
697639 data_fields : list [str ],
@@ -823,7 +765,7 @@ async def async_check_production_status(
823765 return False
824766 return torch .all (production_status == 1 ).item ()
825767
826- @dynamic_socket ( socket_name = "request_handle_socket" )
768+ @with_controller_socket
827769 async def async_reset_consumption (
828770 self ,
829771 partition_id : str ,
@@ -885,7 +827,7 @@ async def async_reset_consumption(
885827 except Exception as e :
886828 raise RuntimeError (f"[{ self .client_id } ]: Error in reset_consumption: { str (e )} " ) from e
887829
888- @dynamic_socket ( socket_name = "request_handle_socket" )
830+ @with_controller_socket
889831 async def async_get_partition_list (
890832 self ,
891833 socket : Optional [zmq .asyncio .Socket ] = None ,
@@ -931,7 +873,7 @@ async def async_get_partition_list(
931873 raise RuntimeError (f"[{ self .client_id } ]: Error in get_partition_list: { str (e )} " ) from e
932874
933875 # ==================== KV Interface API ====================
934- @dynamic_socket ( socket_name = "request_handle_socket" )
876+ @with_controller_socket
935877 async def async_kv_retrieve_meta (
936878 self ,
937879 keys : list [str ] | str ,
@@ -997,7 +939,7 @@ async def async_kv_retrieve_meta(
997939 except Exception as e :
998940 raise RuntimeError (f"[{ self .client_id } ]: Error in kv_retrieve_keys: { str (e )} " ) from e
999941
1000- @dynamic_socket ( socket_name = "request_handle_socket" )
942+ @with_controller_socket
1001943 async def async_kv_retrieve_keys (
1002944 self ,
1003945 global_indexes : list [int ] | int ,
@@ -1060,7 +1002,7 @@ async def async_kv_retrieve_keys(
10601002 except Exception as e :
10611003 raise RuntimeError (f"[{ self .client_id } ]: Error in kv_retrieve_indexes: { str (e )} " ) from e
10621004
1063- @dynamic_socket ( socket_name = "request_handle_socket" )
1005+ @with_controller_socket
10641006 async def async_kv_list (
10651007 self ,
10661008 partition_id : Optional [str ] = None ,
0 commit comments