2222import zmq
2323import zmq .asyncio
2424from tensordict import TensorDict
25- from torch import Tensor
2625
27- from transfer_queue .metadata import (
28- BatchMeta ,
29- )
30- from transfer_queue .storage import (
31- StorageManagerFactory ,
32- )
26+ from transfer_queue .metadata import BatchMeta
27+ from transfer_queue .storage import StorageManagerFactory
3328from transfer_queue .utils .common import limit_pytorch_auto_parallel_threads
3429from transfer_queue .utils .logging_utils import get_logger
3530from transfer_queue .utils .zmq_utils import (
@@ -576,7 +571,7 @@ async def async_get_consumption_status(
576571 task_name : str ,
577572 partition_id : str ,
578573 socket : zmq .asyncio .Socket | None = None ,
579- ) -> tuple [Tensor | None , Tensor | None ]:
574+ ) -> tuple [torch . Tensor | None , torch . Tensor | None ]:
580575 """Get consumption status for current partition in a specific task.
581576
582577 Args:
@@ -639,7 +634,7 @@ async def async_get_production_status(
639634 data_fields : list [str ],
640635 partition_id : str ,
641636 socket : zmq .asyncio .Socket | None = None ,
642- ) -> tuple [Tensor | None , Tensor | None ]:
637+ ) -> tuple [torch . Tensor | None , torch . Tensor | None ]:
643638 """Get production status for specific data fields and partition.
644639
645640 Args:
@@ -881,31 +876,28 @@ async def async_kv_retrieve_meta(
881876 create : bool = False ,
882877 socket : zmq .asyncio .Socket | None = None ,
883878 ) -> BatchMeta :
884- """Asynchronously retrieve BatchMeta from the controller using user-specified keys.
879+ """Asynchronously retrieve BatchMeta by user-defined keys.
880+
881+ Retrieves metadata for given keys from a specified partition.
882+ If keys do not exist and `create=True`, they will be automatically registered.
885883
886884 Args:
887- keys: List of keys to retrieve from the controller
885+ keys: List of keys to retrieve.
888886 partition_id: The ID of the logical partition to search for keys.
889- create: Whether to register new keys if not found .
890- socket: ZMQ socket ( injected by decorator)
887+ create: If True, automatically create entries for missing keys .
888+ socket: ZMQ socket injected by @with_controller_socket.
891889
892890 Returns:
893- metadata: BatchMeta of the corresponding keys
894-
895- Raises:
896- TypeError: If `keys` is not a list of string or a string
891+ BatchMeta: Metadata for the requested keys.
897892 """
898-
899893 if isinstance (keys , str ):
900894 keys = [keys ]
901- elif isinstance (keys , list ):
902- if len (keys ) < 1 :
903- raise ValueError ("Received an empty list as keys." )
904- # validate all the elements are str
905- if not all (isinstance (k , str ) for k in keys ):
906- raise TypeError ("Not all elements in `keys` are strings." )
907- else :
908- raise TypeError ("Only string or list of strings are allowed as `keys`." )
895+
896+ if not isinstance (keys , list ) or len (keys ) < 1 :
897+ raise ValueError ("`keys` must be a non-empty string or list of strings." )
898+
899+ if not all (isinstance (k , str ) for k in keys ):
900+ raise TypeError ("All elements in `keys` must be strings." )
909901
910902 request_msg = ZMQMessage .create (
911903 request_type = ZMQRequestType .KV_RETRIEVE_META , # type: ignore[arg-type]
@@ -919,25 +911,23 @@ async def async_kv_retrieve_meta(
919911 )
920912
921913 try :
922- assert socket is not None
914+ assert socket is not None , "Socket must be initialized before use"
923915 await socket .send_multipart (request_msg .serialize ())
924916 response_serialized = await socket .recv_multipart (copy = False )
925917 response_msg = ZMQMessage .deserialize (response_serialized )
926918 logger .debug (
927- f"[{ self .client_id } ]: Client get kv_retrieve_keys response: { response_msg } "
919+ f"[{ self .client_id } ] Received KV_RETRIEVE_META response: { response_msg } "
928920 f"from controller { self ._controller .id } "
929921 )
930922
931923 if response_msg .request_type == ZMQRequestType .KV_RETRIEVE_META_RESPONSE :
932- metadata = response_msg .body .get ("metadata" , BatchMeta .empty ())
933- return metadata
934- else :
935- raise RuntimeError (
936- f"[{ self .client_id } ]: Failed to retrieve keys from controller { self ._controller .id } : "
937- f"{ response_msg .body .get ('message' , 'Unknown error' )} "
938- )
924+ return response_msg .body .get ("metadata" , BatchMeta .empty ())
925+
926+ raise RuntimeError (
927+ f"[{ self .client_id } ] Failed to retrieve metadata { response_msg .body .get ('message' , 'Unknown error' )} "
928+ )
939929 except Exception as e :
940- raise RuntimeError (f"[{ self .client_id } ]: Error in kv_retrieve_keys : { str ( e ) } " ) from e
930+ raise RuntimeError (f"[{ self .client_id } ] Failed in async_kv_retrieve_meta : { e } " ) from e
941931
942932 @with_controller_socket
943933 async def async_kv_retrieve_keys (
@@ -1356,7 +1346,7 @@ def get_consumption_status(
13561346 self ,
13571347 task_name : str ,
13581348 partition_id : str ,
1359- ) -> tuple [Tensor | None , Tensor | None ]:
1349+ ) -> tuple [torch . Tensor | None , torch . Tensor | None ]:
13601350 """Synchronously get consumption status for a specific task and partition.
13611351
13621352 Args:
@@ -1384,7 +1374,7 @@ def get_production_status(
13841374 self ,
13851375 data_fields : list [str ],
13861376 partition_id : str ,
1387- ) -> tuple [Tensor | None , Tensor | None ]:
1377+ ) -> tuple [torch . Tensor | None , torch . Tensor | None ]:
13881378 """Synchronously get production status for specific data fields and partition.
13891379
13901380 Args:
@@ -1501,20 +1491,22 @@ def kv_retrieve_meta(
15011491 partition_id : str ,
15021492 create : bool = False ,
15031493 ) -> BatchMeta :
1504- """Synchronously retrieve BatchMeta from the controller using user-specified keys.
1494+ """Synchronously retrieve BatchMeta by user-defined keys.
1495+
1496+ Retrieves metadata for given keys from a specified partition.
1497+ If keys do not exist and `create=True`, they will be automatically registered.
15051498
15061499 Args:
1507- keys: List of keys to retrieve from the controller
1508- partition_id: The ID of the logical partition to search for keys .
1509- create: Whether to register new keys if not found .
1500+ keys: List of keys to retrieve from the controller.
1501+ partition_id: Logical partition to query .
1502+ create: If True, automatically create entries for non-existent keys .
15101503
15111504 Returns:
1512- metadata: BatchMeta of the corresponding keys
1505+ BatchMeta: Metadata for the requested keys.
15131506
15141507 Raises:
15151508 TypeError: If `keys` is not a list of string or a string
15161509 """
1517-
15181510 return self ._kv_retrieve_meta (keys = keys , partition_id = partition_id , create = create )
15191511
15201512 def kv_retrieve_keys (
0 commit comments