Skip to content

Commit 759aa40

Browse files
committed
update
Signed-off-by: ji-huazhong <hzji210@gmail.com>
1 parent 20a5ece commit 759aa40

3 files changed

Lines changed: 49 additions & 68 deletions

File tree

transfer_queue/client.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import logging
1818
import os
1919
import threading
20-
from typing import Any, Optional
20+
from typing import Any, Callable, Optional
2121

2222
import torch
2323
import zmq
@@ -36,7 +36,7 @@
3636
ZMQMessage,
3737
ZMQRequestType,
3838
ZMQServerInfo,
39-
dynamic_zmq_socket,
39+
with_zmq_socket,
4040
)
4141

4242
logger = logging.getLogger(__name__)
@@ -51,10 +51,10 @@
5151
TQ_NUM_THREADS = int(os.environ.get("TQ_NUM_THREADS", 8))
5252

5353
# Pre-bound decorator for controller socket operations.
54-
_controller_socket = dynamic_zmq_socket(
54+
with_controller_socket = with_zmq_socket(
5555
"request_handle_socket",
56-
owner_id_attr="client_id",
57-
server_attr="_controller",
56+
get_identity=lambda self: self.client_id,
57+
get_peer=lambda self, target: self._controller,
5858
)
5959

6060

@@ -104,7 +104,7 @@ def initialize_storage_manager(
104104
)
105105

106106
# ==================== Basic API ====================
107-
@_controller_socket
107+
@with_controller_socket
108108
async def async_get_meta(
109109
self,
110110
data_fields: list[str],
@@ -194,7 +194,7 @@ async def async_get_meta(
194194
f"{response_msg.body.get('message', 'Unknown error')}"
195195
)
196196

197-
@_controller_socket
197+
@with_controller_socket
198198
async def async_set_custom_meta(
199199
self,
200200
metadata: BatchMeta,
@@ -494,7 +494,7 @@ async def async_clear_samples(self, metadata: BatchMeta):
494494
except Exception as e:
495495
raise RuntimeError(f"Error in clear_samples operation: {str(e)}") from e
496496

497-
@_controller_socket
497+
@with_controller_socket
498498
async def _clear_meta_in_controller(self, metadata: BatchMeta, socket=None):
499499
"""Clear metadata in the controller.
500500
@@ -520,7 +520,7 @@ async def _clear_meta_in_controller(self, metadata: BatchMeta, socket=None):
520520
if response_msg.request_type != ZMQRequestType.CLEAR_META_RESPONSE:
521521
raise RuntimeError("Failed to clear samples metadata in controller.")
522522

523-
@_controller_socket
523+
@with_controller_socket
524524
async def _get_partition_meta(self, partition_id: str, socket=None) -> BatchMeta:
525525
"""Get metadata required for the whole partition from controller.
526526
@@ -550,7 +550,7 @@ async def _get_partition_meta(self, partition_id: str, socket=None) -> BatchMeta
550550

551551
return response_msg.body["metadata"]
552552

553-
@_controller_socket
553+
@with_controller_socket
554554
async def _clear_partition_in_controller(self, partition_id, socket=None):
555555
"""Clear the whole partition in the controller.
556556
@@ -577,7 +577,7 @@ async def _clear_partition_in_controller(self, partition_id, socket=None):
577577
raise RuntimeError(f"Failed to clear partition {partition_id} in controller.")
578578

579579
# ==================== Status Query API ====================
580-
@_controller_socket
580+
@with_controller_socket
581581
async def async_get_consumption_status(
582582
self,
583583
task_name: str,
@@ -640,7 +640,7 @@ async def async_get_consumption_status(
640640
except Exception as e:
641641
raise RuntimeError(f"[{self.client_id}]: Error in get_consumption_status: {str(e)}") from e
642642

643-
@_controller_socket
643+
@with_controller_socket
644644
async def async_get_production_status(
645645
self,
646646
data_fields: list[str],
@@ -772,7 +772,7 @@ async def async_check_production_status(
772772
return False
773773
return torch.all(production_status == 1).item()
774774

775-
@_controller_socket
775+
@with_controller_socket
776776
async def async_reset_consumption(
777777
self,
778778
partition_id: str,
@@ -834,7 +834,7 @@ async def async_reset_consumption(
834834
except Exception as e:
835835
raise RuntimeError(f"[{self.client_id}]: Error in reset_consumption: {str(e)}") from e
836836

837-
@_controller_socket
837+
@with_controller_socket
838838
async def async_get_partition_list(
839839
self,
840840
socket: Optional[zmq.asyncio.Socket] = None,
@@ -880,7 +880,7 @@ async def async_get_partition_list(
880880
raise RuntimeError(f"[{self.client_id}]: Error in get_partition_list: {str(e)}") from e
881881

882882
# ==================== KV Interface API ====================
883-
@_controller_socket
883+
@with_controller_socket
884884
async def async_kv_retrieve_meta(
885885
self,
886886
keys: list[str] | str,
@@ -946,7 +946,7 @@ async def async_kv_retrieve_meta(
946946
except Exception as e:
947947
raise RuntimeError(f"[{self.client_id}]: Error in kv_retrieve_keys: {str(e)}") from e
948948

949-
@_controller_socket
949+
@with_controller_socket
950950
async def async_kv_retrieve_keys(
951951
self,
952952
global_indexes: list[int] | int,
@@ -1009,7 +1009,7 @@ async def async_kv_retrieve_keys(
10091009
except Exception as e:
10101010
raise RuntimeError(f"[{self.client_id}]: Error in kv_retrieve_indexes: {str(e)}") from e
10111011

1012-
@_controller_socket
1012+
@with_controller_socket
10131013
async def async_kv_list(
10141014
self,
10151015
partition_id: Optional[str] = None,

transfer_queue/storage/managers/simple_backend_manager.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from collections import defaultdict
2121
from collections.abc import Mapping
2222
from operator import itemgetter
23-
from typing import Any, NamedTuple
23+
from typing import Any, Callable, NamedTuple
2424

2525
import torch
2626
import zmq
@@ -34,7 +34,7 @@
3434
ZMQMessage,
3535
ZMQRequestType,
3636
ZMQServerInfo,
37-
dynamic_zmq_socket,
37+
with_zmq_socket,
3838
)
3939

4040
logger = logging.getLogger(__name__)
@@ -49,11 +49,11 @@
4949
TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT", 200)) # seconds
5050

5151
# Pre-bound decorator for storage-unit socket operations.
52-
_storage_unit_socket = dynamic_zmq_socket(
52+
with_storage_unit_socket = with_zmq_socket(
5353
"put_get_socket",
54-
owner_id_attr="storage_manager_id",
55-
server_attr="storage_unit_infos",
56-
target_kwarg="target_storage_unit",
54+
get_identity=lambda self: self.storage_manager_id,
55+
get_peer=lambda self, target: self.storage_unit_infos[target],
56+
resolve_target=lambda args, kwargs: kwargs.get("target_storage_unit"),
5757
timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT,
5858
)
5959

@@ -220,7 +220,7 @@ def _select_by_positions(field_data, positions: list[int]):
220220
return field_data[positions]
221221

222222
async def put_data(
223-
self, data: TensorDict, metadata: BatchMeta, data_parser: Optional[Callable[[Any], Any]] = None
223+
self, data: TensorDict, metadata: BatchMeta, data_parser: Callable[[Any], Any] | None = None
224224
) -> None:
225225
"""
226226
Send data to remote StorageUnit based on metadata.
@@ -281,13 +281,13 @@ async def put_data(
281281
field_schema,
282282
)
283283

284-
@_storage_unit_socket
284+
@with_storage_unit_socket
285285
async def _put_to_single_storage_unit(
286286
self,
287287
global_indexes: list[int],
288288
storage_data: dict[str, Any],
289289
target_storage_unit: str,
290-
data_parser: Optional[Callable[[Any], Any]] = None,
290+
data_parser: Callable[[Any], Any] | None = None,
291291
socket: zmq.Socket = None,
292292
):
293293
"""
@@ -417,7 +417,7 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict:
417417

418418
return TensorDict(tensor_data, batch_size=len(metadata))
419419

420-
@_storage_unit_socket
420+
@with_storage_unit_socket
421421
async def _get_from_single_storage_unit(
422422
self,
423423
global_indexes: list[int],
@@ -489,7 +489,7 @@ async def clear_data(self, metadata: BatchMeta) -> None:
489489
if isinstance(result, Exception):
490490
logger.error(f"[{self.storage_manager_id}]: Error in clear operation task {i}: {result}")
491491

492-
@_storage_unit_socket
492+
@with_storage_unit_socket
493493
async def _clear_single_storage_unit(self, global_indexes, target_storage_unit=None, socket=None):
494494
try:
495495
request_msg = ZMQMessage.create(

transfer_queue/utils/zmq_utils.py

Lines changed: 21 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import os
1818
import socket
1919
import time
20-
from collections.abc import Mapping
2120
from dataclasses import dataclass
2221
from functools import wraps
2322
from typing import Any, Callable, Optional, TypeAlias
@@ -304,12 +303,12 @@ def create_zmq_socket(
304303
return socket
305304

306305

307-
def dynamic_zmq_socket(
306+
def with_zmq_socket(
308307
socket_name: str,
309308
*,
310-
owner_id_attr: str,
311-
server_attr: str,
312-
target_kwarg: Optional[str] = None,
309+
get_identity: Callable[[Any], str],
310+
get_peer: Callable[[Any, Optional[str]], ZMQServerInfo],
311+
resolve_target: Optional[Callable[[tuple, dict], Optional[str]]] = None,
313312
timeout: Optional[int] = None,
314313
):
315314
"""Create a reusable async decorator for request sockets.
@@ -320,50 +319,32 @@ def dynamic_zmq_socket(
320319
321320
Args:
322321
socket_name: Socket port key in ``ZMQServerInfo.ports``.
323-
owner_id_attr: Attribute name on ``self`` used in identity/log prefix
324-
(e.g., ``client_id`` or ``storage_manager_id``).
325-
server_attr: Attribute name on ``self`` that stores server info.
326-
- ``ZMQServerInfo`` for single-target calls.
327-
- ``Mapping[str, ZMQServerInfo]`` for multi-target calls.
328-
target_kwarg: Optional kwarg name that provides target server id when
329-
``server_attr`` is a mapping.
322+
get_identity: Callable that extracts owner identity from ``self``.
323+
Example: ``lambda self: self.client_id``
324+
get_peer: Callable that returns ``ZMQServerInfo`` for the target.
325+
For single-target scenarios, ignore the target parameter.
326+
Example: ``lambda self, target: self.server_info``
327+
Example: ``lambda self, target: self.storage_unit_infos[target]``
328+
resolve_target: Optional callable that extracts target identifier from
329+
function arguments. Receives (args, kwargs) and returns target name.
330+
Example: ``lambda args, kwargs: kwargs.get("target_storage_unit")``
330331
timeout: Optional timeout (seconds) for both send/recv operations.
331332
"""
332333

333334
def decorator(func: Callable):
334335
@wraps(func)
335336
async def wrapper(self, *args, **kwargs):
336-
owner_id = getattr(self, owner_id_attr, None)
337+
owner_id = get_identity(self)
337338
if owner_id is None:
338-
raise RuntimeError(f"Missing owner id attribute: {owner_id_attr}")
339-
340-
server_obj = getattr(self, server_attr, None)
341-
if server_obj is None:
342-
raise RuntimeError(f"Missing server registry attribute: {server_attr}")
339+
raise RuntimeError("get_identity returned None")
343340

344341
target_name: Optional[str] = None
345-
if target_kwarg is not None:
346-
target_name = kwargs.get(target_kwarg)
347-
if target_name is None:
348-
for arg in args:
349-
if isinstance(arg, str):
350-
target_name = arg
351-
break
352-
353-
if isinstance(server_obj, ZMQServerInfo):
354-
if target_name is not None and target_name != server_obj.id:
355-
raise RuntimeError(
356-
f"Target mismatch: target '{target_name}' does not match registered server '{server_obj.id}'"
357-
)
358-
server_info = server_obj
359-
elif isinstance(server_obj, Mapping):
360-
if target_name is None:
361-
raise RuntimeError(f"Missing target server identifier via '{target_kwarg}'")
362-
server_info = server_obj.get(target_name)
363-
if server_info is None:
364-
raise RuntimeError(f"Server '{target_name}' not found in registered servers")
365-
else:
366-
raise RuntimeError(f"Unsupported server registry type for '{server_attr}': {type(server_obj).__name__}")
342+
if resolve_target is not None:
343+
target_name = resolve_target(args, kwargs)
344+
345+
server_info = get_peer(self, target_name)
346+
if server_info is None:
347+
raise RuntimeError(f"get_peer returned None for target '{target_name}'")
367348

368349
port = server_info.ports.get(socket_name)
369350
if port is None:

0 commit comments

Comments
 (0)