From 8efb8593122400f3c80ce055993528e0bbce2c3c Mon Sep 17 00:00:00 2001 From: enriquetomasmb Date: Mon, 16 Jun 2025 18:13:10 +0200 Subject: [PATCH 01/12] fix deadlock during disconnection --- nebula/core/network/communications.py | 85 +++++++++++++++++---------- 1 file changed, 54 insertions(+), 31 deletions(-) diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index d65cc2540..e9a4ccec7 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -35,10 +35,10 @@ class CommunicationsManager: - Handling and dispatching incoming messages to the appropriate handlers. - Preventing message duplication via message hash tracking. - It acts as a central coordinator for message-based interactions and is + It acts as a central coordinator for message-based interactions and is designed to work asynchronously to support non-blocking network operations. """ - + _instance = None _lock = Locker("communications_manager_lock", async_lock=False) @@ -143,7 +143,7 @@ def discoverer(self): def health(self): """ Returns the HealthMonitor component that checks and maintains node health status. - """ + """ return self._health @property @@ -157,7 +157,7 @@ def forwarder(self): def propagator(self): """ Returns the component responsible for propagating messages throughout the network. - """ + """ return self._propagator @property @@ -529,11 +529,12 @@ async def handle_connection(self, reader, writer, priority="medium"): Wrapper coroutine to handle a new incoming connection. Schedules the actual connection handling coroutine as an asyncio task. - + Args: reader (asyncio.StreamReader): Stream reader for the connection. writer (asyncio.StreamWriter): Stream writer for the connection. """ + async def process_connection(reader, writer, priority="medium"): """ Handles the lifecycle of a new incoming connection, including validation, authorization, @@ -790,7 +791,7 @@ async def send_message_to_neighbors(self, message, neighbors=None, interval=0): Args: message (Any): The message to send. - neighbors (set, optional): A set of neighbor addresses to send the message to. + neighbors (set, optional): A set of neighbor addresses to send the message to. If None, the message is sent to all direct neighbors. interval (float, optional): Delay in seconds between sending the message to each neighbor. """ @@ -1057,52 +1058,74 @@ async def disconnect(self, dest_addr, mutual_disconnection=True, forced=False): """ Disconnects from a specified destination address and performs cleanup tasks. - Optionally sends a mutual disconnection message to the peer, adds the address to the blacklist - if the disconnection is forced, and updates the list of current neighbors accordingly. - Args: dest_addr (str): The address of the node to disconnect from. mutual_disconnection (bool, optional): Whether to notify the peer about the disconnection. Defaults to True. forced (bool, optional): If True, the destination address will be blacklisted. Defaults to False. """ - removed = False + logging.info(f"Trying to disconnect {dest_addr}") + + # Check if this is a direct neighbor before proceeding is_neighbor = dest_addr in await self.get_addrs_current_connections(only_direct=True, myself=True) + # Add to blacklist if forced disconnection if forced: await self.add_to_blacklist(dest_addr) - logging.info(f"Trying to disconnect {dest_addr}") + # Get the connection under lock to prevent race conditions async with self.connections_lock: if dest_addr not in self.connections: logging.info(f"Connection {dest_addr} not found") return + conn = self.connections[dest_addr] + try: + # Attempt mutual disconnection if requested if mutual_disconnection: - await self.connections[dest_addr].send(data=self.create_message("connection", "disconnect")) - await asyncio.sleep(1) + try: + await conn.send(data=self.create_message("connection", "disconnect")) + async with self.connections_lock: + if dest_addr in self.connections: + self.connections.pop(dest_addr) + await conn.stop() + except Exception as e: + logging.warning(f"Failed to send disconnect message to {dest_addr}: {e!s}") + # Ensure connection is removed even if message sending fails + async with self.connections_lock: + if dest_addr in self.connections: + self.connections.pop(dest_addr) + await conn.stop() + else: + # For non-mutual disconnection, just stop and remove async with self.connections_lock: - conn = self.connections.pop(dest_addr) + if dest_addr in self.connections: + self.connections.pop(dest_addr) await conn.stop() + + # Update configuration and neighbors + current_connections = await self.get_all_addrs_current_connections(only_direct=True) + current_connections = set(current_connections) + logging.info(f"Current connections after disconnection: {current_connections}") + + # Update configuration + self.config.update_neighbors_from_config(current_connections, dest_addr) + + # Update engine if this was a direct neighbor + if is_neighbor: + current_connections = await self.get_addrs_current_connections(only_direct=True, myself=True) + await self.engine.update_neighbors(dest_addr, current_connections, remove=True) + except Exception as e: - logging.exception(f"❗️ Error while disconnecting {dest_addr}: {e!s}") - if dest_addr in self.connections: - logging.info(f"Removing {dest_addr} from connections") + logging.exception(f"Error during disconnection of {dest_addr}: {e!s}") + # Ensure connection is removed even if there's an error + async with self.connections_lock: + if dest_addr in self.connections: + self.connections.pop(dest_addr) try: - removed = True - async with self.connections_lock: - conn = self.connections.pop(dest_addr) await conn.stop() - except Exception as e: - logging.exception(f"❗️ Error while removing connection {dest_addr}: {e!s}") - current_connections = await self.get_all_addrs_current_connections(only_direct=True) - current_connections = set(current_connections) - logging.info(f"Current connections: {current_connections}") - self.config.update_neighbors_from_config(current_connections, dest_addr) - - if removed: - current_connections = await self.get_addrs_current_connections(only_direct=True, myself=True) - if is_neighbor: - await self.engine.update_neighbors(dest_addr, current_connections, remove=removed) + except Exception as stop_error: + logging.warning(f"Error stopping connection during cleanup: {stop_error!s}") + raise async def get_all_addrs_current_connections(self, only_direct=False, only_undirected=False): """ From cd3fba8d9702720f97f783a4cbf70bb38e03468a Mon Sep 17 00:00:00 2001 From: enriquetomasmb Date: Tue, 17 Jun 2025 09:11:23 +0200 Subject: [PATCH 02/12] remove unnecessary reconnection attempts after the experiment finishes --- nebula/core/engine.py | 13 +++++++++++-- nebula/core/network/connection.py | 18 ++++++++++++++---- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index fd0a24ffe..9d00fd93d 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -687,7 +687,7 @@ async def _waiting_model_updates(self): - Logs an error indicating aggregation failure. This method is called after local training and before proceeding to the next round, - ensuring the model is synchronized with the federation’s latest aggregated state. + ensuring the model is synchronized with the federation's latest aggregated state. """ logging.info(f"💤 Waiting convergence in round {self.round}.") params = await self.aggregator.get_aggregation() @@ -710,7 +710,7 @@ def learning_cycle_finished(self): if not self.round or not self.total_rounds: return False else: - return (self.round < self.total_rounds) + return self.round < self.total_rounds async def _learning_cycle(self): """ @@ -768,6 +768,15 @@ async def _learning_cycle(self): indent=2, title="Round information", ) + + # Removing random neighbor (if I am the starter) + if self.config.participant["device_args"]["start"]: + random_neighbor = random.choice(list(direct_connections)) + try: + await self.cm.disconnect(random_neighbor, mutual_disconnection=True) + except Exception as e: + logging.error(f"Error disconnecting from {random_neighbor}: {e}") + # await self.aggregator.reset() self.trainer.on_round_end() self.round += 1 diff --git a/nebula/core/network/connection.py b/nebula/core/network/connection.py index 29a79da5e..6ac74d153 100755 --- a/nebula/core/network/connection.py +++ b/nebula/core/network/connection.py @@ -288,6 +288,11 @@ async def reconnect(self, max_retries: int = 5, delay: int = 5) -> None: logging.info("Not going to reconnect because this connection is not direct") return + # Check if learning cycle has finished to prevent unnecessary reconnection attempts + if self.cm.learning_finished(): + logging.info(f"Not attempting reconnection to {self.addr} because learning cycle has finished") + return + self.incompleted_reconnections += 1 if self.incompleted_reconnections == MAX_INCOMPLETED_RECONNECTIONS: logging.info(f"Reconnection with {self.addr} failed...") @@ -361,8 +366,10 @@ async def send( await self._send_chunks(message_id, data_to_send) except Exception as e: logging.exception(f"Error sending data: {e}") - if self.direct: + if self.direct and not self.cm.learning_finished(): await self.reconnect() + elif self.cm.learning_finished(): + logging.info(f"Not attempting reconnection to {self.addr} because learning cycle has finished") def _prepare_data(self, data: Any, pb: bool, encoding_type: str) -> tuple[bytes, bytes]: """ @@ -486,11 +493,14 @@ async def handle_incoming_message(self) -> None: logging.exception(f"Connection closed while reading: {e}") except Exception as e: logging.exception(f"Error handling incoming message: {e}") - except BrokenPipeError: - logging.exception(f"Error handling incoming message: {e}") + except BrokenPipeError as e: + logging.exception(f"Broken pipe error handling incoming message: {e}") finally: - if self.direct or self._prio == ConnectionPriority.HIGH: + # Only attempt reconnection if the learning cycle hasn't finished and the connection is direct or high priority + if (self.direct or self._prio == ConnectionPriority.HIGH) and not self.cm.learning_finished(): await self.reconnect() + elif self.cm.learning_finished(): + logging.info(f"Not attempting reconnection to {self.addr} because learning cycle has finished") async def _read_exactly(self, num_bytes: int, max_retries: int = 3) -> bytes: """ From f6a80607f970e462db0584bce85cd8c39dcdecde Mon Sep 17 00:00:00 2001 From: enriquetomasmb Date: Tue, 17 Jun 2025 09:13:11 +0200 Subject: [PATCH 03/12] remove debug code --- nebula/core/engine.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 9d00fd93d..710c64214 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -769,14 +769,6 @@ async def _learning_cycle(self): title="Round information", ) - # Removing random neighbor (if I am the starter) - if self.config.participant["device_args"]["start"]: - random_neighbor = random.choice(list(direct_connections)) - try: - await self.cm.disconnect(random_neighbor, mutual_disconnection=True) - except Exception as e: - logging.error(f"Error disconnecting from {random_neighbor}: {e}") - # await self.aggregator.reset() self.trainer.on_round_end() self.round += 1 From 88449ab26dc1073ac69f6ddc4b41fabbc0ac2ccb Mon Sep 17 00:00:00 2001 From: enriquetomasmb Date: Tue, 17 Jun 2025 11:29:07 +0200 Subject: [PATCH 04/12] fix race conditions and soem issues --- nebula/core/aggregation/aggregator.py | 37 +++++++++++++++++++-------- nebula/core/engine.py | 24 ++++++++++++++++- 2 files changed, 50 insertions(+), 11 deletions(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index 4b4600f71..b9f94066d 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -56,25 +56,37 @@ async def update_federation_nodes(self, federation_nodes: set): This method informs the update handler (`us`) about the new set of federation nodes, clears any pending models, and attempts to acquire the aggregation lock to prepare - for model aggregation. If the aggregation process is already running, it raises an exception. + for model aggregation. If the aggregation process is already running, it releases the lock + and tries again to ensure proper cleanup between rounds. Args: federation_nodes (set): A set of addresses representing the nodes expected to contribute updates for the next aggregation round. Raises: - Exception: If the aggregation process is already running and the lock is currently held. + Exception: If the aggregation process is already running and the lock cannot be released. """ await self.us.round_expected_updates(federation_nodes=federation_nodes) - if not self._aggregation_done_lock.locked(): - self._federation_nodes = federation_nodes - self._pending_models_to_aggregate.clear() - await self._aggregation_done_lock.acquire_async( - timeout=self.config.participant["aggregator_args"]["aggregation_timeout"] - ) - else: - raise Exception("It is not possible to set nodes to aggregate when the aggregation is running.") + # If the aggregation lock is held, release it to prepare for the new round + if self._aggregation_done_lock.locked(): + logging.info("🔄 update_federation_nodes | Aggregation lock is held, releasing for new round") + try: + await self._aggregation_done_lock.release_async() + except Exception as e: + logging.warning(f"🔄 update_federation_nodes | Error releasing aggregation lock: {e}") + # If we can't release the lock, we might be in the middle of aggregation + # In this case, we should wait a bit and try again + await asyncio.sleep(0.1) + if self._aggregation_done_lock.locked(): + raise Exception("It is not possible to set nodes to aggregate when the aggregation is running.") + + # Now acquire the lock for the new round + self._federation_nodes = federation_nodes + self._pending_models_to_aggregate.clear() + await self._aggregation_done_lock.acquire_async( + timeout=self.config.participant["aggregator_args"]["aggregation_timeout"] + ) def get_nodes_pending_models_to_aggregate(self): return self._federation_nodes @@ -97,6 +109,11 @@ async def get_aggregation(self): asyncio.CancelledError: If the aggregation lock acquisition is cancelled. Exception: For any other unexpected errors during the aggregation process. """ + # Check if learning cycle has finished to prevent blocking + if not self.engine.learning_cycle_finished(): + logging.info("🔄 get_aggregation | Learning cycle has finished, skipping aggregation") + return None + try: timeout = self.config.participant["aggregator_args"]["aggregation_timeout"] logging.info(f"Aggregation timeout: {timeout} starts...") diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 710c64214..1134deb6a 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -253,6 +253,17 @@ async def model_initialization_callback(self, source, message): async def model_update_callback(self, source, message): logging.info(f"🤖 handle_model_message | Received model update from {source} with round {message.round}") + + # Ignore model updates if the learning cycle has finished + if self.learning_cycle_finished(): + logging.info(f"🤖 Ignoring model update from {source} because learning cycle has finished") + return + + # Ignore updates from different rounds + if message.round != self.round: + logging.info(f"🤖 Ignoring model update from {source} because it's from round {message.round} but we're in round {self.round}") + return + if not self.get_federation_ready_lock().locked() and len(await self.get_federation_nodes()) == 0: logging.info("🤖 handle_model_message | There are no defined federation nodes") return @@ -348,6 +359,12 @@ async def _federation_federation_start_callback(self, source, message): async def _federation_federation_models_included_callback(self, source, message): logging.info(f"📝 handle_federation_message | Trigger | Received aggregation finished message from {source}") + + # Ignore aggregation finished messages if the learning cycle has finished + if not self.learning_cycle_finished(): + logging.info(f"📝 Ignoring aggregation finished message from {source} because learning cycle has finished") + return + try: await self.cm.get_connections_lock().acquire_async() if self.round is not None and source in self.cm.connections: @@ -448,6 +465,11 @@ async def broadcast_models_include(self, age: AggregationEvent): Sends: federation_models_included: A message containing the round number of the aggregation. """ + # Don't broadcast if the learning cycle has finished + if not self.learning_cycle_finished(): + logging.info(f"🔄 Not broadcasting MODELS_INCLUDED because learning cycle has finished") + return + logging.info(f"🔄 Broadcasting MODELS_INCLUDED for round {self.get_round()}") message = self.cm.create_message( "federation", "federation_models_included", [str(arg) for arg in [self.get_round()]] @@ -710,7 +732,7 @@ def learning_cycle_finished(self): if not self.round or not self.total_rounds: return False else: - return self.round < self.total_rounds + return self.round >= self.total_rounds async def _learning_cycle(self): """ From 9de5e6a4aa0fcd1e3b9b25d377eb5c3c99533981 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Tue, 17 Jun 2025 12:37:48 +0200 Subject: [PATCH 05/12] fix aggregation skipped --- nebula/core/aggregation/aggregator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index b9f94066d..0648b5b42 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -110,7 +110,7 @@ async def get_aggregation(self): Exception: For any other unexpected errors during the aggregation process. """ # Check if learning cycle has finished to prevent blocking - if not self.engine.learning_cycle_finished(): + if self.engine.learning_cycle_finished(): logging.info("🔄 get_aggregation | Learning cycle has finished, skipping aggregation") return None From 38c2263ee0e225a618ed2065defc16c8444185a1 Mon Sep 17 00:00:00 2001 From: enriquetomasmb Date: Tue, 17 Jun 2025 13:49:15 +0200 Subject: [PATCH 06/12] fix issues in blacklist --- nebula/core/network/blacklist.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/nebula/core/network/blacklist.py b/nebula/core/network/blacklist.py index cf95340d2..b239bc07e 100644 --- a/nebula/core/network/blacklist.py +++ b/nebula/core/network/blacklist.py @@ -80,7 +80,9 @@ async def add_to_blacklist(self, addr): asyncio.create_task(self._start_blacklist_cleaner()) await self._blacklisted_nodes_lock.release_async() nbe = NodeBlacklistedEvent(addr, blacklisted=True) - asyncio.create_task(EventManager.get_instance().publish_node_event(nbe)) + event_manager = EventManager.get_instance() + if event_manager is not None: + asyncio.create_task(event_manager.publish_node_event(nbe)) async def get_blacklist(self) -> set: """ @@ -94,7 +96,7 @@ async def get_blacklist(self) -> set: if self._blacklisted_nodes: bl = set(self._blacklisted_nodes.keys()) await self._blacklisted_nodes_lock.release_async() - return bl + return bl or set() async def clear_blacklist(self): """ @@ -192,21 +194,23 @@ async def add_recently_disconnected(self, addr): addr (str): Address of the disconnected node. """ logging.info(f"Recently disconnected from: {addr}") - self._recently_disconnected_lock.acquire_async() + await self._recently_disconnected_lock.acquire_async() self._recently_disconnected.add(addr) - self._recently_disconnected_lock.release_async() + await self._recently_disconnected_lock.release_async() asyncio.create_task(self._remove_recently_disc(addr)) nbe = NodeBlacklistedEvent(addr) - asyncio.create_task(EventManager.get_instance().publish_node_event(nbe)) + event_manager = EventManager.get_instance() + if event_manager is not None: + asyncio.create_task(event_manager.publish_node_event(nbe)) async def clear_recently_disconected(self): """ Clears the list of recently disconnected nodes. """ - self._recently_disconnected_lock.acquire_async() + await self._recently_disconnected_lock.acquire_async() logging.info("🧹 Removing nodes from Recently Disconencted list") self._recently_disconnected.clear() - self._recently_disconnected_lock.release_async() + await self._recently_disconnected_lock.release_async() async def get_recently_disconnected(self): """ @@ -216,9 +220,9 @@ async def get_recently_disconnected(self): set: Addresses of recently disconnected nodes. """ rd = None - self._recently_disconnected_lock.acquire_async() + await self._recently_disconnected_lock.acquire_async() rd = self._recently_disconnected.copy() - self._recently_disconnected_lock.release_async() + await self._recently_disconnected_lock.release_async() return rd async def _remove_recently_disc(self, addr): @@ -229,10 +233,10 @@ async def _remove_recently_disc(self, addr): addr (str): Address to remove after expiration. """ await asyncio.sleep(RECENTLY_DISCONNECTED_EXPIRE_TIME) - self._recently_disconnected_lock.acquire_async() + await self._recently_disconnected_lock.acquire_async() self._recently_disconnected.discard(addr) logging.info(f"Recently disconnection timeout expired for souce: {addr}") - self._recently_disconnected_lock.release_async() + await self._recently_disconnected_lock.release_async() async def verify_not_recently_disc(self, nodes: set) -> set | None: """ @@ -247,10 +251,10 @@ async def verify_not_recently_disc(self, nodes: set) -> set | None: if not nodes: return None nodes_not_listed = nodes - self._recently_disconnected_lock.acquire_async() + await self._recently_disconnected_lock.acquire_async() rec_disc = self._recently_disconnected # logging.info(f"recently disconencted nodes: {rec_disc}") if rec_disc: nodes_not_listed = nodes.difference(rec_disc) - self._recently_disconnected_lock.release_async() + await self._recently_disconnected_lock.release_async() return nodes_not_listed From 5c85e9f3e3fba866d276529fd3f36dbeefd23e9a Mon Sep 17 00:00:00 2001 From: enriquetomasmb Date: Tue, 17 Jun 2025 14:32:57 +0200 Subject: [PATCH 07/12] remove unnecesary condition --- nebula/core/engine.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 1134deb6a..b94f0f81a 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -253,17 +253,12 @@ async def model_initialization_callback(self, source, message): async def model_update_callback(self, source, message): logging.info(f"🤖 handle_model_message | Received model update from {source} with round {message.round}") - + # Ignore model updates if the learning cycle has finished if self.learning_cycle_finished(): logging.info(f"🤖 Ignoring model update from {source} because learning cycle has finished") return - - # Ignore updates from different rounds - if message.round != self.round: - logging.info(f"🤖 Ignoring model update from {source} because it's from round {message.round} but we're in round {self.round}") - return - + if not self.get_federation_ready_lock().locked() and len(await self.get_federation_nodes()) == 0: logging.info("🤖 handle_model_message | There are no defined federation nodes") return @@ -359,12 +354,12 @@ async def _federation_federation_start_callback(self, source, message): async def _federation_federation_models_included_callback(self, source, message): logging.info(f"📝 handle_federation_message | Trigger | Received aggregation finished message from {source}") - + # Ignore aggregation finished messages if the learning cycle has finished if not self.learning_cycle_finished(): logging.info(f"📝 Ignoring aggregation finished message from {source} because learning cycle has finished") return - + try: await self.cm.get_connections_lock().acquire_async() if self.round is not None and source in self.cm.connections: @@ -469,7 +464,7 @@ async def broadcast_models_include(self, age: AggregationEvent): if not self.learning_cycle_finished(): logging.info(f"🔄 Not broadcasting MODELS_INCLUDED because learning cycle has finished") return - + logging.info(f"🔄 Broadcasting MODELS_INCLUDED for round {self.get_round()}") message = self.cm.create_message( "federation", "federation_models_included", [str(arg) for arg in [self.get_round()]] From 7c9a465ef4c7732d132565e1ecc35bd25f6ffb2e Mon Sep 17 00:00:00 2001 From: enriquetomasmb Date: Wed, 18 Jun 2025 09:26:47 +0200 Subject: [PATCH 08/12] improve cleaning when the federation ends --- nebula/config/config.py | 25 +++++++++++++++++++-- nebula/core/engine.py | 50 ++++++++++++++++++++++++++++++++++------- 2 files changed, 65 insertions(+), 10 deletions(-) diff --git a/nebula/config/config.py b/nebula/config/config.py index c15324712..cae3cf7f8 100755 --- a/nebula/config/config.py +++ b/nebula/config/config.py @@ -48,14 +48,35 @@ def get_participant_config(self): def get_train_logging_config(self): # TBD pass - + def reset_logging_configuration(self): for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) - + self.__set_default_logging(mode="a") self.__set_training_logging(mode="a") + def shutdown_logging(self): + """ + Properly shuts down all loggers and their handlers in the system. + This ensures all buffered logs are written to their respective files. + """ + for handler in logging.getLogger().handlers: + handler.flush() + handler.close() + + training_logger = logging.getLogger(TRAINING_LOGGER) + for handler in training_logger.handlers: + handler.flush() + handler.close() + + pl_logger = logging.getLogger("lightning.pytorch") + for handler in pl_logger.handlers: + handler.flush() + handler.close() + + logging.shutdown() + def __default_config(self): self.participant["device_args"]["name"] = ( f"participant_{self.participant['device_args']['idx']}_{self.participant['network_args']['ip']}_{self.participant['network_args']['port']}" diff --git a/nebula/core/engine.py b/nebula/core/engine.py index b94f0f81a..5d9e286f1 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -809,22 +809,56 @@ async def _learning_cycle(self): ) # Report if self.config.participant["scenario_args"]["controller"] != "nebula-test": - result = await self.reporter.report_scenario_finished() - if result: - logging.info("📝 Scenario finished reported succesfully") - else: - logging.error("📝 Error reporting scenario finished") + try: + result = await self.reporter.report_scenario_finished() + if result: + logging.info("📝 Scenario finished reported successfully") + else: + logging.error("📝 Error reporting scenario finished") + except Exception as e: + logging.error(f"📝 Error during scenario finish report: {e}") + + # Get all tasks except the current one + tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] + + logging.info("Starting graceful shutdown process...") + + model_tasks = [t for t in tasks if any(name in t.get_name().lower() for name in ["model", "aggregation"])] + if model_tasks: + logging.info("Waiting for model and aggregation tasks to complete...") + try: + await asyncio.wait_for(asyncio.gather(*model_tasks, return_exceptions=True), timeout=15) + except asyncio.TimeoutError: + logging.warning("Model tasks did not complete in time") + + other_tasks = [t for t in tasks if t not in model_tasks] + if other_tasks: + logging.info("Waiting for remaining tasks to complete...") + try: + await asyncio.wait_for(asyncio.gather(*other_tasks, return_exceptions=True), timeout=15) + except asyncio.TimeoutError: + logging.warning("Some tasks did not complete in time, forcing cancellation...") + for task in other_tasks: + if not task.done(): + task.cancel() + await asyncio.gather(*other_tasks, return_exceptions=True) + + # Remove all pending tasks + asyncio.all_tasks().clear() + + # Shutdown all logging handlers + self.config.shutdown_logging() - await asyncio.sleep(5) + # From here, logging is disabled + print("Shutdown complete. Terminating NEBULA CORE...") # Kill itself if self.config.participant["scenario_args"]["deployment"] == "docker": try: docker_id = socket.gethostname() - logging.info(f"📦 Killing docker container with ID {docker_id}") self.client.containers.get(docker_id).kill() except Exception as e: - logging.exception(f"📦 Error stopping Docker container with ID {docker_id}: {e}") + print(f"Error stopping Docker container with ID {docker_id}: {e}") async def _extended_learning_cycle(self): """ From 07e9d4c311bbcfb9e5e02f101838a131b8bf71bc Mon Sep 17 00:00:00 2001 From: enriquetomasmb Date: Wed, 18 Jun 2025 12:57:47 +0200 Subject: [PATCH 09/12] fix stop from frontend --- nebula/core/network/connection.py | 3 --- nebula/frontend/app.py | 17 ++++++++--------- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/nebula/core/network/connection.py b/nebula/core/network/connection.py index 6ac74d153..d44ef0447 100755 --- a/nebula/core/network/connection.py +++ b/nebula/core/network/connection.py @@ -123,9 +123,6 @@ def __str__(self): def __repr__(self): return self.__str__() - async def __del__(self): - await self.stop() - @property def cm(self): """Communication Manager""" diff --git a/nebula/frontend/app.py b/nebula/frontend/app.py index abc403abd..cf3820ee7 100755 --- a/nebula/frontend/app.py +++ b/nebula/frontend/app.py @@ -1688,8 +1688,8 @@ async def nebula_update_node(scenario_name: str, request: Request): @app.post("/platform/dashboard/{scenario_name}/node/done") async def node_stopped(scenario_name: str, request: Request): """ - Handle notification that a node has finished its task; mark the node as finished, - stop the scenario if all nodes are done, and signal scenario completion. + Handle notification that a node has finished its task; mark the node as finished + and signal scenario completion when all nodes are done. Parameters: scenario_name (str): Name of the scenario. @@ -1708,14 +1708,13 @@ async def node_stopped(scenario_name: str, request: Request): data = await request.json() user_data.nodes_finished.append(data["idx"]) nodes_list = await list_nodes_by_scenario_name(scenario_name) - finished = True - # Check if all the nodes of the scenario have finished the experiment - for node in nodes_list: - if str(node["idx"]) not in map(str, user_data.nodes_finished): - finished = False - if finished: - await stop_scenario_by_name(scenario_name, user) + # Check if all nodes have finished by comparing sets of node IDs + finished_node_ids = set(map(str, user_data.nodes_finished)) + all_node_ids = {str(node["idx"]) for node in nodes_list} + all_nodes_finished = finished_node_ids >= all_node_ids + + if all_nodes_finished: user_data.nodes_finished.clear() user_data.finish_scenario_event.set() return JSONResponse( From c069e1bf0653dac110474768072b78eb822d9f5b Mon Sep 17 00:00:00 2001 From: FerTV Date: Wed, 18 Jun 2025 13:39:58 +0200 Subject: [PATCH 10/12] add kill tasks --- nebula/addons/reporter.py | 6 +++++- nebula/core/engine.py | 26 ++++++++++++++++++++++++++ nebula/core/network/communications.py | 4 ++++ nebula/core/network/forwarder.py | 6 +++++- 4 files changed, 40 insertions(+), 2 deletions(-) diff --git a/nebula/addons/reporter.py b/nebula/addons/reporter.py index 0a8e1426c..bf16a4b44 100755 --- a/nebula/addons/reporter.py +++ b/nebula/addons/reporter.py @@ -67,6 +67,7 @@ def __init__(self, config, trainer): self.acc_bytes_recv = 0 self.acc_packets_sent = 0 self.acc_packets_recv = 0 + self.scenario_finished = asyncio.Event() @property def cm(self): @@ -132,7 +133,7 @@ async def run_reporter(self): Notes: - The reporting frequency is determined by the 'report_frequency' setting in the config file. """ - while True: + while not self.scenario_finished.is_set(): if self.config.participant["reporter_args"]["report_status_data_queue"]: if self.config.participant["scenario_args"]["controller"] != "nebula-test": await self.__report_status_to_controller() @@ -193,6 +194,9 @@ async def report_scenario_finished(self): except aiohttp.ClientError: logging.exception(f"Error connecting to the controller at {url}") return False + + def shutdown(self): + self.scenario_finished.set() async def __report_data_queue(self): """ diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 5d9e286f1..3269868db 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -813,6 +813,7 @@ async def _learning_cycle(self): result = await self.reporter.report_scenario_finished() if result: logging.info("📝 Scenario finished reported successfully") + self.reporter.shutdown() else: logging.error("📝 Error reporting scenario finished") except Exception as e: @@ -820,9 +821,34 @@ async def _learning_cycle(self): # Get all tasks except the current one tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] + # Cut connections + await self.cm.stop() logging.info("Starting graceful shutdown process...") + + tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] + current_task = asyncio.current_task() + all_tasks = asyncio.all_tasks() + + # Log basic task info + logging.info("Task Summary:") + + # Log task names if available + if current_task: + logging.info(f" • Current task: {current_task}") + + for task in all_tasks: + logging.info(f" • Task: {task}") + logging.info(f" • Task name: {task.get_name()}") + logging.info(f" • Task state: {task.get_state()}") + logging.info(f" • Task coroutine: {task.get_coro()}") + logging.info(f" • Task done: {task.done()}") + logging.info(f" • Task cancelled: {task.cancelled()}") + logging.info(f" • Task exception: {task.exception()}") + logging.info(f" • Task result: {task.result()}") + + model_tasks = [t for t in tasks if any(name in t.get_name().lower() for name in ["model", "aggregation"])] if model_tasks: logging.info("Waiting for model and aggregation tasks to complete...") diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index e9a4ccec7..274b586f1 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -691,6 +691,10 @@ async def stop(self): self.network_engine.close() await self.network_engine.wait_closed() self.network_task.cancel() + if self.forwarder: + self.forwarder.shutdown() + if self.ecs: + await self.ecs.stop() async def run_reconnections(self): for connection in self.connections_reconnect: diff --git a/nebula/core/network/forwarder.py b/nebula/core/network/forwarder.py index a5862a471..0c51505b2 100755 --- a/nebula/core/network/forwarder.py +++ b/nebula/core/network/forwarder.py @@ -39,6 +39,7 @@ def __init__(self, config): self.interval = self.config.participant["forwarder_args"]["forwarder_interval"] self.number_forwarded_messages = self.config.participant["forwarder_args"]["number_forwarded_messages"] self.messages_interval = self.config.participant["forwarder_args"]["forward_messages_interval"] + self.scenario_finished = asyncio.Event() @property def cm(self): @@ -72,7 +73,7 @@ async def run_forwarder(self): if self.config.participant["scenario_args"]["federation"] == "CFL": logging.info("🔁 Federation is CFL. Forwarder is disabled...") return - while True: + while not self.scenario_finished.is_set(): # logging.debug(f"🔁 Pending messages: {self.pending_messages.qsize()}") start_time = time.time() await self.pending_messages_lock.acquire_async() @@ -80,6 +81,9 @@ async def run_forwarder(self): await self.pending_messages_lock.release_async() sleep_time = max(0, self.interval - (time.time() - start_time)) await asyncio.sleep(sleep_time) + + def shutdown(self): + self.scenario_finished.set() async def process_pending_messages(self, messages_left): """ From 9c87b20a2733f1025cf4e87d66fffc33680e654f Mon Sep 17 00:00:00 2001 From: enriquetomasmb Date: Thu, 19 Jun 2025 09:56:38 +0200 Subject: [PATCH 11/12] create a shutdown system, improve readability and management --- nebula/addons/gps/nebulagps.py | 77 +++++++-- nebula/addons/mobility.py | 48 ++++-- .../nebulanetworksimulator.py | 72 ++++---- nebula/addons/reporter.py | 28 ++- nebula/controller/database.py | 19 ++- nebula/core/addonmanager.py | 24 ++- nebula/core/engine.py | 161 ++++++++++-------- nebula/core/network/blacklist.py | 69 ++++++-- nebula/core/network/communications.py | 69 ++++++-- nebula/core/network/connection.py | 52 +++--- nebula/core/network/discoverer.py | 11 +- .../externalconnectionservice.py | 5 +- .../nebuladiscoveryservice.py | 80 ++++++--- nebula/core/network/forwarder.py | 31 +++- nebula/core/network/health.py | 11 +- nebula/core/network/propagator.py | 17 +- nebula/core/node.py | 21 +-- .../awareness/sanetwork/sanetwork.py | 70 +++++++- .../awareness/sareasoner.py | 88 +++++----- .../discovery/federationconnector.py | 110 +++++++++--- .../situationalawareness.py | 24 ++- nebula/frontend/static/js/monitor/monitor.js | 84 +++++---- nebula/frontend/templates/dashboard.html | 22 +-- nebula/frontend/templates/monitor.html | 6 +- 24 files changed, 828 insertions(+), 371 deletions(-) diff --git a/nebula/addons/gps/nebulagps.py b/nebula/addons/gps/nebulagps.py index 8a310c561..d2dcc6f98 100644 --- a/nebula/addons/gps/nebulagps.py +++ b/nebula/addons/gps/nebulagps.py @@ -19,16 +19,17 @@ def __init__(self, config, addr, update_interval: float = 5.0, verbose=False): self._config = config self._addr = addr self.update_interval = update_interval # Frequency - self.running = False self._node_locations = {} # Dictionary for storing node locations self._broadcast_socket = None self._nodes_location_lock = Locker("nodes_location_lock", async_lock=True) self._verbose = verbose + self._running = asyncio.Event() + self._background_tasks = [] # Track background tasks async def start(self): """Starts the GPS service, sending and receiving locations.""" logging.info("Starting NebulaGPS service...") - self.running = True + self._running.set() # Create broadcast socket self._broadcast_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) @@ -38,20 +39,39 @@ async def start(self): self._broadcast_socket.bind(("", self.BROADCAST_PORT)) # Start sending and receiving tasks - asyncio.create_task(self._send_location_loop()) - asyncio.create_task(self._receive_location_loop()) - asyncio.create_task(self._notify_geolocs()) + self._background_tasks = [ + asyncio.create_task(self._send_location_loop(), name="NebulaGPS_send_location"), + asyncio.create_task(self._receive_location_loop(), name="NebulaGPS_receive_location"), + asyncio.create_task(self._notify_geolocs(), name="NebulaGPS_notify_geolocs"), + ] async def stop(self): """Stops the GPS service.""" - logging.info("Stopping NebulaGPS service...") - self.running = False + logging.info("🛑 Stopping NebulaGPS service...") + self._running.clear() + logging.info("🛑 NebulaGPS _running event cleared") + + # Cancel all background tasks + if self._background_tasks: + logging.info(f"🛑 Cancelling {len(self._background_tasks)} background tasks...") + for task in self._background_tasks: + if not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + self._background_tasks.clear() + logging.info("🛑 All background tasks cancelled") + if self._broadcast_socket: self._broadcast_socket.close() self._broadcast_socket = None + logging.info("🛑 NebulaGPS broadcast socket closed") + logging.info("✅ NebulaGPS service stopped successfully") async def is_running(self): - return self.running + return self._running.is_set() async def get_geoloc(self): latitude = self._config.participant["mobility_args"]["latitude"] @@ -64,7 +84,18 @@ async def calculate_distance(self, self_lat, self_long, other_lat, other_long): async def _send_location_loop(self): """Send the geolocation periodically by broadcast.""" - while self.running: + while await self.is_running(): + # Check if learning cycle has finished + try: + from nebula.core.network.communications import CommunicationsManager + + cm = CommunicationsManager.get_instance() + if cm.learning_finished(): + logging.info("GPS: Learning cycle finished, stopping location broadcast") + break + except Exception: + pass # If we can't get the communications manager, continue + latitude, longitude = await self.get_geoloc() # Obtener ubicación actual message = f"GPS-UPDATE {self._addr} {latitude} {longitude}" self._broadcast_socket.sendto(message.encode(), (self.BROADCAST_IP, self.BROADCAST_PORT)) @@ -74,7 +105,18 @@ async def _send_location_loop(self): async def _receive_location_loop(self): """Listens to and stores geolocations from other nodes.""" - while self.running: + while await self.is_running(): + # Check if learning cycle has finished + try: + from nebula.core.network.communications import CommunicationsManager + + cm = CommunicationsManager.get_instance() + if cm.learning_finished(): + logging.info("GPS: Learning cycle finished, stopping location reception") + break + except Exception: + pass # If we can't get the communications manager, continue + try: data, addr = await asyncio.get_running_loop().run_in_executor( None, self._broadcast_socket.recvfrom, 1024 @@ -91,7 +133,18 @@ async def _receive_location_loop(self): logging.exception(f"Error receiving GPS update: {e}") async def _notify_geolocs(self): - while True: + while await self.is_running(): + # Check if learning cycle has finished + try: + from nebula.core.network.communications import CommunicationsManager + + cm = CommunicationsManager.get_instance() + if cm.learning_finished(): + logging.info("GPS: Learning cycle finished, stopping geolocation notifications") + break + except Exception: + pass # If we can't get the communications manager, continue + await asyncio.sleep(self.update_interval) await self._nodes_location_lock.acquire_async() geolocs: dict = self._node_locations.copy() @@ -102,7 +155,7 @@ async def _notify_geolocs(self): for addr, (lat, long) in geolocs.items(): dist = await self.calculate_distance(self_lat, self_long, lat, long) distances[addr] = (dist, (lat, long)) - + self._config.update_nodes_distance(distances) gpsevent = GPSEvent(distances) asyncio.create_task(EventManager.get_instance().publish_addonevent(gpsevent)) diff --git a/nebula/addons/mobility.py b/nebula/addons/mobility.py index ad3bfd9f5..c4b9e00c3 100755 --- a/nebula/addons/mobility.py +++ b/nebula/addons/mobility.py @@ -50,24 +50,28 @@ def __init__(self, config, verbose=False): """ logging.info("Starting mobility module...") self.config = config - self.grace_time = self.config.participant["mobility_args"]["grace_time_mobility"] - self.period = self.config.participant["mobility_args"]["change_geo_interval"] + self._verbose = verbose + self._running = asyncio.Event() + self._nodes_distances = {} + self._nodes_distances_lock = Locker("nodes_distances_lock", async_lock=True) + self._mobility_task = None # Track the background task + + # Mobility configuration self.mobility = self.config.participant["mobility_args"]["mobility"] self.mobility_type = self.config.participant["mobility_args"]["mobility_type"] - self.radius_federation = float(self.config.participant["mobility_args"]["radius_federation"]) - self.scheme_mobility = self.config.participant["mobility_args"]["scheme_mobility"] - self.round_frequency = int(self.config.participant["mobility_args"]["round_frequency"]) + self.grace_time = self.config.participant["mobility_args"]["grace_time_mobility"] + self.period = self.config.participant["mobility_args"]["change_geo_interval"] # INFO: These values may change according to the needs of the federation self.max_distance_with_direct_connections = 150 # meters self.max_movement_random_strategy = 50 # meters self.max_movement_nearest_strategy = 50 # meters self.max_initiate_approximation = self.max_distance_with_direct_connections * 1.2 + self.radius_federation = float(config.participant["mobility_args"]["radius_federation"]) + self.scheme_mobility = config.participant["mobility_args"]["scheme_mobility"] + self.round_frequency = int(config.participant["mobility_args"]["round_frequency"]) # Logging box with mobility information mobility_msg = f"Mobility: {self.mobility}\nMobility type: {self.mobility_type}\nRadius federation: {self.radius_federation}\nScheme mobility: {self.scheme_mobility}\nEach {self.round_frequency} rounds" print_msg_box(msg=mobility_msg, indent=2, title="Mobility information") - self._nodes_distances = {} - self._nodes_distances_lock = Locker("nodes_distances_lock", async_lock=True) - self._verbose = verbose @cached_property def cm(self): @@ -103,8 +107,30 @@ async def start(self): """ await EventManager.get_instance().subscribe_addonevent(GPSEvent, self.update_nodes_distances) await EventManager.get_instance().subscribe_addonevent(GPSEvent, self.update_nodes_distances) - task = asyncio.create_task(self.run_mobility()) - return task + self._running.set() + self._mobility_task = asyncio.create_task(self.run_mobility(), name="Mobility_run_mobility") + return self._mobility_task + + async def stop(self): + """ + Stops the mobility module. + """ + logging.info("Stopping Mobility module...") + self._running.clear() + + # Cancel the background task + if self._mobility_task and not self._mobility_task.done(): + logging.info("🛑 Cancelling Mobility background task...") + self._mobility_task.cancel() + try: + await self._mobility_task + except asyncio.CancelledError: + pass + self._mobility_task = None + logging.info("🛑 Mobility background task cancelled") + + async def is_running(self): + return self._running.is_set() async def update_nodes_distances(self, gpsevent: GPSEvent): distances = await gpsevent.get_event_data() @@ -138,7 +164,7 @@ async def run_mobility(self): if not self.mobility: return # await asyncio.sleep(self.grace_time) - while True: + while await self.is_running(): await self.change_geo_location() await asyncio.sleep(self.period) diff --git a/nebula/addons/networksimulation/nebulanetworksimulator.py b/nebula/addons/networksimulation/nebulanetworksimulator.py index 22010e872..2d87e9cdf 100644 --- a/nebula/addons/networksimulation/nebulanetworksimulator.py +++ b/nebula/addons/networksimulation/nebulanetworksimulator.py @@ -26,7 +26,7 @@ def __init__(self, changing_interval, interface, verbose=False): self._network_conditions = self.NETWORK_CONDITIONS.copy() self._network_conditions_lock = Locker("network_conditions_lock", async_lock=True) self._current_network_conditions = {} - self._running = False + self._running = asyncio.Event() @cached_property def cm(self): @@ -34,7 +34,7 @@ def cm(self): async def start(self): logging.info("🌐 Nebula Network Simulator starting...") - self._running = True + self._running.set() grace_time = self.cm.config.participant["mobility_args"]["grace_time_mobility"] # if self._verbose: logging.info(f"Waiting {grace_time}s to start applying network conditions based on distances between devices") # await asyncio.sleep(grace_time) @@ -43,37 +43,49 @@ async def start(self): ) async def stop(self): - self._running = False + logging.info("🌐 Nebula Network Simulator stopping...") + self._running.clear() + + async def is_running(self): + return self._running.is_set() async def _change_network_conditions_based_on_distances(self, gpsevent: GPSEvent): distances = await gpsevent.get_event_data() - await asyncio.sleep(self._refresh_interval) - if self._verbose: - logging.info("Refresh | conditions based on distances...") - try: - for addr, (distance, _) in distances.items(): - if distance is None: - # If the distance is not found, we skip the node - continue - conditions = await self._calculate_network_conditions(distance) - # Only update the network conditions if they have changed - if addr not in self._current_network_conditions or self._current_network_conditions[addr] != conditions: - addr_ip = addr.split(":")[0] - self._set_network_condition_for_addr( - self._node_interface, addr_ip, conditions["bandwidth"], conditions["delay"] - ) - self._set_network_condition_for_multicast( - self._node_interface, addr_ip, self.IP_MULTICAST, conditions["bandwidth"], conditions["delay"] - ) - async with self._network_conditions_lock: - self._current_network_conditions[addr] = conditions - else: - if self._verbose: - logging.info("network conditions havent changed since last time") - except KeyError: - logging.exception(f"📍 Connection {addr} not found") - except Exception: - logging.exception("📍 Error changing connections based on distance") + while await self.is_running(): + if self._verbose: + logging.info("Refresh | conditions based on distances...") + try: + for addr, (distance, _) in distances.items(): + if distance is None: + # If the distance is not found, we skip the node + continue + conditions = await self._calculate_network_conditions(distance) + # Only update the network conditions if they have changed + if ( + addr not in self._current_network_conditions + or self._current_network_conditions[addr] != conditions + ): + addr_ip = addr.split(":")[0] + self._set_network_condition_for_addr( + self._node_interface, addr_ip, conditions["bandwidth"], conditions["delay"] + ) + self._set_network_condition_for_multicast( + self._node_interface, + addr_ip, + self.IP_MULTICAST, + conditions["bandwidth"], + conditions["delay"], + ) + async with self._network_conditions_lock: + self._current_network_conditions[addr] = conditions + else: + if self._verbose: + logging.info("network conditions havent changed since last time") + except KeyError: + logging.exception(f"📍 Connection {addr} not found") + except Exception: + logging.exception("📍 Error changing connections based on distance") + await asyncio.sleep(self._refresh_interval) async def set_thresholds(self, thresholds: dict): async with self._network_conditions_lock: diff --git a/nebula/addons/reporter.py b/nebula/addons/reporter.py index bf16a4b44..376f6a208 100755 --- a/nebula/addons/reporter.py +++ b/nebula/addons/reporter.py @@ -67,7 +67,8 @@ def __init__(self, config, trainer): self.acc_bytes_recv = 0 self.acc_packets_sent = 0 self.acc_packets_recv = 0 - self.scenario_finished = asyncio.Event() + self._running = asyncio.Event() + self._reporter_task = None # Track the background task @property def cm(self): @@ -114,9 +115,10 @@ async def start(self): - The grace period allows for a delay before the first reporting cycle. - The reporter loop runs in the background, ensuring continuous data updates. """ + self._running.set() await asyncio.sleep(self.grace_time) - task = asyncio.create_task(self.run_reporter()) - return task + self._reporter_task = asyncio.create_task(self.run_reporter(), name="Reporter_run_reporter") + return self._reporter_task async def run_reporter(self): """ @@ -133,7 +135,7 @@ async def run_reporter(self): Notes: - The reporting frequency is determined by the 'report_frequency' setting in the config file. """ - while not self.scenario_finished.is_set(): + while self._running.is_set(): if self.config.participant["reporter_args"]["report_status_data_queue"]: if self.config.participant["scenario_args"]["controller"] != "nebula-test": await self.__report_status_to_controller() @@ -194,9 +196,21 @@ async def report_scenario_finished(self): except aiohttp.ClientError: logging.exception(f"Error connecting to the controller at {url}") return False - - def shutdown(self): - self.scenario_finished.set() + + async def stop(self): + logging.info("🔍 Stopping reporter module...") + self._running.clear() + + # Cancel the background task + if self._reporter_task and not self._reporter_task.done(): + logging.info("🛑 Cancelling Reporter background task...") + self._reporter_task.cancel() + try: + await self._reporter_task + except asyncio.CancelledError: + pass + self._reporter_task = None + logging.info("🛑 Reporter background task cancelled") async def __report_data_queue(self): """ diff --git a/nebula/controller/database.py b/nebula/controller/database.py index 2fd77fca3..7a012fd8a 100755 --- a/nebula/controller/database.py +++ b/nebula/controller/database.py @@ -1031,7 +1031,7 @@ def get_running_scenario(username=None, get_all=False): sqlite3.Row or list[sqlite3.Row]: A single scenario record or a list of scenario records matching the criteria. Behavior: - - Filters scenarios with status "running" or "completed". + - Filters scenarios with status "running". - Applies username filter if provided. - Returns either one or all matching records depending on get_all. """ @@ -1042,14 +1042,14 @@ def get_running_scenario(username=None, get_all=False): if username: command = """ SELECT * FROM scenarios - WHERE (status = ? OR status = ?) AND username = ?; + WHERE (status = ?) AND username = ?; """ - c.execute(command, ("running", "completed", username)) + c.execute(command, ("running", username)) result = c.fetchone() else: - command = "SELECT * FROM scenarios WHERE status = ? OR status = ?;" - c.execute(command, ("running", "completed")) + command = "SELECT * FROM scenarios WHERE status = ?;" + c.execute(command, ("running",)) if get_all: result = c.fetchall() else: @@ -1174,7 +1174,8 @@ def check_scenario_federation_completed(scenario_name): return False # Check if all nodes have completed the total rounds - return all(node["round"] == total_rounds for node in nodes) + total_rounds_str = str(total_rounds) + return all(str(node["round"]) == total_rounds_str for node in nodes) except sqlite3.Error as e: print(f"Database error: {e}") @@ -1243,10 +1244,10 @@ def save_notes(scenario, notes): def get_notes(scenario): """ Retrieve notes associated with a specific scenario. - + Parameters: scenario (str): The unique identifier of the scenario. - + Returns: sqlite3.Row or None: The notes record for the given scenario, or None if no notes exist. """ @@ -1275,7 +1276,7 @@ def remove_note(scenario): if __name__ == "__main__": """ Entry point for the script to print the list of users. - + When executed directly, this block calls the `list_users()` function and prints its returned list of users. """ diff --git a/nebula/core/addonmanager.py b/nebula/core/addonmanager.py index c994e994c..46fbfd60a 100644 --- a/nebula/core/addonmanager.py +++ b/nebula/core/addonmanager.py @@ -1,9 +1,11 @@ +import logging from typing import TYPE_CHECKING -from nebula.config.config import Config + from nebula.addons.functions import print_msg_box from nebula.addons.gps.gpsmodule import factory_gpsmodule from nebula.addons.mobility import Mobility from nebula.addons.networksimulation.networksimulator import factory_network_simulator +from nebula.config.config import Config if TYPE_CHECKING: from nebula.core.engine import Engine @@ -16,7 +18,7 @@ class AddondManager: This class handles the lifecycle of optional services (add-ons) such as mobility simulation, GPS module, and network simulation. Add-ons are conditionally deployed based on the provided configuration. """ - + def __init__(self, engine: "Engine", config: Config): """ Initializes the AddondManager instance. @@ -51,10 +53,10 @@ async def deploy_additional_services(self): print_msg_box(msg="Deploying Additional Services", indent=2, title="Addons Manager") if self._config.participant["trustworthiness"]: from nebula.addons.trustworthiness.trustworthiness import Trustworthiness - + trustworthiness = Trustworthiness(self._engine, self._config) self._addons.append(trustworthiness) - + if self._config.participant["mobility_args"]["mobility"]: mobility = Mobility(self._config, verbose=False) self._addons.append(mobility) @@ -70,3 +72,17 @@ async def deploy_additional_services(self): for add in self._addons: await add.start() + + async def stop_additional_services(self): + """ + Stops all additional services. + """ + logging.info("🛑 Stopping additional services...") + for add in self._addons: + try: + logging.info(f"🛑 Stopping addon: {add.__class__.__name__}") + await add.stop() + logging.info(f"✅ Successfully stopped addon: {add.__class__.__name__}") + except Exception as e: + logging.exception(f"❌ Error stopping addon {add.__class__.__name__}: {e}") + logging.info("🛑 Finished stopping additional services") diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 3269868db..9ddd9609b 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -4,10 +4,9 @@ import random import socket import time + import docker -from nebula.core.role import Role, factory_node_role -from nebula.addons.attacks.attacks import create_attack from nebula.addons.functions import print_msg_box from nebula.addons.reporter import Reporter from nebula.addons.reputation.reputation import Reputation @@ -16,13 +15,14 @@ from nebula.core.eventmanager import EventManager from nebula.core.nebulaevents import ( AggregationEvent, + ExperimentFinishEvent, RoundEndEvent, RoundStartEvent, UpdateNeighborEvent, UpdateReceivedEvent, - ExperimentFinishEvent, ) from nebula.core.network.communications import CommunicationsManager +from nebula.core.role import Role, factory_node_role from nebula.core.situationalawareness.situationalawareness import SituationalAwareness from nebula.core.utils.locker import Locker @@ -155,6 +155,8 @@ def __init__( # Additional Components if "situational_awareness" in self.config.participant: self._situational_awareness = SituationalAwareness(self.config, self) + else: + self._situational_awareness = None if self.config.participant["defense_args"]["reputation"]["enabled"]: self._reputation = Reputation(engine=self, config=self.config) @@ -320,10 +322,10 @@ async def _control_leadership_transfer_callback(self, source, message): await self.cm.send_message(source, message) logging.info(f"🔧 handle_control_message | Trigger | Leadership transfer ack message sent to {source}") else: - logging.info(f"🔧 handle_control_message | Trigger | Only one neighbor found, I am the leader") + logging.info("🔧 handle_control_message | Trigger | Only one neighbor found, I am the leader") else: self.role = Role.AGGREGATOR - logging.info(f"🔧 handle_control_message | Trigger | I am now the leader") + logging.info("🔧 handle_control_message | Trigger | I am now the leader") message = self.cm.create_message("control", "leadership_transfer_ack") await self.cm.send_message(source, message) logging.info(f"🔧 handle_control_message | Trigger | Leadership transfer ack message sent to {source}") @@ -462,7 +464,7 @@ async def broadcast_models_include(self, age: AggregationEvent): """ # Don't broadcast if the learning cycle has finished if not self.learning_cycle_finished(): - logging.info(f"🔄 Not broadcasting MODELS_INCLUDED because learning cycle has finished") + logging.info("🔄 Not broadcasting MODELS_INCLUDED because learning cycle has finished") return logging.info(f"🔄 Broadcasting MODELS_INCLUDED for round {self.get_round()}") @@ -813,82 +815,105 @@ async def _learning_cycle(self): result = await self.reporter.report_scenario_finished() if result: logging.info("📝 Scenario finished reported successfully") - self.reporter.shutdown() + await self.reporter.stop() else: logging.error("📝 Error reporting scenario finished") except Exception as e: - logging.error(f"📝 Error during scenario finish report: {e}") + logging.exception(f"📝 Error during scenario finish report: {e}") - # Get all tasks except the current one - tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] - # Cut connections - await self.cm.stop() + # Call centralized shutdown + await self.shutdown() + return - logging.info("Starting graceful shutdown process...") - + async def _extended_learning_cycle(self): + """ + This method is called in each round of the learning cycle. It is used to extend the learning cycle with additional + functionalities. The method is called in the _learning_cycle method. + """ + pass + + async def shutdown(self): + logging.info("🚦 Engine shutdown initiated") + + # Stop addon services first + try: + await self._addon_manager.stop_additional_services() + except Exception as e: + logging.exception("Error stopping add-ons: %s", e) + + # Stop reporter + try: + await self._reporter.stop() + except Exception as e: + logging.exception("Error stopping reporter: %s", e) + + # Stop communications manager (includes forwarder, discoverer, propagator, ECS) + try: + await self.cm.stop() + except Exception as e: + logging.exception("Error stopping communications manager: %s", e) + + # Stop situational awareness + try: + if self.sa: + await self.sa.stop() + except Exception as e: + logging.exception("Error stopping situational awareness: %s", e) + + # Task cleanup with improved handling + logging.info("Starting graceful task cleanup...") tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] - current_task = asyncio.current_task() - all_tasks = asyncio.all_tasks() - - # Log basic task info - logging.info("Task Summary:") - - # Log task names if available - if current_task: - logging.info(f" • Current task: {current_task}") - - for task in all_tasks: - logging.info(f" • Task: {task}") - logging.info(f" • Task name: {task.get_name()}") - logging.info(f" • Task state: {task.get_state()}") - logging.info(f" • Task coroutine: {task.get_coro()}") - logging.info(f" • Task done: {task.done()}") - logging.info(f" • Task cancelled: {task.cancelled()}") - logging.info(f" • Task exception: {task.exception()}") - logging.info(f" • Task result: {task.result()}") - - - model_tasks = [t for t in tasks if any(name in t.get_name().lower() for name in ["model", "aggregation"])] - if model_tasks: - logging.info("Waiting for model and aggregation tasks to complete...") - try: - await asyncio.wait_for(asyncio.gather(*model_tasks, return_exceptions=True), timeout=15) - except asyncio.TimeoutError: - logging.warning("Model tasks did not complete in time") + if tasks: + logging.info(f"Found {len(tasks)} remaining tasks to clean up") + for task in tasks: + logging.info(f" • Task: {task.get_name()} - {task}") + logging.info(f" • State: {task._state} - Done: {task.done()} - Cancelled: {task.cancelled()}") - other_tasks = [t for t in tasks if t not in model_tasks] - if other_tasks: - logging.info("Waiting for remaining tasks to complete...") + # Wait for tasks to complete naturally with shorter timeout try: - await asyncio.wait_for(asyncio.gather(*other_tasks, return_exceptions=True), timeout=15) - except asyncio.TimeoutError: + await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=3) + except TimeoutError: logging.warning("Some tasks did not complete in time, forcing cancellation...") - for task in other_tasks: + for task in tasks: if not task.done(): task.cancel() - await asyncio.gather(*other_tasks, return_exceptions=True) - - # Remove all pending tasks - asyncio.all_tasks().clear() - - # Shutdown all logging handlers - self.config.shutdown_logging() - - # From here, logging is disabled - print("Shutdown complete. Terminating NEBULA CORE...") - - # Kill itself + # Wait a bit more for cancellations to take effect + try: + await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=2) + except TimeoutError: + logging.warning("Some tasks still not responding to cancellation") + + # Final aggressive cleanup - cancel all remaining tasks + remaining_tasks = [ + t for t in asyncio.all_tasks() if t is not asyncio.current_task() and not t.done() + ] + if remaining_tasks: + logging.warning(f"Forcing cancellation of {len(remaining_tasks)} remaining tasks") + for task in remaining_tasks: + task.cancel() + try: + await asyncio.wait_for(asyncio.gather(*remaining_tasks, return_exceptions=True), timeout=1) + except TimeoutError: + logging.exception("Some tasks still not responding to forced cancellation") + + logging.info("✅ Engine shutdown complete") + + # Kill Docker container if running in Docker if self.config.participant["scenario_args"]["deployment"] == "docker": try: docker_id = socket.gethostname() - self.client.containers.get(docker_id).kill() + logging.info(f"📦 Removing docker container with ID {docker_id}") + container = self.client.containers.get(docker_id) + container.remove(force=True) + logging.info(f"📦 Successfully removed docker container {docker_id}") except Exception as e: - print(f"Error stopping Docker container with ID {docker_id}: {e}") + logging.exception(f"📦 Error removing Docker container {docker_id}: {e}") + # Try to force kill the container as last resort + try: + import subprocess - async def _extended_learning_cycle(self): - """ - This method is called in each round of the learning cycle. It is used to extend the learning cycle with additional - functionalities. The method is called in the _learning_cycle method. - """ - pass + subprocess.run(["docker", "rm", "-f", docker_id], check=False) + logging.info(f"📦 Forced removal of container {docker_id} via subprocess") + except Exception as sub_e: + logging.exception(f"📦 Failed to force remove container {docker_id}: {sub_e}") diff --git a/nebula/core/network/blacklist.py b/nebula/core/network/blacklist.py index b239bc07e..2a6c9a66f 100644 --- a/nebula/core/network/blacklist.py +++ b/nebula/core/network/blacklist.py @@ -16,7 +16,7 @@ class BlackList: The blacklist tracks nodes that are temporarily excluded from communication or interaction due to malicious behavior or disconnection events. Nodes remain blacklisted for a fixed period defined by `max_time_listed`. - + The recently disconnected list tracks peers that were recently disconnected and may need to be temporarily avoided. Key features: @@ -26,14 +26,19 @@ class BlackList: """ def __init__(self, max_time_listed=BLACKLIST_EXPIRATION_TIME): + """ + Initialize the BlackList with the specified expiration time. + + Args: + max_time_listed (int): Maximum time in seconds for nodes to remain blacklisted. + """ self._max_time_listed = max_time_listed - self._blacklisted_nodes: dict = {} - self._recently_disconnected: set = set() - self._recently_disconnected_lock = Locker(name="recently_disconnected_lock", async_lock=True) - self._blacklisted_nodes_lock = Locker(name="blacklisted_nodes_lock", async_lock=True) - self._bl_cleaner_running = False - self._blacklist_cleaner_wake_up = asyncio.Event() - self._running = False + self._blacklisted_nodes = {} + self._recently_disconnected = set() + self._blacklisted_nodes_lock = Locker("blacklisted_nodes_lock", async_lock=True) + self._recently_disconnected_lock = Locker("recently_disconnected_lock", async_lock=True) + self._running = asyncio.Event() + self._background_tasks = [] # Track background tasks async def apply_restrictions(self, nodes) -> set | None: """ @@ -75,8 +80,8 @@ async def add_to_blacklist(self, addr): await self._blacklisted_nodes_lock.acquire_async() expiration_time = time.time() self._blacklisted_nodes[addr] = expiration_time - if not self._running: - self._running = True + if not self._running.is_set(): + self._running.set() asyncio.create_task(self._start_blacklist_cleaner()) await self._blacklisted_nodes_lock.release_async() nbe = NodeBlacklistedEvent(addr, blacklisted=True) @@ -111,7 +116,7 @@ async def _start_blacklist_cleaner(self): """ Background task that periodically removes expired entries from the blacklist. """ - while self._running: + while self._running.is_set(): await self._blacklist_clean() await self._blacklist_cleaner_wait() @@ -132,7 +137,7 @@ async def _blacklist_clean(self): self._blacklisted_nodes = new_bl if not new_bl: - self._running = False + self._running.clear() await self._blacklisted_nodes_lock.release_async() async def _blacklist_cleaner_wait(self): @@ -197,7 +202,8 @@ async def add_recently_disconnected(self, addr): await self._recently_disconnected_lock.acquire_async() self._recently_disconnected.add(addr) await self._recently_disconnected_lock.release_async() - asyncio.create_task(self._remove_recently_disc(addr)) + task = asyncio.create_task(self._remove_recently_disc(addr), name=f"BlackList_remove_recently_{addr}") + self._background_tasks.append(task) nbe = NodeBlacklistedEvent(addr) event_manager = EventManager.get_instance() if event_manager is not None: @@ -258,3 +264,40 @@ async def verify_not_recently_disc(self, nodes: set) -> set | None: nodes_not_listed = nodes.difference(rec_disc) await self._recently_disconnected_lock.release_async() return nodes_not_listed + + async def stop(self): + """ + Stop the BlackList by clearing all data and stopping background tasks. + """ + logging.info("🛑 Stopping BlackList...") + + # Stop the background cleaner + self._running.clear() + + # Cancel all background tasks + if self._background_tasks: + logging.info(f"🛑 Cancelling {len(self._background_tasks)} background tasks...") + for task in self._background_tasks: + if not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + self._background_tasks.clear() + logging.info("🛑 All background tasks cancelled") + + # Clear all data + try: + async with self._blacklisted_nodes_lock: + self._blacklisted_nodes.clear() + except Exception as e: + logging.warning(f"Error clearing blacklist: {e}") + + try: + async with self._recently_disconnected_lock: + self._recently_disconnected.clear() + except Exception as e: + logging.warning(f"Error clearing recently disconnected: {e}") + + logging.info("✅ BlackList stopped successfully") diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 274b586f1..5270c5d13 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -90,6 +90,7 @@ def __init__(self, engine: "Engine"): self._discoverer = Discoverer(addr=self.addr, config=self.config) # self._health = Health(addr=self.addr, config=self.config) + self._health = None self._forwarder = Forwarder(config=self.config) self._propagator = Propagator() @@ -109,6 +110,7 @@ def __init__(self, engine: "Engine"): self._external_connection_service = factory_connection_service("nebula", self.addr) self._initialized = True + self._running = asyncio.Event() logging.info("Communication Manager initialization completed") @property @@ -195,6 +197,7 @@ async def start_communications(self, initial_neighbors): Args: initial_neighbors (list): A list of neighbor addresses to connect to after startup. """ + self._running.set() logging.info(f"Neighbors: {self.config.participant['network_args']['neighbors']}") logging.info( f"💤 Cold start time: {self.config.participant['misc_args']['grace_time_connection']} seconds before connecting to the network" @@ -379,7 +382,7 @@ async def is_external_connection_service_running(self): Returns: bool: True if the ECS is running, False otherwise. """ - return self.ecs.is_running() + return await self.ecs.is_running() async def start_beacon(self): """ @@ -551,6 +554,15 @@ async def process_connection(reader, writer, priority="medium"): try: addr = writer.get_extra_info("peername") + # Check if learning cycle has finished - reject new connections + if self.engine.learning_cycle_finished(): + logging.info(f"🔗 [incoming] Rejecting connection from {addr} because learning cycle has finished") + writer.write(b"CONNECTION//CLOSE\n") + await writer.drain() + writer.close() + await writer.wait_closed() + return + connected_node_id = await reader.readline() connected_node_id = connected_node_id.decode("utf-8").strip() connected_node_port = addr[1] @@ -682,19 +694,48 @@ async def terminate_failed_reconnection(self, conn: Connection): await self.disconnect(connected_with, mutual_disconnection=False) async def stop(self): - logging.info("🌐 Stopping Communications Manager... [Removing connections and stopping network engine]") + logging.info("🌐 Stopping Communications Manager...") + + # Stop accepting new connections first + if self.network_engine: + logging.info("🌐 Closing network engine server...") + self.network_engine.close() + await self.network_engine.wait_closed() + if hasattr(self, "network_task") and self.network_task: + self.network_task.cancel() + try: + await self.network_task + except asyncio.CancelledError: + pass + + # Stop all existing connections async with self.connections_lock: connections = list(self.connections.values()) for node in connections: await node.stop() - if hasattr(self, "server"): - self.network_engine.close() - await self.network_engine.wait_closed() - self.network_task.cancel() - if self.forwarder: - self.forwarder.shutdown() + + # Stop additional services + if self._forwarder: + await self._forwarder.stop() if self.ecs: await self.ecs.stop() + if self.discoverer: + await self.discoverer.stop() + if self.health: + try: + await self.health.stop() + except Exception as e: + logging.warning(f"Error stopping health service: {e}") + if self._propagator: + await self._propagator.stop() + if self._blacklist: + await self._blacklist.stop() + + self._running.clear() + + self.stop_network_engine.set() + + logging.info("🌐 Communications Manager stopped successfully") async def run_reconnections(self): for connection in self.connections_reconnect: @@ -854,8 +895,13 @@ async def establish_connection(self, addr, direct=True, reconnect=False, priorit priority (str, optional): Priority level for this connection ("low", "medium", "high"). Defaults to "medium". Returns: - bool: True if the connection was successfully established or upgraded, False otherwise. + bool: True if the connection action (new or upgrade) succeeded, False otherwise. """ + # Check if learning cycle has finished - don't establish new connections + if self.engine.learning_cycle_finished(): + logging.info(f"🔗 [outgoing] Not establishing connection to {addr} because learning cycle has finished") + return False + logging.info(f"🔗 [outgoing] Establishing connection with {addr} (direct: {direct})") async def process_establish_connection(addr, direct, reconnect, priority): @@ -1049,7 +1095,7 @@ async def register(self): logging.error(f"Error registering node {self.addr} in the controller") async def wait_for_controller(self): - while True: + while await self.is_running(): response = requests.get(self.wait_endpoint) if response.status_code == 200: logging.info("Continue signal received from controller") @@ -1058,6 +1104,9 @@ async def wait_for_controller(self): logging.info("Waiting for controller signal...") await asyncio.sleep(1) + async def is_running(self): + return self._running.is_set() + async def disconnect(self, dest_addr, mutual_disconnection=True, forced=False): """ Disconnects from a specified destination address and performs cleanup tasks. diff --git a/nebula/core/network/connection.py b/nebula/core/network/connection.py index d44ef0447..3932a2bb9 100755 --- a/nebula/core/network/connection.py +++ b/nebula/core/network/connection.py @@ -39,7 +39,7 @@ class Connection: """ Manages TCP communication channels using asyncio for asynchronous networking. - This class encapsulates the logic for establishing, maintaining, + This class encapsulates the logic for establishing, maintaining, and handling TCP connections between nodes in the distributed system. Responsibilities: @@ -51,12 +51,12 @@ class Connection: Usage: - Used by nodes to communicate asynchronously with others. - Supports concurrent message exchange via asyncio streams. - + Note: This implementation leverages asyncio to enable scalable and efficient networking in distributed federated learning scenarios. """ - + DEFAULT_FEDERATED_ROUND = -1 INACTIVITY_TIMER = 120 INACTIVITY_DAEMON_SLEEP_TIME = 20 @@ -91,6 +91,7 @@ def __init__( self.loop = asyncio.get_event_loop() self.read_task = None self.process_task = None + self.inactivity_task = None self.pending_messages_queue = asyncio.Queue(maxsize=100) self.message_buffers: dict[bytes, dict[int, MessageChunk]] = {} self._prio: ConnectionPriority = ConnectionPriority(prio) @@ -112,6 +113,7 @@ def __init__( self.incompleted_reconnections = 0 self.forced_disconnection = False + self._running = asyncio.Event() logging.info( f"Connection [established]: {self.addr} (id: {self.id}) (active: {self.active}) (direct: {self.direct})" @@ -144,7 +146,7 @@ def get_prio(self): async def is_inactive(self): """ Check if the connection is currently marked as inactive. - + Returns: bool: True if inactive, False otherwise. """ @@ -162,12 +164,12 @@ async def _update_activity(self): async def _monitor_inactivity(self): """ Background task that monitors the connection for inactivity. - + Runs indefinitely until the connection is marked as direct, periodically checking if the last activity exceeds the inactivity threshold. If inactive, marks the connection as inactive and logs a warning. """ - while True: + while await self.is_running(): if self.direct: break await asyncio.sleep(self.INACTIVITY_DAEMON_SLEEP_TIME) @@ -196,7 +198,7 @@ def get_ready(self): def get_direct(self): """ Check if the connection is marked as direct ( a.k.a neighbor ). - + Returns: bool: True if direct, False otherwise. """ @@ -221,6 +223,9 @@ def is_active(self): def get_last_active(self): return self.last_active + async def is_running(self): + return self._running.is_set() + async def start(self): """ Start the connection by launching asynchronous tasks for handling incoming messages, @@ -231,9 +236,10 @@ async def start(self): 2. `process_message_queue` - processes messages queued for sending or further handling. 3. `_monitor_inactivity` - periodically checks if the connection has been inactive and updates its state accordingly. """ + self._running.set() self.read_task = asyncio.create_task(self.handle_incoming_message(), name=f"Connection {self.addr} reader") self.process_task = asyncio.create_task(self.process_message_queue(), name=f"Connection {self.addr} processor") - asyncio.create_task(self._monitor_inactivity()) + self.inactivity_task = asyncio.create_task(self._monitor_inactivity()) async def stop(self): """ @@ -245,9 +251,10 @@ async def stop(self): - Cancels the read and process tasks if they exist, awaiting their cancellation and logging any cancellation exceptions. - Closes the writer stream safely, awaiting its closure and logging any errors that occur during the closing process. """ + self._running.clear() logging.info(f"❗️ Connection [stopped]: {self.addr} (id: {self.id})") self.forced_disconnection = True - tasks = [self.read_task, self.process_task] + tasks = [self.read_task, self.process_task, self.inactivity_task] for task in tasks: if task is not None: task.cancel() @@ -276,7 +283,7 @@ async def reconnect(self, max_retries: int = 5, delay: int = 5) -> None: - Upon success, recreates the read and process asyncio tasks for this connection. - Logs the successful reconnection if not forced to disconnect, then returns. - If all retries fail, logs the failure and terminates the failed reconnection via the Communication manager. - + Args: max_retries (int): Maximum number of reconnection attempts. Defaults to 5. delay (int): Delay in seconds between reconnection attempts. Defaults to 5. @@ -285,7 +292,7 @@ async def reconnect(self, max_retries: int = 5, delay: int = 5) -> None: logging.info("Not going to reconnect because this connection is not direct") return - # Check if learning cycle has finished to prevent unnecessary reconnection attempts + # Check if learning cycle has finished - don't reconnect if self.cm.learning_finished(): logging.info(f"Not attempting reconnection to {self.addr} because learning cycle has finished") return @@ -318,7 +325,6 @@ async def reconnect(self, max_retries: int = 5, delay: int = 5) -> None: logging.exception(f"Reconnection attempt {attempt + 1} failed: {e}") await asyncio.sleep(delay) logging.error(f"Failed to reconnect to {self.addr} after {max_retries} attempts. Stopping connection...") - # await self.stop() await self.cm.terminate_failed_reconnection(self) async def send( @@ -347,6 +353,11 @@ async def send( logging.error("Cannot send data, writer is None") return + # Check if learning cycle has finished - don't send messages + if self.cm.learning_finished(): + logging.info(f"Not sending message to {self.addr} because learning cycle has finished") + return + try: message_id = uuid.uuid4().bytes data_prefix, encoded_data = self._prepare_data(data, pb, encoding_type) @@ -421,8 +432,8 @@ async def _send_chunks(self, message_id: bytes, data: bytes) -> None: """ Sends the encoded data over the connection in fixed-size chunks. - Each chunk is prefixed with a header containing the message ID, chunk index, - a flag indicating if it's the last chunk, and the size of the chunk. + Each chunk is prefixed with a header containing the message ID, chunk index, + a flag indicating if it's the last chunk, and the size of the chunk. An end-of-transmission (EOT) character is appended to each chunk. Args: @@ -469,7 +480,7 @@ async def handle_incoming_message(self) -> None: """ reusable_buffer = bytearray(self.MAX_CHUNK_SIZE) try: - while True: + while await self.is_running(): if self.pending_messages_queue.full(): await asyncio.sleep(0.1) # Wait a bit if the queue is full to create backpressure continue @@ -490,11 +501,8 @@ async def handle_incoming_message(self) -> None: logging.exception(f"Connection closed while reading: {e}") except Exception as e: logging.exception(f"Error handling incoming message: {e}") - except BrokenPipeError as e: - logging.exception(f"Broken pipe error handling incoming message: {e}") finally: - # Only attempt reconnection if the learning cycle hasn't finished and the connection is direct or high priority - if (self.direct or self._prio == ConnectionPriority.HIGH) and not self.cm.learning_finished(): + if self.direct or self._prio == ConnectionPriority.HIGH: await self.reconnect() elif self.cm.learning_finished(): logging.info(f"Not attempting reconnection to {self.addr} because learning cycle has finished") @@ -524,7 +532,7 @@ async def _read_exactly(self, num_bytes: int, max_retries: int = 3) -> bytes: try: while remaining > 0: chunk = await self.reader.read(min(remaining, self.BUFFER_SIZE)) - if not chunk and not self.cm.learning_finished(): + if not chunk: raise ConnectionError("Connection closed while reading") data += chunk remaining -= len(chunk) @@ -561,7 +569,7 @@ async def _read_chunk(self, buffer: bytearray = None) -> bytes: Reads a data chunk from the stream, validating its size and EOT marker. Args: - buffer (bytearray, optional): A reusable buffer to store the chunk. + buffer (bytearray, optional): A reusable buffer to store the chunk. If not provided, a new buffer of MAX_CHUNK_SIZE will be created. Returns: @@ -679,7 +687,7 @@ async def process_message_queue(self) -> None: Notes: Runs indefinitely unless externally cancelled or stopped. """ - while True: + while await self.is_running(): try: if self.pending_messages_queue is None: logging.error("Pending messages queue is not initialized") diff --git a/nebula/core/network/discoverer.py b/nebula/core/network/discoverer.py index b0f1e83cb..34bae2aac 100755 --- a/nebula/core/network/discoverer.py +++ b/nebula/core/network/discoverer.py @@ -13,6 +13,7 @@ def __init__(self, addr, config): self.grace_time = self.config.participant["discoverer_args"]["grace_time_discovery"] self.period = self.config.participant["discoverer_args"]["discovery_frequency"] self.interval = self.config.participant["discoverer_args"]["discovery_interval"] + self._running = asyncio.Event() @property def cm(self): @@ -25,6 +26,7 @@ def cm(self): return self._cm async def start(self): + self._running.set() asyncio.create_task(self.run_discover()) async def run_discover(self): @@ -32,7 +34,7 @@ async def run_discover(self): logging.info("🔍 Federation is CFL. Discoverer is disabled...") return await asyncio.sleep(self.grace_time) - while True: + while await self.is_running(): if len(self.cm.connections) > 0: latitude = self.config.participant["mobility_args"]["latitude"] longitude = self.config.participant["mobility_args"]["longitude"] @@ -44,3 +46,10 @@ async def run_discover(self): except Exception as e: logging.exception(f"🔍 Cannot send discovery message to neighbors. Error: {e!s}") await asyncio.sleep(self.period) + + async def stop(self): + self._running.clear() + logging.info("🔍 Stopping Discoverer module...") + + async def is_running(self): + return self._running.is_set() diff --git a/nebula/core/network/externalconnection/externalconnectionservice.py b/nebula/core/network/externalconnection/externalconnectionservice.py index b9d16a549..5ac01ada2 100644 --- a/nebula/core/network/externalconnection/externalconnectionservice.py +++ b/nebula/core/network/externalconnection/externalconnectionservice.py @@ -9,7 +9,7 @@ class ExternalConnectionService(ABC): for discovering federations and managing beacon signals that announce node presence in the network. """ - + @abstractmethod async def start(self): """ @@ -31,7 +31,7 @@ async def stop(self): pass @abstractmethod - def is_running(self): + async def is_running(self): """ Check whether the external connection service is currently active. @@ -85,6 +85,7 @@ class ExternalConnectionServiceException(Exception): """ Exception raised for errors related to external connection services. """ + pass diff --git a/nebula/core/network/externalconnection/nebuladiscoveryservice.py b/nebula/core/network/externalconnection/nebuladiscoveryservice.py index 72ba856d4..ec729577a 100644 --- a/nebula/core/network/externalconnection/nebuladiscoveryservice.py +++ b/nebula/core/network/externalconnection/nebuladiscoveryservice.py @@ -2,6 +2,7 @@ import logging import socket import struct + from nebula.core.eventmanager import EventManager from nebula.core.nebulaevents import BeaconRecievedEvent, ChangeLocationEvent from nebula.core.network.externalconnection.externalconnectionservice import ExternalConnectionService @@ -115,6 +116,12 @@ def connection_made(self, transport): sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 2) asyncio.create_task(self.keep_search()) + async def stop(self): + """ + Stop the client protocol by setting the search_done event to release any waiting tasks. + """ + self.search_done.set() + async def keep_search(self): """ Periodically broadcast search requests to discover other nodes in the federation. @@ -126,7 +133,6 @@ async def keep_search(self): to indicate that the search phase is finished. """ logging.info("Federation searching loop started") - # while True: for _ in range(self.SEARCH_TRIES): await self.search() await asyncio.sleep(self.SEARCH_INTERVAL) @@ -180,15 +186,15 @@ def __init__(self, nebula_service, addr, interval=20): self.nebula_service: NebulaConnectionService = nebula_service self.addr = addr self.interval = interval # Send interval in seconds - self.running = False self._latitude = None self._longitude = None + self._running = asyncio.Event() async def start(self): logging.info("[NebulaBeacon]: Starting sending pressence beacon") - self.running = True + self._running.set() await EventManager.get_instance().subscribe_addonevent(ChangeLocationEvent, self._proces_change_location_event) - while self.running: + while await self.is_running(): await asyncio.sleep(self.interval) await self.send_beacon() @@ -199,7 +205,11 @@ async def _proces_change_location_event(self, cle: ChangeLocationEvent): async def stop(self): logging.info("[NebulaBeacon]: Stop existance beacon") - self.running = False + self._running.clear() + logging.info("[NebulaBeacon]: _running event cleared") + + async def is_running(self): + return self._running.is_set() async def modify_beacon_frequency(self, frequency): logging.info(f"[NebulaBeacon]: Changing beacon frequency from {self.interval}s to {frequency}s") @@ -235,7 +245,8 @@ def __init__(self, addr): self.server: NebulaServerProtocol = None self.client: NebulaClientProtocol = None self.beacon: NebulaBeacon = NebulaBeacon(self, self.addr) - self.running = False + self._running = asyncio.Event() + self._beacon_task = None # Track the beacon task @property def cm(self): @@ -248,44 +259,67 @@ def cm(self): return self._cm async def start(self): - loop = asyncio.get_running_loop() - transport, self.server = await loop.create_datagram_endpoint( - lambda: NebulaServerProtocol(self, self.addr), local_addr=("0.0.0.0", 1900) - ) + self._running.set() + try: + loop = asyncio.get_running_loop() + transport, self.server = await loop.create_datagram_endpoint( + lambda: NebulaServerProtocol(self, self.addr), local_addr=("0.0.0.0", 1900) + ) + except Exception as e: + logging.exception(f"Error starting Nebula Connection Service server: {e}") + await self.stop() try: # Advanced socket settings sock = transport.get_extra_info("socket") if sock is not None: - group = socket.inet_aton("239.255.255.250") # Multicast to binary format. - mreq = struct.pack("4sL", group, socket.INADDR_ANY) # Join multicast group in every interface available + group = socket.inet_aton("239.255.255.250") # Multicast to binary format. + mreq = struct.pack("4sL", group, socket.INADDR_ANY) # Join multicast group in every interface available sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq) # SO listen multicast packages except Exception as e: - logging.exception(f"{e}") - self.running = True + logging.exception(f"Error starting Nebula Connection Service client: {e}") + await self.stop() async def stop(self): - logging.info("Stop Nebula Connection Service") - if self.server and self.server.transport: - self.server.transport.close() - await self.beacon.stop() - self.running = False + logging.info("🔗 Stopping Nebula Connection Service...") + if self.server: + if self.server.transport: + self.server.transport.close() + self.server = None + if self.client: + await self.client.stop() + if self.client.transport: + self.client.transport.close() + self.client = None + if self.beacon: + await self.stop_beacon() + self.beacon = None + self._running.clear() async def start_beacon(self): if not self.beacon: self.beacon = NebulaBeacon(self, self.addr) - asyncio.create_task(self.beacon.start()) + self._beacon_task = asyncio.create_task(self.beacon.start(), name="NebulaBeacon_start") async def stop_beacon(self): if self.beacon: await self.beacon.stop() - # self.beacon = None + # Cancel the beacon task + if self._beacon_task and not self._beacon_task.done(): + logging.info("🛑 Cancelling NebulaBeacon background task...") + self._beacon_task.cancel() + try: + await self._beacon_task + except asyncio.CancelledError: + pass + self._beacon_task = None + logging.info("🛑 NebulaBeacon background task cancelled") async def modify_beacon_frequency(self, frequency): if self.beacon: await self.beacon.modify_beacon_frequency(frequency=frequency) - def is_running(self): - return self.running + async def is_running(self): + return self._running.is_set() async def find_federation(self): logging.info(f"Node {self.addr} trying to find federation...") diff --git a/nebula/core/network/forwarder.py b/nebula/core/network/forwarder.py index 0c51505b2..6b3dae7eb 100755 --- a/nebula/core/network/forwarder.py +++ b/nebula/core/network/forwarder.py @@ -19,7 +19,7 @@ class Forwarder: This class is designed to run asynchronously, leveraging the existing connection pool and message routing logic to propagate messages reliably across the network. """ - + def __init__(self, config): """ Initialize the Forwarder module. @@ -35,11 +35,12 @@ def __init__(self, config): self._cm = None self.pending_messages = asyncio.Queue() self.pending_messages_lock = Locker("pending_messages_lock", verbose=False, async_lock=True) + self._forwarder_task = None # Track the background task self.interval = self.config.participant["forwarder_args"]["forwarder_interval"] self.number_forwarded_messages = self.config.participant["forwarder_args"]["number_forwarded_messages"] self.messages_interval = self.config.participant["forwarder_args"]["forward_messages_interval"] - self.scenario_finished = asyncio.Event() + self._running = asyncio.Event() @property def cm(self): @@ -61,7 +62,8 @@ async def start(self): """ Start the forwarder by scheduling the forwarding loop as a background task. """ - asyncio.create_task(self.run_forwarder()) + self._running.set() + self._forwarder_task = asyncio.create_task(self.run_forwarder(), name="Forwarder_run_forwarder") async def run_forwarder(self): """ @@ -73,7 +75,7 @@ async def run_forwarder(self): if self.config.participant["scenario_args"]["federation"] == "CFL": logging.info("🔁 Federation is CFL. Forwarder is disabled...") return - while not self.scenario_finished.is_set(): + while await self.is_running(): # logging.debug(f"🔁 Pending messages: {self.pending_messages.qsize()}") start_time = time.time() await self.pending_messages_lock.acquire_async() @@ -81,9 +83,24 @@ async def run_forwarder(self): await self.pending_messages_lock.release_async() sleep_time = max(0, self.interval - (time.time() - start_time)) await asyncio.sleep(sleep_time) - - def shutdown(self): - self.scenario_finished.set() + + async def stop(self): + self._running.clear() + logging.info("🔁 Stopping Forwarder module...") + + # Cancel the background task + if self._forwarder_task and not self._forwarder_task.done(): + logging.info("🛑 Cancelling Forwarder background task...") + self._forwarder_task.cancel() + try: + await self._forwarder_task + except asyncio.CancelledError: + pass + self._forwarder_task = None + logging.info("🛑 Forwarder background task cancelled") + + async def is_running(self): + return self._running.is_set() async def process_pending_messages(self, messages_left): """ diff --git a/nebula/core/network/health.py b/nebula/core/network/health.py index 554467e34..8e86bc3c6 100755 --- a/nebula/core/network/health.py +++ b/nebula/core/network/health.py @@ -15,6 +15,7 @@ def __init__(self, addr, config): self.alive_interval = self.config.participant["health_args"]["send_alive_interval"] self.check_alive_interval = self.config.participant["health_args"]["check_alive_interval"] self.timeout = self.config.participant["health_args"]["alive_timeout"] + self._running = asyncio.Event() @property def cm(self): @@ -35,7 +36,7 @@ async def run_send_alive(self): # Set all connections to active at the beginning of the health module for conn in self.cm.connections.values(): conn.set_active(True) - while True: + while await self.is_running(): if len(self.cm.connections) > 0: message = self.cm.create_message("control", "alive", log="Alive message") current_connections = list(self.cm.connections.values()) @@ -51,7 +52,7 @@ async def run_send_alive(self): async def run_check_alive(self): await asyncio.sleep(self.config.participant["health_args"]["grace_time_health"] + self.check_alive_interval) - while True: + while await self.is_running(): if len(self.cm.connections) > 0: current_connections = list(self.cm.connections.values()) for conn in current_connections: @@ -69,3 +70,9 @@ async def alive(self, source): if conn.get_last_active() < current_time: logging.debug(f"🕒 Updating last active time for {source}") conn.set_active(True) + + async def is_running(self): + return self._running.is_set() + + async def stop(self): + self._running.clear() diff --git a/nebula/core/network/propagator.py b/nebula/core/network/propagator.py index 427869ae7..2f37c227b 100755 --- a/nebula/core/network/propagator.py +++ b/nebula/core/network/propagator.py @@ -21,7 +21,7 @@ class PropagationStrategy(ABC): Subclasses implement eligibility checks and payload preparation for sending model updates to specific nodes in the federation. """ - + @abstractmethod def is_node_eligible(self, node: str) -> bool: """ @@ -56,7 +56,7 @@ class InitialModelPropagation(PropagationStrategy): Sends a fresh model initialized by the trainer with a default weight. """ - + def __init__(self, aggregator: "Aggregator", trainer: "Lightning", engine: "Engine"): """ Args: @@ -111,7 +111,7 @@ class StableModelPropagation(PropagationStrategy): Sends the latest trained model to neighbors. """ - + def __init__(self, aggregator: "Aggregator", trainer: "Lightning", engine: "Engine"): """ Args: @@ -174,9 +174,10 @@ class Propagator: Designed to work asynchronously, ensuring timely and scalable message dissemination across dynamically changing network topologies. """ - + def __init__(self): self._cm = None + self._running = asyncio.Event() @property def cm(self): @@ -225,6 +226,7 @@ def start(self): indent=2, title="Propagator", ) + self._running.set() def get_round(self): """ @@ -386,3 +388,10 @@ async def get_model_information(self, dest_addr, strategy_id: str, init=False): return (serialized_model, rounds, self.get_round()) return None + + async def stop(self): + logging.info("🌐 Stopping Propagator module...") + self._running.clear() + + async def is_running(self): + return self._running.is_set() diff --git a/nebula/core/node.py b/nebula/core/node.py index a035c32e3..644f0234b 100755 --- a/nebula/core/node.py +++ b/nebula/core/node.py @@ -1,3 +1,4 @@ +import asyncio import os import random import sys @@ -37,8 +38,8 @@ from nebula.core.models.fashionmnist.mlp import FashionMNISTModelMLP from nebula.core.models.mnist.cnn import MNISTModelCNN from nebula.core.models.mnist.mlp import MNISTModelMLP -from nebula.core.role import Role from nebula.core.noderole import AggregatorNode, IdleNode, MaliciousNode, ServerNode, TrainerNode +from nebula.core.role import Role from nebula.core.training.lightning import Lightning from nebula.core.training.siamese import Siamese @@ -160,7 +161,7 @@ async def main(config): local_test_set_indices=dataset.local_test_indices, num_workers=num_workers, batch_size=batch_size, - samples_per_label = samples_per_label + samples_per_label=samples_per_label, ) trainer = None @@ -238,17 +239,9 @@ def randomize_value(value, variability): if __name__ == "__main__": config_path = str(sys.argv[1]) config = Config(entity="participant", participant_config_file=config_path) - if sys.platform == "win32" or config.participant["scenario_args"]["deployment"] == "docker": - import asyncio + try: asyncio.run(main(config), debug=False) - else: - try: - import uvloop - - uvloop.run(main(config), debug=False) - except ImportError: - logging.warning("uvloop not available, using default loop") - import asyncio - - asyncio.run(main(config), debug=False) + except Exception as e: + logging.exception(f"Error starting node {config.participant['device_args']['idx']}: {e}") + raise e diff --git a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py index 6ea1c1a6f..1f1413861 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py @@ -46,7 +46,7 @@ class SANetwork(SAMComponent): Inherits from SAMComponent to participate in the broader Situational Awareness pipeline. """ - + NEIGHBOR_VERIFICATION_TIMEOUT = 30 def __init__(self, config): @@ -61,12 +61,16 @@ def __init__(self, config): self._sar = config["sar"] # sar self._addr = config["addr"] # addr self._neighbor_policy = factory_NeighborPolicy(self._neighbor_policy) - self._restructure_process_lock = Locker(name="restructure_process_lock") + self._restructure_process_lock = Locker(name="restructure_process_lock", async_lock=True) self._restructure_cooldown = 0 self._verbose = config["verbose"] # verbose self._cm = CommunicationsManager.get_instance() self._sa_network_agent = SANetworkAgent(self) + # Track verification tasks for proper cleanup during shutdown + self._verification_tasks = set() + self._verification_tasks_lock = asyncio.Lock() + @property def sar(self) -> SAReasoner: """SA Reasoner""" @@ -376,7 +380,7 @@ async def reconnect_to_federation(self): 4. Release the restructure lock. """ logging.info("Going to reconnect with federation...") - self._restructure_process_lock.acquire() + await self._restructure_process_lock.acquire_async() await self.cm.clear_restrictions() # If we got some refs, try to reconnect to them if len(await self.np.get_nodes_known()) > 0: @@ -389,7 +393,7 @@ async def reconnect_to_federation(self): if self._verbose: logging.info("Reconnecting | NO Addrs availables") await self.sar.sad.start_late_connection_process(connected=False, msg_type="discover_nodes") - self._restructure_process_lock.release() + await self._restructure_process_lock.release_async() async def upgrade_connection_robustness(self, possible_neighbors): """ @@ -404,7 +408,7 @@ async def upgrade_connection_robustness(self, possible_neighbors): Args: possible_neighbors (set): Addresses of candidate nodes for connection enhancement. """ - self._restructure_process_lock.acquire() + await self._restructure_process_lock.acquire_async() # If we got some refs, try to connect to them if possible_neighbors and len(possible_neighbors) > 0: if self._verbose: @@ -416,7 +420,7 @@ async def upgrade_connection_robustness(self, possible_neighbors): if self._verbose: logging.info("Reestructuring | NO Addrs availables") await self.sar.sad.start_late_connection_process(connected=True, msg_type="discover_nodes") - self._restructure_process_lock.release() + await self._restructure_process_lock.release_async() async def stop_connections_with_federation(self): """ @@ -456,7 +460,24 @@ async def verify_neighbors_stablished(self, nodes: set): if neighbors: nodes_to_forget.difference_update(neighbors) logging.info(f"Connections dont stablished: {nodes_to_forget}") - self.forget_nodes(nodes_to_forget) + await self.forget_nodes(nodes_to_forget) + + async def create_verification_task(self, nodes: set): + """ + Create and track a verification task for neighbor establishment. + + Args: + nodes (set): The set of node addresses for which connections were attempted. + + Returns: + asyncio.Task: The created verification task. + """ + verification_task = asyncio.create_task(self.verify_neighbors_stablished(nodes)) + + async with self._verification_tasks_lock: + self._verification_tasks.add(verification_task) + + return verification_task async def forget_nodes(self, nodes_to_forget): """ @@ -467,6 +488,35 @@ async def forget_nodes(self, nodes_to_forget): """ await self.np.forget_nodes(nodes_to_forget) + async def stop(self): + """ + Stop the SANetwork component by releasing locks and clearing any pending operations. + """ + logging.info("🛑 Stopping SANetwork...") + + # Cancel all verification tasks + async with self._verification_tasks_lock: + if self._verification_tasks: + tasks_to_cancel = [task for task in self._verification_tasks if not task.done()] + logging.info(f"🛑 Cancelling {len(tasks_to_cancel)} verification tasks...") + for task in tasks_to_cancel: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + self._verification_tasks.clear() + logging.info("🛑 All verification tasks cancelled") + + # Release any held locks + try: + if self._restructure_process_lock.locked(): + self._restructure_process_lock.release() + except Exception as e: + logging.warning(f"Error releasing restructure_process_lock: {e}") + + logging.info("✅ SANetwork stopped successfully") + """ ############################### # SA NETWORK AGENT # ############################### @@ -489,7 +539,9 @@ async def suggest_action(self, sac: SACommand): async def notify_all_suggestions_done(self, event_type): await SuggestionBuffer.get_instance().notify_all_suggestions_done_for_agent(self, event_type) - async def create_and_suggest_action(self, saca: SACommandAction, function: Callable = None, more_suggestions=False, *args): + async def create_and_suggest_action( + self, saca: SACommandAction, function: Callable = None, more_suggestions=False, *args + ): """ Create a situational awareness command based on the specified action and suggest it for arbitration. @@ -529,7 +581,7 @@ async def create_and_suggest_action(self, saca: SACommandAction, function: Calla sa_command_state = await sac.get_state_future() # By using 'await' we get future.set_result() if sa_command_state == SACommandState.EXECUTED: (nodes_to_forget,) = args - asyncio.create_task(self._san.verify_neighbors_stablished(nodes_to_forget)) + await self._san.create_verification_task(nodes_to_forget) elif saca == SACommandAction.RECONNECT: sac = factory_sa_command( "connectivity", SACommandAction.RECONNECT, self, "", SACommandPRIO.HIGH, True, function diff --git a/nebula/core/situationalawareness/awareness/sareasoner.py b/nebula/core/situationalawareness/awareness/sareasoner.py index 30108b124..40e6de94d 100644 --- a/nebula/core/situationalawareness/awareness/sareasoner.py +++ b/nebula/core/situationalawareness/awareness/sareasoner.py @@ -88,9 +88,9 @@ def __init__( self._config = copy.deepcopy(config.participant) self._addr = config.participant["network_args"]["addr"] self._topology = config.participant["mobility_args"]["topology_type"] - self._situational_awareness_network = None + self._situational_awareness_network: SANetwork | None = None self._situational_awareness_training = None - self._restructure_process_lock = Locker(name="restructure_process_lock") + self._restructure_process_lock = Locker(name="restructure_process_lock", async_lock=True) self._restructure_cooldown = 0 self._arbitrator_notification = asyncio.Event() self._suggestion_buffer = SuggestionBuffer(self._arbitrator_notification, verbose=True) @@ -99,11 +99,11 @@ def __init__( arb_pol = config.participant["situational_awareness"]["sa_reasoner"]["arbitration_policy"] self._arbitatrion_policy = factory_arbitration_policy(arb_pol, True) self._sa_components: dict[str, SAMComponent] = {} - self._sa_discovery: ISADiscovery = None + self._sa_discovery: ISADiscovery | None = None self._verbose = config.participant["situational_awareness"]["sa_reasoner"]["verbose"] @property - def san(self) -> SANetwork: + def san(self) -> SANetwork | None: """Situational Awareness Network""" return self._situational_awareness_network @@ -123,7 +123,7 @@ def ab(self): return self._arbitatrion_policy @property - def sad(self) -> ISADiscovery: + def sad(self) -> ISADiscovery | None: """SA Discovery""" return self._sa_discovery @@ -134,7 +134,7 @@ async def init(self, sa_discovery): Args: sa_discovery (ISADiscovery): The discovery component to coordinate with. """ - self._sa_discovery: ISADiscovery = sa_discovery + self._sa_discovery = sa_discovery await self._loading_sa_components() await EventManager.get_instance().subscribe_node_event(RoundEndEvent, self._process_round_end_event) await EventManager.get_instance().subscribe_node_event(AggregationEvent, self._process_aggregation_event) @@ -154,6 +154,8 @@ def is_additional_participant(self): """ def get_restructure_process_lock(self): + if self.san is None: + raise RuntimeError("Situational Awareness Network (san) is not initialized.") return self.san.get_restructure_process_lock() """ ############################### @@ -162,44 +164,18 @@ def get_restructure_process_lock(self): """ async def get_nodes_known(self, neighbors_too=False, neighbors_only=False): - """ - Retrieve the set of nodes known to the situational awareness reasoner. - - This may include additional metadata depending on the flags. - - Args: - neighbors_too (bool, optional): If True, include neighboring nodes in the result. Defaults to False. - neighbors_only (bool, optional): If True, return only neighbors. Defaults to False. - - Returns: - set: Identifiers of known nodes based on the provided filters. - """ + if self.san is None: + raise RuntimeError("Situational Awareness Network (san) is not initialized.") return await self.san.get_nodes_known(neighbors_too, neighbors_only) async def accept_connection(self, source, joining=False): - """ - Decide whether to accept a connection request from a source node. - - Delegates to the underlying reasoner logic to determine acceptance. - - Args: - source (str): The identifier or address of the requesting node. - joining (bool, optional): If True, this connection is part of a join operation. Defaults to False. - - Returns: - bool: True if the connection should be accepted, False otherwise. - """ + if self.san is None: + raise RuntimeError("Situational Awareness Network (san) is not initialized.") return await self.san.accept_connection(source, joining) async def get_actions(self): - """ - Retrieve the list of situational awareness actions available to execute. - - Delegates to the underlying reasoner component. - - Returns: - list: Action identifiers that the reasoner can perform. - """ + if self.san is None: + raise RuntimeError("Situational Awareness Network (san) is not initialized.") return await self.san.get_actions() """ ############################### @@ -381,13 +357,15 @@ async def _initialize_sa_components(self): await sacomp.init() def _load_minimal_requirement_config(self): - #self._config["situational_awareness"]["sa_reasoner"]["sa_network"]["addr"] = self._addr - #self._config["situational_awareness"]["sa_reasoner"]["sa_network"]["sar"] = self - self._config["situational_awareness"]["sa_reasoner"]["sa_network"]["strict_topology"] = self._config["situational_awareness"]["strict_topology"] - + # self._config["situational_awareness"]["sa_reasoner"]["sa_network"]["addr"] = self._addr + # self._config["situational_awareness"]["sa_reasoner"]["sa_network"]["sar"] = self + self._config["situational_awareness"]["sa_reasoner"]["sa_network"]["strict_topology"] = self._config[ + "situational_awareness" + ]["strict_topology"] + # SA Reasoner instance for all SA Reasoner Components sar_components: dict = self._config["situational_awareness"]["sa_reasoner"]["sar_components"] - for sar_comp in sar_components.keys(): + for sar_comp in sar_components: self._config["situational_awareness"]["sa_reasoner"][sar_comp]["sar"] = self self._config["situational_awareness"]["sa_reasoner"][sar_comp]["addr"] = self._addr @@ -397,3 +375,27 @@ async def _set_minimal_requirements(self): self._situational_awareness_network = self._sa_components["sanetwork"] else: raise ValueError("SA Network not found") + + async def stop(self): + """ + Stop the SAReasoner by stopping all SA components and clearing any pending operations. + """ + logging.info("🛑 Stopping SAReasoner...") + self._arbitrator_notification.set() + + # Stop all SA components + if self._sa_components: + for component_name, component in self._sa_components.items(): + try: + # Check if component has a stop method + stop_method = getattr(component, "stop", None) + if stop_method and callable(stop_method): + if asyncio.iscoroutinefunction(stop_method): + await stop_method() + else: + stop_method() + logging.info(f"✅ Stopped SA component: {component_name}") + except Exception as e: + logging.warning(f"Error stopping SA component {component_name}: {e}") + + logging.info("✅ SAReasoner stopped successfully") diff --git a/nebula/core/situationalawareness/discovery/federationconnector.py b/nebula/core/situationalawareness/discovery/federationconnector.py index 756636e77..8beb6ba52 100644 --- a/nebula/core/situationalawareness/discovery/federationconnector.py +++ b/nebula/core/situationalawareness/discovery/federationconnector.py @@ -37,28 +37,44 @@ class FederationConnector(ISADiscovery): """ def __init__(self, aditional_participant, selector, model_handler, engine: "Engine", verbose=False): + """ + Initialize the FederationConnector. + + Args: + aditional_participant (bool): Whether this is an additional participant. + selector: The candidate selector instance. + model_handler: The model handler instance. + engine (Engine): The main engine instance. + verbose (bool): Whether to enable verbose logging. + """ self._aditional_participant = aditional_participant self._selector = selector + self._model_handler = model_handler + self._engine = engine + self._verbose = verbose + self._sar = None + + # Locks for thread safety + self._update_neighbors_lock = Locker("update_neighbors_lock", async_lock=True) + self.pending_confirmation_from_nodes_lock = Locker("pending_confirmation_from_nodes_lock", async_lock=True) + self.discarded_offers_addr_lock = Locker("discarded_offers_addr_lock", async_lock=True) + self.accept_candidates_lock = Locker("accept_candidates_lock", async_lock=True) + self.late_connection_process_lock = Locker("late_connection_process_lock", async_lock=True) + + # Data structures + self.pending_confirmation_from_nodes = set() + self.discarded_offers_addr = [] + self._background_tasks = [] # Track background tasks + print_msg_box(msg="Starting FederationConnector module...", indent=2, title="FederationConnector module") logging.info("🌐 Initializing Federation Connector") - self._engine = engine self._cm = None self.config = engine.get_config() logging.info("Initializing Candidate Selector") self._candidate_selector = factory_CandidateSelector(self._selector) logging.info("Initializing Model Handler") self._model_handler = factory_ModelHandler(model_handler) - self._update_neighbors_lock = Locker(name="_update_neighbors_lock", async_lock=True) - self.late_connection_process_lock = Locker(name="late_connection_process_lock") - self.pending_confirmation_from_nodes = set() - self.pending_confirmation_from_nodes_lock = Locker(name="pending_confirmation_from_nodes_lock", async_lock=True) - self.accept_candidates_lock = Locker(name="accept_candidates_lock") self.recieve_offer_timer = OFFER_TIMEOUT - self.discarded_offers_addr_lock = Locker(name="discarded_offers_addr_lock", async_lock=True) - self.discarded_offers_addr = [] - - self._sa_reasoner: ISAReasoner = None - self._verbose = verbose @property def engine(self): @@ -83,7 +99,7 @@ def model_handler(self): @property def sar(self): """Situational Awareness Reasoner""" - return self._sa_reasoner + return self._sar async def init(self, sa_reasoner): """ @@ -107,7 +123,7 @@ async def init(self, sa_reasoner): sa_reasoner (ISAReasoner): An instance of the situational awareness reasoner used for decision-making. """ logging.info("Building Federation Connector configurations...") - self._sa_reasoner: ISAReasoner = sa_reasoner + self._sar: ISAReasoner = sa_reasoner await self._register_message_events_callbacks() await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent, self._update_neighbors) await EventManager.get_instance().subscribe(("model", "update"), self._model_update_callback) @@ -146,10 +162,10 @@ def _still_waiting_for_candidates(self): async def _add_pending_connection_confirmation(self, addr): """ - Adds a node to the set of pending connection confirmations if it's not already known or pending. + Adds a node to the pending confirmation set and schedules a cleanup task. Args: - addr (str): The address of the node to track for pending confirmation. + addr (str): Address of the node to add to pending confirmations. """ added = False async with self._update_neighbors_lock: @@ -160,7 +176,10 @@ async def _add_pending_connection_confirmation(self, addr): self.pending_confirmation_from_nodes.add(addr) added = True if added: - asyncio.create_task(self._clear_pending_confirmations(node=addr)) + task = asyncio.create_task( + self._clear_pending_confirmations(node=addr), name=f"FederationConnector_clear_pending_{addr}" + ) + self._background_tasks.append(task) async def _remove_pending_confirmation_from(self, addr): """ @@ -362,7 +381,7 @@ async def start_late_connection_process(self, connected=False, msg_type="discove to nearby nodes. Nodes that receive this message respond with an `OFFER_MODEL` or `OFFER_METRIC` message, which contains the necessary information to evaluate and select the most suitable candidates. - The process is protected by locks to avoid race conditions, and it continues iteratively until at least + The process is protected by locks to avoid race conditions, and it continues iteratively until at least one valid candidate is found. Once candidates are selected, a connection message is sent to the best nodes. Args: @@ -377,7 +396,7 @@ async def start_late_connection_process(self, connected=False, msg_type="discove """ logging.info("🌐 Initializing late connection process..") - self.late_connection_process_lock.acquire() + await self.late_connection_process_lock.acquire_async() best_candidates = [] await self.candidate_selector.remove_candidates() @@ -393,7 +412,7 @@ async def start_late_connection_process(self, connected=False, msg_type="discove await asyncio.sleep(self.recieve_offer_timer) # acquire lock to not accept late candidates - self.accept_candidates_lock.acquire() + await self.accept_candidates_lock.acquire_async() if await self.candidate_selector.any_candidate(): if self._verbose: @@ -415,19 +434,20 @@ async def start_late_connection_process(self, connected=False, msg_type="discove if self._verbose: logging.info("Error during stablishment") - self.accept_candidates_lock.release() - self.late_connection_process_lock.release() + await self.accept_candidates_lock.release_async() + await self.late_connection_process_lock.release_async() await self.candidate_selector.remove_candidates() logging.info("🌐 Ending late connection process..") # if no candidates, repeat process else: if self._verbose: logging.info("❗️ No Candidates found...") - self.accept_candidates_lock.release() - self.late_connection_process_lock.release() + await self.accept_candidates_lock.release_async() + await self.late_connection_process_lock.release_async() if not connected: - if self._verbose: logging.info("❗️ repeating process...") - await self.start_late_connection_process(connected, msg_type, addrs_known) + if self._verbose: + logging.info("❗️ repeating process...") + await self.start_late_connection_process(connected, msg_type, addrs_known) """ ############################## # Mobility callbacks # @@ -615,6 +635,40 @@ async def _link_connect_to_callback(self, source, message): async def _link_disconnect_from_callback(self, source, message): logging.info(f"🔗 handle_link_message | Trigger | Received disconnect_from message from {source}") - addrs = message.addrs - for addr in addrs.split(): - await asyncio.create_task(self.cm.disconnect(addr, mutual_disconnection=False)) + for addr in message.addrs.split(): + await self.cm.disconnect(addr, mutual_disconnection=True) + + async def stop(self): + """ + Stop the FederationConnector by clearing pending confirmations and stopping background tasks. + """ + logging.info("🛑 Stopping FederationConnector...") + + # Cancel all background tasks + if self._background_tasks: + logging.info(f"🛑 Cancelling {len(self._background_tasks)} background tasks...") + for task in self._background_tasks: + if not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + self._background_tasks.clear() + logging.info("🛑 All background tasks cancelled") + + # Clear any pending confirmations + try: + async with self.pending_confirmation_from_nodes_lock: + self.pending_confirmation_from_nodes.clear() + except Exception as e: + logging.warning(f"Error clearing pending confirmations: {e}") + + # Clear discarded offers + try: + async with self.discarded_offers_addr_lock: + self.discarded_offers_addr.clear() + except Exception as e: + logging.warning(f"Error clearing discarded offers: {e}") + + logging.info("✅ FederationConnector stopped successfully") diff --git a/nebula/core/situationalawareness/situationalawareness.py b/nebula/core/situationalawareness/situationalawareness.py index c5bcdc845..6a5dbcbd6 100644 --- a/nebula/core/situationalawareness/situationalawareness.py +++ b/nebula/core/situationalawareness/situationalawareness.py @@ -1,3 +1,4 @@ +import asyncio from abc import ABC, abstractmethod from nebula.addons.functions import print_msg_box @@ -10,7 +11,7 @@ class ISADiscovery(ABC): Defines methods for initializing discovery, handling late connection processes, and retrieving training-related information. """ - + @abstractmethod async def init(self, sa_reasoner): """ @@ -51,7 +52,7 @@ class ISAReasoner(ABC): Defines methods for initializing the reasoner, accepting or rejecting connections, and querying known nodes and available actions. """ - + @abstractmethod async def init(self, sa_discovery): """ @@ -162,7 +163,7 @@ class SituationalAwareness: Manages discovery and reasoning components, wiring them together and exposing simple methods for initialization and late-connection handling. """ - + def __init__(self, config, engine): """ Initialize Situational Awareness module by creating discovery and reasoner instances. @@ -234,3 +235,20 @@ async def get_trainning_info(self): Any: Information relevant to training decisions. """ return await self.sad.get_trainning_info() + + async def stop(self): + """ + Stop both discovery and reasoner components if they implement a stop method. + """ + sad_stop = getattr(self.sad, "stop", None) + if callable(sad_stop): + if asyncio.iscoroutinefunction(sad_stop): + await sad_stop() + else: + sad_stop() + sar_stop = getattr(self.sar, "stop", None) + if callable(sar_stop): + if asyncio.iscoroutinefunction(sar_stop): + await sar_stop() + else: + sar_stop() diff --git a/nebula/frontend/static/js/monitor/monitor.js b/nebula/frontend/static/js/monitor/monitor.js index 9d460c16f..45a9e99ad 100644 --- a/nebula/frontend/static/js/monitor/monitor.js +++ b/nebula/frontend/static/js/monitor/monitor.js @@ -702,9 +702,6 @@ class Monitor { neighbors_distance: data.neighbors_distance || {} }); - // Check if all nodes are offline - this.checkAllNodesOffline(); - this.log('Node update completed successfully'); } catch (error) { this.error('Error handling node update:', error); @@ -1466,7 +1463,6 @@ class Monitor { startPeriodicStatusCheck() { this.log("Starting periodic status check"); this.checkNodeStatus(); - setInterval(() => this.checkNodeStatus(), 5000); } @@ -1490,14 +1486,22 @@ class Monitor { return; } + this.log('Request to monitor API successful'); + const data = await response.json(); + + // Update scenario status if provided + if (data.scenario_status) { + this.updateScenarioStatusBadge(data.scenario_status); + } + + // Process nodes as before if (!data.nodes || data.nodes.length === 0) { this.warn('No nodes in status check response'); return; } this.log('Received status check data:', data); - // Create a Set to track processed nodes in this status check const processedNodes = new Set(); @@ -1562,13 +1566,51 @@ class Monitor { this.updateAllMarkers(); this.updateAllRelatedLines(); - // Check if all nodes are offline - this.checkAllNodesOffline(); } catch (error) { this.error('Error in status check:', error); } } + updateScenarioStatusBadge(status) { + const statusBadge = document.getElementById('scenario_status'); + if (!statusBadge) return; + + // Update the data attribute + statusBadge.setAttribute('data-scenario-status', status); + + // Update the badge appearance based on status + switch (status) { + case 'running': + statusBadge.className = 'badge bg-warning-subtle text-warning px-3 py-2 ms-3'; + statusBadge.innerHTML = 'Running'; + break; + case 'completed': + statusBadge.className = 'badge bg-success-subtle text-success px-3 py-2 ms-3'; + statusBadge.innerHTML = 'Completed'; + // Hide stop button when completed + const stopButton = document.getElementById('stop_button'); + if (stopButton) { + stopButton.style.display = 'none'; + } + break; + case 'finished': + statusBadge.className = 'badge bg-danger-subtle text-danger px-3 py-2 ms-3'; + statusBadge.innerHTML = 'Finished'; + // Hide stop button when finished + const stopButtonFinished = document.getElementById('stop_button'); + if (stopButtonFinished) { + stopButtonFinished.style.display = 'none'; + } + break; + case 'not exists': + statusBadge.className = 'badge bg-secondary-subtle text-secondary px-3 py-2 ms-3'; + statusBadge.innerHTML = 'Not Found'; + break; + default: + this.warn('Unknown scenario status:', status); + } + } + updateTableRow(data) { // Validate required data if (!data || !data.uid) { @@ -1795,34 +1837,6 @@ class Monitor { }); } - checkAllNodesOffline() { - // Get all unique node IPs from markers - const allNodeIPs = new Set(Object.values(this.droneMarkers).map(marker => marker.ip)); - - // Check if all nodes are in the offlineNodes set - const allOffline = allNodeIPs.size > 0 && Array.from(allNodeIPs).every(ip => this.offlineNodes.has(ip)); - - // Update scenario status badge - const statusBadge = document.getElementById('scenario_status'); - if (statusBadge) { - if (allNodeIPs.size === 0) { - // Show Running status when there are no nodes - statusBadge.className = 'badge bg-warning-subtle text-warning px-3 py-2 ms-3'; - statusBadge.innerHTML = 'Running'; - } else if (allOffline) { - statusBadge.className = 'badge bg-danger-subtle text-danger px-3 py-2 ms-3'; - statusBadge.innerHTML = 'Finished'; - const stopButton = document.getElementById('stop_button'); - if (stopButton) { - stopButton.style.display = 'none'; - } - } else { - statusBadge.className = 'badge bg-warning-subtle text-warning px-3 py-2 ms-3'; - statusBadge.innerHTML = 'Running'; - } - } - } - // Helper method to compare two sets areSetsEqual(a, b) { if (a.size !== b.size) return false; diff --git a/nebula/frontend/templates/dashboard.html b/nebula/frontend/templates/dashboard.html index db53589ec..ac6b8a051 100755 --- a/nebula/frontend/templates/dashboard.html +++ b/nebula/frontend/templates/dashboard.html @@ -227,16 +227,7 @@
Status
class="btn btn-sm btn-outline-danger" title="Stop scenario"> - {% elif scenario.status == "completed" %} - - - - - - - {% else %} + {% elif scenario.status == "completed" or scenario.status == "finished" %} + {% else %} + + + + + + {% endif %} @@ -311,4 +311,4 @@
Status
{% endif %} -{% endblock %} \ No newline at end of file +{% endblock %} diff --git a/nebula/frontend/templates/monitor.html b/nebula/frontend/templates/monitor.html index 941444df7..c7cc61d83 100755 --- a/nebula/frontend/templates/monitor.html +++ b/nebula/frontend/templates/monitor.html @@ -36,15 +36,15 @@
Status
{% if scenario.status == "running" %} - + Running {% elif scenario.status == "completed" %} - + Completed {% else %} - + Finished {% endif %} From 8522c7d24f2546f0e5c47ec9748d617541574819 Mon Sep 17 00:00:00 2001 From: enriquetomasmb Date: Thu, 19 Jun 2025 12:11:31 +0200 Subject: [PATCH 12/12] improve error management in async functions --- nebula/core/engine.py | 20 +++++++- nebula/core/network/connection.py | 76 +++++++++++++++---------------- nebula/core/network/forwarder.py | 19 ++++---- nebula/core/node.py | 7 +++ 4 files changed, 74 insertions(+), 48 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 9ddd9609b..a88ca1ddb 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -873,6 +873,11 @@ async def shutdown(self): # Wait for tasks to complete naturally with shorter timeout try: await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=3) + except asyncio.CancelledError: + logging.warning( + "Timeout reached during task cleanup (CancelledError); proceeding with shutdown anyway." + ) + # Do not re-raise, just continue except TimeoutError: logging.warning("Some tasks did not complete in time, forcing cancellation...") for task in tasks: @@ -881,9 +886,13 @@ async def shutdown(self): # Wait a bit more for cancellations to take effect try: await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=2) + except asyncio.CancelledError: + logging.warning( + "Timeout reached during forced cancellation (CancelledError); proceeding with shutdown anyway." + ) + # Do not re-raise, just continue except TimeoutError: logging.warning("Some tasks still not responding to cancellation") - # Final aggressive cleanup - cancel all remaining tasks remaining_tasks = [ t for t in asyncio.all_tasks() if t is not asyncio.current_task() and not t.done() @@ -894,8 +903,17 @@ async def shutdown(self): task.cancel() try: await asyncio.wait_for(asyncio.gather(*remaining_tasks, return_exceptions=True), timeout=1) + except asyncio.CancelledError: + logging.warning( + "Timeout reached during final forced cancellation (CancelledError); proceeding with shutdown anyway." + ) + # Do not re-raise, just continue except TimeoutError: logging.exception("Some tasks still not responding to forced cancellation") + # Proceed anyway after all cancellation attempts + logging.warning("Proceeding with shutdown even if some tasks are still pending/cancelled.") + else: + logging.info("No remaining tasks to clean up.") logging.info("✅ Engine shutdown complete") diff --git a/nebula/core/network/connection.py b/nebula/core/network/connection.py index 3932a2bb9..39c105656 100755 --- a/nebula/core/network/connection.py +++ b/nebula/core/network/connection.py @@ -169,19 +169,23 @@ async def _monitor_inactivity(self): periodically checking if the last activity exceeds the inactivity threshold. If inactive, marks the connection as inactive and logs a warning. """ - while await self.is_running(): - if self.direct: - break - await asyncio.sleep(self.INACTIVITY_DAEMON_SLEEP_TIME) - async with self._activity_lock: - time_since_last = time.time() - self._last_activity - if time_since_last > self.INACTIVITY_TIMER: - if not self._inactivity: - self._inactivity = True - logging.info(f"[{self}] Connection marked as inactive.") - else: - if self._inactivity: - self._inactivity = False + try: + while await self.is_running(): + if self.direct: + break + await asyncio.sleep(self.INACTIVITY_DAEMON_SLEEP_TIME) + async with self._activity_lock: + time_since_last = time.time() - self._last_activity + if time_since_last > self.INACTIVITY_TIMER: + if not self._inactivity: + self._inactivity = True + logging.info(f"[{self}] Connection marked as inactive.") + else: + if self._inactivity: + self._inactivity = False + except asyncio.CancelledError: + logging.info("_monitor_inactivity cancelled during shutdown.") + return def get_federated_round(self): return self.federated_round @@ -482,21 +486,19 @@ async def handle_incoming_message(self) -> None: try: while await self.is_running(): if self.pending_messages_queue.full(): - await asyncio.sleep(0.1) # Wait a bit if the queue is full to create backpressure + await asyncio.sleep(0.1) continue header = await self._read_exactly(self.HEADER_SIZE) message_id, chunk_index, is_last_chunk = self._parse_header(header) - chunk_data = await self._read_chunk(reusable_buffer) await self._update_activity() self._store_chunk(message_id, chunk_index, chunk_data, is_last_chunk) - # logging.debug(f"Received chunk {chunk_index} of message {message_id.hex()} | size: {len(chunk_data)} bytes") - # Active connection without fails self.incompleted_reconnections = 0 if is_last_chunk: await self._process_complete_message(message_id) - except asyncio.CancelledError as e: - logging.exception(f"Message handling cancelled: {e}") + except asyncio.CancelledError: + logging.info("handle_incoming_message cancelled during shutdown.") + return except ConnectionError as e: logging.exception(f"Connection closed while reading: {e}") except Exception as e: @@ -678,27 +680,23 @@ def _decompress(self, data: bytes, compression: str) -> bytes | None: async def process_message_queue(self) -> None: """ Continuously processes messages from the pending queue. - - Behavior: - - Retrieves messages from the queue one by one. - - Delegates the message to the appropriate handler based on its type. - - Ensures the queue is marked as processed. - - Notes: - Runs indefinitely unless externally cancelled or stopped. """ - while await self.is_running(): - try: - if self.pending_messages_queue is None: - logging.error("Pending messages queue is not initialized") - return - data_type_prefix, message = await self.pending_messages_queue.get() - await self._handle_message(data_type_prefix, message) - self.pending_messages_queue.task_done() - except Exception as e: - logging.exception(f"Error processing message queue: {e}") - finally: - await asyncio.sleep(0) + try: + while await self.is_running(): + try: + if self.pending_messages_queue is None: + logging.error("Pending messages queue is not initialized") + return + data_type_prefix, message = await self.pending_messages_queue.get() + await self._handle_message(data_type_prefix, message) + self.pending_messages_queue.task_done() + except Exception as e: + logging.exception(f"Error processing message queue: {e}") + finally: + await asyncio.sleep(0) + except asyncio.CancelledError: + logging.info("process_message_queue cancelled during shutdown.") + return async def _handle_message(self, data_type_prefix: bytes, message: bytes) -> None: """ diff --git a/nebula/core/network/forwarder.py b/nebula/core/network/forwarder.py index 6b3dae7eb..86ce75536 100755 --- a/nebula/core/network/forwarder.py +++ b/nebula/core/network/forwarder.py @@ -75,14 +75,17 @@ async def run_forwarder(self): if self.config.participant["scenario_args"]["federation"] == "CFL": logging.info("🔁 Federation is CFL. Forwarder is disabled...") return - while await self.is_running(): - # logging.debug(f"🔁 Pending messages: {self.pending_messages.qsize()}") - start_time = time.time() - await self.pending_messages_lock.acquire_async() - await self.process_pending_messages(messages_left=self.number_forwarded_messages) - await self.pending_messages_lock.release_async() - sleep_time = max(0, self.interval - (time.time() - start_time)) - await asyncio.sleep(sleep_time) + try: + while await self.is_running(): + start_time = time.time() + await self.pending_messages_lock.acquire_async() + await self.process_pending_messages(messages_left=self.number_forwarded_messages) + await self.pending_messages_lock.release_async() + sleep_time = max(0, self.interval - (time.time() - start_time)) + await asyncio.sleep(sleep_time) + except asyncio.CancelledError: + logging.info("run_forwarder cancelled during shutdown.") + return async def stop(self): self._running.clear() diff --git a/nebula/core/node.py b/nebula/core/node.py index 644f0234b..267e1b28b 100755 --- a/nebula/core/node.py +++ b/nebula/core/node.py @@ -235,6 +235,13 @@ def randomize_value(value, variability): if node.cm is not None: await node.cm.network_wait() + # Ensure shutdown is always called and awaited before main() returns + if hasattr(node, "shutdown") and callable(node.shutdown): + logging.info("Calling node.shutdown() for final cleanup and Docker removal...") + await node.shutdown() + else: + logging.warning("Node does not have a shutdown() method; skipping explicit shutdown.") + if __name__ == "__main__": config_path = str(sys.argv[1])