Skip to content

Commit ad9732b

Browse files
committed
remove unnecessary poller
Signed-off-by: ji-huazhong <hzji210@gmail.com>
1 parent 9ff85a9 commit ad9732b

2 files changed

Lines changed: 31 additions & 43 deletions

File tree

transfer_queue/controller.py

Lines changed: 30 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import numpy as np
2828
import ray
2929
import torch
30-
import zmq
3130
from omegaconf import DictConfig
3231
from torch import Tensor
3332

@@ -1040,51 +1039,44 @@ def _start_daemon_threads(self):
10401039
)
10411040

10421041
def _wait_connection(self):
1043-
"""Wait for storage instances to complete handshake with retransmission support."""
1042+
"""Wait for storage instances to complete handshake."""
10441043
handshake_socket = self._transport.get_socket("handshake_socket")
1045-
poller = zmq.Poller()
1046-
poller.register(handshake_socket, zmq.POLLIN)
10471044

10481045
logger.debug(f"Controller {self.controller_id} started waiting for storage connections...")
10491046

10501047
while True:
1051-
socks = dict(poller.poll(1000))
1048+
try:
1049+
messages = handshake_socket.recv_multipart(copy=False)
1050+
identity = messages.pop(0)
1051+
serialized_msg = messages
1052+
request_msg = ZMQMessage.deserialize(serialized_msg)
10521053

1053-
if handshake_socket in socks:
1054-
try:
1055-
messages = handshake_socket.recv_multipart(copy=False)
1056-
identity = messages.pop(0)
1057-
serialized_msg = messages
1058-
request_msg = ZMQMessage.deserialize(serialized_msg)
1054+
if request_msg.request_type == ZMQRequestType.HANDSHAKE:
1055+
storage_manager_id = request_msg.sender_id
10591056

1060-
if request_msg.request_type == ZMQRequestType.HANDSHAKE:
1061-
storage_manager_id = request_msg.sender_id
1062-
1063-
# Always send ACK for HANDSHAKE
1064-
response_msg = ZMQMessage.create(
1065-
request_type=ZMQRequestType.HANDSHAKE_ACK,
1066-
sender_id=self.controller_id,
1067-
body={},
1068-
).serialize()
1069-
handshake_socket.send_multipart([identity, *response_msg])
1070-
1071-
# Track new connections
1072-
if storage_manager_id not in self._connected_storage_managers:
1073-
self._connected_storage_managers.add(storage_manager_id)
1074-
storage_manager_type = request_msg.body.get("storage_manager_type", "Unknown")
1075-
logger.debug(
1076-
f"[{self.controller_id}]: received handshake from "
1077-
f"storage manager {storage_manager_id} (type: {storage_manager_type}). "
1078-
f"Total connected: {len(self._connected_storage_managers)}"
1079-
)
1080-
else:
1081-
logger.debug(
1082-
f"[{self.controller_id}]: received duplicate handshake from "
1083-
f"storage manager {storage_manager_id}. Resending ACK."
1084-
)
1057+
response_msg = ZMQMessage.create(
1058+
request_type=ZMQRequestType.HANDSHAKE_ACK,
1059+
sender_id=self.controller_id,
1060+
body={},
1061+
).serialize()
1062+
handshake_socket.send_multipart([identity, *response_msg])
1063+
1064+
if storage_manager_id not in self._connected_storage_managers:
1065+
self._connected_storage_managers.add(storage_manager_id)
1066+
storage_manager_type = request_msg.body.get("storage_manager_type", "Unknown")
1067+
logger.debug(
1068+
f"[{self.controller_id}]: received handshake from "
1069+
f"storage manager {storage_manager_id} (type: {storage_manager_type}). "
1070+
f"Total connected: {len(self._connected_storage_managers)}"
1071+
)
1072+
else:
1073+
logger.debug(
1074+
f"[{self.controller_id}]: received duplicate handshake from "
1075+
f"storage manager {storage_manager_id}. Resending ACK."
1076+
)
10851077

1086-
except Exception as e:
1087-
logger.error(f"[{self.controller_id}]: error processing handshake: {e}")
1078+
except Exception as e:
1079+
logger.error(f"[{self.controller_id}]: error processing handshake: {e}")
10881080

10891081
def _process_request(self):
10901082
"""Main request processing loop - adapted for partition-based operations."""
@@ -1701,7 +1693,6 @@ def get_metadata(
17011693

17021694
elif mode == "force_fetch":
17031695
batch_global_indexes = self.index_manager.get_indexes_for_partition(partition_id)
1704-
consumed_indexes = []
17051696

17061697
# Package into metadata
17071698
metadata = self.generate_batch_meta(partition_id, batch_global_indexes, data_fields, mode)

transfer_queue/utils/zmq_utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,7 @@ def deserialize(cls, frames: list) -> "ZMQMessage":
188188

189189

190190
class ZMQServerTransport:
191-
"""Unified ZMQ transport abstraction for Controller / StorageUnit.
192-
Encapsulates socket creation, binding, inproc endpoint, daemon thread,
193-
ZMQ proxy lifecycle, and unified resource cleanup.
194-
"""
191+
"""Unified management of ZMQ Router Sockets, port binding, daemon threads, and message I/O."""
195192

196193
def __init__(self, node_ip: str, ctx: zmq.Context | None = None):
197194
self.node_ip = node_ip

0 commit comments

Comments
 (0)