diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 9661ba0e..2fab2051 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -27,7 +27,6 @@ import numpy as np import ray import torch -import zmq from omegaconf import DictConfig from torch import Tensor @@ -42,9 +41,7 @@ ZMQMessage, ZMQRequestType, ZMQServerInfo, - create_zmq_socket, - format_zmq_address, - get_free_port, + ZMQServerTransport, get_node_ip_address, ) @@ -1003,8 +1000,7 @@ def __init__( self.polling_mode = polling_mode self.tq_config = None # global config for TransferQueue system - # Initialize ZMQ sockets for communication - self._init_zmq_socket() + self._init_zmq_transport() # Partition management self.partitions: dict[str, DataPartitionStatus] = {} # partition_id -> DataPartitionStatus @@ -1020,9 +1016,7 @@ def __init__( self._metrics_endpoint: str = "" # Start background processing threads - self._start_process_handshake() - self._start_process_update_data_status() - self._start_process_request() + self._start_daemon_threads() logger.info(f"TransferQueue Controller {self.controller_id} initialized") @@ -1355,7 +1349,6 @@ def get_metadata( elif mode == "force_fetch": batch_global_indexes = self.index_manager.get_indexes_for_partition(partition_id) - consumed_indexes = [] # Package into metadata metadata = self.generate_batch_meta(partition_id, batch_global_indexes, data_fields, mode) @@ -1676,140 +1669,82 @@ def kv_retrieve_keys( return keys - def _init_zmq_socket(self): - """Initialize ZMQ sockets for communication.""" - self.zmq_context = zmq.Context() - self._node_ip = get_node_ip_address() - - while True: - try: - self._handshake_socket_port = get_free_port(ip=self._node_ip) - self._request_handle_socket_port = get_free_port(ip=self._node_ip) - self._data_status_update_socket_port = get_free_port(ip=self._node_ip) - - self.handshake_socket = create_zmq_socket( - ctx=self.zmq_context, - socket_type=zmq.ROUTER, - ip=self._node_ip, - ) - self.handshake_socket.bind(format_zmq_address(self._node_ip, self._handshake_socket_port)) - - self.request_handle_socket = create_zmq_socket( - ctx=self.zmq_context, - socket_type=zmq.ROUTER, - ip=self._node_ip, - ) - self.request_handle_socket.bind(format_zmq_address(self._node_ip, self._request_handle_socket_port)) - - self.data_status_update_socket = create_zmq_socket( - ctx=self.zmq_context, - socket_type=zmq.ROUTER, - ip=self._node_ip, - ) - self.data_status_update_socket.bind( - format_zmq_address(self._node_ip, self._data_status_update_socket_port) - ) - - break - except zmq.ZMQError: - logger.warning(f"[{self.controller_id}]: Try to bind ZMQ sockets failed, retrying...") - continue - - self.zmq_server_info = ZMQServerInfo( + def _init_zmq_transport(self): + """Initialize ZMQ transport layer.""" + self._transport = ZMQServerTransport(node_ip=get_node_ip_address()) + for socket_name in ("handshake_socket", "request_handle_socket", "data_status_update_socket"): + self._transport.create_router_socket(socket_name) + self.zmq_server_info = self._transport.build_server_info( role=Role.CONTROLLER, id=self.controller_id, - ip=self._node_ip, - ports={ - "handshake_socket": self._handshake_socket_port, - "request_handle_socket": self._request_handle_socket_port, - "data_status_update_socket": self._data_status_update_socket_port, - }, ) def _wait_connection(self): - """Wait for storage instances to complete handshake with retransmission support.""" - poller = zmq.Poller() - poller.register(self.handshake_socket, zmq.POLLIN) + """Wait for storage instances to complete handshake.""" + handshake_socket = self._transport.get_socket("handshake_socket") logger.debug(f"Controller {self.controller_id} started waiting for storage connections...") while True: - socks = dict(poller.poll(1000)) - - if self.handshake_socket in socks: - try: - messages = self.handshake_socket.recv_multipart(copy=False) - identity = messages.pop(0) - serialized_msg = messages - request_msg = ZMQMessage.deserialize(serialized_msg) + try: + messages = handshake_socket.recv_multipart(copy=False) + identity = messages.pop(0) + serialized_msg = messages + request_msg = ZMQMessage.deserialize(serialized_msg) - if request_msg.request_type == ZMQRequestType.HANDSHAKE: - storage_manager_id = request_msg.sender_id + if request_msg.request_type == ZMQRequestType.HANDSHAKE: + storage_manager_id = request_msg.sender_id - # Always send ACK for HANDSHAKE - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.HANDSHAKE_ACK, - sender_id=self.controller_id, - body={}, - ).serialize() - self.handshake_socket.send_multipart([identity, *response_msg]) - - # Track new connections - if storage_manager_id not in self._connected_storage_managers: - self._connected_storage_managers.add(storage_manager_id) - storage_manager_type = request_msg.body.get("storage_manager_type", "Unknown") - logger.debug( - f"[{self.controller_id}]: received handshake from " - f"storage manager {storage_manager_id} (type: {storage_manager_type}). " - f"Total connected: {len(self._connected_storage_managers)}" - ) - else: - logger.debug( - f"[{self.controller_id}]: received duplicate handshake from " - f"storage manager {storage_manager_id}. Resending ACK." - ) + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.HANDSHAKE_ACK, + sender_id=self.controller_id, + body={}, + ).serialize() + handshake_socket.send_multipart([identity, *response_msg]) + + if storage_manager_id not in self._connected_storage_managers: + self._connected_storage_managers.add(storage_manager_id) + storage_manager_type = request_msg.body.get("storage_manager_type", "Unknown") + logger.debug( + f"[{self.controller_id}]: received handshake from " + f"storage manager {storage_manager_id} (type: {storage_manager_type}). " + f"Total connected: {len(self._connected_storage_managers)}" + ) + else: + logger.debug( + f"[{self.controller_id}]: received duplicate handshake from " + f"storage manager {storage_manager_id}. Resending ACK." + ) - except Exception as e: - logger.error(f"[{self.controller_id}]: error processing handshake: {e}") + except Exception as e: + logger.error(f"[{self.controller_id}]: error processing handshake: {e}") - def _start_process_handshake(self): - """Start the handshake process thread.""" - self.wait_connection_thread = Thread( + def _start_daemon_threads(self): + self._transport.start_daemon_thread( target=self._wait_connection, - name="TransferQueueControllerWaitConnectionThread", - daemon=True, + name="TQControllerWaitConnectionThread", ) - self.wait_connection_thread.start() - - def _start_process_update_data_status(self): - """Start the data status update processing thread.""" - self.process_update_data_status_thread = Thread( + self._transport.start_daemon_thread( target=self._update_data_status, - name="TransferQueueControllerProcessUpdateDataStatusThread", - daemon=True, + name="TQControllerProcessUpdateDataStatusThread", ) - self.process_update_data_status_thread.start() - - def _start_process_request(self): - """Start the request processing thread.""" - self.process_request_thread = Thread( + self._transport.start_daemon_thread( target=self._process_request, - name="TransferQueueControllerProcessRequestThread", - daemon=True, + name="TQControllerProcessRequestThread", ) - self.process_request_thread.start() def _process_request(self): """Main request processing loop - adapted for partition-based operations.""" logger.info(f"[{self.controller_id}]: start processing requests...") + request_handle_socket = self._transport.get_socket("request_handle_socket") perf_monitor = IntervalPerfMonitor(caller_name=self.controller_id) while True: monitor = self._metrics if self._metrics is not None else perf_monitor - messages = self.request_handle_socket.recv_multipart(copy=False) + messages = request_handle_socket.recv_multipart(copy=False) identity = messages.pop(0) serialized_msg = messages request_msg = ZMQMessage.deserialize(serialized_msg) @@ -2045,18 +1980,18 @@ def _process_request(self): body={"partition_info": partition_info, "message": message}, ) - self.request_handle_socket.send_multipart([identity, *response_msg.serialize()]) + request_handle_socket.send_multipart([identity, *response_msg.serialize()]) def _update_data_status(self): """Process data status update messages from storage units - adapted for partitions.""" logger.debug(f"[{self.controller_id}]: start receiving update_data_status requests...") + data_status_update_socket = self._transport.get_socket("data_status_update_socket") perf_monitor = IntervalPerfMonitor(caller_name=self.controller_id) while True: monitor = self._metrics if self._metrics is not None else perf_monitor - - messages = self.data_status_update_socket.recv_multipart(copy=False) + messages = data_status_update_socket.recv_multipart(copy=False) identity = messages.pop(0) serialized_msg = messages request_msg = ZMQMessage.deserialize(serialized_msg) @@ -2074,6 +2009,7 @@ def _update_data_status(self): field_schema=message_data.get("field_schema", {}), custom_backend_meta=message_data.get("custom_backend_meta", {}), ) + if success: if self._metrics is not None: self._metrics.record_samples("NOTIFY_DATA_UPDATE", len(global_indexes)) @@ -2089,7 +2025,7 @@ def _update_data_status(self): "success": success, }, ) - self.data_status_update_socket.send_multipart([identity, *response_msg.serialize()]) + data_status_update_socket.send_multipart([identity, *response_msg.serialize()]) def get_zmq_server_info(self) -> ZMQServerInfo: """Get ZMQ server connection information.""" @@ -2196,7 +2132,7 @@ def start_metrics(self, port: int = 0) -> str: from transfer_queue.metrics import TQMetricsExporter self._metrics = TQMetricsExporter() - self._metrics_endpoint = self._metrics.start(node_ip=self._node_ip, port=port) + self._metrics_endpoint = self._metrics.start(node_ip=get_node_ip_address(), port=port) # Launch a daemon thread that periodically pushes controller state # snapshots to the exporter, keeping them process-isolated. self._metrics_snapshot_thread = Thread( diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 4fe32f0a..be62e677 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -14,6 +14,7 @@ # limitations under the License. import socket +import threading import time from dataclasses import dataclass from functools import wraps @@ -103,9 +104,7 @@ class ZMQRequestType(ExplicitEnum): class ZMQServerInfo: - """ - TransferQueue server info class. - """ + """TransferQueue server info class.""" def __init__(self, role: Role, id: str, ip: str, ports: dict[str, int]): self.role = role @@ -132,9 +131,7 @@ def __str__(self) -> str: @dataclass class ZMQMessage: - """ - ZMQMessage class for TransferQueue communication. - """ + """ZMQMessage class for TransferQueue communication.""" request_type: ZMQRequestType sender_id: str @@ -190,6 +187,53 @@ def deserialize(cls, frames: list) -> "ZMQMessage": ) +class ZMQServerTransport: + """Unified management of ZMQ Router Sockets, port binding, daemon threads, and message I/O.""" + + def __init__(self, node_ip: str, ctx: zmq.Context | None = None): + self.node_ip = node_ip + self.zmq_ctx = ctx or zmq.Context() + self.sockets: dict[str, zmq.Socket] = {} + self.ports: dict[str, int] = {} + self.threads: list[threading.Thread] = [] + + def create_router_socket(self, name: str) -> None: + """Create a ROUTER-type socket, automatically retrying port binding.""" + while True: + try: + port = get_free_port(ip=self.node_ip) + sock = create_zmq_socket( + ctx=self.zmq_ctx, + socket_type=zmq.ROUTER, + ip=self.node_ip, + ) + sock.bind(format_zmq_address(self.node_ip, port)) + self.sockets[name] = sock + self.ports[name] = port + return + except zmq.ZMQError: + logger.warning(f"ZMQ bind {name} failed, retrying...") + + def get_socket(self, name: str) -> zmq.Socket: + """Get ZMQ socket by name.""" + return self.sockets[name] + + def start_daemon_thread(self, target, name: str) -> None: + """Start a daemon thread with the given target functions.""" + t = threading.Thread(target=target, name=name, daemon=True) + t.start() + self.threads.append(t) + + def build_server_info(self, role: Role, id: str) -> ZMQServerInfo: + """Build ZMQServerInfo.""" + return ZMQServerInfo( + role=role, + id=id, + ip=self.node_ip, + ports=self.ports, + ) + + def is_ipv6_address(ip: str) -> bool: """Check if the given IP address is an IPv6 address.""" try: