|
27 | 27 | import numpy as np |
28 | 28 | import ray |
29 | 29 | import torch |
30 | | -import zmq |
31 | 30 | from omegaconf import DictConfig |
32 | 31 | from torch import Tensor |
33 | 32 |
|
@@ -1040,51 +1039,44 @@ def _start_daemon_threads(self): |
1040 | 1039 | ) |
1041 | 1040 |
|
1042 | 1041 | def _wait_connection(self): |
1043 | | - """Wait for storage instances to complete handshake with retransmission support.""" |
| 1042 | + """Wait for storage instances to complete handshake.""" |
1044 | 1043 | handshake_socket = self._transport.get_socket("handshake_socket") |
1045 | | - poller = zmq.Poller() |
1046 | | - poller.register(handshake_socket, zmq.POLLIN) |
1047 | 1044 |
|
1048 | 1045 | logger.debug(f"Controller {self.controller_id} started waiting for storage connections...") |
1049 | 1046 |
|
1050 | 1047 | while True: |
1051 | | - socks = dict(poller.poll(1000)) |
| 1048 | + try: |
| 1049 | + messages = handshake_socket.recv_multipart(copy=False) |
| 1050 | + identity = messages.pop(0) |
| 1051 | + serialized_msg = messages |
| 1052 | + request_msg = ZMQMessage.deserialize(serialized_msg) |
1052 | 1053 |
|
1053 | | - if handshake_socket in socks: |
1054 | | - try: |
1055 | | - messages = handshake_socket.recv_multipart(copy=False) |
1056 | | - identity = messages.pop(0) |
1057 | | - serialized_msg = messages |
1058 | | - request_msg = ZMQMessage.deserialize(serialized_msg) |
| 1054 | + if request_msg.request_type == ZMQRequestType.HANDSHAKE: |
| 1055 | + storage_manager_id = request_msg.sender_id |
1059 | 1056 |
|
1060 | | - if request_msg.request_type == ZMQRequestType.HANDSHAKE: |
1061 | | - storage_manager_id = request_msg.sender_id |
1062 | | - |
1063 | | - # Always send ACK for HANDSHAKE |
1064 | | - response_msg = ZMQMessage.create( |
1065 | | - request_type=ZMQRequestType.HANDSHAKE_ACK, |
1066 | | - sender_id=self.controller_id, |
1067 | | - body={}, |
1068 | | - ).serialize() |
1069 | | - handshake_socket.send_multipart([identity, *response_msg]) |
1070 | | - |
1071 | | - # Track new connections |
1072 | | - if storage_manager_id not in self._connected_storage_managers: |
1073 | | - self._connected_storage_managers.add(storage_manager_id) |
1074 | | - storage_manager_type = request_msg.body.get("storage_manager_type", "Unknown") |
1075 | | - logger.debug( |
1076 | | - f"[{self.controller_id}]: received handshake from " |
1077 | | - f"storage manager {storage_manager_id} (type: {storage_manager_type}). " |
1078 | | - f"Total connected: {len(self._connected_storage_managers)}" |
1079 | | - ) |
1080 | | - else: |
1081 | | - logger.debug( |
1082 | | - f"[{self.controller_id}]: received duplicate handshake from " |
1083 | | - f"storage manager {storage_manager_id}. Resending ACK." |
1084 | | - ) |
| 1057 | + response_msg = ZMQMessage.create( |
| 1058 | + request_type=ZMQRequestType.HANDSHAKE_ACK, |
| 1059 | + sender_id=self.controller_id, |
| 1060 | + body={}, |
| 1061 | + ).serialize() |
| 1062 | + handshake_socket.send_multipart([identity, *response_msg]) |
| 1063 | + |
| 1064 | + if storage_manager_id not in self._connected_storage_managers: |
| 1065 | + self._connected_storage_managers.add(storage_manager_id) |
| 1066 | + storage_manager_type = request_msg.body.get("storage_manager_type", "Unknown") |
| 1067 | + logger.debug( |
| 1068 | + f"[{self.controller_id}]: received handshake from " |
| 1069 | + f"storage manager {storage_manager_id} (type: {storage_manager_type}). " |
| 1070 | + f"Total connected: {len(self._connected_storage_managers)}" |
| 1071 | + ) |
| 1072 | + else: |
| 1073 | + logger.debug( |
| 1074 | + f"[{self.controller_id}]: received duplicate handshake from " |
| 1075 | + f"storage manager {storage_manager_id}. Resending ACK." |
| 1076 | + ) |
1085 | 1077 |
|
1086 | | - except Exception as e: |
1087 | | - logger.error(f"[{self.controller_id}]: error processing handshake: {e}") |
| 1078 | + except Exception as e: |
| 1079 | + logger.error(f"[{self.controller_id}]: error processing handshake: {e}") |
1088 | 1080 |
|
1089 | 1081 | def _process_request(self): |
1090 | 1082 | """Main request processing loop - adapted for partition-based operations.""" |
@@ -1701,7 +1693,6 @@ def get_metadata( |
1701 | 1693 |
|
1702 | 1694 | elif mode == "force_fetch": |
1703 | 1695 | batch_global_indexes = self.index_manager.get_indexes_for_partition(partition_id) |
1704 | | - consumed_indexes = [] |
1705 | 1696 |
|
1706 | 1697 | # Package into metadata |
1707 | 1698 | metadata = self.generate_batch_meta(partition_id, batch_global_indexes, data_fields, mode) |
|
0 commit comments