1717import logging
1818import os
1919import threading
20- from typing import Any , Optional
20+ from typing import Any , Callable , Optional
2121
2222import torch
2323import zmq
3636 ZMQMessage ,
3737 ZMQRequestType ,
3838 ZMQServerInfo ,
39- dynamic_zmq_socket ,
39+ with_zmq_socket ,
4040)
4141
4242logger = logging .getLogger (__name__ )
5151TQ_NUM_THREADS = int (os .environ .get ("TQ_NUM_THREADS" , 8 ))
5252
5353# Pre-bound decorator for controller socket operations.
54- _controller_socket = dynamic_zmq_socket (
54+ with_controller_socket = with_zmq_socket (
5555 "request_handle_socket" ,
56- owner_id_attr = " client_id" ,
57- server_attr = " _controller" ,
56+ get_identity = lambda self : self . client_id ,
57+ get_peer = lambda self , target : self . _controller ,
5858)
5959
6060
@@ -104,7 +104,7 @@ def initialize_storage_manager(
104104 )
105105
106106 # ==================== Basic API ====================
107- @_controller_socket
107+ @with_controller_socket
108108 async def async_get_meta (
109109 self ,
110110 data_fields : list [str ],
@@ -194,7 +194,7 @@ async def async_get_meta(
194194 f"{ response_msg .body .get ('message' , 'Unknown error' )} "
195195 )
196196
197- @_controller_socket
197+ @with_controller_socket
198198 async def async_set_custom_meta (
199199 self ,
200200 metadata : BatchMeta ,
@@ -494,7 +494,7 @@ async def async_clear_samples(self, metadata: BatchMeta):
494494 except Exception as e :
495495 raise RuntimeError (f"Error in clear_samples operation: { str (e )} " ) from e
496496
497- @_controller_socket
497+ @with_controller_socket
498498 async def _clear_meta_in_controller (self , metadata : BatchMeta , socket = None ):
499499 """Clear metadata in the controller.
500500
@@ -520,7 +520,7 @@ async def _clear_meta_in_controller(self, metadata: BatchMeta, socket=None):
520520 if response_msg .request_type != ZMQRequestType .CLEAR_META_RESPONSE :
521521 raise RuntimeError ("Failed to clear samples metadata in controller." )
522522
523- @_controller_socket
523+ @with_controller_socket
524524 async def _get_partition_meta (self , partition_id : str , socket = None ) -> BatchMeta :
525525 """Get metadata required for the whole partition from controller.
526526
@@ -550,7 +550,7 @@ async def _get_partition_meta(self, partition_id: str, socket=None) -> BatchMeta
550550
551551 return response_msg .body ["metadata" ]
552552
553- @_controller_socket
553+ @with_controller_socket
554554 async def _clear_partition_in_controller (self , partition_id , socket = None ):
555555 """Clear the whole partition in the controller.
556556
@@ -577,7 +577,7 @@ async def _clear_partition_in_controller(self, partition_id, socket=None):
577577 raise RuntimeError (f"Failed to clear partition { partition_id } in controller." )
578578
579579 # ==================== Status Query API ====================
580- @_controller_socket
580+ @with_controller_socket
581581 async def async_get_consumption_status (
582582 self ,
583583 task_name : str ,
@@ -640,7 +640,7 @@ async def async_get_consumption_status(
640640 except Exception as e :
641641 raise RuntimeError (f"[{ self .client_id } ]: Error in get_consumption_status: { str (e )} " ) from e
642642
643- @_controller_socket
643+ @with_controller_socket
644644 async def async_get_production_status (
645645 self ,
646646 data_fields : list [str ],
@@ -772,7 +772,7 @@ async def async_check_production_status(
772772 return False
773773 return torch .all (production_status == 1 ).item ()
774774
775- @_controller_socket
775+ @with_controller_socket
776776 async def async_reset_consumption (
777777 self ,
778778 partition_id : str ,
@@ -834,7 +834,7 @@ async def async_reset_consumption(
834834 except Exception as e :
835835 raise RuntimeError (f"[{ self .client_id } ]: Error in reset_consumption: { str (e )} " ) from e
836836
837- @_controller_socket
837+ @with_controller_socket
838838 async def async_get_partition_list (
839839 self ,
840840 socket : Optional [zmq .asyncio .Socket ] = None ,
@@ -880,7 +880,7 @@ async def async_get_partition_list(
880880 raise RuntimeError (f"[{ self .client_id } ]: Error in get_partition_list: { str (e )} " ) from e
881881
882882 # ==================== KV Interface API ====================
883- @_controller_socket
883+ @with_controller_socket
884884 async def async_kv_retrieve_meta (
885885 self ,
886886 keys : list [str ] | str ,
@@ -946,7 +946,7 @@ async def async_kv_retrieve_meta(
946946 except Exception as e :
947947 raise RuntimeError (f"[{ self .client_id } ]: Error in kv_retrieve_keys: { str (e )} " ) from e
948948
949- @_controller_socket
949+ @with_controller_socket
950950 async def async_kv_retrieve_keys (
951951 self ,
952952 global_indexes : list [int ] | int ,
@@ -1009,7 +1009,7 @@ async def async_kv_retrieve_keys(
10091009 except Exception as e :
10101010 raise RuntimeError (f"[{ self .client_id } ]: Error in kv_retrieve_indexes: { str (e )} " ) from e
10111011
1012- @_controller_socket
1012+ @with_controller_socket
10131013 async def async_kv_list (
10141014 self ,
10151015 partition_id : Optional [str ] = None ,
0 commit comments