|
27 | 27 | import numpy as np |
28 | 28 | import ray |
29 | 29 | import torch |
30 | | -import zmq |
31 | 30 | from omegaconf import DictConfig |
32 | 31 | from torch import Tensor |
33 | 32 |
|
@@ -1047,51 +1046,44 @@ def _start_daemon_threads(self): |
1047 | 1046 | ) |
1048 | 1047 |
|
1049 | 1048 | def _wait_connection(self): |
1050 | | - """Wait for storage instances to complete handshake with retransmission support.""" |
| 1049 | + """Wait for storage instances to complete handshake.""" |
1051 | 1050 | handshake_socket = self._transport.get_socket("handshake_socket") |
1052 | | - poller = zmq.Poller() |
1053 | | - poller.register(handshake_socket, zmq.POLLIN) |
1054 | 1051 |
|
1055 | 1052 | logger.debug(f"Controller {self.controller_id} started waiting for storage connections...") |
1056 | 1053 |
|
1057 | 1054 | while True: |
1058 | | - socks = dict(poller.poll(1000)) |
1059 | | - |
1060 | | - if handshake_socket in socks: |
1061 | | - try: |
1062 | | - messages = handshake_socket.recv_multipart(copy=False) |
1063 | | - identity = messages.pop(0) |
1064 | | - serialized_msg = messages |
1065 | | - request_msg = ZMQMessage.deserialize(serialized_msg) |
| 1055 | + try: |
| 1056 | + messages = handshake_socket.recv_multipart(copy=False) |
| 1057 | + identity = messages.pop(0) |
| 1058 | + serialized_msg = messages |
| 1059 | + request_msg = ZMQMessage.deserialize(serialized_msg) |
1066 | 1060 |
|
1067 | | - if request_msg.request_type == ZMQRequestType.HANDSHAKE: |
1068 | | - storage_manager_id = request_msg.sender_id |
| 1061 | + if request_msg.request_type == ZMQRequestType.HANDSHAKE: |
| 1062 | + storage_manager_id = request_msg.sender_id |
1069 | 1063 |
|
1070 | | - # Always send ACK for HANDSHAKE |
1071 | | - response_msg = ZMQMessage.create( |
1072 | | - request_type=ZMQRequestType.HANDSHAKE_ACK, |
1073 | | - sender_id=self.controller_id, |
1074 | | - body={}, |
1075 | | - ).serialize() |
1076 | | - handshake_socket.send_multipart([identity, *response_msg]) |
1077 | | - |
1078 | | - # Track new connections |
1079 | | - if storage_manager_id not in self._connected_storage_managers: |
1080 | | - self._connected_storage_managers.add(storage_manager_id) |
1081 | | - storage_manager_type = request_msg.body.get("storage_manager_type", "Unknown") |
1082 | | - logger.debug( |
1083 | | - f"[{self.controller_id}]: received handshake from " |
1084 | | - f"storage manager {storage_manager_id} (type: {storage_manager_type}). " |
1085 | | - f"Total connected: {len(self._connected_storage_managers)}" |
1086 | | - ) |
1087 | | - else: |
1088 | | - logger.debug( |
1089 | | - f"[{self.controller_id}]: received duplicate handshake from " |
1090 | | - f"storage manager {storage_manager_id}. Resending ACK." |
1091 | | - ) |
| 1064 | + response_msg = ZMQMessage.create( |
| 1065 | + request_type=ZMQRequestType.HANDSHAKE_ACK, |
| 1066 | + sender_id=self.controller_id, |
| 1067 | + body={}, |
| 1068 | + ).serialize() |
| 1069 | + handshake_socket.send_multipart([identity, *response_msg]) |
| 1070 | + |
| 1071 | + if storage_manager_id not in self._connected_storage_managers: |
| 1072 | + self._connected_storage_managers.add(storage_manager_id) |
| 1073 | + storage_manager_type = request_msg.body.get("storage_manager_type", "Unknown") |
| 1074 | + logger.debug( |
| 1075 | + f"[{self.controller_id}]: received handshake from " |
| 1076 | + f"storage manager {storage_manager_id} (type: {storage_manager_type}). " |
| 1077 | + f"Total connected: {len(self._connected_storage_managers)}" |
| 1078 | + ) |
| 1079 | + else: |
| 1080 | + logger.debug( |
| 1081 | + f"[{self.controller_id}]: received duplicate handshake from " |
| 1082 | + f"storage manager {storage_manager_id}. Resending ACK." |
| 1083 | + ) |
1092 | 1084 |
|
1093 | | - except Exception as e: |
1094 | | - logger.error(f"[{self.controller_id}]: error processing handshake: {e}") |
| 1085 | + except Exception as e: |
| 1086 | + logger.error(f"[{self.controller_id}]: error processing handshake: {e}") |
1095 | 1087 |
|
1096 | 1088 | def _process_request(self): |
1097 | 1089 | """Main request processing loop - adapted for partition-based operations.""" |
@@ -1708,7 +1700,6 @@ def get_metadata( |
1708 | 1700 |
|
1709 | 1701 | elif mode == "force_fetch": |
1710 | 1702 | batch_global_indexes = self.index_manager.get_indexes_for_partition(partition_id) |
1711 | | - consumed_indexes = [] |
1712 | 1703 |
|
1713 | 1704 | # Package into metadata |
1714 | 1705 | metadata = self.generate_batch_meta(partition_id, batch_global_indexes, data_fields, mode) |
|
0 commit comments