Skip to content

Commit 7435934

Browse files
committed
remove unnecessary poller
Signed-off-by: ji-huazhong <hzji210@gmail.com>
1 parent 8acb4b9 commit 7435934

2 files changed

Lines changed: 36 additions & 49 deletions

File tree

transfer_queue/controller.py

Lines changed: 30 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import numpy as np
2828
import ray
2929
import torch
30-
import zmq
3130
from omegaconf import DictConfig
3231
from torch import Tensor
3332

@@ -1047,51 +1046,44 @@ def _start_daemon_threads(self):
10471046
)
10481047

10491048
def _wait_connection(self):
1050-
"""Wait for storage instances to complete handshake with retransmission support."""
1049+
"""Wait for storage instances to complete handshake."""
10511050
handshake_socket = self._transport.get_socket("handshake_socket")
1052-
poller = zmq.Poller()
1053-
poller.register(handshake_socket, zmq.POLLIN)
10541051

10551052
logger.debug(f"Controller {self.controller_id} started waiting for storage connections...")
10561053

10571054
while True:
1058-
socks = dict(poller.poll(1000))
1059-
1060-
if handshake_socket in socks:
1061-
try:
1062-
messages = handshake_socket.recv_multipart(copy=False)
1063-
identity = messages.pop(0)
1064-
serialized_msg = messages
1065-
request_msg = ZMQMessage.deserialize(serialized_msg)
1055+
try:
1056+
messages = handshake_socket.recv_multipart(copy=False)
1057+
identity = messages.pop(0)
1058+
serialized_msg = messages
1059+
request_msg = ZMQMessage.deserialize(serialized_msg)
10661060

1067-
if request_msg.request_type == ZMQRequestType.HANDSHAKE:
1068-
storage_manager_id = request_msg.sender_id
1061+
if request_msg.request_type == ZMQRequestType.HANDSHAKE:
1062+
storage_manager_id = request_msg.sender_id
10691063

1070-
# Always send ACK for HANDSHAKE
1071-
response_msg = ZMQMessage.create(
1072-
request_type=ZMQRequestType.HANDSHAKE_ACK,
1073-
sender_id=self.controller_id,
1074-
body={},
1075-
).serialize()
1076-
handshake_socket.send_multipart([identity, *response_msg])
1077-
1078-
# Track new connections
1079-
if storage_manager_id not in self._connected_storage_managers:
1080-
self._connected_storage_managers.add(storage_manager_id)
1081-
storage_manager_type = request_msg.body.get("storage_manager_type", "Unknown")
1082-
logger.debug(
1083-
f"[{self.controller_id}]: received handshake from "
1084-
f"storage manager {storage_manager_id} (type: {storage_manager_type}). "
1085-
f"Total connected: {len(self._connected_storage_managers)}"
1086-
)
1087-
else:
1088-
logger.debug(
1089-
f"[{self.controller_id}]: received duplicate handshake from "
1090-
f"storage manager {storage_manager_id}. Resending ACK."
1091-
)
1064+
response_msg = ZMQMessage.create(
1065+
request_type=ZMQRequestType.HANDSHAKE_ACK,
1066+
sender_id=self.controller_id,
1067+
body={},
1068+
).serialize()
1069+
handshake_socket.send_multipart([identity, *response_msg])
1070+
1071+
if storage_manager_id not in self._connected_storage_managers:
1072+
self._connected_storage_managers.add(storage_manager_id)
1073+
storage_manager_type = request_msg.body.get("storage_manager_type", "Unknown")
1074+
logger.debug(
1075+
f"[{self.controller_id}]: received handshake from "
1076+
f"storage manager {storage_manager_id} (type: {storage_manager_type}). "
1077+
f"Total connected: {len(self._connected_storage_managers)}"
1078+
)
1079+
else:
1080+
logger.debug(
1081+
f"[{self.controller_id}]: received duplicate handshake from "
1082+
f"storage manager {storage_manager_id}. Resending ACK."
1083+
)
10921084

1093-
except Exception as e:
1094-
logger.error(f"[{self.controller_id}]: error processing handshake: {e}")
1085+
except Exception as e:
1086+
logger.error(f"[{self.controller_id}]: error processing handshake: {e}")
10951087

10961088
def _process_request(self):
10971089
"""Main request processing loop - adapted for partition-based operations."""
@@ -1708,7 +1700,6 @@ def get_metadata(
17081700

17091701
elif mode == "force_fetch":
17101702
batch_global_indexes = self.index_manager.get_indexes_for_partition(partition_id)
1711-
consumed_indexes = []
17121703

17131704
# Package into metadata
17141705
metadata = self.generate_batch_meta(partition_id, batch_global_indexes, data_fields, mode)

transfer_queue/utils/zmq_utils.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,7 @@ class ZMQRequestType(ExplicitEnum):
104104

105105

106106
class ZMQServerInfo:
107-
"""
108-
TransferQueue server info class.
109-
"""
107+
"""TransferQueue server info class."""
110108

111109
def __init__(self, role: Role, id: str, ip: str, ports: dict[str, int]):
112110
self.role = role
@@ -133,9 +131,7 @@ def __str__(self) -> str:
133131

134132
@dataclass
135133
class ZMQMessage:
136-
"""
137-
ZMQMessage class for TransferQueue communication.
138-
"""
134+
"""ZMQMessage class for TransferQueue communication."""
139135

140136
request_type: ZMQRequestType
141137
sender_id: str
@@ -192,10 +188,7 @@ def deserialize(cls, frames: list) -> "ZMQMessage":
192188

193189

194190
class ZMQServerTransport:
195-
"""Unified ZMQ transport abstraction for Controller / StorageUnit.
196-
Encapsulates socket creation, binding, inproc endpoint, daemon thread,
197-
ZMQ proxy lifecycle, and unified resource cleanup.
198-
"""
191+
"""Unified management of ZMQ Router Sockets, port binding, daemon threads, and message I/O."""
199192

200193
def __init__(self, node_ip: str, ctx: zmq.Context | None = None):
201194
self.node_ip = node_ip
@@ -222,14 +215,17 @@ def create_router_socket(self, name: str) -> None:
222215
logger.warning(f"ZMQ bind {name} failed, retrying...")
223216

224217
def get_socket(self, name: str) -> zmq.Socket:
218+
"""Get ZMQ socket by name."""
225219
return self.sockets[name]
226220

227221
def start_daemon_thread(self, target, name: str) -> None:
222+
"""Start a daemon thread with the given target functions."""
228223
t = threading.Thread(target=target, name=name, daemon=True)
229224
t.start()
230225
self.threads.append(t)
231226

232227
def build_server_info(self, role: Role, server_id: str) -> ZMQServerInfo:
228+
"""Build ZMQServerInfo."""
233229
return ZMQServerInfo(
234230
role=role,
235231
id=server_id,

0 commit comments

Comments
 (0)