Skip to content

Commit 1cd11d7

Browse files
committed
update
Signed-off-by: ji-huazhong <hzji210@gmail.com>
1 parent 20a5ece commit 1cd11d7

3 files changed

Lines changed: 32 additions & 48 deletions

File tree

transfer_queue/client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
ZMQMessage,
3737
ZMQRequestType,
3838
ZMQServerInfo,
39-
dynamic_zmq_socket,
39+
with_zmq_socket,
4040
)
4141

4242
logger = logging.getLogger(__name__)
@@ -51,10 +51,10 @@
5151
TQ_NUM_THREADS = int(os.environ.get("TQ_NUM_THREADS", 8))
5252

5353
# Pre-bound decorator for controller socket operations.
54-
_controller_socket = dynamic_zmq_socket(
54+
_controller_socket = with_zmq_socket(
5555
"request_handle_socket",
56-
owner_id_attr="client_id",
57-
server_attr="_controller",
56+
get_identity=lambda self: self.client_id,
57+
get_peer=lambda self, target: self._controller,
5858
)
5959

6060

transfer_queue/storage/managers/simple_backend_manager.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
ZMQMessage,
3535
ZMQRequestType,
3636
ZMQServerInfo,
37-
dynamic_zmq_socket,
37+
with_zmq_socket,
3838
)
3939

4040
logger = logging.getLogger(__name__)
@@ -49,11 +49,11 @@
4949
TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT", 200)) # seconds
5050

5151
# Pre-bound decorator for storage-unit socket operations.
52-
_storage_unit_socket = dynamic_zmq_socket(
52+
_storage_unit_socket = with_zmq_socket(
5353
"put_get_socket",
54-
owner_id_attr="storage_manager_id",
55-
server_attr="storage_unit_infos",
56-
target_kwarg="target_storage_unit",
54+
get_identity=lambda self: self.storage_manager_id,
55+
get_peer=lambda self, target: self.storage_unit_infos[target],
56+
resolve_target=lambda args, kwargs: kwargs.get("target_storage_unit"),
5757
timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT,
5858
)
5959

transfer_queue/utils/zmq_utils.py

Lines changed: 23 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -304,12 +304,12 @@ def create_zmq_socket(
304304
return socket
305305

306306

307-
def dynamic_zmq_socket(
307+
def with_zmq_socket(
308308
socket_name: str,
309309
*,
310-
owner_id_attr: str,
311-
server_attr: str,
312-
target_kwarg: Optional[str] = None,
310+
get_identity: Callable[[Any], str],
311+
get_peer: Callable[[Any, Optional[str]], ZMQServerInfo],
312+
resolve_target: Optional[Callable[[tuple, dict], Optional[str]]] = None,
313313
timeout: Optional[int] = None,
314314
):
315315
"""Create a reusable async decorator for request sockets.
@@ -320,50 +320,34 @@ def dynamic_zmq_socket(
320320
321321
Args:
322322
socket_name: Socket port key in ``ZMQServerInfo.ports``.
323-
owner_id_attr: Attribute name on ``self`` used in identity/log prefix
324-
(e.g., ``client_id`` or ``storage_manager_id``).
325-
server_attr: Attribute name on ``self`` that stores server info.
326-
- ``ZMQServerInfo`` for single-target calls.
327-
- ``Mapping[str, ZMQServerInfo]`` for multi-target calls.
328-
target_kwarg: Optional kwarg name that provides target server id when
329-
``server_attr`` is a mapping.
323+
get_identity: Callable that extracts owner identity from ``self``.
324+
Example: ``lambda self: self.client_id``
325+
get_peer: Callable that returns ``ZMQServerInfo`` for the target.
326+
For single-target scenarios, ignore the target parameter.
327+
Example: ``lambda self, target: self.server_info``
328+
Example: ``lambda self, target: self.storage_unit_infos[target]``
329+
resolve_target: Optional callable that extracts target identifier from
330+
function arguments. Receives (args, kwargs) and returns target name.
331+
Example: ``lambda args, kwargs: kwargs.get("target_storage_unit")``
330332
timeout: Optional timeout (seconds) for both send/recv operations.
331333
"""
332334

333335
def decorator(func: Callable):
334336
@wraps(func)
335337
async def wrapper(self, *args, **kwargs):
336-
owner_id = getattr(self, owner_id_attr, None)
338+
owner_id = get_identity(self)
337339
if owner_id is None:
338-
raise RuntimeError(f"Missing owner id attribute: {owner_id_attr}")
339-
340-
server_obj = getattr(self, server_attr, None)
341-
if server_obj is None:
342-
raise RuntimeError(f"Missing server registry attribute: {server_attr}")
340+
raise RuntimeError("get_identity returned None")
343341

344342
target_name: Optional[str] = None
345-
if target_kwarg is not None:
346-
target_name = kwargs.get(target_kwarg)
347-
if target_name is None:
348-
for arg in args:
349-
if isinstance(arg, str):
350-
target_name = arg
351-
break
352-
353-
if isinstance(server_obj, ZMQServerInfo):
354-
if target_name is not None and target_name != server_obj.id:
355-
raise RuntimeError(
356-
f"Target mismatch: target '{target_name}' does not match registered server '{server_obj.id}'"
357-
)
358-
server_info = server_obj
359-
elif isinstance(server_obj, Mapping):
360-
if target_name is None:
361-
raise RuntimeError(f"Missing target server identifier via '{target_kwarg}'")
362-
server_info = server_obj.get(target_name)
363-
if server_info is None:
364-
raise RuntimeError(f"Server '{target_name}' not found in registered servers")
365-
else:
366-
raise RuntimeError(f"Unsupported server registry type for '{server_attr}': {type(server_obj).__name__}")
343+
if resolve_target is not None:
344+
target_name = resolve_target(args, kwargs)
345+
346+
server_info = get_peer(self, target_name)
347+
if server_info is None:
348+
raise RuntimeError(
349+
f"get_peer returned None for target '{target_name}'"
350+
)
367351

368352
port = server_info.ports.get(socket_name)
369353
if port is None:

0 commit comments

Comments
 (0)