From 83c3936f2a658e618c8e838775bf4282e280d0e9 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 3 Jul 2024 19:59:51 +0200 Subject: [PATCH 001/233] =?UTF-8?q?Servicio=20de=20conexi=C3=B3n=20UDP?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ImplementaciΓ³n de servicio cliente/servidor para que dispositivos ajenos a la federaciΓ³n entablen comunicaciΓ³n --- nebula/core/network/communications.py | 30 +++ .../core/network/externalconnectionservice.py | 15 ++ nebula/core/network/nebulaconnection.py | 195 ++++++++++++++++++ 3 files changed, 240 insertions(+) create mode 100644 nebula/core/network/externalconnectionservice.py create mode 100644 nebula/core/network/nebulaconnection.py diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 85203c21a..9966ef6d2 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -17,6 +17,7 @@ from nebula.core.pb import nebula_pb2 from nebula.core.network.messages import MessagesManager from nebula.core.network.connection import Connection +from nebula.core.network.nebulaconnection import NebulaConnectionService from nebula.core.utils.locker import Locker from nebula.core.utils.helper import ( @@ -71,6 +72,14 @@ def __init__(self, engine: "Engine"): self.connections_reconnect = [] self.max_connections = 1000 self.network_engine = None + + # Connection service to communicate with external devices + self._external_connection_service = None + + # The line below is neccesary when mobility would be set up + #if self.config.participant["mobility_args"]["mobility"] and not self.config.participant["mobility_args"]["late_creation"]: + self._external_connection_service = NebulaConnectionService(self.addr) + self.ecs.start() self.stop_network_engine = asyncio.Event() @@ -105,6 +114,10 @@ def propagator(self): @property def mobility(self): return self._mobility + + @property + def ecs(self): + return self._external_connection_service async def handle_incoming_message(self, data, addr_from): try: @@ -227,6 +240,23 @@ async def handle_connection_message(self, source, message): except Exception as e: logging.error(f"πŸ”— handle_connection_message | Error while processing: {message.action} | {e}") + def _start_external_connection_service(self): + self.ecs = NebulaConnectionService(self.addr) + self.ecs.start() + + async def establish_connection_with_federation(self, message): + """ + Using ExternalConnectionService to get addrs on local network, after that + stablishment of TCP connection and send the message broadcasted + + Args: + message to be sent + """ + addrs = self.ecs.find_federation() + for addr in addrs: + await self.cm.connect(addr, direct=False) + await self.send_message(addr, message) + def get_connections_lock(self): return self.connections_lock diff --git a/nebula/core/network/externalconnectionservice.py b/nebula/core/network/externalconnectionservice.py new file mode 100644 index 000000000..90b8d4ff1 --- /dev/null +++ b/nebula/core/network/externalconnectionservice.py @@ -0,0 +1,15 @@ +from abc import ABC, abstractmethod + +class ExternalConnectionService(ABC): + + @abstractmethod + def start(self): + pass + + @abstractmethod + def stop(self): + pass + + @abstractmethod + def find_federation(self): + pass \ No newline at end of file diff --git a/nebula/core/network/nebulaconnection.py b/nebula/core/network/nebulaconnection.py new file mode 100644 index 000000000..8a0b0a196 --- /dev/null +++ b/nebula/core/network/nebulaconnection.py @@ -0,0 +1,195 @@ + +import os +import socket +import sys +import platform +import time +import threading +import logging +from nebula.core.utils.locker import Locker +from nebula.core.network.externalconnectionservice import ExternalConnectionService + +class NebulaServer(threading.Thread): + + BCAST_IP = '239.255.255.250' + UPNP_PORT = 1900 + IP = '0.0.0.0' + M_SEARCH_REQ_MATCH = "M-SEARCH" + + def __init__(self, nebula_service: "NebulaConnectionService", addr): + threading.Thread.__init__(self) + self.interrupted = False + self.ns = nebula_service + self.addr = addr + + def run(self): + self.listen() + + def stop(self): + self.interrupted = True + logging.info("Nebula upnp server stop") + + def listen(self): + """ + Listen on broadcast addr with standard 1900 port + It will reponse a standard ssdp message with blockchain ip and port info if receive a M_SEARCH message + """ + try: + macro = socket.SO_REUSEPORT + os_name = platform.system() + if os_name == "Windows": + macro = socket.SO_REUSEADDR + + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.setsockopt(socket.SOL_SOCKET, macro, 1) + sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, socket.inet_aton(self.BCAST_IP) + socket.inet_aton(self.IP)) + sock.bind((self.IP, self.UPNP_PORT)) + sock.settimeout(1) + logging.info("Nebula upnp server is listening...") + while True: + try: + data, addr = sock.recvfrom(1024) + except socket.error: + if self.interrupted: + sock.close() + return + else: + if self._is_nebula_message(data): + self.respond(addr) + time.sleep(1) + self.stop() + except Exception as e: + logging.info('Error in Nebula npnp server listening: %s', e) + + def _is_nebula_message(self, msg): + msg_str = msg.decode('utf-8') + return "ST: urn:nebula-service" in msg_str + + def respond(self, addr): + try: + #local_ip = # FIND THE IP + UPNP_RESPOND = """HTTP/1.1 200 OK + CACHE-CONTROL: max-age=1800 + ST: urn:nebula-service + EXT: + LOCATION: {} + """.format( + self.addr + ).replace("\n", "\r\n") + outSock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + outSock.sendto(UPNP_RESPOND.encode('ASCII'), addr) + outSock.close() + except Exception as e: + logging.info('Error in Nebula upnp response message to client %s', e) + +class NebulaClient(threading.Thread): + # 30 seconds for search_interval + SEARCH_INTERVAL = 5 + BCAST_IP = '239.255.255.250' + BCAST_PORT = 1900 + + def __init__(self, nebula_service: "NebulaConnectionService"): + threading.Thread.__init__(self) + self.interrupted = False + self.ns = nebula_service + + def run(self): + self.keep_search() + + def stop(self): + self.interrupted = True + logging.info(" Nebula upnp client stop") + + def keep_search(self): + """ + run search function every SEARCH_INTERVAL + """ + try: + while True: + self.search() + for x in range(self.SEARCH_INTERVAL): + time.sleep(1) + if self.interrupted: + return + except Exception as e: + logging.info('Error in Nebula upnp client keep search %s', e) + + def search(self): + """ + broadcast SSDP DISCOVER message to LAN network + filter our protocal and add to network + """ + try: + SSDP_DISCOVER = ('M-SEARCH * HTTP/1.1\r\n' + + 'HOST: 239.255.255.250:1900\r\n' + + 'MAN: "ssdp:discover"\r\n' + + 'MX: 1\r\n' + + 'ST: urn:nebula-service\r\n' + + '\r\n') + + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.sendto(SSDP_DISCOVER.encode('ASCII'), (self.BCAST_IP, self.BCAST_PORT)) + sock.settimeout(3) + while True: + data, addr = sock.recvfrom(1024) + if self._is_nebula_message(data): + self.ns.response_recieved(data, addr) + except: + sock.close() + + def _is_nebula_message(self, msg): + msg_str = msg.decode('utf-8') + return "ST: urn:nebula-service" in msg_str + +class NebulaConnectionService(ExternalConnectionService): + + def __init__(self, addr): + self.addrs_found_lock = Locker(name="addrs_found_lock") + self.nodes_found = [] + self.repeatsearch_interval = 3 + self.addr = addr + self.server = None + self.client = None + + def start(self): + self.server = NebulaServer(self, self.addr) + self.server.start() + + def stop(self): + self.server.stop + + def find_federation(self): + """ + Initialization client thread to send broadcast discover to federation + """ + logging.info(f"Node {self.addr} trying to find federation..") + self.nodes_found = [] + self.client = NebulaClient(self) + self.client.start() + time.sleep(self.repeatsearch_interval) + while not len(self.get_nodes()): + time.sleep(self.repeatsearch_interval) + self.client.stop() + + def response_recieved(self, data, addr): + print("NebulaMulticastingService: Response recieved") + msg_str = data.decode('utf-8') + self._add_addr(msg_str) + + def _add_addr(self, msg_str): + self.mutex.acquire() + lineas = msg_str.splitlines() + # Buscar la lΓ­nea que contiene "LOCATION: " + for linea in lineas: + if linea.strip().startswith("LOCATION:"): + addr = linea.split(": ")[1].strip() + break + self.nodes_found.append(addr) + self.mutex.release() + + def get_nodes(self): + self.mutex.acquire() + cp = self.nodes_found.copy() + self.mutex.release() + return cp + \ No newline at end of file From cbeb58fe3d5311dbf0e622e9a581d80fab7a0dca Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 3 Jul 2024 22:13:11 +0200 Subject: [PATCH 002/233] fix_filename Changed file name: nebulamulticasting from nebulaconnection --- nebula/core/network/communications.py | 2 +- .../core/network/{nebulaconnection.py => nebulamulticasting.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename nebula/core/network/{nebulaconnection.py => nebulamulticasting.py} (100%) diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 9966ef6d2..f5e7d3f68 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -17,7 +17,7 @@ from nebula.core.pb import nebula_pb2 from nebula.core.network.messages import MessagesManager from nebula.core.network.connection import Connection -from nebula.core.network.nebulaconnection import NebulaConnectionService +from nebula.core.network.nebulamulticasting import NebulaConnectionService from nebula.core.utils.locker import Locker from nebula.core.utils.helper import ( diff --git a/nebula/core/network/nebulaconnection.py b/nebula/core/network/nebulamulticasting.py similarity index 100% rename from nebula/core/network/nebulaconnection.py rename to nebula/core/network/nebulamulticasting.py From 3e9ad7b7e94349e0cb0ba5767863522542cb2f27 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Sat, 6 Jul 2024 11:30:02 +0200 Subject: [PATCH 003/233] fix_nebulamulticasting Additions: -communication - init_ecs - stop_ecs Remove: -nebulamulticasting -stop condition on nebula_server --- nebula/core/network/communications.py | 7 +++++++ nebula/core/network/nebulamulticasting.py | 5 ++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index f5e7d3f68..8953800ec 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -244,6 +244,13 @@ def _start_external_connection_service(self): self.ecs = NebulaConnectionService(self.addr) self.ecs.start() + def stop_external_connection_service(self): + self.ecs.stop() + + def init_external_connection_service(self): + self.ecs = NebulaConnectionService(self.addr) + self.start_external_connection_service() + async def establish_connection_with_federation(self, message): """ Using ExternalConnectionService to get addrs on local network, after that diff --git a/nebula/core/network/nebulamulticasting.py b/nebula/core/network/nebulamulticasting.py index 8a0b0a196..0612a0789 100644 --- a/nebula/core/network/nebulamulticasting.py +++ b/nebula/core/network/nebulamulticasting.py @@ -56,8 +56,8 @@ def listen(self): else: if self._is_nebula_message(data): self.respond(addr) - time.sleep(1) - self.stop() + #time.sleep(1) + #self.stop() except Exception as e: logging.info('Error in Nebula npnp server listening: %s', e) @@ -172,7 +172,6 @@ def find_federation(self): self.client.stop() def response_recieved(self, data, addr): - print("NebulaMulticastingService: Response recieved") msg_str = data.decode('utf-8') self._add_addr(msg_str) From 7c039a3375ae34e44f30020d95ff79bdbe496c77 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 10 Jul 2024 17:22:17 +0200 Subject: [PATCH 004/233] feat_mobility_module Additions: -neighbormanagement folder -mobility messages -proto files updated -functionalities on comunnications and engine to new messages - get_loss on nebula model - model_info on propagator --- nebula/core/engine.py | 161 +++++++++++- nebula/core/models/nebulamodel.py | 6 + nebula/core/neighbormanagement/README.txt | 64 +++++ nebula/core/neighbormanagement/__init__.py | 0 .../candidateselection/__init__.py | 0 .../candidateselection/candidateselector.py | 37 +++ .../candidateselection/fccandidateselector.py | 36 +++ .../hetcandidateselector.py | 109 +++++++++ .../ringcandidateselector.py | 37 +++ .../modelhandlers/__init__.py | 0 .../modelhandlers/aggmodelhandler.py | 52 ++++ .../modelhandlers/modelhandler.py | 31 +++ .../modelhandlers/stdmodelhandler.py | 50 ++++ .../neighborpolicies/__init__.py | 0 .../neighborpolicies/fcneighborpolicy.py | 98 ++++++++ .../neighborpolicies/idleneighborpolicy.py | 30 +++ .../neighborpolicies/neighborpolicy.py | 51 ++++ .../neighborpolicies/ringneighborpolicy.py | 94 +++++++ .../neighborpolicies/starneighborpolicy.py | 79 ++++++ nebula/core/neighbormanagement/nodemanager.py | 231 ++++++++++++++++++ nebula/core/network/communications.py | 38 ++- nebula/core/network/messages.py | 39 +++ nebula/core/network/propagator.py | 22 ++ nebula/core/pb/nebula.proto | 36 +++ nebula/core/pb/nebula_pb2.py | 81 +++--- .../frontend/config/participant.json.example | 10 +- 26 files changed, 1354 insertions(+), 38 deletions(-) create mode 100644 nebula/core/neighbormanagement/README.txt create mode 100644 nebula/core/neighbormanagement/__init__.py create mode 100644 nebula/core/neighbormanagement/candidateselection/__init__.py create mode 100644 nebula/core/neighbormanagement/candidateselection/candidateselector.py create mode 100644 nebula/core/neighbormanagement/candidateselection/fccandidateselector.py create mode 100644 nebula/core/neighbormanagement/candidateselection/hetcandidateselector.py create mode 100644 nebula/core/neighbormanagement/candidateselection/ringcandidateselector.py create mode 100644 nebula/core/neighbormanagement/modelhandlers/__init__.py create mode 100644 nebula/core/neighbormanagement/modelhandlers/aggmodelhandler.py create mode 100644 nebula/core/neighbormanagement/modelhandlers/modelhandler.py create mode 100644 nebula/core/neighbormanagement/modelhandlers/stdmodelhandler.py create mode 100644 nebula/core/neighbormanagement/neighborpolicies/__init__.py create mode 100644 nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py create mode 100644 nebula/core/neighbormanagement/neighborpolicies/idleneighborpolicy.py create mode 100644 nebula/core/neighbormanagement/neighborpolicies/neighborpolicy.py create mode 100644 nebula/core/neighbormanagement/neighborpolicies/ringneighborpolicy.py create mode 100644 nebula/core/neighbormanagement/neighborpolicies/starneighborpolicy.py create mode 100644 nebula/core/neighbormanagement/nodemanager.py diff --git a/nebula/core/engine.py b/nebula/core/engine.py index c240bd7e2..fb327f1be 100755 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -13,6 +13,7 @@ from nebula.core.utils.nebulalogger_tensorboard import NebulaTensorBoardLogger from nebula.core.utils.nebulalogger import NebulaLogger from nebula.core.utils.locker import Locker +from nebula.core.neighbormanagement.nodemanager import NodeManager logging.getLogger("requests").setLevel(logging.WARNING) logging.getLogger("urllib3").setLevel(logging.WARNING) @@ -165,15 +166,39 @@ def __init__( # Register additional callbacks self._event_manager.register_event((nebula_pb2.FederationMessage, nebula_pb2.FederationMessage.Action.REPUTATION), self._reputation_callback) + + self._event_manager.register_event((nebula_pb2.DiscoverMessage, nebula_pb2.DiscoverMessage.Action.DISCOVER_JOIN), self._discover_discover_join_callback) + self._event_manager.register_event((nebula_pb2.DiscoverMessage, nebula_pb2.DiscoverMessage.Action.DISCOVER_NODE), self._discover_discover_node_callback) + + self._event_manager.register_event((nebula_pb2.OfferMessage, nebula_pb2.OfferMessage.Action.OFFER_METRIC), self._offer_offer_metric_callback) + self._event_manager.register_event((nebula_pb2.OfferMessage, nebula_pb2.OfferMessage.Action.OFFER_MODEL), self._offer_offer_model_callback) + + self._event_manager.register_event((nebula_pb2.ConnectionMessage, nebula_pb2.ConnectionMessage.Action.LATE_CONNECT), self._connection_late_connect_callback) + self._event_manager.register_event((nebula_pb2.ConnectionMessage, nebula_pb2.ConnectionMessage.Action.RESTRUCTURE), self._connection_late_connect_callback) + + self._event_manager.register_event((nebula_pb2.LinkMessage, nebula_pb2.LinkMessage.Action.CONNECT_TO), self._link_connect_to_callback) + self._event_manager.register_event((nebula_pb2.LinkMessage, nebula_pb2.LinkMessage.Action.DISCONNECT_FROM), self._link_disconnect_from_callback) # ... # Thread for the trainer service, it is created when the learning starts self.trainer_service = None + + if self.config.participant["mobility_args"]["mobility"]: + topology = self.config.participant["mobility_args"]["mobility_type"] + model_handler = self.config.participant["mobility_args"]["model_handler"] + self._node_manager = NodeManager(topology, model_handler, engine=self) + if self.config.participant["mobility_args"]["late_creation"]: + self._init_late_node() + @property def cm(self): return self._cm + @property + def nm(self): + return self._node_manager + @property def reporter(self): return self._reporter @@ -248,11 +273,13 @@ async def _connection_connect_callback(self, source, message): if source not in self.cm.get_addrs_current_connections(myself=True): logging.info(f"πŸ”— handle_connection_message | Trigger | Connecting to {source}") await self.cm.connect(source, direct=True) + self.nm.update_neighbors(source) @event_handler(nebula_pb2.ConnectionMessage, nebula_pb2.ConnectionMessage.Action.DISCONNECT) async def _connection_disconnect_callback(self, source, message): logging.info(f"πŸ”— handle_connection_message | Trigger | Received disconnection message from {source}") await self.cm.disconnect(source, mutual_disconnection=False) + self.nm.update_neighbors(source, remove=True) @event_handler(nebula_pb2.FederationMessage, nebula_pb2.FederationMessage.Action.FEDERATION_START) async def _start_federation_callback(self, source, message): @@ -287,10 +314,111 @@ def _federation_models_included_callback(self, source, message): finally: self.cm.get_connections_lock().release() - def create_trainer_service(self): + @event_handler(nebula_pb2.ConnectionMessage, nebula_pb2.ConnectionMessage.Action.LATE_CONNECT) + async def _connection_late_connect_callback(self, source, message): + logging.info(f"πŸ”— handle_connection_message | Trigger | Received late_connect message from {source}") + if self.nm.accept_connection(source, joining=True): + self.nm.add_weight_modifier(source) + ct_actions , df_actions = self.nm.get_actions() + + # connect to + for addr in ct_actions.split(): + cnt_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.CONNECTO_TO, addr) + await self.cm.send_message(source, cnt_msg) + + # disconnect from + for addr in df_actions.split(): + df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, addr) + await self.cm.send_message(source, df_msg) + + await self.cm.connect(source, direct=True) + self.nm.update_neighbors(source) + + @event_handler(nebula_pb2.ConnectionMessage, nebula_pb2.ConnectionMessage.Action.RESTRUCTURE) + async def _connection_disconnect_callback(self, source, message): + logging.info(f"πŸ”— handle_connection_message | Trigger | Received restructure message from {source}") + if self.nm.accept_connection(source, joining=False): + logging.info(f"πŸ”— handle_connection_message | Trigger | restructure connection accepted from {source}") + ct_actions , df_actions = self.nm.get_actions() + + for addr in ct_actions.split(): + cnt_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.CONNECTO_TO, addr) + await self.cm.send_message(source, cnt_msg) + + for addr in df_actions.split(): + df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, addr) + await self.cm.send_message(source, df_msg) + else: + logging.info(f"πŸ”— handle_connection_message | Trigger | restructure connection denied from {source}") + await self.cm.disconnect(source, mutual_disconnection=False) + self.nm.update_neighbors(source, remove=True) + + @event_handler(nebula_pb2.DiscoverMessage, nebula_pb2.DiscoverMessage.Action.DISCOVER_JOIN) + async def _discover_discover_join_callback(self, source, message): + logging.info(f"πŸ” handle_discover_message | Trigger | Received discover_join message from {source} ") + + self.nm.meet_node(source) + # if no neighbors means i'm new + if len(self.get_federation_nodes()) > 0: + model, rounds, round = self.cm.propagator.get_model_information(source, "stable") if self.get_round() > 0 else self.cm.propagator.get_model_information(source, "initialization") + epochs = self.config.participant["training_args"]["epochs"] + msg = self.cm.mm.generate_offer_message( + nebula_pb2.OfferMessage.Action.OFFER_MODEL, + len(self.get_federation_nodes()), + self.trainer.get_loss(), + model, + rounds, + round, + epochs + ) + await self.cm.send_message(source, msg) + + @event_handler(nebula_pb2.DiscoverMessage, nebula_pb2.DiscoverMessage.Action.DISCOVER_NODE) + async def _discover_discover_node_callback(self, source, message): + logging.info(f"πŸ” handle_discover_message | Trigger | Received discover_node message from {source} ") + self.nm.meet_node(source) + msg = self.cm.mm.generate_offer_message(nebula_pb2.OfferMessage.Action.OFFER_METRIC, len(self.get_federation_nodes()), self.trainer.get_loss()) + await self.cm.send_message(source, msg) + + @event_handler(nebula_pb2.OfferMessage, nebula_pb2.OfferMessage.Action.OFFER_MODEL) + async def _offer_offer_model_callback(self, source, message): + logging.info(f"πŸ” handle_offer_message | Trigger | Received offer_model message from {source}") + if not self.nm.get_restructure_process_lock().locked(): + decoded_model = self.trainer.deserialize_model(message.parameters) + self.nm.accept_model(source, decoded_model, message.rounds, message.round, message.epochs, message.n_neighbors, message.loss) + self.nm.add_candidate(source, message.n_neighbors, message.loss) + self.nm.meet_node(source) + + @event_handler(nebula_pb2.OfferMessage, nebula_pb2.OfferMessage.Action.OFFER_METRIC) + async def _offer_offer_metric_callback(self, source, message): + logging.info(f"πŸ” handle_offer_message | Trigger | Received offer_metric message from {source}") + if not self.nm.get_restructure_process_lock().locked(): + n_neighbors, loss, _, _, _, _ = message.arguments + self.nm.add_candidate(source, n_neighbors, loss) + self.nm.meet_node(source) + + @event_handler(nebula_pb2.LinkMessage, nebula_pb2.LinkMessage.Action.CONNECTO_TO) + async def _link_connect_to_callback(self, source, message): + logging.info(f"πŸ”— handle_link_message | Trigger | Received connecto_to message from {source}") + addrs = message.arguments + for addr in addrs: + await self.cm.connect(addr, direct=True) + self.nm.update_neighbors(addr) + self.nm.meet_node(source) + + @event_handler(nebula_pb2.LinkMessage, nebula_pb2.LinkMessage.Action.DISCONNECT_FROM) + async def _link_disconnect_from_callback(self, source, message): + logging.info(f"πŸ”— handle_link_message | Trigger | Received disconnect_from message from {source}") + addrs = message.arguments + for addr in addrs: + await self.cm.disconnect(source, mutual_disconnection=False) + self.nm.update_neighbors(addr, remove=True) + + def create_trainer_service(self, round=0): if self.trainer_service is None: self.trainer_service = threading.Thread( target=self._start_learning, + args=(round,), daemon=True, name="trainer_service_thread-" + self.addr, ) @@ -335,14 +463,14 @@ async def deploy_federation(self): else: logging.info(f"πŸ’€ Waiting until receiving the start signal from the start node") - def _start_learning(self): + def _start_learning(self, round=0): self.learning_cycle_lock.acquire() try: if self.round is None: self.total_rounds = self.config.participant["scenario_args"]["rounds"] epochs = self.config.participant["training_args"]["epochs"] self.get_round_lock().acquire() - self.round = 0 + self.round = round self.get_round_lock().release() self.learning_cycle_lock.release() print_msg_box(msg=f"Starting Federated Learning process...", indent=2, title="Start of the experiment") @@ -507,6 +635,33 @@ async def send_reputation(self, malicious_nodes): await self.cm.send_message_to_neighbors(message) + def _init_late_node(self): + """ + Method to initialize a late connected node, creating its trainer and setting up the learning process + + First step broadcasting discover message, after that we select candidates and connect to them. + The information to create the trainer is recieved from nodes that are already on federation and answared the discover message. + -model: params + -rounds: total rounds + -round: current round of the learning process + -epochs: epochs + """ + logging.info("🌐 Initializing late creation node life from Engine") + model, rounds, round, epochs = self.nm.start_late_connection_process() + + self.config.participant["scenario_args"]["rounds"] = rounds + self.config.participant["training_args"]["epochs"] = epochs + + self.round = round + + #self._trainer = trainer(model, self.dataset, config=self.config, logger=nebulalogger) + self.trainer.set_model_parameters(model, initialize=True) + + self.set_initialization_status(True) + self.get_federation_ready_lock().release() + self._create_trainer_service(round=round) + self.cm.start_external_connection_service() + class MaliciousNode(Engine): def __init__(self, model, dataset, config=Config, trainer=Lightning, security=False, model_poisoning=False, poisoned_ratio=0, noise_type="gaussian"): diff --git a/nebula/core/models/nebulamodel.py b/nebula/core/models/nebulamodel.py index 89f00866d..456e28a20 100755 --- a/nebula/core/models/nebulamodel.py +++ b/nebula/core/models/nebulamodel.py @@ -153,6 +153,8 @@ def __init__( torch.cuda.manual_seed_all(seed) self.global_number = {"Train": 0, "Validation": 0, "Test (Local)": 0, "Test (Global)": 0} + + self.current_loss = -1 # not calculated yet @abstractmethod def forward(self, x): @@ -164,6 +166,9 @@ def configure_optimizers(self): """Optimizer configuration.""" pass + def get_loss(self): + return self.current_loss + def step(self, batch, batch_idx, phase): """Training/validation/test step.""" x, y = batch @@ -171,6 +176,7 @@ def step(self, batch, batch_idx, phase): loss = self.criterion(y_pred, y) self.process_metrics(phase, y_pred, y, loss) + self.current_loss = loss return loss def training_step(self, batch, batch_idx): diff --git a/nebula/core/neighbormanagement/README.txt b/nebula/core/neighbormanagement/README.txt new file mode 100644 index 000000000..ee4b08556 --- /dev/null +++ b/nebula/core/neighbormanagement/README.txt @@ -0,0 +1,64 @@ + β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•— β–ˆβ–ˆβ•— β–ˆβ–ˆβ•—β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•—β–ˆβ–ˆβ•— β–ˆβ–ˆβ•— β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•— β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•— +β–ˆβ–ˆβ•”β•β•β–ˆβ–ˆβ•—β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β•šβ•β•β–ˆβ–ˆβ•”β•β•β•β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•”β•β•β•β–ˆβ–ˆβ•—β–ˆβ–ˆβ•”β•β•β–ˆβ–ˆβ•— +β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•‘β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•‘β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•”β• +β–ˆβ–ˆβ•”β•β•β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•”β•β•β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•”β•β–ˆβ–ˆβ•— +β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β•šβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•”β• β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β•šβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•”β•β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•— +β•šβ•β• β•šβ•β• β•šβ•β•β•β•β•β• β•šβ•β• β•šβ•β• β•šβ•β• β•šβ•β•β•β•β•β• β•šβ•β• β•šβ•β• + +Alejandro AvilΓ©s Serrano. + + +This module implements the functionality for managing federation nodes in terms of mobility. +To achieve this, a NodeManager class will manage the processes of establishing a connection +with the federation when a node wants to join it, receiving models from the federation, +selecting the best candidates to connect with, and deciding, once inside the federation, +when it is necessary or convenient to establish connections with more nodes, +as well as applying strategies to increase the relevance of information coming from recently joined nodes in the federation. + +To accomplish this, it relies on three main blocks that together develop the previously described activity, namely: +-> Candidate Selection Module +-> Received Models Handling Module +-> Neighbor Policy Module + +The first two are self-explanatory. Surely, the third one may raise more questions. +Neighbor policies refer to, for example, restrictions on neighbor aggregation due to the specific topology of the use case. +Similarly, it will have information about the current neighbors of the node and the rest of the known nodes, +which will allow it to make decisions about when it might be an appropriate time to initiate a process of establishing new connections. + + +The process of discovering potential candidates is based on an exchange of discover and offer messages, +with certain characteristics differing depending on whether it is a federation joining or a network restructuring. +After this exchange of messages, the candidate selector will choose the most promising ones and a connection will be made to them. +It is important to note that the receiving node can reject the connection. + +1) Establish Communication + __________ ________________ +| New node | -------> --------> *DISCOVER* --------> --------> | Federation node | +|__________| | _______________ | + + __________ _________________ +| New node | -------> --------> *OFFER* --------> --------> | Federation node | +|__________| | _______________ | + +2) Select Candidates and connect to them + + __________ ____________________ ___________ +| New node | -------> | Candidate Selector | ----> *CONNECT* ----> | Candidate | +|__________| | __________________ | | _________ | + +Retopology works the same way but with diferent arguments on the messages. + +The supported topologies are: +-> RING +-> FULLY CONNECTED +-> RANDOM +-> STAR (not yet) + +If you want to make new implementations, use the interfaces provide. + + + ######################## + ### WORK IN PROGRESS ### + ######################## + +Currently working on retopology process. \ No newline at end of file diff --git a/nebula/core/neighbormanagement/__init__.py b/nebula/core/neighbormanagement/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/nebula/core/neighbormanagement/candidateselection/__init__.py b/nebula/core/neighbormanagement/candidateselection/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/nebula/core/neighbormanagement/candidateselection/candidateselector.py b/nebula/core/neighbormanagement/candidateselection/candidateselector.py new file mode 100644 index 000000000..6c05a57a9 --- /dev/null +++ b/nebula/core/neighbormanagement/candidateselection/candidateselector.py @@ -0,0 +1,37 @@ +from abc import ABC, abstractmethod + +class CandidateSelector(ABC): + + @abstractmethod + def set_config(self, config): + pass + + @abstractmethod + def add_candidate(self, candidate): + pass + + @abstractmethod + def select_candidates(self): + pass + + @abstractmethod + def remove_candidates(self): + pass + + @abstractmethod + def any_candidate(self): + pass + +def factory_CandidateSelector(topology): + from nebula.core.neighbormanagement.candidateselection.fccandidateselector import FCCandidateSelector + from nebula.core.neighbormanagement.candidateselection.hetcandidateselector import HETCandidateSelector + from nebula.core.neighbormanagement.candidateselection.ringcandidateselector import RINGCandidateSelector + + options = { + 'ring': RINGCandidateSelector, + "fully": FCCandidateSelector, + "random": HETCandidateSelector + } + + cs = options.get(topology) + return cs() \ No newline at end of file diff --git a/nebula/core/neighbormanagement/candidateselection/fccandidateselector.py b/nebula/core/neighbormanagement/candidateselection/fccandidateselector.py new file mode 100644 index 000000000..0cfeded70 --- /dev/null +++ b/nebula/core/neighbormanagement/candidateselection/fccandidateselector.py @@ -0,0 +1,36 @@ +from nebula.core.neighbormanagement.candidateselection.candidateselector import CandidateSelector +from nebula.core.utils.locker import Locker + +class FCCandidateSelector(CandidateSelector): + + def __init__(self): + self.candidates = [] + self.candidates_lock = Locker(name="candidates_lock") + + def set_config(self, config): + pass + + def add_candidate(self, candidate): + self.candidates_lock.acquire() + self.candidates.append(candidate) + self.candidates_lock.release() + + def select_candidates(self): + """ + In Fully-Connected topology all candidates should be selected + """ + self.candidates_lock.acquire() + cdts = self.candidates.copy() + self.candidates_lock.release() + return cdts + + def remove_candidates(self): + self.candidates_lock.acquire() + self.candidates = [] + self.candidates_lock.release() + + def any_candidate(self): + self.candidates_lock.acquire() + any = True if len(self.candidates) > 0 else False + self.candidates_lock.release() + return any \ No newline at end of file diff --git a/nebula/core/neighbormanagement/candidateselection/hetcandidateselector.py b/nebula/core/neighbormanagement/candidateselection/hetcandidateselector.py new file mode 100644 index 000000000..61a302007 --- /dev/null +++ b/nebula/core/neighbormanagement/candidateselection/hetcandidateselector.py @@ -0,0 +1,109 @@ +from nebula.core.neighbormanagement.candidateselection.candidateselector import CandidateSelector +from nebula.core.utils.locker import Locker + +class HETCandidateSelector(CandidateSelector): + + def __init__(self): + self.candidates = [] + self.loss, self.weight_distance, self.weight_hetereogeneity = (0, 0.2, 0.8) + self.candidates_lock = Locker(name="candidates_lock") + + def set_config(self, config): + """ + Args: + config contains values to evaluate suitability of candidades + """ + self.loss, self.weight_distance, self.weight_hetereogeneity = config + + def add_candidate(self, candidate): + """ + Args: + candidate is compound of three data: + - candidate.addr + - candidate number of neighbors + - candidate current model loss + """ + addr, n_neighbors, loss = candidate + hv = self.__calculate_hetereogeneity(loss) + self.candidates_lock.acquire() + self.candidates.append((addr, n_neighbors, hv)) + self.candidates_lock.release() + + def select_candidates(self): + """ + Calculate suitability of candidates and sort them, then return + the average number of candidates calculated using info from federation nodes + + Returns: + best 'n' candidates + """ + self.candidates_lock.acquire() + bc = self.__suitability_function() + n = self.__calculate_ideal_neighbors() + self.candidates_lock.release() + n = n if n > 0 else len(bc) + return [addr for addr, _ in bc[:n]] + + def remove_candidates(self): + self.candidates_lock.acquire() + self.candidates = [] + self.candidates_lock.release() + + def any_candidate(self): + self.candidates_lock.acquire() + any = True if len(self.candidates) > 0 else False + self.candidates_lock.release() + return any + + def __calculate_ideal_neighbors(self): + """ + Returns: + Average number of neighbors in candidate nodes + """ + average_neighbors = 0 + if(len(self.candidates)): + n_neighbors = [ pn[1] for pn in self.candidates] + average_neighbors = sum(n_neighbors) / len(n_neighbors) if n_neighbors else 0 + return average_neighbors + + def __calculate_hetereogeneity(self, loss): + """ + Calculate dataset heterogeneity between self.dataset and candidate.dataset using current loss value, + assuming the models are close enough to show good results + """ + if self.loss < 0 or loss < 0: + return 0 + else: + return abs((self.loss-loss)) + + def __suitability_function(self): + """ + Calculate suitability using hetereogeneity value and position on candidate.list. The reason to use that position is + because we assume that better candidates in terms of distance/quality would be in first positions of the list, and slower + or worse connection ones would be at the end + """ + + best_candidates = [] + total_positions = len(self.candidates) + + # lower positions in list represents higher value of Distance/Quality of connection + def calculate_position_weight(position): + return (total_positions - position - 1) / (total_positions - 1) + + # MAX and MIN value of hetereogeneity to normalize + min_hv = min(self.candidates, key=lambda x: x[2])[2] + max_hv = max(self.candidates, key=lambda x: x[2])[2] + + # Smaller values of HET represents higher suitability values + def normalize_hv(hv): + return (max_hv - hv) / (max_hv - min_hv) if max_hv != min_hv else 0.5 + + for position, (addr, n, hv) in enumerate(self.candidates): + position_weight = calculate_position_weight(position) + normalized_hv = normalize_hv(hv) + suitability = self.weight_distance * position_weight + self.weight_hetereogeneity * normalized_hv # suitability of the node + best_candidates.append((addr, suitability)) + + best_candidates.sort(key=lambda x: x[1], reverse=True) + + return best_candidates \ No newline at end of file diff --git a/nebula/core/neighbormanagement/candidateselection/ringcandidateselector.py b/nebula/core/neighbormanagement/candidateselection/ringcandidateselector.py new file mode 100644 index 000000000..7171edf67 --- /dev/null +++ b/nebula/core/neighbormanagement/candidateselection/ringcandidateselector.py @@ -0,0 +1,37 @@ +from nebula.core.neighbormanagement.candidateselection.candidateselector import CandidateSelector +from nebula.core.utils.locker import Locker + +class RINGCandidateSelector(CandidateSelector): + + def __init__(self): + self.candidates = [] + self.candidates_lock = Locker(name="candidates_lock") + + def set_config(self, config): + pass + + def add_candidate(self, candidate): + """ + To avoid topology problems select 1st candidate found + """ + self.candidates_lock.acquire() + if len(self.candidates) == 0: + self.candidates.append(candidate) + self.candidates_lock.release() + + def select_candidates(self): + self.candidates_lock.acquire() + cdts = self.candidates.copy() + self.candidates_lock.release() + return cdts + + def remove_candidates(self): + self.candidates_lock.acquire() + self.candidates = [] + self.candidates_lock.release() + + def any_candidate(self): + self.candidates_lock.acquire() + any = True if len(self.candidates) > 0 else False + self.candidates_lock.release() + return any \ No newline at end of file diff --git a/nebula/core/neighbormanagement/modelhandlers/__init__.py b/nebula/core/neighbormanagement/modelhandlers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/nebula/core/neighbormanagement/modelhandlers/aggmodelhandler.py b/nebula/core/neighbormanagement/modelhandlers/aggmodelhandler.py new file mode 100644 index 000000000..407020bd9 --- /dev/null +++ b/nebula/core/neighbormanagement/modelhandlers/aggmodelhandler.py @@ -0,0 +1,52 @@ +from nebula.core.neighbormanagement.modelhandlers.modelhandler import ModelHandler +from nebula.core.utils.locker import Locker + +class AGGModelHandler(ModelHandler): + + def __init__(self): + self.model = None + self.rounds = 0 + self.round = 0 + self.epochs = 1 + self.model_list = [] + self.models_lock = Locker(name="model_lock") + self.params_lock = Locker(name="param_lock") + + def set_config(self, config): + """ + Args: + config[0] -> total rounds + config[1] -> current round + config[2] -> epochs + """ + self.params_lock.acquire() + self.rounds = config[0] + if config[1] > self.round: + self.round = config[0] + self.epochs = config[2] + self.params_lock.release() + + def accept_model(self, model): + """ + Save first model receive and collect the rest for pre-processing + """ + self.models_lock.acquire() + if self.model is None: + self.model = model + else: + self.model_list.append(model) + self.models_lock.release() + + def get_model(self, model): + """ + Returns: + neccesary data to create trainer after pre-processing + """ + self.models_lock.acquire() + self.pre_process_model() + self.models_lock.release() + return (self.model, self.rounds, self.round, self.epochs) + + def pre_process_model(self): + # define pre-processing strategy + pass \ No newline at end of file diff --git a/nebula/core/neighbormanagement/modelhandlers/modelhandler.py b/nebula/core/neighbormanagement/modelhandlers/modelhandler.py new file mode 100644 index 000000000..31e061a3c --- /dev/null +++ b/nebula/core/neighbormanagement/modelhandlers/modelhandler.py @@ -0,0 +1,31 @@ +from abc import ABC, abstractmethod + +class ModelHandler(ABC): + + @abstractmethod + def set_config(self, config): + pass + + @abstractmethod + def accept_model(self, model): + pass + + @abstractmethod + def get_model(self, model): + pass + + @abstractmethod + def pre_process_model(self): + pass + +def factory_ModelHandler(model_handler): + from nebula.core.neighbormanagement.modelhandlers.stdmodelhandler import STDModelHandler + from nebula.core.neighbormanagement.modelhandlers.aggmodelhandler import AGGModelHandler + + options = { + 'std': STDModelHandler, + "aggregator": AGGModelHandler + } + + cs = options.get(model_handler) + return cs() \ No newline at end of file diff --git a/nebula/core/neighbormanagement/modelhandlers/stdmodelhandler.py b/nebula/core/neighbormanagement/modelhandlers/stdmodelhandler.py new file mode 100644 index 000000000..028900176 --- /dev/null +++ b/nebula/core/neighbormanagement/modelhandlers/stdmodelhandler.py @@ -0,0 +1,50 @@ +from nebula.core.neighbormanagement.modelhandlers.modelhandler import ModelHandler +from nebula.core.utils.locker import Locker + +class STDModelHandler(ModelHandler): + + def __init__(self): + self.model = None + self.rounds = 0 + self.round = 0 + self.epochs = 0 + self.model_lock = Locker(name="model_lock") + self.params_lock = Locker(name="param_lock") + + def set_config(self, config): + """ + Args: + config[0] -> total rounds + config[1] -> current round + config[2] -> epochs + """ + self.params_lock.acquire() + self.rounds = config[0] + if config[1] > self.round: + self.round = config[1] + self.epochs = config[2] + self.params_lock.release() + + def accept_model(self, model): + """ + save only first model receive to set up own model later + """ + if not self.model_lock.locked(): + self.model_lock.acquire() + self.model = model + + def get_model(self, model): + """ + Returns: + neccesary data to create trainer + """ + if self.model is not None: + return (self.model, self.rounds, self.round, self.epochs) + else: + return (None, 0, 0, 0) + + def pre_process_model(self): + """ + no pre-processing defined + """ + pass \ No newline at end of file diff --git a/nebula/core/neighbormanagement/neighborpolicies/__init__.py b/nebula/core/neighbormanagement/neighborpolicies/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py b/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py new file mode 100644 index 000000000..47b679418 --- /dev/null +++ b/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py @@ -0,0 +1,98 @@ +from nebula.core.neighbormanagement.neighborpolicies.neighborpolicy import NeighborPolicy +from nebula.core.utils.locker import Locker + +class FCNeighborPolicy(NeighborPolicy): + + def __init__(self): + self.max_neighbors = None + self.nodes_known = set() + self.neighbors = set() + self.neighbors_lock = Locker(name="neighbors_lock") + self.nodes_known_lock = Locker(name="nodes_known_lock") + + def need_more_neighbors(self): + """ + Fully connected network requires to be connected to all devices, therefore, + if there are more nodes known that self.neighbors, more neighbors are required + """ + self.neighbors_lock.acquire() + need_more = (len(self.neighbors) < len(self.nodes_known)) + self.neighbors_lock.release() + return need_more + + def set_config(self, config): + """ + Args: + config[0] -> list of self neighbors + config[1] -> list of nodes known on federation + """ + self.neighbors_lock.acquire() + self.neighbors = config[0] + self.neighbors_lock.release() + for addr in config[1]: + self.nodes_known.add(addr) + + def accept_connection(self, source, joining=False): + """ + return true if connection is accepted + """ + self.neighbors_lock.acquire() + ac = not source in self.neighbors + self.neighbors_lock.release() + return ac + + def meet_node(self, node): + """ + Update the list of nodes known on federation + """ + self.nodes_known_lock.acquire() + self.nodes_known.add(node) + self.nodes_known_lock.release() + + def get_nodes_known(self, neighbors_too=False): + self.nodes_known_lock.acquire() + nk = self.nodes_known.copy() + if not neighbors_too: + self.neighbors_lock.acquire() + nk = self.nodes_known - self.neighbors + self.neighbors_lock.release() + self.nodes_known_lock.release() + return nk + + def forget_nodes(self, node, forget_all=False): + self.nodes_known_lock.acquire() + if forget_all: + self.nodes_known.clear() + else: + self.nodes_known.discard(node) + self.nodes_known_lock.release() + + def get_actions(self): + """ + return list of actions to do in response to connection + - First list represents addrs argument to LinkMessage to connect to + - Second one represents the same but for disconnect from LinkMessage + """ + actions = [] + actions.append(self._connect_to()) + actions.append(self._disconnect_from()) + return actions + + + def _disconnect_from(self): + return "" + + def _connect_to(self): + ct = "" + self.neighbors_lock.acquire() + ct = " ".join(self.neighbors) + self.neighbors_lock.release() + return ct + + def update_neighbors(self, node, remove=False): + self.neighbors_lock.acquire() + if remove: + self.neighbors.remove(node) + else: + self.neighbors.add(node) + self.neighbors_lock.release() \ No newline at end of file diff --git a/nebula/core/neighbormanagement/neighborpolicies/idleneighborpolicy.py b/nebula/core/neighbormanagement/neighborpolicies/idleneighborpolicy.py new file mode 100644 index 000000000..f28ba85eb --- /dev/null +++ b/nebula/core/neighbormanagement/neighborpolicies/idleneighborpolicy.py @@ -0,0 +1,30 @@ +from nebula.core.neighbormanagement.neighborpolicies.neighborpolicy import NeighborPolicy + +class IDLENeighborPolicy(NeighborPolicy): + + def __init__(self): + pass + + def set_config(self, config): + pass + + def need_more_neighbors(self): + return False + + def accept_connection(self, source, joining=False): + return False + + def get_actions(self): + return [[],[]] + + def meet_node(self, node): + pass + + def forget_nodes(self, node, forget_all=False): + pass + + def get_nodes_known(self, neighbors_too=False): + return Set() + + def update_neighbors(self, node, remove=False): + pass \ No newline at end of file diff --git a/nebula/core/neighbormanagement/neighborpolicies/neighborpolicy.py b/nebula/core/neighbormanagement/neighborpolicies/neighborpolicy.py new file mode 100644 index 000000000..297044976 --- /dev/null +++ b/nebula/core/neighbormanagement/neighborpolicies/neighborpolicy.py @@ -0,0 +1,51 @@ +from abc import ABC, abstractmethod + +class NeighborPolicy(ABC): + + @abstractmethod + def set_config(self, config): + pass + + @abstractmethod + def need_more_neighbors(self): + pass + + @abstractmethod + def accept_connection(self, source, joining=False): + pass + + @abstractmethod + def get_actions(self): + pass + + @abstractmethod + def meet_node(self, node): + pass + + abstractmethod + def forget_nodes(self, node, forget_all=False): + pass + + @abstractmethod + def get_nodes_known(self, neighbors_too=False): + pass + + @abstractmethod + def update_neighbors(self, node, remove=False): + pass + +def factory_NeighborPolicy(topology): + from nebula.core.neighbormanagement.neighborpolicies.idleneighborpolicy import IDLENeighborPolicy + from nebula.core.neighbormanagement.neighborpolicies.fcneighborpolicy import FCNeighborPolicy + from nebula.core.neighbormanagement.neighborpolicies.ringneighborpolicy import RINGNeighborPolicy + from nebula.core.neighbormanagement.neighborpolicies.starneighborpolicy import STARNeighborPolicy + + options = { + 'random': IDLEneighborPolicy, # default value + 'fully': FCNeighborPolicy, + 'ring': RINGNeighborPolicy, + 'star': IDLENeighborPolicy + } + + cs = options.get(topology) + return cs() \ No newline at end of file diff --git a/nebula/core/neighbormanagement/neighborpolicies/ringneighborpolicy.py b/nebula/core/neighbormanagement/neighborpolicies/ringneighborpolicy.py new file mode 100644 index 000000000..d3f0ed4c0 --- /dev/null +++ b/nebula/core/neighbormanagement/neighborpolicies/ringneighborpolicy.py @@ -0,0 +1,94 @@ +from nebula.core.neighbormanagement.neighborpolicies.neighborpolicy import NeighborPolicy +from nebula.core.utils.locker import Locker +import random + +class RINGNeighborPolicy(NeighborPolicy): + + def __init__(self): + self.max_neighbors = 2 + self.nodes_known = set() + self.neighbors = set() + self.neighbors_lock = Locker(name="neighbors_lock") + self.nodes_known_lock = Locker(name="nodes_known_lock") + self.addr = "" + + def need_more_neighbors(self): + self.neighbors_lock.acquire() + need_more = len(self.neighbors) < self.max_neighbors + self.neighbors_lock.release() + return need_more + + def set_config(self, config): + """ + Args: + config[0] -> list of self neighbors + config[1] -> list of nodes known on federation + config[2] -> self.addr + """ + self.neighbors_lock.acquire() + self.neighbors = config[0] + self.neighbors_lock.release() + for addr in config[1]: + self.nodes_known.add(addr) + self.addr = config[2] + + def accept_connection(self, source, joining=False): + """ + return true if connection is accepted + """ + ac = False + self.neighbors_lock.acquire() + if not joining: + ac = not source in self.neighbors + else: + ac = not len(self.neighbors) == self.max_neighbors + self.neighbors_lock.release() + return ac + + def meet_node(self, node): + self.nodes_known_lock.acquire() + self.nodes_known.add(node) + self.nodes_known_lock.release() + + def forget_nodes(self, node, forget_all=False): + self.nodes_known_lock.acquire() + if forget_all: + self.nodes_known.clear() + else: + self.nodes_known.discard(node) + self.nodes_known_lock.release() + + def get_nodes_known(self, neighbors_too=False): + self.nodes_known_lock.acquire() + nk = self.nodes_known.copy() + if not neighbors_too: + self.neighbors_lock.acquire() + nk = self.nodes_known - self.neighbors + self.neighbors_lock.release() + self.nodes_known_lock.release() + return nk + + def get_actions(self): + """ + return list of actions to do in response to connection + - First list represents addrs argument to LinkMessage to connect to + - Second one represents the same but for disconnect from LinkMessage + """ + self.neighbors_lock.acquire() + actions = [] + if len(self.neighbors) < self.max_neighbors: + list_neighbors = list(self.neighbors) + index = random.randint(0, len(list_neighbors)-1) + node = list_neighbors[index] + actions.append(node) # connect to + actions.append(self.addr) # disconnect from + self.neighbors_lock.release() + return actions + + def update_neighbors(self, node, remove=False): + self.neighbors_lock.acquire() + if remove: + self.neighbors.remove(node) + else: + self.neighbors.add(node) + self.neighbors_lock.release() \ No newline at end of file diff --git a/nebula/core/neighbormanagement/neighborpolicies/starneighborpolicy.py b/nebula/core/neighbormanagement/neighborpolicies/starneighborpolicy.py new file mode 100644 index 000000000..75a21677a --- /dev/null +++ b/nebula/core/neighbormanagement/neighborpolicies/starneighborpolicy.py @@ -0,0 +1,79 @@ +from nebula.core.neighbormanagement.neighborpolicies.neighborpolicy import NeighborPolicy +from nebula.core.utils.locker import Locker + +class STARNeighborPolicy(NeighborPolicy): + + def __init__(self): + self.max_neighbors = 1 + self.nodes_known = set() + self.neighbors = set() + self.neighbors_lock = Locker(name="neighbors_lock") + self.nodes_known_lock = Locker(name="nodes_known_lock") + self.addr = "" + + def need_more_neighbors(self): + self.neighbors_lock.acquire() + need_more = len(self.neighbors) < self.max_neighbors + self.neighbors_lock.release() + return need_more + + def set_config(self, config): + """ + Args: + config[0] -> list of self neighbors, in this case, the star point + config[1] -> list of nodes known on federation + config[2] -> self.addr + """ + self.neighbors_lock.acquire() + self.neighbors = config[0] + self.neighbors_lock.release() + for addr in config[1]: + self.nodes_known.add(addr) + self.addr = config[2] + + def accept_connection(self, source, joining=False): + """ + return true if connection is accepted + """ + ac = joining + return ac + + def meet_node(self, node): + self.nodes_known_lock.acquire() + self.nodes_known.add(node) + self.nodes_known_lock.release() + + def forget_nodes(self, node, forget_all=False): + self.nodes_known_lock.acquire() + if forget_all: + self.nodes_known.clear() + else: + self.nodes_known.discard(node) + self.nodes_known_lock.release() + + def get_nodes_known(self, neighbors_too=False): + self.nodes_known_lock.acquire() + nk = self.nodes_known.copy() + if not neighbors_too: + self.neighbors_lock.acquire() + nk = self.nodes_known - self.neighbors + self.neighbors_lock.release() + self.nodes_known_lock.release() + return nk + + def get_actions(self): + """ + return list of actions to do in response to connection + - First list represents addrs argument to LinkMessage to connect to + - Second one represents the same but for disconnect from LinkMessage + """ + self.neighbors_lock.acquire() + actions = [] + if len(self.neighbors) < self.max_neighbors: + actions.append(self.neighbors[0]) # connect to star point + actions.append(self.addr) # disconnect from me + self.neighbors_lock.release() + return actions + + def update_neighbors(self, node, remove=False): + pass \ No newline at end of file diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py new file mode 100644 index 000000000..fd2eaeb21 --- /dev/null +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -0,0 +1,231 @@ +import asyncio +import logging +import os +import asyncio +import threading + +from nebula.core.utils.locker import Locker +from nebula.core.neighbormanagement.candidateselection.candidateselector import factory_CandidateSelector +from nebula.core.neighbormanagement.modelhandlers.modelhandler import factory_ModelHandler +from nebula.core.neighbormanagement.neighborpolicies.neighborpolicy import factory_NeighborPolicy +from nebula.core.pb import nebula_pb2 +from nebula.core.network.communications import CommunicationsManager + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from nebula.core.engine import Engine + +class NodeManager(): + + def __init__( + self, + topology, + model_handler, + engine : "Engine" + ): + logging.info("🌐 Initializing Node Manager") + self._engine = engine + self.config = engine.get_config() + self._neighbor_policy = factory_NeighborPolicy(topology) + self._candidate_selector = factory_CandidateSelector(topology) + self._model_handler = factory_ModelHandler(model_handler) + self.late_connection_process_lock = Locker(name="late_connection_process_lock") + self.weight_modifier = {} + self.weight_modifier_lock = Locker(name="weight_modifier_lock") + self.new_node_weight_value = 3 + self.accept_candidates_lock = Locker(name="accept_candidates_lock") + self.recieve_offer_timer = 5 + self._restructure_process_lock = Locker(name="restructure_process_lock") + self.restructure = False + self.set_confings() + + @property + def engine(self): + return self._engine + + @property + def neighbor_policy(self): + return self._neighbor_policy + + @property + def candidate_selector(self): + return self._candidate_selector + + @property + def model_handler(self): + return self._model_handler + + def get_restructure_process_lock(self): + return self._restructure_process_lock + + def set_confings(self): + """ + neighbor_policy config: + - direct connections a.k.a neighbors + - non-direct connections + - self addr + + model_handler config: + - self total rounds + - self current round + - self epochs + + candidate_selector config: + - self model loss + - self weight distance + - self weight hetereogeneity + """ + self.neighbor_policy.set_config([self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False), self.engine.cm.get_addrs_current_connections(only_direct=False, myself=False), self.engine.addr]) + #self.model_handler.set_config([self.engine.get_round(), self.engine.config.participant["training_args"]["epochs"]]) + self.candidate_selector.set_config([self.engine.get_loss(), self.config.participant["molibity_args"]["weight_distance"], self.config.participant["molibity_args"]["weight_het"]]) + + def add_weight_modifier(self, addr): + self.weight_modifier_lock().acquire() + if not addr in self.weight_modifier: + self.weight_modifier[addr] = self.new_node_weight_value + self.weight_modifier_lock().release() + + def remove_weight_modifier(self, addr): + self.weight_modifier_lock().acquire() + if addr in self.weight_modifier: + del self.weight_modifier[addr] + self.weight_modifier_lock().release() + + def _update_weight_modifier(self, addr): + self.weight_modifier_lock().acquire() + if addr in self.weight_modifier: + new_weight = self.weight_modifier[addr] - 1/self.engine.get_round()**2 + if new_weight > 1: + self.weight_modifier[addr] = new_weight + else: + self.remove_weight_modifier(addr) + self.weight_modifier_lock().release() + + def get_weight_modifier(self, addr): + self.weight_modifier_lock().acquire() + if addr in self.weight_modifier: + wm = self.weight_modifier[addr] + self._update_weight_modifier(addr, self.engine.get_round()) + else: + wm = 1 + self.weight_modifier_lock().release() + return wm + + def accept_connection(self,source): + if self.accept_candidates_lock().locked(): + return False + return self.neighbor_policy.accept_connection(source) + + def need_more_neighbors(self): + return self.neighbor_policy.need_more_neighbors() + + def get_actions(self): + return self.neighbor_policy.get_actions() + + def update_neighbors(self, node, remove=False): + self.neighbor_policy.update_neighbors(node, remove) + if not remove: + self.neighbor_policy.meet_node(node) + + def no_neighbors_left(self): + return len(self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False)) + + def meet_node(self, node): + self.neighbor_policy.meet_node(node) + + def get_nodes_known(self, neighbors_too=False): + return self.neighbor_policy.get_nodes_known(neighbors_too) + + def accept_model(source, decoded_model, rounds, round, epochs, n_neighbors, loss): + if not self.accept_candidates_lock().locked(): + self.model_handler.accept_model(decoded_model) + self.model_handler.setConfig(config=(rounds, round, epochs)) + self.candidate_selector.add_candidate((source, n_neighbors, loss)) + + def add_candidate(source, n_neighbors, loss): + if not self.accept_candidates_lock().locked(): + self.candidate_selector.add_candidate((source, n_neighbors, loss)) + + async def start_late_connection_process(self): + """ + This function represents the process of discovering the federation and stablish the first + connections with it. The first step is to send the DISCOVER_JOIN message to look for nodes, + the ones that receive that message will send back a OFFER_MODEL message. It contains info to do + a selection process among candidates to later on connect do the best ones. + The process will repeat until at least one candidate is found and the process will be locked + to avoid concurrency. + + Returns: + data neccesary to create trainer + """ + logging.info("🌐 Initializing start late connection process from Node Manager") + + self.late_connection_process_lock.acquire() + best_candidates = [] + self.candidate_selector.remove_candidates() + + # send discover + msg = self.engine.cm.mm.generate_discover_message(nebula_pb2.DiscoverMessage.Action.DISCOVER_JOIN) + await self.engine.cm.establish_connection_with_federation(msg) + + # wait offer + await asyncio.sleep(self.recieve_offer_timer) + + # acquire lock to not accept late candidates + self.accept_candidates_lock.acquire() + + if self.candidate_selector.any_candidate(): + + # create message to send to new neightbors + msg = self.engine.cm.mm.generate_connection_message(nebula_pb2.ConnectionMessage.Action.LATE_CONNECT) + + best_candidates = self.candidate_selector.select_candidates() + + for addr, _, _ in best_candidates: + await self.engine.cm.connect(addr, direct=True) + await self.engine.cm.send_message(addr, msg) + + model, rounds, round, epochs = self.model_handler.get_model() + self.accept_candidates_lock().release() + self.late_connection_process_lock.release() + return (model, rounds, round, epochs) + + # if no candidates, repeat process + else: + self.accept_candidates_lock.release() + self.late_connection_process_lock.release() + return self.start_late_connection_process() + + + + """ + Retopology in progress + """ + + async def find_new_connections(self): + logging.info("🌐 Initializing restructure process from Node Manager") + self._restructure_process_lock.acquire() + # Update the config params of candidate_selector + self.candidate_selector.set_config([self.engine.get_loss(), self.engine.weight_distance, self.engine.weight_het]) + self.thread = threading.Thread(target=self._find_connections_thread, args=(self)) + self.thread.start() + self.restructure = True + while self.restructure: + await asyncio.sleep(1) + self._restructure_process_lock.release() + + + async def _find_connections_thread(self): + posible_connections = self.get_nodes_known(neighbors_too=False) + while self.restructure: + # out of federation but got info about nodes inside + if len(posible_connections) > 0: + msg = self.engine.cm.mm.generate_discover_message(nebula_pb2.DiscoverMessage.Action.DISCOVER_NODE) + for addr in posible_connections: + # send message to known nodes, wait for response and select + pass + # im out of federation without info about any nodes inside of it + else: + await self.start_late_connection_process() + + self.restructure = self.need_more_neighbors() \ No newline at end of file diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 8953800ec..0bb37c989 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -77,9 +77,9 @@ def __init__(self, engine: "Engine"): self._external_connection_service = None # The line below is neccesary when mobility would be set up - #if self.config.participant["mobility_args"]["mobility"] and not self.config.participant["mobility_args"]["late_creation"]: - self._external_connection_service = NebulaConnectionService(self.addr) - self.ecs.start() + if self.config.participant["mobility_args"]["mobility"] and not self.config.participant["mobility_args"]["late_creation"]: + self._external_connection_service = NebulaConnectionService(self.addr) + self.ecs.start() self.stop_network_engine = asyncio.Event() @@ -144,6 +144,15 @@ async def handle_incoming_message(self, data, addr_from): await self.handle_model_message(source, message_wrapper.model_message) elif message_wrapper.HasField("connection_message"): await self.handle_connection_message(source, message_wrapper.connection_message) + elif message_wrapper.HasField("discover_message"): + if self.include_received_message_hash(hashlib.md5(data).hexdigest()): + await self.handle_discover_message(source, message_wrapper.discover_message) + elif message_wrapper.HasField("offer_message"): + if self.include_received_message_hash(hashlib.md5(data).hexdigest()): + await self.handle_offer_message(source, message_wrapper.offer_message) + elif message_wrapper.HasField("link_message"): + if self.include_received_message_hash(hashlib.md5(data).hexdigest()): + await self.handle_offer_message(source, message_wrapper.link_message) else: logging.info(f"Unknown handler for message: {message_wrapper}") except Exception as e: @@ -240,7 +249,28 @@ async def handle_connection_message(self, source, message): except Exception as e: logging.error(f"πŸ”— handle_connection_message | Error while processing: {message.action} | {e}") - def _start_external_connection_service(self): + async def handle_discover_message(self, source, message): + logging.info(f"πŸ” handle_discover_message | Received [Action {message.action}] from {source}") + try: + await self.engine.event_manager.trigger_event(source, message) + except Exception as e: + logging.error(f"πŸ” handle_discover_message | Error while processing: {e}") + + async def handle_offer_message(self, source, message): + logging.info(f"πŸ” handle_offer_message | Received [Action {message.action}] from {source} with arguments {message.arguments}") + try: + await self.engine.event_manager.trigger_event(source, message) + except Exception as e: + logging.error(f"πŸ” handle_offer_message | Error while processing: {message.action} {message.arguments} | {e}") + + async def handle_link_message(self, source, message): + logging.info(f"πŸ” handle_link_message | Received [Action {message.action}] from {source} with arguments {message.arguments}") + try: + await self.engine.event_manager.trigger_event(source, message) + except Exception as e: + logging.error(f"πŸ” handle_link_message | Error while processing: {message.action} {message.arguments} | {e}") + + def start_external_connection_service(self): self.ecs = NebulaConnectionService(self.addr) self.ecs.start() diff --git a/nebula/core/network/messages.py b/nebula/core/network/messages.py index 8fd1c2dc9..a95ad1c61 100755 --- a/nebula/core/network/messages.py +++ b/nebula/core/network/messages.py @@ -81,3 +81,42 @@ def generate_reputation_message(self, reputation): message_wrapper.reputation_message.CopyFrom(message) data = message_wrapper.SerializeToString() return data + + def generate_discover_message(self, action): + message = nebula_pb2.DiscoverMessage( + action=action, + ) + message_wrapper = nebula_pb2.Wrapper() + message_wrapper.source = self.addr + message_wrapper.discovery_message.CopyFrom(message) + data = message_wrapper.SerializeToString() + return data + + + def generate_offer_message(self, action, n_neighbors, loss, serialized_model=None, rounds=1, round=-1, epochs = 1): + message = nebula_pb2.OfferMessage( + action=action, + n_neighbors = n_neighbors, + loss = loss, + parameters = serialized_model, + rounds = rounds, + round = round, + epochs = epochs + ) + message_wrapper = nebula_pb2.Wrapper() + message_wrapper.source = self.addr + message_wrapper.discovery_message.CopyFrom(message) + data = message_wrapper.SerializeToString() + return data + + + def generate_link_message(self, action, addrs): + message = nebula_pb2.LinkMessage( + action=action, + addrs = addrs, + ) + message_wrapper = nebula_pb2.Wrapper() + message_wrapper.source = self.addr + message_wrapper.link_message.CopyFrom(message) + data = message_wrapper.SerializeToString() + return data diff --git a/nebula/core/network/propagator.py b/nebula/core/network/propagator.py index a147fd751..5b9361531 100755 --- a/nebula/core/network/propagator.py +++ b/nebula/core/network/propagator.py @@ -156,3 +156,25 @@ async def propagate_continuously(self, strategy_id: str): if not propagated: logging.info("Exiting continuous propagation...") return + + async def get_model_information(self, dest_addr, strategy_id: str): + if strategy_id not in self.strategies: + logging.info(f"Strategy {strategy_id} not found.") + return None + if self.get_round() is None: + logging.info("Propagation halted: round is not set.") + return None + + strategy = self.strategies[strategy_id] + logging.info(f"Preparing model information with strategy to make an offer: {strategy_id}") + + serialized_model, weight = strategy.prepare_model_payload(dest_addr) + + if serialized_model: + serialized_model = serialized_model if isinstance(serialized_model, bytes) else self.trainer.serialize_model(serialized_model) + if strategy_id == "initialization": + return (serialized_model, weight, -1) + else: + return (serialized_model, weight, self.get_round()) + + return None diff --git a/nebula/core/pb/nebula.proto b/nebula/core/pb/nebula.proto index 7108b3dde..a2a6427b7 100755 --- a/nebula/core/pb/nebula.proto +++ b/nebula/core/pb/nebula.proto @@ -22,6 +22,9 @@ message Wrapper { ModelMessage model_message = 5; ConnectionMessage connection_message = 6; ResponseMessage response_message = 7; + DiscoverMessage discover_message = 8; + OfferMessage offer_message = 9; + LinkMessage link_message = 10; } } @@ -70,10 +73,43 @@ message ConnectionMessage { enum Action { CONNECT = 0; DISCONNECT = 1; + LATE_CONNECT = 2; // Message send when late connection to federation + RESTRUCTURE = 3; // Message to notify connection is because restructuration of topology } Action action = 1; } +message DiscoverMessage { + enum Action { + DISCOVER_JOIN = 0; // Message to discover nodes on federation when i'm new + DISCOVER_NODES = 1; // Message to discover nodes on federation when i'm already in + } + Action action = 1; +} + +message OfferMessage{ + enum Action { + OFFER_MODEL = 0; // Message to offer model info to a new node + OFFER_METRIC = 1; // Message to offer metrics info to a node on federation + } + Action action = 1; + float n_neighbors = 2; + float loss = 3; + bytes parameters = 4; + int32 rounds = 5; + int32 round = 6; + int32 epochs = 7; +} + +message LinkMessage { + enum Action { + CONNECT_TO = 0; // Message to tell a node who to connect to + DISCONNECT_FROM = 1; // Message to tell a node who to disconnect + } + Action action = 1; + string addrs = 2; +} + // Response transmits the outcome of a requested operation, including any errors. message ResponseMessage { string response = 1; // Outcome of the requested operation. diff --git a/nebula/core/pb/nebula_pb2.py b/nebula/core/pb/nebula_pb2.py index c13d10ea5..2a7463f86 100755 --- a/nebula/core/pb/nebula_pb2.py +++ b/nebula/core/pb/nebula_pb2.py @@ -1,11 +1,22 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE # source: nebula.proto +# Protobuf Python Version: 5.27.2 """Generated protocol buffer code.""" -from google.protobuf.internal import builder as _builder from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 5, + 27, + 2, + '', + 'nebula.proto' +) # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -13,33 +24,45 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cnebula.proto\x12\x06nebula\"\xe4\x02\n\x07Wrapper\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x35\n\x11\x64iscovery_message\x18\x02 \x01(\x0b\x32\x18.nebula.DiscoveryMessageH\x00\x12\x31\n\x0f\x63ontrol_message\x18\x03 \x01(\x0b\x32\x16.nebula.ControlMessageH\x00\x12\x37\n\x12\x66\x65\x64\x65ration_message\x18\x04 \x01(\x0b\x32\x19.nebula.FederationMessageH\x00\x12-\n\rmodel_message\x18\x05 \x01(\x0b\x32\x14.nebula.ModelMessageH\x00\x12\x37\n\x12\x63onnection_message\x18\x06 \x01(\x0b\x32\x19.nebula.ConnectionMessageH\x00\x12\x33\n\x10response_message\x18\x07 \x01(\x0b\x32\x17.nebula.ResponseMessageH\x00\x42\t\n\x07message\"\x9e\x01\n\x10\x44iscoveryMessage\x12/\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1f.nebula.DiscoveryMessage.Action\x12\x10\n\x08latitude\x18\x02 \x01(\x02\x12\x11\n\tlongitude\x18\x03 \x01(\x02\"4\n\x06\x41\x63tion\x12\x0c\n\x08\x44ISCOVER\x10\x00\x12\x0c\n\x08REGISTER\x10\x01\x12\x0e\n\nDEREGISTER\x10\x02\"\x9a\x01\n\x0e\x43ontrolMessage\x12-\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1d.nebula.ControlMessage.Action\x12\x0b\n\x03log\x18\x02 \x01(\t\"L\n\x06\x41\x63tion\x12\t\n\x05\x41LIVE\x10\x00\x12\x0c\n\x08OVERHEAD\x10\x01\x12\x0c\n\x08MOBILITY\x10\x02\x12\x0c\n\x08RECOVERY\x10\x03\x12\r\n\tWEAK_LINK\x10\x04\"\xb7\x01\n\x11\x46\x65\x64\x65rationMessage\x12\x30\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32 .nebula.FederationMessage.Action\x12\x11\n\targuments\x18\x02 \x03(\t\x12\r\n\x05round\x18\x03 \x01(\x05\"N\n\x06\x41\x63tion\x12\x14\n\x10\x46\x45\x44\x45RATION_START\x10\x00\x12\x0e\n\nREPUTATION\x10\x01\x12\x1e\n\x1a\x46\x45\x44\x45RATION_MODELS_INCLUDED\x10\x02\"A\n\x0cModelMessage\x12\x12\n\nparameters\x18\x01 \x01(\x0c\x12\x0e\n\x06weight\x18\x02 \x01(\x03\x12\r\n\x05round\x18\x03 \x01(\x05\"l\n\x11\x43onnectionMessage\x12\x30\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32 .nebula.ConnectionMessage.Action\"%\n\x06\x41\x63tion\x12\x0b\n\x07\x43ONNECT\x10\x00\x12\x0e\n\nDISCONNECT\x10\x01\"#\n\x0fResponseMessage\x12\x10\n\x08response\x18\x01 \x01(\tb\x06proto3') - -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'nebula_pb2', globals()) -if _descriptor._USE_C_DESCRIPTORS == False: +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cnebula.proto\x12\x06nebula\"\xf5\x03\n\x07Wrapper\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x35\n\x11\x64iscovery_message\x18\x02 \x01(\x0b\x32\x18.nebula.DiscoveryMessageH\x00\x12\x31\n\x0f\x63ontrol_message\x18\x03 \x01(\x0b\x32\x16.nebula.ControlMessageH\x00\x12\x37\n\x12\x66\x65\x64\x65ration_message\x18\x04 \x01(\x0b\x32\x19.nebula.FederationMessageH\x00\x12-\n\rmodel_message\x18\x05 \x01(\x0b\x32\x14.nebula.ModelMessageH\x00\x12\x37\n\x12\x63onnection_message\x18\x06 \x01(\x0b\x32\x19.nebula.ConnectionMessageH\x00\x12\x33\n\x10response_message\x18\x07 \x01(\x0b\x32\x17.nebula.ResponseMessageH\x00\x12\x33\n\x10\x64iscover_message\x18\x08 \x01(\x0b\x32\x17.nebula.DiscoverMessageH\x00\x12-\n\roffer_message\x18\t \x01(\x0b\x32\x14.nebula.OfferMessageH\x00\x12+\n\x0clink_message\x18\n \x01(\x0b\x32\x13.nebula.LinkMessageH\x00\x42\t\n\x07message\"\x9e\x01\n\x10\x44iscoveryMessage\x12/\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1f.nebula.DiscoveryMessage.Action\x12\x10\n\x08latitude\x18\x02 \x01(\x02\x12\x11\n\tlongitude\x18\x03 \x01(\x02\"4\n\x06\x41\x63tion\x12\x0c\n\x08\x44ISCOVER\x10\x00\x12\x0c\n\x08REGISTER\x10\x01\x12\x0e\n\nDEREGISTER\x10\x02\"\x9a\x01\n\x0e\x43ontrolMessage\x12-\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1d.nebula.ControlMessage.Action\x12\x0b\n\x03log\x18\x02 \x01(\t\"L\n\x06\x41\x63tion\x12\t\n\x05\x41LIVE\x10\x00\x12\x0c\n\x08OVERHEAD\x10\x01\x12\x0c\n\x08MOBILITY\x10\x02\x12\x0c\n\x08RECOVERY\x10\x03\x12\r\n\tWEAK_LINK\x10\x04\"\xb7\x01\n\x11\x46\x65\x64\x65rationMessage\x12\x30\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32 .nebula.FederationMessage.Action\x12\x11\n\targuments\x18\x02 \x03(\t\x12\r\n\x05round\x18\x03 \x01(\x05\"N\n\x06\x41\x63tion\x12\x14\n\x10\x46\x45\x44\x45RATION_START\x10\x00\x12\x0e\n\nREPUTATION\x10\x01\x12\x1e\n\x1a\x46\x45\x44\x45RATION_MODELS_INCLUDED\x10\x02\"A\n\x0cModelMessage\x12\x12\n\nparameters\x18\x01 \x01(\x0c\x12\x0e\n\x06weight\x18\x02 \x01(\x03\x12\r\n\x05round\x18\x03 \x01(\x05\"\x8f\x01\n\x11\x43onnectionMessage\x12\x30\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32 .nebula.ConnectionMessage.Action\"H\n\x06\x41\x63tion\x12\x0b\n\x07\x43ONNECT\x10\x00\x12\x0e\n\nDISCONNECT\x10\x01\x12\x10\n\x0cLATE_CONNECT\x10\x02\x12\x0f\n\x0bRESTRUCTURE\x10\x03\"r\n\x0f\x44iscoverMessage\x12.\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1e.nebula.DiscoverMessage.Action\"/\n\x06\x41\x63tion\x12\x11\n\rDISCOVER_JOIN\x10\x00\x12\x12\n\x0e\x44ISCOVER_NODES\x10\x01\"\xce\x01\n\x0cOfferMessage\x12+\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1b.nebula.OfferMessage.Action\x12\x13\n\x0bn_neighbors\x18\x02 \x01(\x02\x12\x0c\n\x04loss\x18\x03 \x01(\x02\x12\x12\n\nparameters\x18\x04 \x01(\x0c\x12\x0e\n\x06rounds\x18\x05 \x01(\x05\x12\r\n\x05round\x18\x06 \x01(\x05\x12\x0e\n\x06\x65pochs\x18\x07 \x01(\x05\"+\n\x06\x41\x63tion\x12\x0f\n\x0bOFFER_MODEL\x10\x00\x12\x10\n\x0cOFFER_METRIC\x10\x01\"w\n\x0bLinkMessage\x12*\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1a.nebula.LinkMessage.Action\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x01(\t\"-\n\x06\x41\x63tion\x12\x0e\n\nCONNECT_TO\x10\x00\x12\x13\n\x0f\x44ISCONNECT_FROM\x10\x01\"#\n\x0fResponseMessage\x12\x10\n\x08response\x18\x01 \x01(\tb\x06proto3') - DESCRIPTOR._options = None - _WRAPPER._serialized_start=25 - _WRAPPER._serialized_end=381 - _DISCOVERYMESSAGE._serialized_start=384 - _DISCOVERYMESSAGE._serialized_end=542 - _DISCOVERYMESSAGE_ACTION._serialized_start=490 - _DISCOVERYMESSAGE_ACTION._serialized_end=542 - _CONTROLMESSAGE._serialized_start=545 - _CONTROLMESSAGE._serialized_end=699 - _CONTROLMESSAGE_ACTION._serialized_start=623 - _CONTROLMESSAGE_ACTION._serialized_end=699 - _FEDERATIONMESSAGE._serialized_start=702 - _FEDERATIONMESSAGE._serialized_end=885 - _FEDERATIONMESSAGE_ACTION._serialized_start=807 - _FEDERATIONMESSAGE_ACTION._serialized_end=885 - _MODELMESSAGE._serialized_start=887 - _MODELMESSAGE._serialized_end=952 - _CONNECTIONMESSAGE._serialized_start=954 - _CONNECTIONMESSAGE._serialized_end=1062 - _CONNECTIONMESSAGE_ACTION._serialized_start=1025 - _CONNECTIONMESSAGE_ACTION._serialized_end=1062 - _RESPONSEMESSAGE._serialized_start=1064 - _RESPONSEMESSAGE._serialized_end=1099 +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'nebula_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_WRAPPER']._serialized_start=25 + _globals['_WRAPPER']._serialized_end=526 + _globals['_DISCOVERYMESSAGE']._serialized_start=529 + _globals['_DISCOVERYMESSAGE']._serialized_end=687 + _globals['_DISCOVERYMESSAGE_ACTION']._serialized_start=635 + _globals['_DISCOVERYMESSAGE_ACTION']._serialized_end=687 + _globals['_CONTROLMESSAGE']._serialized_start=690 + _globals['_CONTROLMESSAGE']._serialized_end=844 + _globals['_CONTROLMESSAGE_ACTION']._serialized_start=768 + _globals['_CONTROLMESSAGE_ACTION']._serialized_end=844 + _globals['_FEDERATIONMESSAGE']._serialized_start=847 + _globals['_FEDERATIONMESSAGE']._serialized_end=1030 + _globals['_FEDERATIONMESSAGE_ACTION']._serialized_start=952 + _globals['_FEDERATIONMESSAGE_ACTION']._serialized_end=1030 + _globals['_MODELMESSAGE']._serialized_start=1032 + _globals['_MODELMESSAGE']._serialized_end=1097 + _globals['_CONNECTIONMESSAGE']._serialized_start=1100 + _globals['_CONNECTIONMESSAGE']._serialized_end=1243 + _globals['_CONNECTIONMESSAGE_ACTION']._serialized_start=1171 + _globals['_CONNECTIONMESSAGE_ACTION']._serialized_end=1243 + _globals['_DISCOVERMESSAGE']._serialized_start=1245 + _globals['_DISCOVERMESSAGE']._serialized_end=1359 + _globals['_DISCOVERMESSAGE_ACTION']._serialized_start=1312 + _globals['_DISCOVERMESSAGE_ACTION']._serialized_end=1359 + _globals['_OFFERMESSAGE']._serialized_start=1362 + _globals['_OFFERMESSAGE']._serialized_end=1568 + _globals['_OFFERMESSAGE_ACTION']._serialized_start=1525 + _globals['_OFFERMESSAGE_ACTION']._serialized_end=1568 + _globals['_LINKMESSAGE']._serialized_start=1570 + _globals['_LINKMESSAGE']._serialized_end=1689 + _globals['_LINKMESSAGE_ACTION']._serialized_start=1644 + _globals['_LINKMESSAGE_ACTION']._serialized_end=1689 + _globals['_RESPONSEMESSAGE']._serialized_start=1691 + _globals['_RESPONSEMESSAGE']._serialized_end=1726 # @@protoc_insertion_point(module_scope) diff --git a/nebula/frontend/config/participant.json.example b/nebula/frontend/config/participant.json.example index 03cbbfb5d..e094af826 100755 --- a/nebula/frontend/config/participant.json.example +++ b/nebula/frontend/config/participant.json.example @@ -68,7 +68,12 @@ "status": false, "round_start": 0, "scheme": "random" - } + }, + "model_handler": "std", + "late_creation": false, + "sleeping_time": 0, + "weight_distance": 0.2, + "weight_het": 0.8 }, "data_args": { "dataset": "MNIST", @@ -82,7 +87,8 @@ }, "training_args": { "trainer": "lightning", - "epochs": 3 + "epochs": 3, + "flexible_participation": false }, "aggregator_args": { "algorithm": "FedAvg", From f9a3964bc248b64909ee6a3371fd46b24d87aa72 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 10 Jul 2024 18:11:56 +0200 Subject: [PATCH 005/233] fix_function_arguments --- nebula/core/neighbormanagement/nodemanager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index fd2eaeb21..5119e8327 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -136,13 +136,13 @@ def meet_node(self, node): def get_nodes_known(self, neighbors_too=False): return self.neighbor_policy.get_nodes_known(neighbors_too) - def accept_model(source, decoded_model, rounds, round, epochs, n_neighbors, loss): + def accept_model(self, source, decoded_model, rounds, round, epochs, n_neighbors, loss): if not self.accept_candidates_lock().locked(): self.model_handler.accept_model(decoded_model) self.model_handler.setConfig(config=(rounds, round, epochs)) self.candidate_selector.add_candidate((source, n_neighbors, loss)) - def add_candidate(source, n_neighbors, loss): + def add_candidate(self,source, n_neighbors, loss): if not self.accept_candidates_lock().locked(): self.candidate_selector.add_candidate((source, n_neighbors, loss)) From edd4ef54340ce96a4d2ff3424effa7c904a50317 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Sat, 13 Jul 2024 12:07:11 +0200 Subject: [PATCH 006/233] fix_compiler_errors type notation on factory methods --- .../candidateselection/candidateselector.py | 3 ++- nebula/core/neighbormanagement/modelhandlers/modelhandler.py | 3 ++- .../neighborpolicies/idleneighborpolicy.py | 2 +- .../neighbormanagement/neighborpolicies/neighborpolicy.py | 5 +++-- nebula/core/neighbormanagement/nodemanager.py | 2 +- 5 files changed, 9 insertions(+), 6 deletions(-) diff --git a/nebula/core/neighbormanagement/candidateselection/candidateselector.py b/nebula/core/neighbormanagement/candidateselection/candidateselector.py index 6c05a57a9..8beb90099 100644 --- a/nebula/core/neighbormanagement/candidateselection/candidateselector.py +++ b/nebula/core/neighbormanagement/candidateselection/candidateselector.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import Type class CandidateSelector(ABC): @@ -22,7 +23,7 @@ def remove_candidates(self): def any_candidate(self): pass -def factory_CandidateSelector(topology): +def factory_CandidateSelector(topology) -> CandidateSelector: from nebula.core.neighbormanagement.candidateselection.fccandidateselector import FCCandidateSelector from nebula.core.neighbormanagement.candidateselection.hetcandidateselector import HETCandidateSelector from nebula.core.neighbormanagement.candidateselection.ringcandidateselector import RINGCandidateSelector diff --git a/nebula/core/neighbormanagement/modelhandlers/modelhandler.py b/nebula/core/neighbormanagement/modelhandlers/modelhandler.py index 31e061a3c..9a06de8fb 100644 --- a/nebula/core/neighbormanagement/modelhandlers/modelhandler.py +++ b/nebula/core/neighbormanagement/modelhandlers/modelhandler.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import Type class ModelHandler(ABC): @@ -18,7 +19,7 @@ def get_model(self, model): def pre_process_model(self): pass -def factory_ModelHandler(model_handler): +def factory_ModelHandler(model_handler) -> ModelHandler: from nebula.core.neighbormanagement.modelhandlers.stdmodelhandler import STDModelHandler from nebula.core.neighbormanagement.modelhandlers.aggmodelhandler import AGGModelHandler diff --git a/nebula/core/neighbormanagement/neighborpolicies/idleneighborpolicy.py b/nebula/core/neighbormanagement/neighborpolicies/idleneighborpolicy.py index f28ba85eb..475b368f6 100644 --- a/nebula/core/neighbormanagement/neighborpolicies/idleneighborpolicy.py +++ b/nebula/core/neighbormanagement/neighborpolicies/idleneighborpolicy.py @@ -24,7 +24,7 @@ def forget_nodes(self, node, forget_all=False): pass def get_nodes_known(self, neighbors_too=False): - return Set() + return set() def update_neighbors(self, node, remove=False): pass \ No newline at end of file diff --git a/nebula/core/neighbormanagement/neighborpolicies/neighborpolicy.py b/nebula/core/neighbormanagement/neighborpolicies/neighborpolicy.py index 297044976..7fb4e9071 100644 --- a/nebula/core/neighbormanagement/neighborpolicies/neighborpolicy.py +++ b/nebula/core/neighbormanagement/neighborpolicies/neighborpolicy.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import Type class NeighborPolicy(ABC): @@ -34,14 +35,14 @@ def get_nodes_known(self, neighbors_too=False): def update_neighbors(self, node, remove=False): pass -def factory_NeighborPolicy(topology): +def factory_NeighborPolicy(topology) -> NeighborPolicy: from nebula.core.neighbormanagement.neighborpolicies.idleneighborpolicy import IDLENeighborPolicy from nebula.core.neighbormanagement.neighborpolicies.fcneighborpolicy import FCNeighborPolicy from nebula.core.neighbormanagement.neighborpolicies.ringneighborpolicy import RINGNeighborPolicy from nebula.core.neighbormanagement.neighborpolicies.starneighborpolicy import STARNeighborPolicy options = { - 'random': IDLEneighborPolicy, # default value + 'random': IDLENeighborPolicy, # default value 'fully': FCNeighborPolicy, 'ring': RINGNeighborPolicy, 'star': IDLENeighborPolicy diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 5119e8327..eee4d59ae 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -139,7 +139,7 @@ def get_nodes_known(self, neighbors_too=False): def accept_model(self, source, decoded_model, rounds, round, epochs, n_neighbors, loss): if not self.accept_candidates_lock().locked(): self.model_handler.accept_model(decoded_model) - self.model_handler.setConfig(config=(rounds, round, epochs)) + self.model_handler.set_config(config=(rounds, round, epochs)) self.candidate_selector.add_candidate((source, n_neighbors, loss)) def add_candidate(self,source, n_neighbors, loss): From 092af15ac2281d0140cebceea28d4f4c417c1acb Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Tue, 16 Jul 2024 13:58:17 +0200 Subject: [PATCH 007/233] feat_timer_generator Additions: -TimerGenerator class to generate timer used for waiting updates from nodes -Weight_modifier for updates received if NodeManager is up --- nebula/core/engine.py | 13 +- nebula/core/neighbormanagement/nodemanager.py | 20 +++ .../core/neighbormanagement/timergenerator.py | 142 ++++++++++++++++++ nebula/core/network/communications.py | 3 +- 4 files changed, 173 insertions(+), 5 deletions(-) create mode 100644 nebula/core/neighbormanagement/timergenerator.py diff --git a/nebula/core/engine.py b/nebula/core/engine.py index fb327f1be..5a7ed2590 100755 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -183,6 +183,7 @@ def __init__( # Thread for the trainer service, it is created when the learning starts self.trainer_service = None + self._node_manager = None if self.config.participant["mobility_args"]["mobility"]: topology = self.config.participant["mobility_args"]["mobility_type"] model_handler = self.config.participant["mobility_args"]["model_handler"] @@ -273,13 +274,15 @@ async def _connection_connect_callback(self, source, message): if source not in self.cm.get_addrs_current_connections(myself=True): logging.info(f"πŸ”— handle_connection_message | Trigger | Connecting to {source}") await self.cm.connect(source, direct=True) - self.nm.update_neighbors(source) + if self.nm is not None: + self.nm.update_neighbors(source) @event_handler(nebula_pb2.ConnectionMessage, nebula_pb2.ConnectionMessage.Action.DISCONNECT) async def _connection_disconnect_callback(self, source, message): logging.info(f"πŸ”— handle_connection_message | Trigger | Received disconnection message from {source}") await self.cm.disconnect(source, mutual_disconnection=False) - self.nm.update_neighbors(source, remove=True) + if self.nm is not None: + self.nm.update_neighbors(source, remove=True) @event_handler(nebula_pb2.FederationMessage, nebula_pb2.FederationMessage.Action.FEDERATION_START) async def _start_federation_callback(self, source, message): @@ -634,7 +637,9 @@ async def send_reputation(self, malicious_nodes): message = self.cm.mm.generate_federation_message(nebula_pb2.FederationMessage.Action.REPUTATION, malicious_nodes) await self.cm.send_message_to_neighbors(message) - + def get_weight_modifier(self, addr): + return self.nm.get_weight_modifier(addr) if self.nm is not None else 1 + def _init_late_node(self): """ Method to initialize a late connected node, creating its trainer and setting up the learning process @@ -659,7 +664,7 @@ def _init_late_node(self): self.set_initialization_status(True) self.get_federation_ready_lock().release() - self._create_trainer_service(round=round) + self.create_trainer_service(round=round) self.cm.start_external_connection_service() class MaliciousNode(Engine): diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index eee4d59ae..8089524a4 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -8,6 +8,7 @@ from nebula.core.neighbormanagement.candidateselection.candidateselector import factory_CandidateSelector from nebula.core.neighbormanagement.modelhandlers.modelhandler import factory_ModelHandler from nebula.core.neighbormanagement.neighborpolicies.neighborpolicy import factory_NeighborPolicy +from nebula.core.neighbormanagement.timergenerator import TimerGenerator from nebula.core.pb import nebula_pb2 from nebula.core.network.communications import CommunicationsManager @@ -37,6 +38,9 @@ def __init__( self.recieve_offer_timer = 5 self._restructure_process_lock = Locker(name="restructure_process_lock") self.restructure = False + self.max_time_to_wait = 6 + self._timer_generator = TimerGenerator(self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False), self.max_time_to_wait, 80) + self.set_confings() @property @@ -55,6 +59,10 @@ def candidate_selector(self): def model_handler(self): return self._model_handler + @property + def timer_generator(self): + return self._timer_generator + def get_restructure_process_lock(self): return self._restructure_process_lock @@ -78,6 +86,15 @@ def set_confings(self): self.neighbor_policy.set_config([self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False), self.engine.cm.get_addrs_current_connections(only_direct=False, myself=False), self.engine.addr]) #self.model_handler.set_config([self.engine.get_round(), self.engine.config.participant["training_args"]["epochs"]]) self.candidate_selector.set_config([self.engine.get_loss(), self.config.participant["molibity_args"]["weight_distance"], self.config.participant["molibity_args"]["weight_het"]]) + + def get_timer(self): + return self.timer_generator.get_timer(self.engine.get_round()) + + def adjust_timer(self): + self.timer_generator.adjust_timer() + + def get_stop_condition(self): + return self.timer_generator.get_stop_condition() def add_weight_modifier(self, addr): self.weight_modifier_lock().acquire() @@ -124,6 +141,9 @@ def get_actions(self): def update_neighbors(self, node, remove=False): self.neighbor_policy.update_neighbors(node, remove) + self.timer_generator.update_node(node, remove) + if remove: + self.remove_weight_modifier(node) if not remove: self.neighbor_policy.meet_node(node) diff --git a/nebula/core/neighbormanagement/timergenerator.py b/nebula/core/neighbormanagement/timergenerator.py new file mode 100644 index 000000000..89ad5962b --- /dev/null +++ b/nebula/core/neighbormanagement/timergenerator.py @@ -0,0 +1,142 @@ +import asyncio +import numpy as np +from collections import deque +import time +from nebula.core.utils.locker import Locker + +class TimerGenerator(): + def __init__( + self, + nodes, + max_timer_value, + acceptable_percent, + round = None, + initial_timer_value = None, + max_historic_size = 10, + adaptative=False + ): + self.waiting_time = initial_timer_value if initial_timer_value != None else max_timer_value + self.max_timer_value = max_timer_value + self.acceptable_percent = acceptable_percent + self.max_historic_size = max_historic_size + self.round = round + self.nodes_historic = {node_id: deque(maxlen=self.max_historic_size) for node_id in nodes} + self.adaptative = adaptative + self.max_updates_number = len(self.nodes_historic) + self.updates_receive_from_nodes = set() + self.n_updates_receive = 0 + self.last_update_receive_time = 0 + self.start_moment = 0 + self.update_lock = Locker(name="update_lock") + self.all_updates_received = asyncio.Condition() + + def get_stop_condition(self): + return self.all_updates_received + + def get_timer(self, round): + self.round = round + self.start_moment = time.time() + return self.waiting_time + + def update_node(self, node, remove=False): + if remove: + self.nodes_historic.pop(node, None) + else: + self.nodes_historic.update({node: deque(maxlen=self.max_historic_size)}) + + async def receive_update(self, node_id, node_response_time): + """ + In this function the response time is saved in the historic, structures are updated and + condition is checked to stop the process because al responses are being received + + Args: + node_id : node addr + node_response_time : the time when the update was received + """ + t_n = node_response_time - self.start_moment + async with self.update_lock: + self.n_updates_receive +=1 + self.updates_receive_from_nodes.add(node_id) # this node has send update + self.nodes_historic[node_id].append(t_n) # add time + if self.n_updates_receive == self.max_updates_number: # it means all updates are receive + self.last_update_receive_time = t_n + async with self.all_updates_received: + self.all_updates_received.notify_all() + + def adjust_timer(self): + """ + The process of adjusting the timer is simple. if adaptative is not set up it will use the MAX_TIMER all the time. + If not, the strategy will depend on the percent of updates receive the last round. + -updates < 25%: + timer = MAX_TIMER, the results are so bad, then we need an aggresive strategy + -updates < 75% + timer = last_timer + 40%, big increase, because the results are not good enough yet + -updates < 100% + if updates < acceptable_percent + if the historic is enough it will use EMA * 1.25, else last_timer * 1.20 + acceptable_percent + if the historic is enough it will use EMA * 1.05, else last_timer * 1.15 + -all updates received + we will reduce the timer the min value between 10% of the time wasted and last_updated_time * 1.2, trying + to adjust the timer to not waste time when bad variations occur + """ + self._complete_data() # fill not receive data for historic + + if self.adaptative: + if self.n_updates_receive == self.max_updates_number: + update_times = [] + for node_id, times_deque in self.nodes_historic.items(): + if node_id in self.updates_receive_from_nodes: + if times_deque: + update_times.append(times_deque[-1]) + max_update_time = np.max(update_times) + new_waiting_time = self.waiting_time - (self.waiting_time - max_update_time)*0.1 # Reduced 10% the difference between last update receive and waiting_time + new_waiting_time = np.max([(max_update_time*1.20),new_waiting_time]) # select max from worst_time*1.20 or 10% reduced + self.waiting_time = self._change_timer_value(new_waiting_time) + else: + percentile = (self.n_updates_receive / self.max_updates_number) * 100 + if percentile <= 25: + self.waiting_time = self.max_timer_value + elif percentile <= 75: + self.waiting_time = self._change_timer_value(self.waiting_time*1.4) # timer + 40% from max + else: + max_ema = 0 + for node_id, times_deque in self.nodes_historic.items(): + if len(times_deque) == 0: # If the deque is empty, skip this node + continue + ema = self._exponential_moving_average(times_deque, alpha=0.1) + max_ema = max(max_ema, ema) + if percentile < self.acceptable_percent: + if not self.round >= self.max_historic_size: + self.waiting_time = self._change_timer_value(self.waiting_time*1.2) # timer + 20% from max + else: + self.waiting_time = self._change_timer_value(max_ema*1.25) # if enough data for historic EMA, EMA*1.25 + else: + if not self.round >= self.max_historic_size: + self.waiting_time = self._change_timer_value(self.waiting_time*1.05) # timer + 5% from max + else: + self.waiting_time = self._change_timer_value(max_ema*1.15) # if enough data for historic EMA, EMA*1.15 + # reset round variables + self.n_updates_receive = 0 + self.last_update_receive_time = 0 + self.updates_receive_from_nodes.clear() + + def _exponential_moving_average(self, data, alpha=0.1): + if not data: # Handle the case where the data list is empty + return 0 + ema = [data[0]] + data_left = list(data)[1:] + for value in data_left: + ema.append((1 - alpha) * ema[-1] + alpha * value) + return ema[-1] + + def _change_timer_value(self, new_value): + return new_value if new_value < self.max_timer_value else self.max_timer_value + + def _complete_data(self): + """ + fill empty times using worst acceptable time + """ + for node_id, times_deque in self.nodes_historic.items(): + if not node_id in self.updates_receive_from_nodes: + times_deque.append(self.max_timer_value) \ No newline at end of file diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 0bb37c989..30e56d24f 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -210,9 +210,10 @@ async def handle_model_message(self, source, message): f.write("timestamp,source_ip,nodes,round,current_round,cosine,euclidean,minkowski,manhattan,pearson_correlation,jaccard\n") f.write(f"{datetime.now()}, {source}, {message.round}, {current_round}, {cosine_value}, {euclidean_value}, {minkowski_value}, {manhattan_value}, {pearson_correlation_value}, {jaccard_value}\n") + model_weight = message.weight * self.engine.get_weight_modifier(source) await self.engine.aggregator.include_model_in_buffer( decoded_model, - message.weight, + model_weight, source=source, round=message.round, ) From 85b706bf22f5f32b645f44a5d2d6f0ce26d500d9 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 17 Jul 2024 11:32:28 +0200 Subject: [PATCH 008/233] fix_timer_integrated Additions: - TimerGenerator functions and references to be used - Updated learning cycle to use timer if mobility is up --- nebula/core/aggregation/aggregator.py | 88 ++++++++++--------- nebula/core/engine.py | 29 ++++++ nebula/core/neighbormanagement/nodemanager.py | 5 +- .../core/neighbormanagement/timergenerator.py | 8 +- nebula/core/network/communications.py | 20 +++-- 5 files changed, 95 insertions(+), 55 deletions(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index 928da2971..88af1ff2a 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from functools import partial import logging +from typing import Type from nebula.core.utils.locker import Locker from nebula.core.pb import nebula_pb2 @@ -8,49 +9,6 @@ class AggregatorException(Exception): pass - -def create_aggregator(config, engine): - from nebula.core.aggregation.fedavg import FedAvg - from nebula.core.aggregation.krum import Krum - from nebula.core.aggregation.median import Median - from nebula.core.aggregation.trimmedmean import TrimmedMean - from nebula.core.aggregation.blockchainReputation import BlockchainReputation - - ALGORITHM_MAP = { - "FedAvg": FedAvg, - "Krum": Krum, - "Median": Median, - "TrimmedMean": TrimmedMean, - "BlockchainReputation": BlockchainReputation, - } - algorithm = config.participant["aggregator_args"]["algorithm"] - aggregator = ALGORITHM_MAP.get(algorithm) - if aggregator: - return aggregator(config=config, engine=engine) - else: - raise AggregatorException(f"Aggregation algorithm {algorithm} not found.") - - -def create_target_aggregator(config, engine): - from nebula.core.aggregation.fedavg import FedAvg - from nebula.core.aggregation.krum import Krum - from nebula.core.aggregation.median import Median - from nebula.core.aggregation.trimmedmean import TrimmedMean - - ALGORITHM_MAP = { - "FedAvg": FedAvg, - "Krum": Krum, - "Median": Median, - "TrimmedMean": TrimmedMean, - } - algorithm = config.participant["defense_args"]["target_aggregation"] - aggregator = ALGORITHM_MAP.get(algorithm) - if aggregator: - return aggregator(config=config, engine=engine) - else: - raise AggregatorException(f"Aggregation algorithm {algorithm} not found.") - - class Aggregator(ABC): def __init__(self, config=None, engine=None): self.config = config @@ -90,6 +48,9 @@ def update_federation_nodes(self, federation_nodes): def set_waiting_global_update(self): self._waiting_global_update = True + def stop_waiting_for_updates(self): + self._aggregation_done_lock.release() + def reset(self): self._add_model_lock.acquire() self._federation_nodes.clear() @@ -208,3 +169,44 @@ def malicious_aggregate(self, models): aggregator.run_aggregation = partial(malicious_aggregate, aggregator) return aggregator + +def create_aggregator(config, engine) -> Aggregator: + from nebula.core.aggregation.fedavg import FedAvg + from nebula.core.aggregation.krum import Krum + from nebula.core.aggregation.median import Median + from nebula.core.aggregation.trimmedmean import TrimmedMean + from nebula.core.aggregation.blockchainReputation import BlockchainReputation + + ALGORITHM_MAP = { + "FedAvg": FedAvg, + "Krum": Krum, + "Median": Median, + "TrimmedMean": TrimmedMean, + "BlockchainReputation": BlockchainReputation, + } + algorithm = config.participant["aggregator_args"]["algorithm"] + aggregator = ALGORITHM_MAP.get(algorithm) + if aggregator: + return aggregator(config=config, engine=engine) + else: + raise AggregatorException(f"Aggregation algorithm {algorithm} not found.") + + +def create_target_aggregator(config, engine) -> Aggregator: + from nebula.core.aggregation.fedavg import FedAvg + from nebula.core.aggregation.krum import Krum + from nebula.core.aggregation.median import Median + from nebula.core.aggregation.trimmedmean import TrimmedMean + + ALGORITHM_MAP = { + "FedAvg": FedAvg, + "Krum": Krum, + "Median": Median, + "TrimmedMean": TrimmedMean, + } + algorithm = config.participant["defense_args"]["target_aggregation"] + aggregator = ALGORITHM_MAP.get(algorithm) + if aggregator: + return aggregator(config=config, engine=engine) + else: + raise AggregatorException(f"Aggregation algorithm {algorithm} not found.") \ No newline at end of file diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 72f913632..6ade03841 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -8,6 +8,7 @@ from nebula.core.aggregation.aggregator import create_aggregator, create_malicious_aggregator, create_target_aggregator from nebula.core.eventmanager import EventManager, event_handler from nebula.core.network.communications import CommunicationsManager +from nebula.core.neighbormanagement.nodemanager import NodeManager from nebula.core.pb import nebula_pb2 from nebula.core.utils.locker import Locker from lightning.pytorch.loggers import CSVLogger @@ -187,6 +188,7 @@ def __init__( # Thread for the trainer service, it is created when the learning starts self.trainer_service = None + self._waiting_updates_lock = Locker(name="waiting_updates_lock") self._node_manager = None if self.config.participant["mobility_args"]["mobility"]: topology = self.config.participant["mobility_args"]["mobility_type"] @@ -548,6 +550,7 @@ async def _dynamic_aggregator(self, aggregated_models_weights, malicious_nodes): async def _waiting_model_updates(self): logging.info(f"πŸ’€ Waiting convergence in round {self.round}.") + self._set_updates_timer() params = self.aggregator.get_aggregation() if params is not None: logging.info(f"_waiting_model_updates | Aggregation done for round {self.round}, including parameters in local model.") @@ -571,6 +574,7 @@ async def _learning_cycle(self): self.aggregator.reset() self.trainer.on_round_end() self.round = self.round + 1 + self._waiting_updates_lock.release() self.config.participant["federation_args"]["round"] = self.round # Set current round in config (send to the controller) self.get_round_lock().release() @@ -647,6 +651,31 @@ async def send_reputation(self, malicious_nodes): def get_weight_modifier(self, addr): return self.nm.get_weight_modifier(addr) if self.nm is not None else 1 + + async def receive_update_from_node(self, node, nose_response_time): + self.nm.receive_update_from_node(node, nose_response_time) + + def still_waiting_for_updates(self): + return not self._waiting_updates_lock.locked() + + async def _set_updates_timer(self): + if self.nm is not None: + task = asyncio.create_task(self._wait_stop_condition()) + + async def _wait_stop_condition(self): + """ + Set up the timer to wait for updates and check condition, after one of them is interrupted + the aggregation process will end up + """ + time_to_wait = self.nm.get_timer() + logging.info(f"πŸ’€ Waiting for all updates received or timeout = {time_to_wait} in round {self.round}.") + try: + await asyncio.wait_for(self.nm.get_stop_condition().wait(), timeout=time_to_wait) + except asyncio.TimeoutError: + pass + + self._waiting_updates_lock.acquire() + self.aggregator.stop_waiting_for_updates() def _init_late_node(self): """ diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 8089524a4..df3d892d6 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -94,7 +94,10 @@ def adjust_timer(self): self.timer_generator.adjust_timer() def get_stop_condition(self): - return self.timer_generator.get_stop_condition() + return self.timer_generator.get_stop_condition() + + async def receive_update_from_node(self, node_id, node_response_time): + await self.timer_generator.receive_update(node_id, node_response_time) def add_weight_modifier(self, addr): self.weight_modifier_lock().acquire() diff --git a/nebula/core/neighbormanagement/timergenerator.py b/nebula/core/neighbormanagement/timergenerator.py index 89ad5962b..eddc99bc6 100644 --- a/nebula/core/neighbormanagement/timergenerator.py +++ b/nebula/core/neighbormanagement/timergenerator.py @@ -35,7 +35,8 @@ def get_stop_condition(self): def get_timer(self, round): self.round = round - self.start_moment = time.time() + sm = time.time() + self.start_moment = round(sm, 2) return self.waiting_time def update_node(self, node, remove=False): @@ -53,12 +54,13 @@ async def receive_update(self, node_id, node_response_time): node_id : node addr node_response_time : the time when the update was received """ - t_n = node_response_time - self.start_moment + nrt = round(node_response_time,2) + t_n = nrt - self.start_moment if nrt - self.start_moment >= 0 else 0 async with self.update_lock: self.n_updates_receive +=1 self.updates_receive_from_nodes.add(node_id) # this node has send update self.nodes_historic[node_id].append(t_n) # add time - if self.n_updates_receive == self.max_updates_number: # it means all updates are receive + if self.n_updates_receive == self.max_updates_number: # it means all updates are being receive self.last_update_receive_time = t_n async with self.all_updates_received: self.all_updates_received.notify_all() diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 0cdf01dba..3915179cb 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -8,6 +8,7 @@ import requests import asyncio import subprocess +import time from nebula.addons.mobility import Mobility from nebula.core.network.discoverer import Discoverer @@ -211,14 +212,17 @@ async def handle_model_message(self, source, message): if os.stat(f"{self.log_dir}/participant_{self.idx}_similarity.csv").st_size == 0: f.write("timestamp,source_ip,nodes,round,current_round,cosine,euclidean,minkowski,manhattan,pearson_correlation,jaccard\n") f.write(f"{datetime.now()}, {source}, {message.round}, {current_round}, {cosine_value}, {euclidean_value}, {minkowski_value}, {manhattan_value}, {pearson_correlation_value}, {jaccard_value}\n") - - model_weight = message.weight * self.engine.get_weight_modifier(source) - await self.engine.aggregator.include_model_in_buffer( - decoded_model, - model_weight, - source=source, - round=message.round, - ) + + if self.engine.still_waiting_for_updates(): + model_weight = message.weight * self.engine.get_weight_modifier(source) + rt = time.time() + self.engine.receive_update_from_node(source, rt) + await self.engine.aggregator.include_model_in_buffer( + decoded_model, + model_weight, + source=source, + round=message.round, + ) else: if message.round != -1: From 2012cc11a8582de5e1947c5b822b4ca94993847f Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 17 Jul 2024 11:42:05 +0200 Subject: [PATCH 009/233] fix_lock_release_excep check lock is acquire --- nebula/core/engine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 6ade03841..f83193239 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -574,7 +574,8 @@ async def _learning_cycle(self): self.aggregator.reset() self.trainer.on_round_end() self.round = self.round + 1 - self._waiting_updates_lock.release() + if self._waiting_updates_lock.locked(): + self._waiting_updates_lock.release() self.config.participant["federation_args"]["round"] = self.round # Set current round in config (send to the controller) self.get_round_lock().release() From f53b32c38e78488d155ea002f1e1d421d7f3106d Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 17 Jul 2024 12:43:40 +0200 Subject: [PATCH 010/233] fix_waiting_logic Changed the way the node wait and update its state of waiting. --- nebula/core/aggregation/aggregator.py | 8 +++--- nebula/core/engine.py | 27 +++++++------------ .../core/neighbormanagement/timergenerator.py | 10 ++++--- 3 files changed, 21 insertions(+), 24 deletions(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index 88af1ff2a..8a1be9db2 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -47,9 +47,6 @@ def update_federation_nodes(self, federation_nodes): def set_waiting_global_update(self): self._waiting_global_update = True - - def stop_waiting_for_updates(self): - self._aggregation_done_lock.release() def reset(self): self._add_model_lock.acquire() @@ -116,6 +113,9 @@ async def include_model_in_buffer(self, model, weight, source=None, round=None, return + def set_timer(self, time_value): + self.config.participant["aggregator_args"]["aggregation_timeout"] = time_value + def get_aggregation(self): if self._aggregation_done_lock.acquire(timeout=self.config.participant["aggregator_args"]["aggregation_timeout"]): try: @@ -125,6 +125,8 @@ def get_aggregation(self): else: logging.error(f"πŸ”„ get_aggregation | Timeout reached for aggregation") + self.engine.stop_waiting_for_updates() + if self._waiting_global_update and len(self._pending_models_to_aggregate) == 1: logging.info(f"πŸ”„ get_aggregation | Received an global model. Overwriting my model with the aggregated model.") return next(iter(self._pending_models_to_aggregate.values()))[0] diff --git a/nebula/core/engine.py b/nebula/core/engine.py index f83193239..feddc3f14 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -654,29 +654,22 @@ def get_weight_modifier(self, addr): return self.nm.get_weight_modifier(addr) if self.nm is not None else 1 async def receive_update_from_node(self, node, nose_response_time): - self.nm.receive_update_from_node(node, nose_response_time) + if self.nm is not None: + self.nm.receive_update_from_node(node, nose_response_time) def still_waiting_for_updates(self): - return not self._waiting_updates_lock.locked() + return not self._waiting_updates_lock.locked() + + def stop_waiting_for_updates(self): + self._waiting_updates_lock.acquire() + if self.nm is not None: + self.nm.adjust_timer() async def _set_updates_timer(self): if self.nm is not None: - task = asyncio.create_task(self._wait_stop_condition()) - - async def _wait_stop_condition(self): - """ - Set up the timer to wait for updates and check condition, after one of them is interrupted - the aggregation process will end up - """ - time_to_wait = self.nm.get_timer() - logging.info(f"πŸ’€ Waiting for all updates received or timeout = {time_to_wait} in round {self.round}.") - try: - await asyncio.wait_for(self.nm.get_stop_condition().wait(), timeout=time_to_wait) - except asyncio.TimeoutError: - pass + time_to_wait = self.nm.get_timer() + self.aggregator.set_timer(time_to_wait) - self._waiting_updates_lock.acquire() - self.aggregator.stop_waiting_for_updates() def _init_late_node(self): """ diff --git a/nebula/core/neighbormanagement/timergenerator.py b/nebula/core/neighbormanagement/timergenerator.py index eddc99bc6..60a5e1b87 100644 --- a/nebula/core/neighbormanagement/timergenerator.py +++ b/nebula/core/neighbormanagement/timergenerator.py @@ -28,7 +28,7 @@ def __init__( self.last_update_receive_time = 0 self.start_moment = 0 self.update_lock = Locker(name="update_lock") - self.all_updates_received = asyncio.Condition() + #self.all_updates_received = asyncio.Condition() def get_stop_condition(self): return self.all_updates_received @@ -42,8 +42,10 @@ def get_timer(self, round): def update_node(self, node, remove=False): if remove: self.nodes_historic.pop(node, None) + self.max_updates_number -= 1 else: - self.nodes_historic.update({node: deque(maxlen=self.max_historic_size)}) + self.nodes_historic.update({node: deque(maxlen=self.max_historic_size)}) + self.max_updates_number += 1 async def receive_update(self, node_id, node_response_time): """ @@ -62,8 +64,8 @@ async def receive_update(self, node_id, node_response_time): self.nodes_historic[node_id].append(t_n) # add time if self.n_updates_receive == self.max_updates_number: # it means all updates are being receive self.last_update_receive_time = t_n - async with self.all_updates_received: - self.all_updates_received.notify_all() + #async with self.all_updates_received: + # self.all_updates_received.notify_all() def adjust_timer(self): """ From 5ca17610b72308028e2d51e139cd09ec3c7a0a15 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Fri, 15 Nov 2024 11:02:37 +0100 Subject: [PATCH 011/233] fixed_some_mobility_configurations --- nebula/core/engine.py | 196 +++++++++++++++++- nebula/core/network/communications.py | 1 + .../frontend/config/participant.json.example | 2 +- nebula/scenarios.py | 3 + 4 files changed, 199 insertions(+), 3 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 974c7f13e..d6c0b839c 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -11,6 +11,7 @@ from nebula.core.network.communications import CommunicationsManager from nebula.core.pb import nebula_pb2 from nebula.core.utils.locker import Locker +from nebula.core.neighbormanagement.nodemanager import NodeManager logging.getLogger("requests").setLevel(logging.WARNING) logging.getLogger("urllib3").setLevel(logging.WARNING) @@ -131,6 +132,17 @@ def __init__( self.trainer.model.set_communication_manager(self._cm) self._reporter = Reporter(config=self.config, trainer=self.trainer, cm=self.cm) + + # Mobility setup + self._waiting_updates_lock = Locker(name="waiting_updates_lock") + self._node_manager = None + if self.config.participant["mobility_args"]["mobility"]: + topology = self.config.participant["mobility_args"]["mobility_type"] + model_handler = self.config.participant["mobility_args"]["model_handler"] + self._node_manager = NodeManager(topology, model_handler, engine=self) + if self.config.participant["mobility_args"]["late_creation"]: + self._init_late_node() + self._event_manager = EventManager( default_callbacks=[ @@ -146,7 +158,18 @@ def __init__( # Register additional callbacks self._event_manager.register_event((nebula_pb2.FederationMessage, nebula_pb2.FederationMessage.Action.REPUTATION), self._reputation_callback) - # ... add more callbacks here + # ... add more callbacks here + self._event_manager.register_event((nebula_pb2.DiscoverMessage, nebula_pb2.DiscoverMessage.Action.DISCOVER_JOIN), self._discover_discover_join_callback) + self._event_manager.register_event((nebula_pb2.DiscoverMessage, nebula_pb2.DiscoverMessage.Action.DISCOVER_NODE), self._discover_discover_node_callback) + + self._event_manager.register_event((nebula_pb2.OfferMessage, nebula_pb2.OfferMessage.Action.OFFER_METRIC), self._offer_offer_metric_callback) + self._event_manager.register_event((nebula_pb2.OfferMessage, nebula_pb2.OfferMessage.Action.OFFER_MODEL), self._offer_offer_model_callback) + + self._event_manager.register_event((nebula_pb2.ConnectionMessage, nebula_pb2.ConnectionMessage.Action.LATE_CONNECT), self._connection_late_connect_callback) + self._event_manager.register_event((nebula_pb2.ConnectionMessage, nebula_pb2.ConnectionMessage.Action.RESTRUCTURE), self._connection_late_connect_callback) + + self._event_manager.register_event((nebula_pb2.LinkMessage, nebula_pb2.LinkMessage.Action.CONNECT_TO), self._link_connect_to_callback) + self._event_manager.register_event((nebula_pb2.LinkMessage, nebula_pb2.LinkMessage.Action.DISCONNECT_FROM), self._link_disconnect_from_callback) @property def cm(self): @@ -163,6 +186,10 @@ def event_manager(self): @property def aggregator(self): return self._aggregator + + @property + def nm(self): + return self._node_manager def get_aggregator_type(self): return type(self.aggregator) @@ -278,6 +305,121 @@ async def _federation_models_included_callback(self, source, message): finally: await self.cm.get_connections_lock().release_async() + # Mobility callbacks + @event_handler(nebula_pb2.ConnectionMessage, nebula_pb2.ConnectionMessage.Action.LATE_CONNECT) + async def _connection_late_connect_callback(self, source, message): + logging.info(f"πŸ”— handle_connection_message | Trigger | Received late_connect message from {source}") + if self.nm.accept_connection(source, joining=True): + self.nm.add_weight_modifier(source) + ct_actions , df_actions = self.nm.get_actions() + + # connect to + for addr in ct_actions.split(): + cnt_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.CONNECTO_TO, addr) + await self.cm.send_message(source, cnt_msg) + + # disconnect from + for addr in df_actions.split(): + df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, addr) + await self.cm.send_message(source, df_msg) + + await self.cm.connect(source, direct=True) + self.nm.update_neighbors(source) + + @event_handler(nebula_pb2.ConnectionMessage, nebula_pb2.ConnectionMessage.Action.RESTRUCTURE) + async def _connection_restructure_callback(self, source, message): + logging.info(f"πŸ”— handle_connection_message | Trigger | Received restructure message from {source}") + if self.nm.accept_connection(source, joining=False): + logging.info(f"πŸ”— handle_connection_message | Trigger | restructure connection accepted from {source}") + ct_actions , df_actions = self.nm.get_actions() + + for addr in ct_actions.split(): + cnt_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.CONNECTO_TO, addr) + await self.cm.send_message(source, cnt_msg) + + for addr in df_actions.split(): + df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, addr) + await self.cm.send_message(source, df_msg) + else: + logging.info(f"πŸ”— handle_connection_message | Trigger | restructure connection denied from {source}") + await self.cm.disconnect(source, mutual_disconnection=False) + self.nm.update_neighbors(source, remove=True) + + @event_handler(nebula_pb2.DiscoverMessage, nebula_pb2.DiscoverMessage.Action.DISCOVER_JOIN) + async def _discover_discover_join_callback(self, source, message): + logging.info(f"πŸ” handle_discover_message | Trigger | Received discover_join message from {source} ") + + self.nm.meet_node(source) + # if no neighbors means i'm new + if len(self.get_federation_nodes()) > 0: + model, rounds, round = self.cm.propagator.get_model_information(source, "stable") if self.get_round() > 0 else self.cm.propagator.get_model_information(source, "initialization") + epochs = self.config.participant["training_args"]["epochs"] + msg = self.cm.mm.generate_offer_message( + nebula_pb2.OfferMessage.Action.OFFER_MODEL, + len(self.get_federation_nodes()), + self.trainer.get_loss(), + model, + rounds, + round, + epochs + ) + await self.cm.send_message(source, msg) + + @event_handler(nebula_pb2.DiscoverMessage, nebula_pb2.DiscoverMessage.Action.DISCOVER_NODE) + async def _discover_discover_node_callback(self, source, message): + logging.info(f"πŸ” handle_discover_message | Trigger | Received discover_node message from {source} ") + self.nm.meet_node(source) + msg = self.cm.mm.generate_offer_message(nebula_pb2.OfferMessage.Action.OFFER_METRIC, len(self.get_federation_nodes()), self.trainer.get_loss()) + await self.cm.send_message(source, msg) + + @event_handler(nebula_pb2.OfferMessage, nebula_pb2.OfferMessage.Action.OFFER_MODEL) + async def _offer_offer_model_callback(self, source, message): + logging.info(f"πŸ” handle_offer_message | Trigger | Received offer_model message from {source}") + if not self.nm.get_restructure_process_lock().locked(): + decoded_model = self.trainer.deserialize_model(message.parameters) + self.nm.accept_model(source, decoded_model, message.rounds, message.round, message.epochs, message.n_neighbors, message.loss) + self.nm.add_candidate(source, message.n_neighbors, message.loss) + self.nm.meet_node(source) + + @event_handler(nebula_pb2.OfferMessage, nebula_pb2.OfferMessage.Action.OFFER_METRIC) + async def _offer_offer_metric_callback(self, source, message): + logging.info(f"πŸ” handle_offer_message | Trigger | Received offer_metric message from {source}") + if not self.nm.get_restructure_process_lock().locked(): + n_neighbors, loss, _, _, _, _ = message.arguments + self.nm.add_candidate(source, n_neighbors, loss) + self.nm.meet_node(source) + + @event_handler(nebula_pb2.LinkMessage, nebula_pb2.LinkMessage.Action.CONNECTO_TO) + async def _link_connect_to_callback(self, source, message): + logging.info(f"πŸ”— handle_link_message | Trigger | Received connecto_to message from {source}") + addrs = message.arguments + for addr in addrs: + await self.cm.connect(addr, direct=True) + self.nm.update_neighbors(addr) + self.nm.meet_node(source) + + @event_handler(nebula_pb2.LinkMessage, nebula_pb2.LinkMessage.Action.DISCONNECT_FROM) + async def _link_disconnect_from_callback(self, source, message): + logging.info(f"πŸ”— handle_link_message | Trigger | Received disconnect_from message from {source}") + addrs = message.arguments + for addr in addrs: + await self.cm.disconnect(source, mutual_disconnection=False) + self.nm.update_neighbors(addr, remove=True) + + def create_trainer_service(self, round=0): + if self.trainer_service is None: + self.trainer_service = threading.Thread( + target=self._start_learning, + args=(round,), + daemon=True, + name="trainer_service_thread-" + self.addr, + ) + self.trainer_service.start() + logging.info(f"Started trainer service thread...") + + def get_trainer_service(self): + return self.trainer_service + async def create_trainer_module(self): asyncio.create_task(self._start_learning()) logging.info(f"Started trainer module...") @@ -505,7 +647,57 @@ async def send_reputation(self, malicious_nodes): logging.info(f"Sending REPUTATION to the rest of the topology: {malicious_nodes}") message = self.cm.mm.generate_federation_message(nebula_pb2.FederationMessage.Action.REPUTATION, malicious_nodes) await self.cm.send_message_to_neighbors(message) - + + def get_weight_modifier(self, addr): + return self.nm.get_weight_modifier(addr) if self.nm is not None else 1 + + async def receive_update_from_node(self, node, node_response_time): + if self.nm is not None: + self.nm.receive_update_from_node(node, node_response_time) + + def still_waiting_for_updates(self): + return not self._waiting_updates_lock.locked() + + def stop_waiting_for_updates(self): + self._waiting_updates_lock.acquire() + if self.nm is not None: + self.nm.adjust_timer() + + async def _set_updates_timer(self): + if self.nm is not None: + time_to_wait = self.nm.get_timer() + self.aggregator.set_timer(time_to_wait) + + def _init_late_node(self): + """ + Method to initialize a late connected node, creating its trainer and setting up the learning process + + First step broadcasting discover message, after that we select candidates and connect to them. + The information to create the trainer is recieved from nodes that are already on federation and answared the discover message. + -model: params + -rounds: total rounds + -round: current round of the learning process + -epochs: epochs + """ + # sleep time before starting + sleep_time = self.config.participant["mobility_args"]["sleeping_time"] + asyncio.sleep(sleep_time) + + logging.info("🌐 Initializing late creation node life from Engine") + model, rounds, round, epochs = self.nm.start_late_connection_process() + + self.config.participant["scenario_args"]["rounds"] = rounds + self.config.participant["training_args"]["epochs"] = epochs + + self.round = round + + # self._trainer = trainer(model, self.dataset, config=self.config, logger=nebulalogger) + self.trainer.set_model_parameters(model, initialize=True) + + self.set_initialization_status(True) + self.get_federation_ready_lock().release() + self.create_trainer_service(round=round) + self.cm.start_external_connection_service() class MaliciousNode(Engine): def __init__(self, model, dataset, config=Config, trainer=Lightning, security=False, model_poisoning=False, poisoned_ratio=0, noise_type="gaussian"): diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 65c372c48..ffd51c201 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -239,6 +239,7 @@ async def handle_model_message(self, source, message): f.write("timestamp,source_ip,nodes,round,current_round,cosine,euclidean,minkowski,manhattan,pearson_correlation,jaccard\n") f.write(f"{datetime.now()}, {source}, {message.round}, {current_round}, {cosine_value}, {euclidean_value}, {minkowski_value}, {manhattan_value}, {pearson_correlation_value}, {jaccard_value}\n") + # getting modifier weight for update if self.engine.still_waiting_for_updates(): model_weight = message.weight * self.engine.get_weight_modifier(source) rt = time.time() diff --git a/nebula/frontend/config/participant.json.example b/nebula/frontend/config/participant.json.example index 864d30405..861e5deb6 100755 --- a/nebula/frontend/config/participant.json.example +++ b/nebula/frontend/config/participant.json.example @@ -72,7 +72,7 @@ }, "model_handler": "std", "late_creation": false, - "sleeping_time": 0, + "sleeping_time": 10, "weight_distance": 0.2, "weight_het": 0.8 }, diff --git a/nebula/scenarios.py b/nebula/scenarios.py index cefddf355..118b8e2de 100644 --- a/nebula/scenarios.py +++ b/nebula/scenarios.py @@ -552,6 +552,9 @@ def load_configurations_and_start_nodes(self, additional_participants=None, sche participant_config["device_args"]["uid"] = hashlib.sha1((str(participant_config["network_args"]["ip"]) + str(participant_config["network_args"]["port"]) + str(self.scenario_name)).encode()).hexdigest() participant_config["mobility_args"]["additional_node"]["status"] = True participant_config["mobility_args"]["additional_node"]["round_start"] = additional_participant["round"] + + # used for late creation nodes + participant_config["mobility_args"]["late_creation"] = True with open(additional_participant_file, "w") as f: json.dump(participant_config, f, sort_keys=False, indent=2) From cab1ba0ed49fef11bf980fccb07f6593117693fa Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Fri, 15 Nov 2024 11:37:06 +0100 Subject: [PATCH 012/233] update_nebula_p2b.py --- nebula/core/pb/nebula_pb2.py | 54 ++++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 21 deletions(-) diff --git a/nebula/core/pb/nebula_pb2.py b/nebula/core/pb/nebula_pb2.py index 333970e06..763d840cf 100755 --- a/nebula/core/pb/nebula_pb2.py +++ b/nebula/core/pb/nebula_pb2.py @@ -22,25 +22,37 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None _globals['_WRAPPER']._serialized_start=25 - _globals['_WRAPPER']._serialized_end=381 - _globals['_DISCOVERYMESSAGE']._serialized_start=384 - _globals['_DISCOVERYMESSAGE']._serialized_end=542 - _globals['_DISCOVERYMESSAGE_ACTION']._serialized_start=490 - _globals['_DISCOVERYMESSAGE_ACTION']._serialized_end=542 - _globals['_CONTROLMESSAGE']._serialized_start=545 - _globals['_CONTROLMESSAGE']._serialized_end=699 - _globals['_CONTROLMESSAGE_ACTION']._serialized_start=623 - _globals['_CONTROLMESSAGE_ACTION']._serialized_end=699 - _globals['_FEDERATIONMESSAGE']._serialized_start=702 - _globals['_FEDERATIONMESSAGE']._serialized_end=907 - _globals['_FEDERATIONMESSAGE_ACTION']._serialized_start=807 - _globals['_FEDERATIONMESSAGE_ACTION']._serialized_end=907 - _globals['_MODELMESSAGE']._serialized_start=909 - _globals['_MODELMESSAGE']._serialized_end=974 - _globals['_CONNECTIONMESSAGE']._serialized_start=976 - _globals['_CONNECTIONMESSAGE']._serialized_end=1084 - _globals['_CONNECTIONMESSAGE_ACTION']._serialized_start=1047 - _globals['_CONNECTIONMESSAGE_ACTION']._serialized_end=1084 - _globals['_RESPONSEMESSAGE']._serialized_start=1086 - _globals['_RESPONSEMESSAGE']._serialized_end=1121 + _globals['_WRAPPER']._serialized_end=526 + _globals['_DISCOVERYMESSAGE']._serialized_start=529 + _globals['_DISCOVERYMESSAGE']._serialized_end=687 + _globals['_DISCOVERYMESSAGE_ACTION']._serialized_start=635 + _globals['_DISCOVERYMESSAGE_ACTION']._serialized_end=687 + _globals['_CONTROLMESSAGE']._serialized_start=690 + _globals['_CONTROLMESSAGE']._serialized_end=844 + _globals['_CONTROLMESSAGE_ACTION']._serialized_start=768 + _globals['_CONTROLMESSAGE_ACTION']._serialized_end=844 + _globals['_FEDERATIONMESSAGE']._serialized_start=847 + _globals['_FEDERATIONMESSAGE']._serialized_end=1030 + _globals['_FEDERATIONMESSAGE_ACTION']._serialized_start=952 + _globals['_FEDERATIONMESSAGE_ACTION']._serialized_end=1030 + _globals['_MODELMESSAGE']._serialized_start=1032 + _globals['_MODELMESSAGE']._serialized_end=1097 + _globals['_CONNECTIONMESSAGE']._serialized_start=1100 + _globals['_CONNECTIONMESSAGE']._serialized_end=1243 + _globals['_CONNECTIONMESSAGE_ACTION']._serialized_start=1171 + _globals['_CONNECTIONMESSAGE_ACTION']._serialized_end=1243 + _globals['_DISCOVERMESSAGE']._serialized_start=1245 + _globals['_DISCOVERMESSAGE']._serialized_end=1359 + _globals['_DISCOVERMESSAGE_ACTION']._serialized_start=1312 + _globals['_DISCOVERMESSAGE_ACTION']._serialized_end=1359 + _globals['_OFFERMESSAGE']._serialized_start=1362 + _globals['_OFFERMESSAGE']._serialized_end=1568 + _globals['_OFFERMESSAGE_ACTION']._serialized_start=1525 + _globals['_OFFERMESSAGE_ACTION']._serialized_end=1568 + _globals['_LINKMESSAGE']._serialized_start=1570 + _globals['_LINKMESSAGE']._serialized_end=1689 + _globals['_LINKMESSAGE_ACTION']._serialized_start=1644 + _globals['_LINKMESSAGE_ACTION']._serialized_end=1689 + _globals['_RESPONSEMESSAGE']._serialized_start=1691 + _globals['_RESPONSEMESSAGE']._serialized_end=1726 # @@protoc_insertion_point(module_scope) From e7ac6414854eefda2c56622748c3e449bb461708 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Sun, 24 Nov 2024 12:11:13 +0100 Subject: [PATCH 013/233] feat_mobility_upgrade .External_connection_service working -Messages and callbacks integrated --- nebula/core/engine.py | 167 +++++++++++++++++- nebula/core/eventmanager.py | 1 + nebula/core/models/nebulamodel.py | 6 + .../candidateselection/candidateselector.py | 2 +- .../modelhandlers/modelhandler.py | 2 +- .../modelhandlers/stdmodelhandler.py | 1 + .../neighborpolicies/neighborpolicy.py | 10 +- nebula/core/neighbormanagement/nodemanager.py | 64 ++++--- .../core/neighbormanagement/timergenerator.py | 4 +- nebula/core/network/communications.py | 108 +++++++++++ nebula/core/network/messages.py | 4 +- nebula/core/network/nebulamulticasting.py | 28 ++- nebula/core/network/propagator.py | 23 +++ nebula/core/pb/nebula_pb2.py | 69 +++++--- nebula/node.py | 22 ++- 15 files changed, 426 insertions(+), 85 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 51ea21f63..c0ac88aaf 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -12,6 +12,7 @@ from nebula.core.network.communications import CommunicationsManager from nebula.core.pb import nebula_pb2 from nebula.core.utils.locker import Locker +from nebula.core.neighbormanagement.nodemanager import NodeManager logging.getLogger("requests").setLevel(logging.WARNING) logging.getLogger("urllib3").setLevel(logging.WARNING) @@ -145,6 +146,15 @@ def __init__( self.trainer.model.set_communication_manager(self._cm) self._reporter = Reporter(config=self.config, trainer=self.trainer, cm=self.cm) + + # Mobility setup + self._node_manager = None + mob = self.config.participant["mobility_args"]["mobility"] + if mob == True: + topology = self.config.participant["mobility_args"]["mobility_type"] + model_handler = "std" #self.config.participant["mobility_args"]["model_handler"] + self._node_manager = NodeManager(topology, model_handler, engine=self) + self._event_manager = EventManager( default_callbacks=[ @@ -155,6 +165,14 @@ def __init__( self._federation_ready_callback, self._start_federation_callback, self._federation_models_included_callback, + self._discover_discover_join_callback, + self._discover_discover_nodes_callback, + self._connection_late_connect_callback, + self._connection_restructure_callback, + self._offer_offer_model_callback, + self._offer_offer_metric_callback, + self._link_connect_to_callback, + self._link_disconnect_from_callback, ] ) @@ -166,8 +184,9 @@ def __init__( ), self._reputation_callback, ) + # ... add more callbacks here - + @property def cm(self): return self._cm @@ -190,6 +209,17 @@ def get_aggregator_type(self): @property def trainer(self): return self._trainer + + @property + def nm(self): + return self._node_manager + + async def _aditional_node_start(self): + logging.info(f"{self.addr} is an aditional node going to stablish connection with federation") + await self.nm.start_late_connection_process() + # continue .. + logging.info("Creating trainer service to start the federation process..") + #await self.cm.establish_connection_with_federation() def get_addr(self): return self.addr @@ -322,6 +352,141 @@ async def _federation_models_included_callback(self, source, message): finally: await self.cm.get_connections_lock().release_async() + # Mobility callbacks + @event_handler( + nebula_pb2.ConnectionMessage, + nebula_pb2.ConnectionMessage.Action.LATE_CONNECT, + ) + async def _connection_late_connect_callback(self, source, message): + logging.info(f"πŸ”— handle_connection_message | Trigger | Received late_connect message from {source}") + if self.nm.accept_connection(source, joining=True): + logging.info(f"πŸ”— Late connection acepted | source:{source}") + self.nm.add_weight_modifier(source) + ct_actions , df_actions = self.nm.get_actions() + + # connect to + for addr in ct_actions.split(): + cnt_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.CONNECTO_TO, addr) + await self.cm.send_message(source, cnt_msg) + # disconnect from + for addr in df_actions.split(): + df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, addr) + await self.cm.send_message(source, df_msg) + + await self.cm.connect(source, direct=True) + self.nm.update_neighbors(source) + + @event_handler( + nebula_pb2.ConnectionMessage, + nebula_pb2.ConnectionMessage.Action.RESTRUCTURE, + ) + async def _connection_restructure_callback(self, source, message): + logging.info(f"πŸ”— handle_connection_message | Trigger | Received restructure message from {source}") + if self.nm.accept_connection(source, joining=False): + logging.info(f"πŸ”— handle_connection_message | Trigger | restructure connection accepted from {source}") + ct_actions , df_actions = self.nm.get_actions() + + for addr in ct_actions.split(): + cnt_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.CONNECTO_TO, addr) + await self.cm.send_message(source, cnt_msg) + + for addr in df_actions.split(): + df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, addr) + await self.cm.send_message(source, df_msg) + else: + logging.info(f"πŸ”— handle_connection_message | Trigger | restructure connection denied from {source}") + await self.cm.disconnect(source, mutual_disconnection=False) + self.nm.update_neighbors(source, remove=True) + + @event_handler(nebula_pb2.DiscoverMessage, nebula_pb2.DiscoverMessage.Action.DISCOVER_JOIN) + async def _discover_discover_join_callback(self, source, message): + logging.info(f"πŸ” handle_discover_message | Trigger | Received discover_join message from {source} ") + self.nm.meet_node(source) + if len(self.get_federation_nodes()) > 0: + #model, rounds, round = await self.cm.propagator.get_model_information(source, "stable") if self.get_round() > 0 else await self.cm.propagator.get_model_information(source, "initialization") + model, rounds, round = await self.cm.propagator.get_model_information(source, "initialization") + # Process not initiated yet + if round != -1: + epochs = self.config.participant["training_args"]["epochs"] + msg = self.cm.mm.generate_offer_message( + nebula_pb2.OfferMessage.Action.OFFER_MODEL, + len(self.get_federation_nodes()), + 0, #self.trainer.get_loss(), + model, + rounds, + round, + epochs + ) + await self.cm.send_offer_model(source, msg) + await asyncio.sleep(1) + await self.cm.remove_temporary_connection(source) + else: + # for the starter federation node + logging.info() + else: + logging.info(f"πŸ”— Dissmissing discover join from {source} | no active connections at the moment") + + @event_handler( + nebula_pb2.DiscoverMessage, + nebula_pb2.DiscoverMessage.Action.DISCOVER_NODES, + ) + async def _discover_discover_nodes_callback(self, source, message): + logging.info(f"πŸ” handle_discover_message | Trigger | Received discover_node message from {source} ") + self.nm.meet_node(source) + msg = self.cm.mm.generate_offer_message(nebula_pb2.OfferMessage.Action.OFFER_METRIC, len(self.get_federation_nodes()), self.trainer.get_loss()) + await self.cm.send_message(source, msg) + + @event_handler( + nebula_pb2.OfferMessage, + nebula_pb2.OfferMessage.Action.OFFER_MODEL, + ) + async def _offer_offer_model_callback(self, source, message): + logging.info(f"πŸ” handle_offer_message | Trigger | Received offer_model message from {source}") + if not self.nm.get_restructure_process_lock().locked(): + try: + decoded_model = self.trainer.deserialize_model(message.parameters) + self.nm.accept_model(source, decoded_model, message.rounds, message.round, message.epochs, message.n_neighbors, message.loss) + self.nm.add_candidate(source, message.n_neighbors, message.loss) + self.nm.meet_node(source) + except RuntimeError: + pass + await self.cm.remove_temporary_connection(source) + + @event_handler( + nebula_pb2.OfferMessage, + nebula_pb2.OfferMessage.Action.OFFER_METRIC, + ) + async def _offer_offer_metric_callback(self, source, message): + logging.info(f"πŸ” handle_offer_message | Trigger | Received offer_metric message from {source}") + if not self.nm.get_restructure_process_lock().locked(): + n_neighbors, loss, _, _, _, _ = message.arguments + self.nm.add_candidate(source, n_neighbors, loss) + self.nm.meet_node(source) + + @event_handler( + nebula_pb2.LinkMessage, + nebula_pb2.LinkMessage.Action.CONNECT_TO, + ) + async def _link_connect_to_callback(self, source, message): + logging.info(f"πŸ”— handle_link_message | Trigger | Received connecto_to message from {source}") + addrs = message.arguments + for addr in addrs: + await self.cm.connect(addr, direct=True) + self.nm.update_neighbors(addr) + self.nm.meet_node(source) + + @event_handler( + nebula_pb2.LinkMessage, + nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, + ) + async def _link_disconnect_from_callback(self, source, message): + logging.info(f"πŸ”— handle_link_message | Trigger | Received disconnect_from message from {source}") + addrs = message.arguments + for addr in addrs: + await self.cm.disconnect(source, mutual_disconnection=False) + self.nm.update_neighbors(addr, remove=True) + + async def create_trainer_module(self): asyncio.create_task(self._start_learning()) logging.info("Started trainer module...") diff --git a/nebula/core/eventmanager.py b/nebula/core/eventmanager.py index 751bc1bd6..4d5b7a3c6 100755 --- a/nebula/core/eventmanager.py +++ b/nebula/core/eventmanager.py @@ -49,6 +49,7 @@ def register_event(self, handler_info, callback): self._event_callbacks[handler_info].append(callback) else: raise ValueError("The callback must be a callable function.") + def unregister_event(self, handler_info, callback): """Unregisters a previously registered callback for an event.""" diff --git a/nebula/core/models/nebulamodel.py b/nebula/core/models/nebulamodel.py index 412e178f6..0e9ea0c30 100755 --- a/nebula/core/models/nebulamodel.py +++ b/nebula/core/models/nebulamodel.py @@ -196,6 +196,8 @@ def __init__( # Communication manager for sending messages from the model (e.g., prototypes, gradients) # Model parameters are sent by default using network.propagator self.communication_manager = None + + self.current_loss = -1 def set_communication_manager(self, communication_manager): self.communication_manager = communication_manager @@ -222,8 +224,12 @@ def step(self, batch, batch_idx, phase): loss = self.criterion(y_pred, y) self.process_metrics(phase, y_pred, y, loss) + self.current_loss=loss return loss + def get_loss(self): + return self.current_loss + def training_step(self, batch, batch_idx): """ Training step for the model. diff --git a/nebula/core/neighbormanagement/candidateselection/candidateselector.py b/nebula/core/neighbormanagement/candidateselection/candidateselector.py index 8beb90099..0a3372dc0 100644 --- a/nebula/core/neighbormanagement/candidateselection/candidateselector.py +++ b/nebula/core/neighbormanagement/candidateselection/candidateselector.py @@ -34,5 +34,5 @@ def factory_CandidateSelector(topology) -> CandidateSelector: "random": HETCandidateSelector } - cs = options.get(topology) + cs = options.get(topology, FCCandidateSelector) return cs() \ No newline at end of file diff --git a/nebula/core/neighbormanagement/modelhandlers/modelhandler.py b/nebula/core/neighbormanagement/modelhandlers/modelhandler.py index 9a06de8fb..0431a0568 100644 --- a/nebula/core/neighbormanagement/modelhandlers/modelhandler.py +++ b/nebula/core/neighbormanagement/modelhandlers/modelhandler.py @@ -28,5 +28,5 @@ def factory_ModelHandler(model_handler) -> ModelHandler: "aggregator": AGGModelHandler } - cs = options.get(model_handler) + cs = options.get(model_handler, STDModelHandler) return cs() \ No newline at end of file diff --git a/nebula/core/neighbormanagement/modelhandlers/stdmodelhandler.py b/nebula/core/neighbormanagement/modelhandlers/stdmodelhandler.py index 028900176..85aff15cb 100644 --- a/nebula/core/neighbormanagement/modelhandlers/stdmodelhandler.py +++ b/nebula/core/neighbormanagement/modelhandlers/stdmodelhandler.py @@ -32,6 +32,7 @@ def accept_model(self, model): if not self.model_lock.locked(): self.model_lock.acquire() self.model = model + return self.model_lock.locked() def get_model(self, model): """ diff --git a/nebula/core/neighbormanagement/neighborpolicies/neighborpolicy.py b/nebula/core/neighbormanagement/neighborpolicies/neighborpolicy.py index 7fb4e9071..143736ccc 100644 --- a/nebula/core/neighbormanagement/neighborpolicies/neighborpolicy.py +++ b/nebula/core/neighbormanagement/neighborpolicies/neighborpolicy.py @@ -42,11 +42,11 @@ def factory_NeighborPolicy(topology) -> NeighborPolicy: from nebula.core.neighbormanagement.neighborpolicies.starneighborpolicy import STARNeighborPolicy options = { - 'random': IDLENeighborPolicy, # default value - 'fully': FCNeighborPolicy, - 'ring': RINGNeighborPolicy, - 'star': IDLENeighborPolicy + "random": IDLENeighborPolicy, # default value + "fully": FCNeighborPolicy, + "ring": RINGNeighborPolicy, + "star": IDLENeighborPolicy } - cs = options.get(topology) + cs = options.get(topology, IDLENeighborPolicy) return cs() \ No newline at end of file diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index df3d892d6..d4df9d0f1 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -11,6 +11,7 @@ from nebula.core.neighbormanagement.timergenerator import TimerGenerator from nebula.core.pb import nebula_pb2 from nebula.core.network.communications import CommunicationsManager +from nebula.addons.functions import print_msg_box from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -24,11 +25,15 @@ def __init__( model_handler, engine : "Engine" ): + print_msg_box(msg=f"Starting NodeManager module...\nTopology: {topology}", indent=2, title="NodeManager module") logging.info("🌐 Initializing Node Manager") self._engine = engine self.config = engine.get_config() - self._neighbor_policy = factory_NeighborPolicy(topology) - self._candidate_selector = factory_CandidateSelector(topology) + logging.info("Initializing Neighbor policy") + self._neighbor_policy = factory_NeighborPolicy(topology) + logging.info("Initializing Candidate Selector") + self._candidate_selector = factory_CandidateSelector(topology) + logging.info("Initializing Model Handler") self._model_handler = factory_ModelHandler(model_handler) self.late_connection_process_lock = Locker(name="late_connection_process_lock") self.weight_modifier = {} @@ -39,7 +44,8 @@ def __init__( self._restructure_process_lock = Locker(name="restructure_process_lock") self.restructure = False self.max_time_to_wait = 6 - self._timer_generator = TimerGenerator(self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False), self.max_time_to_wait, 80) + logging.info("Initializing Timer generator") + self._timer_generator = None #TimerGenerator(self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False), self.max_time_to_wait, 80) self.set_confings() @@ -83,9 +89,12 @@ def set_confings(self): - self weight distance - self weight hetereogeneity """ + logging.info("Building neighbor policy configuration..") self.neighbor_policy.set_config([self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False), self.engine.cm.get_addrs_current_connections(only_direct=False, myself=False), self.engine.addr]) #self.model_handler.set_config([self.engine.get_round(), self.engine.config.participant["training_args"]["epochs"]]) - self.candidate_selector.set_config([self.engine.get_loss(), self.config.participant["molibity_args"]["weight_distance"], self.config.participant["molibity_args"]["weight_het"]]) + logging.info("Building candidate selector configuration..") + self.candidate_selector.set_config([0, 0.2 , 0.8]) + #self.engine.trainer.get_loss(), self.config.participant["molibity_args"]["weight_distance"], self.config.participant["molibity_args"]["weight_het"] def get_timer(self): return self.timer_generator.get_timer(self.engine.get_round()) @@ -100,39 +109,39 @@ async def receive_update_from_node(self, node_id, node_response_time): await self.timer_generator.receive_update(node_id, node_response_time) def add_weight_modifier(self, addr): - self.weight_modifier_lock().acquire() + self.weight_modifier_lock.acquire() if not addr in self.weight_modifier: self.weight_modifier[addr] = self.new_node_weight_value - self.weight_modifier_lock().release() + self.weight_modifier_lock.release() def remove_weight_modifier(self, addr): - self.weight_modifier_lock().acquire() + self.weight_modifier_lock.acquire() if addr in self.weight_modifier: del self.weight_modifier[addr] - self.weight_modifier_lock().release() + self.weight_modifier_lock.release() def _update_weight_modifier(self, addr): - self.weight_modifier_lock().acquire() + self.weight_modifier_lock.acquire() if addr in self.weight_modifier: new_weight = self.weight_modifier[addr] - 1/self.engine.get_round()**2 if new_weight > 1: self.weight_modifier[addr] = new_weight else: self.remove_weight_modifier(addr) - self.weight_modifier_lock().release() + self.weight_modifier_lock.release() def get_weight_modifier(self, addr): - self.weight_modifier_lock().acquire() + self.weight_modifier_lock.acquire() if addr in self.weight_modifier: wm = self.weight_modifier[addr] self._update_weight_modifier(addr, self.engine.get_round()) else: wm = 1 - self.weight_modifier_lock().release() + self.weight_modifier_lock.release() return wm - def accept_connection(self,source): - if self.accept_candidates_lock().locked(): + def accept_connection(self, source, joining=False): + if self.accept_candidates_lock.locked(): return False return self.neighbor_policy.accept_connection(source) @@ -144,7 +153,7 @@ def get_actions(self): def update_neighbors(self, node, remove=False): self.neighbor_policy.update_neighbors(node, remove) - self.timer_generator.update_node(node, remove) + #self.timer_generator.update_node(node, remove) if remove: self.remove_weight_modifier(node) if not remove: @@ -160,13 +169,13 @@ def get_nodes_known(self, neighbors_too=False): return self.neighbor_policy.get_nodes_known(neighbors_too) def accept_model(self, source, decoded_model, rounds, round, epochs, n_neighbors, loss): - if not self.accept_candidates_lock().locked(): + if not self.accept_candidates_lock.locked(): self.model_handler.accept_model(decoded_model) self.model_handler.set_config(config=(rounds, round, epochs)) self.candidate_selector.add_candidate((source, n_neighbors, loss)) def add_candidate(self,source, n_neighbors, loss): - if not self.accept_candidates_lock().locked(): + if not self.accept_candidates_lock.locked(): self.candidate_selector.add_candidate((source, n_neighbors, loss)) async def start_late_connection_process(self): @@ -181,43 +190,44 @@ async def start_late_connection_process(self): Returns: data neccesary to create trainer """ - logging.info("🌐 Initializing start late connection process from Node Manager") + logging.info("🌐 Initializing late connection process..") self.late_connection_process_lock.acquire() best_candidates = [] self.candidate_selector.remove_candidates() - # send discover - msg = self.engine.cm.mm.generate_discover_message(nebula_pb2.DiscoverMessage.Action.DISCOVER_JOIN) - await self.engine.cm.establish_connection_with_federation(msg) + # find federation and send discover + await self.engine.cm.establish_connection_with_federation() # wait offer + logging.info(f"Waiting: {self.recieve_offer_timer}s to receive offers from federation") await asyncio.sleep(self.recieve_offer_timer) # acquire lock to not accept late candidates self.accept_candidates_lock.acquire() if self.candidate_selector.any_candidate(): - + logging.info("Candidates found to connect") # create message to send to new neightbors msg = self.engine.cm.mm.generate_connection_message(nebula_pb2.ConnectionMessage.Action.LATE_CONNECT) best_candidates = self.candidate_selector.select_candidates() - for addr, _, _ in best_candidates: - await self.engine.cm.connect(addr, direct=True) - await self.engine.cm.send_message(addr, msg) + #for addr, _, _ in best_candidates: + #await self.engine.cm.connect(addr, direct=True) + #await self.engine.cm.send_message(addr, msg) model, rounds, round, epochs = self.model_handler.get_model() - self.accept_candidates_lock().release() + self.accept_candidates_lock.release() self.late_connection_process_lock.release() return (model, rounds, round, epochs) # if no candidates, repeat process else: + logging.info("No Candidates found | repeating process") self.accept_candidates_lock.release() self.late_connection_process_lock.release() - return self.start_late_connection_process() + return await self.start_late_connection_process() diff --git a/nebula/core/neighbormanagement/timergenerator.py b/nebula/core/neighbormanagement/timergenerator.py index 60a5e1b87..10d3bae12 100644 --- a/nebula/core/neighbormanagement/timergenerator.py +++ b/nebula/core/neighbormanagement/timergenerator.py @@ -3,6 +3,7 @@ from collections import deque import time from nebula.core.utils.locker import Locker +import logging class TimerGenerator(): def __init__( @@ -15,6 +16,7 @@ def __init__( max_historic_size = 10, adaptative=False ): + logging.info("🌐 Initializing Timer Generator..") self.waiting_time = initial_timer_value if initial_timer_value != None else max_timer_value self.max_timer_value = max_timer_value self.acceptable_percent = acceptable_percent @@ -31,7 +33,7 @@ def __init__( #self.all_updates_received = asyncio.Condition() def get_stop_condition(self): - return self.all_updates_received + return True #self.all_updates_received def get_timer(self, round): self.round = round diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index b207d6a6e..09f11c1d4 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -18,6 +18,8 @@ from nebula.core.network.messages import MessagesManager from nebula.core.network.propagator import Propagator from nebula.core.pb import nebula_pb2 +from nebula.core.network.nebulamulticasting import NebulaConnectionService + from nebula.core.utils.helper import ( cosine_metric, euclidean_metric, @@ -77,6 +79,21 @@ def __init__(self, engine: "Engine"): self.loop = asyncio.get_event_loop() max_concurrent_tasks = 5 self.semaphore_send_model = asyncio.Semaphore(max_concurrent_tasks) + + # Connection service to communicate with external devices + self._external_connection_service = None + + # The line below is neccesary when mobility would be set up + mob = self.config.participant["mobility_args"]["mobility"] + aditional_node = self.config.participant["mobility_args"]["additional_node"]["status"] + if mob == True and not aditional_node: + self._external_connection_service = NebulaConnectionService(self.addr) + logging.info("Deploying External Connection Service") + self.ecs.start() + else: + logging.info("Deploying External Connection Service | No running") + self._external_connection_service = NebulaConnectionService(self.addr) + @property def engine(self): @@ -109,6 +126,10 @@ def propagator(self): @property def mobility(self): return self._mobility + + @property + def ecs(self): + return self._external_connection_service async def check_federation_ready(self): # Check if all my connections are in ready_connections @@ -154,6 +175,15 @@ async def handle_incoming_message(self, data, addr_from): await self.handle_model_message(source, message_wrapper.model_message) elif message_wrapper.HasField("connection_message"): await self.handle_connection_message(source, message_wrapper.connection_message) + elif message_wrapper.HasField("discover_message"): + if self.include_received_message_hash(hashlib.md5(data).hexdigest()): + await self.handle_discover_message(source, message_wrapper.discover_message) + elif message_wrapper.HasField("offer_message"): + if self.include_received_message_hash(hashlib.md5(data).hexdigest()): + await self.handle_offer_message(source, message_wrapper.offer_message) + elif message_wrapper.HasField("link_message"): + if self.include_received_message_hash(hashlib.md5(data).hexdigest()): + await self.handle_offer_message(source, message_wrapper.link_message) else: logging.info(f"Unknown handler for message: {message_wrapper}") except Exception as e: @@ -328,6 +358,60 @@ async def handle_connection_message(self, source, message): await self.engine.event_manager.trigger_event(source, message) except Exception as e: logging.exception(f"πŸ”— handle_connection_message | Error while processing: {message.action} | {e}") + + async def handle_discover_message(self, source, message): + logging.info(f"πŸ” handle_discover_message | Received [Action {message.action}] from {source}") + try: + await self.engine.event_manager.trigger_event(source, message) + except Exception as e: + logging.error(f"πŸ” handle_discover_message | Error while processing: {e}") + + async def handle_offer_message(self, source, message): + logging.info(f"πŸ” handle_offer_message | Received [Action {message.action}] from {source}") + try: + await self.engine.event_manager.trigger_event(source, message) + except Exception as e: + logging.error(f"πŸ” handle_offer_message | Error while processing: {message.action} {message.arguments} | {e}") + + async def handle_link_message(self, source, message): + logging.info(f"πŸ” handle_link_message | Received [Action {message.action}] from {source}") + try: + await self.engine.event_manager.trigger_event(source, message) + except Exception as e: + logging.error(f"πŸ” handle_link_message | Error while processing: {message.action} {message.arguments} | {e}") + + def start_external_connection_service(self): + self.ecs = NebulaConnectionService(self.addr) + self.ecs.start() + + def stop_external_connection_service(self): + self.ecs.stop() + + def init_external_connection_service(self): + self.ecs = NebulaConnectionService(self.addr) + self.start_external_connection_service() + + async def establish_connection_with_federation(self): + """ + Using ExternalConnectionService to get addrs on local network, after that + stablishment of TCP connection and send the message broadcasted + """ + logging.info("Searching federation process beginning..") + addrs = self.ecs.find_federation() + logging.info(f"Found federation devices | addrs {addrs}") + msg = self.mm.generate_discover_message(nebula_pb2.DiscoverMessage.Action.DISCOVER_JOIN) + logging.info("Starting communications with devices found") + for addr in addrs: + await self.connect(addr, direct=False) + await asyncio.sleep(1) + while not self.verify_connections(addrs): + await asyncio.sleep(1) + current_connections = await self.get_addrs_current_connections() + logging.info(f"Connections verified after searching: {current_connections}") + for addr in addrs: + logging.info(f"Sending discover join to --> {addr}") + asyncio.create_task(self.send_message(addr, msg)) + await asyncio.sleep(1) def get_connections_lock(self): return self.connections_lock @@ -666,6 +750,22 @@ async def send_model(self, dest_addr, round, serialized_model, weight=1): except Exception as e: logging.exception(f"❗️ Cannot send model to {dest_addr}: {e!s}") await self.disconnect(dest_addr, mutual_disconnection=False) + + async def send_offer_model(self, dest_addr, offer_message): + async with self.semaphore_send_model: + try: + conn = self.connections.get(dest_addr) + if conn is None: + logging.info(f"❗️ Connection with {dest_addr} not found") + return + logging.info( + f"Sending model to {dest_addr}" + ) + await conn.send(data=offer_message, is_compressed=True) + logging.info(f"Offer_Model sent to {dest_addr}") + except Exception as e: + logging.exception(f"❗️ Cannot send model to {dest_addr}: {e!s}") + await self.disconnect(dest_addr, mutual_disconnection=False) async def establish_connection(self, addr, direct=True, reconnect=False): logging.info(f"πŸ”— [outgoing] Establishing connection with {addr} (direct: {direct})") @@ -864,6 +964,14 @@ async def disconnect(self, dest_addr, mutual_disconnection=True): current_connections = set(current_connections) logging.info(f"Current connections: {current_connections}") self.config.update_neighbors_from_config(current_connections, dest_addr) + + async def remove_temporary_connection(self, temp_addr): + logging.info(f"Removing temporary conneciton:{temp_addr}..") + try: + await self.get_connections_lock().acquire_async() + self.connections.pop(temp_addr, None) + finally: + await self.get_connections_lock().release_async() async def get_all_addrs_current_connections(self, only_direct=False, only_undirected=False): try: diff --git a/nebula/core/network/messages.py b/nebula/core/network/messages.py index 57fddfa01..10573db0e 100755 --- a/nebula/core/network/messages.py +++ b/nebula/core/network/messages.py @@ -88,7 +88,7 @@ def generate_discover_message(self, action): ) message_wrapper = nebula_pb2.Wrapper() message_wrapper.source = self.addr - message_wrapper.discovery_message.CopyFrom(message) + message_wrapper.discover_message.CopyFrom(message) data = message_wrapper.SerializeToString() return data @@ -105,7 +105,7 @@ def generate_offer_message(self, action, n_neighbors, loss, serialized_model=Non ) message_wrapper = nebula_pb2.Wrapper() message_wrapper.source = self.addr - message_wrapper.discovery_message.CopyFrom(message) + message_wrapper.offer_message.CopyFrom(message) data = message_wrapper.SerializeToString() return data diff --git a/nebula/core/network/nebulamulticasting.py b/nebula/core/network/nebulamulticasting.py index 0612a0789..0d4759e32 100644 --- a/nebula/core/network/nebulamulticasting.py +++ b/nebula/core/network/nebulamulticasting.py @@ -54,12 +54,13 @@ def listen(self): sock.close() return else: - if self._is_nebula_message(data): + if self._is_nebula_message(data): + logging.info("Nebula request recieved | response on the way..") self.respond(addr) #time.sleep(1) #self.stop() except Exception as e: - logging.info('Error in Nebula npnp server listening: %s', e) + logging.error('Error in Nebula npnp server listening: %s', e) def _is_nebula_message(self, msg): msg_str = msg.decode('utf-8') @@ -80,7 +81,7 @@ def respond(self, addr): outSock.sendto(UPNP_RESPOND.encode('ASCII'), addr) outSock.close() except Exception as e: - logging.info('Error in Nebula upnp response message to client %s', e) + logging.error('Error in Nebula upnp response message to client %s', e) class NebulaClient(threading.Thread): # 30 seconds for search_interval @@ -89,6 +90,7 @@ class NebulaClient(threading.Thread): BCAST_PORT = 1900 def __init__(self, nebula_service: "NebulaConnectionService"): + logging.info("Initializating Nebula Multicasting Client") threading.Thread.__init__(self) self.interrupted = False self.ns = nebula_service @@ -104,6 +106,7 @@ def keep_search(self): """ run search function every SEARCH_INTERVAL """ + logging.info("Federation searching loop start") try: while True: self.search() @@ -112,13 +115,14 @@ def keep_search(self): if self.interrupted: return except Exception as e: - logging.info('Error in Nebula upnp client keep search %s', e) + logging.error('Error in Nebula upnp client keep search %s', e) def search(self): """ broadcast SSDP DISCOVER message to LAN network filter our protocal and add to network """ + logging.info("Client thread searching for nodes..") try: SSDP_DISCOVER = ('M-SEARCH * HTTP/1.1\r\n' + 'HOST: 239.255.255.250:1900\r\n' + @@ -133,6 +137,7 @@ def search(self): while True: data, addr = sock.recvfrom(1024) if self._is_nebula_message(data): + logging.info("Recieved response from server") self.ns.response_recieved(data, addr) except: sock.close() @@ -145,6 +150,7 @@ class NebulaConnectionService(ExternalConnectionService): def __init__(self, addr): self.addrs_found_lock = Locker(name="addrs_found_lock") + self.get_nodes_lock= Locker(name="get_nodes_lock") self.nodes_found = [] self.repeatsearch_interval = 3 self.addr = addr @@ -167,28 +173,32 @@ def find_federation(self): self.client = NebulaClient(self) self.client.start() time.sleep(self.repeatsearch_interval) - while not len(self.get_nodes()): + while len(self.get_nodes()) == 0: + logging.info("Waiting for server response..") time.sleep(self.repeatsearch_interval) self.client.stop() + return self.get_nodes() def response_recieved(self, data, addr): + logging.info("Parsing response..") msg_str = data.decode('utf-8') self._add_addr(msg_str) def _add_addr(self, msg_str): - self.mutex.acquire() + self.addrs_found_lock.acquire() lineas = msg_str.splitlines() # Buscar la lΓ­nea que contiene "LOCATION: " for linea in lineas: if linea.strip().startswith("LOCATION:"): addr = linea.split(": ")[1].strip() break + logging.info(f"Device addr received: {addr}") self.nodes_found.append(addr) - self.mutex.release() + self.addrs_found_lock.release() def get_nodes(self): - self.mutex.acquire() + self.get_nodes_lock.acquire() cp = self.nodes_found.copy() - self.mutex.release() + self.get_nodes_lock.release() return cp \ No newline at end of file diff --git a/nebula/core/network/propagator.py b/nebula/core/network/propagator.py index 43d4ba9fe..f1499a530 100755 --- a/nebula/core/network/propagator.py +++ b/nebula/core/network/propagator.py @@ -168,3 +168,26 @@ async def propagate(self, strategy_id: str): await asyncio.sleep(self.interval) return True + + async def get_model_information(self, dest_addr, strategy_id: str): + if strategy_id not in self.strategies: + logging.info(f"Strategy {strategy_id} not found.") + return None + if self.get_round() is None: + logging.info("Propagation halted: round is not set.") + return None + + strategy = self.strategies[strategy_id] + logging.info(f"Preparing model information with strategy to make an offer: {strategy_id}") + + model_params, weight = strategy.prepare_model_payload(None) + rounds = self.engine.total_rounds + + if model_params: + serialized_model = ( + model_params if isinstance(model_params, bytes) else self.trainer.serialize_model(model_params) + ) + return (serialized_model, rounds, self.get_round()) + + return None + diff --git a/nebula/core/pb/nebula_pb2.py b/nebula/core/pb/nebula_pb2.py index bf6290e6e..d5c476808 100755 --- a/nebula/core/pb/nebula_pb2.py +++ b/nebula/core/pb/nebula_pb2.py @@ -1,47 +1,58 @@ +# -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: nebula.proto # Protobuf Python Version: 4.25.3 """Generated protocol buffer code.""" - from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder - # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x0cnebula.proto\x12\x06nebula"\xe4\x02\n\x07Wrapper\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x35\n\x11\x64iscovery_message\x18\x02 \x01(\x0b\x32\x18.nebula.DiscoveryMessageH\x00\x12\x31\n\x0f\x63ontrol_message\x18\x03 \x01(\x0b\x32\x16.nebula.ControlMessageH\x00\x12\x37\n\x12\x66\x65\x64\x65ration_message\x18\x04 \x01(\x0b\x32\x19.nebula.FederationMessageH\x00\x12-\n\rmodel_message\x18\x05 \x01(\x0b\x32\x14.nebula.ModelMessageH\x00\x12\x37\n\x12\x63onnection_message\x18\x06 \x01(\x0b\x32\x19.nebula.ConnectionMessageH\x00\x12\x33\n\x10response_message\x18\x07 \x01(\x0b\x32\x17.nebula.ResponseMessageH\x00\x42\t\n\x07message"\x9e\x01\n\x10\x44iscoveryMessage\x12/\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1f.nebula.DiscoveryMessage.Action\x12\x10\n\x08latitude\x18\x02 \x01(\x02\x12\x11\n\tlongitude\x18\x03 \x01(\x02"4\n\x06\x41\x63tion\x12\x0c\n\x08\x44ISCOVER\x10\x00\x12\x0c\n\x08REGISTER\x10\x01\x12\x0e\n\nDEREGISTER\x10\x02"\x9a\x01\n\x0e\x43ontrolMessage\x12-\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1d.nebula.ControlMessage.Action\x12\x0b\n\x03log\x18\x02 \x01(\t"L\n\x06\x41\x63tion\x12\t\n\x05\x41LIVE\x10\x00\x12\x0c\n\x08OVERHEAD\x10\x01\x12\x0c\n\x08MOBILITY\x10\x02\x12\x0c\n\x08RECOVERY\x10\x03\x12\r\n\tWEAK_LINK\x10\x04"\xcd\x01\n\x11\x46\x65\x64\x65rationMessage\x12\x30\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32 .nebula.FederationMessage.Action\x12\x11\n\targuments\x18\x02 \x03(\t\x12\r\n\x05round\x18\x03 \x01(\x05"d\n\x06\x41\x63tion\x12\x14\n\x10\x46\x45\x44\x45RATION_START\x10\x00\x12\x0e\n\nREPUTATION\x10\x01\x12\x1e\n\x1a\x46\x45\x44\x45RATION_MODELS_INCLUDED\x10\x02\x12\x14\n\x10\x46\x45\x44\x45RATION_READY\x10\x03"A\n\x0cModelMessage\x12\x12\n\nparameters\x18\x01 \x01(\x0c\x12\x0e\n\x06weight\x18\x02 \x01(\x03\x12\r\n\x05round\x18\x03 \x01(\x05"l\n\x11\x43onnectionMessage\x12\x30\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32 .nebula.ConnectionMessage.Action"%\n\x06\x41\x63tion\x12\x0b\n\x07\x43ONNECT\x10\x00\x12\x0e\n\nDISCONNECT\x10\x01"#\n\x0fResponseMessage\x12\x10\n\x08response\x18\x01 \x01(\tb\x06proto3' -) + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cnebula.proto\x12\x06nebula\"\xf5\x03\n\x07Wrapper\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x35\n\x11\x64iscovery_message\x18\x02 \x01(\x0b\x32\x18.nebula.DiscoveryMessageH\x00\x12\x31\n\x0f\x63ontrol_message\x18\x03 \x01(\x0b\x32\x16.nebula.ControlMessageH\x00\x12\x37\n\x12\x66\x65\x64\x65ration_message\x18\x04 \x01(\x0b\x32\x19.nebula.FederationMessageH\x00\x12-\n\rmodel_message\x18\x05 \x01(\x0b\x32\x14.nebula.ModelMessageH\x00\x12\x37\n\x12\x63onnection_message\x18\x06 \x01(\x0b\x32\x19.nebula.ConnectionMessageH\x00\x12\x33\n\x10response_message\x18\x07 \x01(\x0b\x32\x17.nebula.ResponseMessageH\x00\x12\x33\n\x10\x64iscover_message\x18\x08 \x01(\x0b\x32\x17.nebula.DiscoverMessageH\x00\x12-\n\roffer_message\x18\t \x01(\x0b\x32\x14.nebula.OfferMessageH\x00\x12+\n\x0clink_message\x18\n \x01(\x0b\x32\x13.nebula.LinkMessageH\x00\x42\t\n\x07message\"\x9e\x01\n\x10\x44iscoveryMessage\x12/\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1f.nebula.DiscoveryMessage.Action\x12\x10\n\x08latitude\x18\x02 \x01(\x02\x12\x11\n\tlongitude\x18\x03 \x01(\x02\"4\n\x06\x41\x63tion\x12\x0c\n\x08\x44ISCOVER\x10\x00\x12\x0c\n\x08REGISTER\x10\x01\x12\x0e\n\nDEREGISTER\x10\x02\"\x9a\x01\n\x0e\x43ontrolMessage\x12-\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1d.nebula.ControlMessage.Action\x12\x0b\n\x03log\x18\x02 \x01(\t\"L\n\x06\x41\x63tion\x12\t\n\x05\x41LIVE\x10\x00\x12\x0c\n\x08OVERHEAD\x10\x01\x12\x0c\n\x08MOBILITY\x10\x02\x12\x0c\n\x08RECOVERY\x10\x03\x12\r\n\tWEAK_LINK\x10\x04\"\xcd\x01\n\x11\x46\x65\x64\x65rationMessage\x12\x30\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32 .nebula.FederationMessage.Action\x12\x11\n\targuments\x18\x02 \x03(\t\x12\r\n\x05round\x18\x03 \x01(\x05\"d\n\x06\x41\x63tion\x12\x14\n\x10\x46\x45\x44\x45RATION_START\x10\x00\x12\x0e\n\nREPUTATION\x10\x01\x12\x1e\n\x1a\x46\x45\x44\x45RATION_MODELS_INCLUDED\x10\x02\x12\x14\n\x10\x46\x45\x44\x45RATION_READY\x10\x03\"A\n\x0cModelMessage\x12\x12\n\nparameters\x18\x01 \x01(\x0c\x12\x0e\n\x06weight\x18\x02 \x01(\x03\x12\r\n\x05round\x18\x03 \x01(\x05\"\x8f\x01\n\x11\x43onnectionMessage\x12\x30\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32 .nebula.ConnectionMessage.Action\"H\n\x06\x41\x63tion\x12\x0b\n\x07\x43ONNECT\x10\x00\x12\x0e\n\nDISCONNECT\x10\x01\x12\x10\n\x0cLATE_CONNECT\x10\x02\x12\x0f\n\x0bRESTRUCTURE\x10\x03\"r\n\x0f\x44iscoverMessage\x12.\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1e.nebula.DiscoverMessage.Action\"/\n\x06\x41\x63tion\x12\x11\n\rDISCOVER_JOIN\x10\x00\x12\x12\n\x0e\x44ISCOVER_NODES\x10\x01\"\xce\x01\n\x0cOfferMessage\x12+\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1b.nebula.OfferMessage.Action\x12\x13\n\x0bn_neighbors\x18\x02 \x01(\x02\x12\x0c\n\x04loss\x18\x03 \x01(\x02\x12\x12\n\nparameters\x18\x04 \x01(\x0c\x12\x0e\n\x06rounds\x18\x05 \x01(\x05\x12\r\n\x05round\x18\x06 \x01(\x05\x12\x0e\n\x06\x65pochs\x18\x07 \x01(\x05\"+\n\x06\x41\x63tion\x12\x0f\n\x0bOFFER_MODEL\x10\x00\x12\x10\n\x0cOFFER_METRIC\x10\x01\"w\n\x0bLinkMessage\x12*\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1a.nebula.LinkMessage.Action\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x01(\t\"-\n\x06\x41\x63tion\x12\x0e\n\nCONNECT_TO\x10\x00\x12\x13\n\x0f\x44ISCONNECT_FROM\x10\x01\"#\n\x0fResponseMessage\x12\x10\n\x08response\x18\x01 \x01(\tb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "nebula_pb2", _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'nebula_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _globals["_WRAPPER"]._serialized_start = 25 - _globals["_WRAPPER"]._serialized_end = 381 - _globals["_DISCOVERYMESSAGE"]._serialized_start = 384 - _globals["_DISCOVERYMESSAGE"]._serialized_end = 542 - _globals["_DISCOVERYMESSAGE_ACTION"]._serialized_start = 490 - _globals["_DISCOVERYMESSAGE_ACTION"]._serialized_end = 542 - _globals["_CONTROLMESSAGE"]._serialized_start = 545 - _globals["_CONTROLMESSAGE"]._serialized_end = 699 - _globals["_CONTROLMESSAGE_ACTION"]._serialized_start = 623 - _globals["_CONTROLMESSAGE_ACTION"]._serialized_end = 699 - _globals["_FEDERATIONMESSAGE"]._serialized_start = 702 - _globals["_FEDERATIONMESSAGE"]._serialized_end = 907 - _globals["_FEDERATIONMESSAGE_ACTION"]._serialized_start = 807 - _globals["_FEDERATIONMESSAGE_ACTION"]._serialized_end = 907 - _globals["_MODELMESSAGE"]._serialized_start = 909 - _globals["_MODELMESSAGE"]._serialized_end = 974 - _globals["_CONNECTIONMESSAGE"]._serialized_start = 976 - _globals["_CONNECTIONMESSAGE"]._serialized_end = 1084 - _globals["_CONNECTIONMESSAGE_ACTION"]._serialized_start = 1047 - _globals["_CONNECTIONMESSAGE_ACTION"]._serialized_end = 1084 - _globals["_RESPONSEMESSAGE"]._serialized_start = 1086 - _globals["_RESPONSEMESSAGE"]._serialized_end = 1121 + DESCRIPTOR._options = None + _globals['_WRAPPER']._serialized_start=25 + _globals['_WRAPPER']._serialized_end=526 + _globals['_DISCOVERYMESSAGE']._serialized_start=529 + _globals['_DISCOVERYMESSAGE']._serialized_end=687 + _globals['_DISCOVERYMESSAGE_ACTION']._serialized_start=635 + _globals['_DISCOVERYMESSAGE_ACTION']._serialized_end=687 + _globals['_CONTROLMESSAGE']._serialized_start=690 + _globals['_CONTROLMESSAGE']._serialized_end=844 + _globals['_CONTROLMESSAGE_ACTION']._serialized_start=768 + _globals['_CONTROLMESSAGE_ACTION']._serialized_end=844 + _globals['_FEDERATIONMESSAGE']._serialized_start=847 + _globals['_FEDERATIONMESSAGE']._serialized_end=1052 + _globals['_FEDERATIONMESSAGE_ACTION']._serialized_start=952 + _globals['_FEDERATIONMESSAGE_ACTION']._serialized_end=1052 + _globals['_MODELMESSAGE']._serialized_start=1054 + _globals['_MODELMESSAGE']._serialized_end=1119 + _globals['_CONNECTIONMESSAGE']._serialized_start=1122 + _globals['_CONNECTIONMESSAGE']._serialized_end=1265 + _globals['_CONNECTIONMESSAGE_ACTION']._serialized_start=1193 + _globals['_CONNECTIONMESSAGE_ACTION']._serialized_end=1265 + _globals['_DISCOVERMESSAGE']._serialized_start=1267 + _globals['_DISCOVERMESSAGE']._serialized_end=1381 + _globals['_DISCOVERMESSAGE_ACTION']._serialized_start=1334 + _globals['_DISCOVERMESSAGE_ACTION']._serialized_end=1381 + _globals['_OFFERMESSAGE']._serialized_start=1384 + _globals['_OFFERMESSAGE']._serialized_end=1590 + _globals['_OFFERMESSAGE_ACTION']._serialized_start=1547 + _globals['_OFFERMESSAGE_ACTION']._serialized_end=1590 + _globals['_LINKMESSAGE']._serialized_start=1592 + _globals['_LINKMESSAGE']._serialized_end=1711 + _globals['_LINKMESSAGE_ACTION']._serialized_start=1666 + _globals['_LINKMESSAGE_ACTION']._serialized_end=1711 + _globals['_RESPONSEMESSAGE']._serialized_start=1713 + _globals['_RESPONSEMESSAGE']._serialized_end=1748 # @@protoc_insertion_point(module_scope) diff --git a/nebula/node.py b/nebula/node.py index fbf2ec44a..7b81a379d 100755 --- a/nebula/node.py +++ b/nebula/node.py @@ -347,15 +347,19 @@ def randomize_value(value, variability): # In order to do that, it should request the current round to the controller if additional_node_status: logging.info(f"Waiting for round {additional_node_round} to start") - time.sleep(6000) # DEBUG purposes - import requests - - url = f"http://{node.config.participant['scenario_args']['controller']}/platform/{node.config.participant['scenario_args']['name']}/round" - current_round = int(requests.get(url).json()["round"]) - while current_round < additional_node_round: - logging.info(f"Waiting for round {additional_node_round} to start") - time.sleep(10) - logging.info(f"Round {additional_node_round} started, connecting to the network") + logging.info("Waiting 60s to start finding federation") + time.sleep(60) + #time.sleep(6000) # DEBUG purposes + #import requests + + #url = f"http://{node.config.participant['scenario_args']['controller']}/platform/{node.config.participant['scenario_args']['name']}/round" + #current_round = int(requests.get(url).json()["round"]) + #while current_round < additional_node_round: + # logging.info(f"Waiting for round {additional_node_round} to start") + # time.sleep(10) + #logging.info(f"Round {additional_node_round} started, connecting to the network") + + await node._aditional_node_start() if node.cm is not None: await node.cm.network_wait() From 2a081e1c761d3218a270d1e22128886865e0b7f8 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Mon, 25 Nov 2024 19:12:57 +0100 Subject: [PATCH 014/233] feat_late_starting_trainning -late nodes start trainning -topology structures set apropiately --- nebula/core/engine.py | 133 +++++++++++++----- nebula/core/neighbormanagement/README.txt | 2 +- .../candidateselection/candidateselector.py | 4 +- .../modelhandlers/modelhandler.py | 4 +- .../modelhandlers/stdmodelhandler.py | 4 +- .../neighborpolicies/fcneighborpolicy.py | 7 +- .../neighborpolicies/neighborpolicy.py | 2 +- .../neighborpolicies/ringneighborpolicy.py | 9 +- .../neighborpolicies/starneighborpolicy.py | 9 +- nebula/core/neighbormanagement/nodemanager.py | 118 ++++++++++++---- nebula/core/network/communications.py | 6 +- nebula/core/network/connection.py | 2 +- .../frontend/config/participant.json.example | 1 + nebula/scenarios.py | 1 + 14 files changed, 209 insertions(+), 93 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index c0ac88aaf..2a5e087a1 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -149,9 +149,10 @@ def __init__( # Mobility setup self._node_manager = None - mob = self.config.participant["mobility_args"]["mobility"] - if mob == True: - topology = self.config.participant["mobility_args"]["mobility_type"] + self.mobility = self.config.participant["mobility_args"]["mobility"] + if self.mobility == True: + topology = self.config.participant["mobility_args"]["topology_type"] + topology = topology.lower() model_handler = "std" #self.config.participant["mobility_args"]["model_handler"] self._node_manager = NodeManager(topology, model_handler, engine=self) @@ -214,13 +215,6 @@ def trainer(self): def nm(self): return self._node_manager - async def _aditional_node_start(self): - logging.info(f"{self.addr} is an aditional node going to stablish connection with federation") - await self.nm.start_late_connection_process() - # continue .. - logging.info("Creating trainer service to start the federation process..") - #await self.cm.establish_connection_with_federation() - def get_addr(self): return self.addr @@ -360,21 +354,25 @@ async def _federation_models_included_callback(self, source, message): async def _connection_late_connect_callback(self, source, message): logging.info(f"πŸ”— handle_connection_message | Trigger | Received late_connect message from {source}") if self.nm.accept_connection(source, joining=True): - logging.info(f"πŸ”— Late connection acepted | source:{source}") + logging.info(f"πŸ”— Late connection accepted | source: {source}") self.nm.add_weight_modifier(source) ct_actions , df_actions = self.nm.get_actions() - # connect to - for addr in ct_actions.split(): - cnt_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.CONNECTO_TO, addr) - await self.cm.send_message(source, cnt_msg) - # disconnect from - for addr in df_actions.split(): - df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, addr) - await self.cm.send_message(source, df_msg) + if len(ct_actions): + for addr in ct_actions.split(): + cnt_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.CONNECT_TO, addr) + #await self.cm.send_message(source, cnt_msg) + + if len(df_actions): + for addr in df_actions.split(): + df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, addr) + #await self.cm.send_message(source, df_msg) await self.cm.connect(source, direct=True) + self.nm.meet_node(source) self.nm.update_neighbors(source) + else: + logging.info(f"πŸ”— Late connection NOT accepted | source: {source}") @event_handler( nebula_pb2.ConnectionMessage, @@ -385,14 +383,17 @@ async def _connection_restructure_callback(self, source, message): if self.nm.accept_connection(source, joining=False): logging.info(f"πŸ”— handle_connection_message | Trigger | restructure connection accepted from {source}") ct_actions , df_actions = self.nm.get_actions() - - for addr in ct_actions.split(): - cnt_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.CONNECTO_TO, addr) - await self.cm.send_message(source, cnt_msg) + + if len(ct_actions): + for addr in ct_actions.split(): + cnt_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.CONNECT_TO, addr) + pass + #await self.cm.send_message(source, cnt_msg) - for addr in df_actions.split(): - df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, addr) - await self.cm.send_message(source, df_msg) + if len(df_actions): + for addr in df_actions.split(): + df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, addr) + #await self.cm.send_message(source, df_msg) else: logging.info(f"πŸ”— handle_connection_message | Trigger | restructure connection denied from {source}") await self.cm.disconnect(source, mutual_disconnection=False) @@ -418,11 +419,9 @@ async def _discover_discover_join_callback(self, source, message): epochs ) await self.cm.send_offer_model(source, msg) - await asyncio.sleep(1) - await self.cm.remove_temporary_connection(source) else: - # for the starter federation node - logging.info() + logging.info("Discover join received before federation is running..") + # starter node is going to send info to the new node else: logging.info(f"πŸ”— Dissmissing discover join from {source} | no active connections at the moment") @@ -442,16 +441,18 @@ async def _discover_discover_nodes_callback(self, source, message): ) async def _offer_offer_model_callback(self, source, message): logging.info(f"πŸ” handle_offer_message | Trigger | Received offer_model message from {source}") - if not self.nm.get_restructure_process_lock().locked(): + self.nm.meet_node(source) + if not self.nm.get_restructure_process_lock().locked() and not self.nm.still_waiting_for_candidates(): try: - decoded_model = self.trainer.deserialize_model(message.parameters) - self.nm.accept_model(source, decoded_model, message.rounds, message.round, message.epochs, message.n_neighbors, message.loss) - self.nm.add_candidate(source, message.n_neighbors, message.loss) - self.nm.meet_node(source) + model_compressed = message.parameters + if self.nm.accept_model_offer(source, model_compressed, message.rounds, message.round, message.epochs, message.n_neighbors, message.loss): + logging.info("Model accepted from offer") + else: + logging.info("Model offer discarded") + self.nm.add_to_discarded_offers(source) except RuntimeError: - pass - await self.cm.remove_temporary_connection(source) - + pass + @event_handler( nebula_pb2.OfferMessage, nebula_pb2.OfferMessage.Action.OFFER_METRIC, @@ -486,7 +487,58 @@ async def _link_disconnect_from_callback(self, source, message): await self.cm.disconnect(source, mutual_disconnection=False) self.nm.update_neighbors(addr, remove=True) - + async def _aditional_node_start(self): + logging.info(f"Aditional node | {self.addr} | going to stablish connection with federation") + await self.nm.start_late_connection_process() + # continue .. + await self.nm.stop_not_selected_connections() + logging.info("Creating trainer service to start the federation process..") + asyncio.create_task(self._start_learning_late()) + #decoded_model = self.trainer.deserialize_model(message.parameters) + + + async def _start_learning_late(self): + await self.learning_cycle_lock.acquire_async() + try: + model_serialized, rounds, round, _epochs = self.nm.get_trainning_info() + self.total_rounds = rounds + epochs = _epochs + await self.get_round_lock().acquire_async() + self.round = round + await self.get_round_lock().release_async() + await self.learning_cycle_lock.release_async() + print_msg_box( + msg="Starting Federated Learning process...", + indent=2, + title="Start of the experiment late", + ) + logging.info(f"Trainning setup | total rounds: {rounds} | current round: {round} | epochs: {_epochs}") + direct_connections = await self.cm.get_addrs_current_connections(only_direct=True) + logging.info(f"Initial DIRECT connections: {direct_connections}") + self.trainer.set_epochs(epochs) + self.trainer.create_trainer() + try: + logging.info("πŸ€– Initializing model...") + model = self.trainer.deserialize_model(model_serialized) + self.trainer.set_model_parameters(model, initialize=True) + logging.info("Model Parameters Initialized") + self.set_initialization_status(True) + await ( + self.get_federation_ready_lock().release_async() + ) # Enable learning cycle once the initialization is done + try: + await ( + self.get_federation_ready_lock().release_async() + ) # Release the lock acquired at the beginning of the engine + except RuntimeError: + pass + except RuntimeError: + pass + await self._learning_cycle() + finally: + if await self.learning_cycle_lock.locked_async(): + await self.learning_cycle_lock.release_async() + async def create_trainer_module(self): asyncio.create_task(self._start_learning()) logging.info("Started trainer module...") @@ -507,6 +559,9 @@ async def start_communications(self): await asyncio.sleep(1) current_connections = await self.cm.get_addrs_current_connections() logging.info(f"Connections verified: {current_connections}") + if self.mobility: + logging.info("Building NodeManager configurations...") + await self.nm.set_confings() await self._reporter.start() await self.cm.deploy_additional_services() await asyncio.sleep(self.config.participant["misc_args"]["grace_time_connection"] // 2) diff --git a/nebula/core/neighbormanagement/README.txt b/nebula/core/neighbormanagement/README.txt index ee4b08556..3eb4afdbf 100644 --- a/nebula/core/neighbormanagement/README.txt +++ b/nebula/core/neighbormanagement/README.txt @@ -2,7 +2,7 @@ β–ˆβ–ˆβ•”β•β•β–ˆβ–ˆβ•—β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β•šβ•β•β–ˆβ–ˆβ•”β•β•β•β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•”β•β•β•β–ˆβ–ˆβ•—β–ˆβ–ˆβ•”β•β•β–ˆβ–ˆβ•— β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•‘β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•‘β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•”β• β–ˆβ–ˆβ•”β•β•β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•”β•β•β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•”β•β–ˆβ–ˆβ•— -β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β•šβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•”β• β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β•šβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•”β•β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•— +β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β•šβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•”β• β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β•šβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•”β•β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•— β•šβ•β• β•šβ•β• β•šβ•β•β•β•β•β• β•šβ•β• β•šβ•β• β•šβ•β• β•šβ•β•β•β•β•β• β•šβ•β• β•šβ•β• Alejandro AvilΓ©s Serrano. diff --git a/nebula/core/neighbormanagement/candidateselection/candidateselector.py b/nebula/core/neighbormanagement/candidateselection/candidateselector.py index 0a3372dc0..f2cb5d0e1 100644 --- a/nebula/core/neighbormanagement/candidateselection/candidateselector.py +++ b/nebula/core/neighbormanagement/candidateselection/candidateselector.py @@ -29,9 +29,9 @@ def factory_CandidateSelector(topology) -> CandidateSelector: from nebula.core.neighbormanagement.candidateselection.ringcandidateselector import RINGCandidateSelector options = { - 'ring': RINGCandidateSelector, + "ring": RINGCandidateSelector, "fully": FCCandidateSelector, - "random": HETCandidateSelector + "random": HETCandidateSelector, } cs = options.get(topology, FCCandidateSelector) diff --git a/nebula/core/neighbormanagement/modelhandlers/modelhandler.py b/nebula/core/neighbormanagement/modelhandlers/modelhandler.py index 0431a0568..0b418992a 100644 --- a/nebula/core/neighbormanagement/modelhandlers/modelhandler.py +++ b/nebula/core/neighbormanagement/modelhandlers/modelhandler.py @@ -24,8 +24,8 @@ def factory_ModelHandler(model_handler) -> ModelHandler: from nebula.core.neighbormanagement.modelhandlers.aggmodelhandler import AGGModelHandler options = { - 'std': STDModelHandler, - "aggregator": AGGModelHandler + "std": STDModelHandler, + "aggregator": AGGModelHandler, } cs = options.get(model_handler, STDModelHandler) diff --git a/nebula/core/neighbormanagement/modelhandlers/stdmodelhandler.py b/nebula/core/neighbormanagement/modelhandlers/stdmodelhandler.py index 85aff15cb..392a3c746 100644 --- a/nebula/core/neighbormanagement/modelhandlers/stdmodelhandler.py +++ b/nebula/core/neighbormanagement/modelhandlers/stdmodelhandler.py @@ -27,12 +27,12 @@ def set_config(self, config): def accept_model(self, model): """ - save only first model receive to set up own model later + save only first model received to set up own model later """ if not self.model_lock.locked(): self.model_lock.acquire() self.model = model - return self.model_lock.locked() + return True def get_model(self, model): """ diff --git a/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py b/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py index 47b679418..1940abf9a 100644 --- a/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py +++ b/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py @@ -73,14 +73,11 @@ def get_actions(self): - First list represents addrs argument to LinkMessage to connect to - Second one represents the same but for disconnect from LinkMessage """ - actions = [] - actions.append(self._connect_to()) - actions.append(self._disconnect_from()) - return actions + return [self._connect_to(), self._disconnect_from()] def _disconnect_from(self): - return "" + return [] def _connect_to(self): ct = "" diff --git a/nebula/core/neighbormanagement/neighborpolicies/neighborpolicy.py b/nebula/core/neighbormanagement/neighborpolicies/neighborpolicy.py index 143736ccc..573ec7d01 100644 --- a/nebula/core/neighbormanagement/neighborpolicies/neighborpolicy.py +++ b/nebula/core/neighbormanagement/neighborpolicies/neighborpolicy.py @@ -45,7 +45,7 @@ def factory_NeighborPolicy(topology) -> NeighborPolicy: "random": IDLENeighborPolicy, # default value "fully": FCNeighborPolicy, "ring": RINGNeighborPolicy, - "star": IDLENeighborPolicy + "star": IDLENeighborPolicy, } cs = options.get(topology, IDLENeighborPolicy) diff --git a/nebula/core/neighbormanagement/neighborpolicies/ringneighborpolicy.py b/nebula/core/neighbormanagement/neighborpolicies/ringneighborpolicy.py index d3f0ed4c0..ad717524b 100644 --- a/nebula/core/neighbormanagement/neighborpolicies/ringneighborpolicy.py +++ b/nebula/core/neighbormanagement/neighborpolicies/ringneighborpolicy.py @@ -75,15 +75,16 @@ def get_actions(self): - Second one represents the same but for disconnect from LinkMessage """ self.neighbors_lock.acquire() - actions = [] + ct_actions = [] + df_actions = [] if len(self.neighbors) < self.max_neighbors: list_neighbors = list(self.neighbors) index = random.randint(0, len(list_neighbors)-1) node = list_neighbors[index] - actions.append(node) # connect to - actions.append(self.addr) # disconnect from + ct_actions.append(node) # connect to + df_actions.append(self.addr) # disconnect from self.neighbors_lock.release() - return actions + return [ct_actions, df_actions] def update_neighbors(self, node, remove=False): self.neighbors_lock.acquire() diff --git a/nebula/core/neighbormanagement/neighborpolicies/starneighborpolicy.py b/nebula/core/neighbormanagement/neighborpolicies/starneighborpolicy.py index 75a21677a..40d2c567b 100644 --- a/nebula/core/neighbormanagement/neighborpolicies/starneighborpolicy.py +++ b/nebula/core/neighbormanagement/neighborpolicies/starneighborpolicy.py @@ -68,12 +68,13 @@ def get_actions(self): - Second one represents the same but for disconnect from LinkMessage """ self.neighbors_lock.acquire() - actions = [] + ct_actions = [] + df_actions = [] if len(self.neighbors) < self.max_neighbors: - actions.append(self.neighbors[0]) # connect to star point - actions.append(self.addr) # disconnect from me + ct_actions.append(self.neighbors[0]) # connect to star point + df_actions.append(self.addr) # disconnect from me self.neighbors_lock.release() - return actions + return [ct_actions, df_actions] def update_neighbors(self, node, remove=False): pass \ No newline at end of file diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index d4df9d0f1..0f1295760 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -25,14 +25,15 @@ def __init__( model_handler, engine : "Engine" ): - print_msg_box(msg=f"Starting NodeManager module...\nTopology: {topology}", indent=2, title="NodeManager module") + self.topology = topology + print_msg_box(msg=f"Starting NodeManager module...\nTopology: {self.topology}", indent=2, title="NodeManager module") logging.info("🌐 Initializing Node Manager") self._engine = engine self.config = engine.get_config() logging.info("Initializing Neighbor policy") - self._neighbor_policy = factory_NeighborPolicy(topology) + self._neighbor_policy = factory_NeighborPolicy(self.topology) logging.info("Initializing Candidate Selector") - self._candidate_selector = factory_CandidateSelector(topology) + self._candidate_selector = factory_CandidateSelector(self.topology) logging.info("Initializing Model Handler") self._model_handler = factory_ModelHandler(model_handler) self.late_connection_process_lock = Locker(name="late_connection_process_lock") @@ -43,11 +44,14 @@ def __init__( self.recieve_offer_timer = 5 self._restructure_process_lock = Locker(name="restructure_process_lock") self.restructure = False + self.discarded_offers_addr_lock = Locker(name="discarded_offers_addr_lock") + self.discarded_offers_addr = [] + self.max_time_to_wait = 6 logging.info("Initializing Timer generator") self._timer_generator = None #TimerGenerator(self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False), self.max_time_to_wait, 80) - self.set_confings() + #self.set_confings() @property def engine(self): @@ -72,11 +76,15 @@ def timer_generator(self): def get_restructure_process_lock(self): return self._restructure_process_lock - def set_confings(self): + def still_waiting_for_candidates(self): + return self.accept_candidates_lock.locked() + + + async def set_confings(self): """ neighbor_policy config: - direct connections a.k.a neighbors - - non-direct connections + - all nodes known - self addr model_handler config: @@ -89,12 +97,29 @@ def set_confings(self): - self weight distance - self weight hetereogeneity """ - logging.info("Building neighbor policy configuration..") - self.neighbor_policy.set_config([self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False), self.engine.cm.get_addrs_current_connections(only_direct=False, myself=False), self.engine.addr]) - #self.model_handler.set_config([self.engine.get_round(), self.engine.config.participant["training_args"]["epochs"]]) - logging.info("Building candidate selector configuration..") - self.candidate_selector.set_config([0, 0.2 , 0.8]) + logging.info(f"Building neighbor policy configuration..") + self.neighbor_policy.set_config( + [ + await self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False), + await self.engine.cm.get_addrs_current_connections(only_direct=False, only_undirected=False, myself=False), + self.engine.addr + ] + ) + logging.info(f"Building candidate selector configuration..") + self.candidate_selector.set_config( + [ + 0, + 0.2, + 0.8 + ] + ) #self.engine.trainer.get_loss(), self.config.participant["molibity_args"]["weight_distance"], self.config.participant["molibity_args"]["weight_het"] + #self.model_handler.set_config([self.engine.get_round(), self.engine.config.participant["training_args"]["epochs"]]) + + def add_to_discarded_offers(self, addr_discarded): + self.discarded_offers_addr_lock.acquire() + self.discarded_offers_addr.append(addr_discarded) + self.discarded_offers_addr_lock.release() def get_timer(self): return self.timer_generator.get_timer(self.engine.get_round()) @@ -141,9 +166,12 @@ def get_weight_modifier(self, addr): return wm def accept_connection(self, source, joining=False): - if self.accept_candidates_lock.locked(): - return False - return self.neighbor_policy.accept_connection(source) + if not joining: + if self.get_restructure_process_lock().locked(): + logging.info("NOT accepting connections | Currently upgrading network Robustness") + return False + else: + return self.neighbor_policy.accept_connection(source) def need_more_neighbors(self): return self.neighbor_policy.need_more_neighbors() @@ -163,27 +191,46 @@ def no_neighbors_left(self): return len(self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False)) def meet_node(self, node): + logging.info(f"Update nodes known | addr: {node}") self.neighbor_policy.meet_node(node) def get_nodes_known(self, neighbors_too=False): return self.neighbor_policy.get_nodes_known(neighbors_too) - def accept_model(self, source, decoded_model, rounds, round, epochs, n_neighbors, loss): + def accept_model_offer(self, source, decoded_model, rounds, round, epochs, n_neighbors, loss): if not self.accept_candidates_lock.locked(): - self.model_handler.accept_model(decoded_model) - self.model_handler.set_config(config=(rounds, round, epochs)) - self.candidate_selector.add_candidate((source, n_neighbors, loss)) + model_accepted = self.model_handler.accept_model(decoded_model) + self.model_handler.set_config(config=(rounds, round, epochs)) + if model_accepted: + self.candidate_selector.add_candidate((source, n_neighbors, loss)) + return True + else: + return False + + def get_trainning_info(self): + return self.model_handler.get_model(None) def add_candidate(self,source, n_neighbors, loss): if not self.accept_candidates_lock.locked(): self.candidate_selector.add_candidate((source, n_neighbors, loss)) + async def stop_not_selected_connections(self): + try: + if len(self.discarded_offers_addr) > 0: + logging.info(f"Interrupting connections | discarded offers | nodes discarded: {self.discarded_offers_addr}") + for addr in self.discarded_offers_addr: + await self.engine.cm.disconnect(addr, mutual_disconnection=True) + await asyncio.sleep(1) + self.discarded_offers_addr = [] + except asyncio.CancelledError as e: + pass + async def start_late_connection_process(self): """ This function represents the process of discovering the federation and stablish the first connections with it. The first step is to send the DISCOVER_JOIN message to look for nodes, the ones that receive that message will send back a OFFER_MODEL message. It contains info to do - a selection process among candidates to later on connect do the best ones. + a selection process among candidates to later on connect to the best ones. The process will repeat until at least one candidate is found and the process will be locked to avoid concurrency. @@ -207,27 +254,26 @@ async def start_late_connection_process(self): self.accept_candidates_lock.acquire() if self.candidate_selector.any_candidate(): - logging.info("Candidates found to connect") + logging.info("Candidates found to connect to...") # create message to send to new neightbors - msg = self.engine.cm.mm.generate_connection_message(nebula_pb2.ConnectionMessage.Action.LATE_CONNECT) - - best_candidates = self.candidate_selector.select_candidates() - - #for addr, _, _ in best_candidates: - #await self.engine.cm.connect(addr, direct=True) - #await self.engine.cm.send_message(addr, msg) + msg = self.engine.cm.mm.generate_connection_message(nebula_pb2.ConnectionMessage.Action.LATE_CONNECT) + best_candidates = self.candidate_selector.select_candidates() + logging.info(f"Candidates | {[addr for addr,_,_ in best_candidates]}") + # candidates not choosen --> disconnect + for addr, _, _ in best_candidates: + await self.engine.cm.connect(addr, direct=True) + await self.engine.cm.send_message(addr, msg) + await asyncio.sleep(1) - model, rounds, round, epochs = self.model_handler.get_model() self.accept_candidates_lock.release() - self.late_connection_process_lock.release() - return (model, rounds, round, epochs) + self.late_connection_process_lock.release() # if no candidates, repeat process else: logging.info("No Candidates found | repeating process") self.accept_candidates_lock.release() self.late_connection_process_lock.release() - return await self.start_late_connection_process() + await self.start_late_connection_process() @@ -235,6 +281,16 @@ async def start_late_connection_process(self): Retopology in progress """ + async def check_robustness(self): + logging.info("Analizing node network robustness...") + if len(self.engine.get_federation_nodes()) == 0: + logging.info("No Neighbors left | reconnecting with Federation") + elif self.neighbor_policy.need_more_neighbors(): + logging.info("Insufficient Robustness | searching for more connections") + else: + logging.info("Sufficient Robustness | no actions required") + + async def find_new_connections(self): logging.info("🌐 Initializing restructure process from Node Manager") self._restructure_process_lock.acquire() diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 09f11c1d4..4d92fc23f 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -781,7 +781,11 @@ async def process_establish_connection(addr, direct, reconnect): async with self.connections_manager_lock: if addr in self.connections: logging.info(f"πŸ”— [outgoing] Already connected with {self.connections[addr]}") - return False + if not self.connections[addr].get_direct() and (direct == True): + self.connections[addr].set_direct(direct) + return True + else: + return False if addr in self.pending_connections: logging.info(f"πŸ”— [outgoing] Connection with {addr} is already pending") if int(self.host.split(".")[3]) >= int(host.split(".")[3]): diff --git a/nebula/core/network/connection.py b/nebula/core/network/connection.py index 21d75519a..338340d67 100755 --- a/nebula/core/network/connection.py +++ b/nebula/core/network/connection.py @@ -287,7 +287,7 @@ async def handle_incoming_message(self) -> None: logging.info("Message handling cancelled") except ConnectionError as e: logging.exception(f"Connection closed while reading: {e}") - await self.reconnect() + #await self.reconnect() except Exception as e: logging.exception(f"Error handling incoming message: {e}") diff --git a/nebula/frontend/config/participant.json.example b/nebula/frontend/config/participant.json.example index c68266120..b7dd73268 100755 --- a/nebula/frontend/config/participant.json.example +++ b/nebula/frontend/config/participant.json.example @@ -61,6 +61,7 @@ "random_geo": true, "mobility": false, "mobility_type": "topology", + "topology_type": "", "radius_federation": 1000, "scheme_mobility": "random", "round_frequency": 1, diff --git a/nebula/scenarios.py b/nebula/scenarios.py index f4a34bb45..efa9b79f9 100644 --- a/nebula/scenarios.py +++ b/nebula/scenarios.py @@ -354,6 +354,7 @@ def __init__(self, scenario): participant_config["mobility_args"]["scheme_mobility"] = self.scenario.scheme_mobility participant_config["mobility_args"]["round_frequency"] = self.scenario.round_frequency participant_config["reporter_args"]["report_status_data_queue"] = self.scenario.report_status_data_queue + participant_config["mobility_args"]["topology_type"] = self.scenario.topology with open(participant_file, "w") as f: json.dump(participant_config, f, sort_keys=False, indent=2) From c0768cd61ba3f856d41d05c11d4ab348cce197a2 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Tue, 26 Nov 2024 21:06:19 +0100 Subject: [PATCH 015/233] feat_LateNodes_train Later creation nodes now can be integrated into the trainning process --- nebula/core/aggregation/aggregator.py | 28 ++++++++++++ nebula/core/engine.py | 44 ++++++++++++++----- nebula/core/neighbormanagement/nodemanager.py | 11 +++-- .../frontend/config/participant.json.example | 2 +- nebula/node.py | 2 +- 5 files changed, 72 insertions(+), 15 deletions(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index c42482d41..6281c5f1b 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -165,6 +165,8 @@ async def _add_pending_model(self, model, weight, source): if len(self.get_nodes_pending_models_to_aggregate()) >= len(self._federation_nodes): logging.info("πŸ”„ _add_pending_model | All models were added in the aggregation buffer. Run aggregation...") await self._aggregation_done_lock.release_async() + else: + await self.aggregation_push_available() await self._add_model_lock.release_async() return self.get_nodes_pending_models_to_aggregate() @@ -252,6 +254,32 @@ def print_model_size(self, model): total_memory_in_mb = total_memory / (1024**2) logging.info(f"print_model_size | Model size: {total_memory_in_mb} MB") + async def aggregation_push_available(self): + """ + If the node is not sinchronized with the federation, it may be possible to make a push + and try to catch the federation asap. + """ + logging.info(f"❗️ synchronized status: {self.engine.get_sinchronized_status()} | Analizing if an aggregation push is available...") + if not self.engine.get_sinchronized_status(): + n_fed_nodes = len(self._federation_nodes) + further_round = self.engine.get_round() + if len(self.get_nodes_pending_models_to_aggregate()) < n_fed_nodes: + for f_round, fm in self._future_models_to_aggregate.items(): + if len(fm) == n_fed_nodes: + further_round = f_round + push = self.engine.get_push_acceleration() + if push == "slow": + logging.info(f"❗️ FUTURE round: {round} is available | PUSH available") + logging.info("❗️ SLOW push selected | Start PUSHING slow") + self._aggregation_done_lock.release_async() + return + if further_round != self.engine.get_round() and push == "fast": + logging.info("❗️ FAST push selected | Start PUSHING fast") + + else: + self.engine.update_sinchronized_status(True) + else: + logging.info(f"All models updates are received | models number: {len(self.get_nodes_pending_models_to_aggregate())}") def create_malicious_aggregator(aggregator, attack): # It creates a partial function aggregate that wraps the aggregate method of the original aggregator. diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 2a5e087a1..37e573ada 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -147,6 +147,11 @@ def __init__( self._reporter = Reporter(config=self.config, trainer=self.trainer, cm=self.cm) + self._sinchronized_status = True + self.sinchronized_status_lock = Locker(name="sinchronized_status_lock") + + self.trainning_in_progress_lock = Locker(name="trainning_in_progress_lock", async_lock=True) + # Mobility setup self._node_manager = None self.mobility = self.config.participant["mobility_args"]["mobility"] @@ -241,6 +246,15 @@ def get_federation_setup_lock(self): def get_round_lock(self): return self.round_lock + + def get_sinchronized_status(self): + with self.sinchronized_status_lock: + return self._sinchronized_status + + def update_sinchronized_status(self, status): + with self.sinchronized_status_lock: + self._sinchronized_status = status + @event_handler(nebula_pb2.DiscoveryMessage, nebula_pb2.DiscoveryMessage.Action.DISCOVER) async def _discovery_discover_callback(self, source, message): @@ -404,9 +418,10 @@ async def _discover_discover_join_callback(self, source, message): logging.info(f"πŸ” handle_discover_message | Trigger | Received discover_join message from {source} ") self.nm.meet_node(source) if len(self.get_federation_nodes()) > 0: - #model, rounds, round = await self.cm.propagator.get_model_information(source, "stable") if self.get_round() > 0 else await self.cm.propagator.get_model_information(source, "initialization") - model, rounds, round = await self.cm.propagator.get_model_information(source, "initialization") - # Process not initiated yet + await self.trainning_in_progress_lock.acquire_async() + model, rounds, round = await self.cm.propagator.get_model_information(source, "stable") if self.get_round() > 0 else await self.cm.propagator.get_model_information(source, "initialization") + await self.trainning_in_progress_lock.release_async() + #model, rounds, round = await self.cm.propagator.get_model_information(source, "initialization") if round != -1: epochs = self.config.participant["training_args"]["epochs"] msg = self.cm.mm.generate_offer_message( @@ -446,9 +461,9 @@ async def _offer_offer_model_callback(self, source, message): try: model_compressed = message.parameters if self.nm.accept_model_offer(source, model_compressed, message.rounds, message.round, message.epochs, message.n_neighbors, message.loss): - logging.info("Model accepted from offer") + logging.info("πŸ”§ Model accepted from offer") else: - logging.info("Model offer discarded") + logging.info("❗️ Model offer discarded") self.nm.add_to_discarded_offers(source) except RuntimeError: pass @@ -488,6 +503,7 @@ async def _link_disconnect_from_callback(self, source, message): self.nm.update_neighbors(addr, remove=True) async def _aditional_node_start(self): + self.update_sinchronized_status(False) logging.info(f"Aditional node | {self.addr} | going to stablish connection with federation") await self.nm.start_late_connection_process() # continue .. @@ -496,13 +512,15 @@ async def _aditional_node_start(self): asyncio.create_task(self._start_learning_late()) #decoded_model = self.trainer.deserialize_model(message.parameters) + def get_push_acceleration(self): + return self.nm.get_push_acceleration() async def _start_learning_late(self): await self.learning_cycle_lock.acquire_async() try: model_serialized, rounds, round, _epochs = self.nm.get_trainning_info() - self.total_rounds = rounds - epochs = _epochs + self.total_rounds = rounds # self.config.participant["scenario_args"]["rounds"] #rounds + epochs = _epochs # self.config.participant["training_args"]["epochs"] #_epochs await self.get_round_lock().acquire_async() self.round = round await self.get_round_lock().release_async() @@ -512,13 +530,13 @@ async def _start_learning_late(self): indent=2, title="Start of the experiment late", ) - logging.info(f"Trainning setup | total rounds: {rounds} | current round: {round} | epochs: {_epochs}") + logging.info(f"Trainning setup | total rounds: {rounds} | current round: {round} | epochs: {epochs}") direct_connections = await self.cm.get_addrs_current_connections(only_direct=True) logging.info(f"Initial DIRECT connections: {direct_connections}") - self.trainer.set_epochs(epochs) - self.trainer.create_trainer() + await asyncio.sleep(1) try: logging.info("πŸ€– Initializing model...") + await asyncio.sleep(1) model = self.trainer.deserialize_model(model_serialized) self.trainer.set_model_parameters(model, initialize=True) logging.info("Model Parameters Initialized") @@ -534,7 +552,11 @@ async def _start_learning_late(self): pass except RuntimeError: pass + + self.trainer.set_epochs(epochs) + self.trainer.create_trainer() await self._learning_cycle() + finally: if await self.learning_cycle_lock.locked_async(): await self.learning_cycle_lock.release_async() @@ -875,7 +897,9 @@ def __init__( async def _extended_learning_cycle(self): # Define the functionality of the aggregator node await self.trainer.test() + await self.trainning_in_progress_lock.acquire_async() await self.trainer.train() + await self.trainning_in_progress_lock.release_async() await self.aggregator.include_model_in_buffer( self.trainer.get_model_parameters(), diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 0f1295760..43c6af259 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -51,6 +51,9 @@ def __init__( logging.info("Initializing Timer generator") self._timer_generator = None #TimerGenerator(self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False), self.max_time_to_wait, 80) + self._push_acceleration = "slow" + self.init_model = None + #self.set_confings() @property @@ -73,12 +76,14 @@ def model_handler(self): def timer_generator(self): return self._timer_generator + def get_push_acceleration(self): + return self._push_acceleration + def get_restructure_process_lock(self): return self._restructure_process_lock def still_waiting_for_candidates(self): return self.accept_candidates_lock.locked() - async def set_confings(self): """ @@ -109,8 +114,8 @@ async def set_confings(self): self.candidate_selector.set_config( [ 0, - 0.2, - 0.8 + 0.5, + 0.5 ] ) #self.engine.trainer.get_loss(), self.config.participant["molibity_args"]["weight_distance"], self.config.participant["molibity_args"]["weight_het"] diff --git a/nebula/frontend/config/participant.json.example b/nebula/frontend/config/participant.json.example index b7dd73268..8c8c4b980 100755 --- a/nebula/frontend/config/participant.json.example +++ b/nebula/frontend/config/participant.json.example @@ -88,7 +88,7 @@ }, "aggregator_args": { "algorithm": "FedAvg", - "aggregation_timeout": 300 + "aggregation_timeout": 20 }, "defense_args": { "with_reputation": false, diff --git a/nebula/node.py b/nebula/node.py index 7b81a379d..2d7cf5514 100755 --- a/nebula/node.py +++ b/nebula/node.py @@ -348,7 +348,7 @@ def randomize_value(value, variability): if additional_node_status: logging.info(f"Waiting for round {additional_node_round} to start") logging.info("Waiting 60s to start finding federation") - time.sleep(60) + time.sleep(65) #time.sleep(6000) # DEBUG purposes #import requests From 71c7604838f982db9e25be3d7f13640296a9e6b4 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 27 Nov 2024 10:40:50 +0100 Subject: [PATCH 016/233] fix_fast_push fast push to sinchronize network integrated --- nebula/core/aggregation/aggregator.py | 17 ++++++++++++++++- nebula/core/engine.py | 3 +++ nebula/frontend/config/participant.json.example | 2 +- nebula/node.py | 2 +- 4 files changed, 21 insertions(+), 3 deletions(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index 6281c5f1b..d72798d4a 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -269,12 +269,27 @@ async def aggregation_push_available(self): further_round = f_round push = self.engine.get_push_acceleration() if push == "slow": - logging.info(f"❗️ FUTURE round: {round} is available | PUSH available") + logging.info(f"❗️ FUTURE round: {further_round} is available | PUSH strategy ON") logging.info("❗️ SLOW push selected | Start PUSHING slow") + # Unlock aggregation self._aggregation_done_lock.release_async() return if further_round != self.engine.get_round() and push == "fast": + logging.info(f"❗️ FUTURE round: {further_round} is available | PUSH strategy ON") logging.info("❗️ FAST push selected | Start PUSHING fast") + (model, weight) = self._pending_models_to_aggregate.get(self.engine.get_addr()) + self._pending_models_to_aggregate.clear() + self._pending_models_to_aggregate.update({self.engine.get_addr(): (model, weight)}) + + for future_update in self._future_models_to_aggregate[further_round]: + (decoded_model, weight, source) = future_update + self._pending_models_to_aggregate.update({source: (decoded_model, weight)}) + + self.engine.set_round(further_round) + + # Unlock aggregation + self._aggregation_done_lock.release_async() + return else: self.engine.update_sinchronized_status(True) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 37e573ada..11409d2d2 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -254,6 +254,9 @@ def get_sinchronized_status(self): def update_sinchronized_status(self, status): with self.sinchronized_status_lock: self._sinchronized_status = status + + def set_round(self, new_round): + self.round = new_round @event_handler(nebula_pb2.DiscoveryMessage, nebula_pb2.DiscoveryMessage.Action.DISCOVER) diff --git a/nebula/frontend/config/participant.json.example b/nebula/frontend/config/participant.json.example index 8c8c4b980..dbea96f53 100755 --- a/nebula/frontend/config/participant.json.example +++ b/nebula/frontend/config/participant.json.example @@ -88,7 +88,7 @@ }, "aggregator_args": { "algorithm": "FedAvg", - "aggregation_timeout": 20 + "aggregation_timeout": 60 }, "defense_args": { "with_reputation": false, diff --git a/nebula/node.py b/nebula/node.py index 2d7cf5514..5f5f9cce4 100755 --- a/nebula/node.py +++ b/nebula/node.py @@ -348,7 +348,7 @@ def randomize_value(value, variability): if additional_node_status: logging.info(f"Waiting for round {additional_node_round} to start") logging.info("Waiting 60s to start finding federation") - time.sleep(65) + time.sleep(70) #time.sleep(6000) # DEBUG purposes #import requests From c495af68d251a17971e3a6b3fb34af24c91c9c73 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 27 Nov 2024 13:12:10 +0100 Subject: [PATCH 017/233] feat_weights_modifiers Weight modifiers applied to late connected nodes --- nebula/core/aggregation/aggregator.py | 4 ++ nebula/core/engine.py | 22 ++++++++- nebula/core/neighbormanagement/nodemanager.py | 46 ++++++++++++++----- .../frontend/config/participant.json.example | 3 +- 4 files changed, 60 insertions(+), 15 deletions(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index b76a858cf..985eedea4 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -165,6 +165,7 @@ async def _add_pending_model(self, model, weight, source): if len(self.get_nodes_pending_models_to_aggregate()) >= len(self._federation_nodes): logging.info("πŸ”„ _add_pending_model | All models were added in the aggregation buffer. Run aggregation...") + self.engine.update_sinchronized_status(True) await self._aggregation_done_lock.release_async() else: await self.aggregation_push_available() @@ -234,6 +235,7 @@ async def get_aggregation(self): else: logging.info("πŸ”„ get_aggregation | All models accounted for, proceeding with aggregation.") + #self._pending_models_to_aggregate = self.engine.apply_weight_strategy(self._pending_models_to_aggregate) aggregated_result = self.run_aggregation(self._pending_models_to_aggregate) self._pending_models_to_aggregate.clear() return aggregated_result @@ -277,6 +279,7 @@ async def aggregation_push_available(self): logging.info(f"❗️ FUTURE round: {further_round} is available | PUSH strategy ON") logging.info("❗️ SLOW push selected | Start PUSHING slow") # Unlock aggregation + self.engine.set_pushed_done(self.engine.get_round() - further_round) self._aggregation_done_lock.release_async() return if further_round != self.engine.get_round() and push == "fast": @@ -290,6 +293,7 @@ async def aggregation_push_available(self): (decoded_model, weight, source) = future_update self._pending_models_to_aggregate.update({source: (decoded_model, weight)}) + self.engine.set_pushed_done(self.engine.get_round() - further_round) self.engine.set_round(further_round) # Unlock aggregation diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 11409d2d2..81043cfeb 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -159,7 +159,8 @@ def __init__( topology = self.config.participant["mobility_args"]["topology_type"] topology = topology.lower() model_handler = "std" #self.config.participant["mobility_args"]["model_handler"] - self._node_manager = NodeManager(topology, model_handler, engine=self) + acceleration_push = self.config.participant["aggregation_args"]["aggregation_push"] + self._node_manager = NodeManager(topology, model_handler, acceleration_push, engine=self) self._event_manager = EventManager( @@ -369,7 +370,7 @@ async def _federation_models_included_callback(self, source, message): nebula_pb2.ConnectionMessage.Action.LATE_CONNECT, ) async def _connection_late_connect_callback(self, source, message): - logging.info(f"πŸ”— handle_connection_message | Trigger | Received late_connect message from {source}") + logging.info(f"πŸ”— handle_connection_message | Trigger | Received late connect message from {source}") if self.nm.accept_connection(source, joining=True): logging.info(f"πŸ”— Late connection accepted | source: {source}") self.nm.add_weight_modifier(source) @@ -517,6 +518,16 @@ async def _aditional_node_start(self): def get_push_acceleration(self): return self.nm.get_push_acceleration() + + def set_pushed_done(self, rounds_push): + self.nm.set_rounds_pushed(rounds_push) + + def apply_weight_strategy(self, pending_models): + #if self.mobility: + # + #else: + return + async def _start_learning_late(self): await self.learning_cycle_lock.acquire_async() @@ -726,6 +737,7 @@ async def _learning_cycle(self): logging.info(f"[Role {self.role}] Starting learning cycle...") await self.aggregator.update_federation_nodes(self.federation_nodes) await self._extended_learning_cycle() + await self._additional_mobility_actions() await self.get_round_lock().acquire_async() print_msg_box( @@ -779,6 +791,12 @@ async def _extended_learning_cycle(self): functionalities. The method is called in the _learning_cycle method. """ pass + + async def _additional_mobility_actions(self): + if not self.mobility: + return + logging.info("πŸ”„ Starting additional mobility actions...") + #self.nm.update_weight_modifiers() def reputation_calculation(self, aggregated_models_weights): cossim_threshold = 0.5 diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 43c6af259..3915246f1 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -23,6 +23,7 @@ def __init__( self, topology, model_handler, + push_acceleration, engine : "Engine" ): self.topology = topology @@ -47,12 +48,12 @@ def __init__( self.discarded_offers_addr_lock = Locker(name="discarded_offers_addr_lock") self.discarded_offers_addr = [] - self.max_time_to_wait = 6 + self.max_time_to_wait = 20 logging.info("Initializing Timer generator") self._timer_generator = None #TimerGenerator(self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False), self.max_time_to_wait, 80) - self._push_acceleration = "slow" - self.init_model = None + self._push_acceleration = push_acceleration + self.rounds_pushed = 0 #self.set_confings() @@ -82,6 +83,9 @@ def get_push_acceleration(self): def get_restructure_process_lock(self): return self._restructure_process_lock + def set_rounds_pushed(self, rp): + self.rounds_pushed = rp + def still_waiting_for_candidates(self): return self.accept_candidates_lock.locked() @@ -141,30 +145,48 @@ async def receive_update_from_node(self, node_id, node_response_time): def add_weight_modifier(self, addr): self.weight_modifier_lock.acquire() if not addr in self.weight_modifier: - self.weight_modifier[addr] = self.new_node_weight_value + wv = self.new_node_weight_value + logging.info(f"πŸ“ Registering | Weight modifier registered for source {addr} | round: {round} | value: {wv}") + self.weight_modifier[addr] = wv self.weight_modifier_lock.release() def remove_weight_modifier(self, addr): self.weight_modifier_lock.acquire() if addr in self.weight_modifier: + logging.info(f"πŸ“ Removing | weight modifier registered for source {addr}") del self.weight_modifier[addr] self.weight_modifier_lock.release() - - def _update_weight_modifier(self, addr): - self.weight_modifier_lock.acquire() - if addr in self.weight_modifier: - new_weight = self.weight_modifier[addr] - 1/self.engine.get_round()**2 + + def apply_weight_strategy(self, updates): + logging.info(f"πŸ”„ Applying weight Strategy...") + # We must lower the weight_modifier value if a round jump has been occured + # as many times as rounds have been jumped + if self.rounds_pushed: + for i in range(0, self.rounds_pushed): + self._update_weight_modifiers() + self.rounds_pushed = 0 + for addr,update in updates.items(): + weight_modifier = self._get_weight_modifier(addr) + if weight_modifier != 1: + logging.info(f"πŸ“ addr found :{addr}") + logging.info (f"πŸ“ Appliying modified weight strategy | multiplier value: {weight_modifier}") + model, weight = update + updates.update({addr: (model, weight*weight_modifier)}) + + def _update_weight_modifiers(self): + self.weight_modifier_lock.acquire() + for addr,weight in self.weight_modifier.items(): + new_weight = weight - 1/(round**2) if new_weight > 1: self.weight_modifier[addr] = new_weight else: self.remove_weight_modifier(addr) self.weight_modifier_lock.release() - def get_weight_modifier(self, addr): + def _get_weight_modifier(self, addr): self.weight_modifier_lock.acquire() if addr in self.weight_modifier: - wm = self.weight_modifier[addr] - self._update_weight_modifier(addr, self.engine.get_round()) + wm = self.weight_modifier[addr] else: wm = 1 self.weight_modifier_lock.release() diff --git a/nebula/frontend/config/participant.json.example b/nebula/frontend/config/participant.json.example index dbea96f53..0a714c09b 100755 --- a/nebula/frontend/config/participant.json.example +++ b/nebula/frontend/config/participant.json.example @@ -88,7 +88,8 @@ }, "aggregator_args": { "algorithm": "FedAvg", - "aggregation_timeout": 60 + "aggregation_timeout": 60, + "aggregation_push": "slow" }, "defense_args": { "with_reputation": false, From ec553b6b2efa844b2776ef21e34d20e8a80e66cd Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 28 Nov 2024 12:13:58 +0100 Subject: [PATCH 018/233] fix_slow_push_strategy slow push strategy working properly fixing metrics from late connection nodes --- nebula/core/aggregation/aggregator.py | 37 +++++++++++++++---- nebula/core/engine.py | 24 +++++++++--- nebula/core/neighbormanagement/nodemanager.py | 19 +++++++--- nebula/core/network/communications.py | 2 +- nebula/core/training/lightning.py | 3 ++ 5 files changed, 65 insertions(+), 20 deletions(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index 985eedea4..7a88c45a8 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -167,8 +167,8 @@ async def _add_pending_model(self, model, weight, source): logging.info("πŸ”„ _add_pending_model | All models were added in the aggregation buffer. Run aggregation...") self.engine.update_sinchronized_status(True) await self._aggregation_done_lock.release_async() - else: - await self.aggregation_push_available() + #else: + # await self.aggregation_push_available() await self._add_model_lock.release_async() return self.get_nodes_pending_models_to_aggregate() @@ -246,6 +246,7 @@ async def include_next_model_in_buffer(self, model, weight, source=None, round=N self._future_models_to_aggregate[round] = [] decoded_model = self.engine.trainer.deserialize_model(model) self._future_models_to_aggregate[round].append((decoded_model, weight, source)) + #await self.aggregation_push_available() def print_model_size(self, model): total_params = 0 @@ -267,21 +268,32 @@ async def aggregation_push_available(self): and try to catch the federation asap. """ logging.info(f"❗️ synchronized status: {self.engine.get_sinchronized_status()} | Analizing if an aggregation push is available...") - if not self.engine.get_sinchronized_status(): - n_fed_nodes = len(self._federation_nodes) + if not self.engine.get_sinchronized_status() and not self.engine.get_trainning_in_progress_lock().locked() and not self.engine.get_synchronizing_rounds(): + n_fed_nodes = len(self._federation_nodes) further_round = self.engine.get_round() + logging.info(f" Pending models: {len(self.get_nodes_pending_models_to_aggregate())} | federation: {n_fed_nodes}") if len(self.get_nodes_pending_models_to_aggregate()) < n_fed_nodes: for f_round, fm in self._future_models_to_aggregate.items(): + # future_models dont count self node + n_fed_nodes-=1 if len(fm) == n_fed_nodes: further_round = f_round push = self.engine.get_push_acceleration() if push == "slow": logging.info(f"❗️ FUTURE round: {further_round} is available | PUSH strategy ON") - logging.info("❗️ SLOW push selected | Start PUSHING slow") - # Unlock aggregation + logging.info("❗️ SLOW push selected | Start PUSHING slow") self.engine.set_pushed_done(self.engine.get_round() - further_round) - self._aggregation_done_lock.release_async() + # we wait until learning cycle reach aggregation point + while not self._aggregation_done_lock.locked_async(): + logging.info("πŸ”„ Waiting | aggregation step not reached yet...") + await asyncio.sleep(1) + # Unlock aggregation + logging.info("πŸ”„ Releasing aggregation lock...") + await self._aggregation_done_lock.release_async() return + # hay que revisar la sincronizacion bien de todo, saltar rondas puede ser complicado + # si en lo q se esta realizando este cambio llega un mensaje, que? tienen q estar las estructuras + # bloqueadas. Tengo que estudiarlo bien if further_round != self.engine.get_round() and push == "fast": logging.info(f"❗️ FUTURE round: {further_round} is available | PUSH strategy ON") logging.info("❗️ FAST push selected | Start PUSHING fast") @@ -297,13 +309,22 @@ async def aggregation_push_available(self): self.engine.set_round(further_round) # Unlock aggregation - self._aggregation_done_lock.release_async() + # we wait until learning cycle reach aggregation point + while not self._aggregation_done_lock.locked_async(): + await asyncio.sleep(1) + await self._aggregation_done_lock.release_async() return else: self.engine.update_sinchronized_status(True) else: logging.info(f"All models updates are received | models number: {len(self.get_nodes_pending_models_to_aggregate())}") + else: + if not self.engine.get_sinchronized_status(): + if self.engine.get_sinchronized_status(): + logging.info("❗️ Cannot analize push | Trainning in progress") + elif self.engine.get_synchronizing_rounds(): + logging.info("❗️ Cannot analize push | already pushing rounds") def create_malicious_aggregator(aggregator, attack): # It creates a partial function aggregate that wraps the aggregate method of the original aggregator. diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 81043cfeb..6cf440233 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -159,7 +159,7 @@ def __init__( topology = self.config.participant["mobility_args"]["topology_type"] topology = topology.lower() model_handler = "std" #self.config.participant["mobility_args"]["model_handler"] - acceleration_push = self.config.participant["aggregation_args"]["aggregation_push"] + acceleration_push = "slow" #self.config.participant["aggregation_args"]["aggregation_push"] self._node_manager = NodeManager(topology, model_handler, acceleration_push, engine=self) @@ -244,6 +244,9 @@ def get_federation_ready_lock(self): def get_federation_setup_lock(self): return self.federation_setup_lock + + def get_trainning_in_progress_lock(self): + return self.trainning_in_progress_lock def get_round_lock(self): return self.round_lock @@ -252,8 +255,13 @@ def get_sinchronized_status(self): with self.sinchronized_status_lock: return self._sinchronized_status + def get_synchronizing_rounds(self): + return self.nm.get_syncrhonizing_rounds() + def update_sinchronized_status(self, status): with self.sinchronized_status_lock: + if self.mobility: + self.nm.set_synchronizing_rounds(status) self._sinchronized_status = status def set_round(self, new_round): @@ -461,16 +469,19 @@ async def _discover_discover_nodes_callback(self, source, message): async def _offer_offer_model_callback(self, source, message): logging.info(f"πŸ” handle_offer_message | Trigger | Received offer_model message from {source}") self.nm.meet_node(source) - if not self.nm.get_restructure_process_lock().locked() and not self.nm.still_waiting_for_candidates(): + if not self.nm.get_restructure_process_lock().locked() and self.nm.still_waiting_for_candidates(): try: model_compressed = message.parameters if self.nm.accept_model_offer(source, model_compressed, message.rounds, message.round, message.epochs, message.n_neighbors, message.loss): - logging.info("πŸ”§ Model accepted from offer") + logging.info(f"πŸ”§ Model accepted from offer | source: {source}") else: - logging.info("❗️ Model offer discarded") + logging.info(f"❗️ Model offer discarded | source: {source}") self.nm.add_to_discarded_offers(source) except RuntimeError: - pass + logging.info(f"❗️ Error proccesing offer model from {source}") + else: + logging.info(f"❗️ handfle_offer_message | NOT accepting offers | restructure: {self.nm.get_restructure_process_lock().locked()} | waiting candidates: {self.nm.still_waiting_for_candidates()}") + self.nm.add_to_discarded_offers(source) @event_handler( nebula_pb2.OfferMessage, @@ -528,7 +539,6 @@ def apply_weight_strategy(self, pending_models): #else: return - async def _start_learning_late(self): await self.learning_cycle_lock.acquire_async() try: @@ -568,6 +578,7 @@ async def _start_learning_late(self): pass self.trainer.set_epochs(epochs) + self.trainer.set_current_round(round) self.trainer.create_trainer() await self._learning_cycle() @@ -712,6 +723,7 @@ async def _dynamic_aggregator(self, aggregated_models_weights, malicious_nodes): async def _waiting_model_updates(self): logging.info(f"πŸ’€ Waiting convergence in round {self.round}.") + await self.aggregator.aggregation_push_available() params = await self.aggregator.get_aggregation() if params is not None: logging.info( diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 3915246f1..7ad043679 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -53,8 +53,11 @@ def __init__( self._timer_generator = None #TimerGenerator(self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False), self.max_time_to_wait, 80) self._push_acceleration = push_acceleration + self.rounds_pushed_lock = Locker(name="rounds_pushed_lock") self.rounds_pushed = 0 + self.synchronizing_rounds = False + #self.set_confings() @property @@ -83,11 +86,18 @@ def get_push_acceleration(self): def get_restructure_process_lock(self): return self._restructure_process_lock + def set_synchronizing_rounds(self, status): + self.synchronizing_rounds = status + + def get_syncrhonizing_rounds(self): + return self.synchronizing_rounds + def set_rounds_pushed(self, rp): - self.rounds_pushed = rp + with self.rounds_pushed_lock: + self.rounds_pushed = rp def still_waiting_for_candidates(self): - return self.accept_candidates_lock.locked() + return not self.accept_candidates_lock.locked() async def set_confings(self): """ @@ -146,7 +156,7 @@ def add_weight_modifier(self, addr): self.weight_modifier_lock.acquire() if not addr in self.weight_modifier: wv = self.new_node_weight_value - logging.info(f"πŸ“ Registering | Weight modifier registered for source {addr} | round: {round} | value: {wv}") + logging.info(f"πŸ“ Registering | Weight modifier registered for source {addr} | round: {self.engine.get_round()} | value: {wv}") self.weight_modifier[addr] = wv self.weight_modifier_lock.release() @@ -168,8 +178,7 @@ def apply_weight_strategy(self, updates): for addr,update in updates.items(): weight_modifier = self._get_weight_modifier(addr) if weight_modifier != 1: - logging.info(f"πŸ“ addr found :{addr}") - logging.info (f"πŸ“ Appliying modified weight strategy | multiplier value: {weight_modifier}") + logging.info (f"πŸ“ Appliying modified weight strategy | addr: {addr}| multiplier value: {weight_modifier}") model, weight = update updates.update({addr: (model, weight*weight_modifier)}) diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 4d92fc23f..8facd7e2b 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -759,7 +759,7 @@ async def send_offer_model(self, dest_addr, offer_message): logging.info(f"❗️ Connection with {dest_addr} not found") return logging.info( - f"Sending model to {dest_addr}" + f"Sending offer model to {dest_addr}" ) await conn.send(data=offer_message, is_compressed=True) logging.info(f"Offer_Model sent to {dest_addr}") diff --git a/nebula/core/training/lightning.py b/nebula/core/training/lightning.py index a03ff0831..b96437755 100755 --- a/nebula/core/training/lightning.py +++ b/nebula/core/training/lightning.py @@ -239,6 +239,9 @@ def get_hash_model(self): def set_epochs(self, epochs): self.epochs = epochs + + def set_current_round(self, round): + self.round = round def serialize_model(self, model): # From https://pytorch.org/docs/stable/notes/serialization.html From ab1af93bea4266f2c8721ffdc4b4cac5a393cbf6 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Fri, 29 Nov 2024 11:53:46 +0100 Subject: [PATCH 019/233] fix_metric_delay Metric delay done for late creation nodes --- nebula/core/engine.py | 4 ++-- nebula/core/models/nebulamodel.py | 9 +++++++++ nebula/core/neighbormanagement/nodemanager.py | 7 ++++--- nebula/core/training/lightning.py | 5 +++++ nebula/core/utils/nebulalogger_tensorboard.py | 5 +++-- 5 files changed, 23 insertions(+), 7 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 6cf440233..511ed5147 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -439,7 +439,7 @@ async def _discover_discover_join_callback(self, source, message): msg = self.cm.mm.generate_offer_message( nebula_pb2.OfferMessage.Action.OFFER_MODEL, len(self.get_federation_nodes()), - 0, #self.trainer.get_loss(), + 0, #self.trainer.get_current_loss(), model, rounds, round, @@ -459,7 +459,7 @@ async def _discover_discover_join_callback(self, source, message): async def _discover_discover_nodes_callback(self, source, message): logging.info(f"πŸ” handle_discover_message | Trigger | Received discover_node message from {source} ") self.nm.meet_node(source) - msg = self.cm.mm.generate_offer_message(nebula_pb2.OfferMessage.Action.OFFER_METRIC, len(self.get_federation_nodes()), self.trainer.get_loss()) + msg = self.cm.mm.generate_offer_message(nebula_pb2.OfferMessage.Action.OFFER_METRIC, len(self.get_federation_nodes()), self.trainer.get_current_loss()) await self.cm.send_message(source, msg) @event_handler( diff --git a/nebula/core/models/nebulamodel.py b/nebula/core/models/nebulamodel.py index 0e9ea0c30..84325c3f1 100755 --- a/nebula/core/models/nebulamodel.py +++ b/nebula/core/models/nebulamodel.py @@ -229,6 +229,15 @@ def step(self, batch, batch_idx, phase): def get_loss(self): return self.current_loss + + def set_updated_round(self, round): + self.round = round + self.global_number = { + "Train": round, + "Validation": round, + "Test (Local)": round, + "Test (Global)": round, + } def training_step(self, batch, batch_idx): """ diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 7ad043679..9c37c8b2d 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -172,17 +172,18 @@ def apply_weight_strategy(self, updates): # We must lower the weight_modifier value if a round jump has been occured # as many times as rounds have been jumped if self.rounds_pushed: + round = self.engine.get_round() for i in range(0, self.rounds_pushed): - self._update_weight_modifiers() + self._update_weight_modifiers((round + i)) self.rounds_pushed = 0 for addr,update in updates.items(): weight_modifier = self._get_weight_modifier(addr) if weight_modifier != 1: - logging.info (f"πŸ“ Appliying modified weight strategy | addr: {addr}| multiplier value: {weight_modifier}") + logging.info (f"πŸ“ Appliying modified weight strategy | addr: {addr} | multiplier value: {weight_modifier}") model, weight = update updates.update({addr: (model, weight*weight_modifier)}) - def _update_weight_modifiers(self): + def _update_weight_modifiers(self, round): self.weight_modifier_lock.acquire() for addr,weight in self.weight_modifier.items(): new_weight = weight - 1/(round**2) diff --git a/nebula/core/training/lightning.py b/nebula/core/training/lightning.py index b96437755..83232797b 100755 --- a/nebula/core/training/lightning.py +++ b/nebula/core/training/lightning.py @@ -241,7 +241,12 @@ def set_epochs(self, epochs): self.epochs = epochs def set_current_round(self, round): + logging.info(f"Update | current round = {round}") self.round = round + self.model.set_updated_round(round) + + def get_current_loss(self): + return self.model.get_loss() def serialize_model(self, model): # From https://pytorch.org/docs/stable/notes/serialization.html diff --git a/nebula/core/utils/nebulalogger_tensorboard.py b/nebula/core/utils/nebulalogger_tensorboard.py index d7c4ce124..a8dc5943a 100755 --- a/nebula/core/utils/nebulalogger_tensorboard.py +++ b/nebula/core/utils/nebulalogger_tensorboard.py @@ -17,7 +17,7 @@ def get_step(self): def log_data(self, data, step=None): if step is None: step = self.get_step() - # logging.debug(f"Logging data for global step {step} | local step {self.local_step} | global step {self.global_step}") + #logging.debug(f"Logging data for global step {step} | local step {self.local_step} | global step {self.global_step}") try: super().log_metrics(data, step) except ValueError: @@ -29,7 +29,7 @@ def log_metrics(self, metrics, step=None): if step is None: self.local_step += 1 step = self.global_step + self.local_step - # logging.debug(f"Logging metrics for global step {step} | local step {self.local_step} | global step {self.global_step}") + #logging.debug(f"Logging metrics for global step {step} | local step {self.local_step} | global step {self.global_step}") if "epoch" in metrics: metrics.pop("epoch") try: @@ -61,3 +61,4 @@ def set_logger_config(self, logger_config): self.global_step = logger_config["global_step"] except Exception as e: logging.exception(f"Error setting logger config: {e}") + From fc64e40f732de20c35532858a18e2c900a66714c Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Fri, 29 Nov 2024 18:45:07 +0100 Subject: [PATCH 020/233] feat_fast_push fast push integrated weight strategy integrated fix modified weights --- nebula/core/aggregation/aggregator.py | 48 ++++++++++++++++--- nebula/core/engine.py | 9 ++-- nebula/core/neighbormanagement/nodemanager.py | 30 ++++++------ .../frontend/config/participant.json.example | 1 + 4 files changed, 62 insertions(+), 26 deletions(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index 7a88c45a8..e0812021e 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -64,6 +64,7 @@ def __init__(self, config=None, engine=None): self._pending_models_to_aggregate = {} self._future_models_to_aggregate = {} self._add_model_lock = Locker(name="add_model_lock", async_lock=True) + self._add_next_model_lock = Locker(name="add_next_model_lock", async_lock=True) self._aggregation_done_lock = Locker(name="aggregation_done_lock", async_lock=True) def __str__(self): @@ -235,8 +236,11 @@ async def get_aggregation(self): else: logging.info("πŸ”„ get_aggregation | All models accounted for, proceeding with aggregation.") - #self._pending_models_to_aggregate = self.engine.apply_weight_strategy(self._pending_models_to_aggregate) + self._pending_models_to_aggregate = self.engine.apply_weight_strategy(self._pending_models_to_aggregate) aggregated_result = self.run_aggregation(self._pending_models_to_aggregate) + if not self.engine.get_sinchronized_status() and self.engine.get_push_acceleration() == "fast": + await self._add_model_lock.release_async() + await self._add_next_model_lock.release_async() self._pending_models_to_aggregate.clear() return aggregated_result @@ -245,7 +249,9 @@ async def include_next_model_in_buffer(self, model, weight, source=None, round=N if round not in self._future_models_to_aggregate: self._future_models_to_aggregate[round] = [] decoded_model = self.engine.trainer.deserialize_model(model) + await self._add_next_model_lock.acquire_async() self._future_models_to_aggregate[round].append((decoded_model, weight, source)) + await self._add_next_model_lock.release_async() #await self.aggregation_push_available() def print_model_size(self, model): @@ -273,9 +279,9 @@ async def aggregation_push_available(self): further_round = self.engine.get_round() logging.info(f" Pending models: {len(self.get_nodes_pending_models_to_aggregate())} | federation: {n_fed_nodes}") if len(self.get_nodes_pending_models_to_aggregate()) < n_fed_nodes: + n_fed_nodes-=1 for f_round, fm in self._future_models_to_aggregate.items(): - # future_models dont count self node - n_fed_nodes-=1 + # future_models dont count self node if len(fm) == n_fed_nodes: further_round = f_round push = self.engine.get_push_acceleration() @@ -291,20 +297,48 @@ async def aggregation_push_available(self): logging.info("πŸ”„ Releasing aggregation lock...") await self._aggregation_done_lock.release_async() return - # hay que revisar la sincronizacion bien de todo, saltar rondas puede ser complicado - # si en lo q se esta realizando este cambio llega un mensaje, que? tienen q estar las estructuras - # bloqueadas. Tengo que estudiarlo bien + if further_round != self.engine.get_round() and push == "fast": logging.info(f"❗️ FUTURE round: {further_round} is available | PUSH strategy ON") logging.info("❗️ FAST push selected | Start PUSHING fast") - (model, weight) = self._pending_models_to_aggregate.get(self.engine.get_addr()) + + if further_round == (self.engine.get_round()+1): + logging.info(f"πŸ”„ Rounds jumped: {1}...") + self.engine.set_pushed_done(self.engine.get_round() - further_round) + # we wait until learning cycle reach aggregation point + while not self._aggregation_done_lock.locked_async(): + logging.info("πŸ”„ Waiting | aggregation step not reached yet...") + await asyncio.sleep(1) + # Unlock aggregation + logging.info("πŸ”„ Releasing aggregation lock...") + await self._aggregation_done_lock.release_async() + return + + logging.info(f"πŸ”„ Rounds jumped: {self.engine.get_round() - further_round}...") + own_update = self._pending_models_to_aggregate.get(self.engine.get_addr()) + while own_update == None: + own_update = self._pending_models_to_aggregate.get(self.engine.get_addr()) + asyncio.sleep(1) + (model, weight) = own_update + + # Getting locks to avoid concurrency issues + await self._add_model_lock.acquire_async() + await self._add_next_model_lock.acquire_async() + + # Remove all pendings updates and add own_update self._pending_models_to_aggregate.clear() self._pending_models_to_aggregate.update({self.engine.get_addr(): (model, weight)}) + # Add to pendings the future round updates for future_update in self._future_models_to_aggregate[further_round]: (decoded_model, weight, source) = future_update self._pending_models_to_aggregate.update({source: (decoded_model, weight)}) + # Clear all rounds that are going to be jumped + for key in self._future_models_to_aggregate.keys(): + if key <= further_round: + del self._future_models_to_aggregate[key] + self.engine.set_pushed_done(self.engine.get_round() - further_round) self.engine.set_round(further_round) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 511ed5147..6a92163e0 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -159,7 +159,7 @@ def __init__( topology = self.config.participant["mobility_args"]["topology_type"] topology = topology.lower() model_handler = "std" #self.config.participant["mobility_args"]["model_handler"] - acceleration_push = "slow" #self.config.participant["aggregation_args"]["aggregation_push"] + acceleration_push = "slow" #self.config.participant["mobility_args"]["push_strategy"] self._node_manager = NodeManager(topology, model_handler, acceleration_push, engine=self) @@ -265,7 +265,9 @@ def update_sinchronized_status(self, status): self._sinchronized_status = status def set_round(self, new_round): + logging.info(f"πŸ€– Update round count | from: {self.round} | to round: {new_round}") self.round = new_round + self.trainer.set_current_round(new_round) @event_handler(nebula_pb2.DiscoveryMessage, nebula_pb2.DiscoveryMessage.Action.DISCOVER) @@ -535,9 +537,10 @@ def set_pushed_done(self, rounds_push): def apply_weight_strategy(self, pending_models): #if self.mobility: - # + # self.nm.apply_weight_strategy(pending_models) + # return pending_models #else: - return + return pending_models async def _start_learning_late(self): await self.learning_cycle_lock.acquire_async() diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 9c37c8b2d..58a4182e8 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -40,7 +40,7 @@ def __init__( self.late_connection_process_lock = Locker(name="late_connection_process_lock") self.weight_modifier = {} self.weight_modifier_lock = Locker(name="weight_modifier_lock") - self.new_node_weight_value = 3 + self.new_node_weight_multiplier = 3 self.accept_candidates_lock = Locker(name="accept_candidates_lock") self.recieve_offer_timer = 5 self._restructure_process_lock = Locker(name="restructure_process_lock") @@ -155,9 +155,9 @@ async def receive_update_from_node(self, node_id, node_response_time): def add_weight_modifier(self, addr): self.weight_modifier_lock.acquire() if not addr in self.weight_modifier: - wv = self.new_node_weight_value - logging.info(f"πŸ“ Registering | Weight modifier registered for source {addr} | round: {self.engine.get_round()} | value: {wv}") - self.weight_modifier[addr] = wv + wm = self.new_node_weight_multiplier + logging.info(f"πŸ“ Registering | Weight modifier registered for source {addr} | round: {self.engine.get_round()} | value: {wm}") + self.weight_modifier[addr] = (wm,1) self.weight_modifier_lock.release() def remove_weight_modifier(self, addr): @@ -172,33 +172,31 @@ def apply_weight_strategy(self, updates): # We must lower the weight_modifier value if a round jump has been occured # as many times as rounds have been jumped if self.rounds_pushed: - round = self.engine.get_round() for i in range(0, self.rounds_pushed): - self._update_weight_modifiers((round + i)) + self._update_weight_modifiers() self.rounds_pushed = 0 for addr,update in updates.items(): - weight_modifier = self._get_weight_modifier(addr) + weight_modifier, _ = self._get_weight_modifier(addr) if weight_modifier != 1: logging.info (f"πŸ“ Appliying modified weight strategy | addr: {addr} | multiplier value: {weight_modifier}") model, weight = update updates.update({addr: (model, weight*weight_modifier)}) - - def _update_weight_modifiers(self, round): + self._update_weight_modifiers() + + def _update_weight_modifiers(self): self.weight_modifier_lock.acquire() - for addr,weight in self.weight_modifier.items(): - new_weight = weight - 1/(round**2) + for addr,(weight,rounds) in self.weight_modifier.items(): + new_weight = weight - 1/(rounds**2) + rounds = rounds + 1 if new_weight > 1: - self.weight_modifier[addr] = new_weight + self.weight_modifier[addr] = (new_weight, rounds) else: self.remove_weight_modifier(addr) self.weight_modifier_lock.release() def _get_weight_modifier(self, addr): self.weight_modifier_lock.acquire() - if addr in self.weight_modifier: - wm = self.weight_modifier[addr] - else: - wm = 1 + wm = self.weight_modifier.get(addr, (1,0)) self.weight_modifier_lock.release() return wm diff --git a/nebula/frontend/config/participant.json.example b/nebula/frontend/config/participant.json.example index 0a714c09b..f53479a73 100755 --- a/nebula/frontend/config/participant.json.example +++ b/nebula/frontend/config/participant.json.example @@ -62,6 +62,7 @@ "mobility": false, "mobility_type": "topology", "topology_type": "", + "push_strategy": "slow", "radius_federation": 1000, "scheme_mobility": "random", "round_frequency": 1, From 9051a7865c6879ce12ed49f7f1250a864e06a091 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Mon, 2 Dec 2024 11:27:08 +0100 Subject: [PATCH 021/233] fix_fast_reboot fast reboot when device arrives fully integrated --- nebula/core/engine.py | 16 +++++++++++----- nebula/core/models/mnist/cnn.py | 1 + nebula/core/models/mnist/mlp.py | 1 + nebula/core/models/nebulamodel.py | 11 ++++++++--- nebula/core/neighbormanagement/nodemanager.py | 9 +++++++++ nebula/core/training/lightning.py | 3 +++ 6 files changed, 33 insertions(+), 8 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 6a92163e0..78d4353f0 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -399,6 +399,7 @@ async def _connection_late_connect_callback(self, source, message): await self.cm.connect(source, direct=True) self.nm.meet_node(source) self.nm.update_neighbors(source) + await self.update_model_learning_rate() else: logging.info(f"πŸ”— Late connection NOT accepted | source: {source}") @@ -536,11 +537,16 @@ def set_pushed_done(self, rounds_push): self.nm.set_rounds_pushed(rounds_push) def apply_weight_strategy(self, pending_models): - #if self.mobility: - # self.nm.apply_weight_strategy(pending_models) - # return pending_models - #else: - return pending_models + if self.mobility and self.nm.fast_reboot_on(): + self.nm.apply_weight_strategy(pending_models) + return pending_models + else: + return pending_models + + async def update_model_learning_rate(self): + await self.trainning_in_progress_lock.acquire_async() + self.trainer.update_model_learning_rate(self.nm.get_learning_rate_increase()) + await self.trainning_in_progress_lock.release_async() async def _start_learning_late(self): await self.learning_cycle_lock.acquire_async() diff --git a/nebula/core/models/mnist/cnn.py b/nebula/core/models/mnist/cnn.py index 28fb96cc3..7cec6b6c3 100755 --- a/nebula/core/models/mnist/cnn.py +++ b/nebula/core/models/mnist/cnn.py @@ -52,4 +52,5 @@ def configure_optimizers(self): betas=(self.config["beta1"], self.config["beta2"]), amsgrad=self.config["amsgrad"], ) + self._optimizer = optimizer return optimizer diff --git a/nebula/core/models/mnist/mlp.py b/nebula/core/models/mnist/mlp.py index 5d4e9ffa8..3e636c847 100755 --- a/nebula/core/models/mnist/mlp.py +++ b/nebula/core/models/mnist/mlp.py @@ -34,4 +34,5 @@ def forward(self, x): def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) + self._optimizer = optimizer return optimizer diff --git a/nebula/core/models/nebulamodel.py b/nebula/core/models/nebulamodel.py index 84325c3f1..8495e660e 100755 --- a/nebula/core/models/nebulamodel.py +++ b/nebula/core/models/nebulamodel.py @@ -197,7 +197,8 @@ def __init__( # Model parameters are sent by default using network.propagator self.communication_manager = None - self.current_loss = -1 + self._current_loss = -1 + self._optimizer = None def set_communication_manager(self, communication_manager): self.communication_manager = communication_manager @@ -224,11 +225,15 @@ def step(self, batch, batch_idx, phase): loss = self.criterion(y_pred, y) self.process_metrics(phase, y_pred, y, loss) - self.current_loss=loss + self._current_loss=loss return loss def get_loss(self): - return self.current_loss + return self._current_loss + + def modify_learning_rate(self, new_lr): + for param_group in self._optimizer.param_groups: + param_group['lr'] = new_lr def set_updated_round(self, round): self.round = round diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 58a4182e8..9ac7a48ba 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -58,6 +58,9 @@ def __init__( self.synchronizing_rounds = False + self._fast_reboot = False + self._learning_rate=1e-3 + #self.set_confings() @property @@ -80,6 +83,12 @@ def model_handler(self): def timer_generator(self): return self._timer_generator + def get_learning_rate_increase(self): + return self._learning_rate + + def fast_reboot_on(self): + return self._fast_reboot + def get_push_acceleration(self): return self._push_acceleration diff --git a/nebula/core/training/lightning.py b/nebula/core/training/lightning.py index 83232797b..6266b5b7f 100755 --- a/nebula/core/training/lightning.py +++ b/nebula/core/training/lightning.py @@ -353,3 +353,6 @@ def on_round_end(self): def on_learning_cycle_end(self): self._logger.log_data({"A-Round": self.round}) # self.reporter.enqueue_data("Round", self.round) + + def update_model_learning_rate(self, new_lr): + self.model.modify_learning_rate(new_lr) From 207783296c8f16d1c0d0a46213bbcecbc0175e9e Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Mon, 2 Dec 2024 14:40:05 +0100 Subject: [PATCH 022/233] feat_upgrading_network_robustness mechanisms itnegrated: .-reconnect_to_federation .-upgrade_connection_robustness --- nebula/core/engine.py | 28 ++--- .../neighborpolicies/fcneighborpolicy.py | 2 +- nebula/core/neighbormanagement/nodemanager.py | 112 ++++++++++-------- nebula/core/network/communications.py | 27 +++-- 4 files changed, 93 insertions(+), 76 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 78d4353f0..d1fe2f32e 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -389,19 +389,19 @@ async def _connection_late_connect_callback(self, source, message): if len(ct_actions): for addr in ct_actions.split(): cnt_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.CONNECT_TO, addr) - #await self.cm.send_message(source, cnt_msg) + await self.cm.send_message(source, cnt_msg) if len(df_actions): for addr in df_actions.split(): df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, addr) - #await self.cm.send_message(source, df_msg) + await self.cm.send_message(source, df_msg) await self.cm.connect(source, direct=True) self.nm.meet_node(source) self.nm.update_neighbors(source) await self.update_model_learning_rate() else: - logging.info(f"πŸ”— Late connection NOT accepted | source: {source}") + logging.info(f"❗️ Late connection NOT accepted | source: {source}") @event_handler( nebula_pb2.ConnectionMessage, @@ -416,17 +416,15 @@ async def _connection_restructure_callback(self, source, message): if len(ct_actions): for addr in ct_actions.split(): cnt_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.CONNECT_TO, addr) - pass - #await self.cm.send_message(source, cnt_msg) + await self.cm.send_message(source, cnt_msg) if len(df_actions): for addr in df_actions.split(): df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, addr) - #await self.cm.send_message(source, df_msg) + await self.cm.send_message(source, df_msg) else: - logging.info(f"πŸ”— handle_connection_message | Trigger | restructure connection denied from {source}") - await self.cm.disconnect(source, mutual_disconnection=False) - self.nm.update_neighbors(source, remove=True) + logging.info(f"❗️ handle_connection_message | Trigger | restructure connection denied from {source}") + await self.cm.disconnect(source, mutual_disconnection=False) @event_handler(nebula_pb2.DiscoverMessage, nebula_pb2.DiscoverMessage.Action.DISCOVER_JOIN) async def _discover_discover_join_callback(self, source, message): @@ -504,10 +502,10 @@ async def _offer_offer_metric_callback(self, source, message): async def _link_connect_to_callback(self, source, message): logging.info(f"πŸ”— handle_link_message | Trigger | Received connecto_to message from {source}") addrs = message.arguments - for addr in addrs: - await self.cm.connect(addr, direct=True) - self.nm.update_neighbors(addr) - self.nm.meet_node(source) + for addr in addrs.split(): + #await self.cm.connect(addr, direct=True) + #self.nm.update_neighbors(addr) + self.nm.meet_node(addr) @event_handler( nebula_pb2.LinkMessage, @@ -516,7 +514,7 @@ async def _link_connect_to_callback(self, source, message): async def _link_disconnect_from_callback(self, source, message): logging.info(f"πŸ”— handle_link_message | Trigger | Received disconnect_from message from {source}") addrs = message.arguments - for addr in addrs: + for addr in addrs.split(): await self.cm.disconnect(source, mutual_disconnection=False) self.nm.update_neighbors(addr, remove=True) @@ -817,7 +815,7 @@ async def _additional_mobility_actions(self): if not self.mobility: return logging.info("πŸ”„ Starting additional mobility actions...") - #self.nm.update_weight_modifiers() + self.nm.check_robustness() def reputation_calculation(self, aggregated_models_weights): cossim_threshold = 0.5 diff --git a/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py b/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py index 1940abf9a..b56dedd47 100644 --- a/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py +++ b/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py @@ -77,7 +77,7 @@ def get_actions(self): def _disconnect_from(self): - return [] + return "" def _connect_to(self): ct = "" diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 9ac7a48ba..77ab21d1b 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -143,12 +143,7 @@ async def set_confings(self): ) #self.engine.trainer.get_loss(), self.config.participant["molibity_args"]["weight_distance"], self.config.participant["molibity_args"]["weight_het"] #self.model_handler.set_config([self.engine.get_round(), self.engine.config.participant["training_args"]["epochs"]]) - - def add_to_discarded_offers(self, addr_discarded): - self.discarded_offers_addr_lock.acquire() - self.discarded_offers_addr.append(addr_discarded) - self.discarded_offers_addr_lock.release() - + def get_timer(self): return self.timer_generator.get_timer(self.engine.get_round()) @@ -157,10 +152,12 @@ def adjust_timer(self): def get_stop_condition(self): return self.timer_generator.get_stop_condition() - - async def receive_update_from_node(self, node_id, node_response_time): - await self.timer_generator.receive_update(node_id, node_response_time) - + + + ############################## + # WEIGHT STRATEGY # + ############################## + def add_weight_modifier(self, addr): self.weight_modifier_lock.acquire() if not addr in self.weight_modifier: @@ -208,6 +205,12 @@ def _get_weight_modifier(self, addr): wm = self.weight_modifier.get(addr, (1,0)) self.weight_modifier_lock.release() return wm + + + ############################## + # CONNECTIONS # + ############################## + def accept_connection(self, source, joining=False): if not joining: @@ -216,6 +219,14 @@ def accept_connection(self, source, joining=False): return False else: return self.neighbor_policy.accept_connection(source) + + async def receive_update_from_node(self, node_id, node_response_time): + await self.timer_generator.receive_update(node_id, node_response_time) + + def add_to_discarded_offers(self, addr_discarded): + self.discarded_offers_addr_lock.acquire() + self.discarded_offers_addr.append(addr_discarded) + self.discarded_offers_addr_lock.release() def need_more_neighbors(self): return self.neighbor_policy.need_more_neighbors() @@ -269,17 +280,15 @@ async def stop_not_selected_connections(self): except asyncio.CancelledError as e: pass - async def start_late_connection_process(self): + #TODO NOT infinite loop, define n_tries + async def start_late_connection_process(self, connected=False, msg_type="discover_join", addrs_known=None): """ This function represents the process of discovering the federation and stablish the first - connections with it. The first step is to send the DISCOVER_JOIN message to look for nodes, - the ones that receive that message will send back a OFFER_MODEL message. It contains info to do + connections with it. The first step is to send the DISCOVER_JOIN/NODES message to look for nodes, + the ones that receive that message will send back a OFFER_MODEL/METRIC message. It contains info to do a selection process among candidates to later on connect to the best ones. The process will repeat until at least one candidate is found and the process will be locked to avoid concurrency. - - Returns: - data neccesary to create trainer """ logging.info("🌐 Initializing late connection process..") @@ -288,7 +297,7 @@ async def start_late_connection_process(self): self.candidate_selector.remove_candidates() # find federation and send discover - await self.engine.cm.establish_connection_with_federation() + await self.engine.cm.stablish_connection_to_federation(msg_type, addrs_known) # wait offer logging.info(f"Waiting: {self.recieve_offer_timer}s to receive offers from federation") @@ -299,8 +308,12 @@ async def start_late_connection_process(self): if self.candidate_selector.any_candidate(): logging.info("Candidates found to connect to...") - # create message to send to new neightbors - msg = self.engine.cm.mm.generate_connection_message(nebula_pb2.ConnectionMessage.Action.LATE_CONNECT) + # create message to send to candidates selected + if not connected: + msg = self.engine.cm.mm.generate_connection_message(nebula_pb2.ConnectionMessage.Action.LATE_CONNECT) + else: + msg = self.engine.cm.mm.generate_connection_message(nebula_pb2.ConnectionMessage.Action.RESTRUCTURE) + best_candidates = self.candidate_selector.select_candidates() logging.info(f"Candidates | {[addr for addr,_,_ in best_candidates]}") # candidates not choosen --> disconnect @@ -311,54 +324,51 @@ async def start_late_connection_process(self): self.accept_candidates_lock.release() self.late_connection_process_lock.release() - + self.candidate_selector.remove_candidates() # if no candidates, repeat process else: logging.info("No Candidates found | repeating process") self.accept_candidates_lock.release() self.late_connection_process_lock.release() - await self.start_late_connection_process() + await self.start_late_connection_process(connected, msg_type, addrs_known) + ############################## + # ROBUSTNESS # + ############################## - """ - Retopology in progress - """ async def check_robustness(self): - logging.info("Analizing node network robustness...") + logging.info("πŸ”„ Analizing node network robustness...") if len(self.engine.get_federation_nodes()) == 0: logging.info("No Neighbors left | reconnecting with Federation") + #await self.reconnect_to_federation() elif self.neighbor_policy.need_more_neighbors(): - logging.info("Insufficient Robustness | searching for more connections") + logging.info("Insufficient Robustness | Upgrading robustness | Searching for more connections") + #asyncio.create_task(self.upgrade_connection_robustness()) else: logging.info("Sufficient Robustness | no actions required") - - async def find_new_connections(self): - logging.info("🌐 Initializing restructure process from Node Manager") + async def reconnect_to_federation(self): + # If we got some refs, try to reconnect to them self._restructure_process_lock.acquire() - # Update the config params of candidate_selector - self.candidate_selector.set_config([self.engine.get_loss(), self.engine.weight_distance, self.engine.weight_het]) - self.thread = threading.Thread(target=self._find_connections_thread, args=(self)) - self.thread.start() - self.restructure = True - while self.restructure: - await asyncio.sleep(1) + if self.neighbor_policy.get_nodes_known() > 0: + logging.info("Reconnecting | Addrs availables") + await self.start_late_connection_process(connected=False, msg_type="discover_nodes", addrs_known=self.neighbor_policy.get_nodes_known()) + # Otherwise stablish connection to federation sending discover nodes instead of join + else: + logging.info("Reconnecting | NO Addrs availables") + await self.start_late_connection_process(connected=False, msg_type="discover_nodes") self._restructure_process_lock.release() - - async def _find_connections_thread(self): - posible_connections = self.get_nodes_known(neighbors_too=False) - while self.restructure: - # out of federation but got info about nodes inside - if len(posible_connections) > 0: - msg = self.engine.cm.mm.generate_discover_message(nebula_pb2.DiscoverMessage.Action.DISCOVER_NODE) - for addr in posible_connections: - # send message to known nodes, wait for response and select - pass - # im out of federation without info about any nodes inside of it - else: - await self.start_late_connection_process() - - self.restructure = self.need_more_neighbors() \ No newline at end of file + async def upgrade_connection_robustness(self): + self._restructure_process_lock.acquire() + addrs_to_connect = self.neighbor_policy.get_nodes_known(neighbors_too=False) + # If we got some refs, try to connect to them + if len(addrs_to_connect) > 0: + await self.start_late_connection_process(connected=True, msg_type="discover_nodes", addrs_known=addrs_to_connect) + else: + await self.start_late_connection_process(connected=True, msg_type="discover_nodes") + self._restructure_process_lock.release() + + \ No newline at end of file diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 8facd7e2b..bd4c35af8 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -388,18 +388,27 @@ def stop_external_connection_service(self): self.ecs.stop() def init_external_connection_service(self): - self.ecs = NebulaConnectionService(self.addr) self.start_external_connection_service() - async def establish_connection_with_federation(self): + async def stablish_connection_to_federation(self, msg_type="discover_join", addrs_known=None): """ Using ExternalConnectionService to get addrs on local network, after that stablishment of TCP connection and send the message broadcasted """ - logging.info("Searching federation process beginning..") - addrs = self.ecs.find_federation() - logging.info(f"Found federation devices | addrs {addrs}") - msg = self.mm.generate_discover_message(nebula_pb2.DiscoverMessage.Action.DISCOVER_JOIN) + addrs = [] + if addrs_known == None: + logging.info("Searching federation process beginning...") + addrs = self.ecs.find_federation() + logging.info(f"Found federation devices | addrs {addrs}") + else: + logging.info("Searching federation process beginning... | Using addrs previously known") + addrs = addrs_known + + if msg_type=="discover_join": + msg = self.mm.generate_discover_message(nebula_pb2.DiscoverMessage.Action.DISCOVER_JOIN) + elif msg_type=="discover_nodes": + msg = self.mm.generate_discover_message(nebula_pb2.DiscoverMessage.Action.DISCOVER_NODES) + logging.info("Starting communications with devices found") for addr in addrs: await self.connect(addr, direct=False) @@ -409,10 +418,10 @@ async def establish_connection_with_federation(self): current_connections = await self.get_addrs_current_connections() logging.info(f"Connections verified after searching: {current_connections}") for addr in addrs: - logging.info(f"Sending discover join to --> {addr}") + logging.info(f"Sending {msg_type} to ---> {addr}") asyncio.create_task(self.send_message(addr, msg)) - await asyncio.sleep(1) - + await asyncio.sleep(1) + def get_connections_lock(self): return self.connections_lock From e86db89cb150dea41a5616f9b5de17907e99b8b8 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Tue, 3 Dec 2024 17:40:36 +0100 Subject: [PATCH 023/233] fix_ecs_run_shutdown --- nebula/core/engine.py | 11 ++++++++- nebula/core/neighbormanagement/nodemanager.py | 23 +++++++++++++++---- nebula/core/network/communications.py | 8 +++++-- .../core/network/externalconnectionservice.py | 4 ++++ nebula/core/network/nebulamulticasting.py | 6 +++++ 5 files changed, 44 insertions(+), 8 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index d1fe2f32e..d36465a54 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -317,7 +317,11 @@ async def _connection_connect_callback(self, source, message): @event_handler(nebula_pb2.ConnectionMessage, nebula_pb2.ConnectionMessage.Action.DISCONNECT) async def _connection_disconnect_callback(self, source, message): logging.info(f"πŸ”— handle_connection_message | Trigger | Received disconnection message from {source}") + if self.mobility: + if source in await self.cm.get_all_addrs_current_connections(only_direct=True): + self.nm.update_neighbors(source, remove=True) await self.cm.disconnect(source, mutual_disconnection=False) + @event_handler( nebula_pb2.FederationMessage, @@ -528,6 +532,8 @@ async def _aditional_node_start(self): asyncio.create_task(self._start_learning_late()) #decoded_model = self.trainer.deserialize_model(message.parameters) + + def get_push_acceleration(self): return self.nm.get_push_acceleration() @@ -815,7 +821,10 @@ async def _additional_mobility_actions(self): if not self.mobility: return logging.info("πŸ”„ Starting additional mobility actions...") - self.nm.check_robustness() + await self.nm.check_robustness() + action = await self.nm.check_external_connection_service_status() + if action: + action() def reputation_calculation(self, aggregated_models_weights): cossim_threshold = 0.5 diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 77ab21d1b..6a759ea3b 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -242,8 +242,8 @@ def update_neighbors(self, node, remove=False): if not remove: self.neighbor_policy.meet_node(node) - def no_neighbors_left(self): - return len(self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False)) + async def neighbors_left(self): + return len(await self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False)) > 0 def meet_node(self, node): logging.info(f"Update nodes known | addr: {node}") @@ -280,6 +280,17 @@ async def stop_not_selected_connections(self): except asyncio.CancelledError as e: pass + async def check_external_connection_service_status(self): + action = None + logging.info(f"πŸ”„ Checking external connection service status...") + if not self.neighbors_left() and self.engine.cm.is_external_connection_service_running(): + logging.info(f"❗️ Isolated node | Shutdowning service required") + action = lambda: self.engine.cm.stop_external_connection_service() + elif self.neighbors_left() and not self.engine.cm.is_external_connection_service_running() and self.engine.get_sinchronized_status(): + logging.info(f"πŸ”„ NOT isolated node | Service not running | Starting service...") + action = lambda: self.engine.cm.init_external_connection_service() + return action + #TODO NOT infinite loop, define n_tries async def start_late_connection_process(self, connected=False, msg_type="discover_join", addrs_known=None): """ @@ -327,10 +338,12 @@ async def start_late_connection_process(self, connected=False, msg_type="discove self.candidate_selector.remove_candidates() # if no candidates, repeat process else: - logging.info("No Candidates found | repeating process") + logging.info("❗️ No Candidates found | repeating process") self.accept_candidates_lock.release() self.late_connection_process_lock.release() - await self.start_late_connection_process(connected, msg_type, addrs_known) + if not connected: + await self.start_late_connection_process(connected, msg_type, addrs_known) + ############################## @@ -340,7 +353,7 @@ async def start_late_connection_process(self, connected=False, msg_type="discove async def check_robustness(self): logging.info("πŸ”„ Analizing node network robustness...") - if len(self.engine.get_federation_nodes()) == 0: + if self.no_neighbors_left(): logging.info("No Neighbors left | reconnecting with Federation") #await self.reconnect_to_federation() elif self.neighbor_policy.need_more_neighbors(): diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index bd4c35af8..e81fb259e 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -381,14 +381,18 @@ async def handle_link_message(self, source, message): logging.error(f"πŸ” handle_link_message | Error while processing: {message.action} {message.arguments} | {e}") def start_external_connection_service(self): - self.ecs = NebulaConnectionService(self.addr) + if self.ecs == None: + self.ecs = NebulaConnectionService(self.addr) self.ecs.start() def stop_external_connection_service(self): self.ecs.stop() def init_external_connection_service(self): - self.start_external_connection_service() + self.start_external_connection_service() + + async def is_external_connection_service_running(self): + return self.ecs.is_running() async def stablish_connection_to_federation(self, msg_type="discover_join", addrs_known=None): """ diff --git a/nebula/core/network/externalconnectionservice.py b/nebula/core/network/externalconnectionservice.py index 90b8d4ff1..d4bc00834 100644 --- a/nebula/core/network/externalconnectionservice.py +++ b/nebula/core/network/externalconnectionservice.py @@ -10,6 +10,10 @@ def start(self): def stop(self): pass + @abstractmethod + def is_running(self): + pass + @abstractmethod def find_federation(self): pass \ No newline at end of file diff --git a/nebula/core/network/nebulamulticasting.py b/nebula/core/network/nebulamulticasting.py index 0d4759e32..2bdbae680 100644 --- a/nebula/core/network/nebulamulticasting.py +++ b/nebula/core/network/nebulamulticasting.py @@ -29,6 +29,9 @@ def stop(self): self.interrupted = True logging.info("Nebula upnp server stop") + def is_running(self): + return not self.interrupted + def listen(self): """ Listen on broadcast addr with standard 1900 port @@ -163,6 +166,9 @@ def start(self): def stop(self): self.server.stop + + def is_running(self): + return self.server.is_running() def find_federation(self): """ From 11f37d8de747b6cc164f9ff71b6f44b88819d399 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 4 Dec 2024 09:54:42 +0100 Subject: [PATCH 024/233] fix_info_points --- nebula/core/engine.py | 1 + nebula/core/models/nebulamodel.py | 5 +++++ nebula/core/neighbormanagement/nodemanager.py | 2 ++ nebula/core/training/lightning.py | 3 +++ 4 files changed, 11 insertions(+) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index d36465a54..9ed27531e 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -821,6 +821,7 @@ async def _additional_mobility_actions(self): if not self.mobility: return logging.info("πŸ”„ Starting additional mobility actions...") + self.trainer.show_current_learning_rate() await self.nm.check_robustness() action = await self.nm.check_external_connection_service_status() if action: diff --git a/nebula/core/models/nebulamodel.py b/nebula/core/models/nebulamodel.py index 8495e660e..1066b2e7c 100755 --- a/nebula/core/models/nebulamodel.py +++ b/nebula/core/models/nebulamodel.py @@ -232,9 +232,14 @@ def get_loss(self): return self._current_loss def modify_learning_rate(self, new_lr): + logging.info(f"Modifiying | learning rate, new value: {new_lr}") for param_group in self._optimizer.param_groups: param_group['lr'] = new_lr + def show_current_learning_rate(self): + for param_group in self._optimizer.param_groups: + logging.info(f"Showing | Learning rate current value: {param_group['lr']}") + def set_updated_round(self, round): self.round = round self.global_number = { diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 6a759ea3b..0de6276fc 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -379,8 +379,10 @@ async def upgrade_connection_robustness(self): addrs_to_connect = self.neighbor_policy.get_nodes_known(neighbors_too=False) # If we got some refs, try to connect to them if len(addrs_to_connect) > 0: + logging.info("Reestructuring | Addrs availables") await self.start_late_connection_process(connected=True, msg_type="discover_nodes", addrs_known=addrs_to_connect) else: + logging.info("Reestructuring | NO Addrs availables") await self.start_late_connection_process(connected=True, msg_type="discover_nodes") self._restructure_process_lock.release() diff --git a/nebula/core/training/lightning.py b/nebula/core/training/lightning.py index 6266b5b7f..5c5d16f53 100755 --- a/nebula/core/training/lightning.py +++ b/nebula/core/training/lightning.py @@ -356,3 +356,6 @@ def on_learning_cycle_end(self): def update_model_learning_rate(self, new_lr): self.model.modify_learning_rate(new_lr) + + def show_current_learning_rate(self): + self.model.show_current_learning_rate() From 3e8ed56ce5e059768831946bdbcf037a1c0b27a2 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 5 Dec 2024 17:08:46 +0100 Subject: [PATCH 025/233] feat_nebulamulticasting_on_off -fixed learnign rate error -fixed push sync error -fix sync errors --- nebula/core/aggregation/aggregator.py | 17 +++-- nebula/core/engine.py | 63 +++++++++++-------- nebula/core/models/nebulamodel.py | 1 + .../hetcandidateselector.py | 1 + nebula/core/neighbormanagement/nodemanager.py | 62 +++++++++++------- nebula/core/network/communications.py | 3 +- nebula/core/network/nebulamulticasting.py | 5 +- 7 files changed, 94 insertions(+), 58 deletions(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index e0812021e..1682dbc65 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -236,7 +236,7 @@ async def get_aggregation(self): else: logging.info("πŸ”„ get_aggregation | All models accounted for, proceeding with aggregation.") - self._pending_models_to_aggregate = self.engine.apply_weight_strategy(self._pending_models_to_aggregate) + self._pending_models_to_aggregate = await self.engine.apply_weight_strategy(self._pending_models_to_aggregate) aggregated_result = self.run_aggregation(self._pending_models_to_aggregate) if not self.engine.get_sinchronized_status() and self.engine.get_push_acceleration() == "fast": await self._add_model_lock.release_async() @@ -288,13 +288,15 @@ async def aggregation_push_available(self): if push == "slow": logging.info(f"❗️ FUTURE round: {further_round} is available | PUSH strategy ON") logging.info("❗️ SLOW push selected | Start PUSHING slow") - self.engine.set_pushed_done(self.engine.get_round() - further_round) + self.engine.set_pushed_done(further_round - self.engine.get_round()) # we wait until learning cycle reach aggregation point while not self._aggregation_done_lock.locked_async(): logging.info("πŸ”„ Waiting | aggregation step not reached yet...") await asyncio.sleep(1) # Unlock aggregation logging.info("πŸ”„ Releasing aggregation lock...") + self.engine.update_sinchronized_status(False) + self.engine.set_synchronizing_rounds(True) await self._aggregation_done_lock.release_async() return @@ -304,13 +306,15 @@ async def aggregation_push_available(self): if further_round == (self.engine.get_round()+1): logging.info(f"πŸ”„ Rounds jumped: {1}...") - self.engine.set_pushed_done(self.engine.get_round() - further_round) + self.engine.set_pushed_done(further_round - self.engine.get_round()) # we wait until learning cycle reach aggregation point while not self._aggregation_done_lock.locked_async(): logging.info("πŸ”„ Waiting | aggregation step not reached yet...") await asyncio.sleep(1) # Unlock aggregation logging.info("πŸ”„ Releasing aggregation lock...") + self.engine.update_sinchronized_status(False) + self.engine.set_synchronizing_rounds(True) await self._aggregation_done_lock.release_async() return @@ -338,8 +342,10 @@ async def aggregation_push_available(self): for key in self._future_models_to_aggregate.keys(): if key <= further_round: del self._future_models_to_aggregate[key] - - self.engine.set_pushed_done(self.engine.get_round() - further_round) + + self.engine.update_sinchronized_status(False) + self.engine.set_synchronizing_rounds(True) + self.engine.set_pushed_done(further_round - self.engine.get_round()) self.engine.set_round(further_round) # Unlock aggregation @@ -351,6 +357,7 @@ async def aggregation_push_available(self): else: self.engine.update_sinchronized_status(True) + self.engine.set_synchronizing_rounds(False) else: logging.info(f"All models updates are received | models number: {len(self.get_nodes_pending_models_to_aggregate())}") else: diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 9ed27531e..d0dd44894 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -260,10 +260,13 @@ def get_synchronizing_rounds(self): def update_sinchronized_status(self, status): with self.sinchronized_status_lock: - if self.mobility: - self.nm.set_synchronizing_rounds(status) + logging.info(f"Update | synchronized status from: {self._sinchronized_status} to {status}") self._sinchronized_status = status + def set_synchronizing_rounds(self, status): + if self.mobility: + self.nm.set_synchronizing_rounds(not status) + def set_round(self, new_round): logging.info(f"πŸ€– Update round count | from: {self.round} | to round: {new_round}") self.round = new_round @@ -386,19 +389,19 @@ async def _federation_models_included_callback(self, source, message): async def _connection_late_connect_callback(self, source, message): logging.info(f"πŸ”— handle_connection_message | Trigger | Received late connect message from {source}") if self.nm.accept_connection(source, joining=True): - logging.info(f"πŸ”— Late connection accepted | source: {source}") + logging.info(f"πŸ”— handle_connection_message | Late connection accepted | source: {source}") self.nm.add_weight_modifier(source) ct_actions , df_actions = self.nm.get_actions() if len(ct_actions): - for addr in ct_actions.split(): - cnt_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.CONNECT_TO, addr) - await self.cm.send_message(source, cnt_msg) + #for addr in ct_actions.split(): + cnt_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.CONNECT_TO, ct_actions) + await self.cm.send_message(source, cnt_msg) if len(df_actions): - for addr in df_actions.split(): - df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, addr) - await self.cm.send_message(source, df_msg) + #for addr in df_actions.split(): + df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, df_actions) + await self.cm.send_message(source, df_msg) await self.cm.connect(source, direct=True) self.nm.meet_node(source) @@ -418,17 +421,19 @@ async def _connection_restructure_callback(self, source, message): ct_actions , df_actions = self.nm.get_actions() if len(ct_actions): - for addr in ct_actions.split(): - cnt_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.CONNECT_TO, addr) - await self.cm.send_message(source, cnt_msg) + cnt_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.CONNECT_TO, ct_actions) + await self.cm.send_message(source, cnt_msg) if len(df_actions): - for addr in df_actions.split(): - df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, addr) - await self.cm.send_message(source, df_msg) + df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, df_actions) + await self.cm.send_message(source, df_msg) + + await self.cm.connect(source, direct=True) + self.nm.meet_node(source) + self.nm.update_neighbors(source) else: logging.info(f"❗️ handle_connection_message | Trigger | restructure connection denied from {source}") - await self.cm.disconnect(source, mutual_disconnection=False) + await self.cm.disconnect(source, mutual_disconnection=True) @event_handler(nebula_pb2.DiscoverMessage, nebula_pb2.DiscoverMessage.Action.DISCOVER_JOIN) async def _discover_discover_join_callback(self, source, message): @@ -438,7 +443,6 @@ async def _discover_discover_join_callback(self, source, message): await self.trainning_in_progress_lock.acquire_async() model, rounds, round = await self.cm.propagator.get_model_information(source, "stable") if self.get_round() > 0 else await self.cm.propagator.get_model_information(source, "initialization") await self.trainning_in_progress_lock.release_async() - #model, rounds, round = await self.cm.propagator.get_model_information(source, "initialization") if round != -1: epochs = self.config.participant["training_args"]["epochs"] msg = self.cm.mm.generate_offer_message( @@ -464,8 +468,11 @@ async def _discover_discover_join_callback(self, source, message): async def _discover_discover_nodes_callback(self, source, message): logging.info(f"πŸ” handle_discover_message | Trigger | Received discover_node message from {source} ") self.nm.meet_node(source) - msg = self.cm.mm.generate_offer_message(nebula_pb2.OfferMessage.Action.OFFER_METRIC, len(self.get_federation_nodes()), self.trainer.get_current_loss()) - await self.cm.send_message(source, msg) + if len(self.get_federation_nodes()) > 0: + msg = self.cm.mm.generate_offer_message(nebula_pb2.OfferMessage.Action.OFFER_METRIC, len(self.get_federation_nodes()), self.trainer.get_current_loss()) + await self.cm.send_message(source, msg) + else: + logging.info(f"πŸ”— Dissmissing discover nodes from {source} | no active connections at the moment") @event_handler( nebula_pb2.OfferMessage, @@ -474,7 +481,7 @@ async def _discover_discover_nodes_callback(self, source, message): async def _offer_offer_model_callback(self, source, message): logging.info(f"πŸ” handle_offer_message | Trigger | Received offer_model message from {source}") self.nm.meet_node(source) - if not self.nm.get_restructure_process_lock().locked() and self.nm.still_waiting_for_candidates(): + if self.nm.still_waiting_for_candidates(): try: model_compressed = message.parameters if self.nm.accept_model_offer(source, model_compressed, message.rounds, message.round, message.epochs, message.n_neighbors, message.loss): @@ -494,8 +501,9 @@ async def _offer_offer_model_callback(self, source, message): ) async def _offer_offer_metric_callback(self, source, message): logging.info(f"πŸ” handle_offer_message | Trigger | Received offer_metric message from {source}") - if not self.nm.get_restructure_process_lock().locked(): - n_neighbors, loss, _, _, _, _ = message.arguments + if self.nm.still_waiting_for_candidates(): + n_neighbors = message.n_neighbors + loss = message.loss self.nm.add_candidate(source, n_neighbors, loss) self.nm.meet_node(source) @@ -505,7 +513,7 @@ async def _offer_offer_metric_callback(self, source, message): ) async def _link_connect_to_callback(self, source, message): logging.info(f"πŸ”— handle_link_message | Trigger | Received connecto_to message from {source}") - addrs = message.arguments + addrs = message.addrs for addr in addrs.split(): #await self.cm.connect(addr, direct=True) #self.nm.update_neighbors(addr) @@ -517,7 +525,7 @@ async def _link_connect_to_callback(self, source, message): ) async def _link_disconnect_from_callback(self, source, message): logging.info(f"πŸ”— handle_link_message | Trigger | Received disconnect_from message from {source}") - addrs = message.arguments + addrs = message.addrs for addr in addrs.split(): await self.cm.disconnect(source, mutual_disconnection=False) self.nm.update_neighbors(addr, remove=True) @@ -540,9 +548,9 @@ def get_push_acceleration(self): def set_pushed_done(self, rounds_push): self.nm.set_rounds_pushed(rounds_push) - def apply_weight_strategy(self, pending_models): + async def apply_weight_strategy(self, pending_models): if self.mobility and self.nm.fast_reboot_on(): - self.nm.apply_weight_strategy(pending_models) + await self.nm.apply_weight_strategy(pending_models) return pending_models else: return pending_models @@ -736,7 +744,8 @@ async def _dynamic_aggregator(self, aggregated_models_weights, malicious_nodes): async def _waiting_model_updates(self): logging.info(f"πŸ’€ Waiting convergence in round {self.round}.") - await self.aggregator.aggregation_push_available() + if self.mobility: + await self.aggregator.aggregation_push_available() params = await self.aggregator.get_aggregation() if params is not None: logging.info( diff --git a/nebula/core/models/nebulamodel.py b/nebula/core/models/nebulamodel.py index 1066b2e7c..8e6eb6cfc 100755 --- a/nebula/core/models/nebulamodel.py +++ b/nebula/core/models/nebulamodel.py @@ -233,6 +233,7 @@ def get_loss(self): def modify_learning_rate(self, new_lr): logging.info(f"Modifiying | learning rate, new value: {new_lr}") + self.learning_rate = new_lr for param_group in self._optimizer.param_groups: param_group['lr'] = new_lr diff --git a/nebula/core/neighbormanagement/candidateselection/hetcandidateselector.py b/nebula/core/neighbormanagement/candidateselection/hetcandidateselector.py index 61a302007..20ec939a1 100644 --- a/nebula/core/neighbormanagement/candidateselection/hetcandidateselector.py +++ b/nebula/core/neighbormanagement/candidateselection/hetcandidateselector.py @@ -55,6 +55,7 @@ def any_candidate(self): self.candidates_lock.release() return any + #TODO hay q descontar los vecinos propios ya establecidos def __calculate_ideal_neighbors(self): """ Returns: diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 0de6276fc..8df8d35ea 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -58,8 +58,8 @@ def __init__( self.synchronizing_rounds = False - self._fast_reboot = False - self._learning_rate=1e-3 + self._fast_reboot = True + self._learning_rate=2e-3 #self.set_confings() @@ -173,25 +173,26 @@ def remove_weight_modifier(self, addr): del self.weight_modifier[addr] self.weight_modifier_lock.release() - def apply_weight_strategy(self, updates): - logging.info(f"πŸ”„ Applying weight Strategy...") + async def apply_weight_strategy(self, updates: dict): + logging.info(f"πŸ”„ Applying weight Strategy...") # We must lower the weight_modifier value if a round jump has been occured # as many times as rounds have been jumped if self.rounds_pushed: + logging.info(f"πŸ”„ There are rounds being pushed...") for i in range(0, self.rounds_pushed): self._update_weight_modifiers() - self.rounds_pushed = 0 + self.rounds_pushed = 0 for addr,update in updates.items(): - weight_modifier, _ = self._get_weight_modifier(addr) - if weight_modifier != 1: - logging.info (f"πŸ“ Appliying modified weight strategy | addr: {addr} | multiplier value: {weight_modifier}") + weightmodifier, rounds = self._get_weight_modifier(addr) + if weightmodifier != 1: + logging.info (f"πŸ“ Appliying modified weight strategy | addr: {addr} | multiplier value: {weightmodifier}") model, weight = update - updates.update({addr: (model, weight*weight_modifier)}) + updates.update({addr: (model, weight*weightmodifier)}) self._update_weight_modifiers() def _update_weight_modifiers(self): self.weight_modifier_lock.acquire() - for addr,(weight,rounds) in self.weight_modifier.items(): + for addr, (weight,rounds) in self.weight_modifier.items(): new_weight = weight - 1/(rounds**2) rounds = rounds + 1 if new_weight > 1: @@ -202,7 +203,7 @@ def _update_weight_modifiers(self): def _get_weight_modifier(self, addr): self.weight_modifier_lock.acquire() - wm = self.weight_modifier.get(addr, (1,0)) + wm = self.weight_modifier.get(addr, (1,0)) self.weight_modifier_lock.release() return wm @@ -235,6 +236,7 @@ def get_actions(self): return self.neighbor_policy.get_actions() def update_neighbors(self, node, remove=False): + logging.info(f"Update neighbor | node addr: {node} | remove: {remove}") self.neighbor_policy.update_neighbors(node, remove) #self.timer_generator.update_node(node, remove) if remove: @@ -280,13 +282,17 @@ async def stop_not_selected_connections(self): except asyncio.CancelledError as e: pass - async def check_external_connection_service_status(self): - action = None + async def check_external_connection_service_status(self): logging.info(f"πŸ”„ Checking external connection service status...") - if not self.neighbors_left() and self.engine.cm.is_external_connection_service_running(): + n = await self.neighbors_left() + ecs = await self.engine.cm.is_external_connection_service_running() + ss = self.engine.get_sinchronized_status() + action = None + logging.info(f"Stats | neighbos: {n} | service running: {ecs} | synchronized status: {ss}") + if not await self.neighbors_left() and await self.engine.cm.is_external_connection_service_running(): logging.info(f"❗️ Isolated node | Shutdowning service required") action = lambda: self.engine.cm.stop_external_connection_service() - elif self.neighbors_left() and not self.engine.cm.is_external_connection_service_running() and self.engine.get_sinchronized_status(): + elif await self.neighbors_left() and not await self.engine.cm.is_external_connection_service_running() and self.engine.get_sinchronized_status(): logging.info(f"πŸ”„ NOT isolated node | Service not running | Starting service...") action = lambda: self.engine.cm.init_external_connection_service() return action @@ -331,6 +337,7 @@ async def start_late_connection_process(self, connected=False, msg_type="discove for addr, _, _ in best_candidates: await self.engine.cm.connect(addr, direct=True) await self.engine.cm.send_message(addr, msg) + self.update_neighbors(addr) await asyncio.sleep(1) self.accept_candidates_lock.release() @@ -338,10 +345,11 @@ async def start_late_connection_process(self, connected=False, msg_type="discove self.candidate_selector.remove_candidates() # if no candidates, repeat process else: - logging.info("❗️ No Candidates found | repeating process") + logging.info("❗️ No Candidates found...") self.accept_candidates_lock.release() self.late_connection_process_lock.release() if not connected: + logging.info("❗️ repeating process...") await self.start_late_connection_process(connected, msg_type, addrs_known) @@ -353,14 +361,20 @@ async def start_late_connection_process(self, connected=False, msg_type="discove async def check_robustness(self): logging.info("πŸ”„ Analizing node network robustness...") - if self.no_neighbors_left(): - logging.info("No Neighbors left | reconnecting with Federation") - #await self.reconnect_to_federation() - elif self.neighbor_policy.need_more_neighbors(): - logging.info("Insufficient Robustness | Upgrading robustness | Searching for more connections") - #asyncio.create_task(self.upgrade_connection_robustness()) + if not self._restructure_process_lock.locked(): + if not self.neighbors_left(): + logging.info("No Neighbors left | reconnecting with Federation") + #await self.reconnect_to_federation() + elif self.neighbor_policy.need_more_neighbors() and self.engine.get_sinchronized_status(): + logging.info("Insufficient Robustness | Upgrading robustness | Searching for more connections") + asyncio.create_task(self.upgrade_connection_robustness()) + else: + if self.engine.get_sinchronized_status(): + logging.info("Device not synchronized with federation") + else: + logging.info("Sufficient Robustness | no actions required") else: - logging.info("Sufficient Robustness | no actions required") + logging.info("❗️ Reestructure/Reconnecting process already running...") async def reconnect_to_federation(self): # If we got some refs, try to reconnect to them @@ -379,7 +393,7 @@ async def upgrade_connection_robustness(self): addrs_to_connect = self.neighbor_policy.get_nodes_known(neighbors_too=False) # If we got some refs, try to connect to them if len(addrs_to_connect) > 0: - logging.info("Reestructuring | Addrs availables") + logging.info(f"Reestructuring | Addrs availables | addr list: {addrs_to_connect}") await self.start_late_connection_process(connected=True, msg_type="discover_nodes", addrs_known=addrs_to_connect) else: logging.info("Reestructuring | NO Addrs availables") diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index e81fb259e..016c81707 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -414,12 +414,13 @@ async def stablish_connection_to_federation(self, msg_type="discover_join", addr msg = self.mm.generate_discover_message(nebula_pb2.DiscoverMessage.Action.DISCOVER_NODES) logging.info("Starting communications with devices found") + #TODO filtrar para para quitar las que ya son vecinos for addr in addrs: await self.connect(addr, direct=False) await asyncio.sleep(1) while not self.verify_connections(addrs): await asyncio.sleep(1) - current_connections = await self.get_addrs_current_connections() + current_connections = await self.get_addrs_current_connections(only_undirected=True) logging.info(f"Connections verified after searching: {current_connections}") for addr in addrs: logging.info(f"Sending {msg_type} to ---> {addr}") diff --git a/nebula/core/network/nebulamulticasting.py b/nebula/core/network/nebulamulticasting.py index 2bdbae680..d8a81f6e9 100644 --- a/nebula/core/network/nebulamulticasting.py +++ b/nebula/core/network/nebulamulticasting.py @@ -168,7 +168,10 @@ def stop(self): self.server.stop def is_running(self): - return self.server.is_running() + if self.server: + return self.server.is_running() + else: + return False def find_federation(self): """ From b84818d73b332004dbfa430b25db6f51bb47575a Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Tue, 10 Dec 2024 11:06:17 +0100 Subject: [PATCH 026/233] fix_errors_reestructuring --- nebula/core/engine.py | 27 +++++++------ nebula/core/neighbormanagement/nodemanager.py | 39 ++++++++++++++++--- 2 files changed, 49 insertions(+), 17 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index d0dd44894..5764f1bec 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -316,13 +316,17 @@ async def _connection_connect_callback(self, source, message): if source not in current_connections: logging.info(f"πŸ”— handle_connection_message | Trigger | Connecting to {source}") await self.cm.connect(source, direct=True) + if self.mobility and self.nm.waiting_confirmation_from(source): + self.nm.confirmation_received(source, confirmation=True) @event_handler(nebula_pb2.ConnectionMessage, nebula_pb2.ConnectionMessage.Action.DISCONNECT) async def _connection_disconnect_callback(self, source, message): logging.info(f"πŸ”— handle_connection_message | Trigger | Received disconnection message from {source}") if self.mobility: - if source in await self.cm.get_all_addrs_current_connections(only_direct=True): - self.nm.update_neighbors(source, remove=True) + if self.nm.waiting_confirmation_from(source): + self.nm.confirmation_received(source, confirmation=False) + #if source in await self.cm.get_all_addrs_current_connections(only_direct=True): + self.nm.update_neighbors(source, remove=True) await self.cm.disconnect(source, mutual_disconnection=False) @@ -391,8 +395,9 @@ async def _connection_late_connect_callback(self, source, message): if self.nm.accept_connection(source, joining=True): logging.info(f"πŸ”— handle_connection_message | Late connection accepted | source: {source}") self.nm.add_weight_modifier(source) - ct_actions , df_actions = self.nm.get_actions() + await self.cm.connect(source, direct=True) + ct_actions , df_actions = self.nm.get_actions() if len(ct_actions): #for addr in ct_actions.split(): cnt_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.CONNECT_TO, ct_actions) @@ -403,7 +408,6 @@ async def _connection_late_connect_callback(self, source, message): df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, df_actions) await self.cm.send_message(source, df_msg) - await self.cm.connect(source, direct=True) self.nm.meet_node(source) self.nm.update_neighbors(source) await self.update_model_learning_rate() @@ -418,8 +422,9 @@ async def _connection_restructure_callback(self, source, message): logging.info(f"πŸ”— handle_connection_message | Trigger | Received restructure message from {source}") if self.nm.accept_connection(source, joining=False): logging.info(f"πŸ”— handle_connection_message | Trigger | restructure connection accepted from {source}") - ct_actions , df_actions = self.nm.get_actions() - + await self.cm.connect(source, direct=True) + + ct_actions , df_actions = self.nm.get_actions() if len(ct_actions): cnt_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.CONNECT_TO, ct_actions) await self.cm.send_message(source, cnt_msg) @@ -427,18 +432,18 @@ async def _connection_restructure_callback(self, source, message): if len(df_actions): df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, df_actions) await self.cm.send_message(source, df_msg) - - await self.cm.connect(source, direct=True) + self.nm.meet_node(source) self.nm.update_neighbors(source) else: logging.info(f"❗️ handle_connection_message | Trigger | restructure connection denied from {source}") + await asyncio.sleep(1) await self.cm.disconnect(source, mutual_disconnection=True) @event_handler(nebula_pb2.DiscoverMessage, nebula_pb2.DiscoverMessage.Action.DISCOVER_JOIN) async def _discover_discover_join_callback(self, source, message): logging.info(f"πŸ” handle_discover_message | Trigger | Received discover_join message from {source} ") - self.nm.meet_node(source) + #self.nm.meet_node(source) if len(self.get_federation_nodes()) > 0: await self.trainning_in_progress_lock.acquire_async() model, rounds, round = await self.cm.propagator.get_model_information(source, "stable") if self.get_round() > 0 else await self.cm.propagator.get_model_information(source, "initialization") @@ -467,7 +472,7 @@ async def _discover_discover_join_callback(self, source, message): ) async def _discover_discover_nodes_callback(self, source, message): logging.info(f"πŸ” handle_discover_message | Trigger | Received discover_node message from {source} ") - self.nm.meet_node(source) + #self.nm.meet_node(source) if len(self.get_federation_nodes()) > 0: msg = self.cm.mm.generate_offer_message(nebula_pb2.OfferMessage.Action.OFFER_METRIC, len(self.get_federation_nodes()), self.trainer.get_current_loss()) await self.cm.send_message(source, msg) @@ -535,7 +540,7 @@ async def _aditional_node_start(self): logging.info(f"Aditional node | {self.addr} | going to stablish connection with federation") await self.nm.start_late_connection_process() # continue .. - await self.nm.stop_not_selected_connections() + asyncio.create_task(self.nm.stop_not_selected_connections()) logging.info("Creating trainer service to start the federation process..") asyncio.create_task(self._start_learning_late()) #decoded_model = self.trainer.deserialize_model(message.parameters) diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 8df8d35ea..c5876da5a 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -38,6 +38,8 @@ def __init__( logging.info("Initializing Model Handler") self._model_handler = factory_ModelHandler(model_handler) self.late_connection_process_lock = Locker(name="late_connection_process_lock") + self.pending_confirmation_from_nodes = [] + self.pending_confirmation_from_nodes_lock = Locker(name="pending_confirmation_from_nodes_lock") self.weight_modifier = {} self.weight_modifier_lock = Locker(name="weight_modifier_lock") self.new_node_weight_multiplier = 3 @@ -180,6 +182,7 @@ async def apply_weight_strategy(self, updates: dict): if self.rounds_pushed: logging.info(f"πŸ”„ There are rounds being pushed...") for i in range(0, self.rounds_pushed): + logging.info(f"πŸ”„ Update | weights being updated cause of push...") self._update_weight_modifiers() self.rounds_pushed = 0 for addr,update in updates.items(): @@ -221,6 +224,24 @@ def accept_connection(self, source, joining=False): else: return self.neighbor_policy.accept_connection(source) + def add_pending_connection_confirmation(self, addr): + logging.info(f" Addition | pending connection confirmation from: {addr}") + with self.pending_confirmation_from_nodes_lock: + self.pending_confirmation_from_nodes.append(addr) + + def waiting_confirmation_from(self, addr): + with self.pending_confirmation_from_nodes_lock: + return addr in self.pending_confirmation_from_nodes + + async def confirmation_received(self, addr, confirmation=False): + logging.info(f" Update | connection confirmation received from: {addr} | confirmation: {confirmation}") + if confirmation: + await self.engine.cm.connect(addr, direct=True) + self.update_neighbors(addr) + else: + with self.pending_confirmation_from_nodes_lock: + self.pending_confirmation_from_nodes.remove(addr) + async def receive_update_from_node(self, node_id, node_response_time): await self.timer_generator.receive_update(node_id, node_response_time) @@ -273,7 +294,9 @@ def add_candidate(self,source, n_neighbors, loss): async def stop_not_selected_connections(self): try: + await asyncio.sleep(20) if len(self.discarded_offers_addr) > 0: + self.discarded_offers_addr = self.discarded_offers_addr - self.engine.get_federation_nodes() logging.info(f"Interrupting connections | discarded offers | nodes discarded: {self.discarded_offers_addr}") for addr in self.discarded_offers_addr: await self.engine.cm.disconnect(addr, mutual_disconnection=True) @@ -334,12 +357,16 @@ async def start_late_connection_process(self, connected=False, msg_type="discove best_candidates = self.candidate_selector.select_candidates() logging.info(f"Candidates | {[addr for addr,_,_ in best_candidates]}") # candidates not choosen --> disconnect - for addr, _, _ in best_candidates: - await self.engine.cm.connect(addr, direct=True) - await self.engine.cm.send_message(addr, msg) - self.update_neighbors(addr) - await asyncio.sleep(1) - + try: + for addr, _, _ in best_candidates: + await self.engine.cm.send_message(addr, msg) + self.add_pending_connection_confirmation(addr) + #await self.engine.cm.connect(addr, direct=True) + #self.update_neighbors(addr) + await asyncio.sleep(1) + except asyncio.CancelledError as e: + self.update_neighbors(addr, remove=True) + pass self.accept_candidates_lock.release() self.late_connection_process_lock.release() self.candidate_selector.remove_candidates() From 7ebf7de3828fab1bd6dc20562e1bcb0af4f39e6d Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Tue, 10 Dec 2024 13:49:45 +0100 Subject: [PATCH 027/233] fix_restructure_errors fixed all restructure errors found service working --- nebula/core/engine.py | 26 ++++++++++++++----- nebula/core/neighbormanagement/README.txt | 10 ++++--- nebula/core/neighbormanagement/nodemanager.py | 26 ++++++++++++------- nebula/node.py | 2 +- 4 files changed, 44 insertions(+), 20 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 5764f1bec..803dad771 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -316,8 +316,6 @@ async def _connection_connect_callback(self, source, message): if source not in current_connections: logging.info(f"πŸ”— handle_connection_message | Trigger | Connecting to {source}") await self.cm.connect(source, direct=True) - if self.mobility and self.nm.waiting_confirmation_from(source): - self.nm.confirmation_received(source, confirmation=True) @event_handler(nebula_pb2.ConnectionMessage, nebula_pb2.ConnectionMessage.Action.DISCONNECT) async def _connection_disconnect_callback(self, source, message): @@ -391,12 +389,19 @@ async def _federation_models_included_callback(self, source, message): nebula_pb2.ConnectionMessage.Action.LATE_CONNECT, ) async def _connection_late_connect_callback(self, source, message): - logging.info(f"πŸ”— handle_connection_message | Trigger | Received late connect message from {source}") - if self.nm.accept_connection(source, joining=True): + logging.info(f"πŸ”— handle_connection_message | Trigger | Received late connect message from {source}") + # Verify if it's a confirmation message from a previous late connection message sent to source + if self.nm.waiting_confirmation_from(source): + await self.nm.confirmation_received(source, confirmation=True) + elif self.nm.accept_connection(source, joining=True): logging.info(f"πŸ”— handle_connection_message | Late connection accepted | source: {source}") self.nm.add_weight_modifier(source) await self.cm.connect(source, direct=True) + # Verify conenction is accepted + conf_msg = self.cm.mm.generate_connection_message(nebula_pb2.ConnectionMessage.Action.LATE_CONNECT) + await self.cm.send_message(source, conf_msg) + ct_actions , df_actions = self.nm.get_actions() if len(ct_actions): #for addr in ct_actions.split(): @@ -420,10 +425,16 @@ async def _connection_late_connect_callback(self, source, message): ) async def _connection_restructure_callback(self, source, message): logging.info(f"πŸ”— handle_connection_message | Trigger | Received restructure message from {source}") - if self.nm.accept_connection(source, joining=False): + # Verify if it's a confirmation message from a previous restructure connection message sent to source + if self.nm.waiting_confirmation_from(source): + await self.nm.confirmation_received(source, confirmation=True) + elif self.nm.accept_connection(source, joining=False): logging.info(f"πŸ”— handle_connection_message | Trigger | restructure connection accepted from {source}") await self.cm.connect(source, direct=True) + conf_msg = self.cm.mm.generate_connection_message(nebula_pb2.ConnectionMessage.Action.RESTRUCTURE) + await self.cm.send_message(source, conf_msg) + ct_actions , df_actions = self.nm.get_actions() if len(ct_actions): cnt_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.CONNECT_TO, ct_actions) @@ -438,7 +449,7 @@ async def _connection_restructure_callback(self, source, message): else: logging.info(f"❗️ handle_connection_message | Trigger | restructure connection denied from {source}") await asyncio.sleep(1) - await self.cm.disconnect(source, mutual_disconnection=True) + #await self.cm.disconnect(source, mutual_disconnection=True) @event_handler(nebula_pb2.DiscoverMessage, nebula_pb2.DiscoverMessage.Action.DISCOVER_JOIN) async def _discover_discover_join_callback(self, source, message): @@ -506,11 +517,12 @@ async def _offer_offer_model_callback(self, source, message): ) async def _offer_offer_metric_callback(self, source, message): logging.info(f"πŸ” handle_offer_message | Trigger | Received offer_metric message from {source}") + self.nm.meet_node(source) if self.nm.still_waiting_for_candidates(): n_neighbors = message.n_neighbors loss = message.loss self.nm.add_candidate(source, n_neighbors, loss) - self.nm.meet_node(source) + @event_handler( nebula_pb2.LinkMessage, diff --git a/nebula/core/neighbormanagement/README.txt b/nebula/core/neighbormanagement/README.txt index 3eb4afdbf..2ce51abc2 100644 --- a/nebula/core/neighbormanagement/README.txt +++ b/nebula/core/neighbormanagement/README.txt @@ -42,9 +42,13 @@ It is important to note that the receiving node can reject the connection. 2) Select Candidates and connect to them - __________ ____________________ ___________ -| New node | -------> | Candidate Selector | ----> *CONNECT* ----> | Candidate | -|__________| | __________________ | | _________ | + __________ ____________________ ___________ +| New node | -------> | Candidate Selector | ----> *LATE_CONNECT* ----> | Candidate | +|__________| | __________________ | | _________ | + + ___________ __________ +| Candidate | -------> *LATE_CONNECT* ----> | New Node | +|___________| | _________| Retopology works the same way but with diferent arguments on the messages. diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index c5876da5a..a9bcdde66 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -221,6 +221,8 @@ def accept_connection(self, source, joining=False): if self.get_restructure_process_lock().locked(): logging.info("NOT accepting connections | Currently upgrading network Robustness") return False + else: + return self.neighbor_policy.accept_connection(source) else: return self.neighbor_policy.accept_connection(source) @@ -229,6 +231,10 @@ def add_pending_connection_confirmation(self, addr): with self.pending_confirmation_from_nodes_lock: self.pending_confirmation_from_nodes.append(addr) + def clear_pending_confirmations(self): + with self.pending_confirmation_from_nodes_lock: + self.pending_confirmation_from_nodes.clear() + def waiting_confirmation_from(self, addr): with self.pending_confirmation_from_nodes_lock: return addr in self.pending_confirmation_from_nodes @@ -295,13 +301,14 @@ def add_candidate(self,source, n_neighbors, loss): async def stop_not_selected_connections(self): try: await asyncio.sleep(20) - if len(self.discarded_offers_addr) > 0: - self.discarded_offers_addr = self.discarded_offers_addr - self.engine.get_federation_nodes() - logging.info(f"Interrupting connections | discarded offers | nodes discarded: {self.discarded_offers_addr}") - for addr in self.discarded_offers_addr: - await self.engine.cm.disconnect(addr, mutual_disconnection=True) - await asyncio.sleep(1) - self.discarded_offers_addr = [] + with self.discarded_offers_addr_lock: + if len(self.discarded_offers_addr) > 0: + self.discarded_offers_addr = set(self.discarded_offers_addr) - await self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False) + logging.info(f"Interrupting connections | discarded offers | nodes discarded: {self.discarded_offers_addr}") + for addr in self.discarded_offers_addr: + await self.engine.cm.disconnect(addr, mutual_disconnection=True) + await asyncio.sleep(1) + self.discarded_offers_addr = [] except asyncio.CancelledError as e: pass @@ -311,7 +318,7 @@ async def check_external_connection_service_status(self): ecs = await self.engine.cm.is_external_connection_service_running() ss = self.engine.get_sinchronized_status() action = None - logging.info(f"Stats | neighbos: {n} | service running: {ecs} | synchronized status: {ss}") + logging.info(f"Stats | neighbors: {n} | service running: {ecs} | synchronized status: {ss}") if not await self.neighbors_left() and await self.engine.cm.is_external_connection_service_running(): logging.info(f"❗️ Isolated node | Shutdowning service required") action = lambda: self.engine.cm.stop_external_connection_service() @@ -335,6 +342,7 @@ async def start_late_connection_process(self, connected=False, msg_type="discove self.late_connection_process_lock.acquire() best_candidates = [] self.candidate_selector.remove_candidates() + self.clear_pending_confirmations() # find federation and send discover await self.engine.cm.stablish_connection_to_federation(msg_type, addrs_known) @@ -396,7 +404,7 @@ async def check_robustness(self): logging.info("Insufficient Robustness | Upgrading robustness | Searching for more connections") asyncio.create_task(self.upgrade_connection_robustness()) else: - if self.engine.get_sinchronized_status(): + if not self.engine.get_sinchronized_status(): logging.info("Device not synchronized with federation") else: logging.info("Sufficient Robustness | no actions required") diff --git a/nebula/node.py b/nebula/node.py index 5f5f9cce4..58d312718 100755 --- a/nebula/node.py +++ b/nebula/node.py @@ -347,7 +347,7 @@ def randomize_value(value, variability): # In order to do that, it should request the current round to the controller if additional_node_status: logging.info(f"Waiting for round {additional_node_round} to start") - logging.info("Waiting 60s to start finding federation") + logging.info("Waiting time to start finding federation") time.sleep(70) #time.sleep(6000) # DEBUG purposes #import requests From 3b72cde36a396df8e4387d062220767eaa6b666e Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Tue, 10 Dec 2024 19:24:17 +0100 Subject: [PATCH 028/233] feat_connection_optimizator Connection optimizator to clear inactive connections --- nebula/core/neighbormanagement/nodemanager.py | 2 +- .../connectionoptimizer.py | 89 +++++++++++++++++++ nebula/node.py | 2 +- 3 files changed, 91 insertions(+), 2 deletions(-) create mode 100644 nebula/core/network/networkoptimization/connectionoptimizer.py diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index a9bcdde66..0f2a609e2 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -60,7 +60,7 @@ def __init__( self.synchronizing_rounds = False - self._fast_reboot = True + self._fast_reboot = False self._learning_rate=2e-3 #self.set_confings() diff --git a/nebula/core/network/networkoptimization/connectionoptimizer.py b/nebula/core/network/networkoptimization/connectionoptimizer.py new file mode 100644 index 000000000..3bec94fa1 --- /dev/null +++ b/nebula/core/network/networkoptimization/connectionoptimizer.py @@ -0,0 +1,89 @@ +import asyncio +import logging +from nebula.core.network.connection import Connection +from nebula.core.utils.locker import Locker +import heapq +import time + +PRIORITIES = {'HIGH': 1, 'MEDIUM': 2, 'LOW': 3} + + +class ConnectionOptimizer: + def __init__(self): + self.connection_heap = [] # Heap: (expire_time, priority, connection) + self.active_connections = {} + self.connection_heap_lock = Locker(name="connection_heap_lock", async_lock=True) + self._wake_up_event = asyncio.Event() + + async def update_connection_activity(self, connection: Connection, priority): + """ + Add new connection timeout to heap + """ + timeout = self._get_timeout_for_priority(priority) + expire_time = time.time() + timeout + async with self.connection_heap_lock: + self.active_connections[connection] = (expire_time, priority, True) # Activa + heapq.heappush(self.connection_heap, (expire_time, PRIORITIES[priority], connection)) + self._wake_up_event.set() + + async def set_connection_inactivity(self, connection): + """ + Set inactive state to a connection + """ + async with self.connection_heap_lock: + if connection in self.active_connections: + # set conection as inactive + self.active_connections[connection] = (*self.active_connections[connection][:2], False) + self._wake_up_event.set() + + async def start_daemon(self): + while True: + logging.info("Wake up | Connection optimizer deamon...") + await self._check_timeouts() + await self._wait_for_next_expiration() + + async def _check_timeouts(self): + """ + Check to remove expired connections + """ + current_time = time.time() + async with self.connection_heap_lock: + while self.connection_heap and self.connection_heap[0][0] <= current_time: + expire_time, priority, connection = heapq.heappop(self.connection_heap) + # Revisa si la conexiΓ³n estΓ‘ activa + if connection in self.active_connections: + _, _, is_active = self.active_connections[connection] + if is_active and self.active_connections[connection][0] == expire_time: + logging.info(f"Closing | Connection: {connection.get_addr()} (priority: {priority}) has expired...") + del self.active_connections[connection] + + async def _wait_for_next_expiration(self): + """ + Sleep until new connection is stored or lower timeout passed + """ + async with self.connection_heap_lock: + if not self.connection_heap: + self._wake_up_event.clear() + await self._wake_up_event.wait() + return + next_expiration = self.connection_heap[0][0] + + sleep_duration = max(0, next_expiration - time.time()) + + try: + await asyncio.wait_for(self._wake_up_event.wait(), timeout=sleep_duration) + except asyncio.TimeoutError: + pass + finally: + self._wake_up_event.clear() + + def _get_timeout_for_priority(self, priority): + """ + Priority timeouts + """ + if priority == 'HIGH': + return 30 # HIGH = 30s + elif priority == 'MEDIUM': + return 20 # MEDIUM = 20s + elif priority == 'LOW': + return 10 # LOW = 10s \ No newline at end of file diff --git a/nebula/node.py b/nebula/node.py index 58d312718..f04851895 100755 --- a/nebula/node.py +++ b/nebula/node.py @@ -348,7 +348,7 @@ def randomize_value(value, variability): if additional_node_status: logging.info(f"Waiting for round {additional_node_round} to start") logging.info("Waiting time to start finding federation") - time.sleep(70) + time.sleep(500) #time.sleep(6000) # DEBUG purposes #import requests From b8f7c353751a04b692b2f93fab53cfd3e0ee9b3b Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 11 Dec 2024 12:00:47 +0100 Subject: [PATCH 029/233] feat_network_optimization networkoptimizer as a controller connectionoptimizer to clean inactive connections timergenerator to generate dynamic timeouts for aggregation --- nebula/core/neighbormanagement/nodemanager.py | 26 +----------- nebula/core/network/connection.py | 5 +++ .../connectionoptimizer.py | 28 ++++++++----- .../networkoptimization/networkoptimizer | 42 +++++++++++++++++++ .../networkoptimization}/timergenerator.py | 9 ++-- nebula/node.py | 2 +- 6 files changed, 71 insertions(+), 41 deletions(-) create mode 100644 nebula/core/network/networkoptimization/networkoptimizer rename nebula/core/{neighbormanagement => network/networkoptimization}/timergenerator.py (96%) diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 0f2a609e2..8f29d8d93 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -8,7 +8,6 @@ from nebula.core.neighbormanagement.candidateselection.candidateselector import factory_CandidateSelector from nebula.core.neighbormanagement.modelhandlers.modelhandler import factory_ModelHandler from nebula.core.neighbormanagement.neighborpolicies.neighborpolicy import factory_NeighborPolicy -from nebula.core.neighbormanagement.timergenerator import TimerGenerator from nebula.core.pb import nebula_pb2 from nebula.core.network.communications import CommunicationsManager from nebula.addons.functions import print_msg_box @@ -48,19 +47,14 @@ def __init__( self._restructure_process_lock = Locker(name="restructure_process_lock") self.restructure = False self.discarded_offers_addr_lock = Locker(name="discarded_offers_addr_lock") - self.discarded_offers_addr = [] - - self.max_time_to_wait = 20 - logging.info("Initializing Timer generator") - self._timer_generator = None #TimerGenerator(self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False), self.max_time_to_wait, 80) - + self.discarded_offers_addr = [] self._push_acceleration = push_acceleration self.rounds_pushed_lock = Locker(name="rounds_pushed_lock") self.rounds_pushed = 0 self.synchronizing_rounds = False - self._fast_reboot = False + self._fast_reboot = True self._learning_rate=2e-3 #self.set_confings() @@ -81,10 +75,6 @@ def candidate_selector(self): def model_handler(self): return self._model_handler - @property - def timer_generator(self): - return self._timer_generator - def get_learning_rate_increase(self): return self._learning_rate @@ -146,15 +136,6 @@ async def set_confings(self): #self.engine.trainer.get_loss(), self.config.participant["molibity_args"]["weight_distance"], self.config.participant["molibity_args"]["weight_het"] #self.model_handler.set_config([self.engine.get_round(), self.engine.config.participant["training_args"]["epochs"]]) - def get_timer(self): - return self.timer_generator.get_timer(self.engine.get_round()) - - def adjust_timer(self): - self.timer_generator.adjust_timer() - - def get_stop_condition(self): - return self.timer_generator.get_stop_condition() - ############################## # WEIGHT STRATEGY # @@ -248,9 +229,6 @@ async def confirmation_received(self, addr, confirmation=False): with self.pending_confirmation_from_nodes_lock: self.pending_confirmation_from_nodes.remove(addr) - async def receive_update_from_node(self, node_id, node_response_time): - await self.timer_generator.receive_update(node_id, node_response_time) - def add_to_discarded_offers(self, addr_discarded): self.discarded_offers_addr_lock.acquire() self.discarded_offers_addr.append(addr_discarded) diff --git a/nebula/core/network/connection.py b/nebula/core/network/connection.py index 338340d67..fc1a6ff6d 100755 --- a/nebula/core/network/connection.py +++ b/nebula/core/network/connection.py @@ -39,6 +39,7 @@ def __init__( active=True, compression="zlib", config=None, + prio="MEDIUM" ): self.cm = cm self.reader = reader @@ -61,6 +62,7 @@ def __init__( self.process_task = None self.pending_messages_queue = asyncio.Queue(maxsize=100) self.message_buffers: dict[bytes, dict[int, MessageChunk]] = {} + self._prio = prio self.EOT_CHAR = b"\x00\x00\x00\x04" self.COMPRESSION_CHAR = b"\x00\x00\x00\x01" @@ -89,6 +91,9 @@ def __del__(self): def get_addr(self): return self.addr + + def get_prio(self): + return self._prio def get_federated_round(self): return self.federated_round diff --git a/nebula/core/network/networkoptimization/connectionoptimizer.py b/nebula/core/network/networkoptimization/connectionoptimizer.py index 3bec94fa1..76d1cffbd 100644 --- a/nebula/core/network/networkoptimization/connectionoptimizer.py +++ b/nebula/core/network/networkoptimization/connectionoptimizer.py @@ -1,11 +1,12 @@ import asyncio import logging from nebula.core.network.connection import Connection +from nebula.core.network.networkoptimization.networkoptimizer import NetworkOptimizer from nebula.core.utils.locker import Locker import heapq import time -PRIORITIES = {'HIGH': 1, 'MEDIUM': 2, 'LOW': 3} +PRIORITIES = {'HIGH': 30, 'MEDIUM': 20, 'LOW': 10} class ConnectionOptimizer: @@ -13,17 +14,19 @@ def __init__(self): self.connection_heap = [] # Heap: (expire_time, priority, connection) self.active_connections = {} self.connection_heap_lock = Locker(name="connection_heap_lock", async_lock=True) - self._wake_up_event = asyncio.Event() + self._wake_up_event = asyncio.Event() + self._running = True - async def update_connection_activity(self, connection: Connection, priority): + async def update_connection_activity(self, connection: Connection): """ Add new connection timeout to heap """ + priority = connection.get_prio() timeout = self._get_timeout_for_priority(priority) expire_time = time.time() + timeout async with self.connection_heap_lock: self.active_connections[connection] = (expire_time, priority, True) # Activa - heapq.heappush(self.connection_heap, (expire_time, PRIORITIES[priority], connection)) + heapq.heappush(self.connection_heap, (expire_time, priority, connection)) self._wake_up_event.set() async def set_connection_inactivity(self, connection): @@ -36,8 +39,12 @@ async def set_connection_inactivity(self, connection): self.active_connections[connection] = (*self.active_connections[connection][:2], False) self._wake_up_event.set() + async def stop_daemon(self): + self._running = False + async def start_daemon(self): - while True: + self._running = True + while self._running: logging.info("Wake up | Connection optimizer deamon...") await self._check_timeouts() await self._wait_for_next_expiration() @@ -81,9 +88,8 @@ def _get_timeout_for_priority(self, priority): """ Priority timeouts """ - if priority == 'HIGH': - return 30 # HIGH = 30s - elif priority == 'MEDIUM': - return 20 # MEDIUM = 20s - elif priority == 'LOW': - return 10 # LOW = 10s \ No newline at end of file + try: + return PRIORITIES[priority] + except KeyError: + logging.info(f"Not allowed: {priority}. PRIORITIES: {list(PRIORITIES.keys())}") + raise ValueError() \ No newline at end of file diff --git a/nebula/core/network/networkoptimization/networkoptimizer b/nebula/core/network/networkoptimization/networkoptimizer new file mode 100644 index 000000000..6d0439a00 --- /dev/null +++ b/nebula/core/network/networkoptimization/networkoptimizer @@ -0,0 +1,42 @@ +import asyncio +import logging +import time +from nebula.core.network.connection import Connection +from nebula.core.network.networkoptimization.connectionoptimizer import ConnectionOptimizer +from nebula.core.network.networkoptimization.timergenerator import TimerGenerator +from nebula.core.network.communications import CommunicationsManager +from nebula.core.utils.locker import Locker + + +class NetworkOptimizer: + def __init__(self, communication_manager, vanilla_max_timer, adaptative_timeouts=False): + self._communications_manager = communication_manager + self._connection_optimizer = ConnectionOptimizer() + self._adaptative_timeouts = adaptative_timeouts + self.max_time_to_wait = vanilla_max_timer + self._timer_generator = None #TimerGenerator(self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False), self.max_time_to_wait, 80) + + @property + def cm(self) -> CommunicationsManager: + return self._communications_manager + + @property + def co(self) -> ConnectionOptimizer: + return self._connection_optimizer + + @property + def tg(self) -> TimerGenerator: + return self._timer_generator + + async def info_received_from_connection(self, connection : Connection): + await self.co.update_connection_activity(connection) + + async def model_update_received_from_connection(self, source): + arrived_time = time.time() + await self.tg.receive_update(source, arrived_time) + + async def on_closing_connection(self, connection : Connection): + await self.co.set_connection_inactivity(connection) + + async def start_connection_cleaner(self): + self.co.start_daemon() \ No newline at end of file diff --git a/nebula/core/neighbormanagement/timergenerator.py b/nebula/core/network/networkoptimization/timergenerator.py similarity index 96% rename from nebula/core/neighbormanagement/timergenerator.py rename to nebula/core/network/networkoptimization/timergenerator.py index 10d3bae12..bac91add7 100644 --- a/nebula/core/neighbormanagement/timergenerator.py +++ b/nebula/core/network/networkoptimization/timergenerator.py @@ -35,13 +35,13 @@ def __init__( def get_stop_condition(self): return True #self.all_updates_received - def get_timer(self, round): + async def get_timer(self, round): self.round = round sm = time.time() self.start_moment = round(sm, 2) return self.waiting_time - def update_node(self, node, remove=False): + async def update_node(self, node, remove=False): if remove: self.nodes_historic.pop(node, None) self.max_updates_number -= 1 @@ -51,8 +51,7 @@ def update_node(self, node, remove=False): async def receive_update(self, node_id, node_response_time): """ - In this function the response time is saved in the historic, structures are updated and - condition is checked to stop the process because al responses are being received + In this function the response time is saved in the historic, structures are updated Args: node_id : node addr @@ -69,7 +68,7 @@ async def receive_update(self, node_id, node_response_time): #async with self.all_updates_received: # self.all_updates_received.notify_all() - def adjust_timer(self): + async def adjust_timer(self): """ The process of adjusting the timer is simple. if adaptative is not set up it will use the MAX_TIMER all the time. If not, the strategy will depend on the percent of updates receive the last round. diff --git a/nebula/node.py b/nebula/node.py index f04851895..f56f3299a 100755 --- a/nebula/node.py +++ b/nebula/node.py @@ -348,7 +348,7 @@ def randomize_value(value, variability): if additional_node_status: logging.info(f"Waiting for round {additional_node_round} to start") logging.info("Waiting time to start finding federation") - time.sleep(500) + time.sleep(550) #time.sleep(6000) # DEBUG purposes #import requests From fbbccd07fc20488a803e428e2531569b85e0e746 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 12 Dec 2024 15:58:11 +0100 Subject: [PATCH 030/233] fix_minor_errors --- nebula/core/engine.py | 3 ++- nebula/core/neighbormanagement/nodemanager.py | 26 ++++++++++++------- .../connectionoptimizer.py | 14 +++++++--- .../{networkoptimizer => networkoptimizer.py} | 21 +++++++++++++-- .../networkoptimization/timergenerator.py | 15 ++++++----- nebula/node.py | 4 ++- 6 files changed, 60 insertions(+), 23 deletions(-) rename nebula/core/network/networkoptimization/{networkoptimizer => networkoptimizer.py} (74%) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 803dad771..12fc539de 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -574,7 +574,8 @@ async def apply_weight_strategy(self, pending_models): async def update_model_learning_rate(self): await self.trainning_in_progress_lock.acquire_async() - self.trainer.update_model_learning_rate(self.nm.get_learning_rate_increase()) + if self.get_round() < self.total_rounds: + self.trainer.update_model_learning_rate(self.nm.get_learning_rate_increase()) await self.trainning_in_progress_lock.release_async() async def _start_learning_late(self): diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 8f29d8d93..b0b2edc97 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -172,17 +172,22 @@ async def apply_weight_strategy(self, updates: dict): logging.info (f"πŸ“ Appliying modified weight strategy | addr: {addr} | multiplier value: {weightmodifier}") model, weight = update updates.update({addr: (model, weight*weightmodifier)}) - self._update_weight_modifiers() + await self._update_weight_modifiers() - def _update_weight_modifiers(self): - self.weight_modifier_lock.acquire() - for addr, (weight,rounds) in self.weight_modifier.items(): - new_weight = weight - 1/(rounds**2) - rounds = rounds + 1 - if new_weight > 1: - self.weight_modifier[addr] = (new_weight, rounds) - else: - self.remove_weight_modifier(addr) + async def _update_weight_modifiers(self): + self.weight_modifier_lock.acquire() + logging.info(f"πŸ”„ Update | weights being updated") + if self.weight_modifier: + for addr, (weight,rounds) in self.weight_modifier.items(): + new_weight = weight - 1/(rounds**2) + rounds = rounds + 1 + if new_weight > 1 and rounds <= 20: + self.weight_modifier[addr] = (new_weight, rounds) + else: + self.remove_weight_modifier(addr) + else: + self._learning_rate = 1e-3 + await self.engine.update_model_learning_rate() self.weight_modifier_lock.release() def _get_weight_modifier(self, addr): @@ -207,6 +212,7 @@ def accept_connection(self, source, joining=False): else: return self.neighbor_policy.accept_connection(source) + #TODO aΓ±adir un remove def add_pending_connection_confirmation(self, addr): logging.info(f" Addition | pending connection confirmation from: {addr}") with self.pending_confirmation_from_nodes_lock: diff --git a/nebula/core/network/networkoptimization/connectionoptimizer.py b/nebula/core/network/networkoptimization/connectionoptimizer.py index 76d1cffbd..56a210a4d 100644 --- a/nebula/core/network/networkoptimization/connectionoptimizer.py +++ b/nebula/core/network/networkoptimization/connectionoptimizer.py @@ -10,13 +10,21 @@ class ConnectionOptimizer: - def __init__(self): + def __init__( + self, + network_optimizer : NetworkOptimizer + ): + self._network_optimizer = network_optimizer self.connection_heap = [] # Heap: (expire_time, priority, connection) self.active_connections = {} self.connection_heap_lock = Locker(name="connection_heap_lock", async_lock=True) self._wake_up_event = asyncio.Event() self._running = True + @property + def no(self): + return self._network_optimizer + async def update_connection_activity(self, connection: Connection): """ Add new connection timeout to heap @@ -45,7 +53,7 @@ async def stop_daemon(self): async def start_daemon(self): self._running = True while self._running: - logging.info("Wake up | Connection optimizer deamon...") + logging.info("Wake up | Connection optimizer daemon...") await self._check_timeouts() await self._wait_for_next_expiration() @@ -91,5 +99,5 @@ def _get_timeout_for_priority(self, priority): try: return PRIORITIES[priority] except KeyError: - logging.info(f"Not allowed: {priority}. PRIORITIES: {list(PRIORITIES.keys())}") + logging.error(f"Not allowed: {priority}. PRIORITIES: {list(PRIORITIES.keys())}") raise ValueError() \ No newline at end of file diff --git a/nebula/core/network/networkoptimization/networkoptimizer b/nebula/core/network/networkoptimization/networkoptimizer.py similarity index 74% rename from nebula/core/network/networkoptimization/networkoptimizer rename to nebula/core/network/networkoptimization/networkoptimizer.py index 6d0439a00..d3780a283 100644 --- a/nebula/core/network/networkoptimization/networkoptimizer +++ b/nebula/core/network/networkoptimization/networkoptimizer.py @@ -9,7 +9,12 @@ class NetworkOptimizer: - def __init__(self, communication_manager, vanilla_max_timer, adaptative_timeouts=False): + def __init__( + self, + communication_manager, + vanilla_max_timer, + adaptative_timeouts=False + ): self._communications_manager = communication_manager self._connection_optimizer = ConnectionOptimizer() self._adaptative_timeouts = adaptative_timeouts @@ -38,5 +43,17 @@ async def model_update_received_from_connection(self, source): async def on_closing_connection(self, connection : Connection): await self.co.set_connection_inactivity(connection) + async def connection_timeout_expired(self, connection): + pass + async def start_connection_cleaner(self): - self.co.start_daemon() \ No newline at end of file + self.co.start_daemon() + + async def process_direct_connection(self, source, closed=False): + await self.tg.update_node(source, remove=closed) + + async def get_round_timeout(self): + return await self.tg.get_timer() + + async def on_round_end(self): + await self.tg.on_round_end() \ No newline at end of file diff --git a/nebula/core/network/networkoptimization/timergenerator.py b/nebula/core/network/networkoptimization/timergenerator.py index bac91add7..42ebb1289 100644 --- a/nebula/core/network/networkoptimization/timergenerator.py +++ b/nebula/core/network/networkoptimization/timergenerator.py @@ -21,7 +21,7 @@ def __init__( self.max_timer_value = max_timer_value self.acceptable_percent = acceptable_percent self.max_historic_size = max_historic_size - self.round = round + self.round_completed = round self.nodes_historic = {node_id: deque(maxlen=self.max_historic_size) for node_id in nodes} self.adaptative = adaptative self.max_updates_number = len(self.nodes_historic) @@ -35,8 +35,8 @@ def __init__( def get_stop_condition(self): return True #self.all_updates_received - async def get_timer(self, round): - self.round = round + async def get_timer(self): + self.round_completed = self.round_completed + 1 sm = time.time() self.start_moment = round(sm, 2) return self.waiting_time @@ -68,7 +68,10 @@ async def receive_update(self, node_id, node_response_time): #async with self.all_updates_received: # self.all_updates_received.notify_all() - async def adjust_timer(self): + async def on_round_end(self): + await self._adjust_timer() + + async def _adjust_timer(self): """ The process of adjusting the timer is simple. if adaptative is not set up it will use the MAX_TIMER all the time. If not, the strategy will depend on the percent of updates receive the last round. @@ -112,12 +115,12 @@ async def adjust_timer(self): ema = self._exponential_moving_average(times_deque, alpha=0.1) max_ema = max(max_ema, ema) if percentile < self.acceptable_percent: - if not self.round >= self.max_historic_size: + if not self.round_completed >= self.max_historic_size: self.waiting_time = self._change_timer_value(self.waiting_time*1.2) # timer + 20% from max else: self.waiting_time = self._change_timer_value(max_ema*1.25) # if enough data for historic EMA, EMA*1.25 else: - if not self.round >= self.max_historic_size: + if not self.round_completed >= self.max_historic_size: self.waiting_time = self._change_timer_value(self.waiting_time*1.05) # timer + 5% from max else: self.waiting_time = self._change_timer_value(max_ema*1.15) # if enough data for historic EMA, EMA*1.15 diff --git a/nebula/node.py b/nebula/node.py index f56f3299a..832cdf340 100755 --- a/nebula/node.py +++ b/nebula/node.py @@ -348,7 +348,9 @@ def randomize_value(value, variability): if additional_node_status: logging.info(f"Waiting for round {additional_node_round} to start") logging.info("Waiting time to start finding federation") - time.sleep(550) + # 385 r30 + # 615 r50 + time.sleep(385) #time.sleep(6000) # DEBUG purposes #import requests From 1fe818d99e77ef7cc885ce264e9eca9936e04e4b Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 12 Dec 2024 17:47:12 +0100 Subject: [PATCH 031/233] fix_remove_weight_error --- nebula/core/engine.py | 1 + nebula/core/neighbormanagement/nodemanager.py | 16 ++++++++++------ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 12fc539de..bcbf4d801 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -575,6 +575,7 @@ async def apply_weight_strategy(self, pending_models): async def update_model_learning_rate(self): await self.trainning_in_progress_lock.acquire_async() if self.get_round() < self.total_rounds: + logging.info("Update | learning rate modified...") self.trainer.update_model_learning_rate(self.nm.get_learning_rate_increase()) await self.trainning_in_progress_lock.release_async() diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index b0b2edc97..eb075f887 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -150,11 +150,9 @@ def add_weight_modifier(self, addr): self.weight_modifier_lock.release() def remove_weight_modifier(self, addr): - self.weight_modifier_lock.acquire() if addr in self.weight_modifier: logging.info(f"πŸ“ Removing | weight modifier registered for source {addr}") del self.weight_modifier[addr] - self.weight_modifier_lock.release() async def apply_weight_strategy(self, updates: dict): logging.info(f"πŸ”„ Applying weight Strategy...") @@ -169,7 +167,7 @@ async def apply_weight_strategy(self, updates: dict): for addr,update in updates.items(): weightmodifier, rounds = self._get_weight_modifier(addr) if weightmodifier != 1: - logging.info (f"πŸ“ Appliying modified weight strategy | addr: {addr} | multiplier value: {weightmodifier}") + logging.info (f"πŸ“ Appliying modified weight strategy | addr: {addr} | multiplier value: {weightmodifier}, rounds applied: {rounds}") model, weight = update updates.update({addr: (model, weight*weightmodifier)}) await self._update_weight_modifiers() @@ -178,16 +176,22 @@ async def _update_weight_modifiers(self): self.weight_modifier_lock.acquire() logging.info(f"πŸ”„ Update | weights being updated") if self.weight_modifier: + remove_addrs = [] for addr, (weight,rounds) in self.weight_modifier.items(): new_weight = weight - 1/(rounds**2) rounds = rounds + 1 if new_weight > 1 and rounds <= 20: self.weight_modifier[addr] = (new_weight, rounds) else: - self.remove_weight_modifier(addr) + remove_addrs.append(addr) + #self.remove_weight_modifier(addr) + for a in remove_addrs: + self.remove_weight_modifier(a) else: - self._learning_rate = 1e-3 - await self.engine.update_model_learning_rate() + if self._learning_rate == (2e-3): + logging.info(f"πŸ”„ Finishing | weight strategy is completed") + self._learning_rate = 1e-3 + await self.engine.update_model_learning_rate() self.weight_modifier_lock.release() def _get_weight_modifier(self, addr): From 33d95786a1dcdbf96433e3af5ac8564f87260df4 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Mon, 16 Dec 2024 12:43:55 +0100 Subject: [PATCH 032/233] fix_solving_distributions --- nebula/node.py | 1 + nebula/scenarios.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/nebula/node.py b/nebula/node.py index 832cdf340..eb6143d86 100755 --- a/nebula/node.py +++ b/nebula/node.py @@ -109,6 +109,7 @@ async def main(config): dataset_str = config.participant["data_args"]["dataset"] num_workers = config.participant["data_args"]["num_workers"] model = None + logging.info(f"Number of nodes on the scenario: {n_nodes}") if dataset_str == "MNIST": dataset = MNISTDataset( num_classes=10, diff --git a/nebula/scenarios.py b/nebula/scenarios.py index ea3803ef3..d96aeab24 100644 --- a/nebula/scenarios.py +++ b/nebula/scenarios.py @@ -459,11 +459,13 @@ def load_configurations_and_start_nodes(self, additional_participants=None, sche # Update participants configuration is_start_node = False config_participants = [] + ap = len(additional_participants) if additional_participants else 0 + logging.info(f"######## nodes: {self.n_nodes} + additionals: {ap}") for i in range(self.n_nodes): with open(f"{self.config_dir}/participant_" + str(i) + ".json") as f: participant_config = json.load(f) participant_config["scenario_args"]["federation"] = self.scenario.federation - participant_config["scenario_args"]["n_nodes"] = self.n_nodes + participant_config["scenario_args"]["n_nodes"] = self.n_nodes + ap participant_config["network_args"]["neighbors"] = self.topologymanager.get_neighbors_string(i) participant_config["scenario_args"]["name"] = self.scenario_name participant_config["scenario_args"]["start_time"] = self.start_date_scenario From 8325055874d79609c28286d9b3f08358c6d84690 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Mon, 16 Dec 2024 14:04:03 +0100 Subject: [PATCH 033/233] feat_additional_data_dist_png additional nodes now show their data distributions --- nebula/core/datasets/mnist/mnist.py | 5 ++ nebula/core/datasets/nebuladataset.py | 71 ++++++++++++++++++++++++++- nebula/node.py | 3 +- nebula/scenarios.py | 6 ++- 4 files changed, 81 insertions(+), 4 deletions(-) diff --git a/nebula/core/datasets/mnist/mnist.py b/nebula/core/datasets/mnist/mnist.py index 8061dca20..d58f1ceab 100755 --- a/nebula/core/datasets/mnist/mnist.py +++ b/nebula/core/datasets/mnist/mnist.py @@ -19,6 +19,7 @@ def __init__( partition_parameter=0.5, seed=42, config=None, + additional=False ): super().__init__( num_classes=num_classes, @@ -31,6 +32,7 @@ def __init__( partition_parameter=partition_parameter, seed=seed, config=config, + additional=additional ) if partition_id < 0 or partition_id >= partitions_number: raise ValueError(f"partition_id {partition_id} is out of range for partitions_number {partitions_number}") @@ -83,6 +85,9 @@ def generate_non_iid_map(self, dataset, partition="dirichlet", partition_paramet self.plot_data_distribution(dataset, partitions_map) self.plot_all_data_distribution(dataset, partitions_map) + if self.additional: + self.plot_data_distribution_for_additional_node(dataset, partitions_map) + return partitions_map[self.partition_id] def generate_iid_map(self, dataset, partition="balancediid", partition_parameter=2): diff --git a/nebula/core/datasets/nebuladataset.py b/nebula/core/datasets/nebuladataset.py index 6e264723d..fb30a61ca 100755 --- a/nebula/core/datasets/nebuladataset.py +++ b/nebula/core/datasets/nebuladataset.py @@ -40,6 +40,7 @@ def __init__( partition_parameter=0.5, seed=42, config=None, + additional=False ): super().__init__() @@ -56,7 +57,8 @@ def __init__( self.partition_parameter = partition_parameter self.seed = seed self.config = config - + self.additional = additional + self.train_set = None self.train_indices_map = None self.test_set = None @@ -202,6 +204,73 @@ def plot_data_distribution(self, dataset, partitions_map): if hasattr(self, "tsne") and self.tsne: self.visualize_tsne(dataset) + def plot_data_distribution_for_additional_node(self, dataset, partitions_map): + sns.set() + sns.set_style("whitegrid", {"axes.grid": False}) + sns.set_context("paper", font_scale=1.5) + sns.set_palette("Set2") + + indices = partitions_map[self.partition_id] + class_counts = [0] * self.num_classes + for idx in indices: + label = dataset.targets[idx] + class_counts[label] += 1 + logging_training.info(f"Participant {self.partition_id + 1} class distribution: {class_counts}") + plt.figure() + plt.bar(range(self.num_classes), class_counts) + plt.xlabel("Class") + plt.ylabel("Number of samples") + plt.xticks(range(self.num_classes)) + if self.iid: + plt.title(f"Participant {self.partition_id + 1} class distribution (IID)") + else: + plt.title( + f"Participant {self.partition_id + 1} class distribution (Non-IID - {self.partition}) - {self.partition_parameter}" + ) + plt.tight_layout() + path_to_save = f"{self.config.participant['tracking_args']['log_dir']}/{self.config.participant['scenario_args']['name']}/participant_{self.partition_id}_class_distribution_{'iid' if self.iid else 'non_iid'}{'_' + self.partition if not self.iid else ''}.png" + plt.savefig(path_to_save, dpi=300, bbox_inches="tight") + plt.close() + + plt.figure() + max_point_size = 500 + min_point_size = 0 + + for i in range(self.partitions_number): + class_counts = [0] * self.num_classes + indices = partitions_map[i] + for idx in indices: + label = dataset.targets[idx] + class_counts[label] += 1 + + # Normalize the point sizes for this partition + max_samples_partition = max(class_counts) + sizes = [ + (size / max_samples_partition) * (max_point_size - min_point_size) + min_point_size + for size in class_counts + ] + plt.scatter([i] * self.num_classes, range(self.num_classes), s=sizes, alpha=0.5) + + plt.xlabel("Participant") + plt.ylabel("Class") + plt.xticks(range(self.partitions_number)) + plt.yticks(range(self.num_classes)) + if self.iid: + plt.title(f"Participant {i + 1} class distribution (IID)") + else: + plt.title( + f"Participant {i + 1} class distribution (Non-IID - {self.partition}) - {self.partition_parameter}" + ) + plt.tight_layout() + + # Saves the distribution display with circles of different size + path_to_save = f"{self.config.participant['tracking_args']['log_dir']}/{self.config.participant['scenario_args']['name']}/class_distribution_additionals_{'iid' if self.iid else 'non_iid'}{'_' + self.partition if not self.iid else ''}.png" + plt.savefig(path_to_save, dpi=300, bbox_inches="tight") + plt.close() + + if hasattr(self, "tsne") and self.tsne: + self.visualize_tsne(dataset) + def visualize_tsne(self, dataset): X = [] # List for storing the characteristics of the samples y = [] # Ready to store the labels of the samples diff --git a/nebula/node.py b/nebula/node.py index eb6143d86..9460ffc92 100755 --- a/nebula/node.py +++ b/nebula/node.py @@ -120,6 +120,7 @@ async def main(config): partition_parameter=partition_parameter, seed=42, config=config, + additional=additional_node_status ) if model_name == "MLP": model = MNISTModelMLP() @@ -351,7 +352,7 @@ def randomize_value(value, variability): logging.info("Waiting time to start finding federation") # 385 r30 # 615 r50 - time.sleep(385) + time.sleep(70) #time.sleep(6000) # DEBUG purposes #import requests diff --git a/nebula/scenarios.py b/nebula/scenarios.py index d96aeab24..7927c6b19 100644 --- a/nebula/scenarios.py +++ b/nebula/scenarios.py @@ -459,13 +459,14 @@ def load_configurations_and_start_nodes(self, additional_participants=None, sche # Update participants configuration is_start_node = False config_participants = [] - ap = len(additional_participants) if additional_participants else 0 + #ap = len(additional_participants) if additional_participants else 0 + ap = 1 logging.info(f"######## nodes: {self.n_nodes} + additionals: {ap}") for i in range(self.n_nodes): with open(f"{self.config_dir}/participant_" + str(i) + ".json") as f: participant_config = json.load(f) participant_config["scenario_args"]["federation"] = self.scenario.federation - participant_config["scenario_args"]["n_nodes"] = self.n_nodes + ap + participant_config["scenario_args"]["n_nodes"] = self.n_nodes participant_config["network_args"]["neighbors"] = self.topologymanager.get_neighbors_string(i) participant_config["scenario_args"]["name"] = self.scenario_name participant_config["scenario_args"]["start_time"] = self.start_date_scenario @@ -531,6 +532,7 @@ def load_configurations_and_start_nodes(self, additional_participants=None, sche with open(additional_participant_file) as f: participant_config = json.load(f) + logging.info(f"Configuration | additional nodes | participant: {self.n_nodes + i + 1}") participant_config["scenario_args"]["n_nodes"] = self.n_nodes + i + 1 participant_config["device_args"]["idx"] = last_participant_index + i participant_config["network_args"]["neighbors"] = "" From 5b9ef3db7f48b005262dfe9ff4360fcab549b1ad Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Tue, 17 Dec 2024 12:10:15 +0100 Subject: [PATCH 034/233] fix_additional_nodes_ip --- nebula/scenarios.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/nebula/scenarios.py b/nebula/scenarios.py index 7927c6b19..bc020cd78 100644 --- a/nebula/scenarios.py +++ b/nebula/scenarios.py @@ -460,8 +460,8 @@ def load_configurations_and_start_nodes(self, additional_participants=None, sche is_start_node = False config_participants = [] #ap = len(additional_participants) if additional_participants else 0 - ap = 1 - logging.info(f"######## nodes: {self.n_nodes} + additionals: {ap}") + ap = len(additional_participants) if additional_participants else 0 + logging.info(f"######## nodes: {self.n_nodes} + additionals: {ap} ######") for i in range(self.n_nodes): with open(f"{self.config_dir}/participant_" + str(i) + ".json") as f: participant_config = json.load(f) @@ -533,14 +533,18 @@ def load_configurations_and_start_nodes(self, additional_participants=None, sche participant_config = json.load(f) logging.info(f"Configuration | additional nodes | participant: {self.n_nodes + i + 1}") + logging.info("Mensaje de prueba de modificaciΓ³n") + logging.info(f"Valores de la ultima ip: ( {participant_config["network_args"]["ip"]} )") participant_config["scenario_args"]["n_nodes"] = self.n_nodes + i + 1 participant_config["device_args"]["idx"] = last_participant_index + i participant_config["network_args"]["neighbors"] = "" participant_config["network_args"]["ip"] = ( participant_config["network_args"]["ip"].rsplit(".", 1)[0] + "." - + str(int(participant_config["network_args"]["ip"].rsplit(".", 1)[1]) + 1) + + str(int(participant_config["network_args"]["ip"].rsplit(".", 1)[1]) + i + 1) ) + ip = str(participant_config["network_args"]["ip"]) + logging.info(f"El valor almacenado en json es: {ip}") participant_config["device_args"]["uid"] = hashlib.sha1( ( str(participant_config["network_args"]["ip"]) From 8c177faec285f19cf6d70abb58fdf1254684a265 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 19 Dec 2024 10:15:39 +0100 Subject: [PATCH 035/233] ft_test_setup --- .../neighbormanagement/modelhandlers/stdmodelhandler.py | 2 ++ nebula/core/neighbormanagement/nodemanager.py | 3 ++- nebula/node.py | 7 ++++++- nebula/scenarios.py | 4 ++-- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/nebula/core/neighbormanagement/modelhandlers/stdmodelhandler.py b/nebula/core/neighbormanagement/modelhandlers/stdmodelhandler.py index 392a3c746..b23f62726 100644 --- a/nebula/core/neighbormanagement/modelhandlers/stdmodelhandler.py +++ b/nebula/core/neighbormanagement/modelhandlers/stdmodelhandler.py @@ -1,5 +1,6 @@ from nebula.core.neighbormanagement.modelhandlers.modelhandler import ModelHandler from nebula.core.utils.locker import Locker +import logging class STDModelHandler(ModelHandler): @@ -30,6 +31,7 @@ def accept_model(self, model): save only first model received to set up own model later """ if not self.model_lock.locked(): + logging.info(" ### First model acquire ###") self.model_lock.acquire() self.model = model return True diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index eb075f887..5d2683d25 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -188,7 +188,7 @@ async def _update_weight_modifiers(self): for a in remove_addrs: self.remove_weight_modifier(a) else: - if self._learning_rate == (2e-3): + if len(self.weight_modifier) == 0 and self._learning_rate == (2e-3): logging.info(f"πŸ”„ Finishing | weight strategy is completed") self._learning_rate = 1e-3 await self.engine.update_model_learning_rate() @@ -271,6 +271,7 @@ def get_nodes_known(self, neighbors_too=False): def accept_model_offer(self, source, decoded_model, rounds, round, epochs, n_neighbors, loss): if not self.accept_candidates_lock.locked(): + logging.info(f"πŸ”„ Processing offer from {source}...") model_accepted = self.model_handler.accept_model(decoded_model) self.model_handler.set_config(config=(rounds, round, epochs)) if model_accepted: diff --git a/nebula/node.py b/nebula/node.py index 9460ffc92..daea7b538 100755 --- a/nebula/node.py +++ b/nebula/node.py @@ -352,7 +352,12 @@ def randomize_value(value, variability): logging.info("Waiting time to start finding federation") # 385 r30 # 615 r50 - time.sleep(70) + if config.participant["network_args"]["ip"] == "192.168.50.9": + logging.info("Sleeping 385s...") + time.sleep(385) + elif config.participant["network_args"]["ip"] == "192.168.50.10": + logging.info("Sleeping 800s...") + time.sleep(615) #time.sleep(6000) # DEBUG purposes #import requests diff --git a/nebula/scenarios.py b/nebula/scenarios.py index bc020cd78..86954f96e 100644 --- a/nebula/scenarios.py +++ b/nebula/scenarios.py @@ -533,8 +533,8 @@ def load_configurations_and_start_nodes(self, additional_participants=None, sche participant_config = json.load(f) logging.info(f"Configuration | additional nodes | participant: {self.n_nodes + i + 1}") - logging.info("Mensaje de prueba de modificaciΓ³n") - logging.info(f"Valores de la ultima ip: ( {participant_config["network_args"]["ip"]} )") + last_ip = participant_config["network_args"]["ip"] + logging.info(f"Valores de la ultima ip: ({last_ip})") participant_config["scenario_args"]["n_nodes"] = self.n_nodes + i + 1 participant_config["device_args"]["idx"] = last_participant_index + i participant_config["network_args"]["neighbors"] = "" From 9a7d80963ec47dab71b810b53b103ce512eb24d3 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 19 Dec 2024 11:27:26 +0100 Subject: [PATCH 036/233] fix_no_coinciding_samples --- nebula/core/datasets/mnist/mnist.py | 4 +-- nebula/core/engine.py | 5 ++-- nebula/core/neighbormanagement/nodemanager.py | 25 ++++++++++++++----- nebula/node.py | 2 +- nebula/scenarios.py | 8 +++--- 5 files changed, 29 insertions(+), 15 deletions(-) diff --git a/nebula/core/datasets/mnist/mnist.py b/nebula/core/datasets/mnist/mnist.py index d58f1ceab..db9a1fef3 100755 --- a/nebula/core/datasets/mnist/mnist.py +++ b/nebula/core/datasets/mnist/mnist.py @@ -85,8 +85,8 @@ def generate_non_iid_map(self, dataset, partition="dirichlet", partition_paramet self.plot_data_distribution(dataset, partitions_map) self.plot_all_data_distribution(dataset, partitions_map) - if self.additional: - self.plot_data_distribution_for_additional_node(dataset, partitions_map) + #if self.additional: + # self.plot_data_distribution_for_additional_node(dataset, partitions_map) return partitions_map[self.partition_id] diff --git a/nebula/core/engine.py b/nebula/core/engine.py index bcbf4d801..e3cbb3af1 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -395,7 +395,7 @@ async def _connection_late_connect_callback(self, source, message): await self.nm.confirmation_received(source, confirmation=True) elif self.nm.accept_connection(source, joining=True): logging.info(f"πŸ”— handle_connection_message | Late connection accepted | source: {source}") - self.nm.add_weight_modifier(source) + await self.nm.add_weight_modifier(source) await self.cm.connect(source, direct=True) # Verify conenction is accepted @@ -576,7 +576,8 @@ async def update_model_learning_rate(self): await self.trainning_in_progress_lock.acquire_async() if self.get_round() < self.total_rounds: logging.info("Update | learning rate modified...") - self.trainer.update_model_learning_rate(self.nm.get_learning_rate_increase()) + new_lr = await self.nm.get_learning_rate_increase() + self.trainer.update_model_learning_rate(new_lr) await self.trainning_in_progress_lock.release_async() async def _start_learning_late(self): diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 5d2683d25..565d6df2c 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -16,6 +16,9 @@ if TYPE_CHECKING: from nebula.core.engine import Engine +VANILLA_LEARNING_RATE = 1e-3 +FT_LEARNING_RATE = 2e-3 + class NodeManager(): def __init__( @@ -55,7 +58,8 @@ def __init__( self.synchronizing_rounds = False self._fast_reboot = True - self._learning_rate=2e-3 + self._learning_rate = VANILLA_LEARNING_RATE + self.learning_rate_lock = Locker(name="learning_rate_lock", async_lock=True) #self.set_confings() @@ -75,8 +79,11 @@ def candidate_selector(self): def model_handler(self): return self._model_handler - def get_learning_rate_increase(self): - return self._learning_rate + async def get_learning_rate_increase(self): + await self.learning_rate_lock.acquire_async() + lr = self._learning_rate + await self.learning_rate_lock.release_async() + return lr def fast_reboot_on(self): return self._fast_reboot @@ -97,6 +104,11 @@ def set_rounds_pushed(self, rp): with self.rounds_pushed_lock: self.rounds_pushed = rp + async def _set_learning_rate(self, lr): + await self.learning_rate_lock.acquire_async() + self._learning_rate = lr + await self.learning_rate_lock.release_async() + def still_waiting_for_candidates(self): return not self.accept_candidates_lock.locked() @@ -141,12 +153,13 @@ async def set_confings(self): # WEIGHT STRATEGY # ############################## - def add_weight_modifier(self, addr): + async def add_weight_modifier(self, addr): self.weight_modifier_lock.acquire() if not addr in self.weight_modifier: wm = self.new_node_weight_multiplier logging.info(f"πŸ“ Registering | Weight modifier registered for source {addr} | round: {self.engine.get_round()} | value: {wm}") self.weight_modifier[addr] = (wm,1) + await self._set_learning_rate(FT_LEARNING_RATE) self.weight_modifier_lock.release() def remove_weight_modifier(self, addr): @@ -188,9 +201,9 @@ async def _update_weight_modifiers(self): for a in remove_addrs: self.remove_weight_modifier(a) else: - if len(self.weight_modifier) == 0 and self._learning_rate == (2e-3): + if len(self.weight_modifier) == 0 and await self.get_learning_rate_increase() == FT_LEARNING_RATE: logging.info(f"πŸ”„ Finishing | weight strategy is completed") - self._learning_rate = 1e-3 + await self._set_learning_rate(VANILLA_LEARNING_RATE) await self.engine.update_model_learning_rate() self.weight_modifier_lock.release() diff --git a/nebula/node.py b/nebula/node.py index daea7b538..e73bbb067 100755 --- a/nebula/node.py +++ b/nebula/node.py @@ -354,7 +354,7 @@ def randomize_value(value, variability): # 615 r50 if config.participant["network_args"]["ip"] == "192.168.50.9": logging.info("Sleeping 385s...") - time.sleep(385) + time.sleep(50) elif config.participant["network_args"]["ip"] == "192.168.50.10": logging.info("Sleeping 800s...") time.sleep(615) diff --git a/nebula/scenarios.py b/nebula/scenarios.py index 86954f96e..3aeac3822 100644 --- a/nebula/scenarios.py +++ b/nebula/scenarios.py @@ -460,13 +460,13 @@ def load_configurations_and_start_nodes(self, additional_participants=None, sche is_start_node = False config_participants = [] #ap = len(additional_participants) if additional_participants else 0 - ap = len(additional_participants) if additional_participants else 0 - logging.info(f"######## nodes: {self.n_nodes} + additionals: {ap} ######") + additional_nodes = len(additional_participants) if additional_participants else 0 + logging.info(f"######## nodes: {self.n_nodes} + additionals: {additional_nodes} ######") for i in range(self.n_nodes): with open(f"{self.config_dir}/participant_" + str(i) + ".json") as f: participant_config = json.load(f) participant_config["scenario_args"]["federation"] = self.scenario.federation - participant_config["scenario_args"]["n_nodes"] = self.n_nodes + participant_config["scenario_args"]["n_nodes"] = self.n_nodes + additional_nodes participant_config["network_args"]["neighbors"] = self.topologymanager.get_neighbors_string(i) participant_config["scenario_args"]["name"] = self.scenario_name participant_config["scenario_args"]["start_time"] = self.start_date_scenario @@ -535,7 +535,7 @@ def load_configurations_and_start_nodes(self, additional_participants=None, sche logging.info(f"Configuration | additional nodes | participant: {self.n_nodes + i + 1}") last_ip = participant_config["network_args"]["ip"] logging.info(f"Valores de la ultima ip: ({last_ip})") - participant_config["scenario_args"]["n_nodes"] = self.n_nodes + i + 1 + participant_config["scenario_args"]["n_nodes"] = self.n_nodes + additional_nodes #self.n_nodes + i + 1 participant_config["device_args"]["idx"] = last_participant_index + i participant_config["network_args"]["neighbors"] = "" participant_config["network_args"]["ip"] = ( From d05833b44c16a441c39344adc80cba7b6885f90d Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Tue, 24 Dec 2024 09:59:59 +0100 Subject: [PATCH 037/233] fix_update --- nebula/core/engine.py | 10 ++++------ nebula/core/models/cifar10/cnn.py | 1 + nebula/core/neighbormanagement/nodemanager.py | 6 +++++- nebula/node.py | 2 +- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index e3cbb3af1..ca9d9ec4d 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -415,7 +415,8 @@ async def _connection_late_connect_callback(self, source, message): self.nm.meet_node(source) self.nm.update_neighbors(source) - await self.update_model_learning_rate() + if self.nm.fast_reboot_on: + await self.update_model_learning_rate() else: logging.info(f"❗️ Late connection NOT accepted | source: {source}") @@ -457,7 +458,7 @@ async def _discover_discover_join_callback(self, source, message): #self.nm.meet_node(source) if len(self.get_federation_nodes()) > 0: await self.trainning_in_progress_lock.acquire_async() - model, rounds, round = await self.cm.propagator.get_model_information(source, "stable") if self.get_round() > 0 else await self.cm.propagator.get_model_information(source, "initialization") + model, rounds, round = await self.cm.propagator.get_model_information(source, "initialization") if self.get_round() > 0 else await self.cm.propagator.get_model_information(source, "initialization") await self.trainning_in_progress_lock.release_async() if round != -1: epochs = self.config.participant["training_args"]["epochs"] @@ -523,7 +524,6 @@ async def _offer_offer_metric_callback(self, source, message): loss = message.loss self.nm.add_candidate(source, n_neighbors, loss) - @event_handler( nebula_pb2.LinkMessage, nebula_pb2.LinkMessage.Action.CONNECT_TO, @@ -557,8 +557,6 @@ async def _aditional_node_start(self): asyncio.create_task(self._start_learning_late()) #decoded_model = self.trainer.deserialize_model(message.parameters) - - def get_push_acceleration(self): return self.nm.get_push_acceleration() @@ -850,7 +848,7 @@ async def _additional_mobility_actions(self): if not self.mobility: return logging.info("πŸ”„ Starting additional mobility actions...") - self.trainer.show_current_learning_rate() + #self.trainer.show_current_learning_rate() await self.nm.check_robustness() action = await self.nm.check_external_connection_service_status() if action: diff --git a/nebula/core/models/cifar10/cnn.py b/nebula/core/models/cifar10/cnn.py index c3c883a76..473ff3b93 100755 --- a/nebula/core/models/cifar10/cnn.py +++ b/nebula/core/models/cifar10/cnn.py @@ -43,4 +43,5 @@ def configure_optimizers(self): betas=(self.config["beta1"], self.config["beta2"]), amsgrad=self.config["amsgrad"], ) + self._optimizer = optimizer return optimizer diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 565d6df2c..d01e85f76 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -57,7 +57,7 @@ def __init__( self.synchronizing_rounds = False - self._fast_reboot = True + self._fast_reboot = False self._learning_rate = VANILLA_LEARNING_RATE self.learning_rate_lock = Locker(name="learning_rate_lock", async_lock=True) @@ -154,6 +154,8 @@ async def set_confings(self): ############################## async def add_weight_modifier(self, addr): + if not self.fast_reboot_on(): + return self.weight_modifier_lock.acquire() if not addr in self.weight_modifier: wm = self.new_node_weight_multiplier @@ -168,6 +170,8 @@ def remove_weight_modifier(self, addr): del self.weight_modifier[addr] async def apply_weight_strategy(self, updates: dict): + if not self.fast_reboot_on(): + return logging.info(f"πŸ”„ Applying weight Strategy...") # We must lower the weight_modifier value if a round jump has been occured # as many times as rounds have been jumped diff --git a/nebula/node.py b/nebula/node.py index e73bbb067..92ed638dc 100755 --- a/nebula/node.py +++ b/nebula/node.py @@ -354,7 +354,7 @@ def randomize_value(value, variability): # 615 r50 if config.participant["network_args"]["ip"] == "192.168.50.9": logging.info("Sleeping 385s...") - time.sleep(50) + time.sleep(240) elif config.participant["network_args"]["ip"] == "192.168.50.10": logging.info("Sleeping 800s...") time.sleep(615) From 7d6dbf45abdd403ef9b0e8d7429a2cb793a006fe Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Tue, 24 Dec 2024 12:26:37 +0100 Subject: [PATCH 038/233] feat_defaultMH default model handler integrated refactor fastreboot --- nebula/core/aggregation/aggregator.py | 6 +- nebula/core/engine.py | 21 ++- nebula/core/neighbormanagement/fastreboot.py | 121 +++++++++++++++++ .../modelhandlers/defaultmodelhandler.py | 47 +++++++ .../modelhandlers/modelhandler.py | 4 +- .../modelhandlers/stdmodelhandler.py | 2 +- nebula/core/neighbormanagement/nodemanager.py | 122 +++++------------- nebula/node.py | 2 +- 8 files changed, 215 insertions(+), 110 deletions(-) create mode 100644 nebula/core/neighbormanagement/fastreboot.py create mode 100644 nebula/core/neighbormanagement/modelhandlers/defaultmodelhandler.py diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index 1682dbc65..f48c587c0 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -288,7 +288,7 @@ async def aggregation_push_available(self): if push == "slow": logging.info(f"❗️ FUTURE round: {further_round} is available | PUSH strategy ON") logging.info("❗️ SLOW push selected | Start PUSHING slow") - self.engine.set_pushed_done(further_round - self.engine.get_round()) + await self.engine.set_pushed_done(further_round - self.engine.get_round()) # we wait until learning cycle reach aggregation point while not self._aggregation_done_lock.locked_async(): logging.info("πŸ”„ Waiting | aggregation step not reached yet...") @@ -306,7 +306,7 @@ async def aggregation_push_available(self): if further_round == (self.engine.get_round()+1): logging.info(f"πŸ”„ Rounds jumped: {1}...") - self.engine.set_pushed_done(further_round - self.engine.get_round()) + await self.engine.set_pushed_done(further_round - self.engine.get_round()) # we wait until learning cycle reach aggregation point while not self._aggregation_done_lock.locked_async(): logging.info("πŸ”„ Waiting | aggregation step not reached yet...") @@ -345,7 +345,7 @@ async def aggregation_push_available(self): self.engine.update_sinchronized_status(False) self.engine.set_synchronizing_rounds(True) - self.engine.set_pushed_done(further_round - self.engine.get_round()) + await self.engine.set_pushed_done(further_round - self.engine.get_round()) self.engine.set_round(further_round) # Unlock aggregation diff --git a/nebula/core/engine.py b/nebula/core/engine.py index ca9d9ec4d..5a7877bc0 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -158,7 +158,7 @@ def __init__( if self.mobility == True: topology = self.config.participant["mobility_args"]["topology_type"] topology = topology.lower() - model_handler = "std" #self.config.participant["mobility_args"]["model_handler"] + model_handler = "default" #self.config.participant["mobility_args"]["model_handler"] acceleration_push = "slow" #self.config.participant["mobility_args"]["push_strategy"] self._node_manager = NodeManager(topology, model_handler, acceleration_push, engine=self) @@ -395,7 +395,6 @@ async def _connection_late_connect_callback(self, source, message): await self.nm.confirmation_received(source, confirmation=True) elif self.nm.accept_connection(source, joining=True): logging.info(f"πŸ”— handle_connection_message | Late connection accepted | source: {source}") - await self.nm.add_weight_modifier(source) await self.cm.connect(source, direct=True) # Verify conenction is accepted @@ -413,10 +412,8 @@ async def _connection_late_connect_callback(self, source, message): df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, df_actions) await self.cm.send_message(source, df_msg) - self.nm.meet_node(source) - self.nm.update_neighbors(source) - if self.nm.fast_reboot_on: - await self.update_model_learning_rate() + self.nm.register_late_neighbor(source, joinning_federation=True) + else: logging.info(f"❗️ Late connection NOT accepted | source: {source}") @@ -445,8 +442,8 @@ async def _connection_restructure_callback(self, source, message): df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, df_actions) await self.cm.send_message(source, df_msg) - self.nm.meet_node(source) - self.nm.update_neighbors(source) + self.nm.register_late_neighbor(source, joinning_federation=False) + else: logging.info(f"❗️ handle_connection_message | Trigger | restructure connection denied from {source}") await asyncio.sleep(1) @@ -458,7 +455,7 @@ async def _discover_discover_join_callback(self, source, message): #self.nm.meet_node(source) if len(self.get_federation_nodes()) > 0: await self.trainning_in_progress_lock.acquire_async() - model, rounds, round = await self.cm.propagator.get_model_information(source, "initialization") if self.get_round() > 0 else await self.cm.propagator.get_model_information(source, "initialization") + model, rounds, round = await self.cm.propagator.get_model_information(source, "stable") if self.get_round() > 0 else await self.cm.propagator.get_model_information(source, "initialization") await self.trainning_in_progress_lock.release_async() if round != -1: epochs = self.config.participant["training_args"]["epochs"] @@ -560,8 +557,8 @@ async def _aditional_node_start(self): def get_push_acceleration(self): return self.nm.get_push_acceleration() - def set_pushed_done(self, rounds_push): - self.nm.set_rounds_pushed(rounds_push) + async def set_pushed_done(self, rounds_push): + await self.nm.set_rounds_pushed(rounds_push) async def apply_weight_strategy(self, pending_models): if self.mobility and self.nm.fast_reboot_on(): @@ -581,7 +578,7 @@ async def update_model_learning_rate(self): async def _start_learning_late(self): await self.learning_cycle_lock.acquire_async() try: - model_serialized, rounds, round, _epochs = self.nm.get_trainning_info() + model_serialized, rounds, round, _epochs = await self.nm.get_trainning_info() self.total_rounds = rounds # self.config.participant["scenario_args"]["rounds"] #rounds epochs = _epochs # self.config.participant["training_args"]["epochs"] #_epochs await self.get_round_lock().acquire_async() diff --git a/nebula/core/neighbormanagement/fastreboot.py b/nebula/core/neighbormanagement/fastreboot.py new file mode 100644 index 000000000..ad9cb3af2 --- /dev/null +++ b/nebula/core/neighbormanagement/fastreboot.py @@ -0,0 +1,121 @@ +import asyncio +import logging + +from nebula.core.utils.locker import Locker +from nebula.core.neighbormanagement.nodemanager import NodeManager + +VANILLA_LEARNING_RATE = 1e-3 +FR_LEARNING_RATE = 2e-3 +MAX_ROUNDS = 20 +DEFAULT_WEIGHT_MODIFIER = 3 + +class FastReboot(): + + def __init__( + self, + node_manager : NodeManager, + max_rounds_application = MAX_ROUNDS, # Max rounds to be applied FastReboot + weight_modifier = DEFAULT_WEIGHT_MODIFIER, + default_learning_rate = VANILLA_LEARNING_RATE, # Stable value for learning rate + upgrade_learning_rate = FR_LEARNING_RATE, # Increased value for learning rate + ): + self._node_manager = node_manager + self._max_rounds = max_rounds_application + self._weight_modifier = weight_modifier + self._default_lr = default_learning_rate + self._upgrade_lr = upgrade_learning_rate + self._current_lr = default_learning_rate + self._learning_rate_lock = Locker(name="learning_rate_lock", async_lock=True) + + self._weight_modifier = {} + self._weight_modifier_lock = Locker(name="weight_modifier_lock", async_lock=True) + self._rounds_pushed_lock = Locker(name="rounds_pushed_lock", async_lock=True) + self._rounds_pushed = 0 + + + + @property + def nm(self): + return self._node_manager + + async def set_rounds_pushed(self, rp): + await self._rounds_pushed_lock.acquire_async() + self.rounds_pushed = rp + await self._rounds_pushed_lock.release_async() + + async def get_learning_rate_increase(self): + await self._learning_rate_lock.acquire_async() + lr = self._current_lr + await self._learning_rate_lock.release_async() + return lr + + async def _set_learning_rate(self, lr): + await self._learning_rate_lock.acquire_async() + self._current_lr = lr + await self._learning_rate_lock.release_async() + + async def add_fastReboot_addr(self, addr): + self._weight_modifier_lock.acquire() + if not addr in self._weight_modifier: + wm = self._weight_modifier + logging.info(f"πŸ“ Registering | FastReboot registered for source {addr} | round application: {self._max_rounds} | multiplier value: {wm}") + self._weight_modifier[addr] = (wm,0) + await self._set_learning_rate(self._upgrade_lr) + self._weight_modifier_lock.release() + + async def _remove_weight_modifier(self, addr): + if addr in self._weight_modifier: + logging.info(f"πŸ“ Removing | FastReboot registered for source {addr}") + del self._weight_modifier[addr] + + async def apply_weight_strategy(self, updates: dict): + logging.info(f"πŸ”„ Applying FastReboot Strategy...") + # We must lower the weight_modifier value if a round jump has been occured + # as many times as rounds have been jumped + if self.rounds_pushed: + logging.info(f"πŸ”„ There are rounds being pushed...") + for i in range(0, self.rounds_pushed): + logging.info(f"πŸ”„ Update | weights being updated cause of push...") + self._update_weight_modifiers() + self.rounds_pushed = 0 + for addr,update in updates.items(): + weightmodifier, rounds = self._get_weight_modifier(addr) + if weightmodifier != 1: + logging.info (f"πŸ“ Appliying FastReboot strategy | addr: {addr} | multiplier value: {weightmodifier}, rounds applied: {rounds}") + model, weight = update + updates.update({addr: (model, weight*weightmodifier)}) + await self._update_weight_modifiers() + + async def _update_weight_modifiers(self): + self._weight_modifier_lock.acquire_async() + logging.info(f"πŸ”„ Update | weights being updated") + if self._weight_modifier: + remove_addrs = [] + for addr, (weight,rounds) in self._weight_modifier.items(): + new_weight = weight - 1/(rounds**2) + rounds = rounds + 1 + if new_weight > 1 and rounds <= self._max_rounds: + self._weight_modifier[addr] = (new_weight, rounds) + else: + remove_addrs.append(addr) + #self.remove_weight_modifier(addr) + for a in remove_addrs: + self._remove_weight_modifier(a) + else: + if not self._weight_modifier and await self._is_lr_modified(): + logging.info(f"πŸ”„ Finishing | FastReboot is completed") + await self._set_learning_rate(self._default_lr) + await self.nm.update_learning_rate() + self._weight_modifier_lock.release_async() + + async def _get_weight_modifier(self, addr): + self._weight_modifier_lock.acquire_async() + wm = self._weight_modifier.get(addr, (1,0)) + self._weight_modifier_lock.release_async() + return wm + + async def _is_lr_modified(self): + await self._learning_rate_lock.acquire_async() + mod = self._current_lr == self._upgrade_lr + await self._learning_rate_lock.release_async() + return mod \ No newline at end of file diff --git a/nebula/core/neighbormanagement/modelhandlers/defaultmodelhandler.py b/nebula/core/neighbormanagement/modelhandlers/defaultmodelhandler.py new file mode 100644 index 000000000..25f698756 --- /dev/null +++ b/nebula/core/neighbormanagement/modelhandlers/defaultmodelhandler.py @@ -0,0 +1,47 @@ +from nebula.core.neighbormanagement.modelhandlers.modelhandler import ModelHandler +from nebula.core.utils.locker import Locker +from nebula.core.neighbormanagement.nodemanager import NodeManager +import logging + +class DefaultModelHandler(ModelHandler): + + def __init__(self): + self.model = None + self.rounds = 0 + self.round = 0 + self.epochs = 0 + self.model_lock = Locker(name="model_lock") + self.params_lock = Locker(name="param_lock") + self._nm : NodeManager = None + + def set_config(self, config): + """ + Args: + config[0] -> total rounds + config[1] -> current round + config[2] -> epochs + """ + self.params_lock.acquire() + self.rounds = config[0] + if config[1] > self.round: + self.round = config[1] + self.epochs = config[2] + if not self._nm: + self._nm = config[3] + self.params_lock.release() + + def accept_model(self, model): + return True + + async def get_model(self, model): + """ + Returns: + model with default weights + """ + return await self._nm.engine.cm.propagator.get_model_information(None, "initialization") + + def pre_process_model(self): + """ + no pre-processing defined + """ + pass \ No newline at end of file diff --git a/nebula/core/neighbormanagement/modelhandlers/modelhandler.py b/nebula/core/neighbormanagement/modelhandlers/modelhandler.py index 0b418992a..1d4af4e5a 100644 --- a/nebula/core/neighbormanagement/modelhandlers/modelhandler.py +++ b/nebula/core/neighbormanagement/modelhandlers/modelhandler.py @@ -12,7 +12,7 @@ def accept_model(self, model): pass @abstractmethod - def get_model(self, model): + async def get_model(self, model): pass @abstractmethod @@ -22,9 +22,11 @@ def pre_process_model(self): def factory_ModelHandler(model_handler) -> ModelHandler: from nebula.core.neighbormanagement.modelhandlers.stdmodelhandler import STDModelHandler from nebula.core.neighbormanagement.modelhandlers.aggmodelhandler import AGGModelHandler + from nebula.core.neighbormanagement.modelhandlers.defaultmodelhandler import DefaultModelHandler options = { "std": STDModelHandler, + "default": DefaultModelHandler, "aggregator": AGGModelHandler, } diff --git a/nebula/core/neighbormanagement/modelhandlers/stdmodelhandler.py b/nebula/core/neighbormanagement/modelhandlers/stdmodelhandler.py index b23f62726..097f5fc0a 100644 --- a/nebula/core/neighbormanagement/modelhandlers/stdmodelhandler.py +++ b/nebula/core/neighbormanagement/modelhandlers/stdmodelhandler.py @@ -36,7 +36,7 @@ def accept_model(self, model): self.model = model return True - def get_model(self, model): + async def get_model(self, model): """ Returns: neccesary data to create trainer diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index d01e85f76..94ee654e2 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -1,8 +1,6 @@ import asyncio import logging import os -import asyncio -import threading from nebula.core.utils.locker import Locker from nebula.core.neighbormanagement.candidateselection.candidateselector import factory_CandidateSelector @@ -10,15 +8,13 @@ from nebula.core.neighbormanagement.neighborpolicies.neighborpolicy import factory_NeighborPolicy from nebula.core.pb import nebula_pb2 from nebula.core.network.communications import CommunicationsManager +from nebula.core.neighbormanagement.fastreboot import FastReboot from nebula.addons.functions import print_msg_box from typing import TYPE_CHECKING if TYPE_CHECKING: from nebula.core.engine import Engine -VANILLA_LEARNING_RATE = 1e-3 -FT_LEARNING_RATE = 2e-3 - class NodeManager(): def __init__( @@ -26,7 +22,8 @@ def __init__( topology, model_handler, push_acceleration, - engine : "Engine" + engine : "Engine", + fastreboot = False ): self.topology = topology print_msg_box(msg=f"Starting NodeManager module...\nTopology: {self.topology}", indent=2, title="NodeManager module") @@ -42,9 +39,6 @@ def __init__( self.late_connection_process_lock = Locker(name="late_connection_process_lock") self.pending_confirmation_from_nodes = [] self.pending_confirmation_from_nodes_lock = Locker(name="pending_confirmation_from_nodes_lock") - self.weight_modifier = {} - self.weight_modifier_lock = Locker(name="weight_modifier_lock") - self.new_node_weight_multiplier = 3 self.accept_candidates_lock = Locker(name="accept_candidates_lock") self.recieve_offer_timer = 5 self._restructure_process_lock = Locker(name="restructure_process_lock") @@ -52,15 +46,13 @@ def __init__( self.discarded_offers_addr_lock = Locker(name="discarded_offers_addr_lock") self.discarded_offers_addr = [] self._push_acceleration = push_acceleration - self.rounds_pushed_lock = Locker(name="rounds_pushed_lock") - self.rounds_pushed = 0 self.synchronizing_rounds = False - self._fast_reboot = False - self._learning_rate = VANILLA_LEARNING_RATE - self.learning_rate_lock = Locker(name="learning_rate_lock", async_lock=True) - + self._fast_reboot_status = fastreboot + if (fastreboot): + self._fastreboot = FastReboot(self) + #self.set_confings() @property @@ -79,14 +71,12 @@ def candidate_selector(self): def model_handler(self): return self._model_handler - async def get_learning_rate_increase(self): - await self.learning_rate_lock.acquire_async() - lr = self._learning_rate - await self.learning_rate_lock.release_async() - return lr + @property + def fr(self): + return self._fastreboot def fast_reboot_on(self): - return self._fast_reboot + return self._fast_reboot_status def get_push_acceleration(self): return self._push_acceleration @@ -100,14 +90,9 @@ def set_synchronizing_rounds(self, status): def get_syncrhonizing_rounds(self): return self.synchronizing_rounds - def set_rounds_pushed(self, rp): - with self.rounds_pushed_lock: - self.rounds_pushed = rp - - async def _set_learning_rate(self, lr): - await self.learning_rate_lock.acquire_async() - self._learning_rate = lr - await self.learning_rate_lock.release_async() + async def set_rounds_pushed(self, rp): + if self.fast_reboot_on(): + self.fr.set_rounds_pushed(rp) def still_waiting_for_candidates(self): return not self.accept_candidates_lock.locked() @@ -150,72 +135,25 @@ async def set_confings(self): ############################## - # WEIGHT STRATEGY # + # FAST REBOOT # ############################## - - async def add_weight_modifier(self, addr): - if not self.fast_reboot_on(): - return - self.weight_modifier_lock.acquire() - if not addr in self.weight_modifier: - wm = self.new_node_weight_multiplier - logging.info(f"πŸ“ Registering | Weight modifier registered for source {addr} | round: {self.engine.get_round()} | value: {wm}") - self.weight_modifier[addr] = (wm,1) - await self._set_learning_rate(FT_LEARNING_RATE) - self.weight_modifier_lock.release() - - def remove_weight_modifier(self, addr): - if addr in self.weight_modifier: - logging.info(f"πŸ“ Removing | weight modifier registered for source {addr}") - del self.weight_modifier[addr] - + + + async def update_learning_rate(self, new_lr): + await self.engine.update_model_learning_rate(new_lr) + + async def register_late_neighbor(self, addr, joinning_federation=False): + self.meet_node(addr) + self.update_neighbors(addr) + if joinning_federation: + if self.fast_reboot_on(): + self.fr.add_fastReboot_addr(addr) + async def apply_weight_strategy(self, updates: dict): if not self.fast_reboot_on(): return - logging.info(f"πŸ”„ Applying weight Strategy...") - # We must lower the weight_modifier value if a round jump has been occured - # as many times as rounds have been jumped - if self.rounds_pushed: - logging.info(f"πŸ”„ There are rounds being pushed...") - for i in range(0, self.rounds_pushed): - logging.info(f"πŸ”„ Update | weights being updated cause of push...") - self._update_weight_modifiers() - self.rounds_pushed = 0 - for addr,update in updates.items(): - weightmodifier, rounds = self._get_weight_modifier(addr) - if weightmodifier != 1: - logging.info (f"πŸ“ Appliying modified weight strategy | addr: {addr} | multiplier value: {weightmodifier}, rounds applied: {rounds}") - model, weight = update - updates.update({addr: (model, weight*weightmodifier)}) - await self._update_weight_modifiers() + await self.fr.apply_weight_strategy(updates) - async def _update_weight_modifiers(self): - self.weight_modifier_lock.acquire() - logging.info(f"πŸ”„ Update | weights being updated") - if self.weight_modifier: - remove_addrs = [] - for addr, (weight,rounds) in self.weight_modifier.items(): - new_weight = weight - 1/(rounds**2) - rounds = rounds + 1 - if new_weight > 1 and rounds <= 20: - self.weight_modifier[addr] = (new_weight, rounds) - else: - remove_addrs.append(addr) - #self.remove_weight_modifier(addr) - for a in remove_addrs: - self.remove_weight_modifier(a) - else: - if len(self.weight_modifier) == 0 and await self.get_learning_rate_increase() == FT_LEARNING_RATE: - logging.info(f"πŸ”„ Finishing | weight strategy is completed") - await self._set_learning_rate(VANILLA_LEARNING_RATE) - await self.engine.update_model_learning_rate() - self.weight_modifier_lock.release() - - def _get_weight_modifier(self, addr): - self.weight_modifier_lock.acquire() - wm = self.weight_modifier.get(addr, (1,0)) - self.weight_modifier_lock.release() - return wm ############################## @@ -297,8 +235,8 @@ def accept_model_offer(self, source, decoded_model, rounds, round, epochs, n_nei else: return False - def get_trainning_info(self): - return self.model_handler.get_model(None) + async def get_trainning_info(self): + return await self.model_handler.get_model(None) def add_candidate(self,source, n_neighbors, loss): if not self.accept_candidates_lock.locked(): diff --git a/nebula/node.py b/nebula/node.py index 92ed638dc..02db9723f 100755 --- a/nebula/node.py +++ b/nebula/node.py @@ -354,7 +354,7 @@ def randomize_value(value, variability): # 615 r50 if config.participant["network_args"]["ip"] == "192.168.50.9": logging.info("Sleeping 385s...") - time.sleep(240) + time.sleep(420) elif config.participant["network_args"]["ip"] == "192.168.50.10": logging.info("Sleeping 800s...") time.sleep(615) From 111568c4e7086d5e296fd2c440f85777b3bd58b9 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Tue, 24 Dec 2024 14:05:31 +0100 Subject: [PATCH 039/233] fix_error_defaultHM --- nebula/core/neighbormanagement/fastreboot.py | 13 +++++++------ .../modelhandlers/defaultmodelhandler.py | 3 ++- nebula/core/neighbormanagement/nodemanager.py | 7 ++++--- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/nebula/core/neighbormanagement/fastreboot.py b/nebula/core/neighbormanagement/fastreboot.py index ad9cb3af2..76930e819 100644 --- a/nebula/core/neighbormanagement/fastreboot.py +++ b/nebula/core/neighbormanagement/fastreboot.py @@ -2,7 +2,9 @@ import logging from nebula.core.utils.locker import Locker -from nebula.core.neighbormanagement.nodemanager import NodeManager +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from nebula.core.neighbormanagement.nodemanager import NodeManager VANILLA_LEARNING_RATE = 1e-3 FR_LEARNING_RATE = 2e-3 @@ -13,7 +15,7 @@ class FastReboot(): def __init__( self, - node_manager : NodeManager, + node_manager : "NodeManager", max_rounds_application = MAX_ROUNDS, # Max rounds to be applied FastReboot weight_modifier = DEFAULT_WEIGHT_MODIFIER, default_learning_rate = VANILLA_LEARNING_RATE, # Stable value for learning rate @@ -31,9 +33,7 @@ def __init__( self._weight_modifier_lock = Locker(name="weight_modifier_lock", async_lock=True) self._rounds_pushed_lock = Locker(name="rounds_pushed_lock", async_lock=True) self._rounds_pushed = 0 - - - + @property def nm(self): return self._node_manager @@ -43,7 +43,7 @@ async def set_rounds_pushed(self, rp): self.rounds_pushed = rp await self._rounds_pushed_lock.release_async() - async def get_learning_rate_increase(self): + async def get_current_learning_rate(self): await self._learning_rate_lock.acquire_async() lr = self._current_lr await self._learning_rate_lock.release_async() @@ -61,6 +61,7 @@ async def add_fastReboot_addr(self, addr): logging.info(f"πŸ“ Registering | FastReboot registered for source {addr} | round application: {self._max_rounds} | multiplier value: {wm}") self._weight_modifier[addr] = (wm,0) await self._set_learning_rate(self._upgrade_lr) + await self.nm.update_learning_rate(await self.get_current_learning_rate()) self._weight_modifier_lock.release() async def _remove_weight_modifier(self, addr): diff --git a/nebula/core/neighbormanagement/modelhandlers/defaultmodelhandler.py b/nebula/core/neighbormanagement/modelhandlers/defaultmodelhandler.py index 25f698756..14440b0d9 100644 --- a/nebula/core/neighbormanagement/modelhandlers/defaultmodelhandler.py +++ b/nebula/core/neighbormanagement/modelhandlers/defaultmodelhandler.py @@ -38,7 +38,8 @@ async def get_model(self, model): Returns: model with default weights """ - return await self._nm.engine.cm.propagator.get_model_information(None, "initialization") + (sm, rounds, round) = await self._nm.engine.cm.propagator.get_model_information(None, "initialization") + return (sm, rounds, round, self.epochs) def pre_process_model(self): """ diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 94ee654e2..6da6e0c93 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -23,7 +23,7 @@ def __init__( model_handler, push_acceleration, engine : "Engine", - fastreboot = False + fastreboot=False, ): self.topology = topology print_msg_box(msg=f"Starting NodeManager module...\nTopology: {self.topology}", indent=2, title="NodeManager module") @@ -210,7 +210,8 @@ def update_neighbors(self, node, remove=False): self.neighbor_policy.update_neighbors(node, remove) #self.timer_generator.update_node(node, remove) if remove: - self.remove_weight_modifier(node) + pass #TODO + #self.remove_weight_modifier(node) if not remove: self.neighbor_policy.meet_node(node) @@ -228,7 +229,7 @@ def accept_model_offer(self, source, decoded_model, rounds, round, epochs, n_nei if not self.accept_candidates_lock.locked(): logging.info(f"πŸ”„ Processing offer from {source}...") model_accepted = self.model_handler.accept_model(decoded_model) - self.model_handler.set_config(config=(rounds, round, epochs)) + self.model_handler.set_config(config=(rounds, round, epochs, self)) if model_accepted: self.candidate_selector.add_candidate((source, n_neighbors, loss)) return True From 882e4eb84a842ba9e4806b822726ed1b7e4e503f Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Tue, 24 Dec 2024 17:48:54 +0100 Subject: [PATCH 040/233] fix_wrong_payload --- .../modelhandlers/defaultmodelhandler.py | 2 +- nebula/core/network/propagator.py | 15 ++++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/nebula/core/neighbormanagement/modelhandlers/defaultmodelhandler.py b/nebula/core/neighbormanagement/modelhandlers/defaultmodelhandler.py index 14440b0d9..e209c0aa1 100644 --- a/nebula/core/neighbormanagement/modelhandlers/defaultmodelhandler.py +++ b/nebula/core/neighbormanagement/modelhandlers/defaultmodelhandler.py @@ -38,7 +38,7 @@ async def get_model(self, model): Returns: model with default weights """ - (sm, rounds, round) = await self._nm.engine.cm.propagator.get_model_information(None, "initialization") + (sm, rounds, round) = await self._nm.engine.cm.propagator.get_model_information(None, "initialization", init=True) return (sm, rounds, round, self.epochs) def pre_process_model(self): diff --git a/nebula/core/network/propagator.py b/nebula/core/network/propagator.py index f1499a530..f156090f2 100755 --- a/nebula/core/network/propagator.py +++ b/nebula/core/network/propagator.py @@ -169,13 +169,14 @@ async def propagate(self, strategy_id: str): await asyncio.sleep(self.interval) return True - async def get_model_information(self, dest_addr, strategy_id: str): - if strategy_id not in self.strategies: - logging.info(f"Strategy {strategy_id} not found.") - return None - if self.get_round() is None: - logging.info("Propagation halted: round is not set.") - return None + async def get_model_information(self, dest_addr, strategy_id: str, init=False): + if not init: + if strategy_id not in self.strategies: + logging.info(f"Strategy {strategy_id} not found.") + return None + if self.get_round() is None: + logging.info("Propagation halted: round is not set.") + return None strategy = self.strategies[strategy_id] logging.info(f"Preparing model information with strategy to make an offer: {strategy_id}") From ba6a397c58d3a0f8062e0330b762f4e761f8b274 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 25 Dec 2024 14:22:15 +0100 Subject: [PATCH 041/233] fix_keyerror_np --- nebula/core/engine.py | 5 +++-- .../neighbormanagement/modelhandlers/defaultmodelhandler.py | 5 +++-- .../neighbormanagement/neighborpolicies/fcneighborpolicy.py | 5 ++++- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 5a7877bc0..e24ae908b 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -158,7 +158,7 @@ def __init__( if self.mobility == True: topology = self.config.participant["mobility_args"]["topology_type"] topology = topology.lower() - model_handler = "default" #self.config.participant["mobility_args"]["model_handler"] + model_handler = "std" #self.config.participant["mobility_args"]["model_handler"] acceleration_push = "slow" #self.config.participant["mobility_args"]["push_strategy"] self._node_manager = NodeManager(topology, model_handler, acceleration_push, engine=self) @@ -327,7 +327,6 @@ async def _connection_disconnect_callback(self, source, message): self.nm.update_neighbors(source, remove=True) await self.cm.disconnect(source, mutual_disconnection=False) - @event_handler( nebula_pb2.FederationMessage, nebula_pb2.FederationMessage.Action.FEDERATION_READY, @@ -383,6 +382,8 @@ async def _federation_models_included_callback(self, source, message): finally: await self.cm.get_connections_lock().release_async() + + # Mobility callbacks @event_handler( nebula_pb2.ConnectionMessage, diff --git a/nebula/core/neighbormanagement/modelhandlers/defaultmodelhandler.py b/nebula/core/neighbormanagement/modelhandlers/defaultmodelhandler.py index e209c0aa1..f8a90c62a 100644 --- a/nebula/core/neighbormanagement/modelhandlers/defaultmodelhandler.py +++ b/nebula/core/neighbormanagement/modelhandlers/defaultmodelhandler.py @@ -20,6 +20,7 @@ def set_config(self, config): config[0] -> total rounds config[1] -> current round config[2] -> epochs + config[3] -> NodeManager """ self.params_lock.acquire() self.rounds = config[0] @@ -38,8 +39,8 @@ async def get_model(self, model): Returns: model with default weights """ - (sm, rounds, round) = await self._nm.engine.cm.propagator.get_model_information(None, "initialization", init=True) - return (sm, rounds, round, self.epochs) + (sm, _, _) = await self._nm.engine.cm.propagator.get_model_information(None, "initialization", init=True) + return (sm, self.rounds, self.round, self.epochs) def pre_process_model(self): """ diff --git a/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py b/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py index b56dedd47..63bf746c5 100644 --- a/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py +++ b/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py @@ -89,7 +89,10 @@ def _connect_to(self): def update_neighbors(self, node, remove=False): self.neighbors_lock.acquire() if remove: - self.neighbors.remove(node) + try: + self.neighbors.remove(node) + except KeyError: + pass else: self.neighbors.add(node) self.neighbors_lock.release() \ No newline at end of file From b70cc4167d9abbe78ecd79df8d93512d2eeb6c61 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Fri, 27 Dec 2024 18:59:35 +0100 Subject: [PATCH 042/233] fix_general_errors --- nebula/core/engine.py | 9 ++- nebula/core/neighbormanagement/fastreboot.py | 64 +++++++++++-------- nebula/core/neighbormanagement/nodemanager.py | 8 ++- nebula/node.py | 16 +++-- 4 files changed, 57 insertions(+), 40 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index e24ae908b..61b63e6a0 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -413,7 +413,7 @@ async def _connection_late_connect_callback(self, source, message): df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, df_actions) await self.cm.send_message(source, df_msg) - self.nm.register_late_neighbor(source, joinning_federation=True) + await self.nm.register_late_neighbor(source, joinning_federation=True) else: logging.info(f"❗️ Late connection NOT accepted | source: {source}") @@ -443,7 +443,7 @@ async def _connection_restructure_callback(self, source, message): df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, df_actions) await self.cm.send_message(source, df_msg) - self.nm.register_late_neighbor(source, joinning_federation=False) + await self.nm.register_late_neighbor(source, joinning_federation=False) else: logging.info(f"❗️ handle_connection_message | Trigger | restructure connection denied from {source}") @@ -550,7 +550,7 @@ async def _aditional_node_start(self): logging.info(f"Aditional node | {self.addr} | going to stablish connection with federation") await self.nm.start_late_connection_process() # continue .. - asyncio.create_task(self.nm.stop_not_selected_connections()) + #asyncio.create_task(self.nm.stop_not_selected_connections()) logging.info("Creating trainer service to start the federation process..") asyncio.create_task(self._start_learning_late()) #decoded_model = self.trainer.deserialize_model(message.parameters) @@ -568,11 +568,10 @@ async def apply_weight_strategy(self, pending_models): else: return pending_models - async def update_model_learning_rate(self): + async def update_model_learning_rate(self, new_lr): await self.trainning_in_progress_lock.acquire_async() if self.get_round() < self.total_rounds: logging.info("Update | learning rate modified...") - new_lr = await self.nm.get_learning_rate_increase() self.trainer.update_model_learning_rate(new_lr) await self.trainning_in_progress_lock.release_async() diff --git a/nebula/core/neighbormanagement/fastreboot.py b/nebula/core/neighbormanagement/fastreboot.py index 76930e819..8ec43001a 100644 --- a/nebula/core/neighbormanagement/fastreboot.py +++ b/nebula/core/neighbormanagement/fastreboot.py @@ -7,7 +7,7 @@ from nebula.core.neighbormanagement.nodemanager import NodeManager VANILLA_LEARNING_RATE = 1e-3 -FR_LEARNING_RATE = 2e-3 +FR_LEARNING_RATE = 1e-3#2e-3 MAX_ROUNDS = 20 DEFAULT_WEIGHT_MODIFIER = 3 @@ -23,7 +23,7 @@ def __init__( ): self._node_manager = node_manager self._max_rounds = max_rounds_application - self._weight_modifier = weight_modifier + self._weight_mod_value = weight_modifier self._default_lr = default_learning_rate self._upgrade_lr = upgrade_learning_rate self._current_lr = default_learning_rate @@ -55,42 +55,52 @@ async def _set_learning_rate(self, lr): await self._learning_rate_lock.release_async() async def add_fastReboot_addr(self, addr): - self._weight_modifier_lock.acquire() + await self._weight_modifier_lock.acquire_async() if not addr in self._weight_modifier: - wm = self._weight_modifier + wm = self._weight_mod_value logging.info(f"πŸ“ Registering | FastReboot registered for source {addr} | round application: {self._max_rounds} | multiplier value: {wm}") - self._weight_modifier[addr] = (wm,0) + self._weight_modifier[addr] = (wm,1) await self._set_learning_rate(self._upgrade_lr) - await self.nm.update_learning_rate(await self.get_current_learning_rate()) - self._weight_modifier_lock.release() + current_lr = await self.get_current_learning_rate() + await self.nm.update_learning_rate(current_lr) + await self._weight_modifier_lock.release_async() async def _remove_weight_modifier(self, addr): - if addr in self._weight_modifier: - logging.info(f"πŸ“ Removing | FastReboot registered for source {addr}") - del self._weight_modifier[addr] + logging.info(f"πŸ“ Removing | FastReboot removed for source {addr}") + del self._weight_modifier[addr] + + async def _weight_modifiers_empty(self): + await self._weight_modifier_lock.acquire_async() + empty = False if self._weight_modifier else True + await self._weight_modifier_lock.release_async() + return empty - async def apply_weight_strategy(self, updates: dict): + async def apply_weight_strategy(self, updates: dict): + if await self._weight_modifiers_empty(): + await self._end_fastreboot() + return logging.info(f"πŸ”„ Applying FastReboot Strategy...") # We must lower the weight_modifier value if a round jump has been occured # as many times as rounds have been jumped - if self.rounds_pushed: + if self._rounds_pushed: logging.info(f"πŸ”„ There are rounds being pushed...") for i in range(0, self.rounds_pushed): logging.info(f"πŸ”„ Update | weights being updated cause of push...") self._update_weight_modifiers() self.rounds_pushed = 0 for addr,update in updates.items(): - weightmodifier, rounds = self._get_weight_modifier(addr) + weightmodifier, rounds = await self._get_weight_modifier(addr) if weightmodifier != 1: logging.info (f"πŸ“ Appliying FastReboot strategy | addr: {addr} | multiplier value: {weightmodifier}, rounds applied: {rounds}") model, weight = update updates.update({addr: (model, weight*weightmodifier)}) await self._update_weight_modifiers() - + + #TODO integrar en el get_wegith_modifier para que se actualice cuando se pide y ahi se compruebe si hay q eliminar una entrada async def _update_weight_modifiers(self): - self._weight_modifier_lock.acquire_async() - logging.info(f"πŸ”„ Update | weights being updated") + await self._weight_modifier_lock.acquire_async() if self._weight_modifier: + logging.info(f"πŸ”„ Update | weights being updated") remove_addrs = [] for addr, (weight,rounds) in self._weight_modifier.items(): new_weight = weight - 1/(rounds**2) @@ -99,20 +109,22 @@ async def _update_weight_modifiers(self): self._weight_modifier[addr] = (new_weight, rounds) else: remove_addrs.append(addr) - #self.remove_weight_modifier(addr) for a in remove_addrs: - self._remove_weight_modifier(a) - else: - if not self._weight_modifier and await self._is_lr_modified(): - logging.info(f"πŸ”„ Finishing | FastReboot is completed") - await self._set_learning_rate(self._default_lr) - await self.nm.update_learning_rate() - self._weight_modifier_lock.release_async() + await self._remove_weight_modifier(a) + await self._weight_modifier_lock.release_async() + + async def _end_fastreboot(self): + await self._weight_modifier_lock.acquire_async() + if not self._weight_modifier and await self._is_lr_modified(): + logging.info(f"πŸ”„ Finishing | FastReboot is completed") + await self._set_learning_rate(self._default_lr) + await self.nm.update_learning_rate(self._default_lr) + await self._weight_modifier_lock.release_async() async def _get_weight_modifier(self, addr): - self._weight_modifier_lock.acquire_async() + await self._weight_modifier_lock.acquire_async() wm = self._weight_modifier.get(addr, (1,0)) - self._weight_modifier_lock.release_async() + await self._weight_modifier_lock.release_async() return wm async def _is_lr_modified(self): diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 6da6e0c93..716eeee0a 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -23,7 +23,7 @@ def __init__( model_handler, push_acceleration, engine : "Engine", - fastreboot=False, + fastreboot=True, ): self.topology = topology print_msg_box(msg=f"Starting NodeManager module...\nTopology: {self.topology}", indent=2, title="NodeManager module") @@ -147,7 +147,7 @@ async def register_late_neighbor(self, addr, joinning_federation=False): self.update_neighbors(addr) if joinning_federation: if self.fast_reboot_on(): - self.fr.add_fastReboot_addr(addr) + await self.fr.add_fastReboot_addr(addr) async def apply_weight_strategy(self, updates: dict): if not self.fast_reboot_on(): @@ -228,7 +228,9 @@ def get_nodes_known(self, neighbors_too=False): def accept_model_offer(self, source, decoded_model, rounds, round, epochs, n_neighbors, loss): if not self.accept_candidates_lock.locked(): logging.info(f"πŸ”„ Processing offer from {source}...") - model_accepted = self.model_handler.accept_model(decoded_model) + model_accepted = True#self.model_handler.accept_model(decoded_model) + if source == "192.168.50.8:45007": + self.model_handler.accept_model(decoded_model) self.model_handler.set_config(config=(rounds, round, epochs, self)) if model_accepted: self.candidate_selector.add_candidate((source, n_neighbors, loss)) diff --git a/nebula/node.py b/nebula/node.py index 02db9723f..75ed0d10b 100755 --- a/nebula/node.py +++ b/nebula/node.py @@ -350,14 +350,18 @@ def randomize_value(value, variability): if additional_node_status: logging.info(f"Waiting for round {additional_node_round} to start") logging.info("Waiting time to start finding federation") + # MNIST # 385 r30 # 615 r50 - if config.participant["network_args"]["ip"] == "192.168.50.9": - logging.info("Sleeping 385s...") - time.sleep(420) - elif config.participant["network_args"]["ip"] == "192.168.50.10": - logging.info("Sleeping 800s...") - time.sleep(615) + + # CIFAR + # 420 r15 + #if config.participant["network_args"]["ip"] == "192.168.50.9": + # logging.info("Sleeping 385s...") + time.sleep(385) + #elif config.participant["network_args"]["ip"] == "192.168.50.10": + # logging.info("Sleeping 800s...") + # time.sleep(615) #time.sleep(6000) # DEBUG purposes #import requests From 82bc4bf57266d15fb73899ff0beccb89efae1fa5 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Sat, 28 Dec 2024 18:57:52 +0100 Subject: [PATCH 043/233] fix_reestructure_loop --- nebula/core/neighbormanagement/fastreboot.py | 15 +++++- nebula/core/neighbormanagement/nodemanager.py | 54 ++++++++++--------- nebula/node.py | 2 +- 3 files changed, 44 insertions(+), 27 deletions(-) diff --git a/nebula/core/neighbormanagement/fastreboot.py b/nebula/core/neighbormanagement/fastreboot.py index 8ec43001a..4912d569d 100644 --- a/nebula/core/neighbormanagement/fastreboot.py +++ b/nebula/core/neighbormanagement/fastreboot.py @@ -33,6 +33,8 @@ def __init__( self._weight_modifier_lock = Locker(name="weight_modifier_lock", async_lock=True) self._rounds_pushed_lock = Locker(name="rounds_pushed_lock", async_lock=True) self._rounds_pushed = 0 + + self._fr_in_progress = False @property def nm(self): @@ -49,6 +51,14 @@ async def get_current_learning_rate(self): await self._learning_rate_lock.release_async() return lr + async def discard_fastreboot_for(self, addr): + await self._weight_modifier_lock.acquire_async() + try: + del self._weight_modifier[addr] + except KeyError as e: + pass + await self._weight_modifier_lock.release_async() + async def _set_learning_rate(self, lr): await self._learning_rate_lock.acquire_async() self._current_lr = lr @@ -57,6 +67,7 @@ async def _set_learning_rate(self, lr): async def add_fastReboot_addr(self, addr): await self._weight_modifier_lock.acquire_async() if not addr in self._weight_modifier: + self._fr_in_progress = True wm = self._weight_mod_value logging.info(f"πŸ“ Registering | FastReboot registered for source {addr} | round application: {self._max_rounds} | multiplier value: {wm}") self._weight_modifier[addr] = (wm,1) @@ -77,7 +88,8 @@ async def _weight_modifiers_empty(self): async def apply_weight_strategy(self, updates: dict): if await self._weight_modifiers_empty(): - await self._end_fastreboot() + if self._fr_in_progress: + await self._end_fastreboot() return logging.info(f"πŸ”„ Applying FastReboot Strategy...") # We must lower the weight_modifier value if a round jump has been occured @@ -117,6 +129,7 @@ async def _end_fastreboot(self): await self._weight_modifier_lock.acquire_async() if not self._weight_modifier and await self._is_lr_modified(): logging.info(f"πŸ”„ Finishing | FastReboot is completed") + self._fr_in_progress = False await self._set_learning_rate(self._default_lr) await self.nm.update_learning_rate(self._default_lr) await self._weight_modifier_lock.release_async() diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 716eeee0a..3901ad690 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -23,7 +23,7 @@ def __init__( model_handler, push_acceleration, engine : "Engine", - fastreboot=True, + fastreboot=False, ): self.topology = topology print_msg_box(msg=f"Starting NodeManager module...\nTopology: {self.topology}", indent=2, title="NodeManager module") @@ -36,8 +36,9 @@ def __init__( self._candidate_selector = factory_CandidateSelector(self.topology) logging.info("Initializing Model Handler") self._model_handler = factory_ModelHandler(model_handler) + self._update_neighbors_lock = Locker(name="_update_neighbors_lock") self.late_connection_process_lock = Locker(name="late_connection_process_lock") - self.pending_confirmation_from_nodes = [] + self.pending_confirmation_from_nodes = set() self.pending_confirmation_from_nodes_lock = Locker(name="pending_confirmation_from_nodes_lock") self.accept_candidates_lock = Locker(name="accept_candidates_lock") self.recieve_offer_timer = 5 @@ -119,7 +120,8 @@ async def set_confings(self): [ await self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False), await self.engine.cm.get_addrs_current_connections(only_direct=False, only_undirected=False, myself=False), - self.engine.addr + self.engine.addr, + self ] ) logging.info(f"Building candidate selector configuration..") @@ -162,20 +164,18 @@ async def apply_weight_strategy(self, updates: dict): def accept_connection(self, source, joining=False): - if not joining: - if self.get_restructure_process_lock().locked(): - logging.info("NOT accepting connections | Currently upgrading network Robustness") - return False - else: - return self.neighbor_policy.accept_connection(source) - else: - return self.neighbor_policy.accept_connection(source) - - #TODO aΓ±adir un remove + return self.neighbor_policy.accept_connection(source, joining) + def add_pending_connection_confirmation(self, addr): - logging.info(f" Addition | pending connection confirmation from: {addr}") + with self._update_neighbors_lock: + with self.pending_confirmation_from_nodes_lock: + if not addr in self.neighbor_policy.get_nodes_known(neighbors_too=True): + logging.info(f" Addition | pending connection confirmation from: {addr}") + self.pending_confirmation_from_nodes.add(addr) + + def _remove_pending_confirmation_from(self, addr): with self.pending_confirmation_from_nodes_lock: - self.pending_confirmation_from_nodes.append(addr) + self.pending_confirmation_from_nodes.discard(addr) def clear_pending_confirmations(self): with self.pending_confirmation_from_nodes_lock: @@ -191,8 +191,7 @@ async def confirmation_received(self, addr, confirmation=False): await self.engine.cm.connect(addr, direct=True) self.update_neighbors(addr) else: - with self.pending_confirmation_from_nodes_lock: - self.pending_confirmation_from_nodes.remove(addr) + self._remove_pending_confirmation_from(addr) def add_to_discarded_offers(self, addr_discarded): self.discarded_offers_addr_lock.acquire() @@ -207,13 +206,16 @@ def get_actions(self): def update_neighbors(self, node, remove=False): logging.info(f"Update neighbor | node addr: {node} | remove: {remove}") + self._update_neighbors_lock.acquire() self.neighbor_policy.update_neighbors(node, remove) #self.timer_generator.update_node(node, remove) if remove: - pass #TODO - #self.remove_weight_modifier(node) - if not remove: + if self._fast_reboot_status: + self.fr.discard_fastreboot_for(node) + else: self.neighbor_policy.meet_node(node) + self._remove_pending_confirmation_from(node) + self._update_neighbors_lock.release() async def neighbors_left(self): return len(await self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False)) > 0 @@ -228,9 +230,9 @@ def get_nodes_known(self, neighbors_too=False): def accept_model_offer(self, source, decoded_model, rounds, round, epochs, n_neighbors, loss): if not self.accept_candidates_lock.locked(): logging.info(f"πŸ”„ Processing offer from {source}...") - model_accepted = True#self.model_handler.accept_model(decoded_model) - if source == "192.168.50.8:45007": - self.model_handler.accept_model(decoded_model) + model_accepted = self.model_handler.accept_model(decoded_model) + #if source == "192.168.50.8:45007": + # self.model_handler.accept_model(decoded_model) self.model_handler.set_config(config=(rounds, round, epochs, self)) if model_accepted: self.candidate_selector.add_candidate((source, n_neighbors, loss)) @@ -244,6 +246,9 @@ async def get_trainning_info(self): def add_candidate(self,source, n_neighbors, loss): if not self.accept_candidates_lock.locked(): self.candidate_selector.add_candidate((source, n_neighbors, loss)) + + async def currently_reestructuring(self): + return self._restructure_process_lock.locked() async def stop_not_selected_connections(self): try: @@ -316,8 +321,6 @@ async def start_late_connection_process(self, connected=False, msg_type="discove for addr, _, _ in best_candidates: await self.engine.cm.send_message(addr, msg) self.add_pending_connection_confirmation(addr) - #await self.engine.cm.connect(addr, direct=True) - #self.update_neighbors(addr) await asyncio.sleep(1) except asyncio.CancelledError as e: self.update_neighbors(addr, remove=True) @@ -342,6 +345,7 @@ async def start_late_connection_process(self, connected=False, msg_type="discove async def check_robustness(self): + #TODO aΓ±adir un cd para que no se haga continuamente logging.info("πŸ”„ Analizing node network robustness...") if not self._restructure_process_lock.locked(): if not self.neighbors_left(): diff --git a/nebula/node.py b/nebula/node.py index 75ed0d10b..13551ca6f 100755 --- a/nebula/node.py +++ b/nebula/node.py @@ -358,7 +358,7 @@ def randomize_value(value, variability): # 420 r15 #if config.participant["network_args"]["ip"] == "192.168.50.9": # logging.info("Sleeping 385s...") - time.sleep(385) + time.sleep(430) #elif config.participant["network_args"]["ip"] == "192.168.50.10": # logging.info("Sleeping 800s...") # time.sleep(615) From 66d62bc253f20d3dff2c7c60f8470446b1d3b5b5 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 1 Jan 2025 13:28:00 +0100 Subject: [PATCH 044/233] fix_slow_push_issue --- nebula/core/aggregation/aggregator.py | 4 +- nebula/core/engine.py | 13 +++-- nebula/core/neighbormanagement/nodemanager.py | 50 +++++++++++-------- nebula/node.py | 2 +- 4 files changed, 38 insertions(+), 31 deletions(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index f48c587c0..68ff00c50 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -273,6 +273,7 @@ async def aggregation_push_available(self): If the node is not sinchronized with the federation, it may be possible to make a push and try to catch the federation asap. """ + #TODO it would be able to push even if not fullround updates are being received logging.info(f"❗️ synchronized status: {self.engine.get_sinchronized_status()} | Analizing if an aggregation push is available...") if not self.engine.get_sinchronized_status() and not self.engine.get_trainning_in_progress_lock().locked() and not self.engine.get_synchronizing_rounds(): n_fed_nodes = len(self._federation_nodes) @@ -282,7 +283,7 @@ async def aggregation_push_available(self): n_fed_nodes-=1 for f_round, fm in self._future_models_to_aggregate.items(): # future_models dont count self node - if len(fm) == n_fed_nodes: + if len(fm) == n_fed_nodes or (f_round-self.engine.get_round() >= 2): further_round = f_round push = self.engine.get_push_acceleration() if push == "slow": @@ -356,6 +357,7 @@ async def aggregation_push_available(self): return else: + logging.info("Info | No future rounds available, device is up to date...") self.engine.update_sinchronized_status(True) self.engine.set_synchronizing_rounds(False) else: diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 61b63e6a0..e2cf4171b 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -321,10 +321,10 @@ async def _connection_connect_callback(self, source, message): async def _connection_disconnect_callback(self, source, message): logging.info(f"πŸ”— handle_connection_message | Trigger | Received disconnection message from {source}") if self.mobility: - if self.nm.waiting_confirmation_from(source): - self.nm.confirmation_received(source, confirmation=False) + if await self.nm.waiting_confirmation_from(source): + await self.nm.confirmation_received(source, confirmation=False) #if source in await self.cm.get_all_addrs_current_connections(only_direct=True): - self.nm.update_neighbors(source, remove=True) + await self.nm.update_neighbors(source, remove=True) await self.cm.disconnect(source, mutual_disconnection=False) @event_handler( @@ -392,7 +392,7 @@ async def _federation_models_included_callback(self, source, message): async def _connection_late_connect_callback(self, source, message): logging.info(f"πŸ”— handle_connection_message | Trigger | Received late connect message from {source}") # Verify if it's a confirmation message from a previous late connection message sent to source - if self.nm.waiting_confirmation_from(source): + if await self.nm.waiting_confirmation_from(source): await self.nm.confirmation_received(source, confirmation=True) elif self.nm.accept_connection(source, joining=True): logging.info(f"πŸ”— handle_connection_message | Late connection accepted | source: {source}") @@ -425,7 +425,7 @@ async def _connection_late_connect_callback(self, source, message): async def _connection_restructure_callback(self, source, message): logging.info(f"πŸ”— handle_connection_message | Trigger | Received restructure message from {source}") # Verify if it's a confirmation message from a previous restructure connection message sent to source - if self.nm.waiting_confirmation_from(source): + if await self.nm.waiting_confirmation_from(source): await self.nm.confirmation_received(source, confirmation=True) elif self.nm.accept_connection(source, joining=False): logging.info(f"πŸ”— handle_connection_message | Trigger | restructure connection accepted from {source}") @@ -543,7 +543,7 @@ async def _link_disconnect_from_callback(self, source, message): addrs = message.addrs for addr in addrs.split(): await self.cm.disconnect(source, mutual_disconnection=False) - self.nm.update_neighbors(addr, remove=True) + await self.nm.update_neighbors(addr, remove=True) async def _aditional_node_start(self): self.update_sinchronized_status(False) @@ -553,7 +553,6 @@ async def _aditional_node_start(self): #asyncio.create_task(self.nm.stop_not_selected_connections()) logging.info("Creating trainer service to start the federation process..") asyncio.create_task(self._start_learning_late()) - #decoded_model = self.trainer.deserialize_model(message.parameters) def get_push_acceleration(self): return self.nm.get_push_acceleration() diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 3901ad690..1e4b64d5e 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -36,10 +36,10 @@ def __init__( self._candidate_selector = factory_CandidateSelector(self.topology) logging.info("Initializing Model Handler") self._model_handler = factory_ModelHandler(model_handler) - self._update_neighbors_lock = Locker(name="_update_neighbors_lock") + self._update_neighbors_lock = Locker(name="_update_neighbors_lock", async_lock=True) self.late_connection_process_lock = Locker(name="late_connection_process_lock") self.pending_confirmation_from_nodes = set() - self.pending_confirmation_from_nodes_lock = Locker(name="pending_confirmation_from_nodes_lock") + self.pending_confirmation_from_nodes_lock = Locker(name="pending_confirmation_from_nodes_lock", async_lock=True) self.accept_candidates_lock = Locker(name="accept_candidates_lock") self.recieve_offer_timer = 5 self._restructure_process_lock = Locker(name="restructure_process_lock") @@ -146,7 +146,7 @@ async def update_learning_rate(self, new_lr): async def register_late_neighbor(self, addr, joinning_federation=False): self.meet_node(addr) - self.update_neighbors(addr) + await self.update_neighbors(addr) if joinning_federation: if self.fast_reboot_on(): await self.fr.add_fastReboot_addr(addr) @@ -166,24 +166,30 @@ async def apply_weight_strategy(self, updates: dict): def accept_connection(self, source, joining=False): return self.neighbor_policy.accept_connection(source, joining) - def add_pending_connection_confirmation(self, addr): - with self._update_neighbors_lock: - with self.pending_confirmation_from_nodes_lock: - if not addr in self.neighbor_policy.get_nodes_known(neighbors_too=True): - logging.info(f" Addition | pending connection confirmation from: {addr}") - self.pending_confirmation_from_nodes.add(addr) + async def add_pending_connection_confirmation(self, addr): + await self._update_neighbors_lock.acquire_async() + await self.pending_confirmation_from_nodes_lock.acquire_async() + if not addr in self.neighbor_policy.get_nodes_known(neighbors_too=True): + logging.info(f" Addition | pending connection confirmation from: {addr}") + self.pending_confirmation_from_nodes.add(addr) + await self.pending_confirmation_from_nodes_lock.release_async() + await self._update_neighbors_lock.release_async() - def _remove_pending_confirmation_from(self, addr): - with self.pending_confirmation_from_nodes_lock: - self.pending_confirmation_from_nodes.discard(addr) + async def _remove_pending_confirmation_from(self, addr): + await self.pending_confirmation_from_nodes_lock.acquire_async() + self.pending_confirmation_from_nodes.discard(addr) + await self.pending_confirmation_from_nodes_lock.release_async() - def clear_pending_confirmations(self): - with self.pending_confirmation_from_nodes_lock: - self.pending_confirmation_from_nodes.clear() + async def clear_pending_confirmations(self): + await self.pending_confirmation_from_nodes_lock.acquire_async() + self.pending_confirmation_from_nodes.clear() + await self.pending_confirmation_from_nodes_lock.release_async() - def waiting_confirmation_from(self, addr): - with self.pending_confirmation_from_nodes_lock: - return addr in self.pending_confirmation_from_nodes + async def waiting_confirmation_from(self, addr): + await self.pending_confirmation_from_nodes_lock.acquire_async() + found = addr in self.pending_confirmation_from_nodes + await self.pending_confirmation_from_nodes_lock.release_async() + return found async def confirmation_received(self, addr, confirmation=False): logging.info(f" Update | connection confirmation received from: {addr} | confirmation: {confirmation}") @@ -204,9 +210,9 @@ def need_more_neighbors(self): def get_actions(self): return self.neighbor_policy.get_actions() - def update_neighbors(self, node, remove=False): + async def update_neighbors(self, node, remove=False): logging.info(f"Update neighbor | node addr: {node} | remove: {remove}") - self._update_neighbors_lock.acquire() + await self._update_neighbors_lock.acquire_async() self.neighbor_policy.update_neighbors(node, remove) #self.timer_generator.update_node(node, remove) if remove: @@ -215,7 +221,7 @@ def update_neighbors(self, node, remove=False): else: self.neighbor_policy.meet_node(node) self._remove_pending_confirmation_from(node) - self._update_neighbors_lock.release() + await self._update_neighbors_lock.release_async() async def neighbors_left(self): return len(await self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False)) > 0 @@ -320,7 +326,7 @@ async def start_late_connection_process(self, connected=False, msg_type="discove try: for addr, _, _ in best_candidates: await self.engine.cm.send_message(addr, msg) - self.add_pending_connection_confirmation(addr) + await self.add_pending_connection_confirmation(addr) await asyncio.sleep(1) except asyncio.CancelledError as e: self.update_neighbors(addr, remove=True) diff --git a/nebula/node.py b/nebula/node.py index 13551ca6f..dc5e06b6d 100755 --- a/nebula/node.py +++ b/nebula/node.py @@ -358,7 +358,7 @@ def randomize_value(value, variability): # 420 r15 #if config.participant["network_args"]["ip"] == "192.168.50.9": # logging.info("Sleeping 385s...") - time.sleep(430) + time.sleep(420) #elif config.participant["network_args"]["ip"] == "192.168.50.10": # logging.info("Sleeping 800s...") # time.sleep(615) From 3445a76845a94e65042d0eaa4ff9ddd74c5922f0 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 2 Jan 2025 09:50:44 +0100 Subject: [PATCH 045/233] fix:concurrency_issue --- nebula/core/aggregation/aggregator.py | 2 +- nebula/core/engine.py | 29 ++++++++++++------- .../neighborpolicies/fcneighborpolicy.py | 7 ++++- nebula/core/neighbormanagement/nodemanager.py | 8 ++--- 4 files changed, 30 insertions(+), 16 deletions(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index 68ff00c50..51424e871 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -283,7 +283,7 @@ async def aggregation_push_available(self): n_fed_nodes-=1 for f_round, fm in self._future_models_to_aggregate.items(): # future_models dont count self node - if len(fm) == n_fed_nodes or (f_round-self.engine.get_round() >= 2): + if len(fm) == n_fed_nodes: further_round = f_round push = self.engine.get_push_acceleration() if push == "slow": diff --git a/nebula/core/engine.py b/nebula/core/engine.py index e2cf4171b..bbcf15ee1 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -393,8 +393,14 @@ async def _connection_late_connect_callback(self, source, message): logging.info(f"πŸ”— handle_connection_message | Trigger | Received late connect message from {source}") # Verify if it's a confirmation message from a previous late connection message sent to source if await self.nm.waiting_confirmation_from(source): - await self.nm.confirmation_received(source, confirmation=True) - elif self.nm.accept_connection(source, joining=True): + await self.nm.confirmation_received(source, confirmation=True) + return + + if not self.get_initialization_status(): + logging.info(f"❗️ Connection refused | Device not initialized yet...") + return + + if self.nm.accept_connection(source, joining=True): logging.info(f"πŸ”— handle_connection_message | Late connection accepted | source: {source}") await self.cm.connect(source, direct=True) @@ -413,8 +419,7 @@ async def _connection_late_connect_callback(self, source, message): df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, df_actions) await self.cm.send_message(source, df_msg) - await self.nm.register_late_neighbor(source, joinning_federation=True) - + await self.nm.register_late_neighbor(source, joinning_federation=True) else: logging.info(f"❗️ Late connection NOT accepted | source: {source}") @@ -427,7 +432,13 @@ async def _connection_restructure_callback(self, source, message): # Verify if it's a confirmation message from a previous restructure connection message sent to source if await self.nm.waiting_confirmation_from(source): await self.nm.confirmation_received(source, confirmation=True) - elif self.nm.accept_connection(source, joining=False): + return + + if not self.get_initialization_status(): + logging.info(f"❗️ Connection refused | Device not initialized yet...") + return + + if self.nm.accept_connection(source, joining=False): logging.info(f"πŸ”— handle_connection_message | Trigger | restructure connection accepted from {source}") await self.cm.connect(source, direct=True) @@ -443,8 +454,7 @@ async def _connection_restructure_callback(self, source, message): df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, df_actions) await self.cm.send_message(source, df_msg) - await self.nm.register_late_neighbor(source, joinning_federation=False) - + await self.nm.register_late_neighbor(source, joinning_federation=False) else: logging.info(f"❗️ handle_connection_message | Trigger | restructure connection denied from {source}") await asyncio.sleep(1) @@ -569,9 +579,8 @@ async def apply_weight_strategy(self, pending_models): async def update_model_learning_rate(self, new_lr): await self.trainning_in_progress_lock.acquire_async() - if self.get_round() < self.total_rounds: - logging.info("Update | learning rate modified...") - self.trainer.update_model_learning_rate(new_lr) + logging.info("Update | learning rate modified...") + self.trainer.update_model_learning_rate(new_lr) await self.trainning_in_progress_lock.release_async() async def _start_learning_late(self): diff --git a/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py b/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py index 63bf746c5..ca295635e 100644 --- a/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py +++ b/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py @@ -1,4 +1,5 @@ from nebula.core.neighbormanagement.neighborpolicies.neighborpolicy import NeighborPolicy +from nebula.core.neighbormanagement.nodemanager import NodeManager from nebula.core.utils.locker import Locker class FCNeighborPolicy(NeighborPolicy): @@ -9,6 +10,7 @@ def __init__(self): self.neighbors = set() self.neighbors_lock = Locker(name="neighbors_lock") self.nodes_known_lock = Locker(name="nodes_known_lock") + self._nm : NodeManager = None def need_more_neighbors(self): """ @@ -25,19 +27,22 @@ def set_config(self, config): Args: config[0] -> list of self neighbors config[1] -> list of nodes known on federation + config[2] -> self addr + config[3] -> NodeManager reference """ self.neighbors_lock.acquire() self.neighbors = config[0] self.neighbors_lock.release() for addr in config[1]: self.nodes_known.add(addr) + self._nm = config[3] def accept_connection(self, source, joining=False): """ return true if connection is accepted """ self.neighbors_lock.acquire() - ac = not source in self.neighbors + ac = (not source in self.neighbors) self.neighbors_lock.release() return ac diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 1e4b64d5e..4b04cf514 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -23,7 +23,7 @@ def __init__( model_handler, push_acceleration, engine : "Engine", - fastreboot=False, + fastreboot=True, ): self.topology = topology print_msg_box(msg=f"Starting NodeManager module...\nTopology: {self.topology}", indent=2, title="NodeManager module") @@ -195,7 +195,7 @@ async def confirmation_received(self, addr, confirmation=False): logging.info(f" Update | connection confirmation received from: {addr} | confirmation: {confirmation}") if confirmation: await self.engine.cm.connect(addr, direct=True) - self.update_neighbors(addr) + await self.update_neighbors(addr) else: self._remove_pending_confirmation_from(addr) @@ -300,7 +300,7 @@ async def start_late_connection_process(self, connected=False, msg_type="discove self.late_connection_process_lock.acquire() best_candidates = [] self.candidate_selector.remove_candidates() - self.clear_pending_confirmations() + await self.clear_pending_confirmations() # find federation and send discover await self.engine.cm.stablish_connection_to_federation(msg_type, addrs_known) @@ -329,7 +329,7 @@ async def start_late_connection_process(self, connected=False, msg_type="discove await self.add_pending_connection_confirmation(addr) await asyncio.sleep(1) except asyncio.CancelledError as e: - self.update_neighbors(addr, remove=True) + await self.update_neighbors(addr, remove=True) pass self.accept_candidates_lock.release() self.late_connection_process_lock.release() From 19e53f94c8607245faa726482baf9ee914f85c23 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 2 Jan 2025 10:31:07 +0100 Subject: [PATCH 046/233] fix_get_neighbors_np --- nebula/core/engine.py | 2 +- .../neighborpolicies/fcneighborpolicy.py | 8 +++++++- .../neighborpolicies/idleneighborpolicy.py | 2 +- .../neighbormanagement/neighborpolicies/neighborpolicy.py | 2 +- .../neighborpolicies/ringneighborpolicy.py | 2 +- .../neighborpolicies/starneighborpolicy.py | 2 +- nebula/core/neighbormanagement/nodemanager.py | 4 ++-- 7 files changed, 14 insertions(+), 8 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index bbcf15ee1..ccc4ab2e9 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -537,7 +537,7 @@ async def _offer_offer_metric_callback(self, source, message): nebula_pb2.LinkMessage.Action.CONNECT_TO, ) async def _link_connect_to_callback(self, source, message): - logging.info(f"πŸ”— handle_link_message | Trigger | Received connecto_to message from {source}") + logging.info(f"πŸ”— handle_link_message | Trigger | Received connect_to message from {source}") addrs = message.addrs for addr in addrs.split(): #await self.cm.connect(addr, direct=True) diff --git a/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py b/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py index ca295635e..af613feff 100644 --- a/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py +++ b/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py @@ -54,7 +54,13 @@ def meet_node(self, node): self.nodes_known.add(node) self.nodes_known_lock.release() - def get_nodes_known(self, neighbors_too=False): + def get_nodes_known(self, neighbors_too=False, neighbors_only=False): + if neighbors_only: + self.neighbors_lock.acquire() + no = self.neighbors.copy() + self.neighbors_lock.release() + return no + self.nodes_known_lock.acquire() nk = self.nodes_known.copy() if not neighbors_too: diff --git a/nebula/core/neighbormanagement/neighborpolicies/idleneighborpolicy.py b/nebula/core/neighbormanagement/neighborpolicies/idleneighborpolicy.py index 475b368f6..63f4b8285 100644 --- a/nebula/core/neighbormanagement/neighborpolicies/idleneighborpolicy.py +++ b/nebula/core/neighbormanagement/neighborpolicies/idleneighborpolicy.py @@ -23,7 +23,7 @@ def meet_node(self, node): def forget_nodes(self, node, forget_all=False): pass - def get_nodes_known(self, neighbors_too=False): + def get_nodes_known(self, neighbors_too=False, neighbors_only=False): return set() def update_neighbors(self, node, remove=False): diff --git a/nebula/core/neighbormanagement/neighborpolicies/neighborpolicy.py b/nebula/core/neighbormanagement/neighborpolicies/neighborpolicy.py index 573ec7d01..d97df6dae 100644 --- a/nebula/core/neighbormanagement/neighborpolicies/neighborpolicy.py +++ b/nebula/core/neighbormanagement/neighborpolicies/neighborpolicy.py @@ -28,7 +28,7 @@ def forget_nodes(self, node, forget_all=False): pass @abstractmethod - def get_nodes_known(self, neighbors_too=False): + def get_nodes_known(self, neighbors_too=False, neighbors_only=False): pass @abstractmethod diff --git a/nebula/core/neighbormanagement/neighborpolicies/ringneighborpolicy.py b/nebula/core/neighbormanagement/neighborpolicies/ringneighborpolicy.py index ad717524b..c74d25919 100644 --- a/nebula/core/neighbormanagement/neighborpolicies/ringneighborpolicy.py +++ b/nebula/core/neighbormanagement/neighborpolicies/ringneighborpolicy.py @@ -58,7 +58,7 @@ def forget_nodes(self, node, forget_all=False): self.nodes_known.discard(node) self.nodes_known_lock.release() - def get_nodes_known(self, neighbors_too=False): + def get_nodes_known(self, neighbors_too=False, neighbors_only=False): self.nodes_known_lock.acquire() nk = self.nodes_known.copy() if not neighbors_too: diff --git a/nebula/core/neighbormanagement/neighborpolicies/starneighborpolicy.py b/nebula/core/neighbormanagement/neighborpolicies/starneighborpolicy.py index 40d2c567b..c79e72ab4 100644 --- a/nebula/core/neighbormanagement/neighborpolicies/starneighborpolicy.py +++ b/nebula/core/neighbormanagement/neighborpolicies/starneighborpolicy.py @@ -51,7 +51,7 @@ def forget_nodes(self, node, forget_all=False): self.nodes_known.discard(node) self.nodes_known_lock.release() - def get_nodes_known(self, neighbors_too=False): + def get_nodes_known(self, neighbors_too=False, neighbors_only=False): self.nodes_known_lock.acquire() nk = self.nodes_known.copy() if not neighbors_too: diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 4b04cf514..1d4262d48 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -169,7 +169,7 @@ def accept_connection(self, source, joining=False): async def add_pending_connection_confirmation(self, addr): await self._update_neighbors_lock.acquire_async() await self.pending_confirmation_from_nodes_lock.acquire_async() - if not addr in self.neighbor_policy.get_nodes_known(neighbors_too=True): + if not addr in self.neighbor_policy.get_nodes_known(neighbors_only=True): logging.info(f" Addition | pending connection confirmation from: {addr}") self.pending_confirmation_from_nodes.add(addr) await self.pending_confirmation_from_nodes_lock.release_async() @@ -325,8 +325,8 @@ async def start_late_connection_process(self, connected=False, msg_type="discove # candidates not choosen --> disconnect try: for addr, _, _ in best_candidates: - await self.engine.cm.send_message(addr, msg) await self.add_pending_connection_confirmation(addr) + await self.engine.cm.send_message(addr, msg) await asyncio.sleep(1) except asyncio.CancelledError as e: await self.update_neighbors(addr, remove=True) From fbf9eb14f94ba9b5bea560d7d642d1d6013f1c24 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Tue, 7 Jan 2025 16:35:48 +0100 Subject: [PATCH 047/233] daily_update --- nebula/core/engine.py | 3 ++- .../candidateselection/fccandidateselector.py | 7 +++++++ nebula/core/neighbormanagement/fastreboot.py | 2 +- .../neighborpolicies/fcneighborpolicy.py | 3 --- nebula/core/neighbormanagement/nodemanager.py | 10 +++++----- nebula/node.py | 18 +++++++++++------- 6 files changed, 26 insertions(+), 17 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index ccc4ab2e9..6913a3700 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -463,7 +463,7 @@ async def _connection_restructure_callback(self, source, message): @event_handler(nebula_pb2.DiscoverMessage, nebula_pb2.DiscoverMessage.Action.DISCOVER_JOIN) async def _discover_discover_join_callback(self, source, message): logging.info(f"πŸ” handle_discover_message | Trigger | Received discover_join message from {source} ") - #self.nm.meet_node(source) + #TODO caso para el starter recibir antes de iniciar federacion if len(self.get_federation_nodes()) > 0: await self.trainning_in_progress_lock.acquire_async() model, rounds, round = await self.cm.propagator.get_model_information(source, "stable") if self.get_round() > 0 else await self.cm.propagator.get_model_information(source, "initialization") @@ -672,6 +672,7 @@ async def deploy_federation(self): await self.cm.send_message_to_neighbors(message) await self.get_federation_ready_lock().release_async() await self.create_trainer_module() + self.set_initialization_status(True) else: logging.info("Federation already started") diff --git a/nebula/core/neighbormanagement/candidateselection/fccandidateselector.py b/nebula/core/neighbormanagement/candidateselection/fccandidateselector.py index 0cfeded70..a2b6937bc 100644 --- a/nebula/core/neighbormanagement/candidateselection/fccandidateselector.py +++ b/nebula/core/neighbormanagement/candidateselection/fccandidateselector.py @@ -19,8 +19,15 @@ def select_candidates(self): """ In Fully-Connected topology all candidates should be selected """ + #0145 + listed = ["192.168.50.2:45001", "192.168.50.3:45002", "192.168.50.6:45005", "192.168.50.7:45006"] + defined = [] self.candidates_lock.acquire() cdts = self.candidates.copy() + for (addr,a,b) in cdts: + if addr in listed: + defined.append((addr,a,b)) + cdts = defined self.candidates_lock.release() return cdts diff --git a/nebula/core/neighbormanagement/fastreboot.py b/nebula/core/neighbormanagement/fastreboot.py index 4912d569d..af830906e 100644 --- a/nebula/core/neighbormanagement/fastreboot.py +++ b/nebula/core/neighbormanagement/fastreboot.py @@ -7,7 +7,7 @@ from nebula.core.neighbormanagement.nodemanager import NodeManager VANILLA_LEARNING_RATE = 1e-3 -FR_LEARNING_RATE = 1e-3#2e-3 +FR_LEARNING_RATE = 2e-3 MAX_ROUNDS = 20 DEFAULT_WEIGHT_MODIFIER = 3 diff --git a/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py b/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py index af613feff..9e4c1a77e 100644 --- a/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py +++ b/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py @@ -1,5 +1,4 @@ from nebula.core.neighbormanagement.neighborpolicies.neighborpolicy import NeighborPolicy -from nebula.core.neighbormanagement.nodemanager import NodeManager from nebula.core.utils.locker import Locker class FCNeighborPolicy(NeighborPolicy): @@ -10,7 +9,6 @@ def __init__(self): self.neighbors = set() self.neighbors_lock = Locker(name="neighbors_lock") self.nodes_known_lock = Locker(name="nodes_known_lock") - self._nm : NodeManager = None def need_more_neighbors(self): """ @@ -35,7 +33,6 @@ def set_config(self, config): self.neighbors_lock.release() for addr in config[1]: self.nodes_known.add(addr) - self._nm = config[3] def accept_connection(self, source, joining=False): """ diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 1d4262d48..85faf14b8 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -25,7 +25,7 @@ def __init__( engine : "Engine", fastreboot=True, ): - self.topology = topology + self.topology = "fully"#topology print_msg_box(msg=f"Starting NodeManager module...\nTopology: {self.topology}", indent=2, title="NodeManager module") logging.info("🌐 Initializing Node Manager") self._engine = engine @@ -236,9 +236,9 @@ def get_nodes_known(self, neighbors_too=False): def accept_model_offer(self, source, decoded_model, rounds, round, epochs, n_neighbors, loss): if not self.accept_candidates_lock.locked(): logging.info(f"πŸ”„ Processing offer from {source}...") - model_accepted = self.model_handler.accept_model(decoded_model) - #if source == "192.168.50.8:45007": - # self.model_handler.accept_model(decoded_model) + model_accepted = True#self.model_handler.accept_model(decoded_model) + if source == "192.168.50.8:45007": + self.model_handler.accept_model(decoded_model) self.model_handler.set_config(config=(rounds, round, epochs, self)) if model_accepted: self.candidate_selector.add_candidate((source, n_neighbors, loss)) @@ -359,7 +359,7 @@ async def check_robustness(self): #await self.reconnect_to_federation() elif self.neighbor_policy.need_more_neighbors() and self.engine.get_sinchronized_status(): logging.info("Insufficient Robustness | Upgrading robustness | Searching for more connections") - asyncio.create_task(self.upgrade_connection_robustness()) + #asyncio.create_task(self.upgrade_connection_robustness()) else: if not self.engine.get_sinchronized_status(): logging.info("Device not synchronized with federation") diff --git a/nebula/node.py b/nebula/node.py index dc5e06b6d..488a95511 100755 --- a/nebula/node.py +++ b/nebula/node.py @@ -350,18 +350,22 @@ def randomize_value(value, variability): if additional_node_status: logging.info(f"Waiting for round {additional_node_round} to start") logging.info("Waiting time to start finding federation") + # MNIST # 385 r30 # 615 r50 - # CIFAR # 420 r15 - #if config.participant["network_args"]["ip"] == "192.168.50.9": - # logging.info("Sleeping 385s...") - time.sleep(420) - #elif config.participant["network_args"]["ip"] == "192.168.50.10": - # logging.info("Sleeping 800s...") - # time.sleep(615) + # 600 r22 + + #if config.participant["network_args"]["ip"] == "192.168.50.11": + #time.sleep(820) + + if config.participant["network_args"]["ip"] == "192.168.50.11": + time.sleep(820) + elif config.participant["network_args"]["ip"] == "192.168.50.12": + time.sleep(420) + #time.sleep(6000) # DEBUG purposes #import requests From 822a4700eab32e4610c237ad3ab44c9300ea7e22 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 9 Jan 2025 10:43:54 +0100 Subject: [PATCH 048/233] change_scenario_config --- nebula/core/engine.py | 2 +- .../candidateselection/fccandidateselector.py | 12 ++++++------ nebula/core/neighbormanagement/nodemanager.py | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 6913a3700..a1cdb72c4 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -158,7 +158,7 @@ def __init__( if self.mobility == True: topology = self.config.participant["mobility_args"]["topology_type"] topology = topology.lower() - model_handler = "std" #self.config.participant["mobility_args"]["model_handler"] + model_handler = "default" #self.config.participant["mobility_args"]["model_handler"] acceleration_push = "slow" #self.config.participant["mobility_args"]["push_strategy"] self._node_manager = NodeManager(topology, model_handler, acceleration_push, engine=self) diff --git a/nebula/core/neighbormanagement/candidateselection/fccandidateselector.py b/nebula/core/neighbormanagement/candidateselection/fccandidateselector.py index a2b6937bc..46a7076ac 100644 --- a/nebula/core/neighbormanagement/candidateselection/fccandidateselector.py +++ b/nebula/core/neighbormanagement/candidateselection/fccandidateselector.py @@ -20,14 +20,14 @@ def select_candidates(self): In Fully-Connected topology all candidates should be selected """ #0145 - listed = ["192.168.50.2:45001", "192.168.50.3:45002", "192.168.50.6:45005", "192.168.50.7:45006"] - defined = [] + #listed = ["192.168.50.2:45001", "192.168.50.3:45002", "192.168.50.6:45005", "192.168.50.7:45006"] + #defined = [] self.candidates_lock.acquire() cdts = self.candidates.copy() - for (addr,a,b) in cdts: - if addr in listed: - defined.append((addr,a,b)) - cdts = defined + #for (addr,a,b) in cdts: + # if addr in listed: + # defined.append((addr,a,b)) + #cdts = defined self.candidates_lock.release() return cdts diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 85faf14b8..67c900e40 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -23,7 +23,7 @@ def __init__( model_handler, push_acceleration, engine : "Engine", - fastreboot=True, + fastreboot=False, ): self.topology = "fully"#topology print_msg_box(msg=f"Starting NodeManager module...\nTopology: {self.topology}", indent=2, title="NodeManager module") @@ -359,7 +359,7 @@ async def check_robustness(self): #await self.reconnect_to_federation() elif self.neighbor_policy.need_more_neighbors() and self.engine.get_sinchronized_status(): logging.info("Insufficient Robustness | Upgrading robustness | Searching for more connections") - #asyncio.create_task(self.upgrade_connection_robustness()) + asyncio.create_task(self.upgrade_connection_robustness()) else: if not self.engine.get_sinchronized_status(): logging.info("Device not synchronized with federation") From 9b507401dab92539c61aede1a2c5cb07f6138aed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Avil=C3=A9s=20Serrano?= <80918548+AlejandroAvilesSerrano@users.noreply.github.com> Date: Fri, 10 Jan 2025 15:36:14 +0100 Subject: [PATCH 049/233] fix_momentum --- nebula/core/neighbormanagement/momentum.py | 43 +++++++++++++++++++ nebula/core/neighbormanagement/nodemanager.py | 1 + nebula/core/network/communications.py | 12 +++--- .../frontend/config/participant.json.example | 2 +- 4 files changed, 51 insertions(+), 7 deletions(-) create mode 100644 nebula/core/neighbormanagement/momentum.py diff --git a/nebula/core/neighbormanagement/momentum.py b/nebula/core/neighbormanagement/momentum.py new file mode 100644 index 000000000..b6bcb633f --- /dev/null +++ b/nebula/core/neighbormanagement/momentum.py @@ -0,0 +1,43 @@ +import asyncio +import logging +from collections import deque +import os + +from nebula.core.utils.locker import Locker +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from nebula.core.neighbormanagement.nodemanager import NodeManager + +MAX_HISTORIC_SIZE = 10 + +class Momentum(): + + def __init__( + self, + node_manager : "NodeManager", + nodes, + ): + self._node_manager = node_manager + self._similarities_historic = {node_id: deque(maxlen=MAX_HISTORIC_SIZE) for node_id in nodes} + self._similarities_historic_lock = Locker(name="__similarities_historic_lock", async_lock=True) + + @property + def nm(self): + return self._node_manager + + async def add_similarity_to_node(self, node_id, sim_value): + logging.info(f"Adding | node ID: {node_id}, cossine similarity value: {sim_value}") + self._similarities_historic_lock.acquire_async() + self._similarities_historic[node_id].append(sim_value) + self._similarities_historic_lock.release_async() + + async def update_node(self, node_id, remove=False): + self._similarities_historic_lock.acquire_async() + if remove: + self._similarities_historic.pop(node_id, None) + else: + self._similarities_historic.update({node_id: deque(maxlen=MAX_HISTORIC_SIZE)}) + self._similarities_historic_lock.release_async() + + async def apply_similarity_weights(self, updates: dict): + pass \ No newline at end of file diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 67c900e40..ff5eb83fd 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -24,6 +24,7 @@ def __init__( push_acceleration, engine : "Engine", fastreboot=False, + momentum=False, ): self.topology = "fully"#topology print_msg_box(msg=f"Starting NodeManager module...\nTopology: {self.topology}", indent=2, title="NodeManager module") diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 016c81707..c15a2df04 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -256,33 +256,33 @@ async def handle_model_message(self, source, message): if self.config.participant["adaptive_args"]["model_similarity"]: logging.info("πŸ€– handle_model_message | Checking model similarity") cosine_value = cosine_metric( - self.trainer.get_model_parameters(), + self.engine.trainer.get_model_parameters(), decoded_model, similarity=True, ) euclidean_value = euclidean_metric( - self.trainer.get_model_parameters(), + self.engine.trainer.get_model_parameters(), decoded_model, similarity=True, ) minkowski_value = minkowski_metric( - self.trainer.get_model_parameters(), + self.engine.trainer.get_model_parameters(), decoded_model, p=2, similarity=True, ) manhattan_value = manhattan_metric( - self.trainer.get_model_parameters(), + self.engine.trainer.get_model_parameters(), decoded_model, similarity=True, ) pearson_correlation_value = pearson_correlation_metric( - self.trainer.get_model_parameters(), + self.engine.trainer.get_model_parameters(), decoded_model, similarity=True, ) jaccard_value = jaccard_metric( - self.trainer.get_model_parameters(), + self.engine.trainer.get_model_parameters(), decoded_model, similarity=True, ) diff --git a/nebula/frontend/config/participant.json.example b/nebula/frontend/config/participant.json.example index f53479a73..b1156f5d6 100755 --- a/nebula/frontend/config/participant.json.example +++ b/nebula/frontend/config/participant.json.example @@ -51,7 +51,7 @@ "reordering": "0%" }, "adaptive_args": { - "model_similarity": false + "model_similarity": true }, "mobility_args": { "latitude": "", From 139dae7910f7b504b5f2612ed7e4e4dfc585d070 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Avil=C3=A9s=20Serrano?= <80918548+AlejandroAvilesSerrano@users.noreply.github.com> Date: Fri, 10 Jan 2025 17:49:04 +0100 Subject: [PATCH 050/233] fix_com_error --- nebula/core/network/communications.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index c15a2df04..bdcb40082 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -286,17 +286,19 @@ async def handle_model_message(self, source, message): decoded_model, similarity=True, ) - with open( - f"{self.log_dir}/participant_{self.idx}_similarity.csv", - "a+", - ) as f: - if os.stat(f"{self.log_dir}/participant_{self.idx}_similarity.csv").st_size == 0: - f.write( - "timestamp,source_ip,nodes,round,current_round,cosine,euclidean,minkowski,manhattan,pearson_correlation,jaccard\n" - ) - f.write( - f"{datetime.now()}, {source}, {message.round}, {current_round}, {cosine_value}, {euclidean_value}, {minkowski_value}, {manhattan_value}, {pearson_correlation_value}, {jaccard_value}\n" - ) + #with open( + # f"{self.config.participant["tracking_args"]["log_dir"]}/participant_{self.id}_similarity.csv", + # "a+", + #) as f: + # if os.stat(f"{self}/participant_{self.id}_similarity.csv").st_size == 0: + # f.write( + # "timestamp,source_ip,nodes,round,current_round,cosine,euclidean,minkowski,manhattan,pearson_correlation,jaccard\n" + # ) + # f.write( + # f"{datetime.now()}, {source}, {message.round}, {current_round}, {cosine_value}, {euclidean_value}, {minkowski_value}, {manhattan_value}, {pearson_correlation_value}, {jaccard_value}\n" + # ) + logging("Similarities between self model and model recieved...") + logging.info(f"{source}, {message.round}, {current_round}, {cosine_value}, {euclidean_value}, {minkowski_value}, {manhattan_value}, {pearson_correlation_value}, {jaccard_value}") await self.engine.aggregator.include_model_in_buffer( decoded_model, From 1177a4fbf6bcd55a5889d179f6f78f7bf228f4a5 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Sat, 11 Jan 2025 23:31:54 +0100 Subject: [PATCH 051/233] updt_momentum --- nebula/core/engine.py | 2 +- nebula/core/neighbormanagement/momentum.py | 78 ++++++++++++++++--- nebula/core/neighbormanagement/nodemanager.py | 6 ++ nebula/node.py | 2 +- 4 files changed, 77 insertions(+), 11 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index a1cdb72c4..6913a3700 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -158,7 +158,7 @@ def __init__( if self.mobility == True: topology = self.config.participant["mobility_args"]["topology_type"] topology = topology.lower() - model_handler = "default" #self.config.participant["mobility_args"]["model_handler"] + model_handler = "std" #self.config.participant["mobility_args"]["model_handler"] acceleration_push = "slow" #self.config.participant["mobility_args"]["push_strategy"] self._node_manager = NodeManager(topology, model_handler, acceleration_push, engine=self) diff --git a/nebula/core/neighbormanagement/momentum.py b/nebula/core/neighbormanagement/momentum.py index b6bcb633f..876505b11 100644 --- a/nebula/core/neighbormanagement/momentum.py +++ b/nebula/core/neighbormanagement/momentum.py @@ -1,35 +1,66 @@ import asyncio import logging from collections import deque -import os - +from nebula.core.utils.helper import cosine_metric from nebula.core.utils.locker import Locker -from typing import TYPE_CHECKING +import numpy as np + +from typing import TYPE_CHECKING, Callable, OrderedDict, Optional if TYPE_CHECKING: from nebula.core.neighbormanagement.nodemanager import NodeManager -MAX_HISTORIC_SIZE = 10 +SimilarityMetricType = Callable[[OrderedDict, OrderedDict, bool], Optional[float]] + +MAX_HISTORIC_SIZE = 10 # Number of historic data storaged +GLOBAL_PRIORITY = 0.5 # Parameter to priorize global vs local metrics +K = 1 # Sigmoid smoother factor class Momentum(): def __init__( self, - node_manager : "NodeManager", + node_manager: "NodeManager", nodes, + variance=True, + similarity_metric : SimilarityMetricType = cosine_metric, + global_priority=GLOBAL_PRIORITY, ): self._node_manager = node_manager self._similarities_historic = {node_id: deque(maxlen=MAX_HISTORIC_SIZE) for node_id in nodes} - self._similarities_historic_lock = Locker(name="__similarities_historic_lock", async_lock=True) + self._similarities_historic_lock = Locker(name="_similarities_historic_lock", async_lock=True) + self._model_similarity_metric_lock = Locker(name="_model_similarity_metric_lock", async_lock=True) + self._model_similarity_metric = similarity_metric + self._global_prio = global_priority + self._variance_status = variance @property def nm(self): return self._node_manager + + @property + def msm(self): + return self._model_similarity_metric - async def add_similarity_to_node(self, node_id, sim_value): + async def _add_similarity_to_node(self, node_id, sim_value): logging.info(f"Adding | node ID: {node_id}, cossine similarity value: {sim_value}") self._similarities_historic_lock.acquire_async() self._similarities_historic[node_id].append(sim_value) self._similarities_historic_lock.release_async() + + async def _get_similarity_historic(self, addrs): + """ + Get historic storaged for node IDs on 'addrs' + + Args: + addrs (List)): List of node IDs that has sent update this round + """ + self._similarities_historic_lock.acquire_async() + historic = {} + for key, value in self._similarities_historic.items(): + if key in addrs: + historic[key] = value + self._similarities_historic_lock.release_async() + return historic async def update_node(self, node_id, remove=False): self._similarities_historic_lock.acquire_async() @@ -38,6 +69,35 @@ async def update_node(self, node_id, remove=False): else: self._similarities_historic.update({node_id: deque(maxlen=MAX_HISTORIC_SIZE)}) self._similarities_historic_lock.release_async() + + async def change_similarity_metric(self, new_metric: SimilarityMetricType): + self._model_similarity_metric_lock.acquire_async() + self.msm = new_metric + # maybe we should remove historic data due to incongruous data + self._model_similarity_metric_lock.release_async() + + async def _calculate_similarities(self, updates: dict): + """ + Function to calculate similarity between local model and models received + using metric function. The value is storaged on the historic + + Args: + updates (dict): {node ID: model} + """ + logging.info(f"Calculating | Model Similarity values are being calculated...") + model = self.nm.engine.trainer.get_model_parameters() + for addr,update in updates.items(): + cosine_value = self._model_similarity_metric( + model, + update, + similarity=True, + ) + await self._add_similarity_to_node(addr, cosine_value) - async def apply_similarity_weights(self, updates: dict): - pass \ No newline at end of file + async def calculate_similarity_weights(self, updates: dict): + logging.info("Calculating | Momemtum weights are being calculated...") + self._model_similarity_metric_lock.acquire_async() + await self._calculate_similarities(updates) + historic = await self._get_similarity_historic(updates.keys()) + similarities = [node_sim[-1] for node_sim in historic.values() if node_sim] + self._model_similarity_metric_lock.release_async() \ No newline at end of file diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index ff5eb83fd..b759d83fa 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -2,6 +2,8 @@ import logging import os +import importlib + from nebula.core.utils.locker import Locker from nebula.core.neighbormanagement.candidateselection.candidateselector import factory_CandidateSelector from nebula.core.neighbormanagement.modelhandlers.modelhandler import factory_ModelHandler @@ -392,5 +394,9 @@ async def upgrade_connection_robustness(self): logging.info("Reestructuring | NO Addrs availables") await self.start_late_connection_process(connected=True, msg_type="discover_nodes") self._restructure_process_lock.release() + + + + \ No newline at end of file diff --git a/nebula/node.py b/nebula/node.py index 488a95511..f5fbc2cb8 100755 --- a/nebula/node.py +++ b/nebula/node.py @@ -361,7 +361,7 @@ def randomize_value(value, variability): #if config.participant["network_args"]["ip"] == "192.168.50.11": #time.sleep(820) - if config.participant["network_args"]["ip"] == "192.168.50.11": + if config.participant["network_args"]["ip"] == "192.168.50.9": time.sleep(820) elif config.participant["network_args"]["ip"] == "192.168.50.12": time.sleep(420) From e4e5287e4d6a2ac009fc97cd7929f06638767368 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Sun, 12 Jan 2025 11:22:18 +0100 Subject: [PATCH 052/233] feat_momentum momemtum logic implemented --- nebula/core/neighbormanagement/momentum.py | 32 +++++++++++++++++----- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/nebula/core/neighbormanagement/momentum.py b/nebula/core/neighbormanagement/momentum.py index 876505b11..8a489861c 100644 --- a/nebula/core/neighbormanagement/momentum.py +++ b/nebula/core/neighbormanagement/momentum.py @@ -14,6 +14,7 @@ MAX_HISTORIC_SIZE = 10 # Number of historic data storaged GLOBAL_PRIORITY = 0.5 # Parameter to priorize global vs local metrics K = 1 # Sigmoid smoother factor +EPSILON = 0.001 class Momentum(): @@ -21,9 +22,9 @@ def __init__( self, node_manager: "NodeManager", nodes, + global_priority=GLOBAL_PRIORITY, variance=True, similarity_metric : SimilarityMetricType = cosine_metric, - global_priority=GLOBAL_PRIORITY, ): self._node_manager = node_manager self._similarities_historic = {node_id: deque(maxlen=MAX_HISTORIC_SIZE) for node_id in nodes} @@ -92,12 +93,29 @@ async def _calculate_similarities(self, updates: dict): update, similarity=True, ) - await self._add_similarity_to_node(addr, cosine_value) - - async def calculate_similarity_weights(self, updates: dict): + await self._add_similarity_to_node(addr, cosine_value) + + async def calculate_momentum_weights(self, updates: dict): logging.info("Calculating | Momemtum weights are being calculated...") self._model_similarity_metric_lock.acquire_async() - await self._calculate_similarities(updates) - historic = await self._get_similarity_historic(updates.keys()) - similarities = [node_sim[-1] for node_sim in historic.values() if node_sim] + await self._calculate_similarities(updates) # Calculate similarity value between self model and updates received + historic = await self._get_similarity_historic(updates.keys()) # Get historic similarities values from nodes that has sent update this round + + round_similarities = [n_hist[-1] for n_hist in historic.values() if n_hist] + variance = np.var(round_similarities) if round_similarities else 0 + + # scaled_sigmoid = a + (bβˆ’a)β‹…sigmoid if u desire to get min_values < 0.5, define a = min_value + def sigmoid(similarity): + sigmoid = 1 / (1 + np.exp(-K * (similarity - 0.5))) + return sigmoid + + for node_id, n_hist in historic.items(): + if not n_hist: + continue + sim_value = n_hist[-1] # Get last similarity value + mapped_sim_value = EPSILON + ((sim_value + 1) / 2) # Mapped [-1, 1] -> [0, 1] + smoothed_value = sigmoid(mapped_sim_value) + adjusted_weight = smoothed_value * self._global_prio + (1 - self._global_prio) * mapped_sim_value + + self._model_similarity_metric_lock.release_async() \ No newline at end of file From 282ab2f3f0fbe6d00f328a5205d21516eaf3ecce Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Sun, 12 Jan 2025 19:28:53 +0100 Subject: [PATCH 053/233] feat_momemtum_penalty --- nebula/core/neighbormanagement/momentum.py | 47 +++++++++++++------ nebula/core/neighbormanagement/nodemanager.py | 3 ++ 2 files changed, 35 insertions(+), 15 deletions(-) diff --git a/nebula/core/neighbormanagement/momentum.py b/nebula/core/neighbormanagement/momentum.py index 8a489861c..ad03bd967 100644 --- a/nebula/core/neighbormanagement/momentum.py +++ b/nebula/core/neighbormanagement/momentum.py @@ -13,7 +13,6 @@ MAX_HISTORIC_SIZE = 10 # Number of historic data storaged GLOBAL_PRIORITY = 0.5 # Parameter to priorize global vs local metrics -K = 1 # Sigmoid smoother factor EPSILON = 0.001 class Momentum(): @@ -23,7 +22,7 @@ def __init__( node_manager: "NodeManager", nodes, global_priority=GLOBAL_PRIORITY, - variance=True, + dispersion_penalty=True, similarity_metric : SimilarityMetricType = cosine_metric, ): self._node_manager = node_manager @@ -32,7 +31,7 @@ def __init__( self._model_similarity_metric_lock = Locker(name="_model_similarity_metric_lock", async_lock=True) self._model_similarity_metric = similarity_metric self._global_prio = global_priority - self._variance_status = variance + self._dispersion_penalty = dispersion_penalty @property def nm(self): @@ -94,28 +93,46 @@ async def _calculate_similarities(self, updates: dict): similarity=True, ) await self._add_similarity_to_node(addr, cosine_value) - + + def _calculate_dispersion_penalty(self, historic: dict, updates: dict): + logging.info("Calculating | Dispersion penalty") + round_similarities = [(addr, n_hist[-1]) for addr,n_hist in historic.items() if n_hist] + if round_similarities: + mean_similarity = np.mean(round_similarities) + std_similarity = np.std(round_similarities) + EPSILON + logging.info(f"Calculating | mean similarity: {mean_similarity}, standar similarity: {std_similarity}") + for addr,sim in round_similarities: + penalty = abs(sim - mean_similarity) / (std_similarity + EPSILON) # To avoid div by 0 + penalty = min(1.0, max(0.0, penalty)) + dispersion_penalty = 1 - penalty + async def calculate_momentum_weights(self, updates: dict): + if not updates: + return logging.info("Calculating | Momemtum weights are being calculated...") - self._model_similarity_metric_lock.acquire_async() + self._model_similarity_metric_lock.acquire_async() await self._calculate_similarities(updates) # Calculate similarity value between self model and updates received historic = await self._get_similarity_historic(updates.keys()) # Get historic similarities values from nodes that has sent update this round - - round_similarities = [n_hist[-1] for n_hist in historic.values() if n_hist] - variance = np.var(round_similarities) if round_similarities else 0 - - # scaled_sigmoid = a + (bβˆ’a)β‹…sigmoid if u desire to get min_values < 0.5, define a = min_value - def sigmoid(similarity): - sigmoid = 1 / (1 + np.exp(-K * (similarity - 0.5))) + + def sigmoid(similarity, k=2.5): + if similarity >= 0.92: + sigmoid = 1 + else: + sigmoid = 1 / (1 + np.exp(-k * (similarity))) return sigmoid - for node_id, n_hist in historic.items(): + def map_value(sim_value, e=EPSILON): + return e + ((sim_value + 1) / 2) + + for node_addr, n_hist in historic.items(): if not n_hist: continue - sim_value = n_hist[-1] # Get last similarity value - mapped_sim_value = EPSILON + ((sim_value + 1) / 2) # Mapped [-1, 1] -> [0, 1] + sim_value = n_hist[-1] # Get last similarity value + mapped_sim_value = map_value(sim_value) # Mapped [-1, 1] -> [0, 1] smoothed_value = sigmoid(mapped_sim_value) adjusted_weight = smoothed_value * self._global_prio + (1 - self._global_prio) * mapped_sim_value + if self._dispersion_penalty: + self._calculate_dispersion_penalty(historic, updates) self._model_similarity_metric_lock.release_async() \ No newline at end of file diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index b759d83fa..ce58c20d3 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -56,6 +56,9 @@ def __init__( self._fast_reboot_status = fastreboot if (fastreboot): self._fastreboot = FastReboot(self) + + if (momentum): + pass #self.set_confings() From 7f5e23010114a0d02da786b7d3835a8a5cf1d66f Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Mon, 13 Jan 2025 15:30:40 +0100 Subject: [PATCH 054/233] updt_momemtum_penalty_ext --- nebula/core/neighbormanagement/momentum.py | 55 +++++++++++++++------- 1 file changed, 39 insertions(+), 16 deletions(-) diff --git a/nebula/core/neighbormanagement/momentum.py b/nebula/core/neighbormanagement/momentum.py index ad03bd967..a4183341a 100644 --- a/nebula/core/neighbormanagement/momentum.py +++ b/nebula/core/neighbormanagement/momentum.py @@ -6,14 +6,18 @@ import numpy as np from typing import TYPE_CHECKING, Callable, OrderedDict, Optional +from typing_extensions import Annotated if TYPE_CHECKING: from nebula.core.neighbormanagement.nodemanager import NodeManager SimilarityMetricType = Callable[[OrderedDict, OrderedDict, bool], Optional[float]] +MappingSimilarityType = Callable[[float, float], Annotated[float, "Value in (0, 1]"]] MAX_HISTORIC_SIZE = 10 # Number of historic data storaged GLOBAL_PRIORITY = 0.5 # Parameter to priorize global vs local metrics EPSILON = 0.001 +TOLERANCE_THRESHOLD = 2 # Threshold to start appliying full penalty +SMOOTH_FACTOR = 0.5 class Momentum(): @@ -24,14 +28,17 @@ def __init__( global_priority=GLOBAL_PRIORITY, dispersion_penalty=True, similarity_metric : SimilarityMetricType = cosine_metric, + mapping_similarity : MappingSimilarityType = lambda sim_value, e=EPSILON: e + ((sim_value + 1) / 2), ): self._node_manager = node_manager self._similarities_historic = {node_id: deque(maxlen=MAX_HISTORIC_SIZE) for node_id in nodes} self._similarities_historic_lock = Locker(name="_similarities_historic_lock", async_lock=True) self._model_similarity_metric_lock = Locker(name="_model_similarity_metric_lock", async_lock=True) self._model_similarity_metric = similarity_metric + self._mapping_similarity_func = mapping_similarity self._global_prio = global_priority self._dispersion_penalty = dispersion_penalty + self._addr = self._node_manager.engine.addr @property def nm(self): @@ -40,6 +47,10 @@ def nm(self): @property def msm(self): return self._model_similarity_metric + + @property + def msf(self): + return self._mapping_similarity_func async def _add_similarity_to_node(self, node_id, sim_value): logging.info(f"Adding | node ID: {node_id}, cossine similarity value: {sim_value}") @@ -70,9 +81,10 @@ async def update_node(self, node_id, remove=False): self._similarities_historic.update({node_id: deque(maxlen=MAX_HISTORIC_SIZE)}) self._similarities_historic_lock.release_async() - async def change_similarity_metric(self, new_metric: SimilarityMetricType): + async def change_similarity_metric(self, new_metric: SimilarityMetricType, new_mapping: MappingSimilarityType): self._model_similarity_metric_lock.acquire_async() self.msm = new_metric + self.msf = new_mapping # maybe we should remove historic data due to incongruous data self._model_similarity_metric_lock.release_async() @@ -84,51 +96,62 @@ async def _calculate_similarities(self, updates: dict): Args: updates (dict): {node ID: model} """ - logging.info(f"Calculating | Model Similarity values are being calculated...") + logging.info(f"Calculate | Model Similarity values are being calculated...") model = self.nm.engine.trainer.get_model_parameters() for addr,update in updates.items(): - cosine_value = self._model_similarity_metric( + if addr == self._addr: + continue + sim_value = self.msm( model, update, similarity=True, ) - await self._add_similarity_to_node(addr, cosine_value) + await self._add_similarity_to_node(addr, sim_value) def _calculate_dispersion_penalty(self, historic: dict, updates: dict): - logging.info("Calculating | Dispersion penalty") + from math import sqrt + logging.info("Calculate | Dispersion penalty") round_similarities = [(addr, n_hist[-1]) for addr,n_hist in historic.items() if n_hist] if round_similarities: mean_similarity = np.mean(round_similarities) - std_similarity = np.std(round_similarities) + EPSILON - logging.info(f"Calculating | mean similarity: {mean_similarity}, standar similarity: {std_similarity}") + std_similarity = np.std(round_similarities) + n_updates = len(updates) - 1 + logging.info(f"Calculate | mean similarity: {mean_similarity}, standar deviation: {std_similarity}") for addr,sim in round_similarities: - penalty = abs(sim - mean_similarity) / (std_similarity + EPSILON) # To avoid div by 0 + if abs(sim - mean_similarity) < TOLERANCE_THRESHOLD * std_similarity: + logging.info(f"Penalty | Dispersion is lower than threshold, for node: {addr}") + penalty = (SMOOTH_FACTOR * (abs(sim - mean_similarity) / (std_similarity + EPSILON))) * (1/sqrt(n_updates)) + else: + logging.info(f"Penalty | Dispersion is higher than threshold, for node: {addr}") + penalty = (abs(sim - mean_similarity) / (std_similarity + EPSILON)) * (1/sqrt(n_updates)) + penalty = min(1.0, max(0.0, penalty)) + logging.info(f"Penalty value: {penalty}") dispersion_penalty = 1 - penalty + + def map_value(sim_value, e=EPSILON): + return e + ((sim_value + 1) / 2) async def calculate_momentum_weights(self, updates: dict): if not updates: return - logging.info("Calculating | Momemtum weights are being calculated...") + logging.info("Calculate | Momemtum weights are being calculated...") self._model_similarity_metric_lock.acquire_async() await self._calculate_similarities(updates) # Calculate similarity value between self model and updates received historic = await self._get_similarity_historic(updates.keys()) # Get historic similarities values from nodes that has sent update this round def sigmoid(similarity, k=2.5): - if similarity >= 0.92: + if similarity >= 0.92: # threshold to consider better updates sigmoid = 1 else: sigmoid = 1 / (1 + np.exp(-k * (similarity))) return sigmoid - - def map_value(sim_value, e=EPSILON): - return e + ((sim_value + 1) / 2) - + for node_addr, n_hist in historic.items(): - if not n_hist: + if not n_hist or node_addr == self._addr: continue sim_value = n_hist[-1] # Get last similarity value - mapped_sim_value = map_value(sim_value) # Mapped [-1, 1] -> [0, 1] + mapped_sim_value = self.msf(sim_value) # Mapped into [0, 1] interval smoothed_value = sigmoid(mapped_sim_value) adjusted_weight = smoothed_value * self._global_prio + (1 - self._global_prio) * mapped_sim_value From e57767c40e5955b72442cca86f83d4d2983add49 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 23 Jan 2025 23:28:40 +0100 Subject: [PATCH 055/233] update momemtum --- .../candidateselection/fccandidateselector.py | 4 ++-- nebula/core/neighbormanagement/fastreboot.py | 2 +- nebula/core/neighbormanagement/momentum.py | 3 ++- nebula/core/neighbormanagement/nodemanager.py | 8 ++++---- nebula/core/network/communications.py | 2 +- nebula/node.py | 12 ++++++++---- 6 files changed, 18 insertions(+), 13 deletions(-) diff --git a/nebula/core/neighbormanagement/candidateselection/fccandidateselector.py b/nebula/core/neighbormanagement/candidateselection/fccandidateselector.py index 46a7076ac..e097e5be9 100644 --- a/nebula/core/neighbormanagement/candidateselection/fccandidateselector.py +++ b/nebula/core/neighbormanagement/candidateselection/fccandidateselector.py @@ -20,9 +20,9 @@ def select_candidates(self): In Fully-Connected topology all candidates should be selected """ #0145 - #listed = ["192.168.50.2:45001", "192.168.50.3:45002", "192.168.50.6:45005", "192.168.50.7:45006"] + #listed = ["192.168.51.2:45001", "192.168.51.3:45002", "192.168.51.6:45005", "192.168.51.7:45006"] #defined = [] - self.candidates_lock.acquire() + #self.candidates_lock.acquire() cdts = self.candidates.copy() #for (addr,a,b) in cdts: # if addr in listed: diff --git a/nebula/core/neighbormanagement/fastreboot.py b/nebula/core/neighbormanagement/fastreboot.py index af830906e..2fdc0f153 100644 --- a/nebula/core/neighbormanagement/fastreboot.py +++ b/nebula/core/neighbormanagement/fastreboot.py @@ -7,7 +7,7 @@ from nebula.core.neighbormanagement.nodemanager import NodeManager VANILLA_LEARNING_RATE = 1e-3 -FR_LEARNING_RATE = 2e-3 +FR_LEARNING_RATE = 1e-3 MAX_ROUNDS = 20 DEFAULT_WEIGHT_MODIFIER = 3 diff --git a/nebula/core/neighbormanagement/momentum.py b/nebula/core/neighbormanagement/momentum.py index a4183341a..8d4152821 100644 --- a/nebula/core/neighbormanagement/momentum.py +++ b/nebula/core/neighbormanagement/momentum.py @@ -16,6 +16,7 @@ MAX_HISTORIC_SIZE = 10 # Number of historic data storaged GLOBAL_PRIORITY = 0.5 # Parameter to priorize global vs local metrics EPSILON = 0.001 +SIGMOID_THRESHOLD = 0.92 TOLERANCE_THRESHOLD = 2 # Threshold to start appliying full penalty SMOOTH_FACTOR = 0.5 @@ -141,7 +142,7 @@ async def calculate_momentum_weights(self, updates: dict): historic = await self._get_similarity_historic(updates.keys()) # Get historic similarities values from nodes that has sent update this round def sigmoid(similarity, k=2.5): - if similarity >= 0.92: # threshold to consider better updates + if similarity >= SIGMOID_THRESHOLD: # threshold to consider better updates sigmoid = 1 else: sigmoid = 1 / (1 + np.exp(-k * (similarity))) diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index ce58c20d3..415edc7b8 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -25,7 +25,7 @@ def __init__( model_handler, push_acceleration, engine : "Engine", - fastreboot=False, + fastreboot=True, momentum=False, ): self.topology = "fully"#topology @@ -242,9 +242,9 @@ def get_nodes_known(self, neighbors_too=False): def accept_model_offer(self, source, decoded_model, rounds, round, epochs, n_neighbors, loss): if not self.accept_candidates_lock.locked(): logging.info(f"πŸ”„ Processing offer from {source}...") - model_accepted = True#self.model_handler.accept_model(decoded_model) - if source == "192.168.50.8:45007": - self.model_handler.accept_model(decoded_model) + #model_accepted = True#self.model_handler.accept_model(decoded_model) + #if source == "192.168.50.8:45007": + model_accepted = self.model_handler.accept_model(decoded_model) self.model_handler.set_config(config=(rounds, round, epochs, self)) if model_accepted: self.candidate_selector.add_candidate((source, n_neighbors, loss)) diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index bdcb40082..ef5663a21 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -253,7 +253,7 @@ async def handle_model_message(self, source, message): # non-starting nodes receive the initialized model from the starting node if not self.engine.get_federation_ready_lock().locked() or self.engine.get_initialization_status(): decoded_model = self.engine.trainer.deserialize_model(message.parameters) - if self.config.participant["adaptive_args"]["model_similarity"]: + if False and self.config.participant["adaptive_args"]["model_similarity"]: logging.info("πŸ€– handle_model_message | Checking model similarity") cosine_value = cosine_metric( self.engine.trainer.get_model_parameters(), diff --git a/nebula/node.py b/nebula/node.py index f5fbc2cb8..ce5a8d683 100755 --- a/nebula/node.py +++ b/nebula/node.py @@ -361,10 +361,14 @@ def randomize_value(value, variability): #if config.participant["network_args"]["ip"] == "192.168.50.11": #time.sleep(820) - if config.participant["network_args"]["ip"] == "192.168.50.9": - time.sleep(820) - elif config.participant["network_args"]["ip"] == "192.168.50.12": - time.sleep(420) + time.sleep(800) + + #if config.participant["network_args"]["ip"] == "192.168.51.11": + # logging.info("waiting 385s") + + #elif config.participant["network_args"]["ip"] == "192.168.51.12": + # logging.info("waiting 800s") + # time.sleep(800) #time.sleep(6000) # DEBUG purposes #import requests From a0c9c242241e26da755fcae3800a8b24cd8bb762 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Avil=C3=A9s=20Serrano?= <80918548+AlejandroAvilesSerrano@users.noreply.github.com> Date: Fri, 24 Jan 2025 15:31:02 +0100 Subject: [PATCH 056/233] update_messages_refactor --- nebula/core/engine.py | 24 +++++--- nebula/core/neighbormanagement/nodemanager.py | 15 ++++- nebula/core/network/actions.py | 59 +++++++++++++++++++ nebula/core/network/communications.py | 5 +- nebula/core/network/messages.py | 47 +++++++++++++++ 5 files changed, 139 insertions(+), 11 deletions(-) create mode 100644 nebula/core/network/actions.py diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 6913a3700..40d1b95e5 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -470,14 +470,22 @@ async def _discover_discover_join_callback(self, source, message): await self.trainning_in_progress_lock.release_async() if round != -1: epochs = self.config.participant["training_args"]["epochs"] - msg = self.cm.mm.generate_offer_message( - nebula_pb2.OfferMessage.Action.OFFER_MODEL, - len(self.get_federation_nodes()), - 0, #self.trainer.get_current_loss(), - model, - rounds, - round, - epochs + #msg = self.cm.mm.generate_offer_message( + # nebula_pb2.OfferMessage.Action.OFFER_MODEL, + # len(self.get_federation_nodes()), + # 0, #self.trainer.get_current_loss(), + # model, + # rounds, + # round, + # epochs + #) + msg = self.cm.create_message("offer", + "offer_model", + len(self.get_federation_nodes()), + model, + rounds, + round, + epochs ) await self.cm.send_offer_model(source, msg) else: diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 415edc7b8..c4e5fab1b 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -17,6 +17,8 @@ if TYPE_CHECKING: from nebula.core.engine import Engine +RESTRUCTURE_COOLDOWN = 5 + class NodeManager(): def __init__( @@ -47,6 +49,7 @@ def __init__( self.recieve_offer_timer = 5 self._restructure_process_lock = Locker(name="restructure_process_lock") self.restructure = False + self._restructure_cooldown = RESTRUCTURE_COOLDOWN self.discarded_offers_addr_lock = Locker(name="discarded_offers_addr_lock") self.discarded_offers_addr = [] self._push_acceleration = push_acceleration @@ -85,6 +88,13 @@ def fr(self): def fast_reboot_on(self): return self._fast_reboot_status + def _update_restructure_cooldown(self): + if self._restructure_cooldown: + self._restructure_cooldown = (self._restructure_cooldown + 1) % RESTRUCTURE_COOLDOWN + + def _restructure_available(self): + return self._restructure_cooldown == 0 + def get_push_acceleration(self): return self._push_acceleration @@ -291,7 +301,6 @@ async def check_external_connection_service_status(self): action = lambda: self.engine.cm.init_external_connection_service() return action - #TODO NOT infinite loop, define n_tries async def start_late_connection_process(self, connected=False, msg_type="discover_join", addrs_known=None): """ This function represents the process of discovering the federation and stablish the first @@ -323,6 +332,7 @@ async def start_late_connection_process(self, connected=False, msg_type="discove # create message to send to candidates selected if not connected: msg = self.engine.cm.mm.generate_connection_message(nebula_pb2.ConnectionMessage.Action.LATE_CONNECT) + msg = self.engine.cm.create_message("connection", "late_connect") else: msg = self.engine.cm.mm.generate_connection_message(nebula_pb2.ConnectionMessage.Action.RESTRUCTURE) @@ -363,8 +373,9 @@ async def check_robustness(self): if not self.neighbors_left(): logging.info("No Neighbors left | reconnecting with Federation") #await self.reconnect_to_federation() - elif self.neighbor_policy.need_more_neighbors() and self.engine.get_sinchronized_status(): + elif self.neighbor_policy.need_more_neighbors() and self.engine.get_sinchronized_status() and self._restructure_available(): logging.info("Insufficient Robustness | Upgrading robustness | Searching for more connections") + self._update_restructure_cooldown() asyncio.create_task(self.upgrade_connection_robustness()) else: if not self.engine.get_sinchronized_status(): diff --git a/nebula/core/network/actions.py b/nebula/core/network/actions.py new file mode 100644 index 000000000..deef068ba --- /dev/null +++ b/nebula/core/network/actions.py @@ -0,0 +1,59 @@ +from nebula.core.pb import nebula_pb2 +from enum import Enum + + +def factory_message_action(message_type: str, action: str): + options = { + "connection": ConnectionAction, + "federation": FederationAction, + "discovery": DiscoveryAction, + "control": ControlAction, + "discover": DiscoverAction, + "offer": OfferAction, + "link": LinkAction, + } + + message_actions = options.get(message_type, None) + + if message_actions: + normalized_action = action.upper() + enum_action = message_actions[normalized_action] + return enum_action + else: + return None + +class ConnectionAction(Enum): + CONNECT = nebula_pb2.ConnectionMessage.Action.CONNECT + DISCONNECT = nebula_pb2.ConnectionMessage.Action.DISCONNECT + LATE_CONNECT = nebula_pb2.ConnectionMessage.Action.LATE_CONNECT + RESTRUCTURE = nebula_pb2.ConnectionMessage.Action.RESTRUCTURE + +class FederationAction(Enum): + FEDERATION_START = nebula_pb2.FederationMessage.Action.FEDERATION_START + REPUTATION = nebula_pb2.FederationMessage.Action.REPUTATION + FEDERATION_MODELS_INCLUDED = nebula_pb2.FederationMessage.Action.FEDERATION_MODELS_INCLUDED + FEDERATION_READY = nebula_pb2.FederationMessage.Action.FEDERATION_READY + +class DiscoveryAction(Enum): + DISCOVER = nebula_pb2.DiscoveryMessage.Action.DISCOVER + REGISTER = nebula_pb2.DiscoveryMessage.Action.REGISTER + DEREGISTER = nebula_pb2.DiscoveryMessage.Action.DEREGISTER + +class ControlAction(Enum): + ALIVE = nebula_pb2.ControlMessage.Action.ALIVE + OVERHEAD = nebula_pb2.ControlMessage.Action.OVERHEAD + MOBILITY = nebula_pb2.ControlMessage.Action.MOBILITY + RECOVERY = nebula_pb2.ControlMessage.Action.RECOVERY + WEAK_LINK = nebula_pb2.ControlMessage.Action.WEAK_LINK + +class DiscoverAction(Enum): + DISCOVER_JOIN = nebula_pb2.DiscoverMessage.Action.DISCOVER_JOIN + DISCOVER_NODES = nebula_pb2.DiscoverMessage.Action.DISCOVER_NODES + +class OfferAction(Enum): + OFFER_MODEL = nebula_pb2.OfferMessage.Action.OFFER_MODEL + OFFER_METRIC = nebula_pb2.OfferMessage.Action.OFFER_METRIC + +class LinkAction(Enum): + CONNECT_TO = nebula_pb2.LinkMessage.Action.CONNECT_TO + DISCONNECT_FROM = nebula_pb2.LinkMessage.Action.DISCONNECT_FROM diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index ef5663a21..00b1fae5c 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -381,7 +381,10 @@ async def handle_link_message(self, source, message): await self.engine.event_manager.trigger_event(source, message) except Exception as e: logging.error(f"πŸ” handle_link_message | Error while processing: {message.action} {message.arguments} | {e}") - + + def create_message(self, message_type: str, action: str = "", **kwargs): + return self.mm.create_message(message_type, action, kwargs) + def start_external_connection_service(self): if self.ecs == None: self.ecs = NebulaConnectionService(self.addr) diff --git a/nebula/core/network/messages.py b/nebula/core/network/messages.py index 10573db0e..3a2022d9a 100755 --- a/nebula/core/network/messages.py +++ b/nebula/core/network/messages.py @@ -2,6 +2,8 @@ from typing import TYPE_CHECKING from nebula.core.pb import nebula_pb2 +from nebula.core.network.actions import factory_message_action +import inspect if TYPE_CHECKING: from nebula.core.network.communications import CommunicationsManager @@ -120,3 +122,48 @@ def generate_link_message(self, action, addrs): message_wrapper.link_message.CopyFrom(message) data = message_wrapper.SerializeToString() return data + + + def create_message(self, message_type: str, action: str = "", **kwargs): + message_action = None + if action: + message_action = factory_message_action(message_type, action) + + message_generators_map = { + "model": self.generate_model_message, + "reputation": self.generate_reputation_message, + "connection": self.generate_connection_message, + "federation": self.generate_federation_message, + "discovery": self.generate_discovery_message, + "control": self.generate_control_message, + "discover": self.generate_discover_message, + "offer": self.generate_offer_message, + "link": self.generate_link_message, + } + message_generator_function = message_generators_map.get(message_type) + if not message_generator_function: + raise ValueError(f"Invalid message type '{message_type}'") + + generator_signature = inspect.signature(message_generator_function) + generator_params = generator_signature.parameters + + generator_args = [] + generator_kwargs = {} + + if "action" in generator_params and message_action is not None: + generator_args.append(message_action) + + if "kwargs" in generator_params: + generator_kwargs.update(kwargs) + + if generator_kwargs: + message = message_generator_function(*generator_args, **generator_kwargs) + else: + message = message_generator_function(*generator_args) + + message_wrapper = nebula_pb2.Wrapper() + message_wrapper.source = self.addr + field_name = f"{message_type}_message" + getattr(message_wrapper, field_name).CopyFrom(message) + data = message_wrapper.SerializeToString() + return data From 7001c6c84cc739bd561f338b24f043af54941d19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Avil=C3=A9s=20Serrano?= <80918548+AlejandroAvilesSerrano@users.noreply.github.com> Date: Fri, 24 Jan 2025 17:18:38 +0100 Subject: [PATCH 057/233] fix_msg_errors --- nebula/core/engine.py | 11 ++++++----- .../candidateselection/fccandidateselector.py | 2 +- nebula/core/network/communications.py | 2 +- nebula/node.py | 2 +- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 40d1b95e5..05e4efc1d 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -481,11 +481,12 @@ async def _discover_discover_join_callback(self, source, message): #) msg = self.cm.create_message("offer", "offer_model", - len(self.get_federation_nodes()), - model, - rounds, - round, - epochs + n_neighbors=len(self.get_federation_nodes()), + loss=0, + parameters=model, + rounds=rounds, + round=round, + epochs=epochs ) await self.cm.send_offer_model(source, msg) else: diff --git a/nebula/core/neighbormanagement/candidateselection/fccandidateselector.py b/nebula/core/neighbormanagement/candidateselection/fccandidateselector.py index e097e5be9..6f1c71129 100644 --- a/nebula/core/neighbormanagement/candidateselection/fccandidateselector.py +++ b/nebula/core/neighbormanagement/candidateselection/fccandidateselector.py @@ -22,7 +22,7 @@ def select_candidates(self): #0145 #listed = ["192.168.51.2:45001", "192.168.51.3:45002", "192.168.51.6:45005", "192.168.51.7:45006"] #defined = [] - #self.candidates_lock.acquire() + self.candidates_lock.acquire() cdts = self.candidates.copy() #for (addr,a,b) in cdts: # if addr in listed: diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 00b1fae5c..bfaf5aeb2 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -383,7 +383,7 @@ async def handle_link_message(self, source, message): logging.error(f"πŸ” handle_link_message | Error while processing: {message.action} {message.arguments} | {e}") def create_message(self, message_type: str, action: str = "", **kwargs): - return self.mm.create_message(message_type, action, kwargs) + return self.mm.create_message(message_type, action, **kwargs) def start_external_connection_service(self): if self.ecs == None: diff --git a/nebula/node.py b/nebula/node.py index ce5a8d683..faa1f6500 100755 --- a/nebula/node.py +++ b/nebula/node.py @@ -361,7 +361,7 @@ def randomize_value(value, variability): #if config.participant["network_args"]["ip"] == "192.168.50.11": #time.sleep(820) - time.sleep(800) + time.sleep(300) #if config.participant["network_args"]["ip"] == "192.168.51.11": # logging.info("waiting 385s") From 614af0f8a37409e60a09e22f478f44f3d00f1dd7 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Sat, 25 Jan 2025 09:41:09 +0100 Subject: [PATCH 058/233] fix_messages_factory --- nebula/core/engine.py | 19 ++++++++---- nebula/core/neighbormanagement/nodemanager.py | 2 +- nebula/core/network/actions.py | 2 +- nebula/core/network/communications.py | 6 ++-- nebula/core/network/messages.py | 30 +++++++++++-------- nebula/node.py | 2 +- 6 files changed, 38 insertions(+), 23 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 05e4efc1d..c7a336b37 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -411,12 +411,14 @@ async def _connection_late_connect_callback(self, source, message): ct_actions , df_actions = self.nm.get_actions() if len(ct_actions): #for addr in ct_actions.split(): - cnt_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.CONNECT_TO, ct_actions) + #cnt_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.CONNECT_TO, ct_actions) + cnt_msg = self.cm.create_message("link", "connect_to", addrs=ct_actions) await self.cm.send_message(source, cnt_msg) if len(df_actions): #for addr in df_actions.split(): - df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, df_actions) + #df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, df_actions) + df_msg = self.cm.create_message("link", "disconnect_from", addrs=df_actions) await self.cm.send_message(source, df_msg) await self.nm.register_late_neighbor(source, joinning_federation=True) @@ -481,13 +483,13 @@ async def _discover_discover_join_callback(self, source, message): #) msg = self.cm.create_message("offer", "offer_model", - n_neighbors=len(self.get_federation_nodes()), - loss=0, - parameters=model, + len(self.get_federation_nodes()), + 0, + serialized_model=model, rounds=rounds, round=round, epochs=epochs - ) + ) await self.cm.send_offer_model(source, msg) else: logging.info("Discover join received before federation is running..") @@ -504,6 +506,11 @@ async def _discover_discover_nodes_callback(self, source, message): #self.nm.meet_node(source) if len(self.get_federation_nodes()) > 0: msg = self.cm.mm.generate_offer_message(nebula_pb2.OfferMessage.Action.OFFER_METRIC, len(self.get_federation_nodes()), self.trainer.get_current_loss()) + msg = self.cm.create_message("offer", + "offer_metric", + n_neighbors=len(self.get_federation_nodes()), + loss=self.trainer.get_current_loss() + ) await self.cm.send_message(source, msg) else: logging.info(f"πŸ”— Dissmissing discover nodes from {source} | no active connections at the moment") diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index c4e5fab1b..62752c758 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -332,7 +332,7 @@ async def start_late_connection_process(self, connected=False, msg_type="discove # create message to send to candidates selected if not connected: msg = self.engine.cm.mm.generate_connection_message(nebula_pb2.ConnectionMessage.Action.LATE_CONNECT) - msg = self.engine.cm.create_message("connection", "late_connect") + #msg = self.engine.cm.create_message("connection", "late_connect") else: msg = self.engine.cm.mm.generate_connection_message(nebula_pb2.ConnectionMessage.Action.RESTRUCTURE) diff --git a/nebula/core/network/actions.py b/nebula/core/network/actions.py index deef068ba..11cc71fe4 100644 --- a/nebula/core/network/actions.py +++ b/nebula/core/network/actions.py @@ -18,7 +18,7 @@ def factory_message_action(message_type: str, action: str): if message_actions: normalized_action = action.upper() enum_action = message_actions[normalized_action] - return enum_action + return enum_action.value else: return None diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index bfaf5aeb2..5107e7c93 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -414,9 +414,11 @@ async def stablish_connection_to_federation(self, msg_type="discover_join", addr addrs = addrs_known if msg_type=="discover_join": - msg = self.mm.generate_discover_message(nebula_pb2.DiscoverMessage.Action.DISCOVER_JOIN) + #msg = self.mm.generate_discover_message(nebula_pb2.DiscoverMessage.Action.DISCOVER_JOIN) + msg = self.create_message("discover", "discover_join") elif msg_type=="discover_nodes": - msg = self.mm.generate_discover_message(nebula_pb2.DiscoverMessage.Action.DISCOVER_NODES) + #msg = self.mm.generate_discover_message(nebula_pb2.DiscoverMessage.Action.DISCOVER_NODES) + msg = self.create_message("discover", "discover_nodes") logging.info("Starting communications with devices found") #TODO filtrar para para quitar las que ya son vecinos diff --git a/nebula/core/network/messages.py b/nebula/core/network/messages.py index 3a2022d9a..23176b748 100755 --- a/nebula/core/network/messages.py +++ b/nebula/core/network/messages.py @@ -88,6 +88,7 @@ def generate_discover_message(self, action): message = nebula_pb2.DiscoverMessage( action=action, ) + return message message_wrapper = nebula_pb2.Wrapper() message_wrapper.source = self.addr message_wrapper.discover_message.CopyFrom(message) @@ -105,6 +106,7 @@ def generate_offer_message(self, action, n_neighbors, loss, serialized_model=Non round = round, epochs = epochs ) + return message message_wrapper = nebula_pb2.Wrapper() message_wrapper.source = self.addr message_wrapper.offer_message.CopyFrom(message) @@ -117,6 +119,7 @@ def generate_link_message(self, action, addrs): action=action, addrs = addrs, ) + return message message_wrapper = nebula_pb2.Wrapper() message_wrapper.source = self.addr message_wrapper.link_message.CopyFrom(message) @@ -124,7 +127,7 @@ def generate_link_message(self, action, addrs): return data - def create_message(self, message_type: str, action: str = "", **kwargs): + def create_message(self, message_type: str, action: str = "", *args, **kwargs): message_action = None if action: message_action = factory_message_action(message_type, action) @@ -144,22 +147,25 @@ def create_message(self, message_type: str, action: str = "", **kwargs): if not message_generator_function: raise ValueError(f"Invalid message type '{message_type}'") + class_name = message_type.capitalize() + "Message" + message_class = getattr(nebula_pb2, class_name, None) + + if message_class is None: + raise AttributeError(f"Message type {message_type} not found on the protocol") + generator_signature = inspect.signature(message_generator_function) generator_params = generator_signature.parameters - generator_args = [] - generator_kwargs = {} - + logging.info(f"Parameters in message: {generator_params}") + if "action" in generator_params and message_action is not None: - generator_args.append(message_action) - - if "kwargs" in generator_params: - generator_kwargs.update(kwargs) + kwargs["action"] = message_action - if generator_kwargs: - message = message_generator_function(*generator_args, **generator_kwargs) - else: - message = message_generator_function(*generator_args) + if args: + for param_name, arg_value in zip(generator_params, args): + kwargs[param_name] = arg_value + + message = message_class(**kwargs) message_wrapper = nebula_pb2.Wrapper() message_wrapper.source = self.addr diff --git a/nebula/node.py b/nebula/node.py index faa1f6500..acfcbb2db 100755 --- a/nebula/node.py +++ b/nebula/node.py @@ -361,7 +361,7 @@ def randomize_value(value, variability): #if config.participant["network_args"]["ip"] == "192.168.50.11": #time.sleep(820) - time.sleep(300) + time.sleep(250) #if config.participant["network_args"]["ip"] == "192.168.51.11": # logging.info("waiting 385s") From 8adf02c1cd92dfdf335d3f65ed5ba9b2a857b95e Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Sat, 25 Jan 2025 10:30:54 +0100 Subject: [PATCH 059/233] fix_error --- nebula/core/engine.py | 2 +- nebula/core/network/communications.py | 4 ++-- nebula/core/network/messages.py | 2 ++ 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index c7a336b37..2832d6e6e 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -485,7 +485,7 @@ async def _discover_discover_join_callback(self, source, message): "offer_model", len(self.get_federation_nodes()), 0, - serialized_model=model, + parameters=model, rounds=rounds, round=round, epochs=epochs diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 5107e7c93..83c7b4ddb 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -382,8 +382,8 @@ async def handle_link_message(self, source, message): except Exception as e: logging.error(f"πŸ” handle_link_message | Error while processing: {message.action} {message.arguments} | {e}") - def create_message(self, message_type: str, action: str = "", **kwargs): - return self.mm.create_message(message_type, action, **kwargs) + def create_message(self, message_type: str, action: str = "", *args, **kwargs): + return self.mm.create_message(message_type, action, *args, **kwargs) def start_external_connection_service(self): if self.ecs == None: diff --git a/nebula/core/network/messages.py b/nebula/core/network/messages.py index 23176b748..5bfead059 100755 --- a/nebula/core/network/messages.py +++ b/nebula/core/network/messages.py @@ -131,6 +131,8 @@ def create_message(self, message_type: str, action: str = "", *args, **kwargs): message_action = None if action: message_action = factory_message_action(message_type, action) + logging.info(f"action defined: {message_action}") + message_generators_map = { "model": self.generate_model_message, From 25ca56d864ec2daa917b351e1d253e32681bf4c5 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Sat, 25 Jan 2025 10:43:29 +0100 Subject: [PATCH 060/233] fix_factory_message_action --- nebula/core/network/actions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nebula/core/network/actions.py b/nebula/core/network/actions.py index 11cc71fe4..deef068ba 100644 --- a/nebula/core/network/actions.py +++ b/nebula/core/network/actions.py @@ -18,7 +18,7 @@ def factory_message_action(message_type: str, action: str): if message_actions: normalized_action = action.upper() enum_action = message_actions[normalized_action] - return enum_action.value + return enum_action else: return None From 55bfb2e29a81bd17199f3f66fca159db09fb518c Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Sun, 26 Jan 2025 13:30:18 +0100 Subject: [PATCH 061/233] fix_message_factory feat message_template --- nebula/core/network/actions.py | 4 +- nebula/core/network/messages.py | 123 ++++++++++++++++++++++++-------- nebula/node.py | 2 +- 3 files changed, 96 insertions(+), 33 deletions(-) diff --git a/nebula/core/network/actions.py b/nebula/core/network/actions.py index deef068ba..8a0bfea0b 100644 --- a/nebula/core/network/actions.py +++ b/nebula/core/network/actions.py @@ -1,5 +1,6 @@ from nebula.core.pb import nebula_pb2 from enum import Enum +import logging def factory_message_action(message_type: str, action: str): @@ -18,7 +19,8 @@ def factory_message_action(message_type: str, action: str): if message_actions: normalized_action = action.upper() enum_action = message_actions[normalized_action] - return enum_action + logging.info(f"Message action: {enum_action}, value: {enum_action.value}") + return enum_action.value else: return None diff --git a/nebula/core/network/messages.py b/nebula/core/network/messages.py index 5bfead059..6387cc068 100755 --- a/nebula/core/network/messages.py +++ b/nebula/core/network/messages.py @@ -14,6 +14,66 @@ def __init__(self, addr, config, cm: "CommunicationsManager"): self.addr = addr self.config = config self.cm = cm + self._message_templates = {} + self._define_message_templates() + + def _define_message_templates(self): + # Dictionary that maps message types to their required parameters and default values + self._message_templates = { + "offer": { + "parameters": ["action", "n_neighbors", "loss", "parameters", "rounds", "round", "epochs"], + "defaults": { + "parameters": None, + "rounds": 1, + "round": -1, + "epochs": 1, + } + }, + "connection": { + "parameters": ["action"], + "defaults": {} + }, + "discovery": { + "parameters": ["action", "latitude", "longitude"], + "defaults": { + "latitude": 0.0, + "longitude": 0.0, + } + }, + "control": { + "parameters": ["action", "log"], + "defaults": { + "log": "Control message", + } + }, + "federation": { + "parameters": ["action", "arguments", "round"], + "defaults": { + "arguments": [], + "round": None, + } + }, + "model": { + "parameters": ["action", "round", "parameters", "weight"], + "defaults": { + "weight": 1, + } + }, + "reputation": { + "parameters": ["reputation"], + "defaults": {} + }, + "discover": { + "parameters": ["action"], + "defaults": {} + }, + "link": { + "parameters": ["action", "addrs"], + "defaults": {} + }, + # Add additional message types here + } + def generate_discovery_message(self, action, latitude=0.0, longitude=0.0): message = nebula_pb2.DiscoveryMessage( @@ -52,10 +112,10 @@ def generate_federation_message(self, action, arguments=[], round=None): data = message_wrapper.SerializeToString() return data - def generate_model_message(self, round, serialized_model, weight=1): + def generate_model_message(self, round, parameters, weight=1): message = nebula_pb2.ModelMessage( round=round, - parameters=serialized_model, + parameters=parameters, weight=weight, ) message_wrapper = nebula_pb2.Wrapper() @@ -95,13 +155,12 @@ def generate_discover_message(self, action): data = message_wrapper.SerializeToString() return data - - def generate_offer_message(self, action, n_neighbors, loss, serialized_model=None, rounds=1, round=-1, epochs = 1): + def generate_offer_message(self, action, n_neighbors, loss, parameters=None, rounds=1, round=-1, epochs = 1): message = nebula_pb2.OfferMessage( action=action, n_neighbors = n_neighbors, loss = loss, - parameters = serialized_model, + parameters = parameters, rounds = rounds, round = round, epochs = epochs @@ -113,8 +172,8 @@ def generate_offer_message(self, action, n_neighbors, loss, serialized_model=Non data = message_wrapper.SerializeToString() return data - def generate_link_message(self, action, addrs): + pass message = nebula_pb2.LinkMessage( action=action, addrs = addrs, @@ -128,45 +187,47 @@ def generate_link_message(self, action, addrs): def create_message(self, message_type: str, action: str = "", *args, **kwargs): + # If an action is provided, convert it to its corresponding enum value using the factory message_action = None if action: message_action = factory_message_action(message_type, action) - logging.info(f"action defined: {message_action}") - - - message_generators_map = { - "model": self.generate_model_message, - "reputation": self.generate_reputation_message, - "connection": self.generate_connection_message, - "federation": self.generate_federation_message, - "discovery": self.generate_discovery_message, - "control": self.generate_control_message, - "discover": self.generate_discover_message, - "offer": self.generate_offer_message, - "link": self.generate_link_message, - } - message_generator_function = message_generators_map.get(message_type) - if not message_generator_function: + + # Retrieve the template for the provided message type + message_template = self._message_templates.get(message_type) + if not message_template: raise ValueError(f"Invalid message type '{message_type}'") + # Extract parameters and defaults from the template + template_params = message_template["parameters"] + default_values: dict = message_template.get("defaults", {}) + + # Dynamically retrieve the class for the protobuf message (e.g., OfferMessage) class_name = message_type.capitalize() + "Message" message_class = getattr(nebula_pb2, class_name, None) if message_class is None: raise AttributeError(f"Message type {message_type} not found on the protocol") - generator_signature = inspect.signature(message_generator_function) - generator_params = generator_signature.parameters - - logging.info(f"Parameters in message: {generator_params}") - - if "action" in generator_params and message_action is not None: + # Set the 'action' parameter if required and if the message_action is available + if "action" in template_params and message_action is not None: kwargs["action"] = message_action - + + # Map positional arguments to template parameters + remaining_params = [param_name for param_name in template_params if param_name not in kwargs] if args: - for param_name, arg_value in zip(generator_params, args): + for param_name, arg_value in zip(remaining_params, args): + if param_name in kwargs: + continue kwargs[param_name] = arg_value - + + # Fill in missing parameters with their default values + # logging.info(f"kwargs parameters: {kwargs.keys()}") + for param_name in template_params: + if param_name not in kwargs: + logging.info(f"Filling parameter '{param_name}' with default value: {default_values.get(param_name)}") + kwargs[param_name] = default_values.get(param_name) + + # Create an instance of the protobuf message class using the constructed kwargs message = message_class(**kwargs) message_wrapper = nebula_pb2.Wrapper() diff --git a/nebula/node.py b/nebula/node.py index acfcbb2db..8fca023b4 100755 --- a/nebula/node.py +++ b/nebula/node.py @@ -361,7 +361,7 @@ def randomize_value(value, variability): #if config.participant["network_args"]["ip"] == "192.168.50.11": #time.sleep(820) - time.sleep(250) + time.sleep(200) #if config.participant["network_args"]["ip"] == "192.168.51.11": # logging.info("waiting 385s") From 13547ba09c097a07c82a926d1f5c9f01967b0f3f Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Sun, 26 Jan 2025 16:24:22 +0100 Subject: [PATCH 062/233] feat_refactor_messages --- nebula/core/aggregation/aggregator.py | 9 +- nebula/core/engine.py | 44 ++-- nebula/core/neighbormanagement/nodemanager.py | 5 +- nebula/core/network/communications.py | 13 +- nebula/core/network/discoverer.py | 11 +- nebula/core/network/health.py | 7 +- nebula/core/network/messages.py | 224 +++++++++--------- 7 files changed, 158 insertions(+), 155 deletions(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index 51424e871..1f027fa90 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -199,10 +199,11 @@ async def include_model_in_buffer(self, model, weight, source=None, round=None, logging.info( f"πŸ”„ include_model_in_buffer | Broadcasting MODELS_INCLUDED for round {self.engine.get_round()}" ) - message = self.cm.mm.generate_federation_message( - nebula_pb2.FederationMessage.Action.FEDERATION_MODELS_INCLUDED, - [self.engine.get_round()], - ) + #message = self.cm.mm.generate_federation_message( + # nebula_pb2.FederationMessage.Action.FEDERATION_MODELS_INCLUDED, + # [self.engine.get_round()], + #) + message = self.cm.create_message("federation", "federation_models_included", [str(arg) for arg in [self.engine.get_round()]]) await self.cm.send_message_to_neighbors(message) return diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 2832d6e6e..dce7802e1 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -383,8 +383,11 @@ async def _federation_models_included_callback(self, source, message): await self.cm.get_connections_lock().release_async() + """ ############################## + # Mobility callbacks # + ############################## + """ - # Mobility callbacks @event_handler( nebula_pb2.ConnectionMessage, nebula_pb2.ConnectionMessage.Action.LATE_CONNECT, @@ -405,19 +408,17 @@ async def _connection_late_connect_callback(self, source, message): await self.cm.connect(source, direct=True) # Verify conenction is accepted - conf_msg = self.cm.mm.generate_connection_message(nebula_pb2.ConnectionMessage.Action.LATE_CONNECT) + conf_msg = self.cm.create_message("connection", "late_connect") await self.cm.send_message(source, conf_msg) ct_actions , df_actions = self.nm.get_actions() if len(ct_actions): #for addr in ct_actions.split(): - #cnt_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.CONNECT_TO, ct_actions) cnt_msg = self.cm.create_message("link", "connect_to", addrs=ct_actions) await self.cm.send_message(source, cnt_msg) if len(df_actions): #for addr in df_actions.split(): - #df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, df_actions) df_msg = self.cm.create_message("link", "disconnect_from", addrs=df_actions) await self.cm.send_message(source, df_msg) @@ -444,16 +445,20 @@ async def _connection_restructure_callback(self, source, message): logging.info(f"πŸ”— handle_connection_message | Trigger | restructure connection accepted from {source}") await self.cm.connect(source, direct=True) - conf_msg = self.cm.mm.generate_connection_message(nebula_pb2.ConnectionMessage.Action.RESTRUCTURE) + #conf_msg = self.cm.mm.generate_connection_message(nebula_pb2.ConnectionMessage.Action.RESTRUCTURE) + conf_msg = self.cm.create_message("connection", "restructure") + await self.cm.send_message(source, conf_msg) ct_actions , df_actions = self.nm.get_actions() if len(ct_actions): - cnt_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.CONNECT_TO, ct_actions) + #cnt_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.CONNECT_TO, ct_actions) + cnt_msg = self.cm.create_message("link", "connect_to", addrs=ct_actions) await self.cm.send_message(source, cnt_msg) if len(df_actions): - df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, df_actions) + #df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, df_actions) + df_msg = self.cm.create_message("link", "disconnect_from", addrs=df_actions) await self.cm.send_message(source, df_msg) await self.nm.register_late_neighbor(source, joinning_federation=False) @@ -472,15 +477,6 @@ async def _discover_discover_join_callback(self, source, message): await self.trainning_in_progress_lock.release_async() if round != -1: epochs = self.config.participant["training_args"]["epochs"] - #msg = self.cm.mm.generate_offer_message( - # nebula_pb2.OfferMessage.Action.OFFER_MODEL, - # len(self.get_federation_nodes()), - # 0, #self.trainer.get_current_loss(), - # model, - # rounds, - # round, - # epochs - #) msg = self.cm.create_message("offer", "offer_model", len(self.get_federation_nodes()), @@ -505,7 +501,7 @@ async def _discover_discover_nodes_callback(self, source, message): logging.info(f"πŸ” handle_discover_message | Trigger | Received discover_node message from {source} ") #self.nm.meet_node(source) if len(self.get_federation_nodes()) > 0: - msg = self.cm.mm.generate_offer_message(nebula_pb2.OfferMessage.Action.OFFER_METRIC, len(self.get_federation_nodes()), self.trainer.get_current_loss()) + #msg = self.cm.mm.generate_offer_message(nebula_pb2.OfferMessage.Action.OFFER_METRIC, len(self.get_federation_nodes()), self.trainer.get_current_loss()) msg = self.cm.create_message("offer", "offer_metric", n_neighbors=len(self.get_federation_nodes()), @@ -570,7 +566,12 @@ async def _link_disconnect_from_callback(self, source, message): for addr in addrs.split(): await self.cm.disconnect(source, mutual_disconnection=False) await self.nm.update_neighbors(addr, remove=True) - + + """ ############################## + # ENGINE FUNCTIONALITY # + ############################## + """ + async def _aditional_node_start(self): self.update_sinchronized_status(False) logging.info(f"Aditional node | {self.addr} | going to stablish connection with federation") @@ -684,7 +685,8 @@ async def deploy_federation(self): while not await self.cm.check_federation_ready(): await asyncio.sleep(1) logging.info("Sending FEDERATION_START to neighbors...") - message = self.cm.mm.generate_federation_message(nebula_pb2.FederationMessage.Action.FEDERATION_START) + #message = self.cm.mm.generate_federation_message(nebula_pb2.FederationMessage.Action.FEDERATION_START) + message = self.cm.create_message("federation", "federation_start") await self.cm.send_message_to_neighbors(message) await self.get_federation_ready_lock().release_async() await self.create_trainer_module() @@ -694,7 +696,8 @@ async def deploy_federation(self): else: logging.info("Sending FEDERATION_READY to neighbors...") - message = self.cm.mm.generate_federation_message(nebula_pb2.FederationMessage.Action.FEDERATION_READY) + #message = self.cm.mm.generate_federation_message(nebula_pb2.FederationMessage.Action.FEDERATION_READY) + message = self.cm.create_message("federation", "federation_ready") await self.cm.send_message_to_neighbors(message) logging.info("πŸ’€ Waiting until receiving the start signal from the start node") @@ -919,6 +922,7 @@ async def send_reputation(self, malicious_nodes): message = self.cm.mm.generate_federation_message( nebula_pb2.FederationMessage.Action.REPUTATION, malicious_nodes ) + message = self.cm.create_message("federation","reputation", arguments=[str(arg) for arg in (malicious_nodes)]) await self.cm.send_message_to_neighbors(message) diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 62752c758..dc635750e 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -331,10 +331,9 @@ async def start_late_connection_process(self, connected=False, msg_type="discove logging.info("Candidates found to connect to...") # create message to send to candidates selected if not connected: - msg = self.engine.cm.mm.generate_connection_message(nebula_pb2.ConnectionMessage.Action.LATE_CONNECT) - #msg = self.engine.cm.create_message("connection", "late_connect") + msg = self.engine.cm.create_message("connection", "late_connect") else: - msg = self.engine.cm.mm.generate_connection_message(nebula_pb2.ConnectionMessage.Action.RESTRUCTURE) + msg = self.engine.cm.create_message("connection", "restructure") best_candidates = self.candidate_selector.select_candidates() logging.info(f"Candidates | {[addr for addr,_,_ in best_candidates]}") diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 83c7b4ddb..3d9cbc651 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -413,12 +413,7 @@ async def stablish_connection_to_federation(self, msg_type="discover_join", addr logging.info("Searching federation process beginning... | Using addrs previously known") addrs = addrs_known - if msg_type=="discover_join": - #msg = self.mm.generate_discover_message(nebula_pb2.DiscoverMessage.Action.DISCOVER_JOIN) - msg = self.create_message("discover", "discover_join") - elif msg_type=="discover_nodes": - #msg = self.mm.generate_discover_message(nebula_pb2.DiscoverMessage.Action.DISCOVER_NODES) - msg = self.create_message("discover", "discover_nodes") + msg = self.create_message("discover", msg_type) logging.info("Starting communications with devices found") #TODO filtrar para para quitar las que ya son vecinos @@ -765,7 +760,8 @@ async def send_model(self, dest_addr, round, serialized_model, weight=1): logging.info( f"Sending model to {dest_addr} with round {round}: weight={weight} |Β size={sys.getsizeof(serialized_model) / (1024** 2) if serialized_model is not None else 0} MB" ) - message = self.mm.generate_model_message(round, serialized_model, weight) + #message = self.mm.generate_model_message(round, serialized_model, weight) + message = self.create_message("model", round, serialized_model, weight) await conn.send(data=message, is_compressed=True) logging.info(f"Model sent to {dest_addr} with round {round}") except Exception as e: @@ -976,7 +972,8 @@ async def disconnect(self, dest_addr, mutual_disconnection=True): try: if mutual_disconnection: await self.connections[dest_addr].send( - data=self.mm.generate_connection_message(nebula_pb2.ConnectionMessage.Action.DISCONNECT) + #data=self.mm.generate_connection_message(nebula_pb2.ConnectionMessage.Action.DISCONNECT) + data=self.create_message("connection", "disconnect") ) await asyncio.sleep(1) self.connections[dest_addr].stop() diff --git a/nebula/core/network/discoverer.py b/nebula/core/network/discoverer.py index f341ea29e..352413815 100755 --- a/nebula/core/network/discoverer.py +++ b/nebula/core/network/discoverer.py @@ -31,11 +31,12 @@ async def run_discover(self): if len(self.cm.connections) > 0: latitude = self.config.participant["mobility_args"]["latitude"] longitude = self.config.participant["mobility_args"]["longitude"] - message = self.cm.mm.generate_discovery_message( - action=nebula_pb2.DiscoveryMessage.Action.DISCOVER, - latitude=latitude, - longitude=longitude, - ) + #message = self.cm.mm.generate_discovery_message( + # action=nebula_pb2.DiscoveryMessage.Action.DISCOVER, + # latitude=latitude, + # longitude=longitude, + #) + message = self.cm.create_message("discovery", "discover", latitude=latitude, longitude=longitude) try: logging.debug("πŸ” Sending discovery message to neighbors...") current_connections = await self.cm.get_addrs_current_connections(only_direct=True) diff --git a/nebula/core/network/health.py b/nebula/core/network/health.py index b294ee65c..1b4cd1d7e 100755 --- a/nebula/core/network/health.py +++ b/nebula/core/network/health.py @@ -32,9 +32,10 @@ async def run_send_alive(self): conn.set_active(True) while True: if len(self.cm.connections) > 0: - message = self.cm.mm.generate_control_message( - nebula_pb2.ControlMessage.Action.ALIVE, log="Alive message" - ) + #message = self.cm.mm.generate_control_message( + # nebula_pb2.ControlMessage.Action.ALIVE, log="Alive message" + #) + message = self.cm.create_message("control", "alive", log="Alive message") current_connections = list(self.cm.connections.values()) for conn in current_connections: if conn.get_direct(): diff --git a/nebula/core/network/messages.py b/nebula/core/network/messages.py index 6387cc068..3f0f6aca0 100755 --- a/nebula/core/network/messages.py +++ b/nebula/core/network/messages.py @@ -54,7 +54,7 @@ def _define_message_templates(self): } }, "model": { - "parameters": ["action", "round", "parameters", "weight"], + "parameters": ["round", "parameters", "weight"], "defaults": { "weight": 1, } @@ -75,117 +75,117 @@ def _define_message_templates(self): } - def generate_discovery_message(self, action, latitude=0.0, longitude=0.0): - message = nebula_pb2.DiscoveryMessage( - action=action, - latitude=latitude, - longitude=longitude, - ) - message_wrapper = nebula_pb2.Wrapper() - message_wrapper.source = self.addr - message_wrapper.discovery_message.CopyFrom(message) - data = message_wrapper.SerializeToString() - return data - - def generate_control_message(self, action, log="Control message"): - message = nebula_pb2.ControlMessage( - action=action, - log=log, - ) - message_wrapper = nebula_pb2.Wrapper() - message_wrapper.source = self.addr - message_wrapper.control_message.CopyFrom(message) - data = message_wrapper.SerializeToString() - return data - - def generate_federation_message(self, action, arguments=[], round=None): - logging.info(f"Building federation message with [Action {action}], arguments {arguments}, and round {round}") - - message = nebula_pb2.FederationMessage( - action=action, - arguments=[str(arg) for arg in (arguments or [])], - round=round, - ) - message_wrapper = nebula_pb2.Wrapper() - message_wrapper.source = self.addr - message_wrapper.federation_message.CopyFrom(message) - data = message_wrapper.SerializeToString() - return data - - def generate_model_message(self, round, parameters, weight=1): - message = nebula_pb2.ModelMessage( - round=round, - parameters=parameters, - weight=weight, - ) - message_wrapper = nebula_pb2.Wrapper() - message_wrapper.source = self.addr - message_wrapper.model_message.CopyFrom(message) - data = message_wrapper.SerializeToString() - return data - - def generate_connection_message(self, action): - message = nebula_pb2.ConnectionMessage( - action=action, - ) - message_wrapper = nebula_pb2.Wrapper() - message_wrapper.source = self.addr - message_wrapper.connection_message.CopyFrom(message) - data = message_wrapper.SerializeToString() - return data - - def generate_reputation_message(self, reputation): - message = nebula_pb2.ReputationMessage( - reputation=reputation, - ) - message_wrapper = nebula_pb2.Wrapper() - message_wrapper.source = self.addr - message_wrapper.reputation_message.CopyFrom(message) - data = message_wrapper.SerializeToString() - return data - - def generate_discover_message(self, action): - message = nebula_pb2.DiscoverMessage( - action=action, - ) - return message - message_wrapper = nebula_pb2.Wrapper() - message_wrapper.source = self.addr - message_wrapper.discover_message.CopyFrom(message) - data = message_wrapper.SerializeToString() - return data - - def generate_offer_message(self, action, n_neighbors, loss, parameters=None, rounds=1, round=-1, epochs = 1): - message = nebula_pb2.OfferMessage( - action=action, - n_neighbors = n_neighbors, - loss = loss, - parameters = parameters, - rounds = rounds, - round = round, - epochs = epochs - ) - return message - message_wrapper = nebula_pb2.Wrapper() - message_wrapper.source = self.addr - message_wrapper.offer_message.CopyFrom(message) - data = message_wrapper.SerializeToString() - return data - - def generate_link_message(self, action, addrs): - pass - message = nebula_pb2.LinkMessage( - action=action, - addrs = addrs, - ) - return message - message_wrapper = nebula_pb2.Wrapper() - message_wrapper.source = self.addr - message_wrapper.link_message.CopyFrom(message) - data = message_wrapper.SerializeToString() - return data - - + """ def generate_discovery_message(self, action, latitude=0.0, longitude=0.0): + message = nebula_pb2.DiscoveryMessage( + action=action, + latitude=latitude, + longitude=longitude, + ) + message_wrapper = nebula_pb2.Wrapper() + message_wrapper.source = self.addr + message_wrapper.discovery_message.CopyFrom(message) + data = message_wrapper.SerializeToString() + return data + + def generate_control_message(self, action, log="Control message"): + message = nebula_pb2.ControlMessage( + action=action, + log=log, + ) + message_wrapper = nebula_pb2.Wrapper() + message_wrapper.source = self.addr + message_wrapper.control_message.CopyFrom(message) + data = message_wrapper.SerializeToString() + return data + + def generate_federation_message(self, action, arguments=[], round=None): + logging.info(f"Building federation message with [Action {action}], arguments {arguments}, and round {round}") + + message = nebula_pb2.FederationMessage( + action=action, + arguments=[str(arg) for arg in (arguments or [])], + round=round, + ) + message_wrapper = nebula_pb2.Wrapper() + message_wrapper.source = self.addr + message_wrapper.federation_message.CopyFrom(message) + data = message_wrapper.SerializeToString() + return data + + def generate_model_message(self, round, parameters, weight=1): + message = nebula_pb2.ModelMessage( + round=round, + parameters=parameters, + weight=weight, + ) + message_wrapper = nebula_pb2.Wrapper() + message_wrapper.source = self.addr + message_wrapper.model_message.CopyFrom(message) + data = message_wrapper.SerializeToString() + return data + + def generate_connection_message(self, action): + message = nebula_pb2.ConnectionMessage( + action=action, + ) + message_wrapper = nebula_pb2.Wrapper() + message_wrapper.source = self.addr + message_wrapper.connection_message.CopyFrom(message) + data = message_wrapper.SerializeToString() + return data + + def generate_reputation_message(self, reputation): + message = nebula_pb2.ReputationMessage( + reputation=reputation, + ) + message_wrapper = nebula_pb2.Wrapper() + message_wrapper.source = self.addr + message_wrapper.reputation_message.CopyFrom(message) + data = message_wrapper.SerializeToString() + return data + + def generate_discover_message(self, action): + message = nebula_pb2.DiscoverMessage( + action=action, + ) + return message + message_wrapper = nebula_pb2.Wrapper() + message_wrapper.source = self.addr + message_wrapper.discover_message.CopyFrom(message) + data = message_wrapper.SerializeToString() + return data + + def generate_offer_message(self, action, n_neighbors, loss, parameters=None, rounds=1, round=-1, epochs = 1): + message = nebula_pb2.OfferMessage( + action=action, + n_neighbors = n_neighbors, + loss = loss, + parameters = parameters, + rounds = rounds, + round = round, + epochs = epochs + ) + return message + message_wrapper = nebula_pb2.Wrapper() + message_wrapper.source = self.addr + message_wrapper.offer_message.CopyFrom(message) + data = message_wrapper.SerializeToString() + return data + + def generate_link_message(self, action, addrs): + pass + message = nebula_pb2.LinkMessage( + action=action, + addrs = addrs, + ) + return message + message_wrapper = nebula_pb2.Wrapper() + message_wrapper.source = self.addr + message_wrapper.link_message.CopyFrom(message) + data = message_wrapper.SerializeToString() + return data + + """ def create_message(self, message_type: str, action: str = "", *args, **kwargs): # If an action is provided, convert it to its corresponding enum value using the factory message_action = None From 849d1b8ea6f6e61e62046c351e55fa442b446453 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Sun, 26 Jan 2025 17:15:00 +0100 Subject: [PATCH 063/233] fix_clean_code --- nebula/core/engine.py | 6 +- nebula/core/network/communications.py | 3 +- nebula/core/network/messages.py | 117 +------------------------- 3 files changed, 8 insertions(+), 118 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index dce7802e1..f306c87d9 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -919,9 +919,9 @@ def reputation_calculation(self, aggregated_models_weights): async def send_reputation(self, malicious_nodes): logging.info(f"Sending REPUTATION to the rest of the topology: {malicious_nodes}") - message = self.cm.mm.generate_federation_message( - nebula_pb2.FederationMessage.Action.REPUTATION, malicious_nodes - ) + #message = self.cm.mm.generate_federation_message( + # nebula_pb2.FederationMessage.Action.REPUTATION, malicious_nodes + #) message = self.cm.create_message("federation","reputation", arguments=[str(arg) for arg in (malicious_nodes)]) await self.cm.send_message_to_neighbors(message) diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 3d9cbc651..1bd39f218 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -761,7 +761,8 @@ async def send_model(self, dest_addr, round, serialized_model, weight=1): f"Sending model to {dest_addr} with round {round}: weight={weight} |Β size={sys.getsizeof(serialized_model) / (1024** 2) if serialized_model is not None else 0} MB" ) #message = self.mm.generate_model_message(round, serialized_model, weight) - message = self.create_message("model", round, serialized_model, weight) + parameters = serialized_model + message = self.create_message("model", "", round, parameters, weight) await conn.send(data=message, is_compressed=True) logging.info(f"Model sent to {dest_addr} with round {round}") except Exception as e: diff --git a/nebula/core/network/messages.py b/nebula/core/network/messages.py index 3f0f6aca0..84fcdc6f0 100755 --- a/nebula/core/network/messages.py +++ b/nebula/core/network/messages.py @@ -74,121 +74,10 @@ def _define_message_templates(self): # Add additional message types here } - - """ def generate_discovery_message(self, action, latitude=0.0, longitude=0.0): - message = nebula_pb2.DiscoveryMessage( - action=action, - latitude=latitude, - longitude=longitude, - ) - message_wrapper = nebula_pb2.Wrapper() - message_wrapper.source = self.addr - message_wrapper.discovery_message.CopyFrom(message) - data = message_wrapper.SerializeToString() - return data - - def generate_control_message(self, action, log="Control message"): - message = nebula_pb2.ControlMessage( - action=action, - log=log, - ) - message_wrapper = nebula_pb2.Wrapper() - message_wrapper.source = self.addr - message_wrapper.control_message.CopyFrom(message) - data = message_wrapper.SerializeToString() - return data - - def generate_federation_message(self, action, arguments=[], round=None): - logging.info(f"Building federation message with [Action {action}], arguments {arguments}, and round {round}") - - message = nebula_pb2.FederationMessage( - action=action, - arguments=[str(arg) for arg in (arguments or [])], - round=round, - ) - message_wrapper = nebula_pb2.Wrapper() - message_wrapper.source = self.addr - message_wrapper.federation_message.CopyFrom(message) - data = message_wrapper.SerializeToString() - return data - - def generate_model_message(self, round, parameters, weight=1): - message = nebula_pb2.ModelMessage( - round=round, - parameters=parameters, - weight=weight, - ) - message_wrapper = nebula_pb2.Wrapper() - message_wrapper.source = self.addr - message_wrapper.model_message.CopyFrom(message) - data = message_wrapper.SerializeToString() - return data - - def generate_connection_message(self, action): - message = nebula_pb2.ConnectionMessage( - action=action, - ) - message_wrapper = nebula_pb2.Wrapper() - message_wrapper.source = self.addr - message_wrapper.connection_message.CopyFrom(message) - data = message_wrapper.SerializeToString() - return data - - def generate_reputation_message(self, reputation): - message = nebula_pb2.ReputationMessage( - reputation=reputation, - ) - message_wrapper = nebula_pb2.Wrapper() - message_wrapper.source = self.addr - message_wrapper.reputation_message.CopyFrom(message) - data = message_wrapper.SerializeToString() - return data - - def generate_discover_message(self, action): - message = nebula_pb2.DiscoverMessage( - action=action, - ) - return message - message_wrapper = nebula_pb2.Wrapper() - message_wrapper.source = self.addr - message_wrapper.discover_message.CopyFrom(message) - data = message_wrapper.SerializeToString() - return data - - def generate_offer_message(self, action, n_neighbors, loss, parameters=None, rounds=1, round=-1, epochs = 1): - message = nebula_pb2.OfferMessage( - action=action, - n_neighbors = n_neighbors, - loss = loss, - parameters = parameters, - rounds = rounds, - round = round, - epochs = epochs - ) - return message - message_wrapper = nebula_pb2.Wrapper() - message_wrapper.source = self.addr - message_wrapper.offer_message.CopyFrom(message) - data = message_wrapper.SerializeToString() - return data - - def generate_link_message(self, action, addrs): - pass - message = nebula_pb2.LinkMessage( - action=action, - addrs = addrs, - ) - return message - message_wrapper = nebula_pb2.Wrapper() - message_wrapper.source = self.addr - message_wrapper.link_message.CopyFrom(message) - data = message_wrapper.SerializeToString() - return data - - """ def create_message(self, message_type: str, action: str = "", *args, **kwargs): + #logging.info(f"Creating message | type: {message_type}, action: {action}, positionals: {args}, explicits: {kwargs.keys()}") # If an action is provided, convert it to its corresponding enum value using the factory - message_action = None + message_action = None if action: message_action = factory_message_action(message_type, action) @@ -221,7 +110,7 @@ def create_message(self, message_type: str, action: str = "", *args, **kwargs): kwargs[param_name] = arg_value # Fill in missing parameters with their default values - # logging.info(f"kwargs parameters: {kwargs.keys()}") + logging.info(f"kwargs parameters: {kwargs.keys()}") for param_name in template_params: if param_name not in kwargs: logging.info(f"Filling parameter '{param_name}' with default value: {default_values.get(param_name)}") From 5812aa3019134d9df25db37fd37b4f752489c9ac Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Mon, 27 Jan 2025 10:35:53 +0100 Subject: [PATCH 064/233] fix_refactor_communciations --- nebula/core/network/communications.py | 17 ++++++++ nebula/core/network/messages.py | 60 ++++++++++++++++++++++++++- 2 files changed, 75 insertions(+), 2 deletions(-) diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 1bd39f218..3ec0d325d 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -143,6 +143,8 @@ async def add_ready_connection(self, addr): self.ready_connections.add(addr) async def handle_incoming_message(self, data, addr_from): + await self.mm.process_message(data, addr_from) + return try: message_wrapper = nebula_pb2.Wrapper() message_wrapper.ParseFromString(data) @@ -190,6 +192,21 @@ async def handle_incoming_message(self, data, addr_from): logging.exception(f"πŸ“₯ handle_incoming_message | Error while processing: {e}") logging.exception(traceback.format_exc()) + + async def forward_message(self, data, addr_from): + await self.forwarder.forward(data, addr_from=addr_from) + + # generic point to handle messages + async def handle_message(self, source, msg_type, message): + logging.info( + f"πŸ” handle_{msg_type} | Received [Action {message.action}] from {source}" + ) + try: + await self.engine.event_manager.trigger_event(source, message) + except Exception as e: + logging.exception(f"πŸ” handle_{msg_type} | Error while processing: {e}") + + async def handle_discovery_message(self, source, message): logging.info( f"πŸ” handle_discovery_message | Received [Action {message.action}] from {source} (network propagation)" diff --git a/nebula/core/network/messages.py b/nebula/core/network/messages.py index 84fcdc6f0..163138a94 100755 --- a/nebula/core/network/messages.py +++ b/nebula/core/network/messages.py @@ -3,7 +3,8 @@ from nebula.core.pb import nebula_pb2 from nebula.core.network.actions import factory_message_action -import inspect +import hashlib +import traceback if TYPE_CHECKING: from nebula.core.network.communications import CommunicationsManager @@ -73,7 +74,62 @@ def _define_message_templates(self): }, # Add additional message types here } + + async def process_message(self, data, addr_from): + not_processing_messages = {"control_message", "connection_message"} + special_processing_messages = {"discovery_message", "federation_message", "model_message"} + + try: + message_wrapper = nebula_pb2.Wrapper() + message_wrapper.ParseFromString(data) + source = message_wrapper.source + logging.debug(f"πŸ“₯ handle_incoming_message | Received message from {addr_from} with source {source}") + if source == self.addr: + return + + # Extract the active message from the oneof field + message_type = message_wrapper.WhichOneof("message") + if not message_type: + logging.warning("Received message with no active field in the 'oneof'") + return + message_data = getattr(message_wrapper, message_type) + + # Not required processing messages + if message_type in not_processing_messages: + await self.cm.handle_message(source, message_type, message_data) + + # Message-specific forwarding and processing + elif message_type in special_processing_messages: + if await self.cm.include_received_message_hash(hashlib.md5(data).hexdigest()): + # Forward the message if required + if self._should_forward_message(message_type, message_wrapper): + self.cm.forward_message(data, addr_from) + + if message_type == "model_message": + self.cm.handle_model_message(source, message_data) + else: + await self.cm.handle_message(source, message_type, message_data) + + # Rest of messages + else: + if await self.cm.include_received_message_hash(hashlib.md5(data).hexdigest()): + await self.cm.handle_message(source, message_type, message_data) + except Exception as e: + logging.exception(f"πŸ“₯ handle_incoming_message | Error while processing: {e}") + logging.exception(traceback.format_exc()) + + def _should_forward_message(self, message_type, message_wrapper): + if self.cm.config.participant["device_args"]["proxy"]: + return True + # TODO: Improve the technique. Now only forward model messages if the node is a proxy + # Need to update the expected model messages receiving during the round + # Round -1 is the initialization round --> all nodes should receive the model + if message_type == "model_message" and message_wrapper.model_message.round == -1: + return True + if message_type == "federation_message" and message_wrapper.federation_message.action == nebula_pb2.FederationMessage.Action.Value("FEDERATION_START"): + return True + def create_message(self, message_type: str, action: str = "", *args, **kwargs): #logging.info(f"Creating message | type: {message_type}, action: {action}, positionals: {args}, explicits: {kwargs.keys()}") # If an action is provided, convert it to its corresponding enum value using the factory @@ -110,7 +166,7 @@ def create_message(self, message_type: str, action: str = "", *args, **kwargs): kwargs[param_name] = arg_value # Fill in missing parameters with their default values - logging.info(f"kwargs parameters: {kwargs.keys()}") + # logging.info(f"kwargs parameters: {kwargs.keys()}") for param_name in template_params: if param_name not in kwargs: logging.info(f"Filling parameter '{param_name}' with default value: {default_values.get(param_name)}") From ed2d764c9e0187e726ab0b78636eaf15c4e253d1 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Mon, 27 Jan 2025 11:43:00 +0100 Subject: [PATCH 065/233] fix_handle_model_error await was required --- nebula/core/network/actions.py | 2 +- nebula/core/network/messages.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/nebula/core/network/actions.py b/nebula/core/network/actions.py index 8a0bfea0b..743a7eb8c 100644 --- a/nebula/core/network/actions.py +++ b/nebula/core/network/actions.py @@ -19,7 +19,7 @@ def factory_message_action(message_type: str, action: str): if message_actions: normalized_action = action.upper() enum_action = message_actions[normalized_action] - logging.info(f"Message action: {enum_action}, value: {enum_action.value}") + #logging.info(f"Message action: {enum_action}, value: {enum_action.value}") return enum_action.value else: return None diff --git a/nebula/core/network/messages.py b/nebula/core/network/messages.py index 163138a94..b4b1f5d61 100755 --- a/nebula/core/network/messages.py +++ b/nebula/core/network/messages.py @@ -92,6 +92,7 @@ async def process_message(self, data, addr_from): if not message_type: logging.warning("Received message with no active field in the 'oneof'") return + logging.info(f"Message type received: {message_type}") message_data = getattr(message_wrapper, message_type) @@ -107,7 +108,7 @@ async def process_message(self, data, addr_from): self.cm.forward_message(data, addr_from) if message_type == "model_message": - self.cm.handle_model_message(source, message_data) + await self.cm.handle_model_message(source, message_data) else: await self.cm.handle_message(source, message_type, message_data) From bfef62ca25897919dbd8d8eda902b0c4c7e61715 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Avil=C3=A9s=20Serrano?= <80918548+AlejandroAvilesSerrano@users.noreply.github.com> Date: Wed, 29 Jan 2025 15:50:27 +0100 Subject: [PATCH 066/233] feat_message_events --- nebula/core/engine.py | 23 ++++++++++++++++++++++- nebula/core/eventmanager.py | 24 +++++++++++++++++++++++- nebula/core/network/actions.py | 20 ++++++++++++++++++++ nebula/core/network/communications.py | 21 +++++++++++++-------- nebula/core/network/messages.py | 26 +++++++++++++++++++++----- 5 files changed, 99 insertions(+), 15 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index f306c87d9..52e45edb3 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -273,6 +273,11 @@ def set_round(self, new_round): self.trainer.set_current_round(new_round) + """ ############################## + # General callbacks # + ############################## + """ + @event_handler(nebula_pb2.DiscoveryMessage, nebula_pb2.DiscoveryMessage.Action.DISCOVER) async def _discovery_discover_callback(self, source, message): logging.info( @@ -571,7 +576,23 @@ async def _link_disconnect_from_callback(self, source, message): # ENGINE FUNCTIONALITY # ############################## """ - + + + def register_message_events_callbacks(self): + me_dict = self.cm.get_messages_events() + message_events = [(message_name, message_action) for (message_name, message_actions) in me_dict.items() for message_action in message_actions] + logging.info(f"{message_events}") + for event_type, action in message_events: + callback_name = f"_ {event_type}_{action}_callback" + method = getattr(self, callback_name, None) + + if callable(method): + self._event_manager.subscribe((event_type, action), method) + + async def trigger_event(self, message_event): + logging.info(f"Publishing MessageEvent: {message_event.message_type}") + await self._event_manager.publish(message_event) + async def _aditional_node_start(self): self.update_sinchronized_status(False) logging.info(f"Aditional node | {self.addr} | going to stablish connection with federation") diff --git a/nebula/core/eventmanager.py b/nebula/core/eventmanager.py index 4fc8f02f7..1042d4376 100755 --- a/nebula/core/eventmanager.py +++ b/nebula/core/eventmanager.py @@ -3,7 +3,7 @@ import logging from collections import defaultdict from functools import wraps - +from nebula.core.network.messages import MessageEvent def event_handler(message_type, action): """Decorator for registering an event handler.""" @@ -33,6 +33,28 @@ class EventManager: def __init__(self, default_callbacks=None): self._event_callbacks = defaultdict(list) self._register_default_callbacks(default_callbacks or []) + self._subscribers: dict[tuple[str,str], list] = {} + + def subscribe(self, event_type: tuple[str,str], callback: callable): + """Register a callback for a specific event type.""" + if event_type not in self._subscribers: + self._subscribers[event_type] = [] + self._subscribers[event_type].append(callback) + logging.info(f"EventManager | Subscribed callback for event: {event_type}") + + async def publish(self, message_event: MessageEvent): + """Trigger all callbacks registered for a specific event type.""" + event_type = message_event.message_type + if event_type not in self._subscribers: + logging.error(f"EventManager | No subscribers for event: {event_type}") + return + + for callback in self._subscribers[event_type]: + try: + logging.info(f"EventManager | Triggering callback for event: {event_type}, from source: {message_event.source}") + await callback(message_event.source, message_event.message) + except Exception as e: + logging.error(f"EventManager | Error in callback for event {event_type}: {e}") def _register_default_callbacks(self, default_callbacks): """Registers default callbacks for events.""" diff --git a/nebula/core/network/actions.py b/nebula/core/network/actions.py index 743a7eb8c..c51041f25 100644 --- a/nebula/core/network/actions.py +++ b/nebula/core/network/actions.py @@ -3,6 +3,26 @@ import logging + + +def get_actions_names(message_type: str): + options = { + "connection": ConnectionAction, + "federation": FederationAction, + "discovery": DiscoveryAction, + "control": ControlAction, + "discover": DiscoverAction, + "offer": OfferAction, + "link": LinkAction, + } + + message_actions = options.get(message_type) + if not message_actions: + raise ValueError(f"Invalid message type: {message_type}") + + return [action.name.lower() for action in message_actions] + + def factory_message_action(message_type: str, action: str): options = { "connection": ConnectionAction, diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 3ec0d325d..2e2b5d70b 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -197,14 +197,16 @@ async def forward_message(self, data, addr_from): await self.forwarder.forward(data, addr_from=addr_from) # generic point to handle messages - async def handle_message(self, source, msg_type, message): - logging.info( - f"πŸ” handle_{msg_type} | Received [Action {message.action}] from {source}" - ) - try: - await self.engine.event_manager.trigger_event(source, message) - except Exception as e: - logging.exception(f"πŸ” handle_{msg_type} | Error while processing: {e}") + #async def handle_message(self, source, msg_type, message): + async def handle_message(self, message_event): + #logging.info( + # f"πŸ” handle_{msg_type} | Received [Action {message.action}] from {source}" + #) + #try: + #await self.engine.event_manager.trigger_event(source, message) + await self.engine.trigger_event(message_event) + #except Exception as e: + # logging.exception(f"πŸ” handle_{msg_type} | Error while processing: {e}") async def handle_discovery_message(self, source, message): @@ -401,6 +403,9 @@ async def handle_link_message(self, source, message): def create_message(self, message_type: str, action: str = "", *args, **kwargs): return self.mm.create_message(message_type, action, *args, **kwargs) + + def get_messages_events(self): + return self.mm.get_messages_events() def start_external_connection_service(self): if self.ecs == None: diff --git a/nebula/core/network/messages.py b/nebula/core/network/messages.py index b4b1f5d61..63322ab3f 100755 --- a/nebula/core/network/messages.py +++ b/nebula/core/network/messages.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING from nebula.core.pb import nebula_pb2 -from nebula.core.network.actions import factory_message_action +from nebula.core.network.actions import factory_message_action, get_actions_names import hashlib import traceback @@ -11,6 +11,7 @@ class MessagesManager: + def __init__(self, addr, config, cm: "CommunicationsManager"): self.addr = addr self.config = config @@ -74,7 +75,13 @@ def _define_message_templates(self): }, # Add additional message types here } - + + def get_messages_events(self): + message_events = {} + for message_name in self._message_templates.keys(): + message_events[message_name] = get_actions_names(message_name) + return message_events + async def process_message(self, data, addr_from): not_processing_messages = {"control_message", "connection_message"} special_processing_messages = {"discovery_message", "federation_message", "model_message"} @@ -98,7 +105,8 @@ async def process_message(self, data, addr_from): # Not required processing messages if message_type in not_processing_messages: - await self.cm.handle_message(source, message_type, message_data) + #await self.cm.handle_message(source, message_type, message_data) + await self.cm.handle_message(MessageEvent(message_type, message_data.action), source, message_data) # Message-specific forwarding and processing elif message_type in special_processing_messages: @@ -110,12 +118,14 @@ async def process_message(self, data, addr_from): if message_type == "model_message": await self.cm.handle_model_message(source, message_data) else: - await self.cm.handle_message(source, message_type, message_data) + #await self.cm.handle_message(source, message_type, message_data) + await self.cm.handle_message(MessageEvent(message_type, message_data.action), source, message_data) # Rest of messages else: if await self.cm.include_received_message_hash(hashlib.md5(data).hexdigest()): - await self.cm.handle_message(source, message_type, message_data) + #await self.cm.handle_message(source, message_type, message_data) + await self.cm.handle_message(MessageEvent(message_type, message_data.action), source, message_data) except Exception as e: logging.exception(f"πŸ“₯ handle_incoming_message | Error while processing: {e}") logging.exception(traceback.format_exc()) @@ -182,3 +192,9 @@ def create_message(self, message_type: str, action: str = "", *args, **kwargs): getattr(message_wrapper, field_name).CopyFrom(message) data = message_wrapper.SerializeToString() return data + +class MessageEvent: + def __init__(self, message_type, source, message): + self.source = source + self.message_type = message_type + self.message = message \ No newline at end of file From ff2a3d10e4291af748a79ba7eec3e14096e5482d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Avil=C3=A9s=20Serrano?= <80918548+AlejandroAvilesSerrano@users.noreply.github.com> Date: Wed, 29 Jan 2025 16:01:24 +0100 Subject: [PATCH 067/233] fix_error_msg --- nebula/core/network/messages.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nebula/core/network/messages.py b/nebula/core/network/messages.py index 63322ab3f..47ad9ad1f 100755 --- a/nebula/core/network/messages.py +++ b/nebula/core/network/messages.py @@ -119,7 +119,7 @@ async def process_message(self, data, addr_from): await self.cm.handle_model_message(source, message_data) else: #await self.cm.handle_message(source, message_type, message_data) - await self.cm.handle_message(MessageEvent(message_type, message_data.action), source, message_data) + await self.cm.handle_message(MessageEvent((message_type, message_data.action), source, message_data)) # Rest of messages else: From 7f9d4be03bc60da895443ab88013d4b6ec4e09ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Avil=C3=A9s=20Serrano?= <80918548+AlejandroAvilesSerrano@users.noreply.github.com> Date: Wed, 29 Jan 2025 16:11:40 +0100 Subject: [PATCH 068/233] fix_event_error --- nebula/core/engine.py | 2 ++ nebula/core/network/messages.py | 7 ++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 52e45edb3..5e33560b5 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -144,6 +144,8 @@ def __init__( self._cm = CommunicationsManager(engine=self) # Set the communication manager in the model (send messages from there) self.trainer.model.set_communication_manager(self._cm) + logging.info("Registering callbacks for MessageEvents...") + self.register_message_events_callbacks() self._reporter = Reporter(config=self.config, trainer=self.trainer, cm=self.cm) diff --git a/nebula/core/network/messages.py b/nebula/core/network/messages.py index 47ad9ad1f..d591aedbb 100755 --- a/nebula/core/network/messages.py +++ b/nebula/core/network/messages.py @@ -96,6 +96,7 @@ async def process_message(self, data, addr_from): # Extract the active message from the oneof field message_type = message_wrapper.WhichOneof("message") + msg_name = message_type.split('_')[0] if not message_type: logging.warning("Received message with no active field in the 'oneof'") return @@ -106,7 +107,7 @@ async def process_message(self, data, addr_from): # Not required processing messages if message_type in not_processing_messages: #await self.cm.handle_message(source, message_type, message_data) - await self.cm.handle_message(MessageEvent(message_type, message_data.action), source, message_data) + await self.cm.handle_message(MessageEvent(msg_name, message_data.action), source, message_data) # Message-specific forwarding and processing elif message_type in special_processing_messages: @@ -119,13 +120,13 @@ async def process_message(self, data, addr_from): await self.cm.handle_model_message(source, message_data) else: #await self.cm.handle_message(source, message_type, message_data) - await self.cm.handle_message(MessageEvent((message_type, message_data.action), source, message_data)) + await self.cm.handle_message(MessageEvent((msg_name, message_data.action), source, message_data)) # Rest of messages else: if await self.cm.include_received_message_hash(hashlib.md5(data).hexdigest()): #await self.cm.handle_message(source, message_type, message_data) - await self.cm.handle_message(MessageEvent(message_type, message_data.action), source, message_data) + await self.cm.handle_message(MessageEvent(msg_name, message_data.action), source, message_data) except Exception as e: logging.exception(f"πŸ“₯ handle_incoming_message | Error while processing: {e}") logging.exception(traceback.format_exc()) From 18e91834821f3b2ce354cfddde79d75380767d28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Avil=C3=A9s=20Serrano?= <80918548+AlejandroAvilesSerrano@users.noreply.github.com> Date: Wed, 29 Jan 2025 17:43:55 +0100 Subject: [PATCH 069/233] fix_errors --- nebula/core/engine.py | 42 +++++++++++++++++---------------- nebula/core/network/actions.py | 23 ++++++++++++++++++ nebula/core/network/messages.py | 15 +++++++----- 3 files changed, 54 insertions(+), 26 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 5e33560b5..b41580f14 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -144,9 +144,7 @@ def __init__( self._cm = CommunicationsManager(engine=self) # Set the communication manager in the model (send messages from there) self.trainer.model.set_communication_manager(self._cm) - logging.info("Registering callbacks for MessageEvents...") - self.register_message_events_callbacks() - + self._reporter = Reporter(config=self.config, trainer=self.trainer, cm=self.cm) self._sinchronized_status = True @@ -171,9 +169,9 @@ def __init__( self._control_alive_callback, self._connection_connect_callback, self._connection_disconnect_callback, - self._federation_ready_callback, - self._start_federation_callback, - self._federation_models_included_callback, + #self._federation_ready_callback, + #self._start_federation_callback, + #self._federation_models_included_callback, self._discover_discover_join_callback, self._discover_discover_nodes_callback, self._connection_late_connect_callback, @@ -185,14 +183,17 @@ def __init__( ] ) + logging.info("Registering callbacks for MessageEvents...") + self.register_message_events_callbacks() + # Register additional callbacks - self._event_manager.register_event( - ( - nebula_pb2.FederationMessage, - nebula_pb2.FederationMessage.Action.REPUTATION, - ), - self._reputation_callback, - ) + #self._event_manager.register_event( + # ( + # nebula_pb2.FederationMessage, + # nebula_pb2.FederationMessage.Action.REPUTATION, + # ), + # self._reputation_callback, + #) # ... add more callbacks here @@ -338,7 +339,7 @@ async def _connection_disconnect_callback(self, source, message): nebula_pb2.FederationMessage, nebula_pb2.FederationMessage.Action.FEDERATION_READY, ) - async def _federation_ready_callback(self, source, message): + async def _federation_federation_ready_callback(self, source, message): logging.info(f"πŸ“ handle_federation_message | Trigger | Received ready federation message from {source}") if self.config.participant["device_args"]["start"]: logging.info(f"πŸ“ handle_federation_message | Trigger | Adding ready connection {source}") @@ -348,12 +349,12 @@ async def _federation_ready_callback(self, source, message): nebula_pb2.FederationMessage, nebula_pb2.FederationMessage.Action.FEDERATION_START, ) - async def _start_federation_callback(self, source, message): + async def _federation_federation_start_callback(self, source, message): logging.info(f"πŸ“ handle_federation_message | Trigger | Received start federation message from {source}") await self.create_trainer_module() @event_handler(nebula_pb2.FederationMessage, nebula_pb2.FederationMessage.Action.REPUTATION) - async def _reputation_callback(self, source, message): + async def _federation_reputation_callback(self, source, message): malicious_nodes = message.arguments # List of malicious nodes if self.with_reputation: if len(malicious_nodes) > 0 and not self._is_malicious: @@ -369,7 +370,7 @@ async def _reputation_callback(self, source, message): nebula_pb2.FederationMessage, nebula_pb2.FederationMessage.Action.FEDERATION_MODELS_INCLUDED, ) - async def _federation_models_included_callback(self, source, message): + async def _federation_federation_models_included_callback(self, source, message): logging.info(f"πŸ“ handle_federation_message | Trigger | Received aggregation finished message from {source}") try: await self.cm.get_connections_lock().acquire_async() @@ -585,15 +586,16 @@ def register_message_events_callbacks(self): message_events = [(message_name, message_action) for (message_name, message_actions) in me_dict.items() for message_action in message_actions] logging.info(f"{message_events}") for event_type, action in message_events: - callback_name = f"_ {event_type}_{action}_callback" + callback_name = f"_{event_type}_{action}_callback" + logging.info(f"Searching callback named: {callback_name}") method = getattr(self, callback_name, None) if callable(method): - self._event_manager.subscribe((event_type, action), method) + self.event_manager.subscribe((event_type, action), method) async def trigger_event(self, message_event): logging.info(f"Publishing MessageEvent: {message_event.message_type}") - await self._event_manager.publish(message_event) + await self.event_manager.publish(message_event) async def _aditional_node_start(self): self.update_sinchronized_status(False) diff --git a/nebula/core/network/actions.py b/nebula/core/network/actions.py index c51041f25..77a4ebd82 100644 --- a/nebula/core/network/actions.py +++ b/nebula/core/network/actions.py @@ -3,6 +3,29 @@ import logging +def get_action_name_from_value(message_type: str, action_value: int) -> str: + # Diccionario que asocia cada tipo de mensaje con su Enum correspondiente + action_classes = { + "connection": ConnectionAction, + "federation": FederationAction, + "discovery": DiscoveryAction, + "control": ControlAction, + "discover": DiscoverAction, + "offer": OfferAction, + "link": LinkAction, + } + + # Obtener el Enum correspondiente al tipo de mensaje + enum_class = action_classes.get(message_type) + if not enum_class: + raise ValueError(f"Unknown message type: {message_type}") + + # Buscar el nombre de la acciΓ³n a partir del valor + for action in enum_class: + if action.value == action_value: + return action.name.lower() # Convertimos a lowercase para mantener el formato "late_connect" + + raise ValueError(f"Unknown action value {action_value} for message type {message_type}") def get_actions_names(message_type: str): diff --git a/nebula/core/network/messages.py b/nebula/core/network/messages.py index d591aedbb..181944f13 100755 --- a/nebula/core/network/messages.py +++ b/nebula/core/network/messages.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING from nebula.core.pb import nebula_pb2 -from nebula.core.network.actions import factory_message_action, get_actions_names +from nebula.core.network.actions import factory_message_action, get_actions_names, get_action_name_from_value import hashlib import traceback @@ -79,7 +79,8 @@ def _define_message_templates(self): def get_messages_events(self): message_events = {} for message_name in self._message_templates.keys(): - message_events[message_name] = get_actions_names(message_name) + if message_name != "model" and message_name != "reputation": + message_events[message_name] = get_actions_names(message_name) return message_events async def process_message(self, data, addr_from): @@ -107,7 +108,8 @@ async def process_message(self, data, addr_from): # Not required processing messages if message_type in not_processing_messages: #await self.cm.handle_message(source, message_type, message_data) - await self.cm.handle_message(MessageEvent(msg_name, message_data.action), source, message_data) + me = MessageEvent((msg_name,get_action_name_from_value(msg_name, message_data.action)), source, message_data) + await self.cm.handle_message(me) # Message-specific forwarding and processing elif message_type in special_processing_messages: @@ -120,13 +122,14 @@ async def process_message(self, data, addr_from): await self.cm.handle_model_message(source, message_data) else: #await self.cm.handle_message(source, message_type, message_data) - await self.cm.handle_message(MessageEvent((msg_name, message_data.action), source, message_data)) - + me = MessageEvent((msg_name,get_action_name_from_value(msg_name, message_data.action)), source, message_data) + await self.cm.handle_message(me) # Rest of messages else: if await self.cm.include_received_message_hash(hashlib.md5(data).hexdigest()): #await self.cm.handle_message(source, message_type, message_data) - await self.cm.handle_message(MessageEvent(msg_name, message_data.action), source, message_data) + me = MessageEvent((msg_name,get_action_name_from_value(msg_name, message_data.action)), source, message_data) + await self.cm.handle_message(me) except Exception as e: logging.exception(f"πŸ“₯ handle_incoming_message | Error while processing: {e}") logging.exception(traceback.format_exc()) From 6a1384a4bf7e7a48c4c7cf5d6d8779b34a97e68f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Avil=C3=A9s=20Serrano?= <80918548+AlejandroAvilesSerrano@users.noreply.github.com> Date: Thu, 30 Jan 2025 11:52:37 +0100 Subject: [PATCH 070/233] update_momemtum --- nebula/core/neighbormanagement/momentum.py | 3 ++- nebula/core/neighbormanagement/nodemanager.py | 8 ++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/nebula/core/neighbormanagement/momentum.py b/nebula/core/neighbormanagement/momentum.py index 8d4152821..2b94af257 100644 --- a/nebula/core/neighbormanagement/momentum.py +++ b/nebula/core/neighbormanagement/momentum.py @@ -26,8 +26,8 @@ def __init__( self, node_manager: "NodeManager", nodes, - global_priority=GLOBAL_PRIORITY, dispersion_penalty=True, + global_priority=GLOBAL_PRIORITY, similarity_metric : SimilarityMetricType = cosine_metric, mapping_similarity : MappingSimilarityType = lambda sim_value, e=EPSILON: e + ((sim_value + 1) / 2), ): @@ -155,6 +155,7 @@ def sigmoid(similarity, k=2.5): mapped_sim_value = self.msf(sim_value) # Mapped into [0, 1] interval smoothed_value = sigmoid(mapped_sim_value) adjusted_weight = smoothed_value * self._global_prio + (1 - self._global_prio) * mapped_sim_value + logging.info(f"Momemtum values | adjusted_weight: {adjusted_weight}, map_value: {mapped_sim_value}, smoothed_value: {smoothed_value}") if self._dispersion_penalty: self._calculate_dispersion_penalty(historic, updates) diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index dc635750e..8b321a01c 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -11,6 +11,7 @@ from nebula.core.pb import nebula_pb2 from nebula.core.network.communications import CommunicationsManager from nebula.core.neighbormanagement.fastreboot import FastReboot +from nebula.core.neighbormanagement.momentum import Momentum from nebula.addons.functions import print_msg_box from typing import TYPE_CHECKING @@ -28,7 +29,7 @@ def __init__( push_acceleration, engine : "Engine", fastreboot=True, - momentum=False, + momentum=True, ): self.topology = "fully"#topology print_msg_box(msg=f"Starting NodeManager module...\nTopology: {self.topology}", indent=2, title="NodeManager module") @@ -60,8 +61,9 @@ def __init__( if (fastreboot): self._fastreboot = FastReboot(self) + self._momemtum = None if (momentum): - pass + self._momemtum = Momentum(self, self.neighbor_policy.get_nodes_known(neighbors_only=True), dispersion_penalty=False) #self.set_confings() @@ -171,6 +173,8 @@ async def apply_weight_strategy(self, updates: dict): if not self.fast_reboot_on(): return await self.fr.apply_weight_strategy(updates) + if self._momemtum: + self._momemtum.calculate_momentum_weights(updates) From db3710b20f2a3f5d2c9b09cb9587cecbb0355a78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Avil=C3=A9s=20Serrano?= <80918548+AlejandroAvilesSerrano@users.noreply.github.com> Date: Thu, 30 Jan 2025 13:16:33 +0100 Subject: [PATCH 071/233] updt --- nebula/core/neighbormanagement/nodemanager.py | 2 +- nebula/core/network/communications.py | 48 ------------------- 2 files changed, 1 insertion(+), 49 deletions(-) diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 8b321a01c..90476f5ec 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -174,7 +174,7 @@ async def apply_weight_strategy(self, updates: dict): return await self.fr.apply_weight_strategy(updates) if self._momemtum: - self._momemtum.calculate_momentum_weights(updates) + await self._momemtum.calculate_momentum_weights(updates) diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 2e2b5d70b..b629bdf8b 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -144,54 +144,6 @@ async def add_ready_connection(self, addr): async def handle_incoming_message(self, data, addr_from): await self.mm.process_message(data, addr_from) - return - try: - message_wrapper = nebula_pb2.Wrapper() - message_wrapper.ParseFromString(data) - source = message_wrapper.source - logging.debug(f"πŸ“₯ handle_incoming_message | Received message from {addr_from} with source {source}") - if source == self.addr: - return - if message_wrapper.HasField("discovery_message"): - if await self.include_received_message_hash(hashlib.md5(data).hexdigest()): - await self.forwarder.forward(data, addr_from=addr_from) - await self.handle_discovery_message(source, message_wrapper.discovery_message) - elif message_wrapper.HasField("control_message"): - await self.handle_control_message(source, message_wrapper.control_message) - elif message_wrapper.HasField("federation_message"): - if await self.include_received_message_hash(hashlib.md5(data).hexdigest()): - if self.config.participant["device_args"][ - "proxy" - ] or message_wrapper.federation_message.action == nebula_pb2.FederationMessage.Action.Value( - "FEDERATION_START" - ): - await self.forwarder.forward(data, addr_from=addr_from) - await self.handle_federation_message(source, message_wrapper.federation_message) - elif message_wrapper.HasField("model_message"): - if await self.include_received_message_hash(hashlib.md5(data).hexdigest()): - # TODO: Improve the technique. Now only forward model messages if the node is a proxy - # Need to update the expected model messages receiving during the round - # Round -1 is the initialization round --> all nodes should receive the model - if self.config.participant["device_args"]["proxy"] or message_wrapper.model_message.round == -1: - await self.forwarder.forward(data, addr_from=addr_from) - await self.handle_model_message(source, message_wrapper.model_message) - elif message_wrapper.HasField("connection_message"): - await self.handle_connection_message(source, message_wrapper.connection_message) - elif message_wrapper.HasField("discover_message"): - if self.include_received_message_hash(hashlib.md5(data).hexdigest()): - await self.handle_discover_message(source, message_wrapper.discover_message) - elif message_wrapper.HasField("offer_message"): - if self.include_received_message_hash(hashlib.md5(data).hexdigest()): - await self.handle_offer_message(source, message_wrapper.offer_message) - elif message_wrapper.HasField("link_message"): - if self.include_received_message_hash(hashlib.md5(data).hexdigest()): - await self.handle_offer_message(source, message_wrapper.link_message) - else: - logging.info(f"Unknown handler for message: {message_wrapper}") - except Exception as e: - logging.exception(f"πŸ“₯ handle_incoming_message | Error while processing: {e}") - logging.exception(traceback.format_exc()) - async def forward_message(self, data, addr_from): await self.forwarder.forward(data, addr_from=addr_from) From 3b3d52386dcedd41f6e6cca488a4a259a0488eba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Avil=C3=A9s=20Serrano?= <80918548+AlejandroAvilesSerrano@users.noreply.github.com> Date: Mon, 3 Feb 2025 10:08:05 +0100 Subject: [PATCH 072/233] fix_momemtum_config --- nebula/core/neighbormanagement/momentum.py | 1 + nebula/core/neighbormanagement/nodemanager.py | 16 +++++++++------- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/nebula/core/neighbormanagement/momentum.py b/nebula/core/neighbormanagement/momentum.py index 2b94af257..0f55f70bd 100644 --- a/nebula/core/neighbormanagement/momentum.py +++ b/nebula/core/neighbormanagement/momentum.py @@ -107,6 +107,7 @@ async def _calculate_similarities(self, updates: dict): update, similarity=True, ) + logging.info(f"Model similarity for node: {addr}, sim: {sim_value}") await self._add_similarity_to_node(addr, sim_value) def _calculate_dispersion_penalty(self, historic: dict, updates: dict): diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 90476f5ec..bb9910de0 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -58,13 +58,8 @@ def __init__( self.synchronizing_rounds = False self._fast_reboot_status = fastreboot - if (fastreboot): - self._fastreboot = FastReboot(self) - - self._momemtum = None - if (momentum): - self._momemtum = Momentum(self, self.neighbor_policy.get_nodes_known(neighbors_only=True), dispersion_penalty=False) - + self._momemtum_status = momentum + #self.set_confings() @property @@ -152,6 +147,13 @@ async def set_confings(self): ) #self.engine.trainer.get_loss(), self.config.participant["molibity_args"]["weight_distance"], self.config.participant["molibity_args"]["weight_het"] #self.model_handler.set_config([self.engine.get_round(), self.engine.config.participant["training_args"]["epochs"]]) + + if (self._fast_reboot_status): + self._fastreboot = FastReboot(self) + + self._momemtum = None + if (self._momemtum_status): + self._momemtum = Momentum(self, self.neighbor_policy.get_nodes_known(neighbors_only=True), dispersion_penalty=False) ############################## From 56b8997a20bf80e8f4307d56fa25c2bab5c69b60 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 6 Feb 2025 14:11:29 +0100 Subject: [PATCH 073/233] optimization_sinc --- nebula/core/aggregation/aggregator.py | 225 +++++++----- nebula/core/engine.py | 344 ++++++++---------- nebula/core/neighbormanagement/fastreboot.py | 103 +++--- .../modelhandlers/stdmodelhandler.py | 20 +- nebula/core/neighbormanagement/momentum.py | 144 +++++--- nebula/core/neighbormanagement/nodemanager.py | 340 ++++++++--------- nebula/core/network/actions.py | 121 +++--- nebula/core/network/communications.py | 134 ++----- nebula/core/network/messages.py | 122 +++---- nebula/node.py | 37 +- 10 files changed, 763 insertions(+), 827 deletions(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index 1f027fa90..28842a478 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -2,37 +2,18 @@ import logging from abc import ABC, abstractmethod from functools import partial +from typing import TYPE_CHECKING -from nebula.core.pb import nebula_pb2 from nebula.core.utils.locker import Locker +if TYPE_CHECKING: + from nebula.core.engine import Engine + class AggregatorException(Exception): pass -def create_aggregator(config, engine): - from nebula.core.aggregation.blockchainReputation import BlockchainReputation - from nebula.core.aggregation.fedavg import FedAvg - from nebula.core.aggregation.krum import Krum - from nebula.core.aggregation.median import Median - from nebula.core.aggregation.trimmedmean import TrimmedMean - - ALGORITHM_MAP = { - "FedAvg": FedAvg, - "Krum": Krum, - "Median": Median, - "TrimmedMean": TrimmedMean, - "BlockchainReputation": BlockchainReputation, - } - algorithm = config.participant["aggregator_args"]["algorithm"] - aggregator = ALGORITHM_MAP.get(algorithm) - if aggregator: - return aggregator(config=config, engine=engine) - else: - raise AggregatorException(f"Aggregation algorithm {algorithm} not found.") - - def create_target_aggregator(config, engine): from nebula.core.aggregation.fedavg import FedAvg from nebula.core.aggregation.krum import Krum @@ -56,7 +37,7 @@ def create_target_aggregator(config, engine): class Aggregator(ABC): def __init__(self, config=None, engine=None): self.config = config - self.engine = engine + self.engine: Engine = engine self._addr = config.participant["network_args"]["addr"] logging.info(f"[{self.__class__.__name__}] Starting Aggregator") self._federation_nodes = set() @@ -66,6 +47,8 @@ def __init__(self, config=None, engine=None): self._add_model_lock = Locker(name="add_model_lock", async_lock=True) self._add_next_model_lock = Locker(name="add_next_model_lock", async_lock=True) self._aggregation_done_lock = Locker(name="aggregation_done_lock", async_lock=True) + self._aggregation_waiting_skip = asyncio.Event() + self._push_strategy_lock = Locker(name="push_strategy_lock", async_lock=True) def __str__(self): return self.__class__.__name__ @@ -164,12 +147,13 @@ async def _add_pending_model(self, model, weight, source): if future_round < self.engine.get_round(): del self._future_models_to_aggregate[future_round] + # TODO comprobar que los q faltan no estan en futuros + if len(self.get_nodes_pending_models_to_aggregate()) >= len(self._federation_nodes): logging.info("πŸ”„ _add_pending_model | All models were added in the aggregation buffer. Run aggregation...") - self.engine.update_sinchronized_status(True) + # self.engine.update_sinchronized_status(True) await self._aggregation_done_lock.release_async() - #else: - # await self.aggregation_push_available() + await self._add_model_lock.release_async() return self.get_nodes_pending_models_to_aggregate() @@ -199,11 +183,13 @@ async def include_model_in_buffer(self, model, weight, source=None, round=None, logging.info( f"πŸ”„ include_model_in_buffer | Broadcasting MODELS_INCLUDED for round {self.engine.get_round()}" ) - #message = self.cm.mm.generate_federation_message( + # message = self.cm.mm.generate_federation_message( # nebula_pb2.FederationMessage.Action.FEDERATION_MODELS_INCLUDED, # [self.engine.get_round()], - #) - message = self.cm.create_message("federation", "federation_models_included", [str(arg) for arg in [self.engine.get_round()]]) + # ) + message = self.cm.create_message( + "federation", "federation_models_included", [str(arg) for arg in [self.engine.get_round()]] + ) await self.cm.send_message_to_neighbors(message) return @@ -211,7 +197,24 @@ async def include_model_in_buffer(self, model, weight, source=None, round=None, async def get_aggregation(self): try: timeout = self.config.participant["aggregator_args"]["aggregation_timeout"] - await self._aggregation_done_lock.acquire_async(timeout=timeout) + logging.info(f"Aggregation timeout: {timeout} starts...") + lock_task = asyncio.create_task(self._aggregation_done_lock.acquire_async(timeout=timeout)) + skip_task = asyncio.create_task(self._aggregation_waiting_skip.wait()) + done, pending = await asyncio.wait( + [lock_task, skip_task], + return_when=asyncio.FIRST_COMPLETED, + ) + lock_acquired = lock_task in done + if skip_task in done: + logging.info("Skipping aggregation wait due to detected desynchronization") + self._aggregation_waiting_skip.clear() + if not lock_acquired: + lock_task.cancel() + try: + await lock_task # Clean cancel + except asyncio.CancelledError: + pass + except TimeoutError: logging.exception("πŸ”„ get_aggregation | Timeout reached for aggregation") except asyncio.CancelledError: @@ -219,7 +222,8 @@ async def get_aggregation(self): except Exception as e: logging.exception(f"πŸ”„ get_aggregation | Error acquiring lock: {e}") finally: - await self._aggregation_done_lock.release_async() + if lock_acquired: + await self._aggregation_done_lock.release_async() if self._waiting_global_update and len(self._pending_models_to_aggregate) == 1: logging.info( @@ -239,9 +243,6 @@ async def get_aggregation(self): self._pending_models_to_aggregate = await self.engine.apply_weight_strategy(self._pending_models_to_aggregate) aggregated_result = self.run_aggregation(self._pending_models_to_aggregate) - if not self.engine.get_sinchronized_status() and self.engine.get_push_acceleration() == "fast": - await self._add_model_lock.release_async() - await self._add_next_model_lock.release_async() self._pending_models_to_aggregate.clear() return aggregated_result @@ -253,7 +254,8 @@ async def include_next_model_in_buffer(self, model, weight, source=None, round=N await self._add_next_model_lock.acquire_async() self._future_models_to_aggregate[round].append((decoded_model, weight, source)) await self._add_next_model_lock.release_async() - #await self.aggregation_push_available() + # await self.aggregation_push_available() + # asyncio.create_task(self.aggregation_push_available()) def print_model_size(self, model): total_params = 0 @@ -271,104 +273,115 @@ def print_model_size(self, model): async def aggregation_push_available(self): """ - If the node is not sinchronized with the federation, it may be possible to make a push - and try to catch the federation asap. + If the node is not sinchronized with the federation, it may be possible to make a push + and try to catch the federation asap. """ - #TODO it would be able to push even if not fullround updates are being received - logging.info(f"❗️ synchronized status: {self.engine.get_sinchronized_status()} | Analizing if an aggregation push is available...") - if not self.engine.get_sinchronized_status() and not self.engine.get_trainning_in_progress_lock().locked() and not self.engine.get_synchronizing_rounds(): - n_fed_nodes = len(self._federation_nodes) - further_round = self.engine.get_round() - logging.info(f" Pending models: {len(self.get_nodes_pending_models_to_aggregate())} | federation: {n_fed_nodes}") + + # TODO verify if an already sinchronized node gets desinchronized + await self._push_strategy_lock.acquire_async() + + logging.info( + f"❗️ synchronized status: {self.engine.get_sinchronized_status()} | Analizing if an aggregation push is available..." + ) + if ( + not self.engine.get_sinchronized_status() + and not self.engine.get_trainning_in_progress_lock().locked() + and not self.engine.get_synchronizing_rounds() + ): + n_fed_nodes = len(self._federation_nodes) + current_round = self.engine.get_round() + further_round = current_round + logging.info( + f" Pending models: {len(self.get_nodes_pending_models_to_aggregate())} | federation: {n_fed_nodes}" + ) if len(self.get_nodes_pending_models_to_aggregate()) < n_fed_nodes: - n_fed_nodes-=1 + n_fed_nodes -= 1 for f_round, fm in self._future_models_to_aggregate.items(): - # future_models dont count self node - if len(fm) == n_fed_nodes: - further_round = f_round + # future_models dont count self node + if (f_round - current_round) > 1 or len(fm) == n_fed_nodes: + further_round = f_round push = self.engine.get_push_acceleration() if push == "slow": - logging.info(f"❗️ FUTURE round: {further_round} is available | PUSH strategy ON") - logging.info("❗️ SLOW push selected | Start PUSHING slow") - await self.engine.set_pushed_done(further_round - self.engine.get_round()) - # we wait until learning cycle reach aggregation point - while not self._aggregation_done_lock.locked_async(): - logging.info("πŸ”„ Waiting | aggregation step not reached yet...") - await asyncio.sleep(1) - # Unlock aggregation - logging.info("πŸ”„ Releasing aggregation lock...") + logging.info("❗️ SLOW push selected") + logging.info( + f"❗️ Federation is at least {(f_round - current_round)} rounds ahead, Pushing slow..." + ) + await self.engine.set_pushed_done(further_round - current_round) self.engine.update_sinchronized_status(False) self.engine.set_synchronizing_rounds(True) - await self._aggregation_done_lock.release_async() + self._aggregation_waiting_skip.set() + await self._push_strategy_lock.release_async() return - - if further_round != self.engine.get_round() and push == "fast": - logging.info(f"❗️ FUTURE round: {further_round} is available | PUSH strategy ON") - logging.info("❗️ FAST push selected | Start PUSHING fast") - - if further_round == (self.engine.get_round()+1): + + if further_round != current_round and push == "fast": + logging.info("❗️ FAST push selected") + logging.info(f"❗️ FUTURE round: {further_round} is available, Pushing fast...") + + if further_round == (current_round + 1): logging.info(f"πŸ”„ Rounds jumped: {1}...") - await self.engine.set_pushed_done(further_round - self.engine.get_round()) - # we wait until learning cycle reach aggregation point - while not self._aggregation_done_lock.locked_async(): - logging.info("πŸ”„ Waiting | aggregation step not reached yet...") - await asyncio.sleep(1) - # Unlock aggregation - logging.info("πŸ”„ Releasing aggregation lock...") + await self.engine.set_pushed_done(further_round - current_round) self.engine.update_sinchronized_status(False) self.engine.set_synchronizing_rounds(True) - await self._aggregation_done_lock.release_async() + self._aggregation_waiting_skip.set() + await self._push_strategy_lock.release_async() return - - logging.info(f"πŸ”„ Rounds jumped: {self.engine.get_round() - further_round}...") + + logging.info(f"πŸ”„ Number of rounds jumped: {further_round - current_round}...") own_update = self._pending_models_to_aggregate.get(self.engine.get_addr()) while own_update == None: own_update = self._pending_models_to_aggregate.get(self.engine.get_addr()) - asyncio.sleep(1) + asyncio.sleep(1) (model, weight) = own_update - + # Getting locks to avoid concurrency issues await self._add_model_lock.acquire_async() await self._add_next_model_lock.acquire_async() - + # Remove all pendings updates and add own_update self._pending_models_to_aggregate.clear() self._pending_models_to_aggregate.update({self.engine.get_addr(): (model, weight)}) - + # Add to pendings the future round updates for future_update in self._future_models_to_aggregate[further_round]: (decoded_model, weight, source) = future_update self._pending_models_to_aggregate.update({source: (decoded_model, weight)}) - + # Clear all rounds that are going to be jumped - for key in self._future_models_to_aggregate.keys(): - if key <= further_round: - del self._future_models_to_aggregate[key] - + self._future_models_to_aggregate = { + key: value for key, value in self._future_models_to_aggregate.items() if key > further_round + } + self.engine.update_sinchronized_status(False) self.engine.set_synchronizing_rounds(True) - await self.engine.set_pushed_done(further_round - self.engine.get_round()) + await self.engine.set_pushed_done(further_round - current_round) self.engine.set_round(further_round) - - # Unlock aggregation - # we wait until learning cycle reach aggregation point - while not self._aggregation_done_lock.locked_async(): - await asyncio.sleep(1) - await self._aggregation_done_lock.release_async() + await self._add_model_lock.release_async() + await self._add_next_model_lock.release_async() + await self._push_strategy_lock.release_async() + self._aggregation_waiting_skip.set() return - + else: - logging.info("Info | No future rounds available, device is up to date...") - self.engine.update_sinchronized_status(True) - self.engine.set_synchronizing_rounds(False) + if len(self._future_models_to_aggregate.items()) < 2: + logging.info("Info | No future rounds available, device is up to date...") + self.engine.update_sinchronized_status(True) + self.engine.set_synchronizing_rounds(False) + else: + pass + await self._push_strategy_lock.release_async() else: - logging.info(f"All models updates are received | models number: {len(self.get_nodes_pending_models_to_aggregate())}") + logging.info( + f"All models updates are received | models number: {len(self.get_nodes_pending_models_to_aggregate())}" + ) + await self._push_strategy_lock.release_async() else: if not self.engine.get_sinchronized_status(): if self.engine.get_sinchronized_status(): logging.info("❗️ Cannot analize push | Trainning in progress") elif self.engine.get_synchronizing_rounds(): logging.info("❗️ Cannot analize push | already pushing rounds") + await self._push_strategy_lock.release_async() + def create_malicious_aggregator(aggregator, attack): # It creates a partial function aggregate that wraps the aggregate method of the original aggregator. @@ -385,3 +398,25 @@ def malicious_aggregate(self, models): aggregator.run_aggregation = partial(malicious_aggregate, aggregator) return aggregator + + +def create_aggregator(config, engine) -> Aggregator: + from nebula.core.aggregation.blockchainReputation import BlockchainReputation + from nebula.core.aggregation.fedavg import FedAvg + from nebula.core.aggregation.krum import Krum + from nebula.core.aggregation.median import Median + from nebula.core.aggregation.trimmedmean import TrimmedMean + + ALGORITHM_MAP = { + "FedAvg": FedAvg, + "Krum": Krum, + "Median": Median, + "TrimmedMean": TrimmedMean, + "BlockchainReputation": BlockchainReputation, + } + algorithm = config.participant["aggregator_args"]["algorithm"] + aggregator = ALGORITHM_MAP.get(algorithm) + if aggregator: + return aggregator(config=config, engine=engine) + else: + raise AggregatorException(f"Aggregation algorithm {algorithm} not found.") diff --git a/nebula/core/engine.py b/nebula/core/engine.py index b41580f14..d124cc9fb 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -8,11 +8,10 @@ from nebula.addons.functions import print_msg_box from nebula.addons.reporter import Reporter from nebula.core.aggregation.aggregator import create_aggregator, create_malicious_aggregator, create_target_aggregator -from nebula.core.eventmanager import EventManager, event_handler +from nebula.core.eventmanager import EventManager +from nebula.core.neighbormanagement.nodemanager import NodeManager from nebula.core.network.communications import CommunicationsManager -from nebula.core.pb import nebula_pb2 from nebula.core.utils.locker import Locker -from nebula.core.neighbormanagement.nodemanager import NodeManager logging.getLogger("requests").setLevel(logging.WARNING) logging.getLogger("urllib3").setLevel(logging.WARNING) @@ -144,59 +143,35 @@ def __init__( self._cm = CommunicationsManager(engine=self) # Set the communication manager in the model (send messages from there) self.trainer.model.set_communication_manager(self._cm) - + self._reporter = Reporter(config=self.config, trainer=self.trainer, cm=self.cm) - + self._sinchronized_status = True self.sinchronized_status_lock = Locker(name="sinchronized_status_lock") - - self.trainning_in_progress_lock = Locker(name="trainning_in_progress_lock", async_lock=True) - + + self.trainning_in_progress_lock = Locker(name="trainning_in_progress_lock", async_lock=True) + # Mobility setup self._node_manager = None self.mobility = self.config.participant["mobility_args"]["mobility"] if self.mobility == True: topology = self.config.participant["mobility_args"]["topology_type"] topology = topology.lower() - model_handler = "std" #self.config.participant["mobility_args"]["model_handler"] - acceleration_push = "slow" #self.config.participant["mobility_args"]["push_strategy"] - self._node_manager = NodeManager(topology, model_handler, acceleration_push, engine=self) - - - self._event_manager = EventManager( - default_callbacks=[ - self._discovery_discover_callback, - self._control_alive_callback, - self._connection_connect_callback, - self._connection_disconnect_callback, - #self._federation_ready_callback, - #self._start_federation_callback, - #self._federation_models_included_callback, - self._discover_discover_join_callback, - self._discover_discover_nodes_callback, - self._connection_late_connect_callback, - self._connection_restructure_callback, - self._offer_offer_model_callback, - self._offer_offer_metric_callback, - self._link_connect_to_callback, - self._link_disconnect_from_callback, - ] - ) + model_handler = "std" # self.config.participant["mobility_args"]["model_handler"] + acceleration_push = "slow" # self.config.participant["mobility_args"]["push_strategy"] + self._node_manager = NodeManager( + config.participant["mobility_args"]["additional_node"]["status"], + topology, + model_handler, + acceleration_push, + engine=self, + ) + + self._event_manager = EventManager() logging.info("Registering callbacks for MessageEvents...") self.register_message_events_callbacks() - # Register additional callbacks - #self._event_manager.register_event( - # ( - # nebula_pb2.FederationMessage, - # nebula_pb2.FederationMessage.Action.REPUTATION, - # ), - # self._reputation_callback, - #) - - # ... add more callbacks here - @property def cm(self): return self._cm @@ -219,7 +194,7 @@ def get_aggregator_type(self): @property def trainer(self): return self._trainer - + @property def nm(self): return self._node_manager @@ -247,41 +222,39 @@ def get_federation_ready_lock(self): def get_federation_setup_lock(self): return self.federation_setup_lock - + def get_trainning_in_progress_lock(self): return self.trainning_in_progress_lock def get_round_lock(self): return self.round_lock - + def get_sinchronized_status(self): with self.sinchronized_status_lock: return self._sinchronized_status - + def get_synchronizing_rounds(self): return self.nm.get_syncrhonizing_rounds() - + def update_sinchronized_status(self, status): with self.sinchronized_status_lock: logging.info(f"Update | synchronized status from: {self._sinchronized_status} to {status}") self._sinchronized_status = status - + def set_synchronizing_rounds(self, status): if self.mobility: self.nm.set_synchronizing_rounds(not status) - + def set_round(self, new_round): logging.info(f"πŸ€– Update round count | from: {self.round} | to round: {new_round}") self.round = new_round self.trainer.set_current_round(new_round) - """ ############################## # General callbacks # - ############################## + ############################## """ - @event_handler(nebula_pb2.DiscoveryMessage, nebula_pb2.DiscoveryMessage.Action.DISCOVER) async def _discovery_discover_callback(self, source, message): logging.info( f"πŸ” handle_discovery_message | Trigger | Received discovery message from {source} (network propagation)" @@ -305,7 +278,6 @@ async def _discovery_discover_callback(self, source, message): f"πŸ” Invalid geolocation received from {source}: latitude={message.latitude}, longitude={message.longitude}" ) - @event_handler(nebula_pb2.ControlMessage, nebula_pb2.ControlMessage.Action.ALIVE) async def _control_alive_callback(self, source, message): logging.info(f"πŸ”§ handle_control_message | Trigger | Received alive message from {source}") current_connections = await self.cm.get_addrs_current_connections(myself=True) @@ -317,7 +289,6 @@ async def _control_alive_callback(self, source, message): else: logging.error(f"❗️ Connection {source} not found in connections...") - @event_handler(nebula_pb2.ConnectionMessage, nebula_pb2.ConnectionMessage.Action.CONNECT) async def _connection_connect_callback(self, source, message): logging.info(f"πŸ”— handle_connection_message | Trigger | Received connection message from {source}") current_connections = await self.cm.get_addrs_current_connections(myself=True) @@ -325,35 +296,25 @@ async def _connection_connect_callback(self, source, message): logging.info(f"πŸ”— handle_connection_message | Trigger | Connecting to {source}") await self.cm.connect(source, direct=True) - @event_handler(nebula_pb2.ConnectionMessage, nebula_pb2.ConnectionMessage.Action.DISCONNECT) async def _connection_disconnect_callback(self, source, message): logging.info(f"πŸ”— handle_connection_message | Trigger | Received disconnection message from {source}") if self.mobility: if await self.nm.waiting_confirmation_from(source): await self.nm.confirmation_received(source, confirmation=False) - #if source in await self.cm.get_all_addrs_current_connections(only_direct=True): + # if source in await self.cm.get_all_addrs_current_connections(only_direct=True): await self.nm.update_neighbors(source, remove=True) await self.cm.disconnect(source, mutual_disconnection=False) - - @event_handler( - nebula_pb2.FederationMessage, - nebula_pb2.FederationMessage.Action.FEDERATION_READY, - ) + async def _federation_federation_ready_callback(self, source, message): logging.info(f"πŸ“ handle_federation_message | Trigger | Received ready federation message from {source}") if self.config.participant["device_args"]["start"]: logging.info(f"πŸ“ handle_federation_message | Trigger | Adding ready connection {source}") await self.cm.add_ready_connection(source) - @event_handler( - nebula_pb2.FederationMessage, - nebula_pb2.FederationMessage.Action.FEDERATION_START, - ) async def _federation_federation_start_callback(self, source, message): logging.info(f"πŸ“ handle_federation_message | Trigger | Received start federation message from {source}") await self.create_trainer_module() - @event_handler(nebula_pb2.FederationMessage, nebula_pb2.FederationMessage.Action.REPUTATION) async def _federation_reputation_callback(self, source, message): malicious_nodes = message.arguments # List of malicious nodes if self.with_reputation: @@ -366,10 +327,6 @@ async def _federation_reputation_callback(self, source, message): malicious_nodes, ) - @event_handler( - nebula_pb2.FederationMessage, - nebula_pb2.FederationMessage.Action.FEDERATION_MODELS_INCLUDED, - ) async def _federation_federation_models_included_callback(self, source, message): logging.info(f"πŸ“ handle_federation_message | Trigger | Received aggregation finished message from {source}") try: @@ -390,54 +347,45 @@ async def _federation_federation_models_included_callback(self, source, message) finally: await self.cm.get_connections_lock().release_async() - """ ############################## # Mobility callbacks # - ############################## + ############################## """ - @event_handler( - nebula_pb2.ConnectionMessage, - nebula_pb2.ConnectionMessage.Action.LATE_CONNECT, - ) async def _connection_late_connect_callback(self, source, message): logging.info(f"πŸ”— handle_connection_message | Trigger | Received late connect message from {source}") # Verify if it's a confirmation message from a previous late connection message sent to source if await self.nm.waiting_confirmation_from(source): await self.nm.confirmation_received(source, confirmation=True) return - + if not self.get_initialization_status(): - logging.info(f"❗️ Connection refused | Device not initialized yet...") - return - + logging.info("❗️ Connection refused | Device not initialized yet...") + return + if self.nm.accept_connection(source, joining=True): - logging.info(f"πŸ”— handle_connection_message | Late connection accepted | source: {source}") + logging.info(f"πŸ”— handle_connection_message | Late connection accepted | source: {source}") await self.cm.connect(source, direct=True) - + # Verify conenction is accepted conf_msg = self.cm.create_message("connection", "late_connect") await self.cm.send_message(source, conf_msg) - - ct_actions , df_actions = self.nm.get_actions() - if len(ct_actions): - #for addr in ct_actions.split(): + await self.nm.register_late_neighbor(source, joinning_federation=True) + + ct_actions, df_actions = self.nm.get_actions() + if len(ct_actions): + # for addr in ct_actions.split(): cnt_msg = self.cm.create_message("link", "connect_to", addrs=ct_actions) await self.cm.send_message(source, cnt_msg) - - if len(df_actions): - #for addr in df_actions.split(): + + if len(df_actions): + # for addr in df_actions.split(): df_msg = self.cm.create_message("link", "disconnect_from", addrs=df_actions) - await self.cm.send_message(source, df_msg) + await self.cm.send_message(source, df_msg) - await self.nm.register_late_neighbor(source, joinning_federation=True) else: - logging.info(f"❗️ Late connection NOT accepted | source: {source}") + logging.info(f"❗️ Late connection NOT accepted | source: {source}") - @event_handler( - nebula_pb2.ConnectionMessage, - nebula_pb2.ConnectionMessage.Action.RESTRUCTURE, - ) async def _connection_restructure_callback(self, source, message): logging.info(f"πŸ”— handle_connection_message | Trigger | Received restructure message from {source}") # Verify if it's a confirmation message from a previous restructure connection message sent to source @@ -446,104 +394,107 @@ async def _connection_restructure_callback(self, source, message): return if not self.get_initialization_status(): - logging.info(f"❗️ Connection refused | Device not initialized yet...") - return + logging.info("❗️ Connection refused | Device not initialized yet...") + return if self.nm.accept_connection(source, joining=False): logging.info(f"πŸ”— handle_connection_message | Trigger | restructure connection accepted from {source}") await self.cm.connect(source, direct=True) - - #conf_msg = self.cm.mm.generate_connection_message(nebula_pb2.ConnectionMessage.Action.RESTRUCTURE) + + # conf_msg = self.cm.mm.generate_connection_message(nebula_pb2.ConnectionMessage.Action.RESTRUCTURE) conf_msg = self.cm.create_message("connection", "restructure") - + await self.cm.send_message(source, conf_msg) - - ct_actions , df_actions = self.nm.get_actions() - if len(ct_actions): - #cnt_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.CONNECT_TO, ct_actions) + + ct_actions, df_actions = self.nm.get_actions() + if len(ct_actions): + # cnt_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.CONNECT_TO, ct_actions) cnt_msg = self.cm.create_message("link", "connect_to", addrs=ct_actions) await self.cm.send_message(source, cnt_msg) - - if len(df_actions): - #df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, df_actions) + + if len(df_actions): + # df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, df_actions) df_msg = self.cm.create_message("link", "disconnect_from", addrs=df_actions) await self.cm.send_message(source, df_msg) - - await self.nm.register_late_neighbor(source, joinning_federation=False) + + await self.nm.register_late_neighbor(source, joinning_federation=False) else: logging.info(f"❗️ handle_connection_message | Trigger | restructure connection denied from {source}") await asyncio.sleep(1) - #await self.cm.disconnect(source, mutual_disconnection=True) - - @event_handler(nebula_pb2.DiscoverMessage, nebula_pb2.DiscoverMessage.Action.DISCOVER_JOIN) + # await self.cm.disconnect(source, mutual_disconnection=True) + async def _discover_discover_join_callback(self, source, message): - logging.info(f"πŸ” handle_discover_message | Trigger | Received discover_join message from {source} ") - #TODO caso para el starter recibir antes de iniciar federacion + logging.info(f"πŸ” handle_discover_message | Trigger | Received discover_join message from {source} ") + # TODO caso para el starter recibir antes de iniciar federacion if len(self.get_federation_nodes()) > 0: await self.trainning_in_progress_lock.acquire_async() - model, rounds, round = await self.cm.propagator.get_model_information(source, "stable") if self.get_round() > 0 else await self.cm.propagator.get_model_information(source, "initialization") + model, rounds, round = ( + await self.cm.propagator.get_model_information(source, "stable") + if self.get_round() > 0 + else await self.cm.propagator.get_model_information(source, "initialization") + ) await self.trainning_in_progress_lock.release_async() if round != -1: epochs = self.config.participant["training_args"]["epochs"] - msg = self.cm.create_message("offer", - "offer_model", - len(self.get_federation_nodes()), - 0, - parameters=model, - rounds=rounds, - round=round, - epochs=epochs - ) + msg = self.cm.create_message( + "offer", + "offer_model", + len(self.get_federation_nodes()), + 0, + parameters=model, + rounds=rounds, + round=round, + epochs=epochs, + ) await self.cm.send_offer_model(source, msg) else: logging.info("Discover join received before federation is running..") # starter node is going to send info to the new node else: logging.info(f"πŸ”— Dissmissing discover join from {source} | no active connections at the moment") - - @event_handler( - nebula_pb2.DiscoverMessage, - nebula_pb2.DiscoverMessage.Action.DISCOVER_NODES, - ) + async def _discover_discover_nodes_callback(self, source, message): logging.info(f"πŸ” handle_discover_message | Trigger | Received discover_node message from {source} ") - #self.nm.meet_node(source) + # self.nm.meet_node(source) if len(self.get_federation_nodes()) > 0: - #msg = self.cm.mm.generate_offer_message(nebula_pb2.OfferMessage.Action.OFFER_METRIC, len(self.get_federation_nodes()), self.trainer.get_current_loss()) - msg = self.cm.create_message("offer", - "offer_metric", - n_neighbors=len(self.get_federation_nodes()), - loss=self.trainer.get_current_loss() - ) + # msg = self.cm.mm.generate_offer_message(nebula_pb2.OfferMessage.Action.OFFER_METRIC, len(self.get_federation_nodes()), self.trainer.get_current_loss()) + msg = self.cm.create_message( + "offer", + "offer_metric", + n_neighbors=len(self.get_federation_nodes()), + loss=self.trainer.get_current_loss(), + ) await self.cm.send_message(source, msg) else: - logging.info(f"πŸ”— Dissmissing discover nodes from {source} | no active connections at the moment") - - @event_handler( - nebula_pb2.OfferMessage, - nebula_pb2.OfferMessage.Action.OFFER_MODEL, - ) + logging.info(f"πŸ”— Dissmissing discover nodes from {source} | no active connections at the moment") + async def _offer_offer_model_callback(self, source, message): logging.info(f"πŸ” handle_offer_message | Trigger | Received offer_model message from {source}") self.nm.meet_node(source) if self.nm.still_waiting_for_candidates(): try: model_compressed = message.parameters - if self.nm.accept_model_offer(source, model_compressed, message.rounds, message.round, message.epochs, message.n_neighbors, message.loss): + if self.nm.accept_model_offer( + source, + model_compressed, + message.rounds, + message.round, + message.epochs, + message.n_neighbors, + message.loss, + ): logging.info(f"πŸ”§ Model accepted from offer | source: {source}") else: logging.info(f"❗️ Model offer discarded | source: {source}") - self.nm.add_to_discarded_offers(source) + self.nm.add_to_discarded_offers(source) except RuntimeError: logging.info(f"❗️ Error proccesing offer model from {source}") else: - logging.info(f"❗️ handfle_offer_message | NOT accepting offers | restructure: {self.nm.get_restructure_process_lock().locked()} | waiting candidates: {self.nm.still_waiting_for_candidates()}") - self.nm.add_to_discarded_offers(source) - - @event_handler( - nebula_pb2.OfferMessage, - nebula_pb2.OfferMessage.Action.OFFER_METRIC, - ) + logging.info( + f"❗️ handfle_offer_message | NOT accepting offers | restructure: {self.nm.get_restructure_process_lock().locked()} | waiting candidates: {self.nm.still_waiting_for_candidates()}" + ) + self.nm.add_to_discarded_offers(source) + async def _offer_offer_metric_callback(self, source, message): logging.info(f"πŸ” handle_offer_message | Trigger | Received offer_metric message from {source}") self.nm.meet_node(source) @@ -551,43 +502,38 @@ async def _offer_offer_metric_callback(self, source, message): n_neighbors = message.n_neighbors loss = message.loss self.nm.add_candidate(source, n_neighbors, loss) - - @event_handler( - nebula_pb2.LinkMessage, - nebula_pb2.LinkMessage.Action.CONNECT_TO, - ) + async def _link_connect_to_callback(self, source, message): logging.info(f"πŸ”— handle_link_message | Trigger | Received connect_to message from {source}") addrs = message.addrs for addr in addrs.split(): - #await self.cm.connect(addr, direct=True) - #self.nm.update_neighbors(addr) + # await self.cm.connect(addr, direct=True) + # self.nm.update_neighbors(addr) self.nm.meet_node(addr) - - @event_handler( - nebula_pb2.LinkMessage, - nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, - ) + async def _link_disconnect_from_callback(self, source, message): logging.info(f"πŸ”— handle_link_message | Trigger | Received disconnect_from message from {source}") addrs = message.addrs for addr in addrs.split(): await self.cm.disconnect(source, mutual_disconnection=False) - await self.nm.update_neighbors(addr, remove=True) - + await self.nm.update_neighbors(addr, remove=True) + """ ############################## # ENGINE FUNCTIONALITY # - ############################## + ############################## """ - def register_message_events_callbacks(self): me_dict = self.cm.get_messages_events() - message_events = [(message_name, message_action) for (message_name, message_actions) in me_dict.items() for message_action in message_actions] - logging.info(f"{message_events}") + message_events = [ + (message_name, message_action) + for (message_name, message_actions) in me_dict.items() + for message_action in message_actions + ] + # logging.info(f"{message_events}") for event_type, action in message_events: callback_name = f"_{event_type}_{action}_callback" - logging.info(f"Searching callback named: {callback_name}") + # logging.info(f"Searching callback named: {callback_name}") method = getattr(self, callback_name, None) if callable(method): @@ -600,45 +546,46 @@ async def trigger_event(self, message_event): async def _aditional_node_start(self): self.update_sinchronized_status(False) logging.info(f"Aditional node | {self.addr} | going to stablish connection with federation") + self.nm.late_config() await self.nm.start_late_connection_process() # continue .. - #asyncio.create_task(self.nm.stop_not_selected_connections()) + # asyncio.create_task(self.nm.stop_not_selected_connections()) logging.info("Creating trainer service to start the federation process..") asyncio.create_task(self._start_learning_late()) def get_push_acceleration(self): return self.nm.get_push_acceleration() - + async def set_pushed_done(self, rounds_push): await self.nm.set_rounds_pushed(rounds_push) - + async def apply_weight_strategy(self, pending_models): - if self.mobility and self.nm.fast_reboot_on(): + if self.mobility: await self.nm.apply_weight_strategy(pending_models) return pending_models else: return pending_models - + async def update_model_learning_rate(self, new_lr): await self.trainning_in_progress_lock.acquire_async() logging.info("Update | learning rate modified...") self.trainer.update_model_learning_rate(new_lr) - await self.trainning_in_progress_lock.release_async() - + await self.trainning_in_progress_lock.release_async() + async def _start_learning_late(self): await self.learning_cycle_lock.acquire_async() try: model_serialized, rounds, round, _epochs = await self.nm.get_trainning_info() - self.total_rounds = rounds # self.config.participant["scenario_args"]["rounds"] #rounds - epochs = _epochs # self.config.participant["training_args"]["epochs"] #_epochs + self.total_rounds = rounds # self.config.participant["scenario_args"]["rounds"] #rounds + epochs = _epochs # self.config.participant["training_args"]["epochs"] #_epochs await self.get_round_lock().acquire_async() self.round = round await self.get_round_lock().release_async() await self.learning_cycle_lock.release_async() print_msg_box( - msg="Starting Federated Learning process...", - indent=2, - title="Start of the experiment late", + msg="Starting Federated Learning process...", + indent=2, + title="Start of the experiment late", ) logging.info(f"Trainning setup | total rounds: {rounds} | current round: {round} | epochs: {epochs}") direct_connections = await self.cm.get_addrs_current_connections(only_direct=True) @@ -662,16 +609,16 @@ async def _start_learning_late(self): pass except RuntimeError: pass - + self.trainer.set_epochs(epochs) self.trainer.set_current_round(round) self.trainer.create_trainer() await self._learning_cycle() - + finally: if await self.learning_cycle_lock.locked_async(): await self.learning_cycle_lock.release_async() - + async def create_trainer_module(self): asyncio.create_task(self._start_learning()) logging.info("Started trainer module...") @@ -694,7 +641,7 @@ async def start_communications(self): logging.info(f"Connections verified: {current_connections}") if self.mobility: logging.info("Building NodeManager configurations...") - await self.nm.set_confings() + await self.nm.set_configs() await self._reporter.start() await self.cm.deploy_additional_services() await asyncio.sleep(self.config.participant["misc_args"]["grace_time_connection"] // 2) @@ -710,7 +657,7 @@ async def deploy_federation(self): while not await self.cm.check_federation_ready(): await asyncio.sleep(1) logging.info("Sending FEDERATION_START to neighbors...") - #message = self.cm.mm.generate_federation_message(nebula_pb2.FederationMessage.Action.FEDERATION_START) + # message = self.cm.mm.generate_federation_message(nebula_pb2.FederationMessage.Action.FEDERATION_START) message = self.cm.create_message("federation", "federation_start") await self.cm.send_message_to_neighbors(message) await self.get_federation_ready_lock().release_async() @@ -721,7 +668,7 @@ async def deploy_federation(self): else: logging.info("Sending FEDERATION_READY to neighbors...") - #message = self.cm.mm.generate_federation_message(nebula_pb2.FederationMessage.Action.FEDERATION_READY) + # message = self.cm.mm.generate_federation_message(nebula_pb2.FederationMessage.Action.FEDERATION_READY) message = self.cm.create_message("federation", "federation_ready") await self.cm.send_message_to_neighbors(message) logging.info("πŸ’€ Waiting until receiving the start signal from the start node") @@ -823,6 +770,13 @@ async def _waiting_model_updates(self): else: logging.error("Aggregation finished with no parameters") + def print_round_information(self): + print_msg_box( + msg=f"Round {self.round} of {self.total_rounds} started.", + indent=2, + title="Round information", + ) + async def _learning_cycle(self): while self.round is not None and self.round < self.total_rounds: print_msg_box( @@ -893,12 +847,12 @@ async def _extended_learning_cycle(self): functionalities. The method is called in the _learning_cycle method. """ pass - + async def _additional_mobility_actions(self): if not self.mobility: return logging.info("πŸ”„ Starting additional mobility actions...") - #self.trainer.show_current_learning_rate() + # self.trainer.show_current_learning_rate() await self.nm.check_robustness() action = await self.nm.check_external_connection_service_status() if action: @@ -944,10 +898,10 @@ def reputation_calculation(self, aggregated_models_weights): async def send_reputation(self, malicious_nodes): logging.info(f"Sending REPUTATION to the rest of the topology: {malicious_nodes}") - #message = self.cm.mm.generate_federation_message( + # message = self.cm.mm.generate_federation_message( # nebula_pb2.FederationMessage.Action.REPUTATION, malicious_nodes - #) - message = self.cm.create_message("federation","reputation", arguments=[str(arg) for arg in (malicious_nodes)]) + # ) + message = self.cm.create_message("federation", "reputation", arguments=[str(arg) for arg in (malicious_nodes)]) await self.cm.send_message_to_neighbors(message) diff --git a/nebula/core/neighbormanagement/fastreboot.py b/nebula/core/neighbormanagement/fastreboot.py index 2fdc0f153..ef312315a 100644 --- a/nebula/core/neighbormanagement/fastreboot.py +++ b/nebula/core/neighbormanagement/fastreboot.py @@ -1,8 +1,8 @@ -import asyncio import logging +from typing import TYPE_CHECKING from nebula.core.utils.locker import Locker -from typing import TYPE_CHECKING + if TYPE_CHECKING: from nebula.core.neighbormanagement.nodemanager import NodeManager @@ -11,137 +11,142 @@ MAX_ROUNDS = 20 DEFAULT_WEIGHT_MODIFIER = 3 -class FastReboot(): - + +class FastReboot: def __init__( - self, - node_manager : "NodeManager", - max_rounds_application = MAX_ROUNDS, # Max rounds to be applied FastReboot - weight_modifier = DEFAULT_WEIGHT_MODIFIER, - default_learning_rate = VANILLA_LEARNING_RATE, # Stable value for learning rate - upgrade_learning_rate = FR_LEARNING_RATE, # Increased value for learning rate - ): + self, + node_manager: "NodeManager", + max_rounds_application=MAX_ROUNDS, # Max rounds to be applied FastReboot + weight_modifier=DEFAULT_WEIGHT_MODIFIER, + default_learning_rate=VANILLA_LEARNING_RATE, # Stable value for learning rate + upgrade_learning_rate=FR_LEARNING_RATE, # Increased value for learning rate + ): + logging.info("🌐 Initializing FastReboot") self._node_manager = node_manager self._max_rounds = max_rounds_application self._weight_mod_value = weight_modifier self._default_lr = default_learning_rate self._upgrade_lr = upgrade_learning_rate self._current_lr = default_learning_rate - self._learning_rate_lock = Locker(name="learning_rate_lock", async_lock=True) - + self._learning_rate_lock = Locker(name="learning_rate_lock", async_lock=True) + self._weight_modifier = {} self._weight_modifier_lock = Locker(name="weight_modifier_lock", async_lock=True) self._rounds_pushed_lock = Locker(name="rounds_pushed_lock", async_lock=True) self._rounds_pushed = 0 - + self._fr_in_progress = False - + @property def nm(self): return self._node_manager - + async def set_rounds_pushed(self, rp): await self._rounds_pushed_lock.acquire_async() self.rounds_pushed = rp await self._rounds_pushed_lock.release_async() - + async def get_current_learning_rate(self): await self._learning_rate_lock.acquire_async() lr = self._current_lr await self._learning_rate_lock.release_async() return lr - + async def discard_fastreboot_for(self, addr): await self._weight_modifier_lock.acquire_async() try: del self._weight_modifier[addr] - except KeyError as e: + except KeyError: pass - await self._weight_modifier_lock.release_async() - + await self._weight_modifier_lock.release_async() + async def _set_learning_rate(self, lr): await self._learning_rate_lock.acquire_async() self._current_lr = lr await self._learning_rate_lock.release_async() - + async def add_fastReboot_addr(self, addr): await self._weight_modifier_lock.acquire_async() - if not addr in self._weight_modifier: + if addr not in self._weight_modifier: self._fr_in_progress = True - wm = self._weight_mod_value - logging.info(f"πŸ“ Registering | FastReboot registered for source {addr} | round application: {self._max_rounds} | multiplier value: {wm}") - self._weight_modifier[addr] = (wm,1) + wm = self._weight_mod_value + logging.info( + f"πŸ“ Registering | FastReboot registered for source {addr} | round application: {self._max_rounds} | multiplier value: {wm}" + ) + self._weight_modifier[addr] = (wm, 1) await self._set_learning_rate(self._upgrade_lr) current_lr = await self.get_current_learning_rate() await self.nm.update_learning_rate(current_lr) await self._weight_modifier_lock.release_async() - + async def _remove_weight_modifier(self, addr): logging.info(f"πŸ“ Removing | FastReboot removed for source {addr}") del self._weight_modifier[addr] - + async def _weight_modifiers_empty(self): await self._weight_modifier_lock.acquire_async() empty = False if self._weight_modifier else True await self._weight_modifier_lock.release_async() return empty - - async def apply_weight_strategy(self, updates: dict): + + async def apply_weight_strategy(self, updates: dict): if await self._weight_modifiers_empty(): if self._fr_in_progress: await self._end_fastreboot() return - logging.info(f"πŸ”„ Applying FastReboot Strategy...") + logging.info("πŸ”„ Applying FastReboot Strategy...") # We must lower the weight_modifier value if a round jump has been occured # as many times as rounds have been jumped if self._rounds_pushed: - logging.info(f"πŸ”„ There are rounds being pushed...") + logging.info("πŸ”„ There are rounds being pushed...") for i in range(0, self.rounds_pushed): - logging.info(f"πŸ”„ Update | weights being updated cause of push...") + logging.info("πŸ”„ Update | weights being updated cause of push...") self._update_weight_modifiers() - self.rounds_pushed = 0 - for addr,update in updates.items(): + self.rounds_pushed = 0 + for addr, update in updates.items(): weightmodifier, rounds = await self._get_weight_modifier(addr) if weightmodifier != 1: - logging.info (f"πŸ“ Appliying FastReboot strategy | addr: {addr} | multiplier value: {weightmodifier}, rounds applied: {rounds}") + logging.info( + f"πŸ“ Appliying FastReboot strategy | addr: {addr} | multiplier value: {weightmodifier}, rounds applied: {rounds}" + ) model, weight = update - updates.update({addr: (model, weight*weightmodifier)}) + updates.update({addr: (model, weight * weightmodifier)}) await self._update_weight_modifiers() - - #TODO integrar en el get_wegith_modifier para que se actualice cuando se pide y ahi se compruebe si hay q eliminar una entrada + + # TODO integrar en el get_wegith_modifier para que se actualice cuando se pide y ahi se compruebe si hay q eliminar una entrada async def _update_weight_modifiers(self): await self._weight_modifier_lock.acquire_async() if self._weight_modifier: - logging.info(f"πŸ”„ Update | weights being updated") + logging.info("πŸ”„ Update | weights being updated") remove_addrs = [] - for addr, (weight,rounds) in self._weight_modifier.items(): - new_weight = weight - 1/(rounds**2) + for addr, (weight, rounds) in self._weight_modifier.items(): + new_weight = weight - 1 / (rounds**2) rounds = rounds + 1 if new_weight > 1 and rounds <= self._max_rounds: - self._weight_modifier[addr] = (new_weight, rounds) + self._weight_modifier[addr] = (new_weight, rounds) else: remove_addrs.append(addr) for a in remove_addrs: await self._remove_weight_modifier(a) await self._weight_modifier_lock.release_async() - + async def _end_fastreboot(self): await self._weight_modifier_lock.acquire_async() if not self._weight_modifier and await self._is_lr_modified(): - logging.info(f"πŸ”„ Finishing | FastReboot is completed") + logging.info("πŸ”„ Finishing | FastReboot is completed") self._fr_in_progress = False await self._set_learning_rate(self._default_lr) await self.nm.update_learning_rate(self._default_lr) await self._weight_modifier_lock.release_async() - + async def _get_weight_modifier(self, addr): await self._weight_modifier_lock.acquire_async() - wm = self._weight_modifier.get(addr, (1,0)) + wm = self._weight_modifier.get(addr, (1, 0)) await self._weight_modifier_lock.release_async() return wm - + async def _is_lr_modified(self): await self._learning_rate_lock.acquire_async() mod = self._current_lr == self._upgrade_lr await self._learning_rate_lock.release_async() - return mod \ No newline at end of file + return mod diff --git a/nebula/core/neighbormanagement/modelhandlers/stdmodelhandler.py b/nebula/core/neighbormanagement/modelhandlers/stdmodelhandler.py index 097f5fc0a..861edd1fa 100644 --- a/nebula/core/neighbormanagement/modelhandlers/stdmodelhandler.py +++ b/nebula/core/neighbormanagement/modelhandlers/stdmodelhandler.py @@ -1,9 +1,8 @@ from nebula.core.neighbormanagement.modelhandlers.modelhandler import ModelHandler from nebula.core.utils.locker import Locker -import logging + class STDModelHandler(ModelHandler): - def __init__(self): self.model = None self.rounds = 0 @@ -11,7 +10,7 @@ def __init__(self): self.epochs = 0 self.model_lock = Locker(name="model_lock") self.params_lock = Locker(name="param_lock") - + def set_config(self, config): """ Args: @@ -22,32 +21,31 @@ def set_config(self, config): self.params_lock.acquire() self.rounds = config[0] if config[1] > self.round: - self.round = config[1] + self.round = config[1] self.epochs = config[2] self.params_lock.release() - + def accept_model(self, model): """ - save only first model received to set up own model later + save only first model received to set up own model later """ if not self.model_lock.locked(): - logging.info(" ### First model acquire ###") self.model_lock.acquire() self.model = model return True - + async def get_model(self, model): """ Returns: neccesary data to create trainer """ - if self.model is not None: + if self.model is not None: return (self.model, self.rounds, self.round, self.epochs) else: return (None, 0, 0, 0) def pre_process_model(self): """ - no pre-processing defined + no pre-processing defined """ - pass \ No newline at end of file + pass diff --git a/nebula/core/neighbormanagement/momentum.py b/nebula/core/neighbormanagement/momentum.py index 0f55f70bd..7b1f8e199 100644 --- a/nebula/core/neighbormanagement/momentum.py +++ b/nebula/core/neighbormanagement/momentum.py @@ -1,37 +1,45 @@ -import asyncio import logging -from collections import deque +from collections import OrderedDict, deque +from collections.abc import Callable +from typing import TYPE_CHECKING, Annotated + +import numpy as np + from nebula.core.utils.helper import cosine_metric from nebula.core.utils.locker import Locker -import numpy as np -from typing import TYPE_CHECKING, Callable, OrderedDict, Optional -from typing_extensions import Annotated if TYPE_CHECKING: from nebula.core.neighbormanagement.nodemanager import NodeManager -SimilarityMetricType = Callable[[OrderedDict, OrderedDict, bool], Optional[float]] +SimilarityMetricType = Callable[[OrderedDict, OrderedDict, bool], float | None] MappingSimilarityType = Callable[[float, float], Annotated[float, "Value in (0, 1]"]] -MAX_HISTORIC_SIZE = 10 # Number of historic data storaged -GLOBAL_PRIORITY = 0.5 # Parameter to priorize global vs local metrics +MAX_HISTORIC_SIZE = 10 # Number of historic data storaged +GLOBAL_PRIORITY = 0.8 # Parameter to priorize global vs local metrics EPSILON = 0.001 SIGMOID_THRESHOLD = 0.92 -TOLERANCE_THRESHOLD = 2 # Threshold to start appliying full penalty +TOLERANCE_THRESHOLD = 2 # Threshold to start appliying full penalty SMOOTH_FACTOR = 0.5 +# Maybe it should change according to number of updates received +MOMENTUM_ATENUATION_FACTOR = 0.85 # Proportion of past momentums +MOMENTUM_LEARNING_STEP = 1 -class Momentum(): +class Momentum: def __init__( - self, + self, node_manager: "NodeManager", nodes, dispersion_penalty=True, - global_priority=GLOBAL_PRIORITY, - similarity_metric : SimilarityMetricType = cosine_metric, - mapping_similarity : MappingSimilarityType = lambda sim_value, e=EPSILON: e + ((sim_value + 1) / 2), + global_priority=GLOBAL_PRIORITY, + similarity_metric: SimilarityMetricType = cosine_metric, + mapping_similarity: MappingSimilarityType = lambda sim_value, e=EPSILON: e + ((sim_value + 1) / 2), ): + logging.info("🌐 Initializing Momemtum strategy") self._node_manager = node_manager + self._momentum_historic = {node_id: deque(maxlen=MAX_HISTORIC_SIZE) for node_id in nodes} + self._momentum_historic_lock = Locker(name="_momentum_historic_lock", async_lock=True) + self._previous_momentum = 0 self._similarities_historic = {node_id: deque(maxlen=MAX_HISTORIC_SIZE) for node_id in nodes} self._similarities_historic_lock = Locker(name="_similarities_historic_lock", async_lock=True) self._model_similarity_metric_lock = Locker(name="_model_similarity_metric_lock", async_lock=True) @@ -44,21 +52,21 @@ def __init__( @property def nm(self): return self._node_manager - + @property def msm(self): return self._model_similarity_metric - + @property def msf(self): return self._mapping_similarity_func async def _add_similarity_to_node(self, node_id, sim_value): - logging.info(f"Adding | node ID: {node_id}, cossine similarity value: {sim_value}") + # logging.info(f"Adding | node ID: {node_id}, cossine similarity value: {sim_value}") self._similarities_historic_lock.acquire_async() self._similarities_historic[node_id].append(sim_value) self._similarities_historic_lock.release_async() - + async def _get_similarity_historic(self, addrs): """ Get historic storaged for node IDs on 'addrs' @@ -72,23 +80,27 @@ async def _get_similarity_historic(self, addrs): if key in addrs: historic[key] = value self._similarities_historic_lock.release_async() - return historic + return historic async def update_node(self, node_id, remove=False): self._similarities_historic_lock.acquire_async() + self._momentum_historic_lock.acquire_async() + logging.info(f"Update | addr: {node_id}, remove: {remove}") if remove: self._similarities_historic.pop(node_id, None) else: self._similarities_historic.update({node_id: deque(maxlen=MAX_HISTORIC_SIZE)}) + self._momentum_historic.update({node_id: deque(maxlen=MAX_HISTORIC_SIZE)}) + self._momentum_historic_lock.release_async() self._similarities_historic_lock.release_async() - + async def change_similarity_metric(self, new_metric: SimilarityMetricType, new_mapping: MappingSimilarityType): self._model_similarity_metric_lock.acquire_async() self.msm = new_metric self.msf = new_mapping # maybe we should remove historic data due to incongruous data self._model_similarity_metric_lock.release_async() - + async def _calculate_similarities(self, updates: dict): """ Function to calculate similarity between local model and models received @@ -97,68 +109,96 @@ async def _calculate_similarities(self, updates: dict): Args: updates (dict): {node ID: model} """ - logging.info(f"Calculate | Model Similarity values are being calculated...") + logging.info("Calculate | Model Similarity values are being calculated...") model = self.nm.engine.trainer.get_model_parameters() - for addr,update in updates.items(): + for addr, update in updates.items(): if addr == self._addr: continue + updt_model, _ = update sim_value = self.msm( model, - update, + updt_model, similarity=True, ) - logging.info(f"Model similarity for node: {addr}, sim: {sim_value}") + logging.info(f"Model similarity for node: {addr}, sim: {sim_value:.4f}") await self._add_similarity_to_node(addr, sim_value) - + def _calculate_dispersion_penalty(self, historic: dict, updates: dict): from math import sqrt + logging.info("Calculate | Dispersion penalty") - round_similarities = [(addr, n_hist[-1]) for addr,n_hist in historic.items() if n_hist] + round_similarities = [(addr, n_hist[-1]) for addr, n_hist in historic.items() if n_hist] if round_similarities: mean_similarity = np.mean(round_similarities) std_similarity = np.std(round_similarities) n_updates = len(updates) - 1 logging.info(f"Calculate | mean similarity: {mean_similarity}, standar deviation: {std_similarity}") - for addr,sim in round_similarities: + for addr, sim in round_similarities: if abs(sim - mean_similarity) < TOLERANCE_THRESHOLD * std_similarity: logging.info(f"Penalty | Dispersion is lower than threshold, for node: {addr}") - penalty = (SMOOTH_FACTOR * (abs(sim - mean_similarity) / (std_similarity + EPSILON))) * (1/sqrt(n_updates)) + penalty = (SMOOTH_FACTOR * (abs(sim - mean_similarity) / (std_similarity + EPSILON))) * ( + 1 / sqrt(n_updates) + ) else: logging.info(f"Penalty | Dispersion is higher than threshold, for node: {addr}") - penalty = (abs(sim - mean_similarity) / (std_similarity + EPSILON)) * (1/sqrt(n_updates)) - + penalty = (abs(sim - mean_similarity) / (std_similarity + EPSILON)) * (1 / sqrt(n_updates)) + penalty = min(1.0, max(0.0, penalty)) logging.info(f"Penalty value: {penalty}") - dispersion_penalty = 1 - penalty - + dispersion_penalty = 1 - penalty + def map_value(sim_value, e=EPSILON): - return e + ((sim_value + 1) / 2) - + return e + ((sim_value + 1) / 2) + async def calculate_momentum_weights(self, updates: dict): - if not updates: + if len(updates) == 1: return logging.info("Calculate | Momemtum weights are being calculated...") - self._model_similarity_metric_lock.acquire_async() - await self._calculate_similarities(updates) # Calculate similarity value between self model and updates received - historic = await self._get_similarity_historic(updates.keys()) # Get historic similarities values from nodes that has sent update this round - + self._model_similarity_metric_lock.acquire_async() + await self._calculate_similarities( + updates + ) # Calculate similarity value between self model and updates received + historic = await self._get_similarity_historic( + updates.keys() + ) # Get historic similarities values from nodes that has sent update this round + def sigmoid(similarity, k=2.5): if similarity >= SIGMOID_THRESHOLD: # threshold to consider better updates sigmoid = 1 else: sigmoid = 1 / (1 + np.exp(-k * (similarity))) - return sigmoid - + return sigmoid + + # Calculate round local momentum for each node for node_addr, n_hist in historic.items(): if not n_hist or node_addr == self._addr: - continue - sim_value = n_hist[-1] # Get last similarity value - mapped_sim_value = self.msf(sim_value) # Mapped into [0, 1] interval + continue + sim_value = n_hist[-1] # Get last similarity value + mapped_sim_value = self.msf(sim_value) # Mapped into [0, 1] interval smoothed_value = sigmoid(mapped_sim_value) - adjusted_weight = smoothed_value * self._global_prio + (1 - self._global_prio) * mapped_sim_value - logging.info(f"Momemtum values | adjusted_weight: {adjusted_weight}, map_value: {mapped_sim_value}, smoothed_value: {smoothed_value}") - - if self._dispersion_penalty: - self._calculate_dispersion_penalty(historic, updates) - - self._model_similarity_metric_lock.release_async() \ No newline at end of file + local_round_momentum = smoothed_value * self._global_prio + (1 - self._global_prio) * mapped_sim_value + self._momentum_historic[node_addr].append(local_round_momentum) + + # Calculate round neighborhood momentum + round_neighborhood_momentum = np.mean([ + self._momentum_historic[node_addr][-1] for node_addr, _ in historic.items() + ]) + if not self._previous_momentum: + self._previous_momentum = round_neighborhood_momentum + else: + self._previous_momentum = ( + MOMENTUM_ATENUATION_FACTOR * self._previous_momentum + + (1 - MOMENTUM_ATENUATION_FACTOR) * round_neighborhood_momentum + ) + + for node_addr in historic.keys(): + model, weight = updates[node_addr] + adjusted_weight = self._previous_momentum * weight # Aplicar momentum como factor de ajuste + + # updates[node_addr] = (model, adjusted_weight) + + logging.info( + f"Node {node_addr}: sim={sim_value:.3f}, momentum_vec={self._previous_momentum:.3f}, adjusted_weight={adjusted_weight:.3f}" + ) + + self._model_similarity_metric_lock.release_async() diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index bb9910de0..d00bcd5c8 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -1,44 +1,43 @@ import asyncio import logging -import os - -import importlib +from typing import TYPE_CHECKING -from nebula.core.utils.locker import Locker +from nebula.addons.functions import print_msg_box from nebula.core.neighbormanagement.candidateselection.candidateselector import factory_CandidateSelector -from nebula.core.neighbormanagement.modelhandlers.modelhandler import factory_ModelHandler -from nebula.core.neighbormanagement.neighborpolicies.neighborpolicy import factory_NeighborPolicy -from nebula.core.pb import nebula_pb2 -from nebula.core.network.communications import CommunicationsManager from nebula.core.neighbormanagement.fastreboot import FastReboot +from nebula.core.neighbormanagement.modelhandlers.modelhandler import factory_ModelHandler from nebula.core.neighbormanagement.momentum import Momentum -from nebula.addons.functions import print_msg_box +from nebula.core.neighbormanagement.neighborpolicies.neighborpolicy import factory_NeighborPolicy +from nebula.core.utils.locker import Locker -from typing import TYPE_CHECKING if TYPE_CHECKING: from nebula.core.engine import Engine RESTRUCTURE_COOLDOWN = 5 -class NodeManager(): - + +class NodeManager: def __init__( - self, + self, + aditional_participant, topology, model_handler, push_acceleration, - engine : "Engine", - fastreboot=True, - momentum=True, + engine: "Engine", + fastreboot=False, + momentum=False, ): - self.topology = "fully"#topology - print_msg_box(msg=f"Starting NodeManager module...\nTopology: {self.topology}", indent=2, title="NodeManager module") + self._aditional_participant = aditional_participant + self.topology = "fully" # topology + print_msg_box( + msg=f"Starting NodeManager module...\nTopology: {self.topology}", indent=2, title="NodeManager module" + ) logging.info("🌐 Initializing Node Manager") self._engine = engine self.config = engine.get_config() logging.info("Initializing Neighbor policy") self._neighbor_policy = factory_NeighborPolicy(self.topology) - logging.info("Initializing Candidate Selector") + logging.info("Initializing Candidate Selector") self._candidate_selector = factory_CandidateSelector(self.topology) logging.info("Initializing Model Handler") self._model_handler = factory_ModelHandler(model_handler) @@ -52,20 +51,18 @@ def __init__( self.restructure = False self._restructure_cooldown = RESTRUCTURE_COOLDOWN self.discarded_offers_addr_lock = Locker(name="discarded_offers_addr_lock") - self.discarded_offers_addr = [] + self.discarded_offers_addr = [] self._push_acceleration = push_acceleration - + self.synchronizing_rounds = False - + self._fast_reboot_status = fastreboot self._momemtum_status = momentum - - #self.set_confings() @property def engine(self): return self._engine - + @property def neighbor_policy(self): return self._neighbor_policy @@ -77,14 +74,18 @@ def candidate_selector(self): @property def model_handler(self): return self._model_handler - + @property def fr(self): return self._fastreboot - + + @property + def mom(self): + return self._momemtum + def fast_reboot_on(self): return self._fast_reboot_status - + def _update_restructure_cooldown(self): if self._restructure_cooldown: self._restructure_cooldown = (self._restructure_cooldown + 1) % RESTRUCTURE_COOLDOWN @@ -94,175 +95,175 @@ def _restructure_available(self): def get_push_acceleration(self): return self._push_acceleration - + def get_restructure_process_lock(self): return self._restructure_process_lock - + def set_synchronizing_rounds(self, status): self.synchronizing_rounds = status - + def get_syncrhonizing_rounds(self): - return self.synchronizing_rounds - + return self.synchronizing_rounds + async def set_rounds_pushed(self, rp): if self.fast_reboot_on(): self.fr.set_rounds_pushed(rp) - + def still_waiting_for_candidates(self): return not self.accept_candidates_lock.locked() - - async def set_confings(self): + + async def set_configs(self): """ - neighbor_policy config: - - direct connections a.k.a neighbors - - all nodes known - - self addr - - model_handler config: - - self total rounds - - self current round - - self epochs - - candidate_selector config: - - self model loss - - self weight distance - - self weight hetereogeneity + neighbor_policy config: + - direct connections a.k.a neighbors + - all nodes known + - self addr + + model_handler config: + - self total rounds + - self current round + - self epochs + + candidate_selector config: + - self model loss + - self weight distance + - self weight hetereogeneity """ - logging.info(f"Building neighbor policy configuration..") - self.neighbor_policy.set_config( - [ - await self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False), - await self.engine.cm.get_addrs_current_connections(only_direct=False, only_undirected=False, myself=False), - self.engine.addr, - self - ] - ) - logging.info(f"Building candidate selector configuration..") - self.candidate_selector.set_config( - [ - 0, - 0.5, - 0.5 - ] - ) - #self.engine.trainer.get_loss(), self.config.participant["molibity_args"]["weight_distance"], self.config.participant["molibity_args"]["weight_het"] - #self.model_handler.set_config([self.engine.get_round(), self.engine.config.participant["training_args"]["epochs"]]) + logging.info("Building neighbor policy configuration..") + self.neighbor_policy.set_config([ + await self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False), + await self.engine.cm.get_addrs_current_connections(only_direct=False, only_undirected=False, myself=False), + self.engine.addr, + self, + ]) + logging.info("Building candidate selector configuration..") + self.candidate_selector.set_config([0, 0.5, 0.5]) + # self.engine.trainer.get_loss(), self.config.participant["molibity_args"]["weight_distance"], self.config.participant["molibity_args"]["weight_het"] - if (self._fast_reboot_status): + if self._fast_reboot_status: self._fastreboot = FastReboot(self) - - self._momemtum = None - if (self._momemtum_status): - self._momemtum = Momentum(self, self.neighbor_policy.get_nodes_known(neighbors_only=True), dispersion_penalty=False) - - - ############################## - # FAST REBOOT # - ############################## - - + + self._momemtum = None + if self._momemtum_status and not self._aditional_participant: + self._momemtum = Momentum( + self, self.neighbor_policy.get_nodes_known(neighbors_only=True), dispersion_penalty=False + ) + + def late_config(self): + if self._momemtum_status: + self._momemtum = Momentum( + self, self.neighbor_policy.get_nodes_known(neighbors_only=True), dispersion_penalty=False + ) + + ############################## + # FAST REBOOT # + ############################## + async def update_learning_rate(self, new_lr): await self.engine.update_model_learning_rate(new_lr) - + async def register_late_neighbor(self, addr, joinning_federation=False): + logging.info(f"Registering | late neighbor: {addr}, joining: {joinning_federation}") self.meet_node(addr) await self.update_neighbors(addr) + if self._momemtum_status: + await self.mom.update_node(addr) if joinning_federation: if self.fast_reboot_on(): await self.fr.add_fastReboot_addr(addr) - + async def apply_weight_strategy(self, updates: dict): - if not self.fast_reboot_on(): - return - await self.fr.apply_weight_strategy(updates) + if self.fast_reboot_on(): + await self.fr.apply_weight_strategy(updates) if self._momemtum: await self._momemtum.calculate_momentum_weights(updates) - + ############################## + # CONNECTIONS # + ############################## - ############################## - # CONNECTIONS # - ############################## - - def accept_connection(self, source, joining=False): return self.neighbor_policy.accept_connection(source, joining) - + async def add_pending_connection_confirmation(self, addr): await self._update_neighbors_lock.acquire_async() await self.pending_confirmation_from_nodes_lock.acquire_async() - if not addr in self.neighbor_policy.get_nodes_known(neighbors_only=True): + if addr not in self.neighbor_policy.get_nodes_known(neighbors_only=True): logging.info(f" Addition | pending connection confirmation from: {addr}") self.pending_confirmation_from_nodes.add(addr) await self.pending_confirmation_from_nodes_lock.release_async() await self._update_neighbors_lock.release_async() - + async def _remove_pending_confirmation_from(self, addr): await self.pending_confirmation_from_nodes_lock.acquire_async() self.pending_confirmation_from_nodes.discard(addr) await self.pending_confirmation_from_nodes_lock.release_async() - + async def clear_pending_confirmations(self): await self.pending_confirmation_from_nodes_lock.acquire_async() self.pending_confirmation_from_nodes.clear() await self.pending_confirmation_from_nodes_lock.release_async() - + async def waiting_confirmation_from(self, addr): await self.pending_confirmation_from_nodes_lock.acquire_async() found = addr in self.pending_confirmation_from_nodes await self.pending_confirmation_from_nodes_lock.release_async() - return found - + return found + async def confirmation_received(self, addr, confirmation=False): logging.info(f" Update | connection confirmation received from: {addr} | confirmation: {confirmation}") if confirmation: - await self.engine.cm.connect(addr, direct=True) + await self.engine.cm.connect(addr, direct=True) await self.update_neighbors(addr) else: - self._remove_pending_confirmation_from(addr) - + self._remove_pending_confirmation_from(addr) + def add_to_discarded_offers(self, addr_discarded): self.discarded_offers_addr_lock.acquire() self.discarded_offers_addr.append(addr_discarded) self.discarded_offers_addr_lock.release() - + def need_more_neighbors(self): return self.neighbor_policy.need_more_neighbors() - + def get_actions(self): return self.neighbor_policy.get_actions() - + async def update_neighbors(self, node, remove=False): logging.info(f"Update neighbor | node addr: {node} | remove: {remove}") await self._update_neighbors_lock.acquire_async() self.neighbor_policy.update_neighbors(node, remove) - #self.timer_generator.update_node(node, remove) + # self.timer_generator.update_node(node, remove) if remove: if self._fast_reboot_status: self.fr.discard_fastreboot_for(node) + if self._momemtum_status: + await self.mom.update_node(node, remove=True) else: self.neighbor_policy.meet_node(node) + if self._momemtum_status: + await self.mom.update_node(node) self._remove_pending_confirmation_from(node) await self._update_neighbors_lock.release_async() - + async def neighbors_left(self): return len(await self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False)) > 0 - + def meet_node(self, node): logging.info(f"Update nodes known | addr: {node}") self.neighbor_policy.meet_node(node) - + def get_nodes_known(self, neighbors_too=False): return self.neighbor_policy.get_nodes_known(neighbors_too) - - def accept_model_offer(self, source, decoded_model, rounds, round, epochs, n_neighbors, loss): + + def accept_model_offer(self, source, decoded_model, rounds, round, epochs, n_neighbors, loss): if not self.accept_candidates_lock.locked(): logging.info(f"πŸ”„ Processing offer from {source}...") - #model_accepted = True#self.model_handler.accept_model(decoded_model) - #if source == "192.168.50.8:45007": + # model_accepted = True#self.model_handler.accept_model(decoded_model) + # if source == "192.168.50.8:45007": model_accepted = self.model_handler.accept_model(decoded_model) self.model_handler.set_config(config=(rounds, round, epochs, self)) - if model_accepted: + if model_accepted: self.candidate_selector.add_candidate((source, n_neighbors, loss)) return True else: @@ -271,10 +272,10 @@ def accept_model_offer(self, source, decoded_model, rounds, round, epochs, n_nei async def get_trainning_info(self): return await self.model_handler.get_model(None) - def add_candidate(self,source, n_neighbors, loss): + def add_candidate(self, source, n_neighbors, loss): if not self.accept_candidates_lock.locked(): self.candidate_selector.add_candidate((source, n_neighbors, loss)) - + async def currently_reestructuring(self): return self._restructure_process_lock.locked() @@ -283,78 +284,86 @@ async def stop_not_selected_connections(self): await asyncio.sleep(20) with self.discarded_offers_addr_lock: if len(self.discarded_offers_addr) > 0: - self.discarded_offers_addr = set(self.discarded_offers_addr) - await self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False) - logging.info(f"Interrupting connections | discarded offers | nodes discarded: {self.discarded_offers_addr}") + self.discarded_offers_addr = set( + self.discarded_offers_addr + ) - await self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False) + logging.info( + f"Interrupting connections | discarded offers | nodes discarded: {self.discarded_offers_addr}" + ) for addr in self.discarded_offers_addr: await self.engine.cm.disconnect(addr, mutual_disconnection=True) - await asyncio.sleep(1) + await asyncio.sleep(1) self.discarded_offers_addr = [] - except asyncio.CancelledError as e: + except asyncio.CancelledError: pass - async def check_external_connection_service_status(self): - logging.info(f"πŸ”„ Checking external connection service status...") + async def check_external_connection_service_status(self): + logging.info("πŸ”„ Checking external connection service status...") n = await self.neighbors_left() ecs = await self.engine.cm.is_external_connection_service_running() ss = self.engine.get_sinchronized_status() action = None logging.info(f"Stats | neighbors: {n} | service running: {ecs} | synchronized status: {ss}") if not await self.neighbors_left() and await self.engine.cm.is_external_connection_service_running(): - logging.info(f"❗️ Isolated node | Shutdowning service required") + logging.info("❗️ Isolated node | Shutdowning service required") action = lambda: self.engine.cm.stop_external_connection_service() - elif await self.neighbors_left() and not await self.engine.cm.is_external_connection_service_running() and self.engine.get_sinchronized_status(): - logging.info(f"πŸ”„ NOT isolated node | Service not running | Starting service...") + elif ( + await self.neighbors_left() + and not await self.engine.cm.is_external_connection_service_running() + and self.engine.get_sinchronized_status() + ): + logging.info("πŸ”„ NOT isolated node | Service not running | Starting service...") action = lambda: self.engine.cm.init_external_connection_service() return action async def start_late_connection_process(self, connected=False, msg_type="discover_join", addrs_known=None): """ - This function represents the process of discovering the federation and stablish the first - connections with it. The first step is to send the DISCOVER_JOIN/NODES message to look for nodes, - the ones that receive that message will send back a OFFER_MODEL/METRIC message. It contains info to do - a selection process among candidates to later on connect to the best ones. - The process will repeat until at least one candidate is found and the process will be locked - to avoid concurrency. + This function represents the process of discovering the federation and stablish the first + connections with it. The first step is to send the DISCOVER_JOIN/NODES message to look for nodes, + the ones that receive that message will send back a OFFER_MODEL/METRIC message. It contains info to do + a selection process among candidates to later on connect to the best ones. + The process will repeat until at least one candidate is found and the process will be locked + to avoid concurrency. """ logging.info("🌐 Initializing late connection process..") - + self.late_connection_process_lock.acquire() best_candidates = [] self.candidate_selector.remove_candidates() await self.clear_pending_confirmations() - + # find federation and send discover await self.engine.cm.stablish_connection_to_federation(msg_type, addrs_known) - + # wait offer logging.info(f"Waiting: {self.recieve_offer_timer}s to receive offers from federation") await asyncio.sleep(self.recieve_offer_timer) - + # acquire lock to not accept late candidates self.accept_candidates_lock.acquire() - + if self.candidate_selector.any_candidate(): - logging.info("Candidates found to connect to...") + logging.info("Candidates found to connect to...") # create message to send to candidates selected if not connected: msg = self.engine.cm.create_message("connection", "late_connect") else: msg = self.engine.cm.create_message("connection", "restructure") - + best_candidates = self.candidate_selector.select_candidates() - logging.info(f"Candidates | {[addr for addr,_,_ in best_candidates]}") + logging.info(f"Candidates | {[addr for addr, _, _ in best_candidates]}") # candidates not choosen --> disconnect try: for addr, _, _ in best_candidates: await self.add_pending_connection_confirmation(addr) await self.engine.cm.send_message(addr, msg) - await asyncio.sleep(1) - except asyncio.CancelledError as e: + await asyncio.sleep(1) + except asyncio.CancelledError: await self.update_neighbors(addr, remove=True) - pass + pass self.accept_candidates_lock.release() - self.late_connection_process_lock.release() - self.candidate_selector.remove_candidates() + self.late_connection_process_lock.release() + self.candidate_selector.remove_candidates() # if no candidates, repeat process else: logging.info("❗️ No Candidates found...") @@ -363,22 +372,23 @@ async def start_late_connection_process(self, connected=False, msg_type="discove if not connected: logging.info("❗️ repeating process...") await self.start_late_connection_process(connected, msg_type, addrs_known) - - - - ############################## - # ROBUSTNESS # - ############################## - - + + ############################## + # ROBUSTNESS # + ############################## + async def check_robustness(self): - #TODO aΓ±adir un cd para que no se haga continuamente + # TODO aΓ±adir un cd para que no se haga continuamente logging.info("πŸ”„ Analizing node network robustness...") if not self._restructure_process_lock.locked(): if not self.neighbors_left(): logging.info("No Neighbors left | reconnecting with Federation") - #await self.reconnect_to_federation() - elif self.neighbor_policy.need_more_neighbors() and self.engine.get_sinchronized_status() and self._restructure_available(): + # await self.reconnect_to_federation() + elif ( + self.neighbor_policy.need_more_neighbors() + and self.engine.get_sinchronized_status() + and self._restructure_available() + ): logging.info("Insufficient Robustness | Upgrading robustness | Searching for more connections") self._update_restructure_cooldown() asyncio.create_task(self.upgrade_connection_robustness()) @@ -389,33 +399,31 @@ async def check_robustness(self): logging.info("Sufficient Robustness | no actions required") else: logging.info("❗️ Reestructure/Reconnecting process already running...") - + async def reconnect_to_federation(self): # If we got some refs, try to reconnect to them self._restructure_process_lock.acquire() if self.neighbor_policy.get_nodes_known() > 0: logging.info("Reconnecting | Addrs availables") - await self.start_late_connection_process(connected=False, msg_type="discover_nodes", addrs_known=self.neighbor_policy.get_nodes_known()) - # Otherwise stablish connection to federation sending discover nodes instead of join + await self.start_late_connection_process( + connected=False, msg_type="discover_nodes", addrs_known=self.neighbor_policy.get_nodes_known() + ) + # Otherwise stablish connection to federation sending discover nodes instead of join else: logging.info("Reconnecting | NO Addrs availables") await self.start_late_connection_process(connected=False, msg_type="discover_nodes") self._restructure_process_lock.release() - + async def upgrade_connection_robustness(self): self._restructure_process_lock.acquire() addrs_to_connect = self.neighbor_policy.get_nodes_known(neighbors_too=False) # If we got some refs, try to connect to them if len(addrs_to_connect) > 0: logging.info(f"Reestructuring | Addrs availables | addr list: {addrs_to_connect}") - await self.start_late_connection_process(connected=True, msg_type="discover_nodes", addrs_known=addrs_to_connect) + await self.start_late_connection_process( + connected=True, msg_type="discover_nodes", addrs_known=addrs_to_connect + ) else: logging.info("Reestructuring | NO Addrs availables") await self.start_late_connection_process(connected=True, msg_type="discover_nodes") self._restructure_process_lock.release() - - - - - - \ No newline at end of file diff --git a/nebula/core/network/actions.py b/nebula/core/network/actions.py index 77a4ebd82..b8b9bb9b9 100644 --- a/nebula/core/network/actions.py +++ b/nebula/core/network/actions.py @@ -1,89 +1,28 @@ -from nebula.core.pb import nebula_pb2 from enum import Enum -import logging - - -def get_action_name_from_value(message_type: str, action_value: int) -> str: - # Diccionario que asocia cada tipo de mensaje con su Enum correspondiente - action_classes = { - "connection": ConnectionAction, - "federation": FederationAction, - "discovery": DiscoveryAction, - "control": ControlAction, - "discover": DiscoverAction, - "offer": OfferAction, - "link": LinkAction, - } - - # Obtener el Enum correspondiente al tipo de mensaje - enum_class = action_classes.get(message_type) - if not enum_class: - raise ValueError(f"Unknown message type: {message_type}") - - # Buscar el nombre de la acciΓ³n a partir del valor - for action in enum_class: - if action.value == action_value: - return action.name.lower() # Convertimos a lowercase para mantener el formato "late_connect" - raise ValueError(f"Unknown action value {action_value} for message type {message_type}") - - -def get_actions_names(message_type: str): - options = { - "connection": ConnectionAction, - "federation": FederationAction, - "discovery": DiscoveryAction, - "control": ControlAction, - "discover": DiscoverAction, - "offer": OfferAction, - "link": LinkAction, - } - - message_actions = options.get(message_type) - if not message_actions: - raise ValueError(f"Invalid message type: {message_type}") - - return [action.name.lower() for action in message_actions] +from nebula.core.pb import nebula_pb2 -def factory_message_action(message_type: str, action: str): - options = { - "connection": ConnectionAction, - "federation": FederationAction, - "discovery": DiscoveryAction, - "control": ControlAction, - "discover": DiscoverAction, - "offer": OfferAction, - "link": LinkAction, - } - - message_actions = options.get(message_type, None) - - if message_actions: - normalized_action = action.upper() - enum_action = message_actions[normalized_action] - #logging.info(f"Message action: {enum_action}, value: {enum_action.value}") - return enum_action.value - else: - return None - class ConnectionAction(Enum): CONNECT = nebula_pb2.ConnectionMessage.Action.CONNECT DISCONNECT = nebula_pb2.ConnectionMessage.Action.DISCONNECT LATE_CONNECT = nebula_pb2.ConnectionMessage.Action.LATE_CONNECT RESTRUCTURE = nebula_pb2.ConnectionMessage.Action.RESTRUCTURE + class FederationAction(Enum): FEDERATION_START = nebula_pb2.FederationMessage.Action.FEDERATION_START REPUTATION = nebula_pb2.FederationMessage.Action.REPUTATION FEDERATION_MODELS_INCLUDED = nebula_pb2.FederationMessage.Action.FEDERATION_MODELS_INCLUDED FEDERATION_READY = nebula_pb2.FederationMessage.Action.FEDERATION_READY + class DiscoveryAction(Enum): DISCOVER = nebula_pb2.DiscoveryMessage.Action.DISCOVER REGISTER = nebula_pb2.DiscoveryMessage.Action.REGISTER DEREGISTER = nebula_pb2.DiscoveryMessage.Action.DEREGISTER + class ControlAction(Enum): ALIVE = nebula_pb2.ControlMessage.Action.ALIVE OVERHEAD = nebula_pb2.ControlMessage.Action.OVERHEAD @@ -91,14 +30,62 @@ class ControlAction(Enum): RECOVERY = nebula_pb2.ControlMessage.Action.RECOVERY WEAK_LINK = nebula_pb2.ControlMessage.Action.WEAK_LINK + class DiscoverAction(Enum): - DISCOVER_JOIN = nebula_pb2.DiscoverMessage.Action.DISCOVER_JOIN + DISCOVER_JOIN = nebula_pb2.DiscoverMessage.Action.DISCOVER_JOIN DISCOVER_NODES = nebula_pb2.DiscoverMessage.Action.DISCOVER_NODES + class OfferAction(Enum): - OFFER_MODEL = nebula_pb2.OfferMessage.Action.OFFER_MODEL + OFFER_MODEL = nebula_pb2.OfferMessage.Action.OFFER_MODEL OFFER_METRIC = nebula_pb2.OfferMessage.Action.OFFER_METRIC + class LinkAction(Enum): CONNECT_TO = nebula_pb2.LinkMessage.Action.CONNECT_TO DISCONNECT_FROM = nebula_pb2.LinkMessage.Action.DISCONNECT_FROM + + +ACTION_CLASSES = { + "connection": ConnectionAction, + "federation": FederationAction, + "discovery": DiscoveryAction, + "control": ControlAction, + "discover": DiscoverAction, + "offer": OfferAction, + "link": LinkAction, +} + + +def get_action_name_from_value(message_type: str, action_value: int) -> str: + # Obtener el Enum correspondiente al tipo de mensaje + enum_class = ACTION_CLASSES.get(message_type) + if not enum_class: + raise ValueError(f"Unknown message type: {message_type}") + + # Buscar el nombre de la acciΓ³n a partir del valor + for action in enum_class: + if action.value == action_value: + return action.name.lower() # Convertimos a lowercase para mantener el formato "late_connect" + + raise ValueError(f"Unknown action value {action_value} for message type {message_type}") + + +def get_actions_names(message_type: str): + message_actions = ACTION_CLASSES.get(message_type) + if not message_actions: + raise ValueError(f"Invalid message type: {message_type}") + + return [action.name.lower() for action in message_actions] + + +def factory_message_action(message_type: str, action: str): + message_actions = ACTION_CLASSES.get(message_type) + + if message_actions: + normalized_action = action.upper() + enum_action = message_actions[normalized_action] + # logging.info(f"Message action: {enum_action}, value: {enum_action.value}") + return enum_action.value + else: + return None diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index b629bdf8b..34741522d 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -1,12 +1,9 @@ import asyncio import collections -import hashlib import logging -import os import subprocess import sys import traceback -from datetime import datetime from typing import TYPE_CHECKING import requests @@ -16,10 +13,8 @@ from nebula.core.network.discoverer import Discoverer from nebula.core.network.forwarder import Forwarder from nebula.core.network.messages import MessagesManager -from nebula.core.network.propagator import Propagator -from nebula.core.pb import nebula_pb2 from nebula.core.network.nebulamulticasting import NebulaConnectionService - +from nebula.core.network.propagator import Propagator from nebula.core.utils.helper import ( cosine_metric, euclidean_metric, @@ -79,10 +74,10 @@ def __init__(self, engine: "Engine"): self.loop = asyncio.get_event_loop() max_concurrent_tasks = 5 self.semaphore_send_model = asyncio.Semaphore(max_concurrent_tasks) - + # Connection service to communicate with external devices self._external_connection_service = None - + # The line below is neccesary when mobility would be set up mob = self.config.participant["mobility_args"]["mobility"] aditional_node = self.config.participant["mobility_args"]["additional_node"]["status"] @@ -93,7 +88,6 @@ def __init__(self, engine: "Engine"): else: logging.info("Deploying External Connection Service | No running") self._external_connection_service = NebulaConnectionService(self.addr) - @property def engine(self): @@ -126,7 +120,7 @@ def propagator(self): @property def mobility(self): return self._mobility - + @property def ecs(self): return self._external_connection_service @@ -146,51 +140,12 @@ async def handle_incoming_message(self, data, addr_from): await self.mm.process_message(data, addr_from) async def forward_message(self, data, addr_from): + logging.info("Forwarding message... ") await self.forwarder.forward(data, addr_from=addr_from) # generic point to handle messages - #async def handle_message(self, source, msg_type, message): async def handle_message(self, message_event): - #logging.info( - # f"πŸ” handle_{msg_type} | Received [Action {message.action}] from {source}" - #) - #try: - #await self.engine.event_manager.trigger_event(source, message) await self.engine.trigger_event(message_event) - #except Exception as e: - # logging.exception(f"πŸ” handle_{msg_type} | Error while processing: {e}") - - - async def handle_discovery_message(self, source, message): - logging.info( - f"πŸ” handle_discovery_message | Received [Action {message.action}] from {source} (network propagation)" - ) - try: - await self.engine.event_manager.trigger_event(source, message) - except Exception as e: - logging.exception(f"πŸ” handle_discovery_message | Error while processing: {e}") - - async def handle_control_message(self, source, message): - logging.info( - f"πŸ”§ handle_control_message | Received [Action {message.action}] from {source} with log {message.log}" - ) - try: - await self.engine.event_manager.trigger_event(source, message) - except Exception as e: - logging.exception( - f"πŸ”§ handle_control_message | Error while processing: {message.action} {message.log} | {e}" - ) - - async def handle_federation_message(self, source, message): - logging.info( - f"πŸ“ handle_federation_message | Received [Action {message.action}] from {source} with arguments {message.arguments}" - ) - try: - await self.engine.event_manager.trigger_event(source, message) - except Exception as e: - logging.exception( - f"πŸ“ handle_federation_message | Error while processing: {message.action} {message.arguments} | {e}" - ) async def handle_model_message(self, source, message): logging.info(f"πŸ€– handle_model_message | Received model from {source} with round {message.round}") @@ -257,10 +212,10 @@ async def handle_model_message(self, source, message): decoded_model, similarity=True, ) - #with open( + # with open( # f"{self.config.participant["tracking_args"]["log_dir"]}/participant_{self.id}_similarity.csv", # "a+", - #) as f: + # ) as f: # if os.stat(f"{self}/participant_{self.id}_similarity.csv").st_size == 0: # f.write( # "timestamp,source_ip,nodes,round,current_round,cosine,euclidean,minkowski,manhattan,pearson_correlation,jaccard\n" @@ -269,7 +224,9 @@ async def handle_model_message(self, source, message): # f"{datetime.now()}, {source}, {message.round}, {current_round}, {cosine_value}, {euclidean_value}, {minkowski_value}, {manhattan_value}, {pearson_correlation_value}, {jaccard_value}\n" # ) logging("Similarities between self model and model recieved...") - logging.info(f"{source}, {message.round}, {current_round}, {cosine_value}, {euclidean_value}, {minkowski_value}, {manhattan_value}, {pearson_correlation_value}, {jaccard_value}") + logging.info( + f"{source}, {message.round}, {current_round}, {cosine_value}, {euclidean_value}, {minkowski_value}, {manhattan_value}, {pearson_correlation_value}, {jaccard_value}" + ) await self.engine.aggregator.include_model_in_buffer( decoded_model, @@ -326,36 +283,9 @@ async def handle_model_message(self, source, message): ) return - async def handle_connection_message(self, source, message): - try: - await self.engine.event_manager.trigger_event(source, message) - except Exception as e: - logging.exception(f"πŸ”— handle_connection_message | Error while processing: {message.action} | {e}") - - async def handle_discover_message(self, source, message): - logging.info(f"πŸ” handle_discover_message | Received [Action {message.action}] from {source}") - try: - await self.engine.event_manager.trigger_event(source, message) - except Exception as e: - logging.error(f"πŸ” handle_discover_message | Error while processing: {e}") - - async def handle_offer_message(self, source, message): - logging.info(f"πŸ” handle_offer_message | Received [Action {message.action}] from {source}") - try: - await self.engine.event_manager.trigger_event(source, message) - except Exception as e: - logging.error(f"πŸ” handle_offer_message | Error while processing: {message.action} {message.arguments} | {e}") - - async def handle_link_message(self, source, message): - logging.info(f"πŸ” handle_link_message | Received [Action {message.action}] from {source}") - try: - await self.engine.event_manager.trigger_event(source, message) - except Exception as e: - logging.error(f"πŸ” handle_link_message | Error while processing: {message.action} {message.arguments} | {e}") - def create_message(self, message_type: str, action: str = "", *args, **kwargs): return self.mm.create_message(message_type, action, *args, **kwargs) - + def get_messages_events(self): return self.mm.get_messages_events() @@ -363,20 +293,20 @@ def start_external_connection_service(self): if self.ecs == None: self.ecs = NebulaConnectionService(self.addr) self.ecs.start() - + def stop_external_connection_service(self): - self.ecs.stop() - + self.ecs.stop() + def init_external_connection_service(self): self.start_external_connection_service() - + async def is_external_connection_service_running(self): - return self.ecs.is_running() - + return self.ecs.is_running() + async def stablish_connection_to_federation(self, msg_type="discover_join", addrs_known=None): """ - Using ExternalConnectionService to get addrs on local network, after that - stablishment of TCP connection and send the message broadcasted + Using ExternalConnectionService to get addrs on local network, after that + stablishment of TCP connection and send the message broadcasted """ addrs = [] if addrs_known == None: @@ -386,11 +316,11 @@ async def stablish_connection_to_federation(self, msg_type="discover_join", addr else: logging.info("Searching federation process beginning... | Using addrs previously known") addrs = addrs_known - + msg = self.create_message("discover", msg_type) - + logging.info("Starting communications with devices found") - #TODO filtrar para para quitar las que ya son vecinos + # TODO filtrar para para quitar las que ya son vecinos for addr in addrs: await self.connect(addr, direct=False) await asyncio.sleep(1) @@ -402,7 +332,7 @@ async def stablish_connection_to_federation(self, msg_type="discover_join", addr logging.info(f"Sending {msg_type} to ---> {addr}") asyncio.create_task(self.send_message(addr, msg)) await asyncio.sleep(1) - + def get_connections_lock(self): return self.connections_lock @@ -734,7 +664,7 @@ async def send_model(self, dest_addr, round, serialized_model, weight=1): logging.info( f"Sending model to {dest_addr} with round {round}: weight={weight} |Β size={sys.getsizeof(serialized_model) / (1024** 2) if serialized_model is not None else 0} MB" ) - #message = self.mm.generate_model_message(round, serialized_model, weight) + # message = self.mm.generate_model_message(round, serialized_model, weight) parameters = serialized_model message = self.create_message("model", "", round, parameters, weight) await conn.send(data=message, is_compressed=True) @@ -742,7 +672,7 @@ async def send_model(self, dest_addr, round, serialized_model, weight=1): except Exception as e: logging.exception(f"❗️ Cannot send model to {dest_addr}: {e!s}") await self.disconnect(dest_addr, mutual_disconnection=False) - + async def send_offer_model(self, dest_addr, offer_message): async with self.semaphore_send_model: try: @@ -750,14 +680,12 @@ async def send_offer_model(self, dest_addr, offer_message): if conn is None: logging.info(f"❗️ Connection with {dest_addr} not found") return - logging.info( - f"Sending offer model to {dest_addr}" - ) + logging.info(f"Sending offer model to {dest_addr}") await conn.send(data=offer_message, is_compressed=True) logging.info(f"Offer_Model sent to {dest_addr}") except Exception as e: logging.exception(f"❗️ Cannot send model to {dest_addr}: {e!s}") - await self.disconnect(dest_addr, mutual_disconnection=False) + await self.disconnect(dest_addr, mutual_disconnection=False) async def establish_connection(self, addr, direct=True, reconnect=False): logging.info(f"πŸ”— [outgoing] Establishing connection with {addr} (direct: {direct})") @@ -776,7 +704,7 @@ async def process_establish_connection(addr, direct, reconnect): if not self.connections[addr].get_direct() and (direct == True): self.connections[addr].set_direct(direct) return True - else: + else: return False if addr in self.pending_connections: logging.info(f"πŸ”— [outgoing] Connection with {addr} is already pending") @@ -947,7 +875,7 @@ async def disconnect(self, dest_addr, mutual_disconnection=True): try: if mutual_disconnection: await self.connections[dest_addr].send( - #data=self.mm.generate_connection_message(nebula_pb2.ConnectionMessage.Action.DISCONNECT) + # data=self.mm.generate_connection_message(nebula_pb2.ConnectionMessage.Action.DISCONNECT) data=self.create_message("connection", "disconnect") ) await asyncio.sleep(1) @@ -961,14 +889,14 @@ async def disconnect(self, dest_addr, mutual_disconnection=True): current_connections = set(current_connections) logging.info(f"Current connections: {current_connections}") self.config.update_neighbors_from_config(current_connections, dest_addr) - + async def remove_temporary_connection(self, temp_addr): logging.info(f"Removing temporary conneciton:{temp_addr}..") try: await self.get_connections_lock().acquire_async() self.connections.pop(temp_addr, None) finally: - await self.get_connections_lock().release_async() + await self.get_connections_lock().release_async() async def get_all_addrs_current_connections(self, only_direct=False, only_undirected=False): try: diff --git a/nebula/core/network/messages.py b/nebula/core/network/messages.py index 181944f13..5beb25b4f 100755 --- a/nebula/core/network/messages.py +++ b/nebula/core/network/messages.py @@ -1,27 +1,26 @@ +import hashlib import logging +import traceback from typing import TYPE_CHECKING +from nebula.core.network.actions import factory_message_action, get_action_name_from_value, get_actions_names from nebula.core.pb import nebula_pb2 -from nebula.core.network.actions import factory_message_action, get_actions_names, get_action_name_from_value -import hashlib -import traceback if TYPE_CHECKING: from nebula.core.network.communications import CommunicationsManager class MessagesManager: - def __init__(self, addr, config, cm: "CommunicationsManager"): self.addr = addr self.config = config self.cm = cm self._message_templates = {} - self._define_message_templates() + self._define_message_templates() def _define_message_templates(self): # Dictionary that maps message types to their required parameters and default values - self._message_templates = { + self._message_templates = { "offer": { "parameters": ["action", "n_neighbors", "loss", "parameters", "rounds", "round", "epochs"], "defaults": { @@ -29,50 +28,38 @@ def _define_message_templates(self): "rounds": 1, "round": -1, "epochs": 1, - } - }, - "connection": { - "parameters": ["action"], - "defaults": {} + }, }, + "connection": {"parameters": ["action"], "defaults": {}}, "discovery": { "parameters": ["action", "latitude", "longitude"], "defaults": { "latitude": 0.0, "longitude": 0.0, - } + }, }, "control": { "parameters": ["action", "log"], "defaults": { "log": "Control message", - } + }, }, "federation": { "parameters": ["action", "arguments", "round"], "defaults": { "arguments": [], "round": None, - } + }, }, "model": { "parameters": ["round", "parameters", "weight"], "defaults": { "weight": 1, - } - }, - "reputation": { - "parameters": ["reputation"], - "defaults": {} - }, - "discover": { - "parameters": ["action"], - "defaults": {} - }, - "link": { - "parameters": ["action", "addrs"], - "defaults": {} + }, }, + "reputation": {"parameters": ["reputation"], "defaults": {}}, + "discover": {"parameters": ["action"], "defaults": {}}, + "link": {"parameters": ["action", "addrs"], "defaults": {}}, # Add additional message types here } @@ -86,7 +73,7 @@ def get_messages_events(self): async def process_message(self, data, addr_from): not_processing_messages = {"control_message", "connection_message"} special_processing_messages = {"discovery_message", "federation_message", "model_message"} - + try: message_wrapper = nebula_pb2.Wrapper() message_wrapper.ParseFromString(data) @@ -94,46 +81,52 @@ async def process_message(self, data, addr_from): logging.debug(f"πŸ“₯ handle_incoming_message | Received message from {addr_from} with source {source}") if source == self.addr: return - + # Extract the active message from the oneof field message_type = message_wrapper.WhichOneof("message") - msg_name = message_type.split('_')[0] + msg_name = message_type.split("_")[0] if not message_type: logging.warning("Received message with no active field in the 'oneof'") return logging.info(f"Message type received: {message_type}") message_data = getattr(message_wrapper, message_type) - + # Not required processing messages if message_type in not_processing_messages: - #await self.cm.handle_message(source, message_type, message_data) - me = MessageEvent((msg_name,get_action_name_from_value(msg_name, message_data.action)), source, message_data) + # await self.cm.handle_message(source, message_type, message_data) + me = MessageEvent( + (msg_name, get_action_name_from_value(msg_name, message_data.action)), source, message_data + ) await self.cm.handle_message(me) - + # Message-specific forwarding and processing elif message_type in special_processing_messages: if await self.cm.include_received_message_hash(hashlib.md5(data).hexdigest()): # Forward the message if required if self._should_forward_message(message_type, message_wrapper): - self.cm.forward_message(data, addr_from) - + await self.cm.forward_message(data, addr_from) + if message_type == "model_message": await self.cm.handle_model_message(source, message_data) else: - #await self.cm.handle_message(source, message_type, message_data) - me = MessageEvent((msg_name,get_action_name_from_value(msg_name, message_data.action)), source, message_data) + # await self.cm.handle_message(source, message_type, message_data) + me = MessageEvent( + (msg_name, get_action_name_from_value(msg_name, message_data.action)), source, message_data + ) await self.cm.handle_message(me) # Rest of messages else: if await self.cm.include_received_message_hash(hashlib.md5(data).hexdigest()): - #await self.cm.handle_message(source, message_type, message_data) - me = MessageEvent((msg_name,get_action_name_from_value(msg_name, message_data.action)), source, message_data) + # await self.cm.handle_message(source, message_type, message_data) + me = MessageEvent( + (msg_name, get_action_name_from_value(msg_name, message_data.action)), source, message_data + ) await self.cm.handle_message(me) except Exception as e: logging.exception(f"πŸ“₯ handle_incoming_message | Error while processing: {e}") logging.exception(traceback.format_exc()) - + def _should_forward_message(self, message_type, message_wrapper): if self.cm.config.participant["device_args"]["proxy"]: return True @@ -142,52 +135,56 @@ def _should_forward_message(self, message_type, message_wrapper): # Round -1 is the initialization round --> all nodes should receive the model if message_type == "model_message" and message_wrapper.model_message.round == -1: return True - if message_type == "federation_message" and message_wrapper.federation_message.action == nebula_pb2.FederationMessage.Action.Value("FEDERATION_START"): + if ( + message_type == "federation_message" + and message_wrapper.federation_message.action + == nebula_pb2.FederationMessage.Action.Value("FEDERATION_START") + ): return True - + def create_message(self, message_type: str, action: str = "", *args, **kwargs): - #logging.info(f"Creating message | type: {message_type}, action: {action}, positionals: {args}, explicits: {kwargs.keys()}") + # logging.info(f"Creating message | type: {message_type}, action: {action}, positionals: {args}, explicits: {kwargs.keys()}") # If an action is provided, convert it to its corresponding enum value using the factory - message_action = None + message_action = None if action: message_action = factory_message_action(message_type, action) - + # Retrieve the template for the provided message type message_template = self._message_templates.get(message_type) if not message_template: raise ValueError(f"Invalid message type '{message_type}'") - + # Extract parameters and defaults from the template template_params = message_template["parameters"] default_values: dict = message_template.get("defaults", {}) - + # Dynamically retrieve the class for the protobuf message (e.g., OfferMessage) - class_name = message_type.capitalize() + "Message" + class_name = message_type.capitalize() + "Message" message_class = getattr(nebula_pb2, class_name, None) - + if message_class is None: raise AttributeError(f"Message type {message_type} not found on the protocol") - + # Set the 'action' parameter if required and if the message_action is available if "action" in template_params and message_action is not None: kwargs["action"] = message_action - + # Map positional arguments to template parameters - remaining_params = [param_name for param_name in template_params if param_name not in kwargs] + remaining_params = [param_name for param_name in template_params if param_name not in kwargs] if args: - for param_name, arg_value in zip(remaining_params, args): + for param_name, arg_value in zip(remaining_params, args, strict=False): if param_name in kwargs: - continue + continue kwargs[param_name] = arg_value - + # Fill in missing parameters with their default values # logging.info(f"kwargs parameters: {kwargs.keys()}") for param_name in template_params: if param_name not in kwargs: logging.info(f"Filling parameter '{param_name}' with default value: {default_values.get(param_name)}") kwargs[param_name] = default_values.get(param_name) - - # Create an instance of the protobuf message class using the constructed kwargs + + # Create an instance of the protobuf message class using the constructed kwargs message = message_class(**kwargs) message_wrapper = nebula_pb2.Wrapper() @@ -197,8 +194,9 @@ def create_message(self, message_type: str, action: str = "", *args, **kwargs): data = message_wrapper.SerializeToString() return data + class MessageEvent: - def __init__(self, message_type, source, message): - self.source = source - self.message_type = message_type - self.message = message \ No newline at end of file + def __init__(self, message_type, source, message): + self.source = source + self.message_type = message_type + self.message = message diff --git a/nebula/node.py b/nebula/node.py index 8fca023b4..1cd037862 100755 --- a/nebula/node.py +++ b/nebula/node.py @@ -120,7 +120,7 @@ async def main(config): partition_parameter=partition_parameter, seed=42, config=config, - additional=additional_node_status + additional=additional_node_status, ) if model_name == "MLP": model = MNISTModelMLP() @@ -350,35 +350,18 @@ def randomize_value(value, variability): if additional_node_status: logging.info(f"Waiting for round {additional_node_round} to start") logging.info("Waiting time to start finding federation") - - # MNIST - # 385 r30 - # 615 r50 - # CIFAR - # 420 r15 - # 600 r22 - - #if config.participant["network_args"]["ip"] == "192.168.50.11": - #time.sleep(820) - + time.sleep(200) - - #if config.participant["network_args"]["ip"] == "192.168.51.11": - # logging.info("waiting 385s") - - #elif config.participant["network_args"]["ip"] == "192.168.51.12": - # logging.info("waiting 800s") - # time.sleep(800) - - #time.sleep(6000) # DEBUG purposes - #import requests - - #url = f"http://{node.config.participant['scenario_args']['controller']}/platform/{node.config.participant['scenario_args']['name']}/round" - #current_round = int(requests.get(url).json()["round"]) - #while current_round < additional_node_round: + + # time.sleep(6000) # DEBUG purposes + # import requests + + # url = f"http://{node.config.participant['scenario_args']['controller']}/platform/{node.config.participant['scenario_args']['name']}/round" + # current_round = int(requests.get(url).json()["round"]) + # while current_round < additional_node_round: # logging.info(f"Waiting for round {additional_node_round} to start") # time.sleep(10) - #logging.info(f"Round {additional_node_round} started, connecting to the network") + # logging.info(f"Round {additional_node_round} started, connecting to the network") await node._aditional_node_start() From fd1e7340bb3d11b6f5fa6f35483a8dfd61b86bc7 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 6 Feb 2025 17:05:43 +0100 Subject: [PATCH 074/233] opt_sinc++ --- nebula/core/aggregation/aggregator.py | 28 +++++++++++++++++---------- nebula/core/engine.py | 2 +- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index 28842a478..89284f034 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -147,11 +147,8 @@ async def _add_pending_model(self, model, weight, source): if future_round < self.engine.get_round(): del self._future_models_to_aggregate[future_round] - # TODO comprobar que los q faltan no estan en futuros - if len(self.get_nodes_pending_models_to_aggregate()) >= len(self._federation_nodes): logging.info("πŸ”„ _add_pending_model | All models were added in the aggregation buffer. Run aggregation...") - # self.engine.update_sinchronized_status(True) await self._aggregation_done_lock.release_async() await self._add_model_lock.release_async() @@ -183,10 +180,6 @@ async def include_model_in_buffer(self, model, weight, source=None, round=None, logging.info( f"πŸ”„ include_model_in_buffer | Broadcasting MODELS_INCLUDED for round {self.engine.get_round()}" ) - # message = self.cm.mm.generate_federation_message( - # nebula_pb2.FederationMessage.Action.FEDERATION_MODELS_INCLUDED, - # [self.engine.get_round()], - # ) message = self.cm.create_message( "federation", "federation_models_included", [str(arg) for arg in [self.engine.get_round()]] ) @@ -254,8 +247,23 @@ async def include_next_model_in_buffer(self, model, weight, source=None, round=N await self._add_next_model_lock.acquire_async() self._future_models_to_aggregate[round].append((decoded_model, weight, source)) await self._add_next_model_lock.release_async() - # await self.aggregation_push_available() - # asyncio.create_task(self.aggregation_push_available()) + + # Verify if we are waiting an update that maybe we wont received + if self._aggregation_done_lock.locked(): + pending_nodes: set = self._federation_nodes - self.get_nodes_pending_models_to_aggregate() + if pending_nodes: + for f_round, future_updates in self._future_models_to_aggregate.items(): + for _, _, source in future_updates: + logging.info(f"a ver q haya qui {source}") + if source in pending_nodes: + logging.info( + f"Waiting update from source: {source}, but future update storaged for round: {f_round}" + ) + pending_nodes.discard(source) + + if not pending_nodes: + logging.info("Received advanced updates for all sources missing this round") + await self._aggregation_done_lock.release_async() def print_model_size(self, model): total_params = 0 @@ -367,7 +375,7 @@ async def aggregation_push_available(self): self.engine.update_sinchronized_status(True) self.engine.set_synchronizing_rounds(False) else: - pass + logging.info("No rounds can be pushed...") await self._push_strategy_lock.release_async() else: logging.info( diff --git a/nebula/core/engine.py b/nebula/core/engine.py index d124cc9fb..37146403f 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -158,7 +158,7 @@ def __init__( topology = self.config.participant["mobility_args"]["topology_type"] topology = topology.lower() model_handler = "std" # self.config.participant["mobility_args"]["model_handler"] - acceleration_push = "slow" # self.config.participant["mobility_args"]["push_strategy"] + acceleration_push = "fast" # self.config.participant["mobility_args"]["push_strategy"] self._node_manager = NodeManager( config.participant["mobility_args"]["additional_node"]["status"], topology, From bc6f7b8c244f1550b824ede49cb05f0ba208ec1b Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Fri, 7 Feb 2025 13:57:43 +0100 Subject: [PATCH 075/233] fix_momentum --- nebula/core/aggregation/aggregator.py | 20 ++++++++++++++++++-- nebula/core/engine.py | 12 +++++++++++- nebula/core/network/communications.py | 10 ++++++++++ nebula/core/network/connection.py | 5 +++-- 4 files changed, 42 insertions(+), 5 deletions(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index 89284f034..02bcfa812 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -43,6 +43,7 @@ def __init__(self, config=None, engine=None): self._federation_nodes = set() self._waiting_global_update = False self._pending_models_to_aggregate = {} + self._pending_models_to_aggregate_lock = Locker(name="pending_models_to_aggregate_lock", async_lock=True) self._future_models_to_aggregate = {} self._add_model_lock = Locker(name="add_model_lock", async_lock=True) self._add_next_model_lock = Locker(name="add_next_model_lock", async_lock=True) @@ -74,7 +75,22 @@ async def update_federation_nodes(self, federation_nodes): timeout=self.config.participant["aggregator_args"]["aggregation_timeout"] ) else: - raise Exception("It is not possible to set nodes to aggregate when the aggregation is running.") + # Neighbor has been removed + if len(self._federation_nodes) - len(federation_nodes) > 0: + nodes_removed = self._federation_nodes - federation_nodes + pending_nodes = self.get_nodes_pending_models_to_aggregate() + shouldnt_waited_model = [] + shouldnt_waited_model = [source for source in nodes_removed if source in pending_nodes] + logging.info(f"Waiting models from removed neighbors: {shouldnt_waited_model}") + if shouldnt_waited_model: + await self._pending_models_to_aggregate_lock.acquire_async() + self._pending_models_to_aggregate.difference_update(shouldnt_waited_model) + await self._pending_models_to_aggregate_lock.release_async() + if self._aggregation_done_lock.locked(): + if not self._pending_models_to_aggregate: + await self._aggregation_done_lock.release_async() + else: + raise Exception("It is not possible to set nodes to aggregate when the aggregation is running.") def set_waiting_global_update(self): self._waiting_global_update = True @@ -389,7 +405,7 @@ async def aggregation_push_available(self): elif self.engine.get_synchronizing_rounds(): logging.info("❗️ Cannot analize push | already pushing rounds") await self._push_strategy_lock.release_async() - + def create_malicious_aggregator(aggregator, attack): # It creates a partial function aggregate that wraps the aggregate method of the original aggregator. diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 37146403f..506b96788 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -294,7 +294,11 @@ async def _connection_connect_callback(self, source, message): current_connections = await self.cm.get_addrs_current_connections(myself=True) if source not in current_connections: logging.info(f"πŸ”— handle_connection_message | Trigger | Connecting to {source}") - await self.cm.connect(source, direct=True) + #TODO remove conditional + if not source == "192.168.53.4:45003": + await self.cm.connect(source, direct=True) + else: + logging.info("### DEBUGGING ###") async def _connection_disconnect_callback(self, source, message): logging.info(f"πŸ”— handle_connection_message | Trigger | Received disconnection message from {source}") @@ -565,6 +569,12 @@ async def apply_weight_strategy(self, pending_models): return pending_models else: return pending_models + + async def update_neighbors(self, removed_neighbor_addr, neighbors, remove=False): + if self.mobility: + self.federation_nodes = neighbors + await self.nm.update_neighbors(removed_neighbor_addr, remove=remove) + await self.aggregator.update_federation_nodes(self.federation_nodes) async def update_model_learning_rate(self, new_lr): await self.trainning_in_progress_lock.acquire_async() diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 34741522d..391120e97 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -467,6 +467,16 @@ async def process_connection(reader, writer): await process_connection(reader, writer) + async def terminate_failed_reconnection(self, conn: Connection): + # Remove failed connection + connected_with = conn.addr + await self.get_connections_lock().acquire_async() + for key, val in list(self.connections.items()): + if val == conn: + del self.connections[key] + await self.get_connections_lock().release_async() + await self.engine.update_neighbors(connected_with, await self.get_addrs_current_connections(only_direct=True, myself=True), remove=True) + async def stop(self): logging.info("🌐 Stopping Communications Manager... [Removing connections and stopping network engine]") connections = list(self.connections.values()) diff --git a/nebula/core/network/connection.py b/nebula/core/network/connection.py index fc1a6ff6d..abcd656e4 100755 --- a/nebula/core/network/connection.py +++ b/nebula/core/network/connection.py @@ -195,7 +195,8 @@ async def reconnect(self, max_retries: int = 5, delay: int = 5) -> None: logging.exception(f"Reconnection attempt {attempt + 1} failed: {e}") await asyncio.sleep(delay) logging.error(f"Failed to reconnect to {self.addr} after {max_retries} attempts. Stopping connection...") - await self.stop() + #await self.stop() + await self.cm.terminate_failed_reconnection(self) async def send( self, @@ -292,7 +293,7 @@ async def handle_incoming_message(self) -> None: logging.info("Message handling cancelled") except ConnectionError as e: logging.exception(f"Connection closed while reading: {e}") - #await self.reconnect() + await self.reconnect() except Exception as e: logging.exception(f"Error handling incoming message: {e}") From 81edea1ed59907834ee15841115fe9777bfd3494 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Fri, 7 Feb 2025 16:07:21 +0100 Subject: [PATCH 076/233] fix_disconnect_error --- nebula/core/aggregation/aggregator.py | 1 - nebula/core/engine.py | 11 ++--------- nebula/core/neighbormanagement/nodemanager.py | 13 ++++++++++--- nebula/core/network/communications.py | 19 +++++++++++++++++-- 4 files changed, 29 insertions(+), 15 deletions(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index 02bcfa812..6d8b8e389 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -270,7 +270,6 @@ async def include_next_model_in_buffer(self, model, weight, source=None, round=N if pending_nodes: for f_round, future_updates in self._future_models_to_aggregate.items(): for _, _, source in future_updates: - logging.info(f"a ver q haya qui {source}") if source in pending_nodes: logging.info( f"Waiting update from source: {source}, but future update storaged for round: {f_round}" diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 506b96788..e29981e63 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -294,12 +294,8 @@ async def _connection_connect_callback(self, source, message): current_connections = await self.cm.get_addrs_current_connections(myself=True) if source not in current_connections: logging.info(f"πŸ”— handle_connection_message | Trigger | Connecting to {source}") - #TODO remove conditional - if not source == "192.168.53.4:45003": - await self.cm.connect(source, direct=True) - else: - logging.info("### DEBUGGING ###") - + await self.cm.connect(source, direct=True) + async def _connection_disconnect_callback(self, source, message): logging.info(f"πŸ”— handle_connection_message | Trigger | Received disconnection message from {source}") if self.mobility: @@ -908,9 +904,6 @@ def reputation_calculation(self, aggregated_models_weights): async def send_reputation(self, malicious_nodes): logging.info(f"Sending REPUTATION to the rest of the topology: {malicious_nodes}") - # message = self.cm.mm.generate_federation_message( - # nebula_pb2.FederationMessage.Action.REPUTATION, malicious_nodes - # ) message = self.cm.create_message("federation", "reputation", arguments=[str(arg) for arg in (malicious_nodes)]) await self.cm.send_message_to_neighbors(message) diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index d00bcd5c8..81ae1ed50 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -281,7 +281,6 @@ async def currently_reestructuring(self): async def stop_not_selected_connections(self): try: - await asyncio.sleep(20) with self.discarded_offers_addr_lock: if len(self.discarded_offers_addr) > 0: self.discarded_offers_addr = set( @@ -364,6 +363,7 @@ async def start_late_connection_process(self, connected=False, msg_type="discove self.accept_candidates_lock.release() self.late_connection_process_lock.release() self.candidate_selector.remove_candidates() + asyncio.create_task(self.stop_connections_with_federation()) # if no candidates, repeat process else: logging.info("❗️ No Candidates found...") @@ -378,12 +378,11 @@ async def start_late_connection_process(self, connected=False, msg_type="discove ############################## async def check_robustness(self): - # TODO aΓ±adir un cd para que no se haga continuamente logging.info("πŸ”„ Analizing node network robustness...") if not self._restructure_process_lock.locked(): if not self.neighbors_left(): logging.info("No Neighbors left | reconnecting with Federation") - # await self.reconnect_to_federation() + await self.reconnect_to_federation() elif ( self.neighbor_policy.need_more_neighbors() and self.engine.get_sinchronized_status() @@ -427,3 +426,11 @@ async def upgrade_connection_robustness(self): logging.info("Reestructuring | NO Addrs availables") await self.start_late_connection_process(connected=True, msg_type="discover_nodes") self._restructure_process_lock.release() + + async def stop_connections_with_federation(self): + asyncio.sleep(120) + logging.info("### DISCONNECTING FROM FEDERATON ###") + neighbors = self.neighbor_policy.get_nodes_known(neighbors_only=True) + await self.engine.cm.add_to_blacklist(neighbors) + for n in neighbors: + await self.engine.cm.disconnect(n, mutual_disconnection=False) diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 391120e97..365583796 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -74,6 +74,9 @@ def __init__(self, engine: "Engine"): self.loop = asyncio.get_event_loop() max_concurrent_tasks = 5 self.semaphore_send_model = asyncio.Semaphore(max_concurrent_tasks) + + self._blacklisted_nodes = set() + self._blacklisted_nodes_lock = Locker(name="_blacklisted_nodes_lock", async_lock=True) # Connection service to communicate with external devices self._external_connection_service = None @@ -88,7 +91,7 @@ def __init__(self, engine: "Engine"): else: logging.info("Deploying External Connection Service | No running") self._external_connection_service = NebulaConnectionService(self.addr) - + @property def engine(self): return self._engine @@ -137,7 +140,11 @@ async def add_ready_connection(self, addr): self.ready_connections.add(addr) async def handle_incoming_message(self, data, addr_from): - await self.mm.process_message(data, addr_from) + self._blacklisted_nodes_lock.acquire_async() + blacklist = self._blacklisted_nodes.copy() + self._blacklisted_nodes_lock.release_async() + if not addr_from in blacklist: + await self.mm.process_message(data, addr_from) async def forward_message(self, data, addr_from): logging.info("Forwarding message... ") @@ -289,6 +296,12 @@ def create_message(self, message_type: str, action: str = "", *args, **kwargs): def get_messages_events(self): return self.mm.get_messages_events() + async def add_to_blacklist(self, addr): + logging.info(f"Update blackList | addr listed: {addr}") + self._blacklisted_nodes_lock.acquire_async() + self._blacklisted_nodes.add(addr) + self._blacklisted_nodes_lock.release_async() + def start_external_connection_service(self): if self.ecs == None: self.ecs = NebulaConnectionService(self.addr) @@ -894,6 +907,8 @@ async def disconnect(self, dest_addr, mutual_disconnection=True): logging.exception(f"❗️ Error while disconnecting {dest_addr}: {e!s}") if dest_addr in self.connections: logging.info(f"Removing {dest_addr} from connections") + #del self.connections[dest_addr] + self.connections[dest_addr].stop() del self.connections[dest_addr] current_connections = await self.get_all_addrs_current_connections(only_direct=True) current_connections = set(current_connections) From 8c1fc0f0a6912c2d3d5b389d1adae57767a89402 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Fri, 7 Feb 2025 21:29:46 +0100 Subject: [PATCH 077/233] fix_disconnection_node --- nebula/core/aggregation/aggregator.py | 40 ++++++++++--------- nebula/core/engine.py | 2 +- nebula/core/eventmanager.py | 2 +- nebula/core/neighbormanagement/nodemanager.py | 9 +++-- nebula/core/network/communications.py | 36 +++++++++++++---- nebula/core/network/connection.py | 15 ++++++- nebula/core/network/messages.py | 2 +- nebula/node.py | 2 +- 8 files changed, 72 insertions(+), 36 deletions(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index 6d8b8e389..71f8990e5 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -67,30 +67,34 @@ def run_aggregation(self, models): logging.error("Trying to aggregate models when there are no models") return None - async def update_federation_nodes(self, federation_nodes): + async def update_federation_nodes(self, federation_nodes: set): if not self._aggregation_done_lock.locked(): self._federation_nodes = federation_nodes self._pending_models_to_aggregate.clear() await self._aggregation_done_lock.acquire_async( timeout=self.config.participant["aggregator_args"]["aggregation_timeout"] - ) + ) else: - # Neighbor has been removed - if len(self._federation_nodes) - len(federation_nodes) > 0: - nodes_removed = self._federation_nodes - federation_nodes - pending_nodes = self.get_nodes_pending_models_to_aggregate() - shouldnt_waited_model = [] - shouldnt_waited_model = [source for source in nodes_removed if source in pending_nodes] - logging.info(f"Waiting models from removed neighbors: {shouldnt_waited_model}") - if shouldnt_waited_model: - await self._pending_models_to_aggregate_lock.acquire_async() - self._pending_models_to_aggregate.difference_update(shouldnt_waited_model) - await self._pending_models_to_aggregate_lock.release_async() - if self._aggregation_done_lock.locked(): - if not self._pending_models_to_aggregate: - await self._aggregation_done_lock.release_async() - else: - raise Exception("It is not possible to set nodes to aggregate when the aggregation is running.") + raise Exception("It is not possible to set nodes to aggregate when the aggregation is running.") + + async def notify_federation_nodes_removed(self, federation_nodes: set): + # Neighbor has been removed + logging.info(f"Updating expected updates, current models: {self._federation_nodes}, new expectation: {federation_nodes}") + if len(self._federation_nodes) - len(federation_nodes) > 0: + nodes_removed = self._federation_nodes.symmetric_difference(federation_nodes) + logging.info(f"Nodes removed from aggregation: {nodes_removed}") + pending_nodes = self.get_nodes_pending_models_to_aggregate() + logging.info(f"Pending models to aggregato: {self.get_nodes_pending_models_to_aggregate()}") + shouldnt_waited_model = [] + shouldnt_waited_model = [source for source in nodes_removed if source in pending_nodes] + logging.info(f"Waiting models from removed neighbors: {shouldnt_waited_model}") + if shouldnt_waited_model: + await self._pending_models_to_aggregate_lock.acquire_async() + self._pending_models_to_aggregate.difference_update(shouldnt_waited_model) + await self._pending_models_to_aggregate_lock.release_async() + if self._aggregation_done_lock.locked(): + if not self._pending_models_to_aggregate: + await self._aggregation_done_lock.release_async() def set_waiting_global_update(self): self._waiting_global_update = True diff --git a/nebula/core/engine.py b/nebula/core/engine.py index e29981e63..4beb0fa81 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -570,7 +570,7 @@ async def update_neighbors(self, removed_neighbor_addr, neighbors, remove=False) if self.mobility: self.federation_nodes = neighbors await self.nm.update_neighbors(removed_neighbor_addr, remove=remove) - await self.aggregator.update_federation_nodes(self.federation_nodes) + await self.aggregator.notify_federation_nodes_removed(self.federation_nodes) async def update_model_learning_rate(self, new_lr): await self.trainning_in_progress_lock.acquire_async() diff --git a/nebula/core/eventmanager.py b/nebula/core/eventmanager.py index 1042d4376..fd44b6534 100755 --- a/nebula/core/eventmanager.py +++ b/nebula/core/eventmanager.py @@ -51,7 +51,7 @@ async def publish(self, message_event: MessageEvent): for callback in self._subscribers[event_type]: try: - logging.info(f"EventManager | Triggering callback for event: {event_type}, from source: {message_event.source}") + #logging.info(f"EventManager | Triggering callback for event: {event_type}, from source: {message_event.source}") await callback(message_event.source, message_event.message) except Exception as e: logging.error(f"EventManager | Error in callback for event {event_type}: {e}") diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 81ae1ed50..4847864d7 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -382,7 +382,7 @@ async def check_robustness(self): if not self._restructure_process_lock.locked(): if not self.neighbors_left(): logging.info("No Neighbors left | reconnecting with Federation") - await self.reconnect_to_federation() + #await self.reconnect_to_federation() elif ( self.neighbor_policy.need_more_neighbors() and self.engine.get_sinchronized_status() @@ -428,9 +428,10 @@ async def upgrade_connection_robustness(self): self._restructure_process_lock.release() async def stop_connections_with_federation(self): - asyncio.sleep(120) + asyncio.sleep(150) logging.info("### DISCONNECTING FROM FEDERATON ###") neighbors = self.neighbor_policy.get_nodes_known(neighbors_only=True) - await self.engine.cm.add_to_blacklist(neighbors) + for n in neighbors: + await self.engine.cm.add_to_blacklist(n) for n in neighbors: - await self.engine.cm.disconnect(n, mutual_disconnection=False) + await self.engine.cm.disconnect(n, mutual_disconnection=False, forced=True) diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 365583796..ed9d3af0d 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -301,6 +301,13 @@ async def add_to_blacklist(self, addr): self._blacklisted_nodes_lock.acquire_async() self._blacklisted_nodes.add(addr) self._blacklisted_nodes_lock.release_async() + + async def get_blacklist(self): + bl = None + self._blacklisted_nodes_lock.acquire_async() + bl = self._blacklisted_nodes.copy() + self._blacklisted_nodes_lock.release_async() + return bl def start_external_connection_service(self): if self.ecs == None: @@ -333,7 +340,6 @@ async def stablish_connection_to_federation(self, msg_type="discover_join", addr msg = self.create_message("discover", msg_type) logging.info("Starting communications with devices found") - # TODO filtrar para para quitar las que ya son vecinos for addr in addrs: await self.connect(addr, direct=False) await asyncio.sleep(1) @@ -375,6 +381,17 @@ async def handle_connection(self, reader, writer): async def process_connection(reader, writer): try: addr = writer.get_extra_info("peername") + + + peer_ip, peer_port = addr + addr_str = f"{peer_ip}:{peer_port}" + blacklist = await self.get_blacklist() + if addr_str in blacklist: + logging.info(f"πŸ”— [incoming] Rejecting connection from {addr}, it is blacklisted.") + writer.close() + await writer.wait_closed() + return + connected_node_id = await reader.readline() connected_node_id = connected_node_id.decode("utf-8").strip() connected_node_port = addr[1] @@ -481,13 +498,8 @@ async def process_connection(reader, writer): await process_connection(reader, writer) async def terminate_failed_reconnection(self, conn: Connection): - # Remove failed connection connected_with = conn.addr - await self.get_connections_lock().acquire_async() - for key, val in list(self.connections.items()): - if val == conn: - del self.connections[key] - await self.get_connections_lock().release_async() + await self.disconnect(connected_with, mutual_disconnection=False) await self.engine.update_neighbors(connected_with, await self.get_addrs_current_connections(only_direct=True, myself=True), remove=True) async def stop(self): @@ -890,7 +902,12 @@ async def wait_for_controller(self): logging.info("Waiting for controller signal...") await asyncio.sleep(1) - async def disconnect(self, dest_addr, mutual_disconnection=True): + async def disconnect(self, dest_addr, mutual_disconnection=True, forced=False): + removed = False + + if forced: + self.add_to_blacklist(dest_addr) + logging.info(f"Trying to disconnect {dest_addr}") if dest_addr not in self.connections: logging.info(f"Connection {dest_addr} not found") @@ -910,10 +927,13 @@ async def disconnect(self, dest_addr, mutual_disconnection=True): #del self.connections[dest_addr] self.connections[dest_addr].stop() del self.connections[dest_addr] + removed = True current_connections = await self.get_all_addrs_current_connections(only_direct=True) current_connections = set(current_connections) logging.info(f"Current connections: {current_connections}") self.config.update_neighbors_from_config(current_connections, dest_addr) + if removed: + await self.engine.update_neighbors(dest_addr, current_connections, remove=removed) async def remove_temporary_connection(self, temp_addr): logging.info(f"Removing temporary conneciton:{temp_addr}..") diff --git a/nebula/core/network/connection.py b/nebula/core/network/connection.py index abcd656e4..773fb96d4 100755 --- a/nebula/core/network/connection.py +++ b/nebula/core/network/connection.py @@ -23,10 +23,11 @@ class MessageChunk: data: bytes is_last: bool +MAX_INCOMPLETED_RECONNECTIONS = 3 class Connection: DEFAULT_FEDERATED_ROUND = -1 - + def __init__( self, cm: "CommunicationsManager", @@ -75,6 +76,8 @@ def __init__( self.HEADER_SIZE = 21 self.MAX_CHUNK_SIZE = 1024 # 1 KB self.BUFFER_SIZE = 1024 # 1 KB + + self.incompleted_reconnections = 0 logging.info( f"Connection [established]: {self.addr} (id: {self.id}) (active: {self.active}) (direct: {self.direct})" @@ -177,10 +180,16 @@ async def stop(self): await self.writer.wait_closed() async def reconnect(self, max_retries: int = 5, delay: int = 5) -> None: + if self.incompleted_reconnections == MAX_INCOMPLETED_RECONNECTIONS: + logging.info(f"Reconnection failed...") + await self.cm.terminate_failed_reconnection(self) + return for attempt in range(max_retries): try: logging.info(f"Attempting to reconnect to {self.addr} (attempt {attempt + 1}/{max_retries})") await self.cm.connect(self.addr) + await asyncio.sleep(1) + self.read_task = asyncio.create_task( self.handle_incoming_message(), name=f"Connection {self.addr} reader", @@ -286,13 +295,15 @@ async def handle_incoming_message(self) -> None: chunk_data = await self._read_chunk(reusable_buffer) self._store_chunk(message_id, chunk_index, chunk_data, is_last_chunk) # logging.debug(f"Received chunk {chunk_index} of message {message_id.hex()} | size: {len(chunk_data)} bytes") - + # Active connection without fails + self.incompleted_reconnections = 0 if is_last_chunk: await self._process_complete_message(message_id) except asyncio.CancelledError: logging.info("Message handling cancelled") except ConnectionError as e: logging.exception(f"Connection closed while reading: {e}") + self.incompleted_reconnections += 1 await self.reconnect() except Exception as e: logging.exception(f"Error handling incoming message: {e}") diff --git a/nebula/core/network/messages.py b/nebula/core/network/messages.py index 5beb25b4f..5953d3551 100755 --- a/nebula/core/network/messages.py +++ b/nebula/core/network/messages.py @@ -88,7 +88,7 @@ async def process_message(self, data, addr_from): if not message_type: logging.warning("Received message with no active field in the 'oneof'") return - logging.info(f"Message type received: {message_type}") + #logging.info(f"Message type received: {message_type}") message_data = getattr(message_wrapper, message_type) diff --git a/nebula/node.py b/nebula/node.py index 1cd037862..fdf9f74d4 100755 --- a/nebula/node.py +++ b/nebula/node.py @@ -351,7 +351,7 @@ def randomize_value(value, variability): logging.info(f"Waiting for round {additional_node_round} to start") logging.info("Waiting time to start finding federation") - time.sleep(200) + time.sleep(150) # time.sleep(6000) # DEBUG purposes # import requests From e6413cdd909e3ac640beaf656ebe7072df286ba7 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Fri, 7 Feb 2025 22:31:50 +0100 Subject: [PATCH 078/233] fix_TCP_temporary_port --- nebula/core/aggregation/aggregator.py | 3 ++- nebula/core/neighbormanagement/nodemanager.py | 2 +- nebula/core/network/communications.py | 25 ++++++++++--------- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index 71f8990e5..579de0b12 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -90,7 +90,8 @@ async def notify_federation_nodes_removed(self, federation_nodes: set): logging.info(f"Waiting models from removed neighbors: {shouldnt_waited_model}") if shouldnt_waited_model: await self._pending_models_to_aggregate_lock.acquire_async() - self._pending_models_to_aggregate.difference_update(shouldnt_waited_model) + for swm in shouldnt_waited_model: + self._pending_models_to_aggregate.pop(swm) await self._pending_models_to_aggregate_lock.release_async() if self._aggregation_done_lock.locked(): if not self._pending_models_to_aggregate: diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 4847864d7..3a341b4ab 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -428,7 +428,7 @@ async def upgrade_connection_robustness(self): self._restructure_process_lock.release() async def stop_connections_with_federation(self): - asyncio.sleep(150) + asyncio.sleep(400) logging.info("### DISCONNECTING FROM FEDERATON ###") neighbors = self.neighbor_policy.get_nodes_known(neighbors_only=True) for n in neighbors: diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index ed9d3af0d..7e7c3ae35 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -296,6 +296,7 @@ def create_message(self, message_type: str, action: str = "", *args, **kwargs): def get_messages_events(self): return self.mm.get_messages_events() + #TODO limpiar la blacklist periodicamente async def add_to_blacklist(self, addr): logging.info(f"Update blackList | addr listed: {addr}") self._blacklisted_nodes_lock.acquire_async() @@ -381,17 +382,7 @@ async def handle_connection(self, reader, writer): async def process_connection(reader, writer): try: addr = writer.get_extra_info("peername") - - - peer_ip, peer_port = addr - addr_str = f"{peer_ip}:{peer_port}" - blacklist = await self.get_blacklist() - if addr_str in blacklist: - logging.info(f"πŸ”— [incoming] Rejecting connection from {addr}, it is blacklisted.") - writer.close() - await writer.wait_closed() - return - + connected_node_id = await reader.readline() connected_node_id = connected_node_id.decode("utf-8").strip() connected_node_port = addr[1] @@ -404,6 +395,16 @@ async def process_connection(reader, writer): logging.info( f"πŸ”— [incoming] Connection from {addr} - {connection_addr} [id {connected_node_id} | port {connected_node_port} | direct {direct}] (incoming)" ) + + blacklist = await self.get_blacklist() + if blacklist: + logging.info(f"blacklist: {blacklist}, source trying to connect: {connection_addr}") + + if connected_node_id in blacklist: + logging.info(f"πŸ”— [incoming] Rejecting connection from {connection_addr}, it is blacklisted.") + writer.close() + await writer.wait_closed() + return if self.id == connected_node_id: logging.info("πŸ”— [incoming] Connection with yourself is not allowed") @@ -500,7 +501,7 @@ async def process_connection(reader, writer): async def terminate_failed_reconnection(self, conn: Connection): connected_with = conn.addr await self.disconnect(connected_with, mutual_disconnection=False) - await self.engine.update_neighbors(connected_with, await self.get_addrs_current_connections(only_direct=True, myself=True), remove=True) + #await self.engine.update_neighbors(connected_with, await self.get_addrs_current_connections(only_direct=True, myself=True), remove=True) async def stop(self): logging.info("🌐 Stopping Communications Manager... [Removing connections and stopping network engine]") From 85fabd2070dc4168bc03631b101c16cc9b96b0ab Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Fri, 7 Feb 2025 22:39:48 +0100 Subject: [PATCH 079/233] fix_notself_agg --- nebula/core/aggregation/aggregator.py | 2 +- nebula/core/network/communications.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index 579de0b12..57acb9bca 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -84,7 +84,7 @@ async def notify_federation_nodes_removed(self, federation_nodes: set): nodes_removed = self._federation_nodes.symmetric_difference(federation_nodes) logging.info(f"Nodes removed from aggregation: {nodes_removed}") pending_nodes = self.get_nodes_pending_models_to_aggregate() - logging.info(f"Pending models to aggregato: {self.get_nodes_pending_models_to_aggregate()}") + logging.info(f"Pending models to aggregate: {self.get_nodes_pending_models_to_aggregate()}") shouldnt_waited_model = [] shouldnt_waited_model = [source for source in nodes_removed if source in pending_nodes] logging.info(f"Waiting models from removed neighbors: {shouldnt_waited_model}") diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 7e7c3ae35..475c1e22b 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -934,6 +934,7 @@ async def disconnect(self, dest_addr, mutual_disconnection=True, forced=False): logging.info(f"Current connections: {current_connections}") self.config.update_neighbors_from_config(current_connections, dest_addr) if removed: + current_connections = await self.get_addrs_current_connections(only_direct=True, myself=True) await self.engine.update_neighbors(dest_addr, current_connections, remove=removed) async def remove_temporary_connection(self, temp_addr): From 4ecb5bd78b25024b9d598a4acd5a2b6b13267554 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Fri, 7 Feb 2025 23:38:40 +0100 Subject: [PATCH 080/233] fix_tcp_ports --- nebula/core/aggregation/aggregator.py | 3 +++ nebula/core/network/communications.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index 57acb9bca..90f6ad380 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -79,6 +79,9 @@ async def update_federation_nodes(self, federation_nodes: set): async def notify_federation_nodes_removed(self, federation_nodes: set): # Neighbor has been removed + #TODO revisar no esta bien + # nodes_removed tiene el valor correcto del nodo q no hay q esperar, + # falta actualizar federation_nodes para eliminar los que ya no se deben esperar logging.info(f"Updating expected updates, current models: {self._federation_nodes}, new expectation: {federation_nodes}") if len(self._federation_nodes) - len(federation_nodes) > 0: nodes_removed = self._federation_nodes.symmetric_difference(federation_nodes) diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 475c1e22b..8ca9ee9b3 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -399,8 +399,8 @@ async def process_connection(reader, writer): blacklist = await self.get_blacklist() if blacklist: logging.info(f"blacklist: {blacklist}, source trying to connect: {connection_addr}") - - if connected_node_id in blacklist: + + if connection_addr in blacklist: logging.info(f"πŸ”— [incoming] Rejecting connection from {connection_addr}, it is blacklisted.") writer.close() await writer.wait_closed() From 2275f651e68faa180b22220920518ee62e1e8f13 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Sat, 8 Feb 2025 12:51:50 +0100 Subject: [PATCH 081/233] feat_node_disconnection --- nebula/core/aggregation/aggregator.py | 25 +++++++--------- nebula/core/engine.py | 1 - nebula/core/eventmanager.py | 1 + nebula/core/neighbormanagement/nodemanager.py | 12 +++++--- nebula/core/network/communications.py | 11 +++---- nebula/core/network/connection.py | 29 +++++++++++++++---- nebula/core/network/nebulamulticasting.py | 2 +- 7 files changed, 50 insertions(+), 31 deletions(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index 90f6ad380..048b59e4d 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -78,28 +78,27 @@ async def update_federation_nodes(self, federation_nodes: set): raise Exception("It is not possible to set nodes to aggregate when the aggregation is running.") async def notify_federation_nodes_removed(self, federation_nodes: set): - # Neighbor has been removed - #TODO revisar no esta bien - # nodes_removed tiene el valor correcto del nodo q no hay q esperar, - # falta actualizar federation_nodes para eliminar los que ya no se deben esperar - logging.info(f"Updating expected updates, current models: {self._federation_nodes}, new expectation: {federation_nodes}") + # Neighbor has been removed if len(self._federation_nodes) - len(federation_nodes) > 0: nodes_removed = self._federation_nodes.symmetric_difference(federation_nodes) logging.info(f"Nodes removed from aggregation: {nodes_removed}") - pending_nodes = self.get_nodes_pending_models_to_aggregate() - logging.info(f"Pending models to aggregate: {self.get_nodes_pending_models_to_aggregate()}") + pending_nodes = (self._federation_nodes - self.get_nodes_pending_models_to_aggregate()) + #logging.info(f"Pending models to aggregate: {pending_nodes}") shouldnt_waited_model = [] shouldnt_waited_model = [source for source in nodes_removed if source in pending_nodes] logging.info(f"Waiting models from removed neighbors: {shouldnt_waited_model}") if shouldnt_waited_model: - await self._pending_models_to_aggregate_lock.acquire_async() for swm in shouldnt_waited_model: - self._pending_models_to_aggregate.pop(swm) - await self._pending_models_to_aggregate_lock.release_async() + logging.info(f"Removing model from waiting: {swm}") + pending_nodes.discard(swm) if self._aggregation_done_lock.locked(): - if not self._pending_models_to_aggregate: + if not pending_nodes: + logging.info("No model updates required left | releasing aggregation lock...") + self._federation_nodes = federation_nodes await self._aggregation_done_lock.release_async() + self._federation_nodes = federation_nodes + def set_waiting_global_update(self): self._waiting_global_update = True @@ -279,9 +278,7 @@ async def include_next_model_in_buffer(self, model, weight, source=None, round=N for f_round, future_updates in self._future_models_to_aggregate.items(): for _, _, source in future_updates: if source in pending_nodes: - logging.info( - f"Waiting update from source: {source}, but future update storaged for round: {f_round}" - ) + #logging.info(f"Waiting update from source: {source}, but future update storaged for round: {f_round}") pending_nodes.discard(source) if not pending_nodes: diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 4beb0fa81..fcf9af0ec 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -540,7 +540,6 @@ def register_message_events_callbacks(self): self.event_manager.subscribe((event_type, action), method) async def trigger_event(self, message_event): - logging.info(f"Publishing MessageEvent: {message_event.message_type}") await self.event_manager.publish(message_event) async def _aditional_node_start(self): diff --git a/nebula/core/eventmanager.py b/nebula/core/eventmanager.py index fd44b6534..72d24ffaf 100755 --- a/nebula/core/eventmanager.py +++ b/nebula/core/eventmanager.py @@ -44,6 +44,7 @@ def subscribe(self, event_type: tuple[str,str], callback: callable): async def publish(self, message_event: MessageEvent): """Trigger all callbacks registered for a specific event type.""" + #logging.info(f"Publishing MessageEvent: {message_event.message_type}") event_type = message_event.message_type if event_type not in self._subscribers: logging.error(f"EventManager | No subscribers for event: {event_type}") diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 3a341b4ab..b8bbdf703 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -49,7 +49,7 @@ def __init__( self.recieve_offer_timer = 5 self._restructure_process_lock = Locker(name="restructure_process_lock") self.restructure = False - self._restructure_cooldown = RESTRUCTURE_COOLDOWN + self._restructure_cooldown = 0 self.discarded_offers_addr_lock = Locker(name="discarded_offers_addr_lock") self.discarded_offers_addr = [] self._push_acceleration = push_acceleration @@ -379,9 +379,12 @@ async def start_late_connection_process(self, connected=False, msg_type="discove async def check_robustness(self): logging.info("πŸ”„ Analizing node network robustness...") + logging.info(f"Synchronization status: {self.engine.get_sinchronized_status()} | got neighbors: {await self.neighbors_left()}") if not self._restructure_process_lock.locked(): - if not self.neighbors_left(): + if not await self.neighbors_left(): logging.info("No Neighbors left | reconnecting with Federation") + #TODO comprobar q funcione correctamente + #TODO actualizar estado a desincronizado #await self.reconnect_to_federation() elif ( self.neighbor_policy.need_more_neighbors() @@ -390,7 +393,8 @@ async def check_robustness(self): ): logging.info("Insufficient Robustness | Upgrading robustness | Searching for more connections") self._update_restructure_cooldown() - asyncio.create_task(self.upgrade_connection_robustness()) + #TODO comprobar q los posibles vecinos no sean nodos de los que recientemente te has desconectado + #asyncio.create_task(self.upgrade_connection_robustness()) else: if not self.engine.get_sinchronized_status(): logging.info("Device not synchronized with federation") @@ -428,7 +432,7 @@ async def upgrade_connection_robustness(self): self._restructure_process_lock.release() async def stop_connections_with_federation(self): - asyncio.sleep(400) + await asyncio.sleep(100) logging.info("### DISCONNECTING FROM FEDERATON ###") neighbors = self.neighbor_policy.get_nodes_known(neighbors_only=True) for n in neighbors: diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 8ca9ee9b3..98402d38b 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -150,7 +150,6 @@ async def forward_message(self, data, addr_from): logging.info("Forwarding message... ") await self.forwarder.forward(data, addr_from=addr_from) - # generic point to handle messages async def handle_message(self, message_event): await self.engine.trigger_event(message_event) @@ -916,7 +915,6 @@ async def disconnect(self, dest_addr, mutual_disconnection=True, forced=False): try: if mutual_disconnection: await self.connections[dest_addr].send( - # data=self.mm.generate_connection_message(nebula_pb2.ConnectionMessage.Action.DISCONNECT) data=self.create_message("connection", "disconnect") ) await asyncio.sleep(1) @@ -926,9 +924,12 @@ async def disconnect(self, dest_addr, mutual_disconnection=True, forced=False): if dest_addr in self.connections: logging.info(f"Removing {dest_addr} from connections") #del self.connections[dest_addr] - self.connections[dest_addr].stop() - del self.connections[dest_addr] - removed = True + try: + removed = True + await self.connections[dest_addr].stop() + del self.connections[dest_addr] + except Exception as e: + logging.exception(f"❗️ Error while removing connection {dest_addr}: {e!s}") current_connections = await self.get_all_addrs_current_connections(only_direct=True) current_connections = set(current_connections) logging.info(f"Current connections: {current_connections}") diff --git a/nebula/core/network/connection.py b/nebula/core/network/connection.py index 773fb96d4..21a082181 100755 --- a/nebula/core/network/connection.py +++ b/nebula/core/network/connection.py @@ -78,6 +78,7 @@ def __init__( self.BUFFER_SIZE = 1024 # 1 KB self.incompleted_reconnections = 0 + self.forced_disconnection = False logging.info( f"Connection [established]: {self.addr} (id: {self.id}) (active: {self.active}) (direct: {self.direct})" @@ -166,6 +167,7 @@ async def start(self): async def stop(self): logging.info(f"❗️ Connection [stopped]: {self.addr} (id: {self.id})") + self.forced_disconnection = True tasks = [self.read_task, self.process_task] for task in tasks: if task is not None: @@ -176,14 +178,23 @@ async def stop(self): logging.exception(f"❗️ {self} cancelled...") if self.writer is not None: - self.writer.close() - await self.writer.wait_closed() + try: + self.writer.close() + await self.writer.wait_closed() + except Exception as e: + logging.exception(f"❗️ Error ocurred when closing pipe: {e}") async def reconnect(self, max_retries: int = 5, delay: int = 5) -> None: + if self.forced_disconnection: + return + + self.incompleted_reconnections += 1 if self.incompleted_reconnections == MAX_INCOMPLETED_RECONNECTIONS: - logging.info(f"Reconnection failed...") + logging.info(f"Reconnection with {self.addr} failed...") + self.forced_disconnection = True await self.cm.terminate_failed_reconnection(self) return + for attempt in range(max_retries): try: logging.info(f"Attempting to reconnect to {self.addr} (attempt {attempt + 1}/{max_retries})") @@ -198,7 +209,8 @@ async def reconnect(self, max_retries: int = 5, delay: int = 5) -> None: self.process_message_queue(), name=f"Connection {self.addr} processor", ) - logging.info(f"Reconnected to {self.addr}") + if not self.forced_disconnection: + logging.info(f"Reconnected to {self.addr}") return except Exception as e: logging.exception(f"Reconnection attempt {attempt + 1} failed: {e}") @@ -303,10 +315,12 @@ async def handle_incoming_message(self) -> None: logging.info("Message handling cancelled") except ConnectionError as e: logging.exception(f"Connection closed while reading: {e}") - self.incompleted_reconnections += 1 - await self.reconnect() except Exception as e: logging.exception(f"Error handling incoming message: {e}") + except BrokenPipeError: + logging.exception(f"Error handling incoming message: {e}") + finally: + await self.reconnect() async def _read_exactly(self, num_bytes: int, max_retries: int = 3) -> bytes: data = b"" @@ -324,6 +338,9 @@ async def _read_exactly(self, num_bytes: int, max_retries: int = 3) -> bytes: if _ == max_retries - 1: raise logging.warning(f"Retrying read after IncompleteReadError: {e}") + except BrokenPipeError as e: + if not self.forced_disconnection: + logging.exception(f"Broken PIPE while reading: {e}") raise RuntimeError("Max retries reached in _read_exactly") def _parse_header(self, header: bytes) -> tuple[bytes, int, bool]: diff --git a/nebula/core/network/nebulamulticasting.py b/nebula/core/network/nebulamulticasting.py index d8a81f6e9..e573ca7bf 100644 --- a/nebula/core/network/nebulamulticasting.py +++ b/nebula/core/network/nebulamulticasting.py @@ -165,7 +165,7 @@ def start(self): self.server.start() def stop(self): - self.server.stop + self.server.stop() def is_running(self): if self.server: From cef7757023372ceffe43f11c493f339e7798b730 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Mon, 10 Feb 2025 14:54:17 +0100 Subject: [PATCH 082/233] feat_blacklist --- nebula/core/neighbormanagement/nodemanager.py | 38 ++--- nebula/core/network/blacklist.py | 148 ++++++++++++++++++ nebula/core/network/communications.py | 71 ++++++--- 3 files changed, 215 insertions(+), 42 deletions(-) create mode 100644 nebula/core/network/blacklist.py diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index b8bbdf703..5e46724c7 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -91,6 +91,8 @@ def _update_restructure_cooldown(self): self._restructure_cooldown = (self._restructure_cooldown + 1) % RESTRUCTURE_COOLDOWN def _restructure_available(self): + if self._restructure_cooldown: + logging.info("Reestructure on cooldown") return self._restructure_cooldown == 0 def get_push_acceleration(self): @@ -156,7 +158,7 @@ def late_config(self): ) ############################## - # FAST REBOOT # + # WEIGHT STRATEGIES # ############################## async def update_learning_rate(self, new_lr): @@ -384,8 +386,9 @@ async def check_robustness(self): if not await self.neighbors_left(): logging.info("No Neighbors left | reconnecting with Federation") #TODO comprobar q funcione correctamente - #TODO actualizar estado a desincronizado - #await self.reconnect_to_federation() + self.engine.update_sinchronized_status(False) + await asyncio.sleep(120) + await self.reconnect_to_federation() elif ( self.neighbor_policy.need_more_neighbors() and self.engine.get_sinchronized_status() @@ -393,11 +396,14 @@ async def check_robustness(self): ): logging.info("Insufficient Robustness | Upgrading robustness | Searching for more connections") self._update_restructure_cooldown() - #TODO comprobar q los posibles vecinos no sean nodos de los que recientemente te has desconectado - #asyncio.create_task(self.upgrade_connection_robustness()) + possible_neighbors = self.neighbor_policy.get_nodes_known(neighbors_too=False) + possible_neighbors = await self.engine.cm.apply_restrictions(possible_neighbors) + if not possible_neighbors: + logging.info("All possible neighbors using nodes known are restricted...") + #asyncio.create_task(self.upgrade_connection_robustness(possible_neighbors)) else: if not self.engine.get_sinchronized_status(): - logging.info("Device not synchronized with federation") + logging.info("Device not synchronized with federation") else: logging.info("Sufficient Robustness | no actions required") else: @@ -406,26 +412,22 @@ async def check_robustness(self): async def reconnect_to_federation(self): # If we got some refs, try to reconnect to them self._restructure_process_lock.acquire() - if self.neighbor_policy.get_nodes_known() > 0: + await self.engine.cm.clear_restrictions() + if len(self.neighbor_policy.get_nodes_known()) > 0: logging.info("Reconnecting | Addrs availables") - await self.start_late_connection_process( - connected=False, msg_type="discover_nodes", addrs_known=self.neighbor_policy.get_nodes_known() - ) - # Otherwise stablish connection to federation sending discover nodes instead of join + await self.start_late_connection_process(connected=False, msg_type="discover_nodes", addrs_known=self.neighbor_policy.get_nodes_known()) else: logging.info("Reconnecting | NO Addrs availables") await self.start_late_connection_process(connected=False, msg_type="discover_nodes") self._restructure_process_lock.release() - async def upgrade_connection_robustness(self): + async def upgrade_connection_robustness(self, possible_neighbors): self._restructure_process_lock.acquire() - addrs_to_connect = self.neighbor_policy.get_nodes_known(neighbors_too=False) + #addrs_to_connect = self.neighbor_policy.get_nodes_known(neighbors_too=False) # If we got some refs, try to connect to them - if len(addrs_to_connect) > 0: - logging.info(f"Reestructuring | Addrs availables | addr list: {addrs_to_connect}") - await self.start_late_connection_process( - connected=True, msg_type="discover_nodes", addrs_known=addrs_to_connect - ) + if len(possible_neighbors) > 0: + logging.info(f"Reestructuring | Addrs availables | addr list: {possible_neighbors}") + await self.start_late_connection_process(connected=True, msg_type="discover_nodes", addrs_known=possible_neighbors) else: logging.info("Reestructuring | NO Addrs availables") await self.start_late_connection_process(connected=True, msg_type="discover_nodes") diff --git a/nebula/core/network/blacklist.py b/nebula/core/network/blacklist.py new file mode 100644 index 000000000..32b2e398f --- /dev/null +++ b/nebula/core/network/blacklist.py @@ -0,0 +1,148 @@ +import asyncio +import logging +import time +from nebula.core.utils.locker import Locker + +BLACKLIST_EXPIRATION_TIME = 240 +RECENTLY_DISCONNECTED_EXPIRE_TIME = 60 + +class BlackList: + def __init__( + self, + max_time_listed = BLACKLIST_EXPIRATION_TIME + ): + self._max_time_listed = max_time_listed + self._blacklisted_nodes: dict = {} + self._recently_disconnected: set = set() # para no inentar coenctarse a recently disconnected + self._recently_disconnected_lock = Locker(name="recently_disconnected_lock", async_lock=True) + self._blacklisted_nodes_lock = Locker(name="blacklisted_nodes_lock", async_lock=True) + self._bl_cleaner_running = False + self._blacklist_cleaner_wake_up = asyncio.Event() + self._running = False + + async def apply_restrictions(self, nodes) -> set | None: + nodes_allowed = await self.verify_allowed_nodes(nodes) + if nodes_allowed: + nodes_allowed = await self.verify_not_recently_disc(nodes_allowed) + return nodes_allowed + + async def clear_restrictions(self): + await self.clear_blacklist() + await self.clear_recently_disconected() + + """ ############################## + # BLACKLIST # + ############################## + """ + + async def add_to_blacklist(self, addr): + logging.info(f"Update blackList | addr listed: {addr}") + await self._blacklisted_nodes_lock.acquire_async() + expiration_time = time.time() + self._blacklisted_nodes[addr] = expiration_time + if not self._running: + self._running = True + asyncio.create_task(self._start_blacklist_cleaner()) + await self._blacklisted_nodes_lock.release_async() + + async def get_blacklist(self) -> set: + bl = None + await self._blacklisted_nodes_lock.acquire_async() + if self._blacklisted_nodes: + bl = set(self._blacklisted_nodes.keys()) + await self._blacklisted_nodes_lock.release_async() + return bl + + async def clear_blacklist(self): + await self._blacklisted_nodes_lock.acquire_async() + self._blacklisted_nodes.clear() + await self._blacklisted_nodes_lock.release_async() + + async def _start_blacklist_cleaner(self): + while self._running: + await self._blacklist_clean() + await self._blacklist_cleaner_wait() + + async def _blacklist_clean(self): + await self._blacklisted_nodes_lock.acquire_async() + logging.info("BlackList cleaner has waken up") + now = time.time() + new_bl = {} + + for addr,timer in self._blacklisted_nodes.items(): + if timer + self._max_time_listed >= now: + new_bl[addr] = timer + else: + logging.info(f"Removing addr{addr} from blacklisted nodes...") + + self._blacklisted_nodes = new_bl + if not new_bl: + self._running = False + await self._blacklisted_nodes_lock.release_async() + + async def _blacklist_cleaner_wait(self): + try: + await asyncio.sleep(self._max_time_listed) + except asyncio.TimeoutError: + pass + + async def node_in_blacklist(self, addr): + blacklisted = False + await self._blacklisted_nodes_lock.acquire_async() + if self._blacklisted_nodes: + blacklisted = addr in self._blacklisted_nodes.keys() + await self._blacklisted_nodes_lock.release_async() + return blacklisted + + async def verify_allowed_nodes(self, nodes) -> set | None: + if not nodes: + return None + nodes_not_listed = nodes + await self._blacklisted_nodes_lock.acquire_async() + blacklist = self._blacklisted_nodes + if blacklist: + nodes_not_listed = set(nodes).difference_update(blacklist) + await self._blacklisted_nodes_lock.release_async() + return nodes_not_listed + + """ ############################## + # RECENTLY DISCONNECTED # + ############################## + """ + + async def add_recently_disconnected(self, addr): + logging.info(f"Recently disconnected from: {addr}") + self._recently_disconnected_lock.acquire_async() + self._recently_disconnected.add(addr) + self._recently_disconnected_lock.release_async() + asyncio.create_task(self._remove_recently_disc(addr)) + + async def clear_recently_disconected(self): + self._recently_disconnected_lock.acquire_async() + self._recently_disconnected.clear() + self._recently_disconnected_lock.release_async() + + async def get_recently_disconnected(self): + rd = None + self._recently_disconnected_lock.acquire_async() + rd = self._recently_disconnected.copy() + self._recently_disconnected_lock.release_async() + return rd + + async def _remove_recently_disc(self,addr): + await asyncio.sleep(RECENTLY_DISCONNECTED_EXPIRE_TIME) + self._recently_disconnected_lock.acquire_async() + self._recently_disconnected.discard(addr) + logging.info(f"Recently disconnection timeout expired for souce: {addr}") + self._recently_disconnected_lock.release_async() + + async def verify_not_recently_disc(self, nodes) -> set | None: + if not nodes: + return None + nodes_not_listed = nodes + self._recently_disconnected_lock.acquire_async() + rec_disc = self._recently_disconnected + if rec_disc: + nodes_not_listed = set(nodes).difference_update(rec_disc) + self._recently_disconnected_lock.release_async() + return nodes_not_listed \ No newline at end of file diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 98402d38b..8064915bc 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -4,6 +4,7 @@ import subprocess import sys import traceback +import time from typing import TYPE_CHECKING import requests @@ -25,11 +26,15 @@ ) from nebula.core.utils.locker import Locker +from nebula.core.network.blacklist import BlackList + if TYPE_CHECKING: from nebula.core.engine import Engine +BLACKLIST_EXPIRATION_TIME = 60 class CommunicationsManager: + def __init__(self, engine: "Engine"): logging.info("🌐 Initializing Communications Manager") self._engine = engine @@ -75,8 +80,7 @@ def __init__(self, engine: "Engine"): max_concurrent_tasks = 5 self.semaphore_send_model = asyncio.Semaphore(max_concurrent_tasks) - self._blacklisted_nodes = set() - self._blacklisted_nodes_lock = Locker(name="_blacklisted_nodes_lock", async_lock=True) + self._blacklist = BlackList() # Connection service to communicate with external devices self._external_connection_service = None @@ -127,6 +131,10 @@ def mobility(self): @property def ecs(self): return self._external_connection_service + + @property + def bl(self): + return self._blacklist async def check_federation_ready(self): # Check if all my connections are in ready_connections @@ -140,10 +148,7 @@ async def add_ready_connection(self, addr): self.ready_connections.add(addr) async def handle_incoming_message(self, data, addr_from): - self._blacklisted_nodes_lock.acquire_async() - blacklist = self._blacklisted_nodes.copy() - self._blacklisted_nodes_lock.release_async() - if not addr_from in blacklist: + if not await self.bl.node_in_blacklist(addr_from): await self.mm.process_message(data, addr_from) async def forward_message(self, data, addr_from): @@ -295,19 +300,31 @@ def create_message(self, message_type: str, action: str = "", *args, **kwargs): def get_messages_events(self): return self.mm.get_messages_events() - #TODO limpiar la blacklist periodicamente + """ ############################## + # BLACKLIST # + ############################## + """ + + async def add_to_recently_disconnected(self, addr): + await self.bl.add_recently_disconnected(addr) + async def add_to_blacklist(self, addr): - logging.info(f"Update blackList | addr listed: {addr}") - self._blacklisted_nodes_lock.acquire_async() - self._blacklisted_nodes.add(addr) - self._blacklisted_nodes_lock.release_async() + await self.bl.add_to_blacklist(addr) async def get_blacklist(self): - bl = None - self._blacklisted_nodes_lock.acquire_async() - bl = self._blacklisted_nodes.copy() - self._blacklisted_nodes_lock.release_async() - return bl + return await self.bl.get_blacklist() + + async def apply_restrictions(self, nodes): + return await self.bl.apply_restrictions(nodes) + + async def clear_restrictions(self): + await self.bl.clear_restrictions() + + + """ ############################### + # EXTERNAL CONNECTION SERVICE # + ############################### + """ def start_external_connection_service(self): if self.ecs == None: @@ -352,6 +369,12 @@ async def stablish_connection_to_federation(self, msg_type="discover_join", addr asyncio.create_task(self.send_message(addr, msg)) await asyncio.sleep(1) + + """ ############################## + # OTHER FUNCTIONALITIES # + ############################## + """ + def get_connections_lock(self): return self.connections_lock @@ -395,15 +418,14 @@ async def process_connection(reader, writer): f"πŸ”— [incoming] Connection from {addr} - {connection_addr} [id {connected_node_id} | port {connected_node_port} | direct {direct}] (incoming)" ) - blacklist = await self.get_blacklist() + blacklist = await self.bl.get_blacklist() if blacklist: - logging.info(f"blacklist: {blacklist}, source trying to connect: {connection_addr}") - - if connection_addr in blacklist: - logging.info(f"πŸ”— [incoming] Rejecting connection from {connection_addr}, it is blacklisted.") - writer.close() - await writer.wait_closed() - return + logging.info(f"blacklist: {blacklist}, source trying to connect: {connection_addr}") + if connection_addr in blacklist: + logging.info(f"πŸ”— [incoming] Rejecting connection from {connection_addr}, it is blacklisted.") + writer.close() + await writer.wait_closed() + return if self.id == connected_node_id: logging.info("πŸ”— [incoming] Connection with yourself is not allowed") @@ -499,6 +521,7 @@ async def process_connection(reader, writer): async def terminate_failed_reconnection(self, conn: Connection): connected_with = conn.addr + await self.bl.add_recently_disconnected(connected_with) await self.disconnect(connected_with, mutual_disconnection=False) #await self.engine.update_neighbors(connected_with, await self.get_addrs_current_connections(only_direct=True, myself=True), remove=True) From c8c3656d54a98bf9b9bde5434ff23e75e5b6c62d Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Mon, 10 Feb 2025 16:33:53 +0100 Subject: [PATCH 083/233] fix_resinc_error --- nebula/core/aggregation/aggregator.py | 2 ++ .../neighborpolicies/fcneighborpolicy.py | 7 ++++++- nebula/core/neighbormanagement/nodemanager.py | 6 ++++-- nebula/core/network/communications.py | 2 +- 4 files changed, 13 insertions(+), 4 deletions(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index 048b59e4d..16c303c27 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -306,6 +306,8 @@ async def aggregation_push_available(self): """ # TODO verify if an already sinchronized node gets desinchronized + # TODO sinc -> disconnect -> sinc no funciona + # TODO comprobar que se pare el proceso a mitad await self._push_strategy_lock.acquire_async() logging.info( diff --git a/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py b/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py index 9e4c1a77e..2ff6e148c 100644 --- a/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py +++ b/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py @@ -7,6 +7,7 @@ def __init__(self): self.max_neighbors = None self.nodes_known = set() self.neighbors = set() + self.addr = None self.neighbors_lock = Locker(name="neighbors_lock") self.nodes_known_lock = Locker(name="nodes_known_lock") @@ -33,6 +34,7 @@ def set_config(self, config): self.neighbors_lock.release() for addr in config[1]: self.nodes_known.add(addr) + self.addr def accept_connection(self, source, joining=False): """ @@ -48,7 +50,8 @@ def meet_node(self, node): Update the list of nodes known on federation """ self.nodes_known_lock.acquire() - self.nodes_known.add(node) + if node != self.addr: + self.nodes_known.add(node) self.nodes_known_lock.release() def get_nodes_known(self, neighbors_too=False, neighbors_only=False): @@ -95,6 +98,8 @@ def _connect_to(self): return ct def update_neighbors(self, node, remove=False): + if node == self.addr: + return self.neighbors_lock.acquire() if remove: try: diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 5e46724c7..83f70a4d9 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -252,8 +252,9 @@ async def neighbors_left(self): return len(await self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False)) > 0 def meet_node(self, node): - logging.info(f"Update nodes known | addr: {node}") - self.neighbor_policy.meet_node(node) + if node != self.engine.addr: + logging.info(f"Update nodes known | addr: {node}") + self.neighbor_policy.meet_node(node) def get_nodes_known(self, neighbors_too=False): return self.neighbor_policy.get_nodes_known(neighbors_too) @@ -387,6 +388,7 @@ async def check_robustness(self): logging.info("No Neighbors left | reconnecting with Federation") #TODO comprobar q funcione correctamente self.engine.update_sinchronized_status(False) + self.set_synchronizing_rounds(True) await asyncio.sleep(120) await self.reconnect_to_federation() elif ( diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 8064915bc..055a162f7 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -351,7 +351,7 @@ async def stablish_connection_to_federation(self, msg_type="discover_join", addr addrs = self.ecs.find_federation() logging.info(f"Found federation devices | addrs {addrs}") else: - logging.info("Searching federation process beginning... | Using addrs previously known") + logging.info(f"Searching federation process beginning... | Using addrs previously known {addrs_known}") addrs = addrs_known msg = self.create_message("discover", msg_type) From d8fbb3182fcee0ac6eaba4c9bbde3e1b4f992890 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Tue, 11 Feb 2025 13:56:10 +0100 Subject: [PATCH 084/233] fix_resinc_after_disc --- nebula/core/aggregation/aggregator.py | 33 +++++++++++++++---- nebula/core/engine.py | 4 +-- nebula/core/neighbormanagement/nodemanager.py | 13 +++++--- nebula/core/network/communications.py | 1 - 4 files changed, 37 insertions(+), 14 deletions(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index 16c303c27..0d46cbae0 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -50,6 +50,7 @@ def __init__(self, config=None, engine=None): self._aggregation_done_lock = Locker(name="aggregation_done_lock", async_lock=True) self._aggregation_waiting_skip = asyncio.Event() self._push_strategy_lock = Locker(name="push_strategy_lock", async_lock=True) + self._end_round_push = 0 def __str__(self): return self.__class__.__name__ @@ -299,27 +300,42 @@ def print_model_size(self, model): total_memory_in_mb = total_memory / (1024**2) logging.info(f"print_model_size | Model size: {total_memory_in_mb} MB") + def verify_push_done(self, current_round): + logging.info("Verifying if round push is done") + current_round = self.engine.get_round() + if self.engine.get_synchronizing_rounds(): + logging.info(f"end round push: {self._end_round_push}, current round: {current_round}") + if self._end_round_push <= current_round: + logging.info("Push done...") + self.engine.set_synchronizing_rounds(False) + self._end_round_push = 0 + if len(self._future_models_to_aggregate.items()) < 2: + logging.info("Device is sinchronized") + self.engine.update_sinchronized_status(True) + else: + logging.info("Device is not sinchronized yet | more actions required...") + async def aggregation_push_available(self): """ If the node is not sinchronized with the federation, it may be possible to make a push and try to catch the federation asap. """ - # TODO verify if an already sinchronized node gets desinchronized # TODO sinc -> disconnect -> sinc no funciona # TODO comprobar que se pare el proceso a mitad + current_round = self.engine.get_round() + if self.engine.get_synchronizing_rounds(): + self.verify_push_done(current_round) + await self._push_strategy_lock.acquire_async() - logging.info( - f"❗️ synchronized status: {self.engine.get_sinchronized_status()} | Analizing if an aggregation push is available..." - ) + logging.info(f"❗️ synchronized status: {self.engine.get_sinchronized_status()} | Analizing if an aggregation push is available...") if ( not self.engine.get_sinchronized_status() and not self.engine.get_trainning_in_progress_lock().locked() and not self.engine.get_synchronizing_rounds() ): n_fed_nodes = len(self._federation_nodes) - current_round = self.engine.get_round() further_round = current_round logging.info( f" Pending models: {len(self.get_nodes_pending_models_to_aggregate())} | federation: {n_fed_nodes}" @@ -339,6 +355,7 @@ async def aggregation_push_available(self): await self.engine.set_pushed_done(further_round - current_round) self.engine.update_sinchronized_status(False) self.engine.set_synchronizing_rounds(True) + self._end_round_push = further_round self._aggregation_waiting_skip.set() await self._push_strategy_lock.release_async() return @@ -352,6 +369,7 @@ async def aggregation_push_available(self): await self.engine.set_pushed_done(further_round - current_round) self.engine.update_sinchronized_status(False) self.engine.set_synchronizing_rounds(True) + self._end_round_push = further_round self._aggregation_waiting_skip.set() await self._push_strategy_lock.release_async() return @@ -384,6 +402,7 @@ async def aggregation_push_available(self): self.engine.update_sinchronized_status(False) self.engine.set_synchronizing_rounds(True) await self.engine.set_pushed_done(further_round - current_round) + self._end_round_push = further_round self.engine.set_round(further_round) await self._add_model_lock.release_async() await self._add_next_model_lock.release_async() @@ -406,10 +425,10 @@ async def aggregation_push_available(self): await self._push_strategy_lock.release_async() else: if not self.engine.get_sinchronized_status(): - if self.engine.get_sinchronized_status(): + if self.engine.get_trainning_in_progress_lock().locked(): logging.info("❗️ Cannot analize push | Trainning in progress") elif self.engine.get_synchronizing_rounds(): - logging.info("❗️ Cannot analize push | already pushing rounds") + logging.info("❗️ Cannot analize push | Already pushing rounds") await self._push_strategy_lock.release_async() diff --git a/nebula/core/engine.py b/nebula/core/engine.py index fcf9af0ec..a134b0399 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -243,7 +243,8 @@ def update_sinchronized_status(self, status): def set_synchronizing_rounds(self, status): if self.mobility: - self.nm.set_synchronizing_rounds(not status) + logging.info(f"Set sinchronizing rounds: {status}") + self.nm.set_synchronizing_rounds(status) def set_round(self, new_round): logging.info(f"πŸ€– Update round count | from: {self.round} | to round: {new_round}") @@ -857,7 +858,6 @@ async def _additional_mobility_actions(self): if not self.mobility: return logging.info("πŸ”„ Starting additional mobility actions...") - # self.trainer.show_current_learning_rate() await self.nm.check_robustness() action = await self.nm.check_external_connection_service_status() if action: diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 83f70a4d9..0b96aad3d 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -59,6 +59,8 @@ def __init__( self._fast_reboot_status = fastreboot self._momemtum_status = momentum + self._desc_done = False #TODO remove + @property def engine(self): return self._engine @@ -366,7 +368,9 @@ async def start_late_connection_process(self, connected=False, msg_type="discove self.accept_candidates_lock.release() self.late_connection_process_lock.release() self.candidate_selector.remove_candidates() - asyncio.create_task(self.stop_connections_with_federation()) + if not self._desc_done: #TODO remove + self._desc_done = True + asyncio.create_task(self.stop_connections_with_federation()) # if no candidates, repeat process else: logging.info("❗️ No Candidates found...") @@ -388,7 +392,6 @@ async def check_robustness(self): logging.info("No Neighbors left | reconnecting with Federation") #TODO comprobar q funcione correctamente self.engine.update_sinchronized_status(False) - self.set_synchronizing_rounds(True) await asyncio.sleep(120) await self.reconnect_to_federation() elif ( @@ -412,9 +415,11 @@ async def check_robustness(self): logging.info("❗️ Reestructure/Reconnecting process already running...") async def reconnect_to_federation(self): - # If we got some refs, try to reconnect to them self._restructure_process_lock.acquire() - await self.engine.cm.clear_restrictions() + await self.engine.cm.clear_restrictions() + if await self.engine.cm.is_external_connection_service_running(): + self.engine.cm.stop_external_connection_service() + # If we got some refs, try to reconnect to them if len(self.neighbor_policy.get_nodes_known()) > 0: logging.info("Reconnecting | Addrs availables") await self.start_late_connection_process(connected=False, msg_type="discover_nodes", addrs_known=self.neighbor_policy.get_nodes_known()) diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 055a162f7..213799535 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -523,7 +523,6 @@ async def terminate_failed_reconnection(self, conn: Connection): connected_with = conn.addr await self.bl.add_recently_disconnected(connected_with) await self.disconnect(connected_with, mutual_disconnection=False) - #await self.engine.update_neighbors(connected_with, await self.get_addrs_current_connections(only_direct=True, myself=True), remove=True) async def stop(self): logging.info("🌐 Stopping Communications Manager... [Removing connections and stopping network engine]") From 6348ceeec1e288e667c62a41bad130e026d431cb Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Tue, 11 Feb 2025 15:40:46 +0100 Subject: [PATCH 085/233] fix_resinc_Node --- nebula/core/aggregation/aggregator.py | 8 ++------ nebula/core/engine.py | 1 - 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index 0d46cbae0..d040836a7 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -301,10 +301,9 @@ def print_model_size(self, model): logging.info(f"print_model_size | Model size: {total_memory_in_mb} MB") def verify_push_done(self, current_round): - logging.info("Verifying if round push is done") current_round = self.engine.get_round() if self.engine.get_synchronizing_rounds(): - logging.info(f"end round push: {self._end_round_push}, current round: {current_round}") + logging.info("Verifying if round push is done") if self._end_round_push <= current_round: logging.info("Push done...") self.engine.set_synchronizing_rounds(False) @@ -321,11 +320,8 @@ async def aggregation_push_available(self): and try to catch the federation asap. """ # TODO verify if an already sinchronized node gets desinchronized - # TODO sinc -> disconnect -> sinc no funciona - # TODO comprobar que se pare el proceso a mitad current_round = self.engine.get_round() - if self.engine.get_synchronizing_rounds(): - self.verify_push_done(current_round) + self.verify_push_done(current_round) await self._push_strategy_lock.acquire_async() diff --git a/nebula/core/engine.py b/nebula/core/engine.py index a134b0399..d64a28208 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -426,7 +426,6 @@ async def _connection_restructure_callback(self, source, message): async def _discover_discover_join_callback(self, source, message): logging.info(f"πŸ” handle_discover_message | Trigger | Received discover_join message from {source} ") - # TODO caso para el starter recibir antes de iniciar federacion if len(self.get_federation_nodes()) > 0: await self.trainning_in_progress_lock.acquire_async() model, rounds, round = ( From 0669f8f81387ca5d1825cebcbf4a08dea2a1ade8 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 12 Feb 2025 16:43:21 +0100 Subject: [PATCH 086/233] feat_target_attacks --- .../communications/communicationattack.py | 26 +++++++++++++++++-- .../attacks/communications/delayerattack.py | 10 ++++++- nebula/core/network/communications.py | 3 ++- 3 files changed, 35 insertions(+), 4 deletions(-) diff --git a/nebula/addons/attacks/communications/communicationattack.py b/nebula/addons/attacks/communications/communicationattack.py index 53732cd58..128ac020d 100644 --- a/nebula/addons/attacks/communications/communicationattack.py +++ b/nebula/addons/attacks/communications/communicationattack.py @@ -1,12 +1,21 @@ import logging import types from abc import abstractmethod +import random from nebula.addons.attacks.attacks import Attack class CommunicationAttack(Attack): - def __init__(self, engine, target_class, target_method, round_start_attack, round_stop_attack, decorator_args=None): + def __init__(self, engine, + target_class, + target_method, + round_start_attack, + round_stop_attack, + decorator_args=None, + selectivity_percentage: int = 100, + selection_interval: int = None + ): super().__init__() self.engine = engine self.target_class = target_class @@ -15,6 +24,10 @@ def __init__(self, engine, target_class, target_method, round_start_attack, roun self.round_start_attack = round_start_attack self.round_stop_attack = round_stop_attack self.original_method = getattr(target_class, target_method, None) + self.selectivity_percentage = selectivity_percentage + self.selection_interval = selection_interval + self.last_selection_round = 0 + self.targets = set() if not self.original_method: raise AttributeError(f"Method {target_method} not found in class {target_class}") @@ -23,7 +36,16 @@ def __init__(self, engine, target_class, target_method, round_start_attack, roun def decorator(self, *args): """Decorator that adds malicious behavior to the execution of the original method.""" pass - + + async def select_targets(self): + if not self.selection_interval and not self.targets: + self.targets = await self.engine.cm.get_addrs_current_connections(only_direct=True) + elif self.last_selection_round % self.selection_interval == 0: + all_nodes = await self.engine.cm.get_addrs_current_connections(only_direct=True) + num_targets = max(1, int(len(all_nodes) * (self.selectivity_percentage / 100))) + self.selected_targets = set(random.sample(all_nodes, num_targets)) + logging.info(f"Selected targets: {self.selected_targets}") + async def _inject_malicious_behaviour(self): """Inject malicious behavior into the target method.""" logging.info("Injecting malicious behavior") diff --git a/nebula/addons/attacks/communications/delayerattack.py b/nebula/addons/attacks/communications/delayerattack.py index f3352ab60..743403291 100644 --- a/nebula/addons/attacks/communications/delayerattack.py +++ b/nebula/addons/attacks/communications/delayerattack.py @@ -22,6 +22,8 @@ def __init__(self, engine, attack_params: dict): self.delay = int(attack_params["delay"]) round_start = int(attack_params["round_start_attack"]) round_stop = int(attack_params["round_stop_attack"]) + self.target_percentage = int(attack_params["target_percentage"]) + self.selection_interval = int(attack_params["selection_interval"]) except KeyError as e: raise ValueError(f"Missing required attack parameter: {e}") except ValueError: @@ -29,13 +31,18 @@ def __init__(self, engine, attack_params: dict): super().__init__( engine, - engine._cm._propagator, + engine._cm._propagator, #TODO modificar por send_model de communciations "propagate", round_start, round_stop, self.delay, ) + @abstractmethod + async def is_attack_selective(self): + """Obliga a todas las subclases de CommunicationAttack a implementarlo""" + return True + def decorator(self, delay: int): """ Decorator that adds a delay to the execution of the original method. @@ -50,6 +57,7 @@ def decorator(self, delay: int): def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): + await self.select_targets() logging.info(f"[DelayerAttack] Adding delay of {delay} seconds to {func.__name__}") await asyncio.sleep(delay) _, *new_args = args # Exclude self argument diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index e25e07b84..0339784ed 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -556,7 +556,8 @@ async def deploy_additional_services(self): self._generate_network_conditions() await self._forwarder.start() if self.config.participant["mobility_args"]["mobility"]: - await self._discoverer.start() + pass + #await self._discoverer.start() # await self._health.start() self._propagator.start() await self._mobility.start() From 3031cfeec4517e6dd268b7452552d19939e13346 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 13 Feb 2025 09:57:36 +0100 Subject: [PATCH 087/233] feature select changing targets --- .../communications/communicationattack.py | 26 ++++++++++++++----- .../attacks/communications/delayerattack.py | 25 +++++++++--------- 2 files changed, 32 insertions(+), 19 deletions(-) diff --git a/nebula/addons/attacks/communications/communicationattack.py b/nebula/addons/attacks/communications/communicationattack.py index 128ac020d..9344e4e59 100644 --- a/nebula/addons/attacks/communications/communicationattack.py +++ b/nebula/addons/attacks/communications/communicationattack.py @@ -38,13 +38,24 @@ def decorator(self, *args): pass async def select_targets(self): - if not self.selection_interval and not self.targets: - self.targets = await self.engine.cm.get_addrs_current_connections(only_direct=True) - elif self.last_selection_round % self.selection_interval == 0: - all_nodes = await self.engine.cm.get_addrs_current_connections(only_direct=True) - num_targets = max(1, int(len(all_nodes) * (self.selectivity_percentage / 100))) - self.selected_targets = set(random.sample(all_nodes, num_targets)) - logging.info(f"Selected targets: {self.selected_targets}") + if self.selectivity_percentage != 100: + if self.selection_interval: + if self.last_selection_round % self.selection_interval == 0: + logging.info("Recalculating targets...") + all_nodes = await self.engine.cm.get_addrs_current_connections(only_direct=True) + num_targets = max(1, int(len(all_nodes) * (self.selectivity_percentage / 100))) + self.targets = set(random.sample(list(all_nodes), num_targets)) + elif not self.targets: + logging.info("Calculating targets...") + all_nodes = await self.engine.cm.get_addrs_current_connections(only_direct=True) + num_targets = max(1, int(len(all_nodes) * (self.selectivity_percentage / 100))) + self.targets = set(random.sample(list(all_nodes), num_targets)) + else: + logging.info("All neighbors selected as targets") + self.targets = await self.engine.cm.get_addrs_current_connections(only_direct=True) + + logging.info(f"Selected {self.selectivity_percentage}% targets from neighbors: {self.targets}") + self.last_selection_round+=1 async def _inject_malicious_behaviour(self): """Inject malicious behavior into the target method.""" @@ -69,5 +80,6 @@ async def attack(self): logging.info(f"[{self.__class__.__name__}] Restoring original behavior") await self._restore_original_behaviour() elif self.engine.round == self.round_start_attack: + await self.select_targets() logging.info(f"[{self.__class__.__name__}] Injecting malicious behavior") await self._inject_malicious_behaviour() diff --git a/nebula/addons/attacks/communications/delayerattack.py b/nebula/addons/attacks/communications/delayerattack.py index 743403291..6b0ab9c67 100644 --- a/nebula/addons/attacks/communications/delayerattack.py +++ b/nebula/addons/attacks/communications/delayerattack.py @@ -22,8 +22,8 @@ def __init__(self, engine, attack_params: dict): self.delay = int(attack_params["delay"]) round_start = int(attack_params["round_start_attack"]) round_stop = int(attack_params["round_stop_attack"]) - self.target_percentage = int(attack_params["target_percentage"]) - self.selection_interval = int(attack_params["selection_interval"]) + self.target_percentage = 50#int(attack_params["target_percentage"]) + self.selection_interval = 1#int(attack_params["selection_interval"]) except KeyError as e: raise ValueError(f"Missing required attack parameter: {e}") except ValueError: @@ -31,18 +31,15 @@ def __init__(self, engine, attack_params: dict): super().__init__( engine, - engine._cm._propagator, #TODO modificar por send_model de communciations - "propagate", + engine._cm, + "send_model", round_start, round_stop, self.delay, + self.target_percentage, + self.selection_interval, ) - @abstractmethod - async def is_attack_selective(self): - """Obliga a todas las subclases de CommunicationAttack a implementarlo""" - return True - def decorator(self, delay: int): """ Decorator that adds a delay to the execution of the original method. @@ -57,9 +54,13 @@ def decorator(self, delay: int): def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): - await self.select_targets() - logging.info(f"[DelayerAttack] Adding delay of {delay} seconds to {func.__name__}") - await asyncio.sleep(delay) + if len(args) > 1: + dest_addr = args[1] + if dest_addr in self.targets: + logging.info(f"[DelayerAttack] Delaying model propagation to {dest_addr} by {delay} seconds") + await asyncio.sleep(delay) + #logging.info(f"[DelayerAttack] Adding delay of {delay} seconds to {func.__name__}") + #await asyncio.sleep(delay) _, *new_args = args # Exclude self argument return await func(*new_args) From 43cce2b42d60b2e8ec429d1ea4ff1a796d6a913b Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 13 Feb 2025 11:17:36 +0100 Subject: [PATCH 088/233] feat standar mobility strategies --- .../candidateselection/candidateselector.py | 4 +- .../stdcandidateselector.py | 37 ++++++ .../neighborpolicies/idleneighborpolicy.py | 113 +++++++++++++++--- nebula/core/neighbormanagement/nodemanager.py | 12 +- nebula/core/network/connection.py | 6 +- 5 files changed, 145 insertions(+), 27 deletions(-) create mode 100644 nebula/core/neighbormanagement/candidateselection/stdcandidateselector.py diff --git a/nebula/core/neighbormanagement/candidateselection/candidateselector.py b/nebula/core/neighbormanagement/candidateselection/candidateselector.py index f2cb5d0e1..16546c89a 100644 --- a/nebula/core/neighbormanagement/candidateselection/candidateselector.py +++ b/nebula/core/neighbormanagement/candidateselection/candidateselector.py @@ -24,6 +24,7 @@ def any_candidate(self): pass def factory_CandidateSelector(topology) -> CandidateSelector: + from nebula.core.neighbormanagement.candidateselection.stdcandidateselector import STDandidateSelector from nebula.core.neighbormanagement.candidateselection.fccandidateselector import FCCandidateSelector from nebula.core.neighbormanagement.candidateselection.hetcandidateselector import HETCandidateSelector from nebula.core.neighbormanagement.candidateselection.ringcandidateselector import RINGCandidateSelector @@ -31,7 +32,8 @@ def factory_CandidateSelector(topology) -> CandidateSelector: options = { "ring": RINGCandidateSelector, "fully": FCCandidateSelector, - "random": HETCandidateSelector, + "random": STDandidateSelector, + "het": HETCandidateSelector, } cs = options.get(topology, FCCandidateSelector) diff --git a/nebula/core/neighbormanagement/candidateselection/stdcandidateselector.py b/nebula/core/neighbormanagement/candidateselection/stdcandidateselector.py new file mode 100644 index 000000000..7418e8dbb --- /dev/null +++ b/nebula/core/neighbormanagement/candidateselection/stdcandidateselector.py @@ -0,0 +1,37 @@ +from nebula.core.neighbormanagement.candidateselection.candidateselector import CandidateSelector +from nebula.core.utils.locker import Locker + +class STDandidateSelector(CandidateSelector): + + def __init__(self): + self.candidates = [] + self.candidates_lock = Locker(name="candidates_lock") + + def set_config(self, config): + pass + + def add_candidate(self, candidate): + self.candidates_lock.acquire() + self.candidates.append(candidate) + self.candidates_lock.release() + + def select_candidates(self): + """ + Select mean number of neighbors + """ + self.candidates_lock.acquire() + mean_neighbors = sum(n for n, _ in self.candidates) / len(self.candidates) if self.candidates else 0 + cdts = self.candidates[:mean_neighbors] + self.candidates_lock.release() + return cdts + + def remove_candidates(self): + self.candidates_lock.acquire() + self.candidates = [] + self.candidates_lock.release() + + def any_candidate(self): + self.candidates_lock.acquire() + any = True if len(self.candidates) > 0 else False + self.candidates_lock.release() + return any \ No newline at end of file diff --git a/nebula/core/neighbormanagement/neighborpolicies/idleneighborpolicy.py b/nebula/core/neighbormanagement/neighborpolicies/idleneighborpolicy.py index 63f4b8285..20a7fbbc2 100644 --- a/nebula/core/neighbormanagement/neighborpolicies/idleneighborpolicy.py +++ b/nebula/core/neighbormanagement/neighborpolicies/idleneighborpolicy.py @@ -1,30 +1,111 @@ from nebula.core.neighbormanagement.neighborpolicies.neighborpolicy import NeighborPolicy +from nebula.core.utils.locker import Locker class IDLENeighborPolicy(NeighborPolicy): def __init__(self): - pass - - def set_config(self, config): - pass - + self.max_neighbors = None + self.nodes_known = set() + self.neighbors = set() + self.addr = None + self.neighbors_lock = Locker(name="neighbors_lock") + self.nodes_known_lock = Locker(name="nodes_known_lock") + def need_more_neighbors(self): - return False - + """ + Fully connected network requires to be connected to all devices, therefore, + if there are more nodes known that self.neighbors, more neighbors are required + """ + self.neighbors_lock.acquire() + need_more = (len(self.neighbors) <= 0) + self.neighbors_lock.release() + return need_more + + def set_config(self, config): + """ + Args: + config[0] -> list of self neighbors + config[1] -> list of nodes known on federation + config[2] -> self addr + config[3] -> NodeManager reference + """ + self.neighbors_lock.acquire() + self.neighbors = config[0] + self.neighbors_lock.release() + for addr in config[1]: + self.nodes_known.add(addr) + self.addr + def accept_connection(self, source, joining=False): - return False + """ + return true if connection is accepted + """ + self.neighbors_lock.acquire() + ac = (not source in self.neighbors) + self.neighbors_lock.release() + return ac - def get_actions(self): - return [[],[]] - def meet_node(self, node): - pass + """ + Update the list of nodes known on federation + """ + self.nodes_known_lock.acquire() + if node != self.addr: + self.nodes_known.add(node) + self.nodes_known_lock.release() + + def get_nodes_known(self, neighbors_too=False, neighbors_only=False): + if neighbors_only: + self.neighbors_lock.acquire() + no = self.neighbors.copy() + self.neighbors_lock.release() + return no + + self.nodes_known_lock.acquire() + nk = self.nodes_known.copy() + if not neighbors_too: + self.neighbors_lock.acquire() + nk = self.nodes_known - self.neighbors + self.neighbors_lock.release() + self.nodes_known_lock.release() + return nk def forget_nodes(self, node, forget_all=False): - pass + self.nodes_known_lock.acquire() + if forget_all: + self.nodes_known.clear() + else: + self.nodes_known.discard(node) + self.nodes_known_lock.release() + + def get_actions(self): + """ + return list of actions to do in response to connection + - First list represents addrs argument to LinkMessage to connect to + - Second one represents the same but for disconnect from LinkMessage + """ + return [self._connect_to(), self._disconnect_from()] + + + def _disconnect_from(self): + return "" - def get_nodes_known(self, neighbors_too=False, neighbors_only=False): - return set() + def _connect_to(self): + ct = "" + self.neighbors_lock.acquire() + ct = " ".join(self.neighbors) + self.neighbors_lock.release() + return ct def update_neighbors(self, node, remove=False): - pass \ No newline at end of file + if node == self.addr: + return + self.neighbors_lock.acquire() + if remove: + try: + self.neighbors.remove(node) + except KeyError: + pass + else: + self.neighbors.add(node) + self.neighbors_lock.release() \ No newline at end of file diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index 0b96aad3d..ed27da094 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -28,7 +28,7 @@ def __init__( momentum=False, ): self._aditional_participant = aditional_participant - self.topology = "fully" # topology + self.topology = topology print_msg_box( msg=f"Starting NodeManager module...\nTopology: {self.topology}", indent=2, title="NodeManager module" ) @@ -264,8 +264,6 @@ def get_nodes_known(self, neighbors_too=False): def accept_model_offer(self, source, decoded_model, rounds, round, epochs, n_neighbors, loss): if not self.accept_candidates_lock.locked(): logging.info(f"πŸ”„ Processing offer from {source}...") - # model_accepted = True#self.model_handler.accept_model(decoded_model) - # if source == "192.168.50.8:45007": model_accepted = self.model_handler.accept_model(decoded_model) self.model_handler.set_config(config=(rounds, round, epochs, self)) if model_accepted: @@ -368,9 +366,9 @@ async def start_late_connection_process(self, connected=False, msg_type="discove self.accept_candidates_lock.release() self.late_connection_process_lock.release() self.candidate_selector.remove_candidates() - if not self._desc_done: #TODO remove - self._desc_done = True - asyncio.create_task(self.stop_connections_with_federation()) + #if not self._desc_done: #TODO remove + # self._desc_done = True + # asyncio.create_task(self.stop_connections_with_federation()) # if no candidates, repeat process else: logging.info("❗️ No Candidates found...") @@ -390,9 +388,7 @@ async def check_robustness(self): if not self._restructure_process_lock.locked(): if not await self.neighbors_left(): logging.info("No Neighbors left | reconnecting with Federation") - #TODO comprobar q funcione correctamente self.engine.update_sinchronized_status(False) - await asyncio.sleep(120) await self.reconnect_to_federation() elif ( self.neighbor_policy.need_more_neighbors() diff --git a/nebula/core/network/connection.py b/nebula/core/network/connection.py index ff8e5e0f2..b2b341b8c 100755 --- a/nebula/core/network/connection.py +++ b/nebula/core/network/connection.py @@ -247,7 +247,8 @@ async def send( await self._send_chunks(message_id, data_to_send) except Exception as e: logging.exception(f"Error sending data: {e}") - await self.reconnect() + if self.direct: + await self.reconnect() def _prepare_data(self, data: Any, pb: bool, encoding_type: str) -> tuple[bytes, bytes]: if pb: @@ -322,7 +323,8 @@ async def handle_incoming_message(self) -> None: except BrokenPipeError: logging.exception(f"Error handling incoming message: {e}") finally: - await self.reconnect() + if self.direct: + await self.reconnect() async def _read_exactly(self, num_bytes: int, max_retries: int = 3) -> bytes: data = b"" From d5cdb7daf4ded97fb6098e3327d71529c8d16e87 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 13 Feb 2025 13:49:57 +0100 Subject: [PATCH 089/233] feataure update storage --- nebula/core/aggregation/aggregator.py | 24 ++--- nebula/core/aggregation/updatestorage.py | 108 +++++++++++++++++++++++ 2 files changed, 113 insertions(+), 19 deletions(-) create mode 100644 nebula/core/aggregation/updatestorage.py diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index 636a3ecf1..2def6d6e0 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -2,10 +2,10 @@ import logging from abc import ABC, abstractmethod from functools import partial -from typing import TYPE_CHECKING - from nebula.core.utils.locker import Locker +from nebula.core.aggregation.updatestorage import UpdateStorage +from typing import TYPE_CHECKING if TYPE_CHECKING: from nebula.core.engine import Engine @@ -51,6 +51,7 @@ def __init__(self, config=None, engine=None): self._aggregation_waiting_skip = asyncio.Event() self._push_strategy_lock = Locker(name="push_strategy_lock", async_lock=True) self._end_round_push = 0 + self._update_storage = UpdateStorage(aggregator=self, addr=self._addr) def __str__(self): return self.__class__.__name__ @@ -429,23 +430,8 @@ async def aggregation_push_available(self): logging.info("❗️ Cannot analize push | Already pushing rounds") await self._push_strategy_lock.release_async() - -def create_malicious_aggregator(aggregator, attack): - # It creates a partial function aggregate that wraps the aggregate method of the original aggregator. - run_aggregation = partial(aggregator.run_aggregation) # None is the self (not used) - - # This function will replace the original aggregate method of the aggregator. - def malicious_aggregate(self, models): - accum = run_aggregation(models) - logging.info(f"malicious_aggregate | original aggregation result={accum}") - if models is not None: - accum = attack(accum) - logging.info(f"malicious_aggregate | attack aggregation result={accum}") - return accum - - aggregator.run_aggregation = partial(malicious_aggregate, aggregator) - return aggregator - + def notify_all_updates_received(self): + self._aggregation_waiting_skip.set() def create_aggregator(config, engine) -> Aggregator: from nebula.core.aggregation.blockchainReputation import BlockchainReputation diff --git a/nebula/core/aggregation/updatestorage.py b/nebula/core/aggregation/updatestorage.py new file mode 100644 index 000000000..883bc64a0 --- /dev/null +++ b/nebula/core/aggregation/updatestorage.py @@ -0,0 +1,108 @@ +import asyncio +import logging +from collections import deque +from typing import Dict, Tuple, Deque +from nebula.core.utils.locker import Locker +import time +from nebula.core.aggregation.aggregator import Aggregator + +class Update(): + def __init__(self, model, weight, source, round, time_received): + self.model = model + self.weight = weight + self.source = source + self.round = round + self.time_received = time_received + +MAX_UPDATE_BUFFER_SIZE = 1 # Modify to create an historic + +class UpdateStorage(): + def __init__( + self, + aggregator: Aggregator, + addr, + buffersize=MAX_UPDATE_BUFFER_SIZE + ): + self._addr = addr + self._aggregator = aggregator + self._buffersize = buffersize + self._updates_storage: Dict[str, Tuple[Update, Deque[Update]]] = {} + self._updates_storage_lock = Locker(name="updates_storage_lock", async_lock=True) + self._sources_expected = set() + self._sources_received = set() + + @property + def us(self): + return self._updates_storage + + @property + def agg(self): + return self._aggregator + + async def round_expected_updates(self, federation_nodes: set): + self._updates_storage_lock.acquire_async() + self._sources_expected = federation_nodes + self._sources_received.clear() + for fn in federation_nodes: + if fn not in self.us: + self.us[fn] = (None, deque(maxlen=self._buffersize)) + self._updates_storage_lock.release_async() + + async def storage_update(self, model, weight, source, round): + #TODO verificar duplicados + time_received = time.time() + if source in self._sources_expected: + updt = Update(model, weight, source, round, time_received) + self._updates_storage_lock.acquire_async() + self.us[source][1].append(updt) + self.us[source][0] = updt + logging.info(f"Storage Update | source={source} | round={round} | weight={weight} | federation nodes: {self._sources_expected}") + + self._sources_received.add(source) + updates_left = self._sources_expected.difference(self._sources_received) + logging.info(f"Updates received ({self._sources_received}/{self._sources_expected}) | Missing nodes: {updates_left}") + if not updates_left: + await self.all_updates_received() + self._updates_storage_lock.release_async() + else: + logging.info(f"source: {source} not in expected updates for this Round") + + async def update_source(self, source, remove=False): + logging.info(f"πŸ”„ Update | remove: {remove} | soure: {source}") + self._updates_storage_lock.acquire_async() + if remove: + self._sources_expected.discard(source) + del self.us[source] + else: + self.us[source] = (None, deque(maxlen=self._buffersize)) + self._sources_expected.add(source) + self._updates_storage_lock.release_async() + + async def get_round_updates(self): + self._updates_storage_lock.acquire_async() + updates_missing = self._sources_expected.difference(self._sources_received) + if updates_missing: + logging.info(f"Missing updates from sources: {updates_missing}") + updates = {} + for sr in self._sources_received: + source_historic = self.us[sr][1] + last_updt_received = self.us[sr][0] + updt: Update = None + if source_historic: + updt = source_historic[-1] + elif last_updt_received: + logging.info(f"Missing update source: {sr}, using last update received..") + updt = last_updt_received + updates[sr] = (updt.model, updt.weight) + self._updates_storage_lock.release_async() + return updates + + async def get_round_missing_nodes(self): + self._updates_storage_lock.acquire_async() + updates_left = self._sources_expected.difference(self._sources_received) + self._updates_storage_lock.release_async() + return updates_left + + async def all_updates_received(self): + logging.info("πŸ”„ Notify | All updates received") + self.agg.notify_all_updates_received() \ No newline at end of file From d3abf62f232aaccdde52050034e3da0e451d9054 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 13 Feb 2025 15:47:40 +0100 Subject: [PATCH 090/233] fix update storage errors --- nebula/core/aggregation/aggregator.py | 3 +- nebula/core/aggregation/updatestorage.py | 149 ++++++++++++++++++----- 2 files changed, 119 insertions(+), 33 deletions(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index 2def6d6e0..e0e458acf 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -216,6 +216,7 @@ async def get_aggregation(self): try: timeout = self.config.participant["aggregator_args"]["aggregation_timeout"] logging.info(f"Aggregation timeout: {timeout} starts...") + #TODO notificar a updatestorage lock_task = asyncio.create_task(self._aggregation_done_lock.acquire_async(timeout=timeout)) skip_task = asyncio.create_task(self._aggregation_waiting_skip.wait()) done, pending = await asyncio.wait( @@ -430,7 +431,7 @@ async def aggregation_push_available(self): logging.info("❗️ Cannot analize push | Already pushing rounds") await self._push_strategy_lock.release_async() - def notify_all_updates_received(self): + async def notify_all_updates_received(self): self._aggregation_waiting_skip.set() def create_aggregator(config, engine) -> Aggregator: diff --git a/nebula/core/aggregation/updatestorage.py b/nebula/core/aggregation/updatestorage.py index 883bc64a0..355556f4e 100644 --- a/nebula/core/aggregation/updatestorage.py +++ b/nebula/core/aggregation/updatestorage.py @@ -13,6 +13,9 @@ def __init__(self, model, weight, source, round, time_received): self.source = source self.round = round self.time_received = time_received + + def __eq__(self, other): + return self.round == other.round MAX_UPDATE_BUFFER_SIZE = 1 # Modify to create an historic @@ -30,6 +33,8 @@ def __init__( self._updates_storage_lock = Locker(name="updates_storage_lock", async_lock=True) self._sources_expected = set() self._sources_received = set() + self._round_updates_lock = Locker(name="round_updates_lock", async_lock=True) # se coge cuando se empieza a comprobar si estan todas las updates + self._update_federation_lock = Locker(name="update_federation_lock", async_lock=True) @property def us(self): @@ -40,46 +45,92 @@ def agg(self): return self._aggregator async def round_expected_updates(self, federation_nodes: set): - self._updates_storage_lock.acquire_async() - self._sources_expected = federation_nodes + """ + Initializes the expected updates for the current round. + + This method updates the list of expected sources (`_sources_expected`) for the current training round + and ensures that their respective update storage is initialized if they were not previously registered. + + Args: + federation_nodes (set): A set of node identifiers expected to provide updates in the current round. + """ + await self._update_federation_lock.acquire_async() + await self._updates_storage_lock.acquire_async() + self._sources_expected = federation_nodes.copy() self._sources_received.clear() + + # Initialize new nodes for fn in federation_nodes: if fn not in self.us: self.us[fn] = (None, deque(maxlen=self._buffersize)) - self._updates_storage_lock.release_async() + + # Clear removed nodes + removed_nodes = [node for node in self._updates_storage.keys() if node not in federation_nodes] + for rn in removed_nodes: + del self._updates_storage[rn] + + await self._updates_storage_lock.release_async() + await self._update_federation_lock.release_async() + + # Lock to check if all updates received + if self._round_updates_lock.locked(): + self._round_updates_lock.release_async() async def storage_update(self, model, weight, source, round): - #TODO verificar duplicados + """ + Stores an update in the update queue if it has not been previously received for the same source and round. + + This method ensures that only one update per source and round is stored, avoiding duplicates. + If all expected updates for the current round have been received, it triggers the `all_updates_received()` method. + + Args: + model: The model associated with the update. + weight: The weight assigned to the update. + source (str): The source identifier of the update. + round (int): The training round in which the update was received. + + Logs: + - Stores and logs the update if it's new for the round. + - Logs a duplicate update if an identical one already exists. + - Logs missing sources if not all expected updates have been received. + """ time_received = time.time() if source in self._sources_expected: updt = Update(model, weight, source, round, time_received) - self._updates_storage_lock.acquire_async() - self.us[source][1].append(updt) - self.us[source][0] = updt - logging.info(f"Storage Update | source={source} | round={round} | weight={weight} | federation nodes: {self._sources_expected}") - - self._sources_received.add(source) - updates_left = self._sources_expected.difference(self._sources_received) - logging.info(f"Updates received ({self._sources_received}/{self._sources_expected}) | Missing nodes: {updates_left}") - if not updates_left: - await self.all_updates_received() - self._updates_storage_lock.release_async() + await self._updates_storage_lock.acquire_async() + if updt in self.us[source][1]: + logging.info(f"Discard | Alerady received update from source: {source} for round: {round}") + else: + self.us[source][1].append(updt) + self.us[source][0] = updt + logging.info(f"Storage Update | source={source} | round={round} | weight={weight} | federation nodes: {self._sources_expected}") + + self._sources_received.add(source) + updates_left = self._sources_expected.difference(self._sources_received) + logging.info(f"Updates received ({self._sources_received}/{self._sources_expected}) | Missing nodes: {updates_left}") + if self._round_updates_lock.locked() and not updates_left: + await self.all_updates_received() + await self._updates_storage_lock.release_async() else: logging.info(f"source: {source} not in expected updates for this Round") - async def update_source(self, source, remove=False): - logging.info(f"πŸ”„ Update | remove: {remove} | soure: {source}") - self._updates_storage_lock.acquire_async() - if remove: - self._sources_expected.discard(source) - del self.us[source] - else: - self.us[source] = (None, deque(maxlen=self._buffersize)) - self._sources_expected.add(source) - self._updates_storage_lock.release_async() - async def get_round_updates(self): - self._updates_storage_lock.acquire_async() + """ + Retrieves the latest updates received in the current round. + + This method collects updates from all received sources, prioritizing the most recent update + stored in the queue. If an expected update is missing, it attempts to use the last received + update from that source instead. + + Returns: + dict: A dictionary mapping each source to a tuple `(model, weight)`, containing + the most recent update for that source. + + Logs: + - Logs missing sources if expected updates have not been received. + - Logs when a missing update is replaced by the last received update. + """ + await self._updates_storage_lock.acquire_async() updates_missing = self._sources_expected.difference(self._sources_received) if updates_missing: logging.info(f"Missing updates from sources: {updates_missing}") @@ -94,15 +145,49 @@ async def get_round_updates(self): logging.info(f"Missing update source: {sr}, using last update received..") updt = last_updt_received updates[sr] = (updt.model, updt.weight) - self._updates_storage_lock.release_async() + await self._updates_storage_lock.release_async() return updates + + async def notify_federation_update(self, source, remove=False): + if not remove: + if self._round_updates_lock.locked(): + logging.info(f"Source: {source} will be count next round") + else: + self._update_source(source, remove) + # si se quita + else: + pass + # comprobar si he recibido la updt + # si no la he recibido se descarta de los esperados + # si la he recibido + # si antes de la agregacion de esta ronda la uso + + async def _clear_removed_sources(self): + pass + + async def _update_source(self, source, remove=False): + logging.info(f"πŸ”„ Update | remove: {remove} | source: {source}") + await self._updates_storage_lock.acquire_async() + if remove: + self._sources_expected.discard(source) + else: + self.us[source] = (None, deque(maxlen=self._buffersize)) + self._sources_expected.add(source) + logging.info(f"federation nodes expected this round: {self._sources_expected}") + await self._updates_storage_lock.release_async() async def get_round_missing_nodes(self): - self._updates_storage_lock.acquire_async() + await self._updates_storage_lock.acquire_async() updates_left = self._sources_expected.difference(self._sources_received) - self._updates_storage_lock.release_async() + await self._updates_storage_lock.release_async() return updates_left + async def notify_if_all_updates_received(self): + await self._round_updates_lock.acquire_async() + async def all_updates_received(self): - logging.info("πŸ”„ Notify | All updates received") - self.agg.notify_all_updates_received() \ No newline at end of file + updates_left = self._sources_expected.difference(self._sources_received) + if len(updates_left) == 0: + await self._round_updates_lock.release_async() + await self.agg.notify_all_updates_received() + \ No newline at end of file From 16ca2dad7e145dcff02078a8e8f460673202e1ac Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Fri, 14 Feb 2025 13:32:00 +0100 Subject: [PATCH 091/233] fix_no_round_mechs --- nebula/core/aggregation/aggregator.py | 58 ++++++++----- nebula/core/aggregation/updatestorage.py | 50 +++++++---- nebula/core/engine.py | 43 +++++++++- nebula/core/neighbormanagement/nodemanager.py | 7 +- nebula/core/network/communications.py | 85 +++++-------------- 5 files changed, 135 insertions(+), 108 deletions(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index e0e458acf..8eb487a57 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -62,6 +62,10 @@ def __repr__(self): @property def cm(self): return self.engine.cm + + @property + def us(self): + return self._update_storage @abstractmethod def run_aggregation(self, models): @@ -70,6 +74,8 @@ def run_aggregation(self, models): return None async def update_federation_nodes(self, federation_nodes: set): + await self.us.round_expected_updates(federation_nodes=federation_nodes) + if not self._aggregation_done_lock.locked(): self._federation_nodes = federation_nodes self._pending_models_to_aggregate.clear() @@ -79,27 +85,32 @@ async def update_federation_nodes(self, federation_nodes: set): else: raise Exception("It is not possible to set nodes to aggregate when the aggregation is running.") - async def notify_federation_nodes_removed(self, federation_nodes: set): - # Neighbor has been removed - if len(self._federation_nodes) - len(federation_nodes) > 0: - nodes_removed = self._federation_nodes.symmetric_difference(federation_nodes) - logging.info(f"Nodes removed from aggregation: {nodes_removed}") - pending_nodes = self._federation_nodes - self.get_nodes_pending_models_to_aggregate() - # logging.info(f"Pending models to aggregate: {pending_nodes}") - shouldnt_waited_model = [] - shouldnt_waited_model = [source for source in nodes_removed if source in pending_nodes] - logging.info(f"Waiting models from removed neighbors: {shouldnt_waited_model}") - if shouldnt_waited_model: - for swm in shouldnt_waited_model: - logging.info(f"Removing model from waiting: {swm}") - pending_nodes.discard(swm) - if self._aggregation_done_lock.locked(): - if not pending_nodes: - logging.info("No model updates required left | releasing aggregation lock...") - self._federation_nodes = federation_nodes - await self._aggregation_done_lock.release_async() + async def update_received_from_source(self, model, weight, source, round): + pass - self._federation_nodes = federation_nodes + + async def notify_federation_nodes_removed(self, federation_node, remove=False): + # Neighbor has been removed + #if len(self._federation_nodes) - len(federation_nodes) > 0: + # nodes_removed = self._federation_nodes.symmetric_difference(federation_nodes) + # logging.info(f"Nodes removed from aggregation: {nodes_removed}") + # pending_nodes = self._federation_nodes - self.get_nodes_pending_models_to_aggregate() + # # logging.info(f"Pending models to aggregate: {pending_nodes}") + # shouldnt_waited_model = [] + # shouldnt_waited_model = [source for source in nodes_removed if source in pending_nodes] + # logging.info(f"Waiting models from removed neighbors: {shouldnt_waited_model}") + # if shouldnt_waited_model: + # for swm in shouldnt_waited_model: + # logging.info(f"Removing model from waiting: {swm}") + # pending_nodes.discard(swm) + # if self._aggregation_done_lock.locked(): + # if not pending_nodes: + # logging.info("No model updates required left | releasing aggregation lock...") + # self._federation_nodes = federation_nodes + # await self._aggregation_done_lock.release_async() + + # self._federation_nodes = federation_nodes + await self.us.notify_federation_update(federation_node, remove=remove) def set_waiting_global_update(self): self._waiting_global_update = True @@ -180,6 +191,7 @@ async def _add_pending_model(self, model, weight, source): return self.get_nodes_pending_models_to_aggregate() async def include_model_in_buffer(self, model, weight, source=None, round=None, local=False): + await self.us.storage_update(model, weight, source, round) await self._add_model_lock.acquire_async() logging.info( f"πŸ”„ include_model_in_buffer | source={source} | round={round} | weight={weight} |--| __models={self._pending_models_to_aggregate.keys()} | federation_nodes={self._federation_nodes} | pending_models_to_aggregate={self.get_nodes_pending_models_to_aggregate()}" @@ -216,7 +228,7 @@ async def get_aggregation(self): try: timeout = self.config.participant["aggregator_args"]["aggregation_timeout"] logging.info(f"Aggregation timeout: {timeout} starts...") - #TODO notificar a updatestorage + await self.us.notify_if_all_updates_received() lock_task = asyncio.create_task(self._aggregation_done_lock.acquire_async(timeout=timeout)) skip_task = asyncio.create_task(self._aggregation_waiting_skip.wait()) done, pending = await asyncio.wait( @@ -244,6 +256,9 @@ async def get_aggregation(self): if lock_acquired: await self._aggregation_done_lock.release_async() + await self.us.stop_notifying_updates() + updates = await self.us.get_round_updates() + if self._waiting_global_update and len(self._pending_models_to_aggregate) == 1: logging.info( "πŸ”„ get_aggregation | Received an global model. Overwriting my model with the aggregated model." @@ -266,6 +281,7 @@ async def get_aggregation(self): return aggregated_result async def include_next_model_in_buffer(self, model, weight, source=None, round=None): + await self.us.storage_update(model, weight, source, round) logging.info(f"πŸ”„ include_next_model_in_buffer | source={source} | round={round} | weight={weight}") if round not in self._future_models_to_aggregate: self._future_models_to_aggregate[round] = [] diff --git a/nebula/core/aggregation/updatestorage.py b/nebula/core/aggregation/updatestorage.py index 355556f4e..d84a95637 100644 --- a/nebula/core/aggregation/updatestorage.py +++ b/nebula/core/aggregation/updatestorage.py @@ -4,7 +4,10 @@ from typing import Dict, Tuple, Deque from nebula.core.utils.locker import Locker import time -from nebula.core.aggregation.aggregator import Aggregator + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from nebula.core.aggregation.aggregator import Aggregator class Update(): def __init__(self, model, weight, source, round, time_received): @@ -22,12 +25,12 @@ def __eq__(self, other): class UpdateStorage(): def __init__( self, - aggregator: Aggregator, + aggregator, addr, buffersize=MAX_UPDATE_BUFFER_SIZE ): self._addr = addr - self._aggregator = aggregator + self._aggregator: Aggregator = aggregator self._buffersize = buffersize self._updates_storage: Dict[str, Tuple[Update, Deque[Update]]] = {} self._updates_storage_lock = Locker(name="updates_storage_lock", async_lock=True) @@ -69,12 +72,22 @@ async def round_expected_updates(self, federation_nodes: set): for rn in removed_nodes: del self._updates_storage[rn] + # Check already received updates + await self._check_updates_already_received() + await self._updates_storage_lock.release_async() await self._update_federation_lock.release_async() # Lock to check if all updates received if self._round_updates_lock.locked(): self._round_updates_lock.release_async() + + async def _check_updates_already_received(self): + for se in self._sources_expected: + (_,node_storage) = self._updates_storage[se] + if len(node_storage): + logging.info(f"Update already received from source: {se} | ({len(self._sources_received)}/{len(self._sources_expected)}) Updates received") + self._sources_received.add(se) async def storage_update(self, model, weight, source, round): """ @@ -102,12 +115,12 @@ async def storage_update(self, model, weight, source, round): logging.info(f"Discard | Alerady received update from source: {source} for round: {round}") else: self.us[source][1].append(updt) - self.us[source][0] = updt + self.us[source] = (updt, self.us[source][1]) logging.info(f"Storage Update | source={source} | round={round} | weight={weight} | federation nodes: {self._sources_expected}") self._sources_received.add(source) updates_left = self._sources_expected.difference(self._sources_received) - logging.info(f"Updates received ({self._sources_received}/{self._sources_expected}) | Missing nodes: {updates_left}") + logging.info(f"Updates received ({len(self._sources_received)}/{len(self._sources_expected)}) | Missing nodes: {updates_left}") if self._round_updates_lock.locked() and not updates_left: await self.all_updates_received() await self._updates_storage_lock.release_async() @@ -139,8 +152,8 @@ async def get_round_updates(self): source_historic = self.us[sr][1] last_updt_received = self.us[sr][0] updt: Update = None - if source_historic: - updt = source_historic[-1] + if len(source_historic): + updt = source_historic.pop() # [-1] last update received from this node elif last_updt_received: logging.info(f"Missing update source: {sr}, using last update received..") updt = last_updt_received @@ -153,17 +166,11 @@ async def notify_federation_update(self, source, remove=False): if self._round_updates_lock.locked(): logging.info(f"Source: {source} will be count next round") else: - self._update_source(source, remove) - # si se quita + await self._update_source(source, remove) else: - pass - # comprobar si he recibido la updt - # si no la he recibido se descarta de los esperados - # si la he recibido - # si antes de la agregacion de esta ronda la uso - - async def _clear_removed_sources(self): - pass + # Not received update from this source yet + if not source in self._sources_received: + await self._update_source(source, remove=True) async def _update_source(self, source, remove=False): logging.info(f"πŸ”„ Update | remove: {remove} | source: {source}") @@ -183,11 +190,18 @@ async def get_round_missing_nodes(self): return updates_left async def notify_if_all_updates_received(self): + logging.info("Set notification when all expected updates received") await self._round_updates_lock.acquire_async() + + async def stop_notifying_updates(self): + logging.info("Stop notifications updates") + if self._round_updates_lock.locked(): + await self._round_updates_lock.release_async() async def all_updates_received(self): updates_left = self._sources_expected.difference(self._sources_received) if len(updates_left) == 0: + logging.info("All updates have been received this round | releasing aggregation lock") await self._round_updates_lock.release_async() - await self.agg.notify_all_updates_received() + #await self.agg.notify_all_updates_received() \ No newline at end of file diff --git a/nebula/core/engine.py b/nebula/core/engine.py index ad75385b4..cce63e236 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -166,6 +166,9 @@ def __init__( logging.info("Registering callbacks for MessageEvents...") self.register_message_events_callbacks() + # Additional callbacks not registered automatically + self.register_message_callback(("model","initialization"), "model_initialization_callback") + @property def cm(self): return self._cm @@ -245,6 +248,33 @@ def set_round(self, new_round): self.round = new_round self.trainer.set_current_round(new_round) + """ ############################## + # MODEL CALLBACKS # + ############################## + """ + + async def model_initialization_callback(self, source, message): + try: + model = self.trainer.deserialize_model(message.parameters) + self.trainer.set_model_parameters(model, initialize=True) + logging.info("πŸ€– Init Model | Model Parameters Initialized") + self.set_initialization_status(True) + await ( + self.get_federation_ready_lock().release_async() + ) # Enable learning cycle once the initialization is done + try: + await ( + self.get_federation_ready_lock().release_async() + ) # Release the lock acquired at the beginning of the engine + except RuntimeError: + pass + except RuntimeError: + pass + + async def model_update_callback(self, source, message): + pass + + """ ############################## # General callbacks # ############################## @@ -512,6 +542,9 @@ async def _link_disconnect_from_callback(self, source, message): await self.cm.disconnect(source, mutual_disconnection=False) await self.nm.update_neighbors(addr, remove=True) + + + """ ############################## # ENGINE FUNCTIONALITY # ############################## @@ -533,6 +566,12 @@ def register_message_events_callbacks(self): if callable(method): self.event_manager.subscribe((event_type, action), method) + def register_message_callback(self, message_event: tuple[str, str], callback: str): + event_type, action = message_event + method = getattr(self, callback, None) + if callable(method): + self.event_manager.subscribe((event_type, action), method) + async def trigger_event(self, message_event): await self.event_manager.publish(message_event) @@ -563,7 +602,7 @@ async def update_neighbors(self, removed_neighbor_addr, neighbors, remove=False) if self.mobility: self.federation_nodes = neighbors await self.nm.update_neighbors(removed_neighbor_addr, remove=remove) - await self.aggregator.notify_federation_nodes_removed(self.federation_nodes) + await self.aggregator.notify_federation_nodes_removed(removed_neighbor_addr, remove=remove) async def update_model_learning_rate(self, new_lr): await self.trainning_in_progress_lock.acquire_async() @@ -759,7 +798,7 @@ async def _dynamic_aggregator(self, aggregated_models_weights, malicious_nodes): async def _waiting_model_updates(self): logging.info(f"πŸ’€ Waiting convergence in round {self.round}.") if self.mobility: - await self.aggregator.aggregation_push_available() + await self.aggregator.aggregation_push_available() #TODO params = await self.aggregator.get_aggregation() if params is not None: logging.info( diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/neighbormanagement/nodemanager.py index ed27da094..177650870 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/neighbormanagement/nodemanager.py @@ -366,9 +366,9 @@ async def start_late_connection_process(self, connected=False, msg_type="discove self.accept_candidates_lock.release() self.late_connection_process_lock.release() self.candidate_selector.remove_candidates() - #if not self._desc_done: #TODO remove - # self._desc_done = True - # asyncio.create_task(self.stop_connections_with_federation()) + if not self._desc_done: #TODO remove + self._desc_done = True + asyncio.create_task(self.stop_connections_with_federation()) # if no candidates, repeat process else: logging.info("❗️ No Candidates found...") @@ -413,6 +413,7 @@ async def check_robustness(self): async def reconnect_to_federation(self): self._restructure_process_lock.acquire() await self.engine.cm.clear_restrictions() + await asyncio.sleep(120) if await self.engine.cm.is_external_connection_service_running(): self.engine.cm.stop_external_connection_service() # If we got some refs, try to reconnect to them diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 0339784ed..5268a84da 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -16,6 +16,7 @@ from nebula.core.network.messages import MessagesManager from nebula.core.network.nebulamulticasting import NebulaConnectionService from nebula.core.network.propagator import Propagator +from nebula.core.network.messages import MessageEvent from nebula.core.utils.helper import ( cosine_metric, euclidean_metric, @@ -189,54 +190,8 @@ async def handle_model_message(self, source, message): if not self.engine.get_federation_ready_lock().locked() or self.engine.get_initialization_status(): decoded_model = self.engine.trainer.deserialize_model(message.parameters) if False and self.config.participant["adaptive_args"]["model_similarity"]: - logging.info("πŸ€– handle_model_message | Checking model similarity") - cosine_value = cosine_metric( - self.engine.trainer.get_model_parameters(), - decoded_model, - similarity=True, - ) - euclidean_value = euclidean_metric( - self.engine.trainer.get_model_parameters(), - decoded_model, - similarity=True, - ) - minkowski_value = minkowski_metric( - self.engine.trainer.get_model_parameters(), - decoded_model, - p=2, - similarity=True, - ) - manhattan_value = manhattan_metric( - self.engine.trainer.get_model_parameters(), - decoded_model, - similarity=True, - ) - pearson_correlation_value = pearson_correlation_metric( - self.engine.trainer.get_model_parameters(), - decoded_model, - similarity=True, - ) - jaccard_value = jaccard_metric( - self.engine.trainer.get_model_parameters(), - decoded_model, - similarity=True, - ) - # with open( - # f"{self.config.participant["tracking_args"]["log_dir"]}/participant_{self.id}_similarity.csv", - # "a+", - # ) as f: - # if os.stat(f"{self}/participant_{self.id}_similarity.csv").st_size == 0: - # f.write( - # "timestamp,source_ip,nodes,round,current_round,cosine,euclidean,minkowski,manhattan,pearson_correlation,jaccard\n" - # ) - # f.write( - # f"{datetime.now()}, {source}, {message.round}, {current_round}, {cosine_value}, {euclidean_value}, {minkowski_value}, {manhattan_value}, {pearson_correlation_value}, {jaccard_value}\n" - # ) - logging("Similarities between self model and model recieved...") - logging.info( - f"{source}, {message.round}, {current_round}, {cosine_value}, {euclidean_value}, {minkowski_value}, {manhattan_value}, {pearson_correlation_value}, {jaccard_value}" - ) - + pass + await self.engine.aggregator.include_model_in_buffer( decoded_model, message.weight, @@ -258,22 +213,24 @@ async def handle_model_message(self, source, message): ) return logging.info(f"πŸ€– handle_model_message | Initializing model (executed by {source})") - try: - model = self.engine.trainer.deserialize_model(message.parameters) - self.engine.trainer.set_model_parameters(model, initialize=True) - logging.info("πŸ€– handle_model_message | Model Parameters Initialized") - self.engine.set_initialization_status(True) - await ( - self.engine.get_federation_ready_lock().release_async() - ) # Enable learning cycle once the initialization is done - try: - await ( - self.engine.get_federation_ready_lock().release_async() - ) # Release the lock acquired at the beginning of the engine - except RuntimeError: - pass - except RuntimeError: - pass + model_init_event = MessageEvent(("model","initialization"), source, message) + await self.engine.trigger_event(model_init_event) + #try: +# model = self.engine.trainer.deserialize_model(message.parameters) +# self.engine.trainer.set_model_parameters(model, initialize=True) +# logging.info("πŸ€– handle_model_message | Model Parameters Initialized") +# self.engine.set_initialization_status(True) +# await ( +# self.engine.get_federation_ready_lock().release_async() +# ) # Enable learning cycle once the initialization is done +# try: +# await ( +# self.engine.get_federation_ready_lock().release_async() +# ) # Release the lock acquired at the beginning of the engine +# except RuntimeError: +# pass +# except RuntimeError: +# pass except Exception as e: logging.exception(f"πŸ€– handle_model_message | Unknown error adding model: {e}") From 2dafe3abb4ff3edb9f1b41e147e6c4f3309334a2 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Fri, 14 Feb 2025 16:54:56 +0100 Subject: [PATCH 092/233] fix_error --- nebula/core/aggregation/aggregator.py | 3 +-- nebula/core/aggregation/updatestorage.py | 25 ++++++++++++--------- nebula/core/engine.py | 8 ++++++- nebula/core/network/communications.py | 28 ++++++++++-------------- 4 files changed, 35 insertions(+), 29 deletions(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index 8eb487a57..bae1d18a5 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -86,8 +86,7 @@ async def update_federation_nodes(self, federation_nodes: set): raise Exception("It is not possible to set nodes to aggregate when the aggregation is running.") async def update_received_from_source(self, model, weight, source, round): - pass - + await self.us.storage_update(model, weight, source, round) async def notify_federation_nodes_removed(self, federation_node, remove=False): # Neighbor has been removed diff --git a/nebula/core/aggregation/updatestorage.py b/nebula/core/aggregation/updatestorage.py index d84a95637..5d5fc0b60 100644 --- a/nebula/core/aggregation/updatestorage.py +++ b/nebula/core/aggregation/updatestorage.py @@ -16,6 +16,7 @@ def __init__(self, model, weight, source, round, time_received): self.source = source self.round = round self.time_received = time_received + self.used = False def __eq__(self, other): return self.round == other.round @@ -84,10 +85,11 @@ async def round_expected_updates(self, federation_nodes: set): async def _check_updates_already_received(self): for se in self._sources_expected: - (_,node_storage) = self._updates_storage[se] + (last_updt, node_storage) = self._updates_storage[se] if len(node_storage): - logging.info(f"Update already received from source: {se} | ({len(self._sources_received)}/{len(self._sources_expected)}) Updates received") - self._sources_received.add(se) + if last_updt != node_storage[-1]: + logging.info(f"Update already received from source: {se} | ({len(self._sources_received)}/{len(self._sources_expected)}) Updates received") + self._sources_received.add(se) async def storage_update(self, model, weight, source, round): """ @@ -113,9 +115,10 @@ async def storage_update(self, model, weight, source, round): await self._updates_storage_lock.acquire_async() if updt in self.us[source][1]: logging.info(f"Discard | Alerady received update from source: {source} for round: {round}") - else: + else: + last_update_used = self.us[source][0] self.us[source][1].append(updt) - self.us[source] = (updt, self.us[source][1]) + self.us[source] = (last_update_used, self.us[source][1]) logging.info(f"Storage Update | source={source} | round={round} | weight={weight} | federation nodes: {self._sources_expected}") self._sources_received.add(source) @@ -125,7 +128,7 @@ async def storage_update(self, model, weight, source, round): await self.all_updates_received() await self._updates_storage_lock.release_async() else: - logging.info(f"source: {source} not in expected updates for this Round") + logging.info(f"Discard update | source: {source} not in expected updates for this Round") async def get_round_updates(self): """ @@ -152,12 +155,14 @@ async def get_round_updates(self): source_historic = self.us[sr][1] last_updt_received = self.us[sr][0] updt: Update = None - if len(source_historic): - updt = source_historic.pop() # [-1] last update received from this node - elif last_updt_received: + updt = source_historic[-1] # Get last update received + if last_updt_received == updt: logging.info(f"Missing update source: {sr}, using last update received..") - updt = last_updt_received + else: + last_updt_received = updt + self.us[sr] = (last_updt_received, source_historic) # Update storage with new last update used updates[sr] = (updt.model, updt.weight) + await self._updates_storage_lock.release_async() return updates diff --git a/nebula/core/engine.py b/nebula/core/engine.py index cce63e236..c01f0f0ba 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -168,6 +168,7 @@ def __init__( # Additional callbacks not registered automatically self.register_message_callback(("model","initialization"), "model_initialization_callback") + self.register_message_callback(("model","update"), "model_update_callback") @property def cm(self): @@ -272,7 +273,12 @@ async def model_initialization_callback(self, source, message): pass async def model_update_callback(self, source, message): - pass + #TODO gestionar situaciones aqui + logging.info(f"πŸ€– handle_model_message | Received model from {source} with round {message.round}") + if not self.get_federation_ready_lock().locked() and len(self.get_federation_nodes()) == 0: + logging.info("πŸ€– handle_model_message | There are no defined federation nodes") + return + await self.aggregator.update_received_from_source(message.parameters, message.weight, source, message.round) """ ############################## diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 5268a84da..d01e5474a 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -146,6 +146,11 @@ async def check_federation_ready(self): async def add_ready_connection(self, addr): self.ready_connections.add(addr) + """ ############################## + # PROCESSING MESSAGES # + ############################## + """ + async def handle_incoming_message(self, data, addr_from): if not await self.bl.node_in_blacklist(addr_from): await self.mm.process_message(data, addr_from) @@ -158,6 +163,7 @@ async def handle_message(self, message_event): await self.engine.trigger_event(message_event) async def handle_model_message(self, source, message): + #TODO modificar para generar eventos y gestionar en engine logging.info(f"πŸ€– handle_model_message | Received model from {source} with round {message.round}") if self.get_round() is not None: await self.engine.get_round_lock().acquire_async() @@ -172,6 +178,7 @@ async def handle_model_message(self, source, message): logging.info( f"πŸ€– handle_model_message | Saving model from {source} for future round {message.round}" ) + logging.info("### ENTRO 1 ###") await self.engine.aggregator.include_next_model_in_buffer( message.parameters, message.weight, @@ -192,6 +199,7 @@ async def handle_model_message(self, source, message): if False and self.config.participant["adaptive_args"]["model_similarity"]: pass + logging.info("### ENTRO 2 ###") await self.engine.aggregator.include_model_in_buffer( decoded_model, message.weight, @@ -205,6 +213,7 @@ async def handle_model_message(self, source, message): logging.info( f"πŸ€– handle_model_message | Saving model from {source} for future round {message.round}" ) + logging.info("### ENTRO 3 ###") await self.engine.aggregator.include_next_model_in_buffer( message.parameters, message.weight, @@ -215,22 +224,7 @@ async def handle_model_message(self, source, message): logging.info(f"πŸ€– handle_model_message | Initializing model (executed by {source})") model_init_event = MessageEvent(("model","initialization"), source, message) await self.engine.trigger_event(model_init_event) - #try: -# model = self.engine.trainer.deserialize_model(message.parameters) -# self.engine.trainer.set_model_parameters(model, initialize=True) -# logging.info("πŸ€– handle_model_message | Model Parameters Initialized") -# self.engine.set_initialization_status(True) -# await ( -# self.engine.get_federation_ready_lock().release_async() -# ) # Enable learning cycle once the initialization is done -# try: -# await ( -# self.engine.get_federation_ready_lock().release_async() -# ) # Release the lock acquired at the beginning of the engine -# except RuntimeError: -# pass -# except RuntimeError: -# pass + except Exception as e: logging.exception(f"πŸ€– handle_model_message | Unknown error adding model: {e}") @@ -275,6 +269,7 @@ async def apply_restrictions(self, nodes): async def clear_restrictions(self): await self.bl.clear_restrictions() + """ ############################### # EXTERNAL CONNECTION SERVICE # ############################### @@ -323,6 +318,7 @@ async def stablish_connection_to_federation(self, msg_type="discover_join", addr asyncio.create_task(self.send_message(addr, msg)) await asyncio.sleep(1) + """ ############################## # OTHER FUNCTIONALITIES # ############################## From 646c9a32e757f7796892353136be68c9a3feb4c7 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Sat, 15 Feb 2025 11:48:38 +0100 Subject: [PATCH 093/233] Feature update handlers interface --- nebula/core/aggregation/aggregator.py | 4 +- .../dflupdatehandler.py} | 7 +- .../updatehandlers/updatehandler.py | 113 ++++++++++++++++++ 3 files changed, 119 insertions(+), 5 deletions(-) rename nebula/core/aggregation/{updatestorage.py => updatehandlers/dflupdatehandler.py} (97%) create mode 100644 nebula/core/aggregation/updatehandlers/updatehandler.py diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index bae1d18a5..7294c98b1 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from functools import partial from nebula.core.utils.locker import Locker -from nebula.core.aggregation.updatestorage import UpdateStorage +from nebula.core.aggregation.updatehandlers.updatehandler import factory_update_handler from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -51,7 +51,7 @@ def __init__(self, config=None, engine=None): self._aggregation_waiting_skip = asyncio.Event() self._push_strategy_lock = Locker(name="push_strategy_lock", async_lock=True) self._end_round_push = 0 - self._update_storage = UpdateStorage(aggregator=self, addr=self._addr) + self._update_storage = factory_update_handler("dfl", self, self._addr) def __str__(self): return self.__class__.__name__ diff --git a/nebula/core/aggregation/updatestorage.py b/nebula/core/aggregation/updatehandlers/dflupdatehandler.py similarity index 97% rename from nebula/core/aggregation/updatestorage.py rename to nebula/core/aggregation/updatehandlers/dflupdatehandler.py index 5d5fc0b60..3141f1008 100644 --- a/nebula/core/aggregation/updatestorage.py +++ b/nebula/core/aggregation/updatehandlers/dflupdatehandler.py @@ -4,6 +4,7 @@ from typing import Dict, Tuple, Deque from nebula.core.utils.locker import Locker import time +from nebula.core.aggregation.updatehandlers.updatehandler import UpdateHandler from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -23,7 +24,7 @@ def __eq__(self, other): MAX_UPDATE_BUFFER_SIZE = 1 # Modify to create an historic -class UpdateStorage(): +class DFLUpdateHandler(UpdateHandler): def __init__( self, aggregator, @@ -125,7 +126,7 @@ async def storage_update(self, model, weight, source, round): updates_left = self._sources_expected.difference(self._sources_received) logging.info(f"Updates received ({len(self._sources_received)}/{len(self._sources_expected)}) | Missing nodes: {updates_left}") if self._round_updates_lock.locked() and not updates_left: - await self.all_updates_received() + await self._all_updates_received() await self._updates_storage_lock.release_async() else: logging.info(f"Discard update | source: {source} not in expected updates for this Round") @@ -203,7 +204,7 @@ async def stop_notifying_updates(self): if self._round_updates_lock.locked(): await self._round_updates_lock.release_async() - async def all_updates_received(self): + async def _all_updates_received(self): updates_left = self._sources_expected.difference(self._sources_received) if len(updates_left) == 0: logging.info("All updates have been received this round | releasing aggregation lock") diff --git a/nebula/core/aggregation/updatehandlers/updatehandler.py b/nebula/core/aggregation/updatehandlers/updatehandler.py new file mode 100644 index 000000000..fbbff98fe --- /dev/null +++ b/nebula/core/aggregation/updatehandlers/updatehandler.py @@ -0,0 +1,113 @@ +from abc import ABC, abstractmethod + +class UpdateHandlerException(Exception): + pass + +class UpdateHandler(ABC): + """ + Abstract base class for managing update storage and retrieval in a federated learning setting. + + This class defines the required methods for handling updates from multiple sources, + ensuring they are properly stored, retrieved, and processed during the aggregation process. + """ + + @abstractmethod + async def round_expected_updates(self, federation_nodes: set): + """ + Initializes the expected updates for the current round. + + This method sets up the expected sources (`federation_nodes`) that should provide updates + in the current training round. It ensures that each source has an entry in the storage + and resets any previous tracking of received updates. + + Args: + federation_nodes (set): A set of node identifiers expected to provide updates. + """ + raise NotImplementedError + + @abstractmethod + async def storage_update(self, model, weight, source, round): + """ + Stores an update from a source in the update storage. + + This method ensures that an update received from a source is properly stored in the buffer, + avoiding duplicates and managing update history if necessary. + + Args: + model: The model associated with the update. + weight: The weight assigned to the update (e.g., based on the amount of data used in training). + source (str): The identifier of the node sending the update. + round (int): The current device local training round when the update was done. + """ + raise NotImplementedError + + @abstractmethod + async def get_round_updates(self) -> dict[str, tuple[object, float]]: + """ + Retrieves the latest updates from all received sources in the current round. + + This method collects updates from all sources that have sent updates, + prioritizing the most recent update available in the buffer. + + Returns: + dict: A dictionary where keys are source identifiers and values are tuples `(model, weight)`, + representing the latest updates received from each source. + """ + raise NotImplementedError + + @abstractmethod + async def notify_federation_update(self, source, remove=False): + """ + Notifies the system of a change in the federation regarding a specific source. + + If a source leaves the federation, it is removed from the list of expected updates. + If a source is newly added, it is registered for future updates. + + Args: + source (str): The identifier of the source node. + remove (bool, optional): Whether to remove the source from the federation. Defaults to `False`. + """ + raise NotImplementedError + + @abstractmethod + async def get_round_missing_nodes(self) -> set[str]: + """ + Identifies sources that have not yet provided updates in the current round. + + Returns: + set: A set of source identifiers that are expected to send updates but have not yet been received. + """ + raise NotImplementedError + + @abstractmethod + async def notify_if_all_updates_received(self): + """ + Notifies the system when all expected updates for the current round have been received. + """ + raise NotImplementedError + + @abstractmethod + async def stop_notifying_updates(self): + """ + Stops notifications related to update reception. + + This method can be used to reset any notification mechanisms or stop tracking updates + if the aggregation process is halted. + """ + raise NotImplementedError + +def factory_update_handler(updt_handler, aggregator, addr) -> UpdateHandler: + from nebula.core.aggregation.updatehandlers.dflupdatehandler import DFLUpdateHandler + + UPDATE_HANDLERS = { + "dfl": DFLUpdateHandler + } + + update_handler = UPDATE_HANDLERS.get(updt_handler, None) + + if update_handler: + return update_handler(aggregator, addr) + else: + raise UpdateHandlerException(f"Update Handler {updt_handler} not found") + + \ No newline at end of file From ceb516f97cbbaa0c4c0c8e96f8173134c5906cc4 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Mon, 17 Feb 2025 13:28:10 +0100 Subject: [PATCH 094/233] feature dfl no rounds --- nebula/core/aggregation/aggregator.py | 60 ++++--- .../updatehandlers/dflupdatehandler.py | 96 +++++------ .../updatehandlers/updatehandler.py | 3 +- nebula/core/engine.py | 51 ++++-- nebula/core/network/communications.py | 159 +++++++++--------- nebula/core/network/messages.py | 2 +- .../README.txt | 0 .../__init__.py | 0 .../connectionoptimizer.py | 0 .../networkoptimization/networkoptimizer.py | 0 .../networkoptimization/timergenerator.py | 0 .../topologymanagement/awareness/samodule.py | 2 + .../candidateselection/__init__.py | 0 .../candidateselection/candidateselector.py | 8 +- .../candidateselection/fccandidateselector.py | 2 +- .../hetcandidateselector.py | 2 +- .../ringcandidateselector.py | 2 +- .../stdcandidateselector.py | 2 +- .../fastreboot.py | 2 +- .../modelhandlers/__init__.py | 0 .../modelhandlers/aggmodelhandler.py | 2 +- .../modelhandlers/defaultmodelhandler.py | 4 +- .../modelhandlers/modelhandler.py | 6 +- .../modelhandlers/stdmodelhandler.py | 2 +- .../momentum.py | 2 +- .../neighborpolicies/__init__.py | 0 .../neighborpolicies/fcneighborpolicy.py | 2 +- .../neighborpolicies/idleneighborpolicy.py | 2 +- .../neighborpolicies/neighborpolicy.py | 8 +- .../neighborpolicies/ringneighborpolicy.py | 2 +- .../neighborpolicies/starneighborpolicy.py | 2 +- .../nodemanager.py | 14 +- 32 files changed, 236 insertions(+), 201 deletions(-) rename nebula/core/{neighbormanagement => topologymanagement}/README.txt (100%) rename nebula/core/{neighbormanagement => topologymanagement}/__init__.py (100%) rename nebula/core/{network => topologymanagement/awareness}/networkoptimization/connectionoptimizer.py (100%) rename nebula/core/{network => topologymanagement/awareness}/networkoptimization/networkoptimizer.py (100%) rename nebula/core/{network => topologymanagement/awareness}/networkoptimization/timergenerator.py (100%) create mode 100644 nebula/core/topologymanagement/awareness/samodule.py rename nebula/core/{neighbormanagement => topologymanagement}/candidateselection/__init__.py (100%) rename nebula/core/{neighbormanagement => topologymanagement}/candidateselection/candidateselector.py (78%) rename nebula/core/{neighbormanagement => topologymanagement}/candidateselection/fccandidateselector.py (95%) rename nebula/core/{neighbormanagement => topologymanagement}/candidateselection/hetcandidateselector.py (98%) rename nebula/core/{neighbormanagement => topologymanagement}/candidateselection/ringcandidateselector.py (94%) rename nebula/core/{neighbormanagement => topologymanagement}/candidateselection/stdcandidateselector.py (94%) rename nebula/core/{neighbormanagement => topologymanagement}/fastreboot.py (99%) rename nebula/core/{neighbormanagement => topologymanagement}/modelhandlers/__init__.py (100%) rename nebula/core/{neighbormanagement => topologymanagement}/modelhandlers/aggmodelhandler.py (95%) rename nebula/core/{neighbormanagement => topologymanagement}/modelhandlers/defaultmodelhandler.py (91%) rename nebula/core/{neighbormanagement => topologymanagement}/modelhandlers/modelhandler.py (79%) rename nebula/core/{neighbormanagement => topologymanagement}/modelhandlers/stdmodelhandler.py (95%) rename nebula/core/{neighbormanagement => topologymanagement}/momentum.py (99%) rename nebula/core/{neighbormanagement => topologymanagement}/neighborpolicies/__init__.py (100%) rename nebula/core/{neighbormanagement => topologymanagement}/neighborpolicies/fcneighborpolicy.py (98%) rename nebula/core/{neighbormanagement => topologymanagement}/neighborpolicies/idleneighborpolicy.py (98%) rename nebula/core/{neighbormanagement => topologymanagement}/neighborpolicies/neighborpolicy.py (82%) rename nebula/core/{neighbormanagement => topologymanagement}/neighborpolicies/ringneighborpolicy.py (98%) rename nebula/core/{neighbormanagement => topologymanagement}/neighborpolicies/starneighborpolicy.py (97%) rename nebula/core/{neighbormanagement => topologymanagement}/nodemanager.py (97%) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index 7294c98b1..acacf27f4 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -85,8 +85,8 @@ async def update_federation_nodes(self, federation_nodes: set): else: raise Exception("It is not possible to set nodes to aggregate when the aggregation is running.") - async def update_received_from_source(self, model, weight, source, round): - await self.us.storage_update(model, weight, source, round) + async def update_received_from_source(self, model, weight, source, round, local=False): + await self.us.storage_update(model, weight, source, round, local=False) async def notify_federation_nodes_removed(self, federation_node, remove=False): # Neighbor has been removed @@ -190,7 +190,6 @@ async def _add_pending_model(self, model, weight, source): return self.get_nodes_pending_models_to_aggregate() async def include_model_in_buffer(self, model, weight, source=None, round=None, local=False): - await self.us.storage_update(model, weight, source, round) await self._add_model_lock.acquire_async() logging.info( f"πŸ”„ include_model_in_buffer | source={source} | round={round} | weight={weight} |--| __models={self._pending_models_to_aggregate.keys()} | federation_nodes={self._federation_nodes} | pending_models_to_aggregate={self.get_nodes_pending_models_to_aggregate()}" @@ -236,7 +235,7 @@ async def get_aggregation(self): ) lock_acquired = lock_task in done if skip_task in done: - logging.info("Skipping aggregation wait due to detected desynchronization") + logging.info("Skipping aggregation timeout, updates received before grace time") self._aggregation_waiting_skip.clear() if not lock_acquired: lock_task.cancel() @@ -257,30 +256,47 @@ async def get_aggregation(self): await self.us.stop_notifying_updates() updates = await self.us.get_round_updates() - - if self._waiting_global_update and len(self._pending_models_to_aggregate) == 1: - logging.info( - "πŸ”„ get_aggregation | Received an global model. Overwriting my model with the aggregated model." - ) - aggregated_model = next(iter(self._pending_models_to_aggregate.values()))[0] - self._pending_models_to_aggregate.clear() - return aggregated_model - - unique_nodes_involved = set(node for key in self._pending_models_to_aggregate for node in key.split()) - - if len(unique_nodes_involved) != len(self._federation_nodes): - missing_nodes = self._federation_nodes - unique_nodes_involved + + missing_nodes = await self.us.get_round_missing_nodes() + + if missing_nodes: logging.info(f"πŸ”„ get_aggregation | Aggregation incomplete, missing models from: {missing_nodes}") else: logging.info("πŸ”„ get_aggregation | All models accounted for, proceeding with aggregation.") - - self._pending_models_to_aggregate = await self.engine.apply_weight_strategy(self._pending_models_to_aggregate) - aggregated_result = self.run_aggregation(self._pending_models_to_aggregate) - self._pending_models_to_aggregate.clear() + + logging.info( + f"πŸ”„ Broadcasting MODELS_INCLUDED for round {self.engine.get_round()}" + ) + message = self.cm.create_message( + "federation", "federation_models_included", [str(arg) for arg in [self.engine.get_round()]] + ) + await self.cm.send_message_to_neighbors(message) + + # if self._waiting_global_update and len(self._pending_models_to_aggregate) == 1: + # logging.info( + # "πŸ”„ get_aggregation | Received an global model. Overwriting my model with the aggregated model." + # ) + # aggregated_model = next(iter(self._pending_models_to_aggregate.values()))[0] + # self._pending_models_to_aggregate.clear() + # return aggregated_model + + # unique_nodes_involved = set(node for key in self._pending_models_to_aggregate for node in key.split()) + + # if len(unique_nodes_involved) != len(self._federation_nodes): + # missing_nodes = self._federation_nodes - unique_nodes_involved + # logging.info(f"πŸ”„ get_aggregation | Aggregation incomplete, missing models from: {missing_nodes}") + # else: + # logging.info("πŸ”„ get_aggregation | All models accounted for, proceeding with aggregation.") + + # self._pending_models_to_aggregate = await self.engine.apply_weight_strategy(self._pending_models_to_aggregate) + # aggregated_result = self.run_aggregation(self._pending_models_to_aggregate) + # self._pending_models_to_aggregate.clear() + + updates = await self.engine.apply_weight_strategy(updates) + aggregated_result = self.run_aggregation(updates) return aggregated_result async def include_next_model_in_buffer(self, model, weight, source=None, round=None): - await self.us.storage_update(model, weight, source, round) logging.info(f"πŸ”„ include_next_model_in_buffer | source={source} | round={round} | weight={weight}") if round not in self._future_models_to_aggregate: self._future_models_to_aggregate[round] = [] diff --git a/nebula/core/aggregation/updatehandlers/dflupdatehandler.py b/nebula/core/aggregation/updatehandlers/dflupdatehandler.py index 3141f1008..4cc9d5409 100644 --- a/nebula/core/aggregation/updatehandlers/dflupdatehandler.py +++ b/nebula/core/aggregation/updatehandlers/dflupdatehandler.py @@ -40,6 +40,10 @@ def __init__( self._sources_received = set() self._round_updates_lock = Locker(name="round_updates_lock", async_lock=True) # se coge cuando se empieza a comprobar si estan todas las updates self._update_federation_lock = Locker(name="update_federation_lock", async_lock=True) + self._notification_sent_lock = Locker(name="notification_sent_lock", async_lock=True) + self._notification = False + self._missing_ones = set() + self._nodes_using_historic = set() @property def us(self): @@ -50,15 +54,6 @@ def agg(self): return self._aggregator async def round_expected_updates(self, federation_nodes: set): - """ - Initializes the expected updates for the current round. - - This method updates the list of expected sources (`_sources_expected`) for the current training round - and ensures that their respective update storage is initialized if they were not previously registered. - - Args: - federation_nodes (set): A set of node identifiers expected to provide updates in the current round. - """ await self._update_federation_lock.acquire_async() await self._updates_storage_lock.acquire_async() self._sources_expected = federation_nodes.copy() @@ -83,6 +78,8 @@ async def round_expected_updates(self, federation_nodes: set): # Lock to check if all updates received if self._round_updates_lock.locked(): self._round_updates_lock.release_async() + + self._notification = False async def _check_updates_already_received(self): for se in self._sources_expected: @@ -92,24 +89,7 @@ async def _check_updates_already_received(self): logging.info(f"Update already received from source: {se} | ({len(self._sources_received)}/{len(self._sources_expected)}) Updates received") self._sources_received.add(se) - async def storage_update(self, model, weight, source, round): - """ - Stores an update in the update queue if it has not been previously received for the same source and round. - - This method ensures that only one update per source and round is stored, avoiding duplicates. - If all expected updates for the current round have been received, it triggers the `all_updates_received()` method. - - Args: - model: The model associated with the update. - weight: The weight assigned to the update. - source (str): The source identifier of the update. - round (int): The training round in which the update was received. - - Logs: - - Stores and logs the update if it's new for the round. - - Logs a duplicate update if an identical one already exists. - - Logs missing sources if not all expected updates have been received. - """ + async def storage_update(self, model, weight, source, round, local=False): time_received = time.time() if source in self._sources_expected: updt = Update(model, weight, source, round, time_received) @@ -126,30 +106,19 @@ async def storage_update(self, model, weight, source, round): updates_left = self._sources_expected.difference(self._sources_received) logging.info(f"Updates received ({len(self._sources_received)}/{len(self._sources_expected)}) | Missing nodes: {updates_left}") if self._round_updates_lock.locked() and not updates_left: - await self._all_updates_received() + all_rec = await self._all_updates_received() + if all_rec: + await self._notify() await self._updates_storage_lock.release_async() else: - logging.info(f"Discard update | source: {source} not in expected updates for this Round") + if not source in self._sources_received: + logging.info(f"Discard update | source: {source} not in expected updates for this Round") async def get_round_updates(self): - """ - Retrieves the latest updates received in the current round. - - This method collects updates from all received sources, prioritizing the most recent update - stored in the queue. If an expected update is missing, it attempts to use the last received - update from that source instead. - - Returns: - dict: A dictionary mapping each source to a tuple `(model, weight)`, containing - the most recent update for that source. - - Logs: - - Logs missing sources if expected updates have not been received. - - Logs when a missing update is replaced by the last received update. - """ await self._updates_storage_lock.acquire_async() updates_missing = self._sources_expected.difference(self._sources_received) if updates_missing: + self._missing_ones = updates_missing logging.info(f"Missing updates from sources: {updates_missing}") updates = {} for sr in self._sources_received: @@ -157,7 +126,7 @@ async def get_round_updates(self): last_updt_received = self.us[sr][0] updt: Update = None updt = source_historic[-1] # Get last update received - if last_updt_received == updt: + if last_updt_received and last_updt_received == updt: logging.info(f"Missing update source: {sr}, using last update received..") else: last_updt_received = updt @@ -190,24 +159,45 @@ async def _update_source(self, source, remove=False): await self._updates_storage_lock.release_async() async def get_round_missing_nodes(self): - await self._updates_storage_lock.acquire_async() - updates_left = self._sources_expected.difference(self._sources_received) - await self._updates_storage_lock.release_async() - return updates_left + # await self._updates_storage_lock.acquire_async() + # updates_left = self._sources_expected.difference(self._sources_received) + # await self._updates_storage_lock.release_async() + return self._missing_ones async def notify_if_all_updates_received(self): logging.info("Set notification when all expected updates received") await self._round_updates_lock.acquire_async() - + await self._updates_storage_lock.acquire_async() + all_received = await self._all_updates_received() + await self._updates_storage_lock.release_async() + if all_received: + await self._notify() + + async def stop_notifying_updates(self): - logging.info("Stop notifications updates") if self._round_updates_lock.locked(): + logging.info("Stop notification updates") await self._round_updates_lock.release_async() + + async def _notify(self): + await self._notification_sent_lock.acquire_async() + if self._notification: + await self._notification_sent_lock.release_async() + return + self._notification = True + await self.stop_notifying_updates() + await self._notification_sent_lock.release_async() + logging.info("πŸ”„ Notifying aggregator to release aggregation") + await self.agg.notify_all_updates_received() + async def _all_updates_received(self): updates_left = self._sources_expected.difference(self._sources_received) + all_received = False if len(updates_left) == 0: - logging.info("All updates have been received this round | releasing aggregation lock") + logging.info("All updates have been received this round") await self._round_updates_lock.release_async() - #await self.agg.notify_all_updates_received() + all_received = True + return all_received + \ No newline at end of file diff --git a/nebula/core/aggregation/updatehandlers/updatehandler.py b/nebula/core/aggregation/updatehandlers/updatehandler.py index fbbff98fe..f74f6862a 100644 --- a/nebula/core/aggregation/updatehandlers/updatehandler.py +++ b/nebula/core/aggregation/updatehandlers/updatehandler.py @@ -26,7 +26,7 @@ async def round_expected_updates(self, federation_nodes: set): raise NotImplementedError @abstractmethod - async def storage_update(self, model, weight, source, round): + async def storage_update(self, model, weight, source, round, local=False): """ Stores an update from a source in the update storage. @@ -38,6 +38,7 @@ async def storage_update(self, model, weight, source, round): weight: The weight assigned to the update (e.g., based on the amount of data used in training). source (str): The identifier of the node sending the update. round (int): The current device local training round when the update was done. + local (boolean): Local update """ raise NotImplementedError diff --git a/nebula/core/engine.py b/nebula/core/engine.py index c01f0f0ba..3a24455d6 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -9,7 +9,7 @@ from nebula.addons.reporter import Reporter from nebula.core.aggregation.aggregator import create_aggregator, create_target_aggregator from nebula.core.eventmanager import EventManager -from nebula.core.neighbormanagement.nodemanager import NodeManager +from nebula.core.topologymanagement.nodemanager import NodeManager from nebula.core.network.communications import CommunicationsManager from nebula.core.utils.locker import Locker @@ -255,6 +255,7 @@ def set_round(self, new_round): """ async def model_initialization_callback(self, source, message): + logging.info(f"πŸ€– handle_model_message | Received model initialization from {source}") try: model = self.trainer.deserialize_model(message.parameters) self.trainer.set_model_parameters(model, initialize=True) @@ -273,12 +274,12 @@ async def model_initialization_callback(self, source, message): pass async def model_update_callback(self, source, message): - #TODO gestionar situaciones aqui - logging.info(f"πŸ€– handle_model_message | Received model from {source} with round {message.round}") + logging.info(f"πŸ€– handle_model_message | Received model update from {source} with round {message.round}") if not self.get_federation_ready_lock().locked() and len(self.get_federation_nodes()) == 0: logging.info("πŸ€– handle_model_message | There are no defined federation nodes") return - await self.aggregator.update_received_from_source(message.parameters, message.weight, source, message.round) + decoded_model = self.trainer.deserialize_model(message.parameters) + await self.aggregator.update_received_from_source(decoded_model, message.weight, source, message.round) """ ############################## @@ -405,12 +406,10 @@ async def _connection_late_connect_callback(self, source, message): ct_actions, df_actions = self.nm.get_actions() if len(ct_actions): - # for addr in ct_actions.split(): cnt_msg = self.cm.create_message("link", "connect_to", addrs=ct_actions) await self.cm.send_message(source, cnt_msg) if len(df_actions): - # for addr in df_actions.split(): df_msg = self.cm.create_message("link", "disconnect_from", addrs=df_actions) await self.cm.send_message(source, df_msg) @@ -432,19 +431,16 @@ async def _connection_restructure_callback(self, source, message): logging.info(f"πŸ”— handle_connection_message | Trigger | restructure connection accepted from {source}") await self.cm.connect(source, direct=True) - # conf_msg = self.cm.mm.generate_connection_message(nebula_pb2.ConnectionMessage.Action.RESTRUCTURE) conf_msg = self.cm.create_message("connection", "restructure") await self.cm.send_message(source, conf_msg) ct_actions, df_actions = self.nm.get_actions() if len(ct_actions): - # cnt_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.CONNECT_TO, ct_actions) cnt_msg = self.cm.create_message("link", "connect_to", addrs=ct_actions) await self.cm.send_message(source, cnt_msg) if len(df_actions): - # df_msg = self.cm.mm.generate_link_message(nebula_pb2.LinkMessage.Action.DISCONNECT_FROM, df_actions) df_msg = self.cm.create_message("link", "disconnect_from", addrs=df_actions) await self.cm.send_message(source, df_msg) @@ -620,8 +616,8 @@ async def _start_learning_late(self): await self.learning_cycle_lock.acquire_async() try: model_serialized, rounds, round, _epochs = await self.nm.get_trainning_info() - self.total_rounds = rounds # self.config.participant["scenario_args"]["rounds"] #rounds - epochs = _epochs # self.config.participant["training_args"]["epochs"] #_epochs + self.total_rounds = rounds + epochs = _epochs await self.get_round_lock().acquire_async() self.round = round await self.get_round_lock().release_async() @@ -803,8 +799,8 @@ async def _dynamic_aggregator(self, aggregated_models_weights, malicious_nodes): async def _waiting_model_updates(self): logging.info(f"πŸ’€ Waiting convergence in round {self.round}.") - if self.mobility: - await self.aggregator.aggregation_push_available() #TODO + # if self.mobility: + # await self.aggregator.aggregation_push_available() #TODO params = await self.aggregator.get_aggregation() if params is not None: logging.info( @@ -1003,7 +999,14 @@ async def _extended_learning_cycle(self): await self.trainer.train() await self.trainning_in_progress_lock.release_async() - await self.aggregator.include_model_in_buffer( + # await self.aggregator.include_model_in_buffer( + # self.trainer.get_model_parameters(), + # self.trainer.get_model_weight(), + # source=self.addr, + # round=self.round, + # ) + + await self.aggregator.update_received_from_source( self.trainer.get_model_parameters(), self.trainer.get_model_weight(), source=self.addr, @@ -1036,12 +1039,20 @@ async def _extended_learning_cycle(self): await self.trainer.test() # In the first round, the server node doest take into account the initial model parameters for the aggregation - await self.aggregator.include_model_in_buffer( + # await self.aggregator.include_model_in_buffer( + # self.trainer.get_model_parameters(), + # self.trainer.BYPASS_MODEL_WEIGHT, + # source=self.addr, + # round=self.round, + # ) + + await self.aggregator.update_received_from_source( self.trainer.get_model_parameters(), self.trainer.BYPASS_MODEL_WEIGHT, source=self.addr, round=self.round, ) + await self._waiting_model_updates() await self.cm.propagator.propagate("stable") @@ -1071,7 +1082,15 @@ async def _extended_learning_cycle(self): await self.trainer.test() await self.trainer.train() - await self.aggregator.include_model_in_buffer( + # await self.aggregator.include_model_in_buffer( + # self.trainer.get_model_parameters(), + # self.trainer.get_model_weight(), + # source=self.addr, + # round=self.round, + # local=True, + # ) + + await self.aggregator.update_received_from_source( self.trainer.get_model_parameters(), self.trainer.get_model_weight(), source=self.addr, diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index d01e5474a..08ebdc28f 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -163,85 +163,89 @@ async def handle_message(self, message_event): await self.engine.trigger_event(message_event) async def handle_model_message(self, source, message): - #TODO modificar para generar eventos y gestionar en engine logging.info(f"πŸ€– handle_model_message | Received model from {source} with round {message.round}") - if self.get_round() is not None: - await self.engine.get_round_lock().acquire_async() - current_round = self.get_round() - await self.engine.get_round_lock().release_async() - - if message.round != current_round and message.round != -1: - logging.info( - f"❗️ handle_model_message | Received a model from a different round | Model round: {message.round} | Current round: {current_round}" - ) - if message.round > current_round: - logging.info( - f"πŸ€– handle_model_message | Saving model from {source} for future round {message.round}" - ) - logging.info("### ENTRO 1 ###") - await self.engine.aggregator.include_next_model_in_buffer( - message.parameters, - message.weight, - source=source, - round=message.round, - ) - else: - logging.info(f"❗️ handle_model_message | Ignoring model from {source} from a previous round") - return - if not self.engine.get_federation_ready_lock().locked() and len(self.engine.get_federation_nodes()) == 0: - logging.info("πŸ€– handle_model_message | There are no defined federation nodes") - return - try: - # get_federation_ready_lock() is locked when the model is being initialized (first round) - # non-starting nodes receive the initialized model from the starting node - if not self.engine.get_federation_ready_lock().locked() or self.engine.get_initialization_status(): - decoded_model = self.engine.trainer.deserialize_model(message.parameters) - if False and self.config.participant["adaptive_args"]["model_similarity"]: - pass - - logging.info("### ENTRO 2 ###") - await self.engine.aggregator.include_model_in_buffer( - decoded_model, - message.weight, - source=source, - round=message.round, - ) - - else: - if message.round != -1: - # Be sure that the model message is from the initialization round (round = -1) - logging.info( - f"πŸ€– handle_model_message | Saving model from {source} for future round {message.round}" - ) - logging.info("### ENTRO 3 ###") - await self.engine.aggregator.include_next_model_in_buffer( - message.parameters, - message.weight, - source=source, - round=message.round, - ) - return - logging.info(f"πŸ€– handle_model_message | Initializing model (executed by {source})") - model_init_event = MessageEvent(("model","initialization"), source, message) - await self.engine.trigger_event(model_init_event) - - - except Exception as e: - logging.exception(f"πŸ€– handle_model_message | Unknown error adding model: {e}") - logging.exception(traceback.format_exc()) - + if message.round == -1: + model_init_event = MessageEvent(("model","initialization"), source, message) + await self.engine.trigger_event(model_init_event) else: - logging.info("πŸ€– handle_model_message | Tried to add a model while learning is not running") - if message.round != -1: - # Be sure that the model message is from the initialization round (round = -1) - logging.info(f"πŸ€– handle_model_message | Saving model from {source} for future round {message.round}") - await self.engine.aggregator.include_next_model_in_buffer( - message.parameters, - message.weight, - source=source, - round=message.round, - ) - return + model_updt_event = MessageEvent(("model","update"), source, message) + await self.engine.trigger_event(model_updt_event) + + # if self.get_round() is not None: + # await self.engine.get_round_lock().acquire_async() + # current_round = self.get_round() + # await self.engine.get_round_lock().release_async() + + # if message.round != current_round and message.round != -1: + # logging.info( + # f"❗️ handle_model_message | Received a model from a different round | Model round: {message.round} | Current round: {current_round}" + # ) + # if message.round > current_round: + # logging.info( + # f"πŸ€– handle_model_message | Saving model from {source} for future round {message.round}" + # ) + # logging.info("### ENTRO 1 ###") + # await self.engine.aggregator.include_next_model_in_buffer( + # message.parameters, + # message.weight, + # source=source, + # round=message.round, + # ) + # else: + # logging.info(f"❗️ handle_model_message | Ignoring model from {source} from a previous round") + # return + # if not self.engine.get_federation_ready_lock().locked() and len(self.engine.get_federation_nodes()) == 0: + # logging.info("πŸ€– handle_model_message | There are no defined federation nodes") + # return + # try: + # # get_federation_ready_lock() is locked when the model is being initialized (first round) + # # non-starting nodes receive the initialized model from the starting node + # if not self.engine.get_federation_ready_lock().locked() or self.engine.get_initialization_status(): + # decoded_model = self.engine.trainer.deserialize_model(message.parameters) + # if False and self.config.participant["adaptive_args"]["model_similarity"]: + # pass + + # logging.info("### ENTRO 2 ###") + # await self.engine.aggregator.include_model_in_buffer( + # decoded_model, + # message.weight, + # source=source, + # round=message.round, + # ) + + # else: + # if message.round != -1: + # # Be sure that the model message is from the initialization round (round = -1) + # logging.info( + # f"πŸ€– handle_model_message | Saving model from {source} for future round {message.round}" + # ) + # logging.info("### ENTRO 3 ###") + # await self.engine.aggregator.include_next_model_in_buffer( + # message.parameters, + # message.weight, + # source=source, + # round=message.round, + # ) + # return + + + # except Exception as e: + # logging.exception(f"πŸ€– handle_model_message | Unknown error adding model: {e}") + # logging.exception(traceback.format_exc()) + + # else: + # logging.info("πŸ€– handle_model_message | Tried to add a model while learning is not running") + # if message.round != -1: + # # Be sure that the model message is from the initialization round (round = -1) + # logging.info("### ENTRO 4 ###") + # logging.info(f"πŸ€– handle_model_message | Saving model from {source} for future round {message.round}") + # await self.engine.aggregator.include_next_model_in_buffer( + # message.parameters, + # message.weight, + # source=source, + # round=message.round, + # ) + # return def create_message(self, message_type: str, action: str = "", *args, **kwargs): return self.mm.create_message(message_type, action, *args, **kwargs) @@ -289,6 +293,7 @@ def init_external_connection_service(self): async def is_external_connection_service_running(self): return self.ecs.is_running() + #TODO comprobar que el verify_connections no cree un bucle de espera infinito async def stablish_connection_to_federation(self, msg_type="discover_join", addrs_known=None): """ Using ExternalConnectionService to get addrs on local network, after that diff --git a/nebula/core/network/messages.py b/nebula/core/network/messages.py index b435264e0..ff75cee32 100644 --- a/nebula/core/network/messages.py +++ b/nebula/core/network/messages.py @@ -181,7 +181,7 @@ def create_message(self, message_type: str, action: str = "", *args, **kwargs): # logging.info(f"kwargs parameters: {kwargs.keys()}") for param_name in template_params: if param_name not in kwargs: - logging.info(f"Filling parameter '{param_name}' with default value: {default_values.get(param_name)}") + #logging.info(f"Filling parameter '{param_name}' with default value: {default_values.get(param_name)}") kwargs[param_name] = default_values.get(param_name) # Create an instance of the protobuf message class using the constructed kwargs diff --git a/nebula/core/neighbormanagement/README.txt b/nebula/core/topologymanagement/README.txt similarity index 100% rename from nebula/core/neighbormanagement/README.txt rename to nebula/core/topologymanagement/README.txt diff --git a/nebula/core/neighbormanagement/__init__.py b/nebula/core/topologymanagement/__init__.py similarity index 100% rename from nebula/core/neighbormanagement/__init__.py rename to nebula/core/topologymanagement/__init__.py diff --git a/nebula/core/network/networkoptimization/connectionoptimizer.py b/nebula/core/topologymanagement/awareness/networkoptimization/connectionoptimizer.py similarity index 100% rename from nebula/core/network/networkoptimization/connectionoptimizer.py rename to nebula/core/topologymanagement/awareness/networkoptimization/connectionoptimizer.py diff --git a/nebula/core/network/networkoptimization/networkoptimizer.py b/nebula/core/topologymanagement/awareness/networkoptimization/networkoptimizer.py similarity index 100% rename from nebula/core/network/networkoptimization/networkoptimizer.py rename to nebula/core/topologymanagement/awareness/networkoptimization/networkoptimizer.py diff --git a/nebula/core/network/networkoptimization/timergenerator.py b/nebula/core/topologymanagement/awareness/networkoptimization/timergenerator.py similarity index 100% rename from nebula/core/network/networkoptimization/timergenerator.py rename to nebula/core/topologymanagement/awareness/networkoptimization/timergenerator.py diff --git a/nebula/core/topologymanagement/awareness/samodule.py b/nebula/core/topologymanagement/awareness/samodule.py new file mode 100644 index 000000000..e80542215 --- /dev/null +++ b/nebula/core/topologymanagement/awareness/samodule.py @@ -0,0 +1,2 @@ +import asyncio +import logging \ No newline at end of file diff --git a/nebula/core/neighbormanagement/candidateselection/__init__.py b/nebula/core/topologymanagement/candidateselection/__init__.py similarity index 100% rename from nebula/core/neighbormanagement/candidateselection/__init__.py rename to nebula/core/topologymanagement/candidateselection/__init__.py diff --git a/nebula/core/neighbormanagement/candidateselection/candidateselector.py b/nebula/core/topologymanagement/candidateselection/candidateselector.py similarity index 78% rename from nebula/core/neighbormanagement/candidateselection/candidateselector.py rename to nebula/core/topologymanagement/candidateselection/candidateselector.py index 16546c89a..30a5b09e8 100644 --- a/nebula/core/neighbormanagement/candidateselection/candidateselector.py +++ b/nebula/core/topologymanagement/candidateselection/candidateselector.py @@ -24,10 +24,10 @@ def any_candidate(self): pass def factory_CandidateSelector(topology) -> CandidateSelector: - from nebula.core.neighbormanagement.candidateselection.stdcandidateselector import STDandidateSelector - from nebula.core.neighbormanagement.candidateselection.fccandidateselector import FCCandidateSelector - from nebula.core.neighbormanagement.candidateselection.hetcandidateselector import HETCandidateSelector - from nebula.core.neighbormanagement.candidateselection.ringcandidateselector import RINGCandidateSelector + from nebula.core.topologymanagement.candidateselection.stdcandidateselector import STDandidateSelector + from nebula.core.topologymanagement.candidateselection.fccandidateselector import FCCandidateSelector + from nebula.core.topologymanagement.candidateselection.hetcandidateselector import HETCandidateSelector + from nebula.core.topologymanagement.candidateselection.ringcandidateselector import RINGCandidateSelector options = { "ring": RINGCandidateSelector, diff --git a/nebula/core/neighbormanagement/candidateselection/fccandidateselector.py b/nebula/core/topologymanagement/candidateselection/fccandidateselector.py similarity index 95% rename from nebula/core/neighbormanagement/candidateselection/fccandidateselector.py rename to nebula/core/topologymanagement/candidateselection/fccandidateselector.py index 6f1c71129..e487737da 100644 --- a/nebula/core/neighbormanagement/candidateselection/fccandidateselector.py +++ b/nebula/core/topologymanagement/candidateselection/fccandidateselector.py @@ -1,4 +1,4 @@ -from nebula.core.neighbormanagement.candidateselection.candidateselector import CandidateSelector +from nebula.core.topologymanagement.candidateselection.candidateselector import CandidateSelector from nebula.core.utils.locker import Locker class FCCandidateSelector(CandidateSelector): diff --git a/nebula/core/neighbormanagement/candidateselection/hetcandidateselector.py b/nebula/core/topologymanagement/candidateselection/hetcandidateselector.py similarity index 98% rename from nebula/core/neighbormanagement/candidateselection/hetcandidateselector.py rename to nebula/core/topologymanagement/candidateselection/hetcandidateselector.py index 20ec939a1..345f63c5b 100644 --- a/nebula/core/neighbormanagement/candidateselection/hetcandidateselector.py +++ b/nebula/core/topologymanagement/candidateselection/hetcandidateselector.py @@ -1,4 +1,4 @@ -from nebula.core.neighbormanagement.candidateselection.candidateselector import CandidateSelector +from nebula.core.topologymanagement.candidateselection.candidateselector import CandidateSelector from nebula.core.utils.locker import Locker class HETCandidateSelector(CandidateSelector): diff --git a/nebula/core/neighbormanagement/candidateselection/ringcandidateselector.py b/nebula/core/topologymanagement/candidateselection/ringcandidateselector.py similarity index 94% rename from nebula/core/neighbormanagement/candidateselection/ringcandidateselector.py rename to nebula/core/topologymanagement/candidateselection/ringcandidateselector.py index 7171edf67..47990e88d 100644 --- a/nebula/core/neighbormanagement/candidateselection/ringcandidateselector.py +++ b/nebula/core/topologymanagement/candidateselection/ringcandidateselector.py @@ -1,4 +1,4 @@ -from nebula.core.neighbormanagement.candidateselection.candidateselector import CandidateSelector +from nebula.core.topologymanagement.candidateselection.candidateselector import CandidateSelector from nebula.core.utils.locker import Locker class RINGCandidateSelector(CandidateSelector): diff --git a/nebula/core/neighbormanagement/candidateselection/stdcandidateselector.py b/nebula/core/topologymanagement/candidateselection/stdcandidateselector.py similarity index 94% rename from nebula/core/neighbormanagement/candidateselection/stdcandidateselector.py rename to nebula/core/topologymanagement/candidateselection/stdcandidateselector.py index 7418e8dbb..ddbdbcf57 100644 --- a/nebula/core/neighbormanagement/candidateselection/stdcandidateselector.py +++ b/nebula/core/topologymanagement/candidateselection/stdcandidateselector.py @@ -1,4 +1,4 @@ -from nebula.core.neighbormanagement.candidateselection.candidateselector import CandidateSelector +from nebula.core.topologymanagement.candidateselection.candidateselector import CandidateSelector from nebula.core.utils.locker import Locker class STDandidateSelector(CandidateSelector): diff --git a/nebula/core/neighbormanagement/fastreboot.py b/nebula/core/topologymanagement/fastreboot.py similarity index 99% rename from nebula/core/neighbormanagement/fastreboot.py rename to nebula/core/topologymanagement/fastreboot.py index ef312315a..f49919601 100644 --- a/nebula/core/neighbormanagement/fastreboot.py +++ b/nebula/core/topologymanagement/fastreboot.py @@ -4,7 +4,7 @@ from nebula.core.utils.locker import Locker if TYPE_CHECKING: - from nebula.core.neighbormanagement.nodemanager import NodeManager + from nebula.core.topologymanagement.nodemanager import NodeManager VANILLA_LEARNING_RATE = 1e-3 FR_LEARNING_RATE = 1e-3 diff --git a/nebula/core/neighbormanagement/modelhandlers/__init__.py b/nebula/core/topologymanagement/modelhandlers/__init__.py similarity index 100% rename from nebula/core/neighbormanagement/modelhandlers/__init__.py rename to nebula/core/topologymanagement/modelhandlers/__init__.py diff --git a/nebula/core/neighbormanagement/modelhandlers/aggmodelhandler.py b/nebula/core/topologymanagement/modelhandlers/aggmodelhandler.py similarity index 95% rename from nebula/core/neighbormanagement/modelhandlers/aggmodelhandler.py rename to nebula/core/topologymanagement/modelhandlers/aggmodelhandler.py index 407020bd9..3f0f6331a 100644 --- a/nebula/core/neighbormanagement/modelhandlers/aggmodelhandler.py +++ b/nebula/core/topologymanagement/modelhandlers/aggmodelhandler.py @@ -1,4 +1,4 @@ -from nebula.core.neighbormanagement.modelhandlers.modelhandler import ModelHandler +from nebula.core.topologymanagement.modelhandlers.modelhandler import ModelHandler from nebula.core.utils.locker import Locker class AGGModelHandler(ModelHandler): diff --git a/nebula/core/neighbormanagement/modelhandlers/defaultmodelhandler.py b/nebula/core/topologymanagement/modelhandlers/defaultmodelhandler.py similarity index 91% rename from nebula/core/neighbormanagement/modelhandlers/defaultmodelhandler.py rename to nebula/core/topologymanagement/modelhandlers/defaultmodelhandler.py index f8a90c62a..b17910751 100644 --- a/nebula/core/neighbormanagement/modelhandlers/defaultmodelhandler.py +++ b/nebula/core/topologymanagement/modelhandlers/defaultmodelhandler.py @@ -1,6 +1,6 @@ -from nebula.core.neighbormanagement.modelhandlers.modelhandler import ModelHandler +from nebula.core.topologymanagement.modelhandlers.modelhandler import ModelHandler from nebula.core.utils.locker import Locker -from nebula.core.neighbormanagement.nodemanager import NodeManager +from nebula.core.topologymanagement.nodemanager import NodeManager import logging class DefaultModelHandler(ModelHandler): diff --git a/nebula/core/neighbormanagement/modelhandlers/modelhandler.py b/nebula/core/topologymanagement/modelhandlers/modelhandler.py similarity index 79% rename from nebula/core/neighbormanagement/modelhandlers/modelhandler.py rename to nebula/core/topologymanagement/modelhandlers/modelhandler.py index 1d4af4e5a..242dbd69a 100644 --- a/nebula/core/neighbormanagement/modelhandlers/modelhandler.py +++ b/nebula/core/topologymanagement/modelhandlers/modelhandler.py @@ -20,9 +20,9 @@ def pre_process_model(self): pass def factory_ModelHandler(model_handler) -> ModelHandler: - from nebula.core.neighbormanagement.modelhandlers.stdmodelhandler import STDModelHandler - from nebula.core.neighbormanagement.modelhandlers.aggmodelhandler import AGGModelHandler - from nebula.core.neighbormanagement.modelhandlers.defaultmodelhandler import DefaultModelHandler + from nebula.core.topologymanagement.modelhandlers.stdmodelhandler import STDModelHandler + from nebula.core.topologymanagement.modelhandlers.aggmodelhandler import AGGModelHandler + from nebula.core.topologymanagement.modelhandlers.defaultmodelhandler import DefaultModelHandler options = { "std": STDModelHandler, diff --git a/nebula/core/neighbormanagement/modelhandlers/stdmodelhandler.py b/nebula/core/topologymanagement/modelhandlers/stdmodelhandler.py similarity index 95% rename from nebula/core/neighbormanagement/modelhandlers/stdmodelhandler.py rename to nebula/core/topologymanagement/modelhandlers/stdmodelhandler.py index 861edd1fa..b93d43276 100644 --- a/nebula/core/neighbormanagement/modelhandlers/stdmodelhandler.py +++ b/nebula/core/topologymanagement/modelhandlers/stdmodelhandler.py @@ -1,4 +1,4 @@ -from nebula.core.neighbormanagement.modelhandlers.modelhandler import ModelHandler +from nebula.core.topologymanagement.modelhandlers.modelhandler import ModelHandler from nebula.core.utils.locker import Locker diff --git a/nebula/core/neighbormanagement/momentum.py b/nebula/core/topologymanagement/momentum.py similarity index 99% rename from nebula/core/neighbormanagement/momentum.py rename to nebula/core/topologymanagement/momentum.py index 7b1f8e199..6d68dd5a6 100644 --- a/nebula/core/neighbormanagement/momentum.py +++ b/nebula/core/topologymanagement/momentum.py @@ -9,7 +9,7 @@ from nebula.core.utils.locker import Locker if TYPE_CHECKING: - from nebula.core.neighbormanagement.nodemanager import NodeManager + from nebula.core.topologymanagement.nodemanager import NodeManager SimilarityMetricType = Callable[[OrderedDict, OrderedDict, bool], float | None] MappingSimilarityType = Callable[[float, float], Annotated[float, "Value in (0, 1]"]] diff --git a/nebula/core/neighbormanagement/neighborpolicies/__init__.py b/nebula/core/topologymanagement/neighborpolicies/__init__.py similarity index 100% rename from nebula/core/neighbormanagement/neighborpolicies/__init__.py rename to nebula/core/topologymanagement/neighborpolicies/__init__.py diff --git a/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py b/nebula/core/topologymanagement/neighborpolicies/fcneighborpolicy.py similarity index 98% rename from nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py rename to nebula/core/topologymanagement/neighborpolicies/fcneighborpolicy.py index 2ff6e148c..b7f88c1d2 100644 --- a/nebula/core/neighbormanagement/neighborpolicies/fcneighborpolicy.py +++ b/nebula/core/topologymanagement/neighborpolicies/fcneighborpolicy.py @@ -1,4 +1,4 @@ -from nebula.core.neighbormanagement.neighborpolicies.neighborpolicy import NeighborPolicy +from nebula.core.topologymanagement.neighborpolicies.neighborpolicy import NeighborPolicy from nebula.core.utils.locker import Locker class FCNeighborPolicy(NeighborPolicy): diff --git a/nebula/core/neighbormanagement/neighborpolicies/idleneighborpolicy.py b/nebula/core/topologymanagement/neighborpolicies/idleneighborpolicy.py similarity index 98% rename from nebula/core/neighbormanagement/neighborpolicies/idleneighborpolicy.py rename to nebula/core/topologymanagement/neighborpolicies/idleneighborpolicy.py index 20a7fbbc2..81c98435d 100644 --- a/nebula/core/neighbormanagement/neighborpolicies/idleneighborpolicy.py +++ b/nebula/core/topologymanagement/neighborpolicies/idleneighborpolicy.py @@ -1,4 +1,4 @@ -from nebula.core.neighbormanagement.neighborpolicies.neighborpolicy import NeighborPolicy +from nebula.core.topologymanagement.neighborpolicies.neighborpolicy import NeighborPolicy from nebula.core.utils.locker import Locker class IDLENeighborPolicy(NeighborPolicy): diff --git a/nebula/core/neighbormanagement/neighborpolicies/neighborpolicy.py b/nebula/core/topologymanagement/neighborpolicies/neighborpolicy.py similarity index 82% rename from nebula/core/neighbormanagement/neighborpolicies/neighborpolicy.py rename to nebula/core/topologymanagement/neighborpolicies/neighborpolicy.py index d97df6dae..7436ae12f 100644 --- a/nebula/core/neighbormanagement/neighborpolicies/neighborpolicy.py +++ b/nebula/core/topologymanagement/neighborpolicies/neighborpolicy.py @@ -36,10 +36,10 @@ def update_neighbors(self, node, remove=False): pass def factory_NeighborPolicy(topology) -> NeighborPolicy: - from nebula.core.neighbormanagement.neighborpolicies.idleneighborpolicy import IDLENeighborPolicy - from nebula.core.neighbormanagement.neighborpolicies.fcneighborpolicy import FCNeighborPolicy - from nebula.core.neighbormanagement.neighborpolicies.ringneighborpolicy import RINGNeighborPolicy - from nebula.core.neighbormanagement.neighborpolicies.starneighborpolicy import STARNeighborPolicy + from nebula.core.topologymanagement.neighborpolicies.idleneighborpolicy import IDLENeighborPolicy + from nebula.core.topologymanagement.neighborpolicies.fcneighborpolicy import FCNeighborPolicy + from nebula.core.topologymanagement.neighborpolicies.ringneighborpolicy import RINGNeighborPolicy + from nebula.core.topologymanagement.neighborpolicies.starneighborpolicy import STARNeighborPolicy options = { "random": IDLENeighborPolicy, # default value diff --git a/nebula/core/neighbormanagement/neighborpolicies/ringneighborpolicy.py b/nebula/core/topologymanagement/neighborpolicies/ringneighborpolicy.py similarity index 98% rename from nebula/core/neighbormanagement/neighborpolicies/ringneighborpolicy.py rename to nebula/core/topologymanagement/neighborpolicies/ringneighborpolicy.py index c74d25919..90310d6fc 100644 --- a/nebula/core/neighbormanagement/neighborpolicies/ringneighborpolicy.py +++ b/nebula/core/topologymanagement/neighborpolicies/ringneighborpolicy.py @@ -1,4 +1,4 @@ -from nebula.core.neighbormanagement.neighborpolicies.neighborpolicy import NeighborPolicy +from nebula.core.topologymanagement.neighborpolicies.neighborpolicy import NeighborPolicy from nebula.core.utils.locker import Locker import random diff --git a/nebula/core/neighbormanagement/neighborpolicies/starneighborpolicy.py b/nebula/core/topologymanagement/neighborpolicies/starneighborpolicy.py similarity index 97% rename from nebula/core/neighbormanagement/neighborpolicies/starneighborpolicy.py rename to nebula/core/topologymanagement/neighborpolicies/starneighborpolicy.py index c79e72ab4..71b3f69c6 100644 --- a/nebula/core/neighbormanagement/neighborpolicies/starneighborpolicy.py +++ b/nebula/core/topologymanagement/neighborpolicies/starneighborpolicy.py @@ -1,4 +1,4 @@ -from nebula.core.neighbormanagement.neighborpolicies.neighborpolicy import NeighborPolicy +from nebula.core.topologymanagement.neighborpolicies.neighborpolicy import NeighborPolicy from nebula.core.utils.locker import Locker class STARNeighborPolicy(NeighborPolicy): diff --git a/nebula/core/neighbormanagement/nodemanager.py b/nebula/core/topologymanagement/nodemanager.py similarity index 97% rename from nebula/core/neighbormanagement/nodemanager.py rename to nebula/core/topologymanagement/nodemanager.py index 177650870..915905a31 100644 --- a/nebula/core/neighbormanagement/nodemanager.py +++ b/nebula/core/topologymanagement/nodemanager.py @@ -3,11 +3,11 @@ from typing import TYPE_CHECKING from nebula.addons.functions import print_msg_box -from nebula.core.neighbormanagement.candidateselection.candidateselector import factory_CandidateSelector -from nebula.core.neighbormanagement.fastreboot import FastReboot -from nebula.core.neighbormanagement.modelhandlers.modelhandler import factory_ModelHandler -from nebula.core.neighbormanagement.momentum import Momentum -from nebula.core.neighbormanagement.neighborpolicies.neighborpolicy import factory_NeighborPolicy +from nebula.core.topologymanagement.candidateselection.candidateselector import factory_CandidateSelector +from nebula.core.topologymanagement.fastreboot import FastReboot +from nebula.core.topologymanagement.modelhandlers.modelhandler import factory_ModelHandler +from nebula.core.topologymanagement.momentum import Momentum +from nebula.core.topologymanagement.neighborpolicies.neighborpolicy import factory_NeighborPolicy from nebula.core.utils.locker import Locker if TYPE_CHECKING: @@ -400,7 +400,9 @@ async def check_robustness(self): possible_neighbors = self.neighbor_policy.get_nodes_known(neighbors_too=False) possible_neighbors = await self.engine.cm.apply_restrictions(possible_neighbors) if not possible_neighbors: - logging.info("All possible neighbors using nodes known are restricted...") + logging.info("All possible neighbors using nodes known are restricted...") + else: + pass #asyncio.create_task(self.upgrade_connection_robustness(possible_neighbors)) else: if not self.engine.get_sinchronized_status(): From 9d96587da434f7addc07eb7e8a684cfff4a767b6 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Mon, 17 Feb 2025 14:51:33 +0100 Subject: [PATCH 095/233] fix updates handling and ecs service --- nebula/core/aggregation/aggregator.py | 293 +------------- .../updatehandlers/cflupdatehandler.py | 363 ++++++++++++++++++ .../updatehandlers/dflupdatehandler.py | 14 +- .../updatehandlers/updatehandler.py | 4 +- nebula/core/engine.py | 4 +- nebula/core/network/communications.py | 118 ++---- nebula/core/topologymanagement/nodemanager.py | 28 +- 7 files changed, 421 insertions(+), 403 deletions(-) create mode 100644 nebula/core/aggregation/updatehandlers/cflupdatehandler.py diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index acacf27f4..404ca4d77 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -51,7 +51,7 @@ def __init__(self, config=None, engine=None): self._aggregation_waiting_skip = asyncio.Event() self._push_strategy_lock = Locker(name="push_strategy_lock", async_lock=True) self._end_round_push = 0 - self._update_storage = factory_update_handler("dfl", self, self._addr) + self._update_storage = factory_update_handler("DFL", self, self._addr) #TODO use json config def __str__(self): return self.__class__.__name__ @@ -89,26 +89,6 @@ async def update_received_from_source(self, model, weight, source, round, local= await self.us.storage_update(model, weight, source, round, local=False) async def notify_federation_nodes_removed(self, federation_node, remove=False): - # Neighbor has been removed - #if len(self._federation_nodes) - len(federation_nodes) > 0: - # nodes_removed = self._federation_nodes.symmetric_difference(federation_nodes) - # logging.info(f"Nodes removed from aggregation: {nodes_removed}") - # pending_nodes = self._federation_nodes - self.get_nodes_pending_models_to_aggregate() - # # logging.info(f"Pending models to aggregate: {pending_nodes}") - # shouldnt_waited_model = [] - # shouldnt_waited_model = [source for source in nodes_removed if source in pending_nodes] - # logging.info(f"Waiting models from removed neighbors: {shouldnt_waited_model}") - # if shouldnt_waited_model: - # for swm in shouldnt_waited_model: - # logging.info(f"Removing model from waiting: {swm}") - # pending_nodes.discard(swm) - # if self._aggregation_done_lock.locked(): - # if not pending_nodes: - # logging.info("No model updates required left | releasing aggregation lock...") - # self._federation_nodes = federation_nodes - # await self._aggregation_done_lock.release_async() - - # self._federation_nodes = federation_nodes await self.us.notify_federation_update(federation_node, remove=remove) def set_waiting_global_update(self): @@ -124,104 +104,6 @@ async def reset(self): pass await self._add_model_lock.release_async() - def get_nodes_pending_models_to_aggregate(self): - return {node for key in self._pending_models_to_aggregate.keys() for node in key.split()} - - async def _handle_global_update(self, model, source): - logging.info(f"πŸ”„ _handle_global_update | source={source}") - logging.info( - f"πŸ”„ _handle_global_update | Received a model from {source}. Overwriting __models with the aggregated model." - ) - self._pending_models_to_aggregate.clear() - self._pending_models_to_aggregate = {source: (model, 1)} - self._waiting_global_update = False - await self._add_model_lock.release_async() - await self._aggregation_done_lock.release_async() - - async def _add_pending_model(self, model, weight, source): - if len(self._federation_nodes) <= len(self.get_nodes_pending_models_to_aggregate()): - logging.info("πŸ”„ _add_pending_model | Ignoring model...") - await self._add_model_lock.release_async() - return None - - if source not in self._federation_nodes: - logging.info(f"πŸ”„ _add_pending_model | Can't add a model from ({source}), which is not in the federation.") - await self._add_model_lock.release_async() - return None - - elif source not in self.get_nodes_pending_models_to_aggregate(): - logging.info( - "πŸ”„ _add_pending_model | Node is not in the aggregation buffer --> Include model in the aggregation buffer." - ) - self._pending_models_to_aggregate.update({source: (model, weight)}) - - logging.info( - f"πŸ”„ _add_pending_model | Model added in aggregation buffer ({len(self.get_nodes_pending_models_to_aggregate())!s}/{len(self._federation_nodes)!s}) | Pending nodes: {self._federation_nodes - self.get_nodes_pending_models_to_aggregate()}" - ) - - # Check if _future_models_to_aggregate has models in the current round to include in the aggregation buffer - if self.engine.get_round() in self._future_models_to_aggregate: - logging.info( - f"πŸ”„ _add_pending_model | Including next models in the aggregation buffer for round {self.engine.get_round()}" - ) - for future_model in self._future_models_to_aggregate[self.engine.get_round()]: - if future_model is None: - continue - future_model, future_weight, future_source = future_model - if ( - future_source in self._federation_nodes - and future_source not in self.get_nodes_pending_models_to_aggregate() - ): - self._pending_models_to_aggregate.update({future_source: (future_model, future_weight)}) - logging.info( - f"πŸ”„ _add_pending_model | Next model added in aggregation buffer ({len(self.get_nodes_pending_models_to_aggregate())!s}/{len(self._federation_nodes)!s}) | Pending nodes: {self._federation_nodes - self.get_nodes_pending_models_to_aggregate()}" - ) - del self._future_models_to_aggregate[self.engine.get_round()] - - for future_round in list(self._future_models_to_aggregate.keys()): - if future_round < self.engine.get_round(): - del self._future_models_to_aggregate[future_round] - - if len(self.get_nodes_pending_models_to_aggregate()) >= len(self._federation_nodes): - logging.info("πŸ”„ _add_pending_model | All models were added in the aggregation buffer. Run aggregation...") - await self._aggregation_done_lock.release_async() - - await self._add_model_lock.release_async() - return self.get_nodes_pending_models_to_aggregate() - - async def include_model_in_buffer(self, model, weight, source=None, round=None, local=False): - await self._add_model_lock.acquire_async() - logging.info( - f"πŸ”„ include_model_in_buffer | source={source} | round={round} | weight={weight} |--| __models={self._pending_models_to_aggregate.keys()} | federation_nodes={self._federation_nodes} | pending_models_to_aggregate={self.get_nodes_pending_models_to_aggregate()}" - ) - if model is None: - logging.info("πŸ”„ include_model_in_buffer | Ignoring model bad formed...") - await self._add_model_lock.release_async() - return - - if round == -1: - # Be sure that the model message is not from the initialization round (round = -1) - logging.info("πŸ”„ include_model_in_buffer | Ignoring model with round -1") - await self._add_model_lock.release_async() - return - - if self._waiting_global_update and not local: - await self._handle_global_update(model, source) - return - - await self._add_pending_model(model, weight, source) - - if len(self.get_nodes_pending_models_to_aggregate()) >= len(self._federation_nodes): - logging.info( - f"πŸ”„ include_model_in_buffer | Broadcasting MODELS_INCLUDED for round {self.engine.get_round()}" - ) - message = self.cm.create_message( - "federation", "federation_models_included", [str(arg) for arg in [self.engine.get_round()]] - ) - await self.cm.send_message_to_neighbors(message) - - return - async def get_aggregation(self): try: timeout = self.config.participant["aggregator_args"]["aggregation_timeout"] @@ -271,53 +153,11 @@ async def get_aggregation(self): "federation", "federation_models_included", [str(arg) for arg in [self.engine.get_round()]] ) await self.cm.send_message_to_neighbors(message) - - # if self._waiting_global_update and len(self._pending_models_to_aggregate) == 1: - # logging.info( - # "πŸ”„ get_aggregation | Received an global model. Overwriting my model with the aggregated model." - # ) - # aggregated_model = next(iter(self._pending_models_to_aggregate.values()))[0] - # self._pending_models_to_aggregate.clear() - # return aggregated_model - - # unique_nodes_involved = set(node for key in self._pending_models_to_aggregate for node in key.split()) - - # if len(unique_nodes_involved) != len(self._federation_nodes): - # missing_nodes = self._federation_nodes - unique_nodes_involved - # logging.info(f"πŸ”„ get_aggregation | Aggregation incomplete, missing models from: {missing_nodes}") - # else: - # logging.info("πŸ”„ get_aggregation | All models accounted for, proceeding with aggregation.") - - # self._pending_models_to_aggregate = await self.engine.apply_weight_strategy(self._pending_models_to_aggregate) - # aggregated_result = self.run_aggregation(self._pending_models_to_aggregate) - # self._pending_models_to_aggregate.clear() - + updates = await self.engine.apply_weight_strategy(updates) aggregated_result = self.run_aggregation(updates) return aggregated_result - async def include_next_model_in_buffer(self, model, weight, source=None, round=None): - logging.info(f"πŸ”„ include_next_model_in_buffer | source={source} | round={round} | weight={weight}") - if round not in self._future_models_to_aggregate: - self._future_models_to_aggregate[round] = [] - decoded_model = self.engine.trainer.deserialize_model(model) - await self._add_next_model_lock.acquire_async() - self._future_models_to_aggregate[round].append((decoded_model, weight, source)) - await self._add_next_model_lock.release_async() - - # Verify if we are waiting an update that maybe we wont received - if self._aggregation_done_lock.locked(): - pending_nodes: set = self._federation_nodes - self.get_nodes_pending_models_to_aggregate() - if pending_nodes: - for f_round, future_updates in self._future_models_to_aggregate.items(): - for _, _, source in future_updates: - if source in pending_nodes: - # logging.info(f"Waiting update from source: {source}, but future update storaged for round: {f_round}") - pending_nodes.discard(source) - - if not pending_nodes: - logging.info("Received advanced updates for all sources missing this round") - await self._aggregation_done_lock.release_async() def print_model_size(self, model): total_params = 0 @@ -333,135 +173,6 @@ def print_model_size(self, model): total_memory_in_mb = total_memory / (1024**2) logging.info(f"print_model_size | Model size: {total_memory_in_mb} MB") - def verify_push_done(self, current_round): - current_round = self.engine.get_round() - if self.engine.get_synchronizing_rounds(): - logging.info("Verifying if round push is done") - if self._end_round_push <= current_round: - logging.info("Push done...") - self.engine.set_synchronizing_rounds(False) - self._end_round_push = 0 - if len(self._future_models_to_aggregate.items()) < 2: - logging.info("Device is sinchronized") - self.engine.update_sinchronized_status(True) - else: - logging.info("Device is not sinchronized yet | more actions required...") - - async def aggregation_push_available(self): - """ - If the node is not sinchronized with the federation, it may be possible to make a push - and try to catch the federation asap. - """ - # TODO verify if an already sinchronized node gets desinchronized - current_round = self.engine.get_round() - self.verify_push_done(current_round) - - await self._push_strategy_lock.acquire_async() - - logging.info( - f"❗️ synchronized status: {self.engine.get_sinchronized_status()} | Analizing if an aggregation push is available..." - ) - if ( - not self.engine.get_sinchronized_status() - and not self.engine.get_trainning_in_progress_lock().locked() - and not self.engine.get_synchronizing_rounds() - ): - n_fed_nodes = len(self._federation_nodes) - further_round = current_round - logging.info( - f" Pending models: {len(self.get_nodes_pending_models_to_aggregate())} | federation: {n_fed_nodes}" - ) - if len(self.get_nodes_pending_models_to_aggregate()) < n_fed_nodes: - n_fed_nodes -= 1 - for f_round, fm in self._future_models_to_aggregate.items(): - # future_models dont count self node - if (f_round - current_round) > 1 or len(fm) == n_fed_nodes: - further_round = f_round - push = self.engine.get_push_acceleration() - if push == "slow": - logging.info("❗️ SLOW push selected") - logging.info( - f"❗️ Federation is at least {(f_round - current_round)} rounds ahead, Pushing slow..." - ) - await self.engine.set_pushed_done(further_round - current_round) - self.engine.update_sinchronized_status(False) - self.engine.set_synchronizing_rounds(True) - self._end_round_push = further_round - self._aggregation_waiting_skip.set() - await self._push_strategy_lock.release_async() - return - - if further_round != current_round and push == "fast": - logging.info("❗️ FAST push selected") - logging.info(f"❗️ FUTURE round: {further_round} is available, Pushing fast...") - - if further_round == (current_round + 1): - logging.info(f"πŸ”„ Rounds jumped: {1}...") - await self.engine.set_pushed_done(further_round - current_round) - self.engine.update_sinchronized_status(False) - self.engine.set_synchronizing_rounds(True) - self._end_round_push = further_round - self._aggregation_waiting_skip.set() - await self._push_strategy_lock.release_async() - return - - logging.info(f"πŸ”„ Number of rounds jumped: {further_round - current_round}...") - own_update = self._pending_models_to_aggregate.get(self.engine.get_addr()) - while own_update == None: - own_update = self._pending_models_to_aggregate.get(self.engine.get_addr()) - asyncio.sleep(1) - (model, weight) = own_update - - # Getting locks to avoid concurrency issues - await self._add_model_lock.acquire_async() - await self._add_next_model_lock.acquire_async() - - # Remove all pendings updates and add own_update - self._pending_models_to_aggregate.clear() - self._pending_models_to_aggregate.update({self.engine.get_addr(): (model, weight)}) - - # Add to pendings the future round updates - for future_update in self._future_models_to_aggregate[further_round]: - (decoded_model, weight, source) = future_update - self._pending_models_to_aggregate.update({source: (decoded_model, weight)}) - - # Clear all rounds that are going to be jumped - self._future_models_to_aggregate = { - key: value for key, value in self._future_models_to_aggregate.items() if key > further_round - } - - self.engine.update_sinchronized_status(False) - self.engine.set_synchronizing_rounds(True) - await self.engine.set_pushed_done(further_round - current_round) - self._end_round_push = further_round - self.engine.set_round(further_round) - await self._add_model_lock.release_async() - await self._add_next_model_lock.release_async() - await self._push_strategy_lock.release_async() - self._aggregation_waiting_skip.set() - return - - else: - if len(self._future_models_to_aggregate.items()) < 2: - logging.info("Info | No future rounds available, device is up to date...") - self.engine.update_sinchronized_status(True) - self.engine.set_synchronizing_rounds(False) - else: - logging.info("No rounds can be pushed...") - await self._push_strategy_lock.release_async() - else: - logging.info( - f"All models updates are received | models number: {len(self.get_nodes_pending_models_to_aggregate())}" - ) - await self._push_strategy_lock.release_async() - else: - if not self.engine.get_sinchronized_status(): - if self.engine.get_trainning_in_progress_lock().locked(): - logging.info("❗️ Cannot analize push | Trainning in progress") - elif self.engine.get_synchronizing_rounds(): - logging.info("❗️ Cannot analize push | Already pushing rounds") - await self._push_strategy_lock.release_async() - async def notify_all_updates_received(self): self._aggregation_waiting_skip.set() diff --git a/nebula/core/aggregation/updatehandlers/cflupdatehandler.py b/nebula/core/aggregation/updatehandlers/cflupdatehandler.py new file mode 100644 index 000000000..5440bf8a9 --- /dev/null +++ b/nebula/core/aggregation/updatehandlers/cflupdatehandler.py @@ -0,0 +1,363 @@ +import asyncio +import logging +from nebula.core.utils.locker import Locker +from nebula.core.aggregation.updatehandlers.updatehandler import UpdateHandler + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from nebula.core.aggregation.aggregator import Aggregator + +class CFLUpdateHandler(UpdateHandler): + def __init__( + self, + aggregator, + addr + ): + pass + + async def round_expected_updates(self, federation_nodes: set): + raise NotImplementedError + + async def storage_update(self, model, weight, source, round, local=False): + raise NotImplementedError + + async def get_round_updates(self) -> dict[str, tuple[object, float]]: + raise NotImplementedError + + async def notify_federation_update(self, source, remove=False): + raise NotImplementedError + + async def get_round_missing_nodes(self) -> set[str]: + raise NotImplementedError + + async def notify_if_all_updates_received(self): + raise NotImplementedError + + async def stop_notifying_updates(self): + raise NotImplementedError + + +# def get_nodes_pending_models_to_aggregate(self): + # return {node for key in self._pending_models_to_aggregate.keys() for node in key.split()} + + # async def _handle_global_update(self, model, source): + # logging.info(f"πŸ”„ _handle_global_update | source={source}") + # logging.info( + # f"πŸ”„ _handle_global_update | Received a model from {source}. Overwriting __models with the aggregated model." + # ) + # self._pending_models_to_aggregate.clear() + # self._pending_models_to_aggregate = {source: (model, 1)} + # self._waiting_global_update = False + # await self._add_model_lock.release_async() + # await self._aggregation_done_lock.release_async() + + # async def _add_pending_model(self, model, weight, source): + # if len(self._federation_nodes) <= len(self.get_nodes_pending_models_to_aggregate()): + # logging.info("πŸ”„ _add_pending_model | Ignoring model...") + # await self._add_model_lock.release_async() + # return None + + # if source not in self._federation_nodes: + # logging.info(f"πŸ”„ _add_pending_model | Can't add a model from ({source}), which is not in the federation.") + # await self._add_model_lock.release_async() + # return None + + # elif source not in self.get_nodes_pending_models_to_aggregate(): + # logging.info( + # "πŸ”„ _add_pending_model | Node is not in the aggregation buffer --> Include model in the aggregation buffer." + # ) + # self._pending_models_to_aggregate.update({source: (model, weight)}) + + # logging.info( + # f"πŸ”„ _add_pending_model | Model added in aggregation buffer ({len(self.get_nodes_pending_models_to_aggregate())!s}/{len(self._federation_nodes)!s}) | Pending nodes: {self._federation_nodes - self.get_nodes_pending_models_to_aggregate()}" + # ) + + # # Check if _future_models_to_aggregate has models in the current round to include in the aggregation buffer + # if self.engine.get_round() in self._future_models_to_aggregate: + # logging.info( + # f"πŸ”„ _add_pending_model | Including next models in the aggregation buffer for round {self.engine.get_round()}" + # ) + # for future_model in self._future_models_to_aggregate[self.engine.get_round()]: + # if future_model is None: + # continue + # future_model, future_weight, future_source = future_model + # if ( + # future_source in self._federation_nodes + # and future_source not in self.get_nodes_pending_models_to_aggregate() + # ): + # self._pending_models_to_aggregate.update({future_source: (future_model, future_weight)}) + # logging.info( + # f"πŸ”„ _add_pending_model | Next model added in aggregation buffer ({len(self.get_nodes_pending_models_to_aggregate())!s}/{len(self._federation_nodes)!s}) | Pending nodes: {self._federation_nodes - self.get_nodes_pending_models_to_aggregate()}" + # ) + # del self._future_models_to_aggregate[self.engine.get_round()] + + # for future_round in list(self._future_models_to_aggregate.keys()): + # if future_round < self.engine.get_round(): + # del self._future_models_to_aggregate[future_round] + + # if len(self.get_nodes_pending_models_to_aggregate()) >= len(self._federation_nodes): + # logging.info("πŸ”„ _add_pending_model | All models were added in the aggregation buffer. Run aggregation...") + # await self._aggregation_done_lock.release_async() + + # await self._add_model_lock.release_async() + # return self.get_nodes_pending_models_to_aggregate() + + # async def include_model_in_buffer(self, model, weight, source=None, round=None, local=False): + # await self._add_model_lock.acquire_async() + # logging.info( + # f"πŸ”„ include_model_in_buffer | source={source} | round={round} | weight={weight} |--| __models={self._pending_models_to_aggregate.keys()} | federation_nodes={self._federation_nodes} | pending_models_to_aggregate={self.get_nodes_pending_models_to_aggregate()}" + # ) + # if model is None: + # logging.info("πŸ”„ include_model_in_buffer | Ignoring model bad formed...") + # await self._add_model_lock.release_async() + # return + + # if round == -1: + # # Be sure that the model message is not from the initialization round (round = -1) + # logging.info("πŸ”„ include_model_in_buffer | Ignoring model with round -1") + # await self._add_model_lock.release_async() + # return + + # if self._waiting_global_update and not local: + # await self._handle_global_update(model, source) + # return + + # await self._add_pending_model(model, weight, source) + + # if len(self.get_nodes_pending_models_to_aggregate()) >= len(self._federation_nodes): + # logging.info( + # f"πŸ”„ include_model_in_buffer | Broadcasting MODELS_INCLUDED for round {self.engine.get_round()}" + # ) + # message = self.cm.create_message( + # "federation", "federation_models_included", [str(arg) for arg in [self.engine.get_round()]] + # ) + # await self.cm.send_message_to_neighbors(message) + + # return + + # async def get_aggregation(self): + # try: + # timeout = self.config.participant["aggregator_args"]["aggregation_timeout"] + # logging.info(f"Aggregation timeout: {timeout} starts...") + # await self.us.notify_if_all_updates_received() + # lock_task = asyncio.create_task(self._aggregation_done_lock.acquire_async(timeout=timeout)) + # skip_task = asyncio.create_task(self._aggregation_waiting_skip.wait()) + # done, pending = await asyncio.wait( + # [lock_task, skip_task], + # return_when=asyncio.FIRST_COMPLETED, + # ) + # lock_acquired = lock_task in done + # if skip_task in done: + # logging.info("Skipping aggregation timeout, updates received before grace time") + # self._aggregation_waiting_skip.clear() + # if not lock_acquired: + # lock_task.cancel() + # try: + # await lock_task # Clean cancel + # except asyncio.CancelledError: + # pass + + # except TimeoutError: + # logging.exception("πŸ”„ get_aggregation | Timeout reached for aggregation") + # except asyncio.CancelledError: + # logging.exception("πŸ”„ get_aggregation | Lock acquisition was cancelled") + # except Exception as e: + # logging.exception(f"πŸ”„ get_aggregation | Error acquiring lock: {e}") + # finally: + # if lock_acquired: + # await self._aggregation_done_lock.release_async() + + # await self.us.stop_notifying_updates() + # updates = await self.us.get_round_updates() + + # missing_nodes = await self.us.get_round_missing_nodes() + + # if missing_nodes: + # logging.info(f"πŸ”„ get_aggregation | Aggregation incomplete, missing models from: {missing_nodes}") + # else: + # logging.info("πŸ”„ get_aggregation | All models accounted for, proceeding with aggregation.") + + # logging.info( + # f"πŸ”„ Broadcasting MODELS_INCLUDED for round {self.engine.get_round()}" + # ) + # message = self.cm.create_message( + # "federation", "federation_models_included", [str(arg) for arg in [self.engine.get_round()]] + # ) + # await self.cm.send_message_to_neighbors(message) + + # if self._waiting_global_update and len(self._pending_models_to_aggregate) == 1: + # logging.info( + # "πŸ”„ get_aggregation | Received an global model. Overwriting my model with the aggregated model." + # ) + # aggregated_model = next(iter(self._pending_models_to_aggregate.values()))[0] + # self._pending_models_to_aggregate.clear() + # return aggregated_model + + # unique_nodes_involved = set(node for key in self._pending_models_to_aggregate for node in key.split()) + + # if len(unique_nodes_involved) != len(self._federation_nodes): + # missing_nodes = self._federation_nodes - unique_nodes_involved + # logging.info(f"πŸ”„ get_aggregation | Aggregation incomplete, missing models from: {missing_nodes}") + # else: + # logging.info("πŸ”„ get_aggregation | All models accounted for, proceeding with aggregation.") + + # self._pending_models_to_aggregate = await self.engine.apply_weight_strategy(self._pending_models_to_aggregate) + # aggregated_result = self.run_aggregation(self._pending_models_to_aggregate) + # self._pending_models_to_aggregate.clear() + + # updates = await self.engine.apply_weight_strategy(updates) + # aggregated_result = self.run_aggregation(updates) + # return aggregated_result + + # async def include_next_model_in_buffer(self, model, weight, source=None, round=None): + # logging.info(f"πŸ”„ include_next_model_in_buffer | source={source} | round={round} | weight={weight}") + # if round not in self._future_models_to_aggregate: + # self._future_models_to_aggregate[round] = [] + # decoded_model = self.engine.trainer.deserialize_model(model) + # await self._add_next_model_lock.acquire_async() + # self._future_models_to_aggregate[round].append((decoded_model, weight, source)) + # await self._add_next_model_lock.release_async() + + # # Verify if we are waiting an update that maybe we wont received + # if self._aggregation_done_lock.locked(): + # pending_nodes: set = self._federation_nodes - self.get_nodes_pending_models_to_aggregate() + # if pending_nodes: + # for f_round, future_updates in self._future_models_to_aggregate.items(): + # for _, _, source in future_updates: + # if source in pending_nodes: + # # logging.info(f"Waiting update from source: {source}, but future update storaged for round: {f_round}") + # pending_nodes.discard(source) + + # if not pending_nodes: + # logging.info("Received advanced updates for all sources missing this round") + # await self._aggregation_done_lock.release_async() + + + # def verify_push_done(self, current_round): + # current_round = self.engine.get_round() + # if self.engine.get_synchronizing_rounds(): + # logging.info("Verifying if round push is done") + # if self._end_round_push <= current_round: + # logging.info("Push done...") + # self.engine.set_synchronizing_rounds(False) + # self._end_round_push = 0 + # if len(self._future_models_to_aggregate.items()) < 2: + # logging.info("Device is sinchronized") + # self.engine.update_sinchronized_status(True) + # else: + # logging.info("Device is not sinchronized yet | more actions required...") + + # async def aggregation_push_available(self): + # """ + # If the node is not sinchronized with the federation, it may be possible to make a push + # and try to catch the federation asap. + # """ + # # TODO verify if an already sinchronized node gets desinchronized + # current_round = self.engine.get_round() + # self.verify_push_done(current_round) + + # await self._push_strategy_lock.acquire_async() + + # logging.info( + # f"❗️ synchronized status: {self.engine.get_sinchronized_status()} | Analizing if an aggregation push is available..." + # ) + # if ( + # not self.engine.get_sinchronized_status() + # and not self.engine.get_trainning_in_progress_lock().locked() + # and not self.engine.get_synchronizing_rounds() + # ): + # n_fed_nodes = len(self._federation_nodes) + # further_round = current_round + # logging.info( + # f" Pending models: {len(self.get_nodes_pending_models_to_aggregate())} | federation: {n_fed_nodes}" + # ) + # if len(self.get_nodes_pending_models_to_aggregate()) < n_fed_nodes: + # n_fed_nodes -= 1 + # for f_round, fm in self._future_models_to_aggregate.items(): + # # future_models dont count self node + # if (f_round - current_round) > 1 or len(fm) == n_fed_nodes: + # further_round = f_round + # push = self.engine.get_push_acceleration() + # if push == "slow": + # logging.info("❗️ SLOW push selected") + # logging.info( + # f"❗️ Federation is at least {(f_round - current_round)} rounds ahead, Pushing slow..." + # ) + # await self.engine.set_pushed_done(further_round - current_round) + # self.engine.update_sinchronized_status(False) + # self.engine.set_synchronizing_rounds(True) + # self._end_round_push = further_round + # self._aggregation_waiting_skip.set() + # await self._push_strategy_lock.release_async() + # return + + # if further_round != current_round and push == "fast": + # logging.info("❗️ FAST push selected") + # logging.info(f"❗️ FUTURE round: {further_round} is available, Pushing fast...") + + # if further_round == (current_round + 1): + # logging.info(f"πŸ”„ Rounds jumped: {1}...") + # await self.engine.set_pushed_done(further_round - current_round) + # self.engine.update_sinchronized_status(False) + # self.engine.set_synchronizing_rounds(True) + # self._end_round_push = further_round + # self._aggregation_waiting_skip.set() + # await self._push_strategy_lock.release_async() + # return + + # logging.info(f"πŸ”„ Number of rounds jumped: {further_round - current_round}...") + # own_update = self._pending_models_to_aggregate.get(self.engine.get_addr()) + # while own_update == None: + # own_update = self._pending_models_to_aggregate.get(self.engine.get_addr()) + # asyncio.sleep(1) + # (model, weight) = own_update + + # # Getting locks to avoid concurrency issues + # await self._add_model_lock.acquire_async() + # await self._add_next_model_lock.acquire_async() + + # # Remove all pendings updates and add own_update + # self._pending_models_to_aggregate.clear() + # self._pending_models_to_aggregate.update({self.engine.get_addr(): (model, weight)}) + + # # Add to pendings the future round updates + # for future_update in self._future_models_to_aggregate[further_round]: + # (decoded_model, weight, source) = future_update + # self._pending_models_to_aggregate.update({source: (decoded_model, weight)}) + + # # Clear all rounds that are going to be jumped + # self._future_models_to_aggregate = { + # key: value for key, value in self._future_models_to_aggregate.items() if key > further_round + # } + + # self.engine.update_sinchronized_status(False) + # self.engine.set_synchronizing_rounds(True) + # await self.engine.set_pushed_done(further_round - current_round) + # self._end_round_push = further_round + # self.engine.set_round(further_round) + # await self._add_model_lock.release_async() + # await self._add_next_model_lock.release_async() + # await self._push_strategy_lock.release_async() + # self._aggregation_waiting_skip.set() + # return + + # else: + # if len(self._future_models_to_aggregate.items()) < 2: + # logging.info("Info | No future rounds available, device is up to date...") + # self.engine.update_sinchronized_status(True) + # self.engine.set_synchronizing_rounds(False) + # else: + # logging.info("No rounds can be pushed...") + # await self._push_strategy_lock.release_async() + # else: + # logging.info( + # f"All models updates are received | models number: {len(self.get_nodes_pending_models_to_aggregate())}" + # ) + # await self._push_strategy_lock.release_async() + # else: + # if not self.engine.get_sinchronized_status(): + # if self.engine.get_trainning_in_progress_lock().locked(): + # logging.info("❗️ Cannot analize push | Trainning in progress") + # elif self.engine.get_synchronizing_rounds(): + # logging.info("❗️ Cannot analize push | Already pushing rounds") + # await self._push_strategy_lock.release_async() diff --git a/nebula/core/aggregation/updatehandlers/dflupdatehandler.py b/nebula/core/aggregation/updatehandlers/dflupdatehandler.py index 4cc9d5409..dc3231ce2 100644 --- a/nebula/core/aggregation/updatehandlers/dflupdatehandler.py +++ b/nebula/core/aggregation/updatehandlers/dflupdatehandler.py @@ -17,7 +17,6 @@ def __init__(self, model, weight, source, round, time_received): self.source = source self.round = round self.time_received = time_received - self.used = False def __eq__(self, other): return self.round == other.round @@ -120,6 +119,7 @@ async def get_round_updates(self): if updates_missing: self._missing_ones = updates_missing logging.info(f"Missing updates from sources: {updates_missing}") + self._nodes_using_historic.clear() updates = {} for sr in self._sources_received: source_historic = self.us[sr][1] @@ -127,7 +127,8 @@ async def get_round_updates(self): updt: Update = None updt = source_historic[-1] # Get last update received if last_updt_received and last_updt_received == updt: - logging.info(f"Missing update source: {sr}, using last update received..") + logging.info(f"Missing update from source: {sr}, using last update received..") + self._nodes_using_historic.add(sr) else: last_updt_received = updt self.us[sr] = (last_updt_received, source_historic) # Update storage with new last update used @@ -159,9 +160,6 @@ async def _update_source(self, source, remove=False): await self._updates_storage_lock.release_async() async def get_round_missing_nodes(self): - # await self._updates_storage_lock.acquire_async() - # updates_left = self._sources_expected.difference(self._sources_received) - # await self._updates_storage_lock.release_async() return self._missing_ones async def notify_if_all_updates_received(self): @@ -172,8 +170,7 @@ async def notify_if_all_updates_received(self): await self._updates_storage_lock.release_async() if all_received: await self._notify() - - + async def stop_notifying_updates(self): if self._round_updates_lock.locked(): logging.info("Stop notification updates") @@ -189,8 +186,7 @@ async def _notify(self): await self._notification_sent_lock.release_async() logging.info("πŸ”„ Notifying aggregator to release aggregation") await self.agg.notify_all_updates_received() - - + async def _all_updates_received(self): updates_left = self._sources_expected.difference(self._sources_received) all_received = False diff --git a/nebula/core/aggregation/updatehandlers/updatehandler.py b/nebula/core/aggregation/updatehandlers/updatehandler.py index f74f6862a..74e64f634 100644 --- a/nebula/core/aggregation/updatehandlers/updatehandler.py +++ b/nebula/core/aggregation/updatehandlers/updatehandler.py @@ -99,9 +99,11 @@ async def stop_notifying_updates(self): def factory_update_handler(updt_handler, aggregator, addr) -> UpdateHandler: from nebula.core.aggregation.updatehandlers.dflupdatehandler import DFLUpdateHandler + from nebula.core.aggregation.updatehandlers.cflupdatehandler import CFLUpdateHandler UPDATE_HANDLERS = { - "dfl": DFLUpdateHandler + "DFL": DFLUpdateHandler, + "CFL": CFLUpdateHandler, } update_handler = UPDATE_HANDLERS.get(updt_handler, None) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 3a24455d6..7e62bd31f 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -229,9 +229,11 @@ def get_round_lock(self): def get_sinchronized_status(self): with self.sinchronized_status_lock: + return True return self._sinchronized_status def get_synchronizing_rounds(self): + return False return self.nm.get_syncrhonizing_rounds() def update_sinchronized_status(self, status): @@ -799,8 +801,6 @@ async def _dynamic_aggregator(self, aggregated_models_weights, malicious_nodes): async def _waiting_model_updates(self): logging.info(f"πŸ’€ Waiting convergence in round {self.round}.") - # if self.mobility: - # await self.aggregator.aggregation_push_available() #TODO params = await self.aggregator.get_aggregation() if params is not None: logging.info( diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 08ebdc28f..1f88842a4 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -170,82 +170,6 @@ async def handle_model_message(self, source, message): else: model_updt_event = MessageEvent(("model","update"), source, message) await self.engine.trigger_event(model_updt_event) - - # if self.get_round() is not None: - # await self.engine.get_round_lock().acquire_async() - # current_round = self.get_round() - # await self.engine.get_round_lock().release_async() - - # if message.round != current_round and message.round != -1: - # logging.info( - # f"❗️ handle_model_message | Received a model from a different round | Model round: {message.round} | Current round: {current_round}" - # ) - # if message.round > current_round: - # logging.info( - # f"πŸ€– handle_model_message | Saving model from {source} for future round {message.round}" - # ) - # logging.info("### ENTRO 1 ###") - # await self.engine.aggregator.include_next_model_in_buffer( - # message.parameters, - # message.weight, - # source=source, - # round=message.round, - # ) - # else: - # logging.info(f"❗️ handle_model_message | Ignoring model from {source} from a previous round") - # return - # if not self.engine.get_federation_ready_lock().locked() and len(self.engine.get_federation_nodes()) == 0: - # logging.info("πŸ€– handle_model_message | There are no defined federation nodes") - # return - # try: - # # get_federation_ready_lock() is locked when the model is being initialized (first round) - # # non-starting nodes receive the initialized model from the starting node - # if not self.engine.get_federation_ready_lock().locked() or self.engine.get_initialization_status(): - # decoded_model = self.engine.trainer.deserialize_model(message.parameters) - # if False and self.config.participant["adaptive_args"]["model_similarity"]: - # pass - - # logging.info("### ENTRO 2 ###") - # await self.engine.aggregator.include_model_in_buffer( - # decoded_model, - # message.weight, - # source=source, - # round=message.round, - # ) - - # else: - # if message.round != -1: - # # Be sure that the model message is from the initialization round (round = -1) - # logging.info( - # f"πŸ€– handle_model_message | Saving model from {source} for future round {message.round}" - # ) - # logging.info("### ENTRO 3 ###") - # await self.engine.aggregator.include_next_model_in_buffer( - # message.parameters, - # message.weight, - # source=source, - # round=message.round, - # ) - # return - - - # except Exception as e: - # logging.exception(f"πŸ€– handle_model_message | Unknown error adding model: {e}") - # logging.exception(traceback.format_exc()) - - # else: - # logging.info("πŸ€– handle_model_message | Tried to add a model while learning is not running") - # if message.round != -1: - # # Be sure that the model message is from the initialization round (round = -1) - # logging.info("### ENTRO 4 ###") - # logging.info(f"πŸ€– handle_model_message | Saving model from {source} for future round {message.round}") - # await self.engine.aggregator.include_next_model_in_buffer( - # message.parameters, - # message.weight, - # source=source, - # round=message.round, - # ) - # return def create_message(self, message_type: str, action: str = "", *args, **kwargs): return self.mm.create_message(message_type, action, *args, **kwargs) @@ -293,7 +217,10 @@ def init_external_connection_service(self): async def is_external_connection_service_running(self): return self.ecs.is_running() - #TODO comprobar que el verify_connections no cree un bucle de espera infinito + #TODO + # si se utilizan addr conocidas y no se consigue conectar a ninguna quΓ© hacer + # -> funcion reentrante pero sin utilizar las conocidas + # S async def stablish_connection_to_federation(self, msg_type="discover_join", addrs_known=None): """ Using ExternalConnectionService to get addrs on local network, after that @@ -309,19 +236,30 @@ async def stablish_connection_to_federation(self, msg_type="discover_join", addr addrs = addrs_known msg = self.create_message("discover", msg_type) - - logging.info("Starting communications with devices found") - for addr in addrs: - await self.connect(addr, direct=False) - await asyncio.sleep(1) - while not self.verify_connections(addrs): - await asyncio.sleep(1) - current_connections = await self.get_addrs_current_connections(only_undirected=True) - logging.info(f"Connections verified after searching: {current_connections}") - for addr in addrs: - logging.info(f"Sending {msg_type} to ---> {addr}") - asyncio.create_task(self.send_message(addr, msg)) - await asyncio.sleep(1) + + neighbors = await self.get_addrs_current_connections(only_undirected=True) + addrs = set(addrs) + if neighbors: + addrs.difference_update(neighbors) + + if addrs: + logging.info("Starting communications with devices found") + max_tries = 5 + for addr in addrs: + await self.connect(addr, direct=False) + await asyncio.sleep(1) + for i in range(0,max_tries): + if self.verify_connections(addrs): + break + await asyncio.sleep(1) + # while not self.verify_connections(addrs): + # await asyncio.sleep(1) + current_connections = await self.get_addrs_current_connections(only_undirected=True) + logging.info(f"Connections verified after searching: {current_connections}") + for addr in addrs: + logging.info(f"Sending {msg_type} to ---> {addr}") + asyncio.create_task(self.send_message(addr, msg)) + await asyncio.sleep(1) """ ############################## diff --git a/nebula/core/topologymanagement/nodemanager.py b/nebula/core/topologymanagement/nodemanager.py index 915905a31..27bb83ac9 100644 --- a/nebula/core/topologymanagement/nodemanager.py +++ b/nebula/core/topologymanagement/nodemanager.py @@ -159,9 +159,11 @@ def late_config(self): self, self.neighbor_policy.get_nodes_known(neighbors_only=True), dispersion_penalty=False ) - ############################## - # WEIGHT STRATEGIES # - ############################## + """ + ############################## + # WEIGHT STRATEGIES # + ############################## + """ async def update_learning_rate(self, new_lr): await self.engine.update_model_learning_rate(new_lr) @@ -182,9 +184,11 @@ async def apply_weight_strategy(self, updates: dict): if self._momemtum: await self._momemtum.calculate_momentum_weights(updates) - ############################## - # CONNECTIONS # - ############################## + """ + ############################## + # CONNECTIONS # + ############################## + """ def accept_connection(self, source, joining=False): return self.neighbor_policy.accept_connection(source, joining) @@ -299,6 +303,7 @@ async def stop_not_selected_connections(self): except asyncio.CancelledError: pass + #TODO todo esto es innecesario async def check_external_connection_service_status(self): logging.info("πŸ”„ Checking external connection service status...") n = await self.neighbors_left() @@ -338,6 +343,7 @@ async def start_late_connection_process(self, connected=False, msg_type="discove await self.engine.cm.stablish_connection_to_federation(msg_type, addrs_known) # wait offer + #TODO actualizar con la informacion de latencias logging.info(f"Waiting: {self.recieve_offer_timer}s to receive offers from federation") await asyncio.sleep(self.recieve_offer_timer) @@ -377,10 +383,12 @@ async def start_late_connection_process(self, connected=False, msg_type="discove if not connected: logging.info("❗️ repeating process...") await self.start_late_connection_process(connected, msg_type, addrs_known) - - ############################## - # ROBUSTNESS # - ############################## + + """ + ############################## + # ROBUSTNESS # + ############################## + """ async def check_robustness(self): logging.info("πŸ”„ Analizing node network robustness...") From b821b3701bed39581d4a9bd69b66c4a2be4721e3 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Tue, 18 Feb 2025 10:44:18 +0100 Subject: [PATCH 096/233] fix_propagator_error --- nebula/core/aggregation/aggregator.py | 23 +++++++++++-------- nebula/core/engine.py | 2 +- nebula/core/network/communications.py | 21 ++++++++++++----- nebula/core/network/propagator.py | 4 ++-- nebula/core/topologymanagement/nodemanager.py | 10 ++++---- 5 files changed, 36 insertions(+), 24 deletions(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index 404ca4d77..695221865 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -90,19 +90,22 @@ async def update_received_from_source(self, model, weight, source, round, local= async def notify_federation_nodes_removed(self, federation_node, remove=False): await self.us.notify_federation_update(federation_node, remove=remove) + + def get_nodes_pending_models_to_aggregate(self): + return self._federation_nodes def set_waiting_global_update(self): self._waiting_global_update = True - async def reset(self): - await self._add_model_lock.acquire_async() - self._federation_nodes.clear() - self._pending_models_to_aggregate.clear() - try: - await self._aggregation_done_lock.release_async() - except: - pass - await self._add_model_lock.release_async() + # async def reset(self): + # await self._add_model_lock.acquire_async() + # self._federation_nodes.clear() + # self._pending_models_to_aggregate.clear() + # try: + # await self._aggregation_done_lock.release_async() + # except: + # pass + # await self._add_model_lock.release_async() async def get_aggregation(self): try: @@ -133,7 +136,7 @@ async def get_aggregation(self): except Exception as e: logging.exception(f"πŸ”„ get_aggregation | Error acquiring lock: {e}") finally: - if lock_acquired: + if lock_acquired or self._aggregation_done_lock.locked(): await self._aggregation_done_lock.release_async() await self.us.stop_notifying_updates() diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 7e62bd31f..1c33e6906 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -844,7 +844,7 @@ async def _learning_cycle(self): indent=2, title="Round information", ) - await self.aggregator.reset() + #await self.aggregator.reset() self.trainer.on_round_end() self.round = self.round + 1 self.config.participant["federation_args"]["round"] = ( diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 1f88842a4..4477fa45e 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -237,11 +237,13 @@ async def stablish_connection_to_federation(self, msg_type="discover_join", addr msg = self.create_message("discover", msg_type) - neighbors = await self.get_addrs_current_connections(only_undirected=True) + # Remove neighbors + neighbors = await self.get_addrs_current_connections(only_undirected=True, myself=True) addrs = set(addrs) if neighbors: addrs.difference_update(neighbors) + discovers_sent = 0 if addrs: logging.info("Starting communications with devices found") max_tries = 5 @@ -249,17 +251,18 @@ async def stablish_connection_to_federation(self, msg_type="discover_join", addr await self.connect(addr, direct=False) await asyncio.sleep(1) for i in range(0,max_tries): - if self.verify_connections(addrs): + if self.verify_any_connections(addrs): break await asyncio.sleep(1) - # while not self.verify_connections(addrs): - # await asyncio.sleep(1) current_connections = await self.get_addrs_current_connections(only_undirected=True) logging.info(f"Connections verified after searching: {current_connections}") + for addr in addrs: logging.info(f"Sending {msg_type} to ---> {addr}") asyncio.create_task(self.send_message(addr, msg)) await asyncio.sleep(1) + discovers_sent += 1 + return discovers_sent """ ############################## @@ -438,6 +441,12 @@ async def run_reconnections(self): connection["tries"] += 1 await self.connect(connection["addr"]) + def verify_any_connections(self, neighbors): + # Return True if any neighbors are connected + if any(neighbor in self.connections for neighbor in neighbors): + return True + return False + def verify_connections(self, neighbors): # Return True if all neighbors are connected if all(neighbor in self.connections for neighbor in neighbors): @@ -452,8 +461,8 @@ async def deploy_additional_services(self): self._generate_network_conditions() await self._forwarder.start() if self.config.participant["mobility_args"]["mobility"]: - pass - #await self._discoverer.start() + #pass + await self._discoverer.start() # await self._health.start() self._propagator.start() await self._mobility.start() diff --git a/nebula/core/network/propagator.py b/nebula/core/network/propagator.py index f156090f2..1907c03f5 100755 --- a/nebula/core/network/propagator.py +++ b/nebula/core/network/propagator.py @@ -163,8 +163,8 @@ async def propagate(self, strategy_id: str): for neighbor_addr in eligible_neighbors: asyncio.create_task(self.cm.send_model(neighbor_addr, round_number, serialized_model, weight)) - if len(self.aggregator.get_nodes_pending_models_to_aggregate()) >= len(self.aggregator._federation_nodes): - return False + # if len(self.aggregator.get_nodes_pending_models_to_aggregate()) >= len(self.aggregator._federation_nodes): + # return False await asyncio.sleep(self.interval) return True diff --git a/nebula/core/topologymanagement/nodemanager.py b/nebula/core/topologymanagement/nodemanager.py index 27bb83ac9..ffa7a512f 100644 --- a/nebula/core/topologymanagement/nodemanager.py +++ b/nebula/core/topologymanagement/nodemanager.py @@ -313,7 +313,7 @@ async def check_external_connection_service_status(self): logging.info(f"Stats | neighbors: {n} | service running: {ecs} | synchronized status: {ss}") if not await self.neighbors_left() and await self.engine.cm.is_external_connection_service_running(): logging.info("❗️ Isolated node | Shutdowning service required") - action = lambda: self.engine.cm.stop_external_connection_service() + #action = lambda: self.engine.cm.stop_external_connection_service() elif ( await self.neighbors_left() and not await self.engine.cm.is_external_connection_service_running() @@ -374,7 +374,7 @@ async def start_late_connection_process(self, connected=False, msg_type="discove self.candidate_selector.remove_candidates() if not self._desc_done: #TODO remove self._desc_done = True - asyncio.create_task(self.stop_connections_with_federation()) + #asyncio.create_task(self.stop_connections_with_federation()) # if no candidates, repeat process else: logging.info("❗️ No Candidates found...") @@ -423,9 +423,9 @@ async def check_robustness(self): async def reconnect_to_federation(self): self._restructure_process_lock.acquire() await self.engine.cm.clear_restrictions() - await asyncio.sleep(120) - if await self.engine.cm.is_external_connection_service_running(): - self.engine.cm.stop_external_connection_service() + #await asyncio.sleep(120) + #if await self.engine.cm.is_external_connection_service_running(): + # self.engine.cm.stop_external_connection_service() # If we got some refs, try to reconnect to them if len(self.neighbor_policy.get_nodes_known()) > 0: logging.info("Reconnecting | Addrs availables") From 37a017c4e2ed880684a140ce12e34d3c210b964e Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Tue, 18 Feb 2025 12:42:52 +0100 Subject: [PATCH 097/233] opt_test_mobility --- nebula/addons/mobility.py | 3 +++ nebula/core/aggregation/updatehandlers/dflupdatehandler.py | 1 + nebula/core/network/communications.py | 5 +++++ nebula/core/network/nebulamulticasting.py | 5 +++-- nebula/core/topologymanagement/nodemanager.py | 2 +- 5 files changed, 13 insertions(+), 3 deletions(-) diff --git a/nebula/addons/mobility.py b/nebula/addons/mobility.py index de3d922e1..41ffd4038 100755 --- a/nebula/addons/mobility.py +++ b/nebula/addons/mobility.py @@ -346,6 +346,7 @@ async def change_connections_based_on_distance(self): ): logging.info(f"πŸ“ Node {addr} is close enough [{distance}], adding to direct connections") self.cm.connections[addr].set_direct(True) + await self.cm.update_neighbors(addr) else: # 10% margin to avoid oscillations if ( @@ -355,7 +356,9 @@ async def change_connections_based_on_distance(self): logging.info( f"πŸ“ Node {addr} is too far away [{distance}], removing from direct connections" ) + await asyncio.sleep(1) self.cm.connections[addr].set_direct(False) + await self.cm.update_neighbors(addr,remove=True) # Adapt network conditions of the connection based on distance for threshold in sorted(self.network_conditions.keys()): if distance < threshold: diff --git a/nebula/core/aggregation/updatehandlers/dflupdatehandler.py b/nebula/core/aggregation/updatehandlers/dflupdatehandler.py index dc3231ce2..42ff199d7 100644 --- a/nebula/core/aggregation/updatehandlers/dflupdatehandler.py +++ b/nebula/core/aggregation/updatehandlers/dflupdatehandler.py @@ -153,6 +153,7 @@ async def _update_source(self, source, remove=False): await self._updates_storage_lock.acquire_async() if remove: self._sources_expected.discard(source) + await self._all_updates_received() else: self.us[source] = (None, deque(maxlen=self._buffersize)) self._sources_expected.add(source) diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 4477fa45e..9541290c6 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -269,6 +269,11 @@ async def stablish_connection_to_federation(self, msg_type="discover_join", addr # OTHER FUNCTIONALITIES # ############################## """ + + #TODO remove + async def update_neighbors(self, addr, remove=False): + current_connections = await self.get_addrs_current_connections(only_direct=True, myself=True) + await self.engine.update_neighbors(addr, current_connections, remove=remove) def get_connections_lock(self): return self.connections_lock diff --git a/nebula/core/network/nebulamulticasting.py b/nebula/core/network/nebulamulticasting.py index e573ca7bf..d65a06dc7 100644 --- a/nebula/core/network/nebulamulticasting.py +++ b/nebula/core/network/nebulamulticasting.py @@ -201,8 +201,9 @@ def _add_addr(self, msg_str): if linea.strip().startswith("LOCATION:"): addr = linea.split(": ")[1].strip() break - logging.info(f"Device addr received: {addr}") - self.nodes_found.append(addr) + if addr != self.addr: + logging.info(f"Device addr received: {addr}") + self.nodes_found.append(addr) self.addrs_found_lock.release() def get_nodes(self): diff --git a/nebula/core/topologymanagement/nodemanager.py b/nebula/core/topologymanagement/nodemanager.py index ffa7a512f..4147cb2ad 100644 --- a/nebula/core/topologymanagement/nodemanager.py +++ b/nebula/core/topologymanagement/nodemanager.py @@ -397,7 +397,7 @@ async def check_robustness(self): if not await self.neighbors_left(): logging.info("No Neighbors left | reconnecting with Federation") self.engine.update_sinchronized_status(False) - await self.reconnect_to_federation() + #await self.reconnect_to_federation() elif ( self.neighbor_policy.need_more_neighbors() and self.engine.get_sinchronized_status() From 4432f46b1f602696a1b29a14132603fd323c5c26 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Tue, 18 Feb 2025 13:06:42 +0100 Subject: [PATCH 098/233] refactor situational awareness --- .../updatehandlers/dflupdatehandler.py | 9 ++++---- nebula/core/engine.py | 2 +- .../README.txt | 0 .../__init__.py | 0 .../connectionoptimizer.py | 0 .../networkoptimization/networkoptimizer.py | 0 .../networkoptimization/timergenerator.py | 0 .../awareness/samodule.py | 0 .../candidateselection/__init__.py | 0 .../candidateselection/candidateselector.py | 8 +++---- .../candidateselection/fccandidateselector.py | 2 +- .../hetcandidateselector.py | 2 +- .../ringcandidateselector.py | 2 +- .../stdcandidateselector.py | 2 +- .../fastreboot.py | 2 +- .../modelhandlers/__init__.py | 0 .../modelhandlers/aggmodelhandler.py | 2 +- .../modelhandlers/defaultmodelhandler.py | 4 ++-- .../modelhandlers/modelhandler.py | 6 ++--- .../modelhandlers/stdmodelhandler.py | 2 +- .../momentum.py | 2 +- .../neighborpolicies/__init__.py | 0 .../neighborpolicies/fcneighborpolicy.py | 2 +- .../neighborpolicies/idleneighborpolicy.py | 2 +- .../neighborpolicies/neighborpolicy.py | 8 +++---- .../neighborpolicies/ringneighborpolicy.py | 2 +- .../neighborpolicies/starneighborpolicy.py | 2 +- .../nodemanager.py | 23 +++++++++---------- 28 files changed, 41 insertions(+), 43 deletions(-) rename nebula/core/{topologymanagement => situationalawareness}/README.txt (100%) rename nebula/core/{topologymanagement => situationalawareness}/__init__.py (100%) rename nebula/core/{topologymanagement => situationalawareness}/awareness/networkoptimization/connectionoptimizer.py (100%) rename nebula/core/{topologymanagement => situationalawareness}/awareness/networkoptimization/networkoptimizer.py (100%) rename nebula/core/{topologymanagement => situationalawareness}/awareness/networkoptimization/timergenerator.py (100%) rename nebula/core/{topologymanagement => situationalawareness}/awareness/samodule.py (100%) rename nebula/core/{topologymanagement => situationalawareness}/candidateselection/__init__.py (100%) rename nebula/core/{topologymanagement => situationalawareness}/candidateselection/candidateselector.py (63%) rename nebula/core/{topologymanagement => situationalawareness}/candidateselection/fccandidateselector.py (92%) rename nebula/core/{topologymanagement => situationalawareness}/candidateselection/hetcandidateselector.py (97%) rename nebula/core/{topologymanagement => situationalawareness}/candidateselection/ringcandidateselector.py (91%) rename nebula/core/{topologymanagement => situationalawareness}/candidateselection/stdcandidateselector.py (91%) rename nebula/core/{topologymanagement => situationalawareness}/fastreboot.py (98%) rename nebula/core/{topologymanagement => situationalawareness}/modelhandlers/__init__.py (100%) rename nebula/core/{topologymanagement => situationalawareness}/modelhandlers/aggmodelhandler.py (94%) rename nebula/core/{topologymanagement => situationalawareness}/modelhandlers/defaultmodelhandler.py (89%) rename nebula/core/{topologymanagement => situationalawareness}/modelhandlers/modelhandler.py (68%) rename nebula/core/{topologymanagement => situationalawareness}/modelhandlers/stdmodelhandler.py (94%) rename nebula/core/{topologymanagement => situationalawareness}/momentum.py (99%) rename nebula/core/{topologymanagement => situationalawareness}/neighborpolicies/__init__.py (100%) rename nebula/core/{topologymanagement => situationalawareness}/neighborpolicies/fcneighborpolicy.py (97%) rename nebula/core/{topologymanagement => situationalawareness}/neighborpolicies/idleneighborpolicy.py (97%) rename nebula/core/{topologymanagement => situationalawareness}/neighborpolicies/neighborpolicy.py (72%) rename nebula/core/{topologymanagement => situationalawareness}/neighborpolicies/ringneighborpolicy.py (97%) rename nebula/core/{topologymanagement => situationalawareness}/neighborpolicies/starneighborpolicy.py (96%) rename nebula/core/{topologymanagement => situationalawareness}/nodemanager.py (95%) diff --git a/nebula/core/aggregation/updatehandlers/dflupdatehandler.py b/nebula/core/aggregation/updatehandlers/dflupdatehandler.py index 42ff199d7..032f8f0b2 100644 --- a/nebula/core/aggregation/updatehandlers/dflupdatehandler.py +++ b/nebula/core/aggregation/updatehandlers/dflupdatehandler.py @@ -125,13 +125,13 @@ async def get_round_updates(self): source_historic = self.us[sr][1] last_updt_received = self.us[sr][0] updt: Update = None - updt = source_historic[-1] # Get last update received + updt = source_historic[-1] # Get last update received if last_updt_received and last_updt_received == updt: logging.info(f"Missing update from source: {sr}, using last update received..") self._nodes_using_historic.add(sr) else: last_updt_received = updt - self.us[sr] = (last_updt_received, source_historic) # Update storage with new last update used + self.us[sr] = (last_updt_received, source_historic) # Update storage with new last update used updates[sr] = (updt.model, updt.weight) await self._updates_storage_lock.release_async() @@ -144,16 +144,15 @@ async def notify_federation_update(self, source, remove=False): else: await self._update_source(source, remove) else: - # Not received update from this source yet - if not source in self._sources_received: + if not source in self._sources_received: # Not received update from this source yet await self._update_source(source, remove=True) + await self._all_updates_received() # Verify if discarding node aggregation could be done async def _update_source(self, source, remove=False): logging.info(f"πŸ”„ Update | remove: {remove} | source: {source}") await self._updates_storage_lock.acquire_async() if remove: self._sources_expected.discard(source) - await self._all_updates_received() else: self.us[source] = (None, deque(maxlen=self._buffersize)) self._sources_expected.add(source) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 1c33e6906..e2a324ea1 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -9,7 +9,7 @@ from nebula.addons.reporter import Reporter from nebula.core.aggregation.aggregator import create_aggregator, create_target_aggregator from nebula.core.eventmanager import EventManager -from nebula.core.topologymanagement.nodemanager import NodeManager +from nebula.core.situationalawareness.nodemanager import NodeManager from nebula.core.network.communications import CommunicationsManager from nebula.core.utils.locker import Locker diff --git a/nebula/core/topologymanagement/README.txt b/nebula/core/situationalawareness/README.txt similarity index 100% rename from nebula/core/topologymanagement/README.txt rename to nebula/core/situationalawareness/README.txt diff --git a/nebula/core/topologymanagement/__init__.py b/nebula/core/situationalawareness/__init__.py similarity index 100% rename from nebula/core/topologymanagement/__init__.py rename to nebula/core/situationalawareness/__init__.py diff --git a/nebula/core/topologymanagement/awareness/networkoptimization/connectionoptimizer.py b/nebula/core/situationalawareness/awareness/networkoptimization/connectionoptimizer.py similarity index 100% rename from nebula/core/topologymanagement/awareness/networkoptimization/connectionoptimizer.py rename to nebula/core/situationalawareness/awareness/networkoptimization/connectionoptimizer.py diff --git a/nebula/core/topologymanagement/awareness/networkoptimization/networkoptimizer.py b/nebula/core/situationalawareness/awareness/networkoptimization/networkoptimizer.py similarity index 100% rename from nebula/core/topologymanagement/awareness/networkoptimization/networkoptimizer.py rename to nebula/core/situationalawareness/awareness/networkoptimization/networkoptimizer.py diff --git a/nebula/core/topologymanagement/awareness/networkoptimization/timergenerator.py b/nebula/core/situationalawareness/awareness/networkoptimization/timergenerator.py similarity index 100% rename from nebula/core/topologymanagement/awareness/networkoptimization/timergenerator.py rename to nebula/core/situationalawareness/awareness/networkoptimization/timergenerator.py diff --git a/nebula/core/topologymanagement/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py similarity index 100% rename from nebula/core/topologymanagement/awareness/samodule.py rename to nebula/core/situationalawareness/awareness/samodule.py diff --git a/nebula/core/topologymanagement/candidateselection/__init__.py b/nebula/core/situationalawareness/candidateselection/__init__.py similarity index 100% rename from nebula/core/topologymanagement/candidateselection/__init__.py rename to nebula/core/situationalawareness/candidateselection/__init__.py diff --git a/nebula/core/topologymanagement/candidateselection/candidateselector.py b/nebula/core/situationalawareness/candidateselection/candidateselector.py similarity index 63% rename from nebula/core/topologymanagement/candidateselection/candidateselector.py rename to nebula/core/situationalawareness/candidateselection/candidateselector.py index 30a5b09e8..71c2b28c5 100644 --- a/nebula/core/topologymanagement/candidateselection/candidateselector.py +++ b/nebula/core/situationalawareness/candidateselection/candidateselector.py @@ -24,10 +24,10 @@ def any_candidate(self): pass def factory_CandidateSelector(topology) -> CandidateSelector: - from nebula.core.topologymanagement.candidateselection.stdcandidateselector import STDandidateSelector - from nebula.core.topologymanagement.candidateselection.fccandidateselector import FCCandidateSelector - from nebula.core.topologymanagement.candidateselection.hetcandidateselector import HETCandidateSelector - from nebula.core.topologymanagement.candidateselection.ringcandidateselector import RINGCandidateSelector + from nebula.core.situationalawareness.candidateselection.stdcandidateselector import STDandidateSelector + from nebula.core.situationalawareness.candidateselection.fccandidateselector import FCCandidateSelector + from nebula.core.situationalawareness.candidateselection.hetcandidateselector import HETCandidateSelector + from nebula.core.situationalawareness.candidateselection.ringcandidateselector import RINGCandidateSelector options = { "ring": RINGCandidateSelector, diff --git a/nebula/core/topologymanagement/candidateselection/fccandidateselector.py b/nebula/core/situationalawareness/candidateselection/fccandidateselector.py similarity index 92% rename from nebula/core/topologymanagement/candidateselection/fccandidateselector.py rename to nebula/core/situationalawareness/candidateselection/fccandidateselector.py index e487737da..3a91098d8 100644 --- a/nebula/core/topologymanagement/candidateselection/fccandidateselector.py +++ b/nebula/core/situationalawareness/candidateselection/fccandidateselector.py @@ -1,4 +1,4 @@ -from nebula.core.topologymanagement.candidateselection.candidateselector import CandidateSelector +from nebula.core.situationalawareness.candidateselection.candidateselector import CandidateSelector from nebula.core.utils.locker import Locker class FCCandidateSelector(CandidateSelector): diff --git a/nebula/core/topologymanagement/candidateselection/hetcandidateselector.py b/nebula/core/situationalawareness/candidateselection/hetcandidateselector.py similarity index 97% rename from nebula/core/topologymanagement/candidateselection/hetcandidateselector.py rename to nebula/core/situationalawareness/candidateselection/hetcandidateselector.py index 345f63c5b..84766581d 100644 --- a/nebula/core/topologymanagement/candidateselection/hetcandidateselector.py +++ b/nebula/core/situationalawareness/candidateselection/hetcandidateselector.py @@ -1,4 +1,4 @@ -from nebula.core.topologymanagement.candidateselection.candidateselector import CandidateSelector +from nebula.core.situationalawareness.candidateselection.candidateselector import CandidateSelector from nebula.core.utils.locker import Locker class HETCandidateSelector(CandidateSelector): diff --git a/nebula/core/topologymanagement/candidateselection/ringcandidateselector.py b/nebula/core/situationalawareness/candidateselection/ringcandidateselector.py similarity index 91% rename from nebula/core/topologymanagement/candidateselection/ringcandidateselector.py rename to nebula/core/situationalawareness/candidateselection/ringcandidateselector.py index 47990e88d..02effc281 100644 --- a/nebula/core/topologymanagement/candidateselection/ringcandidateselector.py +++ b/nebula/core/situationalawareness/candidateselection/ringcandidateselector.py @@ -1,4 +1,4 @@ -from nebula.core.topologymanagement.candidateselection.candidateselector import CandidateSelector +from nebula.core.situationalawareness.candidateselection.candidateselector import CandidateSelector from nebula.core.utils.locker import Locker class RINGCandidateSelector(CandidateSelector): diff --git a/nebula/core/topologymanagement/candidateselection/stdcandidateselector.py b/nebula/core/situationalawareness/candidateselection/stdcandidateselector.py similarity index 91% rename from nebula/core/topologymanagement/candidateselection/stdcandidateselector.py rename to nebula/core/situationalawareness/candidateselection/stdcandidateselector.py index ddbdbcf57..022677fc6 100644 --- a/nebula/core/topologymanagement/candidateselection/stdcandidateselector.py +++ b/nebula/core/situationalawareness/candidateselection/stdcandidateselector.py @@ -1,4 +1,4 @@ -from nebula.core.topologymanagement.candidateselection.candidateselector import CandidateSelector +from nebula.core.situationalawareness.candidateselection.candidateselector import CandidateSelector from nebula.core.utils.locker import Locker class STDandidateSelector(CandidateSelector): diff --git a/nebula/core/topologymanagement/fastreboot.py b/nebula/core/situationalawareness/fastreboot.py similarity index 98% rename from nebula/core/topologymanagement/fastreboot.py rename to nebula/core/situationalawareness/fastreboot.py index f49919601..d11da965f 100644 --- a/nebula/core/topologymanagement/fastreboot.py +++ b/nebula/core/situationalawareness/fastreboot.py @@ -4,7 +4,7 @@ from nebula.core.utils.locker import Locker if TYPE_CHECKING: - from nebula.core.topologymanagement.nodemanager import NodeManager + from nebula.core.situationalawareness.nodemanager import NodeManager VANILLA_LEARNING_RATE = 1e-3 FR_LEARNING_RATE = 1e-3 diff --git a/nebula/core/topologymanagement/modelhandlers/__init__.py b/nebula/core/situationalawareness/modelhandlers/__init__.py similarity index 100% rename from nebula/core/topologymanagement/modelhandlers/__init__.py rename to nebula/core/situationalawareness/modelhandlers/__init__.py diff --git a/nebula/core/topologymanagement/modelhandlers/aggmodelhandler.py b/nebula/core/situationalawareness/modelhandlers/aggmodelhandler.py similarity index 94% rename from nebula/core/topologymanagement/modelhandlers/aggmodelhandler.py rename to nebula/core/situationalawareness/modelhandlers/aggmodelhandler.py index 3f0f6331a..a55c5d1cd 100644 --- a/nebula/core/topologymanagement/modelhandlers/aggmodelhandler.py +++ b/nebula/core/situationalawareness/modelhandlers/aggmodelhandler.py @@ -1,4 +1,4 @@ -from nebula.core.topologymanagement.modelhandlers.modelhandler import ModelHandler +from nebula.core.situationalawareness.modelhandlers.modelhandler import ModelHandler from nebula.core.utils.locker import Locker class AGGModelHandler(ModelHandler): diff --git a/nebula/core/topologymanagement/modelhandlers/defaultmodelhandler.py b/nebula/core/situationalawareness/modelhandlers/defaultmodelhandler.py similarity index 89% rename from nebula/core/topologymanagement/modelhandlers/defaultmodelhandler.py rename to nebula/core/situationalawareness/modelhandlers/defaultmodelhandler.py index b17910751..bcf850972 100644 --- a/nebula/core/topologymanagement/modelhandlers/defaultmodelhandler.py +++ b/nebula/core/situationalawareness/modelhandlers/defaultmodelhandler.py @@ -1,6 +1,6 @@ -from nebula.core.topologymanagement.modelhandlers.modelhandler import ModelHandler +from nebula.core.situationalawareness.modelhandlers.modelhandler import ModelHandler from nebula.core.utils.locker import Locker -from nebula.core.topologymanagement.nodemanager import NodeManager +from nebula.core.situationalawareness.nodemanager import NodeManager import logging class DefaultModelHandler(ModelHandler): diff --git a/nebula/core/topologymanagement/modelhandlers/modelhandler.py b/nebula/core/situationalawareness/modelhandlers/modelhandler.py similarity index 68% rename from nebula/core/topologymanagement/modelhandlers/modelhandler.py rename to nebula/core/situationalawareness/modelhandlers/modelhandler.py index 242dbd69a..b60022c22 100644 --- a/nebula/core/topologymanagement/modelhandlers/modelhandler.py +++ b/nebula/core/situationalawareness/modelhandlers/modelhandler.py @@ -20,9 +20,9 @@ def pre_process_model(self): pass def factory_ModelHandler(model_handler) -> ModelHandler: - from nebula.core.topologymanagement.modelhandlers.stdmodelhandler import STDModelHandler - from nebula.core.topologymanagement.modelhandlers.aggmodelhandler import AGGModelHandler - from nebula.core.topologymanagement.modelhandlers.defaultmodelhandler import DefaultModelHandler + from nebula.core.situationalawareness.modelhandlers.stdmodelhandler import STDModelHandler + from nebula.core.situationalawareness.modelhandlers.aggmodelhandler import AGGModelHandler + from nebula.core.situationalawareness.modelhandlers.defaultmodelhandler import DefaultModelHandler options = { "std": STDModelHandler, diff --git a/nebula/core/topologymanagement/modelhandlers/stdmodelhandler.py b/nebula/core/situationalawareness/modelhandlers/stdmodelhandler.py similarity index 94% rename from nebula/core/topologymanagement/modelhandlers/stdmodelhandler.py rename to nebula/core/situationalawareness/modelhandlers/stdmodelhandler.py index b93d43276..24f3c59b5 100644 --- a/nebula/core/topologymanagement/modelhandlers/stdmodelhandler.py +++ b/nebula/core/situationalawareness/modelhandlers/stdmodelhandler.py @@ -1,4 +1,4 @@ -from nebula.core.topologymanagement.modelhandlers.modelhandler import ModelHandler +from nebula.core.situationalawareness.modelhandlers.modelhandler import ModelHandler from nebula.core.utils.locker import Locker diff --git a/nebula/core/topologymanagement/momentum.py b/nebula/core/situationalawareness/momentum.py similarity index 99% rename from nebula/core/topologymanagement/momentum.py rename to nebula/core/situationalawareness/momentum.py index 6d68dd5a6..7bfdefbb2 100644 --- a/nebula/core/topologymanagement/momentum.py +++ b/nebula/core/situationalawareness/momentum.py @@ -9,7 +9,7 @@ from nebula.core.utils.locker import Locker if TYPE_CHECKING: - from nebula.core.topologymanagement.nodemanager import NodeManager + from nebula.core.situationalawareness.nodemanager import NodeManager SimilarityMetricType = Callable[[OrderedDict, OrderedDict, bool], float | None] MappingSimilarityType = Callable[[float, float], Annotated[float, "Value in (0, 1]"]] diff --git a/nebula/core/topologymanagement/neighborpolicies/__init__.py b/nebula/core/situationalawareness/neighborpolicies/__init__.py similarity index 100% rename from nebula/core/topologymanagement/neighborpolicies/__init__.py rename to nebula/core/situationalawareness/neighborpolicies/__init__.py diff --git a/nebula/core/topologymanagement/neighborpolicies/fcneighborpolicy.py b/nebula/core/situationalawareness/neighborpolicies/fcneighborpolicy.py similarity index 97% rename from nebula/core/topologymanagement/neighborpolicies/fcneighborpolicy.py rename to nebula/core/situationalawareness/neighborpolicies/fcneighborpolicy.py index b7f88c1d2..0d12cee7d 100644 --- a/nebula/core/topologymanagement/neighborpolicies/fcneighborpolicy.py +++ b/nebula/core/situationalawareness/neighborpolicies/fcneighborpolicy.py @@ -1,4 +1,4 @@ -from nebula.core.topologymanagement.neighborpolicies.neighborpolicy import NeighborPolicy +from nebula.core.situationalawareness.neighborpolicies.neighborpolicy import NeighborPolicy from nebula.core.utils.locker import Locker class FCNeighborPolicy(NeighborPolicy): diff --git a/nebula/core/topologymanagement/neighborpolicies/idleneighborpolicy.py b/nebula/core/situationalawareness/neighborpolicies/idleneighborpolicy.py similarity index 97% rename from nebula/core/topologymanagement/neighborpolicies/idleneighborpolicy.py rename to nebula/core/situationalawareness/neighborpolicies/idleneighborpolicy.py index 81c98435d..fefff7dc8 100644 --- a/nebula/core/topologymanagement/neighborpolicies/idleneighborpolicy.py +++ b/nebula/core/situationalawareness/neighborpolicies/idleneighborpolicy.py @@ -1,4 +1,4 @@ -from nebula.core.topologymanagement.neighborpolicies.neighborpolicy import NeighborPolicy +from nebula.core.situationalawareness.neighborpolicies.neighborpolicy import NeighborPolicy from nebula.core.utils.locker import Locker class IDLENeighborPolicy(NeighborPolicy): diff --git a/nebula/core/topologymanagement/neighborpolicies/neighborpolicy.py b/nebula/core/situationalawareness/neighborpolicies/neighborpolicy.py similarity index 72% rename from nebula/core/topologymanagement/neighborpolicies/neighborpolicy.py rename to nebula/core/situationalawareness/neighborpolicies/neighborpolicy.py index 7436ae12f..43352c7d7 100644 --- a/nebula/core/topologymanagement/neighborpolicies/neighborpolicy.py +++ b/nebula/core/situationalawareness/neighborpolicies/neighborpolicy.py @@ -36,10 +36,10 @@ def update_neighbors(self, node, remove=False): pass def factory_NeighborPolicy(topology) -> NeighborPolicy: - from nebula.core.topologymanagement.neighborpolicies.idleneighborpolicy import IDLENeighborPolicy - from nebula.core.topologymanagement.neighborpolicies.fcneighborpolicy import FCNeighborPolicy - from nebula.core.topologymanagement.neighborpolicies.ringneighborpolicy import RINGNeighborPolicy - from nebula.core.topologymanagement.neighborpolicies.starneighborpolicy import STARNeighborPolicy + from nebula.core.situationalawareness.neighborpolicies.idleneighborpolicy import IDLENeighborPolicy + from nebula.core.situationalawareness.neighborpolicies.fcneighborpolicy import FCNeighborPolicy + from nebula.core.situationalawareness.neighborpolicies.ringneighborpolicy import RINGNeighborPolicy + from nebula.core.situationalawareness.neighborpolicies.starneighborpolicy import STARNeighborPolicy options = { "random": IDLENeighborPolicy, # default value diff --git a/nebula/core/topologymanagement/neighborpolicies/ringneighborpolicy.py b/nebula/core/situationalawareness/neighborpolicies/ringneighborpolicy.py similarity index 97% rename from nebula/core/topologymanagement/neighborpolicies/ringneighborpolicy.py rename to nebula/core/situationalawareness/neighborpolicies/ringneighborpolicy.py index 90310d6fc..602c85f5c 100644 --- a/nebula/core/topologymanagement/neighborpolicies/ringneighborpolicy.py +++ b/nebula/core/situationalawareness/neighborpolicies/ringneighborpolicy.py @@ -1,4 +1,4 @@ -from nebula.core.topologymanagement.neighborpolicies.neighborpolicy import NeighborPolicy +from nebula.core.situationalawareness.neighborpolicies.neighborpolicy import NeighborPolicy from nebula.core.utils.locker import Locker import random diff --git a/nebula/core/topologymanagement/neighborpolicies/starneighborpolicy.py b/nebula/core/situationalawareness/neighborpolicies/starneighborpolicy.py similarity index 96% rename from nebula/core/topologymanagement/neighborpolicies/starneighborpolicy.py rename to nebula/core/situationalawareness/neighborpolicies/starneighborpolicy.py index 71b3f69c6..561a48530 100644 --- a/nebula/core/topologymanagement/neighborpolicies/starneighborpolicy.py +++ b/nebula/core/situationalawareness/neighborpolicies/starneighborpolicy.py @@ -1,4 +1,4 @@ -from nebula.core.topologymanagement.neighborpolicies.neighborpolicy import NeighborPolicy +from nebula.core.situationalawareness.neighborpolicies.neighborpolicy import NeighborPolicy from nebula.core.utils.locker import Locker class STARNeighborPolicy(NeighborPolicy): diff --git a/nebula/core/topologymanagement/nodemanager.py b/nebula/core/situationalawareness/nodemanager.py similarity index 95% rename from nebula/core/topologymanagement/nodemanager.py rename to nebula/core/situationalawareness/nodemanager.py index 4147cb2ad..69d7b2b8b 100644 --- a/nebula/core/topologymanagement/nodemanager.py +++ b/nebula/core/situationalawareness/nodemanager.py @@ -3,11 +3,11 @@ from typing import TYPE_CHECKING from nebula.addons.functions import print_msg_box -from nebula.core.topologymanagement.candidateselection.candidateselector import factory_CandidateSelector -from nebula.core.topologymanagement.fastreboot import FastReboot -from nebula.core.topologymanagement.modelhandlers.modelhandler import factory_ModelHandler -from nebula.core.topologymanagement.momentum import Momentum -from nebula.core.topologymanagement.neighborpolicies.neighborpolicy import factory_NeighborPolicy +from nebula.core.situationalawareness.candidateselection.candidateselector import factory_CandidateSelector +from nebula.core.situationalawareness.fastreboot import FastReboot +from nebula.core.situationalawareness.modelhandlers.modelhandler import factory_ModelHandler +from nebula.core.situationalawareness.momentum import Momentum +from nebula.core.situationalawareness.neighborpolicies.neighborpolicy import factory_NeighborPolicy from nebula.core.utils.locker import Locker if TYPE_CHECKING: @@ -340,12 +340,13 @@ async def start_late_connection_process(self, connected=False, msg_type="discove await self.clear_pending_confirmations() # find federation and send discover - await self.engine.cm.stablish_connection_to_federation(msg_type, addrs_known) + connections_stablished = await self.engine.cm.stablish_connection_to_federation(msg_type, addrs_known) # wait offer #TODO actualizar con la informacion de latencias - logging.info(f"Waiting: {self.recieve_offer_timer}s to receive offers from federation") - await asyncio.sleep(self.recieve_offer_timer) + if connections_stablished: + logging.info(f"Waiting: {self.recieve_offer_timer}s to receive offers from federation") + await asyncio.sleep(self.recieve_offer_timer) # acquire lock to not accept late candidates self.accept_candidates_lock.acquire() @@ -360,7 +361,7 @@ async def start_late_connection_process(self, connected=False, msg_type="discove best_candidates = self.candidate_selector.select_candidates() logging.info(f"Candidates | {[addr for addr, _, _ in best_candidates]}") - # candidates not choosen --> disconnect + #TODO candidates not choosen --> disconnect try: for addr, _, _ in best_candidates: await self.add_pending_connection_confirmation(addr) @@ -423,9 +424,7 @@ async def check_robustness(self): async def reconnect_to_federation(self): self._restructure_process_lock.acquire() await self.engine.cm.clear_restrictions() - #await asyncio.sleep(120) - #if await self.engine.cm.is_external_connection_service_running(): - # self.engine.cm.stop_external_connection_service() + #await asyncio.sleep(120) # If we got some refs, try to reconnect to them if len(self.neighbor_policy.get_nodes_known()) > 0: logging.info("Reconnecting | Addrs availables") From 4f23e38a08dbe2c03cfe1abe32dcffdf0d0b3f29 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Tue, 18 Feb 2025 16:52:40 +0100 Subject: [PATCH 099/233] feature situational awareness module functionalities --- nebula/core/engine.py | 11 +- nebula/core/network/blacklist.py | 2 + nebula/core/network/communications.py | 4 +- nebula/core/network/connection.py | 5 +- .../awareness/samodule.py | 166 ++++++++++++++- .../core/situationalawareness/nodemanager.py | 194 +++--------------- 6 files changed, 199 insertions(+), 183 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index e2a324ea1..85abfcb5c 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -152,12 +152,10 @@ def __init__( topology = self.config.participant["mobility_args"]["topology_type"] topology = topology.lower() model_handler = "std" # self.config.participant["mobility_args"]["model_handler"] - acceleration_push = "fast" # self.config.participant["mobility_args"]["push_strategy"] self._node_manager = NodeManager( config.participant["mobility_args"]["additional_node"]["status"], topology, model_handler, - acceleration_push, engine=self, ) @@ -582,16 +580,12 @@ async def trigger_event(self, message_event): async def _aditional_node_start(self): self.update_sinchronized_status(False) logging.info(f"Aditional node | {self.addr} | going to stablish connection with federation") - self.nm.late_config() await self.nm.start_late_connection_process() # continue .. # asyncio.create_task(self.nm.stop_not_selected_connections()) logging.info("Creating trainer service to start the federation process..") asyncio.create_task(self._start_learning_late()) - def get_push_acceleration(self): - return self.nm.get_push_acceleration() - async def set_pushed_done(self, rounds_push): await self.nm.set_rounds_pushed(rounds_push) @@ -892,10 +886,7 @@ async def _additional_mobility_actions(self): if not self.mobility: return logging.info("πŸ”„ Starting additional mobility actions...") - await self.nm.check_robustness() - action = await self.nm.check_external_connection_service_status() - if action: - action() + await self.nm.mobility_actions() def reputation_calculation(self, aggregated_models_weights): cossim_threshold = 0.5 diff --git a/nebula/core/network/blacklist.py b/nebula/core/network/blacklist.py index 32b2e398f..ac7bdd531 100644 --- a/nebula/core/network/blacklist.py +++ b/nebula/core/network/blacklist.py @@ -55,6 +55,7 @@ async def get_blacklist(self) -> set: async def clear_blacklist(self): await self._blacklisted_nodes_lock.acquire_async() + logging.info(f"🧹 Removing nodes from blacklist") self._blacklisted_nodes.clear() await self._blacklisted_nodes_lock.release_async() @@ -119,6 +120,7 @@ async def add_recently_disconnected(self, addr): async def clear_recently_disconected(self): self._recently_disconnected_lock.acquire_async() + logging.info(f"🧹 Removing nodes from Recently Disconencted list") self._recently_disconnected.clear() self._recently_disconnected_lock.release_async() diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 9541290c6..6e6f4d15a 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -466,8 +466,8 @@ async def deploy_additional_services(self): self._generate_network_conditions() await self._forwarder.start() if self.config.participant["mobility_args"]["mobility"]: - #pass - await self._discoverer.start() + pass + #await self._discoverer.start() # await self._health.start() self._propagator.start() await self._mobility.start() diff --git a/nebula/core/network/connection.py b/nebula/core/network/connection.py index b2b341b8c..3794bdee4 100755 --- a/nebula/core/network/connection.py +++ b/nebula/core/network/connection.py @@ -314,8 +314,8 @@ async def handle_incoming_message(self) -> None: self.incompleted_reconnections = 0 if is_last_chunk: await self._process_complete_message(message_id) - except asyncio.CancelledError: - logging.info("Message handling cancelled") + except asyncio.CancelledError as e: + logging.exception(f"Message handling cancelled: {e}") except ConnectionError as e: logging.exception(f"Connection closed while reading: {e}") except Exception as e: @@ -324,6 +324,7 @@ async def handle_incoming_message(self) -> None: logging.exception(f"Error handling incoming message: {e}") finally: if self.direct: + #TODO tal vez una task? await self.reconnect() async def _read_exactly(self, num_bytes: int, max_retries: int = 3) -> bytes: diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index e80542215..03aea5345 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -1,2 +1,166 @@ import asyncio -import logging \ No newline at end of file +import logging +from nebula.core.situationalawareness.neighborpolicies.neighborpolicy import factory_NeighborPolicy +from nebula.core.utils.locker import Locker +from nebula.addons.functions import print_msg_box + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from nebula.core.situationalawareness.nodemanager import NodeManager + +RESTRUCTURE_COOLDOWN = 5 + +class SAModule(): + def __init__( + self, + nodemanager, + addr, + topology, + ): + print_msg_box( + msg=f"Starting Situational Awareness module...\nTopology: {topology}", indent=2, title="Situational Awareness module" + ) + logging.info("🌐 Initializing SAModule") + self._addr = addr + self._topology = topology + self._node_manager : NodeManager = nodemanager + self._neighbor_policy = factory_NeighborPolicy(topology) + self._restructure_process_lock = Locker(name="restructure_process_lock") + self._restructure_cooldown = 0 + + @property + def nm(self): + return self._node_manager + + @property + def np(self): + return self._neighbor_policy + + @property + def cm(self): + return self.nm.engine.cm + + async def init(self): + logging.info("Building neighbor policy configuration..") + self.np.set_config([ + await self.cm.get_addrs_current_connections(only_direct=True, myself=False), + await self.cm.get_addrs_current_connections(only_direct=False, only_undirected=False, myself=False), + self.nm.engine.addr, + self, + ]) + + + """ ############################### + # REESTRUCTURE TOPOLOGY # + ############################### + """ + + def _update_restructure_cooldown(self): + if self._restructure_cooldown: + self._restructure_cooldown = (self._restructure_cooldown + 1) % RESTRUCTURE_COOLDOWN + + def _restructure_available(self): + if self._restructure_cooldown: + logging.info("Reestructure on cooldown") + return self._restructure_cooldown == 0 + + def get_restructure_process_lock(self): + return self._restructure_process_lock + + + """ ############################### + # NEIGHBOR POLICY # + ############################### + """ + + def meet_node(self, node): + if node != self._addr: + logging.info(f"Update nodes known | addr: {node}") + self.np.meet_node(node) + + def update_neighbors(self, node, remove=False): + self.np.update_neighbors(node, remove) + if not remove: + self.np.meet_node(node) + + def get_nodes_known(self, neighbors_too=False, neighbors_only=False): + return self.np.get_nodes_known(neighbors_too, neighbors_only) + + async def neighbors_left(self): + return len(await self.cm.get_addrs_current_connections(only_direct=True, myself=False)) > 0 + + def accept_connection(self, source, joining=False): + return self.np.accept_connection(source, joining) + + def need_more_neighbors(self): + return self.np.need_more_neighbors() + + def get_actions(self): + return self.np.get_actions() + + + """ ############################### + # ROBUSTNESS # + ############################### + """ + + async def check_external_connection_service_status(self): + if not await self.nm.engine.cm.is_external_connection_service_running(): + logging.info("πŸ”„ External Service not running | Starting service...") + self.nm.engine.cm.init_external_connection_service() + + async def analize_topology_robustness(self): + logging.info("πŸ”„ Analizing node network robustness...") + if not self._restructure_process_lock.locked(): + if not await self.neighbors_left(): + logging.info("No Neighbors left | reconnecting with Federation") + #await self.reconnect_to_federation() + elif (self.np.need_more_neighbors() and self._restructure_available()): + logging.info("Insufficient Robustness | Upgrading robustness | Searching for more connections") + self._update_restructure_cooldown() + possible_neighbors = self.np.get_nodes_known(neighbors_too=False) + possible_neighbors = await self.cm.apply_restrictions(possible_neighbors) + if not possible_neighbors: + logging.info("All possible neighbors using nodes known are restricted...") + else: + pass + #asyncio.create_task(self.upgrade_connection_robustness(possible_neighbors)) + else: + logging.info("Sufficient Robustness | no actions required") + else: + logging.info("❗️ Reestructure/Reconnecting process already running...") + + async def reconnect_to_federation(self): + self._restructure_process_lock.acquire() + await self.cm.clear_restrictions() + await asyncio.sleep(120) + # If we got some refs, try to reconnect to them + if len(self.np.get_nodes_known()) > 0: + logging.info("Reconnecting | Addrs availables") + await self.nm.start_late_connection_process(connected=False, msg_type="discover_nodes", addrs_known=self.np.get_nodes_known()) + else: + logging.info("Reconnecting | NO Addrs availables") + await self.nm.start_late_connection_process(connected=False, msg_type="discover_nodes") + self._restructure_process_lock.release() + + async def upgrade_connection_robustness(self, possible_neighbors): + self._restructure_process_lock.acquire() + #addrs_to_connect = self.neighbor_policy.get_nodes_known(neighbors_too=False) + # If we got some refs, try to connect to them + if len(possible_neighbors) > 0: + logging.info(f"Reestructuring | Addrs availables | addr list: {possible_neighbors}") + await self.nm.start_late_connection_process(connected=True, msg_type="discover_nodes", addrs_known=possible_neighbors) + else: + logging.info("Reestructuring | NO Addrs availables") + await self.nm.start_late_connection_process(connected=True, msg_type="discover_nodes") + self._restructure_process_lock.release() + + async def stop_connections_with_federation(self): + await asyncio.sleep(200) + logging.info("### DISCONNECTING FROM FEDERATON ###") + neighbors = self.np.get_nodes_known(neighbors_only=True) + for n in neighbors: + await self.cm.add_to_blacklist(n) + for n in neighbors: + await self.cm.disconnect(n, mutual_disconnection=False, forced=True) + \ No newline at end of file diff --git a/nebula/core/situationalawareness/nodemanager.py b/nebula/core/situationalawareness/nodemanager.py index 69d7b2b8b..95140552c 100644 --- a/nebula/core/situationalawareness/nodemanager.py +++ b/nebula/core/situationalawareness/nodemanager.py @@ -7,7 +7,8 @@ from nebula.core.situationalawareness.fastreboot import FastReboot from nebula.core.situationalawareness.modelhandlers.modelhandler import factory_ModelHandler from nebula.core.situationalawareness.momentum import Momentum -from nebula.core.situationalawareness.neighborpolicies.neighborpolicy import factory_NeighborPolicy +#from nebula.core.situationalawareness.neighborpolicies.neighborpolicy import factory_NeighborPolicy +from nebula.core.situationalawareness.awareness.samodule import SAModule from nebula.core.utils.locker import Locker if TYPE_CHECKING: @@ -22,7 +23,6 @@ def __init__( aditional_participant, topology, model_handler, - push_acceleration, engine: "Engine", fastreboot=False, momentum=False, @@ -35,8 +35,6 @@ def __init__( logging.info("🌐 Initializing Node Manager") self._engine = engine self.config = engine.get_config() - logging.info("Initializing Neighbor policy") - self._neighbor_policy = factory_NeighborPolicy(self.topology) logging.info("Initializing Candidate Selector") self._candidate_selector = factory_CandidateSelector(self.topology) logging.info("Initializing Model Handler") @@ -47,27 +45,23 @@ def __init__( self.pending_confirmation_from_nodes_lock = Locker(name="pending_confirmation_from_nodes_lock", async_lock=True) self.accept_candidates_lock = Locker(name="accept_candidates_lock") self.recieve_offer_timer = 5 - self._restructure_process_lock = Locker(name="restructure_process_lock") - self.restructure = False - self._restructure_cooldown = 0 self.discarded_offers_addr_lock = Locker(name="discarded_offers_addr_lock") self.discarded_offers_addr = [] - self._push_acceleration = push_acceleration - - self.synchronizing_rounds = False self._fast_reboot_status = fastreboot self._momemtum_status = momentum self._desc_done = False #TODO remove + + self._situational_awareness_module = SAModule(self, self.engine.addr, topology) @property def engine(self): return self._engine - @property - def neighbor_policy(self): - return self._neighbor_policy + # @property + # def neighbor_policy(self): + # return self._neighbor_policy @property def candidate_selector(self): @@ -82,32 +76,14 @@ def fr(self): return self._fastreboot @property - def mom(self): - return self._momemtum + def sam(self): + return self._situational_awareness_module def fast_reboot_on(self): return self._fast_reboot_status - def _update_restructure_cooldown(self): - if self._restructure_cooldown: - self._restructure_cooldown = (self._restructure_cooldown + 1) % RESTRUCTURE_COOLDOWN - - def _restructure_available(self): - if self._restructure_cooldown: - logging.info("Reestructure on cooldown") - return self._restructure_cooldown == 0 - - def get_push_acceleration(self): - return self._push_acceleration - def get_restructure_process_lock(self): - return self._restructure_process_lock - - def set_synchronizing_rounds(self, status): - self.synchronizing_rounds = status - - def get_syncrhonizing_rounds(self): - return self.synchronizing_rounds + return self.sam.get_restructure_process_lock() async def set_rounds_pushed(self, rp): if self.fast_reboot_on(): @@ -118,11 +94,6 @@ def still_waiting_for_candidates(self): async def set_configs(self): """ - neighbor_policy config: - - direct connections a.k.a neighbors - - all nodes known - - self addr - model_handler config: - self total rounds - self current round @@ -133,13 +104,7 @@ async def set_configs(self): - self weight distance - self weight hetereogeneity """ - logging.info("Building neighbor policy configuration..") - self.neighbor_policy.set_config([ - await self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False), - await self.engine.cm.get_addrs_current_connections(only_direct=False, only_undirected=False, myself=False), - self.engine.addr, - self, - ]) + await self.sam.init() logging.info("Building candidate selector configuration..") self.candidate_selector.set_config([0, 0.5, 0.5]) # self.engine.trainer.get_loss(), self.config.participant["molibity_args"]["weight_distance"], self.config.participant["molibity_args"]["weight_het"] @@ -147,18 +112,6 @@ async def set_configs(self): if self._fast_reboot_status: self._fastreboot = FastReboot(self) - self._momemtum = None - if self._momemtum_status and not self._aditional_participant: - self._momemtum = Momentum( - self, self.neighbor_policy.get_nodes_known(neighbors_only=True), dispersion_penalty=False - ) - - def late_config(self): - if self._momemtum_status: - self._momemtum = Momentum( - self, self.neighbor_policy.get_nodes_known(neighbors_only=True), dispersion_penalty=False - ) - """ ############################## # WEIGHT STRATEGIES # @@ -172,8 +125,6 @@ async def register_late_neighbor(self, addr, joinning_federation=False): logging.info(f"Registering | late neighbor: {addr}, joining: {joinning_federation}") self.meet_node(addr) await self.update_neighbors(addr) - if self._momemtum_status: - await self.mom.update_node(addr) if joinning_federation: if self.fast_reboot_on(): await self.fr.add_fastReboot_addr(addr) @@ -181,8 +132,6 @@ async def register_late_neighbor(self, addr, joinning_federation=False): async def apply_weight_strategy(self, updates: dict): if self.fast_reboot_on(): await self.fr.apply_weight_strategy(updates) - if self._momemtum: - await self._momemtum.calculate_momentum_weights(updates) """ ############################## @@ -191,12 +140,12 @@ async def apply_weight_strategy(self, updates: dict): """ def accept_connection(self, source, joining=False): - return self.neighbor_policy.accept_connection(source, joining) + return self.sam.accept_connection(source, joining) async def add_pending_connection_confirmation(self, addr): await self._update_neighbors_lock.acquire_async() await self.pending_confirmation_from_nodes_lock.acquire_async() - if addr not in self.neighbor_policy.get_nodes_known(neighbors_only=True): + if addr not in self.sam.get_nodes_known(neighbors_only=True): logging.info(f" Addition | pending connection confirmation from: {addr}") self.pending_confirmation_from_nodes.add(addr) await self.pending_confirmation_from_nodes_lock.release_async() @@ -232,38 +181,28 @@ def add_to_discarded_offers(self, addr_discarded): self.discarded_offers_addr_lock.release() def need_more_neighbors(self): - return self.neighbor_policy.need_more_neighbors() + return self.sam.need_more_neighbors() def get_actions(self): - return self.neighbor_policy.get_actions() + return self.sam.get_actions() async def update_neighbors(self, node, remove=False): logging.info(f"Update neighbor | node addr: {node} | remove: {remove}") await self._update_neighbors_lock.acquire_async() - self.neighbor_policy.update_neighbors(node, remove) - # self.timer_generator.update_node(node, remove) + self.sam.update_neighbors(node, remove) if remove: if self._fast_reboot_status: self.fr.discard_fastreboot_for(node) - if self._momemtum_status: - await self.mom.update_node(node, remove=True) else: - self.neighbor_policy.meet_node(node) - if self._momemtum_status: - await self.mom.update_node(node) + self.sam.meet_node(node) self._remove_pending_confirmation_from(node) await self._update_neighbors_lock.release_async() - async def neighbors_left(self): - return len(await self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False)) > 0 - def meet_node(self, node): - if node != self.engine.addr: - logging.info(f"Update nodes known | addr: {node}") - self.neighbor_policy.meet_node(node) + self.sam.meet_node(node) def get_nodes_known(self, neighbors_too=False): - return self.neighbor_policy.get_nodes_known(neighbors_too) + return self.sam.get_nodes_known(neighbors_too) def accept_model_offer(self, source, decoded_model, rounds, round, epochs, n_neighbors, loss): if not self.accept_candidates_lock.locked(): @@ -283,9 +222,6 @@ def add_candidate(self, source, n_neighbors, loss): if not self.accept_candidates_lock.locked(): self.candidate_selector.add_candidate((source, n_neighbors, loss)) - async def currently_reestructuring(self): - return self._restructure_process_lock.locked() - async def stop_not_selected_connections(self): try: with self.discarded_offers_addr_lock: @@ -303,26 +239,6 @@ async def stop_not_selected_connections(self): except asyncio.CancelledError: pass - #TODO todo esto es innecesario - async def check_external_connection_service_status(self): - logging.info("πŸ”„ Checking external connection service status...") - n = await self.neighbors_left() - ecs = await self.engine.cm.is_external_connection_service_running() - ss = self.engine.get_sinchronized_status() - action = None - logging.info(f"Stats | neighbors: {n} | service running: {ecs} | synchronized status: {ss}") - if not await self.neighbors_left() and await self.engine.cm.is_external_connection_service_running(): - logging.info("❗️ Isolated node | Shutdowning service required") - #action = lambda: self.engine.cm.stop_external_connection_service() - elif ( - await self.neighbors_left() - and not await self.engine.cm.is_external_connection_service_running() - and self.engine.get_sinchronized_status() - ): - logging.info("πŸ”„ NOT isolated node | Service not running | Starting service...") - action = lambda: self.engine.cm.init_external_connection_service() - return action - async def start_late_connection_process(self, connected=False, msg_type="discover_join", addrs_known=None): """ This function represents the process of discovering the federation and stablish the first @@ -344,6 +260,7 @@ async def start_late_connection_process(self, connected=False, msg_type="discove # wait offer #TODO actualizar con la informacion de latencias + logging.info(f"Connections stablish after finding federation: {connections_stablished}") if connections_stablished: logging.info(f"Waiting: {self.recieve_offer_timer}s to receive offers from federation") await asyncio.sleep(self.recieve_offer_timer) @@ -369,13 +286,13 @@ async def start_late_connection_process(self, connected=False, msg_type="discove await asyncio.sleep(1) except asyncio.CancelledError: await self.update_neighbors(addr, remove=True) - pass + logging.info("Error during stablishment") self.accept_candidates_lock.release() self.late_connection_process_lock.release() self.candidate_selector.remove_candidates() if not self._desc_done: #TODO remove self._desc_done = True - #asyncio.create_task(self.stop_connections_with_federation()) + asyncio.create_task(self.sam.stop_connections_with_federation()) # if no candidates, repeat process else: logging.info("❗️ No Candidates found...") @@ -391,66 +308,7 @@ async def start_late_connection_process(self, connected=False, msg_type="discove ############################## """ - async def check_robustness(self): - logging.info("πŸ”„ Analizing node network robustness...") - logging.info(f"Synchronization status: {self.engine.get_sinchronized_status()} | got neighbors: {await self.neighbors_left()}") - if not self._restructure_process_lock.locked(): - if not await self.neighbors_left(): - logging.info("No Neighbors left | reconnecting with Federation") - self.engine.update_sinchronized_status(False) - #await self.reconnect_to_federation() - elif ( - self.neighbor_policy.need_more_neighbors() - and self.engine.get_sinchronized_status() - and self._restructure_available() - ): - logging.info("Insufficient Robustness | Upgrading robustness | Searching for more connections") - self._update_restructure_cooldown() - possible_neighbors = self.neighbor_policy.get_nodes_known(neighbors_too=False) - possible_neighbors = await self.engine.cm.apply_restrictions(possible_neighbors) - if not possible_neighbors: - logging.info("All possible neighbors using nodes known are restricted...") - else: - pass - #asyncio.create_task(self.upgrade_connection_robustness(possible_neighbors)) - else: - if not self.engine.get_sinchronized_status(): - logging.info("Device not synchronized with federation") - else: - logging.info("Sufficient Robustness | no actions required") - else: - logging.info("❗️ Reestructure/Reconnecting process already running...") - - async def reconnect_to_federation(self): - self._restructure_process_lock.acquire() - await self.engine.cm.clear_restrictions() - #await asyncio.sleep(120) - # If we got some refs, try to reconnect to them - if len(self.neighbor_policy.get_nodes_known()) > 0: - logging.info("Reconnecting | Addrs availables") - await self.start_late_connection_process(connected=False, msg_type="discover_nodes", addrs_known=self.neighbor_policy.get_nodes_known()) - else: - logging.info("Reconnecting | NO Addrs availables") - await self.start_late_connection_process(connected=False, msg_type="discover_nodes") - self._restructure_process_lock.release() - - async def upgrade_connection_robustness(self, possible_neighbors): - self._restructure_process_lock.acquire() - #addrs_to_connect = self.neighbor_policy.get_nodes_known(neighbors_too=False) - # If we got some refs, try to connect to them - if len(possible_neighbors) > 0: - logging.info(f"Reestructuring | Addrs availables | addr list: {possible_neighbors}") - await self.start_late_connection_process(connected=True, msg_type="discover_nodes", addrs_known=possible_neighbors) - else: - logging.info("Reestructuring | NO Addrs availables") - await self.start_late_connection_process(connected=True, msg_type="discover_nodes") - self._restructure_process_lock.release() - - async def stop_connections_with_federation(self): - await asyncio.sleep(100) - logging.info("### DISCONNECTING FROM FEDERATON ###") - neighbors = self.neighbor_policy.get_nodes_known(neighbors_only=True) - for n in neighbors: - await self.engine.cm.add_to_blacklist(n) - for n in neighbors: - await self.engine.cm.disconnect(n, mutual_disconnection=False, forced=True) + async def mobility_actions(self): + await self.sam.check_external_connection_service_status() + await self.sam.analize_topology_robustness() + From f1364f053b8bfd332e8277ddcb3edf56bfb3f6b4 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 19 Feb 2025 14:00:55 +0100 Subject: [PATCH 100/233] feature nebula discover service asynchronous --- .../attacks/communications/floodingattack.py | 70 +++++++++ nebula/core/network/communications.py | 41 +++--- .../core/network/externalconnectionservice.py | 7 +- nebula/core/network/nebuladiscoveryservice.py | 137 ++++++++++++++++++ .../awareness/samodule.py | 11 +- .../core/situationalawareness/nodemanager.py | 13 +- 6 files changed, 250 insertions(+), 29 deletions(-) create mode 100644 nebula/addons/attacks/communications/floodingattack.py create mode 100644 nebula/core/network/nebuladiscoveryservice.py diff --git a/nebula/addons/attacks/communications/floodingattack.py b/nebula/addons/attacks/communications/floodingattack.py new file mode 100644 index 000000000..35379b2c8 --- /dev/null +++ b/nebula/addons/attacks/communications/floodingattack.py @@ -0,0 +1,70 @@ +import asyncio +import logging +from functools import wraps + +from nebula.addons.attacks.communications.communicationattack import CommunicationAttack + + +class FloodingAttack(CommunicationAttack): + """ + Implements an attack that delays the execution of a target method by a specified amount of time. + """ + + def __init__(self, engine, attack_params: dict): + """ + Initializes the DelayerAttack with the engine and attack parameters. + + Args: + engine: The engine managing the attack context. + attack_params (dict): Parameters for the attack, including the delay duration. + """ + try: + + round_start = int(attack_params["round_start_attack"]) + round_stop = int(attack_params["round_stop_attack"]) + self.flooding_factor = 100 #int(attack_params["flooding_factor"]) + self.target_percentage = 50#int(attack_params["target_percentage"]) + self.selection_interval = 1#int(attack_params["selection_interval"]) + except KeyError as e: + raise ValueError(f"Missing required attack parameter: {e}") + except ValueError: + raise ValueError("Invalid value in attack_params. Ensure all values are integers.") + + super().__init__( + engine, + engine._cm, + "send_message", + round_start, + round_stop, + self.flooding_factor, + self.target_percentage, + self.selection_interval, + ) + + def decorator(self, flooding_factor: int): + """ + Decorator that adds a delay to the execution of the original method. + + Args: + flooding_factor (int): The number of times to repeat the function execution. + + Returns: + function: A decorator function that wraps the target method with the delay logic. + """ + + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + if len(args) > 1: + dest_addr = args[1] + if dest_addr in self.targets: + logging.info(f"[FloodingAttack] Flooding message to {dest_addr} by {flooding_factor} times") + for i in range(flooding_factor): + logging.info(f"[FloodingAttack] Sending duplicate {i+1}/{flooding_factor} to {dest_addr}") + await func(*args, **kwargs) + _, *new_args = args # Exclude self argument + return await func(*new_args) + + return wrapper + + return decorator \ No newline at end of file diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 6e6f4d15a..939a764b9 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -14,7 +14,9 @@ from nebula.core.network.discoverer import Discoverer from nebula.core.network.forwarder import Forwarder from nebula.core.network.messages import MessagesManager -from nebula.core.network.nebulamulticasting import NebulaConnectionService +#from nebula.core.network.nebulamulticasting import NebulaConnectionService +from nebula.core.network.externalconnectionservice import ExternalConnectionService +from nebula.core.network.nebuladiscoveryservice import NebulaConnectionService #TODO mover a mΓ©todo factoria from nebula.core.network.propagator import Propagator from nebula.core.network.messages import MessageEvent from nebula.core.utils.helper import ( @@ -82,18 +84,18 @@ def __init__(self, engine: "Engine"): self._blacklist = BlackList() # Connection service to communicate with external devices - self._external_connection_service = None + self._external_connection_service : ExternalConnectionService = None # The line below is neccesary when mobility would be set up - mob = self.config.participant["mobility_args"]["mobility"] - aditional_node = self.config.participant["mobility_args"]["additional_node"]["status"] - if mob == True and not aditional_node: - self._external_connection_service = NebulaConnectionService(self.addr) - logging.info("Deploying External Connection Service") - self.ecs.start() - else: - logging.info("Deploying External Connection Service | No running") - self._external_connection_service = NebulaConnectionService(self.addr) + # mob = self.config.participant["mobility_args"]["mobility"] + # aditional_node = self.config.participant["mobility_args"]["additional_node"]["status"] + # if mob == True and not aditional_node: + # self._external_connection_service = NebulaConnectionService(self.addr) + # ) + # self.ecs.start() + # else: + + # self._external_connection_service = NebulaConnectionService(self.addr) @property def engine(self): @@ -203,16 +205,17 @@ async def clear_restrictions(self): ############################### """ - def start_external_connection_service(self): + async def start_external_connection_service(self, run_service=True): if self.ecs == None: - self.ecs = NebulaConnectionService(self.addr) - self.ecs.start() + self._external_connection_service = NebulaConnectionService(self.addr) + if run_service: + await self.ecs.start() - def stop_external_connection_service(self): - self.ecs.stop() + async def stop_external_connection_service(self): + await self.ecs.stop() - def init_external_connection_service(self): - self.start_external_connection_service() + async def init_external_connection_service(self): + await self.start_external_connection_service() async def is_external_connection_service_running(self): return self.ecs.is_running() @@ -229,7 +232,7 @@ async def stablish_connection_to_federation(self, msg_type="discover_join", addr addrs = [] if addrs_known == None: logging.info("Searching federation process beginning...") - addrs = self.ecs.find_federation() + addrs = await self.ecs.find_federation() logging.info(f"Found federation devices | addrs {addrs}") else: logging.info(f"Searching federation process beginning... | Using addrs previously known {addrs_known}") diff --git a/nebula/core/network/externalconnectionservice.py b/nebula/core/network/externalconnectionservice.py index d4bc00834..0c84563b0 100644 --- a/nebula/core/network/externalconnectionservice.py +++ b/nebula/core/network/externalconnectionservice.py @@ -1,13 +1,14 @@ +import asyncio from abc import ABC, abstractmethod class ExternalConnectionService(ABC): @abstractmethod - def start(self): + async def start(self): pass @abstractmethod - def stop(self): + async def stop(self): pass @abstractmethod @@ -15,5 +16,5 @@ def is_running(self): pass @abstractmethod - def find_federation(self): + async def find_federation(self): pass \ No newline at end of file diff --git a/nebula/core/network/nebuladiscoveryservice.py b/nebula/core/network/nebuladiscoveryservice.py new file mode 100644 index 000000000..2971d0cad --- /dev/null +++ b/nebula/core/network/nebuladiscoveryservice.py @@ -0,0 +1,137 @@ +import asyncio +import logging +import socket +import struct +from nebula.core.network.externalconnectionservice import ExternalConnectionService + +class NebulaServerProtocol(asyncio.DatagramProtocol): + BCAST_IP = '239.255.255.250' + UPNP_PORT = 1900 + + def __init__(self, nebula_service, addr): + self.nebula_service = nebula_service + self.addr = addr + self.transport = None + + def connection_made(self, transport): + self.transport = transport + logging.info("Nebula UPnP server is listening...") + + def datagram_received(self, data, addr): + logging.info("Server service receiving information") + if self._is_nebula_message(data): + logging.info("Nebula request received, responding...") + asyncio.create_task(self.respond(addr)) + + async def respond(self, addr): + try: + response = ("HTTP/1.1 200 OK\r\n" + "CACHE-CONTROL: max-age=1800\r\n" + "ST: urn:nebula-service\r\n" + "EXT:\r\n" + f"LOCATION: {self.addr}\r\n") + self.transport.sendto(response.encode('ASCII'), addr) + except Exception as e: + logging.error(f"Error responding to client: {e}") + + def _is_nebula_message(self, msg): + return "ST: urn:nebula-service" in msg.decode('utf-8') + +class NebulaClientProtocol(asyncio.DatagramProtocol): + BCAST_IP = '239.255.255.250' + BCAST_PORT = 1900 + SEARCH_TRIES = 5 + SEARCH_INTERVAL = 3 + + def __init__(self, nebula_service): + self.nebula_service : NebulaConnectionService = nebula_service + self.transport = None + self.search_done = asyncio.Event() + + def connection_made(self, transport): + self.transport = transport + sock = self.transport.get_extra_info('socket') + if sock is not None: + sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 2) + asyncio.create_task(self.keep_search()) + + async def keep_search(self): + logging.info("Federation searching loop started") + # while True: + for _ in range(self.SEARCH_TRIES): + await self.search() + await asyncio.sleep(self.SEARCH_INTERVAL) + self.search_done.set() + + async def wait_for_search(self): + await self.search_done.wait() + + async def search(self): + logging.info("Searching for nodes...") + try: + search_request = ("M-SEARCH * HTTP/1.1\r\n" + "HOST: 239.255.255.250:1900\r\n" + "MAN: \"ssdp:discover\"\r\n" + "MX: 1\r\n" + "ST: urn:nebula-service\r\n" + "\r\n") + self.transport.sendto(search_request.encode('ASCII'), (self.BCAST_IP, self.BCAST_PORT)) + except Exception as e: + logging.error(f"Error sending search request: {e}") + + def datagram_received(self, data, addr): + if "ST: urn:nebula-service" in data.decode('utf-8'): + logging.info("Received response from server") + self.nebula_service.response_received(data, addr) + +class NebulaConnectionService(ExternalConnectionService): + def __init__(self, addr): + self.nodes_found = set() + self.addr = addr + self.server : NebulaServerProtocol = None + self.client : NebulaClientProtocol = None + self.running = False + + async def start(self): + loop = asyncio.get_running_loop() + transport, self.server = await loop.create_datagram_endpoint( + lambda: NebulaServerProtocol(self, self.addr), + local_addr=('0.0.0.0', 1900)) + try: + # Advanced socket settings + sock = transport.get_extra_info('socket') + if sock is not None: + group = socket.inet_aton('239.255.255.250') # Multicast to binary format. + mreq = struct.pack('4sL', group, socket.INADDR_ANY) # Join multicast group in every interface available + sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq) # SO listen multicast packages + except Exception as e: + logging.exception(f"{e}") + self.running = True + + async def stop(self): + if self.server and self.server.transport: + self.server.transport.close() + self.running = False + + def is_running(self): + return self.running + + async def find_federation(self): + logging.info(f"Node {self.addr} trying to find federation...") + loop = asyncio.get_running_loop() + transport, self.client = await loop.create_datagram_endpoint( + lambda: NebulaClientProtocol(self), + local_addr=('0.0.0.0', 0)) # To listen on all network interfaces + await self.client.wait_for_search() + transport.close() + return self.nodes_found + + def response_received(self, data, addr): + logging.info("Parsing response...") + msg_str = data.decode('utf-8') + for line in msg_str.splitlines(): + if line.strip().startswith("LOCATION:"): + addr = line.split(": ")[1].strip() + if addr != self.addr: + logging.info(f"Device address received: {addr}") + self.nodes_found.add(addr) \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index 03aea5345..005a6cb37 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -41,6 +41,13 @@ def cm(self): return self.nm.engine.cm async def init(self): + if not self.nm.is_additional_participant(): + logging.info("Deploying External Connection Service") + await self.cm.start_external_connection_service() + else: + logging.info("Deploying External Connection Service | No running") + await self.cm.start_external_connection_service(run_service=False) + logging.info("Building neighbor policy configuration..") self.np.set_config([ await self.cm.get_addrs_current_connections(only_direct=True, myself=False), @@ -105,9 +112,9 @@ def get_actions(self): """ async def check_external_connection_service_status(self): - if not await self.nm.engine.cm.is_external_connection_service_running(): + if not await self.cm.is_external_connection_service_running(): logging.info("πŸ”„ External Service not running | Starting service...") - self.nm.engine.cm.init_external_connection_service() + await self.cm.init_external_connection_service() async def analize_topology_robustness(self): logging.info("πŸ”„ Analizing node network robustness...") diff --git a/nebula/core/situationalawareness/nodemanager.py b/nebula/core/situationalawareness/nodemanager.py index 95140552c..72175e090 100644 --- a/nebula/core/situationalawareness/nodemanager.py +++ b/nebula/core/situationalawareness/nodemanager.py @@ -82,15 +82,12 @@ def sam(self): def fast_reboot_on(self): return self._fast_reboot_status - def get_restructure_process_lock(self): - return self.sam.get_restructure_process_lock() - async def set_rounds_pushed(self, rp): if self.fast_reboot_on(): self.fr.set_rounds_pushed(rp) - def still_waiting_for_candidates(self): - return not self.accept_candidates_lock.locked() + def is_additional_participant(self): + return self._aditional_participant async def set_configs(self): """ @@ -139,8 +136,14 @@ async def apply_weight_strategy(self, updates: dict): ############################## """ + def get_restructure_process_lock(self): + return self.sam.get_restructure_process_lock() + def accept_connection(self, source, joining=False): return self.sam.accept_connection(source, joining) + + def still_waiting_for_candidates(self): + return not self.accept_candidates_lock.locked() async def add_pending_connection_confirmation(self, addr): await self._update_neighbors_lock.acquire_async() From c2e1d7c61950c078af8a51da02b1603e5341ecee Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 19 Feb 2025 17:09:33 +0100 Subject: [PATCH 101/233] optimization code --- nebula/core/aggregation/aggregator.py | 9 +++------ nebula/core/network/communications.py | 19 +++---------------- .../core/network/externalconnectionservice.py | 19 ++++++++++++++++++- nebula/core/network/nebuladiscoveryservice.py | 12 ++++++++---- 4 files changed, 32 insertions(+), 27 deletions(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index 695221865..f395f140f 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -44,14 +44,11 @@ def __init__(self, config=None, engine=None): self._waiting_global_update = False self._pending_models_to_aggregate = {} self._pending_models_to_aggregate_lock = Locker(name="pending_models_to_aggregate_lock", async_lock=True) - self._future_models_to_aggregate = {} - self._add_model_lock = Locker(name="add_model_lock", async_lock=True) - self._add_next_model_lock = Locker(name="add_next_model_lock", async_lock=True) self._aggregation_done_lock = Locker(name="aggregation_done_lock", async_lock=True) self._aggregation_waiting_skip = asyncio.Event() - self._push_strategy_lock = Locker(name="push_strategy_lock", async_lock=True) - self._end_round_push = 0 - self._update_storage = factory_update_handler("DFL", self, self._addr) #TODO use json config + + scenario = self.config.participant["scenario_args"]["federation"] + self._update_storage = factory_update_handler(scenario, self, self._addr) def __str__(self): return self.__class__.__name__ diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 939a764b9..d57b10a49 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -14,9 +14,7 @@ from nebula.core.network.discoverer import Discoverer from nebula.core.network.forwarder import Forwarder from nebula.core.network.messages import MessagesManager -#from nebula.core.network.nebulamulticasting import NebulaConnectionService -from nebula.core.network.externalconnectionservice import ExternalConnectionService -from nebula.core.network.nebuladiscoveryservice import NebulaConnectionService #TODO mover a mΓ©todo factoria +from nebula.core.network.externalconnectionservice import factory_connection_service from nebula.core.network.propagator import Propagator from nebula.core.network.messages import MessageEvent from nebula.core.utils.helper import ( @@ -84,18 +82,7 @@ def __init__(self, engine: "Engine"): self._blacklist = BlackList() # Connection service to communicate with external devices - self._external_connection_service : ExternalConnectionService = None - - # The line below is neccesary when mobility would be set up - # mob = self.config.participant["mobility_args"]["mobility"] - # aditional_node = self.config.participant["mobility_args"]["additional_node"]["status"] - # if mob == True and not aditional_node: - # self._external_connection_service = NebulaConnectionService(self.addr) - # ) - # self.ecs.start() - # else: - - # self._external_connection_service = NebulaConnectionService(self.addr) + self._external_connection_service = factory_connection_service("nebula", self.addr) @property def engine(self): @@ -207,7 +194,7 @@ async def clear_restrictions(self): async def start_external_connection_service(self, run_service=True): if self.ecs == None: - self._external_connection_service = NebulaConnectionService(self.addr) + self._external_connection_service = factory_connection_service(self.addr) #NebulaConnectionService(self.addr) if run_service: await self.ecs.start() diff --git a/nebula/core/network/externalconnectionservice.py b/nebula/core/network/externalconnectionservice.py index 0c84563b0..6df2aba59 100644 --- a/nebula/core/network/externalconnectionservice.py +++ b/nebula/core/network/externalconnectionservice.py @@ -17,4 +17,21 @@ def is_running(self): @abstractmethod async def find_federation(self): - pass \ No newline at end of file + pass + +class ExternalConnectionServiceException(Exception): + pass + +def factory_connection_service(con_serv, addr) -> ExternalConnectionService: + from nebula.core.network.nebuladiscoveryservice import NebulaConnectionService + + CONNECTION_SERVICES = { + "nebula": NebulaConnectionService, + } + + con_serv = CONNECTION_SERVICES.get(con_serv, NebulaConnectionService) + + if con_serv: + return con_serv(addr) + else: + raise ExternalConnectionServiceException(f"Connection Service {con_serv} not found") \ No newline at end of file diff --git a/nebula/core/network/nebuladiscoveryservice.py b/nebula/core/network/nebuladiscoveryservice.py index 2971d0cad..7b424384e 100644 --- a/nebula/core/network/nebuladiscoveryservice.py +++ b/nebula/core/network/nebuladiscoveryservice.py @@ -18,7 +18,6 @@ def connection_made(self, transport): logging.info("Nebula UPnP server is listening...") def datagram_received(self, data, addr): - logging.info("Server service receiving information") if self._is_nebula_message(data): logging.info("Nebula request received, responding...") asyncio.create_task(self.respond(addr)) @@ -80,10 +79,15 @@ async def search(self): logging.error(f"Error sending search request: {e}") def datagram_received(self, data, addr): - if "ST: urn:nebula-service" in data.decode('utf-8'): - logging.info("Received response from server") - self.nebula_service.response_received(data, addr) + try: + if "ST: urn:nebula-service" in data.decode('utf-8'): + logging.info("Received response from Node server-service") + self.nebula_service.response_received(data, addr) + except UnicodeDecodeError: + logging.warning(f"Received malformed message from {addr}, ignoring.") + +#TODO si la busqueda no devuelve nada nuevo, dejar de hacerla para eliminar trΓ‘fico inutil class NebulaConnectionService(ExternalConnectionService): def __init__(self, addr): self.nodes_found = set() From 1f28f77ec6f13d7b17e7ef5be9b6bc01e3654e60 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 20 Feb 2025 11:56:17 +0100 Subject: [PATCH 102/233] feature beacon service --- .../updatehandlers/dflupdatehandler.py | 2 + nebula/core/network/communications.py | 14 ++- .../core/network/externalconnectionservice.py | 16 +++ nebula/core/network/nebuladiscoveryservice.py | 98 +++++++++++++++++-- .../neighborpolicies/__init__.py | 0 .../neighborpolicies/fcneighborpolicy.py | 2 +- .../neighborpolicies/idleneighborpolicy.py | 2 +- .../neighborpolicies/neighborpolicy.py | 8 +- .../neighborpolicies/ringneighborpolicy.py | 2 +- .../neighborpolicies/starneighborpolicy.py | 2 +- .../awareness/samodule.py | 8 +- 11 files changed, 136 insertions(+), 18 deletions(-) rename nebula/core/situationalawareness/{ => awareness}/neighborpolicies/__init__.py (100%) rename nebula/core/situationalawareness/{ => awareness}/neighborpolicies/fcneighborpolicy.py (97%) rename nebula/core/situationalawareness/{ => awareness}/neighborpolicies/idleneighborpolicy.py (97%) rename nebula/core/situationalawareness/{ => awareness}/neighborpolicies/neighborpolicy.py (70%) rename nebula/core/situationalawareness/{ => awareness}/neighborpolicies/ringneighborpolicy.py (96%) rename nebula/core/situationalawareness/{ => awareness}/neighborpolicies/starneighborpolicy.py (96%) diff --git a/nebula/core/aggregation/updatehandlers/dflupdatehandler.py b/nebula/core/aggregation/updatehandlers/dflupdatehandler.py index 032f8f0b2..9ba0a280b 100644 --- a/nebula/core/aggregation/updatehandlers/dflupdatehandler.py +++ b/nebula/core/aggregation/updatehandlers/dflupdatehandler.py @@ -147,6 +147,8 @@ async def notify_federation_update(self, source, remove=False): if not source in self._sources_received: # Not received update from this source yet await self._update_source(source, remove=True) await self._all_updates_received() # Verify if discarding node aggregation could be done + else: + logging.info(f"Already received update from: {source}, it will be discarded next round") async def _update_source(self, source, remove=False): logging.info(f"πŸ”„ Update | remove: {remove} | source: {source}") diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index d57b10a49..1f75bfaf0 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -194,7 +194,7 @@ async def clear_restrictions(self): async def start_external_connection_service(self, run_service=True): if self.ecs == None: - self._external_connection_service = factory_connection_service(self.addr) #NebulaConnectionService(self.addr) + self._external_connection_service = factory_connection_service(self.addr) if run_service: await self.ecs.start() @@ -206,6 +206,18 @@ async def init_external_connection_service(self): async def is_external_connection_service_running(self): return self.ecs.is_running() + + async def start_beacon(self): + await self.ecs.start_beacon() + + async def stop_beacon(self): + await self.ecs.stop_beacon() + + async def subscribe_beacon_listener(self, listener): + await self.ecs.subscribe_beacon_listener(listener) + + async def modify_beacon_frequency(self, frequency): + await self.ecs.modify_beacon_frequency(frequency) #TODO # si se utilizan addr conocidas y no se consigue conectar a ninguna quΓ© hacer diff --git a/nebula/core/network/externalconnectionservice.py b/nebula/core/network/externalconnectionservice.py index 6df2aba59..689d8e738 100644 --- a/nebula/core/network/externalconnectionservice.py +++ b/nebula/core/network/externalconnectionservice.py @@ -18,6 +18,22 @@ def is_running(self): @abstractmethod async def find_federation(self): pass + + @abstractmethod + async def start_beacon(self): + pass + + @abstractmethod + async def stop_beacon(self): + pass + + @abstractmethod + async def modify_beacon_frequency(self, frequency): + pass + + @abstractmethod + async def subscribe_beacon_listener(self, listener): + pass class ExternalConnectionServiceException(Exception): pass diff --git a/nebula/core/network/nebuladiscoveryservice.py b/nebula/core/network/nebuladiscoveryservice.py index 7b424384e..b9c50a9d8 100644 --- a/nebula/core/network/nebuladiscoveryservice.py +++ b/nebula/core/network/nebuladiscoveryservice.py @@ -3,13 +3,16 @@ import socket import struct from nebula.core.network.externalconnectionservice import ExternalConnectionService +from nebula.core.utils.locker import Locker class NebulaServerProtocol(asyncio.DatagramProtocol): BCAST_IP = '239.255.255.250' UPNP_PORT = 1900 + DISCOVER_MESSAGE = "TYPE: discover" + BEACON_MESSAGE = "TYPE: beacon" def __init__(self, nebula_service, addr): - self.nebula_service = nebula_service + self.nebula_service : NebulaConnectionService = nebula_service self.addr = addr self.transport = None @@ -18,23 +21,37 @@ def connection_made(self, transport): logging.info("Nebula UPnP server is listening...") def datagram_received(self, data, addr): - if self._is_nebula_message(data): - logging.info("Nebula request received, responding...") - asyncio.create_task(self.respond(addr)) + msg = data.decode('utf-8') + if self._is_nebula_message(msg): + logging.info("Nebula message received...") + if self.DISCOVER_MESSAGE in msg: + logging.info("Discovery request received, responding...") + asyncio.create_task(self.respond(addr)) + elif self.BEACON_MESSAGE in msg: + asyncio.create_task(self.handle_beacon_received(msg, addr)) async def respond(self, addr): try: response = ("HTTP/1.1 200 OK\r\n" "CACHE-CONTROL: max-age=1800\r\n" "ST: urn:nebula-service\r\n" - "EXT:\r\n" - f"LOCATION: {self.addr}\r\n") + "TYPE: response\r\n" + f"LOCATION: {self.addr}\r\n" + "\r\n") self.transport.sendto(response.encode('ASCII'), addr) except Exception as e: logging.error(f"Error responding to client: {e}") + async def handle_beacon_received(self, msg): + for line in msg.splitlines(): + if line.startswith("LOCATION:"): + beacon_addr = line.split(": ")[1].strip() + if beacon_addr != self.addr: + logging.info(f"Beacon received from: {beacon_addr}") + await self.nebula_service.notify_beacon_received(beacon_addr) + def _is_nebula_message(self, msg): - return "ST: urn:nebula-service" in msg.decode('utf-8') + return "ST: urn:nebula-service" in msg class NebulaClientProtocol(asyncio.DatagramProtocol): BCAST_IP = '239.255.255.250' @@ -73,6 +90,7 @@ async def search(self): "MAN: \"ssdp:discover\"\r\n" "MX: 1\r\n" "ST: urn:nebula-service\r\n" + "TYPE: discover\r\n" "\r\n") self.transport.sendto(search_request.encode('ASCII'), (self.BCAST_IP, self.BCAST_PORT)) except Exception as e: @@ -86,6 +104,42 @@ def datagram_received(self, data, addr): except UnicodeDecodeError: logging.warning(f"Received malformed message from {addr}, ignoring.") +class NebulaBeacon: + def __init__(self, addr, interval=7): + self.addr = addr + self.interval = interval # Intervalo de envΓ­o en segundos + self.running = False + + async def start(self): + logging.info("[NebulaBeacon]: Starting sending pressence beacon") + self.running = True + while self.running: + await self.send_beacon() + await asyncio.sleep(self.interval) + + async def stop(self): + logging.info("[NebulaBeacon]: Stop existance beacon") + self.running = False + + async def modify_beacon_frequency(self, frequency): + logging.info(f"[NebulaBeacon]: Changing beacon frequency from {self.interval}s to {frequency}s") + self.interval = frequency + + async def send_beacon(self): + try: + message = ("NOTIFY * HTTP/1.1\r\n" + "HOST: 239.255.255.250:1900\r\n" + "ST: urn:nebula-service\r\n" + "TYPE: beacon\r\n" + f"LOCATION: {self.addr}\r\n" + "\r\n") + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) + sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 2) + sock.sendto(message.encode('ASCII'), ('239.255.255.250', 1900)) + sock.close() + logging.info("Beacon sent") + except Exception as e: + logging.error(f"Error sending beacon: {e}") #TODO si la busqueda no devuelve nada nuevo, dejar de hacerla para eliminar trΓ‘fico inutil class NebulaConnectionService(ExternalConnectionService): @@ -94,7 +148,10 @@ def __init__(self, addr): self.addr = addr self.server : NebulaServerProtocol = None self.client : NebulaClientProtocol = None + self.beacon : NebulaBeacon = NebulaBeacon(self.addr) self.running = False + self._beacon_listeners_lock = Locker(name="beacon_listeners_lock", async_lock=True) + self._beacon_listeners = [] async def start(self): loop = asyncio.get_running_loop() @@ -117,6 +174,20 @@ async def stop(self): self.server.transport.close() self.running = False + async def start_beacon(self): + if not self.beacon: + self.beacon = NebulaBeacon(self.addr) + asyncio.create_task(self.beacon.start()) + + async def stop_beacon(self): + if self.beacon: + await self.beacon.stop() + #self.beacon = None + + async def modify_beacon_frequency(self, frequency): + if self.beacon: + await self.beacon.modify_beacon_frequency(frequency=frequency) + def is_running(self): return self.running @@ -138,4 +209,15 @@ def response_received(self, data, addr): addr = line.split(": ")[1].strip() if addr != self.addr: logging.info(f"Device address received: {addr}") - self.nodes_found.add(addr) \ No newline at end of file + self.nodes_found.add(addr) + + async def subscribe_beacon_listener(self, listener : callable): + await self._beacon_listeners_lock.acquire_async() + self._beacon_listeners.append(listener) + await self._beacon_listeners_lock.release_async() + + async def notify_beacon_received(self, addr): + await self._beacon_listeners_lock.acquire_async() + for bec_listener in self._beacon_listeners: + await bec_listener(addr) + await self._beacon_listeners_lock.release_async() \ No newline at end of file diff --git a/nebula/core/situationalawareness/neighborpolicies/__init__.py b/nebula/core/situationalawareness/awareness/neighborpolicies/__init__.py similarity index 100% rename from nebula/core/situationalawareness/neighborpolicies/__init__.py rename to nebula/core/situationalawareness/awareness/neighborpolicies/__init__.py diff --git a/nebula/core/situationalawareness/neighborpolicies/fcneighborpolicy.py b/nebula/core/situationalawareness/awareness/neighborpolicies/fcneighborpolicy.py similarity index 97% rename from nebula/core/situationalawareness/neighborpolicies/fcneighborpolicy.py rename to nebula/core/situationalawareness/awareness/neighborpolicies/fcneighborpolicy.py index 0d12cee7d..17ded13e5 100644 --- a/nebula/core/situationalawareness/neighborpolicies/fcneighborpolicy.py +++ b/nebula/core/situationalawareness/awareness/neighborpolicies/fcneighborpolicy.py @@ -1,4 +1,4 @@ -from nebula.core.situationalawareness.neighborpolicies.neighborpolicy import NeighborPolicy +from nebula.core.situationalawareness.awareness.neighborpolicies.neighborpolicy import NeighborPolicy from nebula.core.utils.locker import Locker class FCNeighborPolicy(NeighborPolicy): diff --git a/nebula/core/situationalawareness/neighborpolicies/idleneighborpolicy.py b/nebula/core/situationalawareness/awareness/neighborpolicies/idleneighborpolicy.py similarity index 97% rename from nebula/core/situationalawareness/neighborpolicies/idleneighborpolicy.py rename to nebula/core/situationalawareness/awareness/neighborpolicies/idleneighborpolicy.py index fefff7dc8..02292aee9 100644 --- a/nebula/core/situationalawareness/neighborpolicies/idleneighborpolicy.py +++ b/nebula/core/situationalawareness/awareness/neighborpolicies/idleneighborpolicy.py @@ -1,4 +1,4 @@ -from nebula.core.situationalawareness.neighborpolicies.neighborpolicy import NeighborPolicy +from nebula.core.situationalawareness.awareness.neighborpolicies.neighborpolicy import NeighborPolicy from nebula.core.utils.locker import Locker class IDLENeighborPolicy(NeighborPolicy): diff --git a/nebula/core/situationalawareness/neighborpolicies/neighborpolicy.py b/nebula/core/situationalawareness/awareness/neighborpolicies/neighborpolicy.py similarity index 70% rename from nebula/core/situationalawareness/neighborpolicies/neighborpolicy.py rename to nebula/core/situationalawareness/awareness/neighborpolicies/neighborpolicy.py index 43352c7d7..c5e598b3e 100644 --- a/nebula/core/situationalawareness/neighborpolicies/neighborpolicy.py +++ b/nebula/core/situationalawareness/awareness/neighborpolicies/neighborpolicy.py @@ -36,10 +36,10 @@ def update_neighbors(self, node, remove=False): pass def factory_NeighborPolicy(topology) -> NeighborPolicy: - from nebula.core.situationalawareness.neighborpolicies.idleneighborpolicy import IDLENeighborPolicy - from nebula.core.situationalawareness.neighborpolicies.fcneighborpolicy import FCNeighborPolicy - from nebula.core.situationalawareness.neighborpolicies.ringneighborpolicy import RINGNeighborPolicy - from nebula.core.situationalawareness.neighborpolicies.starneighborpolicy import STARNeighborPolicy + from nebula.core.situationalawareness.awareness.neighborpolicies.idleneighborpolicy import IDLENeighborPolicy + from nebula.core.situationalawareness.awareness.neighborpolicies.fcneighborpolicy import FCNeighborPolicy + from nebula.core.situationalawareness.awareness.neighborpolicies.ringneighborpolicy import RINGNeighborPolicy + from nebula.core.situationalawareness.awareness.neighborpolicies.starneighborpolicy import STARNeighborPolicy options = { "random": IDLENeighborPolicy, # default value diff --git a/nebula/core/situationalawareness/neighborpolicies/ringneighborpolicy.py b/nebula/core/situationalawareness/awareness/neighborpolicies/ringneighborpolicy.py similarity index 96% rename from nebula/core/situationalawareness/neighborpolicies/ringneighborpolicy.py rename to nebula/core/situationalawareness/awareness/neighborpolicies/ringneighborpolicy.py index 602c85f5c..8db66e1f4 100644 --- a/nebula/core/situationalawareness/neighborpolicies/ringneighborpolicy.py +++ b/nebula/core/situationalawareness/awareness/neighborpolicies/ringneighborpolicy.py @@ -1,4 +1,4 @@ -from nebula.core.situationalawareness.neighborpolicies.neighborpolicy import NeighborPolicy +from nebula.core.situationalawareness.awareness.neighborpolicies.neighborpolicy import NeighborPolicy from nebula.core.utils.locker import Locker import random diff --git a/nebula/core/situationalawareness/neighborpolicies/starneighborpolicy.py b/nebula/core/situationalawareness/awareness/neighborpolicies/starneighborpolicy.py similarity index 96% rename from nebula/core/situationalawareness/neighborpolicies/starneighborpolicy.py rename to nebula/core/situationalawareness/awareness/neighborpolicies/starneighborpolicy.py index 561a48530..87931d2c7 100644 --- a/nebula/core/situationalawareness/neighborpolicies/starneighborpolicy.py +++ b/nebula/core/situationalawareness/awareness/neighborpolicies/starneighborpolicy.py @@ -1,4 +1,4 @@ -from nebula.core.situationalawareness.neighborpolicies.neighborpolicy import NeighborPolicy +from nebula.core.situationalawareness.awareness.neighborpolicies.neighborpolicy import NeighborPolicy from nebula.core.utils.locker import Locker class STARNeighborPolicy(NeighborPolicy): diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index 005a6cb37..934935e17 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -1,6 +1,6 @@ import asyncio import logging -from nebula.core.situationalawareness.neighborpolicies.neighborpolicy import factory_NeighborPolicy +from nebula.core.situationalawareness.awareness.neighborpolicies.neighborpolicy import factory_NeighborPolicy from nebula.core.utils.locker import Locker from nebula.addons.functions import print_msg_box @@ -44,6 +44,8 @@ async def init(self): if not self.nm.is_additional_participant(): logging.info("Deploying External Connection Service") await self.cm.start_external_connection_service() + await self.cm.subscribe_beacon_listener(self.beacon_received) + await self.cm.start_beacon() else: logging.info("Deploying External Connection Service | No running") await self.cm.start_external_connection_service(run_service=False) @@ -56,6 +58,8 @@ async def init(self): self, ]) + async def beacon_received(self): + logging.info("Beacon received SAModule") """ ############################### # REESTRUCTURE TOPOLOGY # @@ -115,6 +119,8 @@ async def check_external_connection_service_status(self): if not await self.cm.is_external_connection_service_running(): logging.info("πŸ”„ External Service not running | Starting service...") await self.cm.init_external_connection_service() + await self.cm.subscribe_beacon_listener(None) + await self.cm.start_beacon() async def analize_topology_robustness(self): logging.info("πŸ”„ Analizing node network robustness...") From 3d5ff25321e3f291b1fce8c1ee61b3b03bb25b35 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 20 Feb 2025 13:10:17 +0100 Subject: [PATCH 103/233] fix fully integrated beacon --- nebula/core/engine.py | 1 + nebula/core/network/nebuladiscoveryservice.py | 9 +- nebula/core/network/nebulamulticasting.py | 214 ------------------ .../awareness/samodule.py | 13 +- .../core/situationalawareness/nodemanager.py | 3 + 5 files changed, 18 insertions(+), 222 deletions(-) delete mode 100644 nebula/core/network/nebulamulticasting.py diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 85abfcb5c..32235ab23 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -846,6 +846,7 @@ async def _learning_cycle(self): ) # Set current round in config (send to the controller) await self.get_round_lock().release_async() + await self.nm.experiment_finish() # End of the learning cycle self.trainer.on_learning_cycle_end() await self.trainer.test() diff --git a/nebula/core/network/nebuladiscoveryservice.py b/nebula/core/network/nebuladiscoveryservice.py index b9c50a9d8..58213512e 100644 --- a/nebula/core/network/nebuladiscoveryservice.py +++ b/nebula/core/network/nebuladiscoveryservice.py @@ -23,12 +23,12 @@ def connection_made(self, transport): def datagram_received(self, data, addr): msg = data.decode('utf-8') if self._is_nebula_message(msg): - logging.info("Nebula message received...") + #logging.info("Nebula message received...") if self.DISCOVER_MESSAGE in msg: logging.info("Discovery request received, responding...") asyncio.create_task(self.respond(addr)) elif self.BEACON_MESSAGE in msg: - asyncio.create_task(self.handle_beacon_received(msg, addr)) + asyncio.create_task(self.handle_beacon_received(msg)) async def respond(self, addr): try: @@ -105,7 +105,7 @@ def datagram_received(self, data, addr): logging.warning(f"Received malformed message from {addr}, ignoring.") class NebulaBeacon: - def __init__(self, addr, interval=7): + def __init__(self, addr, interval=20): self.addr = addr self.interval = interval # Intervalo de envΓ­o en segundos self.running = False @@ -170,8 +170,10 @@ async def start(self): self.running = True async def stop(self): + logging.info("Stop Nebula Connection Service") if self.server and self.server.transport: self.server.transport.close() + await self.beacon.stop() self.running = False async def start_beacon(self): @@ -213,6 +215,7 @@ def response_received(self, data, addr): async def subscribe_beacon_listener(self, listener : callable): await self._beacon_listeners_lock.acquire_async() + logging.info("Registering beacon listener...") self._beacon_listeners.append(listener) await self._beacon_listeners_lock.release_async() diff --git a/nebula/core/network/nebulamulticasting.py b/nebula/core/network/nebulamulticasting.py deleted file mode 100644 index d65a06dc7..000000000 --- a/nebula/core/network/nebulamulticasting.py +++ /dev/null @@ -1,214 +0,0 @@ - -import os -import socket -import sys -import platform -import time -import threading -import logging -from nebula.core.utils.locker import Locker -from nebula.core.network.externalconnectionservice import ExternalConnectionService - -class NebulaServer(threading.Thread): - - BCAST_IP = '239.255.255.250' - UPNP_PORT = 1900 - IP = '0.0.0.0' - M_SEARCH_REQ_MATCH = "M-SEARCH" - - def __init__(self, nebula_service: "NebulaConnectionService", addr): - threading.Thread.__init__(self) - self.interrupted = False - self.ns = nebula_service - self.addr = addr - - def run(self): - self.listen() - - def stop(self): - self.interrupted = True - logging.info("Nebula upnp server stop") - - def is_running(self): - return not self.interrupted - - def listen(self): - """ - Listen on broadcast addr with standard 1900 port - It will reponse a standard ssdp message with blockchain ip and port info if receive a M_SEARCH message - """ - try: - macro = socket.SO_REUSEPORT - os_name = platform.system() - if os_name == "Windows": - macro = socket.SO_REUSEADDR - - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - sock.setsockopt(socket.SOL_SOCKET, macro, 1) - sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, socket.inet_aton(self.BCAST_IP) + socket.inet_aton(self.IP)) - sock.bind((self.IP, self.UPNP_PORT)) - sock.settimeout(1) - logging.info("Nebula upnp server is listening...") - while True: - try: - data, addr = sock.recvfrom(1024) - except socket.error: - if self.interrupted: - sock.close() - return - else: - if self._is_nebula_message(data): - logging.info("Nebula request recieved | response on the way..") - self.respond(addr) - #time.sleep(1) - #self.stop() - except Exception as e: - logging.error('Error in Nebula npnp server listening: %s', e) - - def _is_nebula_message(self, msg): - msg_str = msg.decode('utf-8') - return "ST: urn:nebula-service" in msg_str - - def respond(self, addr): - try: - #local_ip = # FIND THE IP - UPNP_RESPOND = """HTTP/1.1 200 OK - CACHE-CONTROL: max-age=1800 - ST: urn:nebula-service - EXT: - LOCATION: {} - """.format( - self.addr - ).replace("\n", "\r\n") - outSock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - outSock.sendto(UPNP_RESPOND.encode('ASCII'), addr) - outSock.close() - except Exception as e: - logging.error('Error in Nebula upnp response message to client %s', e) - -class NebulaClient(threading.Thread): - # 30 seconds for search_interval - SEARCH_INTERVAL = 5 - BCAST_IP = '239.255.255.250' - BCAST_PORT = 1900 - - def __init__(self, nebula_service: "NebulaConnectionService"): - logging.info("Initializating Nebula Multicasting Client") - threading.Thread.__init__(self) - self.interrupted = False - self.ns = nebula_service - - def run(self): - self.keep_search() - - def stop(self): - self.interrupted = True - logging.info(" Nebula upnp client stop") - - def keep_search(self): - """ - run search function every SEARCH_INTERVAL - """ - logging.info("Federation searching loop start") - try: - while True: - self.search() - for x in range(self.SEARCH_INTERVAL): - time.sleep(1) - if self.interrupted: - return - except Exception as e: - logging.error('Error in Nebula upnp client keep search %s', e) - - def search(self): - """ - broadcast SSDP DISCOVER message to LAN network - filter our protocal and add to network - """ - logging.info("Client thread searching for nodes..") - try: - SSDP_DISCOVER = ('M-SEARCH * HTTP/1.1\r\n' + - 'HOST: 239.255.255.250:1900\r\n' + - 'MAN: "ssdp:discover"\r\n' + - 'MX: 1\r\n' + - 'ST: urn:nebula-service\r\n' + - '\r\n') - - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - sock.sendto(SSDP_DISCOVER.encode('ASCII'), (self.BCAST_IP, self.BCAST_PORT)) - sock.settimeout(3) - while True: - data, addr = sock.recvfrom(1024) - if self._is_nebula_message(data): - logging.info("Recieved response from server") - self.ns.response_recieved(data, addr) - except: - sock.close() - - def _is_nebula_message(self, msg): - msg_str = msg.decode('utf-8') - return "ST: urn:nebula-service" in msg_str - -class NebulaConnectionService(ExternalConnectionService): - - def __init__(self, addr): - self.addrs_found_lock = Locker(name="addrs_found_lock") - self.get_nodes_lock= Locker(name="get_nodes_lock") - self.nodes_found = [] - self.repeatsearch_interval = 3 - self.addr = addr - self.server = None - self.client = None - - def start(self): - self.server = NebulaServer(self, self.addr) - self.server.start() - - def stop(self): - self.server.stop() - - def is_running(self): - if self.server: - return self.server.is_running() - else: - return False - - def find_federation(self): - """ - Initialization client thread to send broadcast discover to federation - """ - logging.info(f"Node {self.addr} trying to find federation..") - self.nodes_found = [] - self.client = NebulaClient(self) - self.client.start() - time.sleep(self.repeatsearch_interval) - while len(self.get_nodes()) == 0: - logging.info("Waiting for server response..") - time.sleep(self.repeatsearch_interval) - self.client.stop() - return self.get_nodes() - - def response_recieved(self, data, addr): - logging.info("Parsing response..") - msg_str = data.decode('utf-8') - self._add_addr(msg_str) - - def _add_addr(self, msg_str): - self.addrs_found_lock.acquire() - lineas = msg_str.splitlines() - # Buscar la lΓ­nea que contiene "LOCATION: " - for linea in lineas: - if linea.strip().startswith("LOCATION:"): - addr = linea.split(": ")[1].strip() - break - if addr != self.addr: - logging.info(f"Device addr received: {addr}") - self.nodes_found.append(addr) - self.addrs_found_lock.release() - - def get_nodes(self): - self.get_nodes_lock.acquire() - cp = self.nodes_found.copy() - self.get_nodes_lock.release() - return cp - \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index 934935e17..fa672789a 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -57,9 +57,9 @@ async def init(self): self.nm.engine.addr, self, ]) - - async def beacon_received(self): - logging.info("Beacon received SAModule") + + async def experiment_finish(self): + await self.cm.stop_external_connection_service() """ ############################### # REESTRUCTURE TOPOLOGY # @@ -115,11 +115,14 @@ def get_actions(self): ############################### """ + async def beacon_received(self, addr): + logging.info(f"Beacon received SAModule, source:{addr}") + async def check_external_connection_service_status(self): if not await self.cm.is_external_connection_service_running(): logging.info("πŸ”„ External Service not running | Starting service...") await self.cm.init_external_connection_service() - await self.cm.subscribe_beacon_listener(None) + await self.cm.subscribe_beacon_listener(self.beacon_received) await self.cm.start_beacon() async def analize_topology_robustness(self): @@ -127,7 +130,7 @@ async def analize_topology_robustness(self): if not self._restructure_process_lock.locked(): if not await self.neighbors_left(): logging.info("No Neighbors left | reconnecting with Federation") - #await self.reconnect_to_federation() + await self.reconnect_to_federation() elif (self.np.need_more_neighbors() and self._restructure_available()): logging.info("Insufficient Robustness | Upgrading robustness | Searching for more connections") self._update_restructure_cooldown() diff --git a/nebula/core/situationalawareness/nodemanager.py b/nebula/core/situationalawareness/nodemanager.py index 72175e090..fd020f3f5 100644 --- a/nebula/core/situationalawareness/nodemanager.py +++ b/nebula/core/situationalawareness/nodemanager.py @@ -108,6 +108,9 @@ async def set_configs(self): if self._fast_reboot_status: self._fastreboot = FastReboot(self) + + async def experiment_finish(self): + await self.sam.experiment_finish() """ ############################## From f4951d6ff5909784070fd2e14bbb5c2c4a570a03 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 20 Feb 2025 15:11:07 +0100 Subject: [PATCH 104/233] feature geolocalization in beacon --- nebula/addons/mobility.py | 28 +++++++++ nebula/core/engine.py | 3 + nebula/core/network/communications.py | 7 ++- .../core/network/externalconnectionservice.py | 4 +- nebula/core/network/nebuladiscoveryservice.py | 57 +++++++++++++------ .../awareness/samodule.py | 10 +++- .../core/situationalawareness/nodemanager.py | 3 + 7 files changed, 89 insertions(+), 23 deletions(-) diff --git a/nebula/addons/mobility.py b/nebula/addons/mobility.py index 41ffd4038..906cdb807 100755 --- a/nebula/addons/mobility.py +++ b/nebula/addons/mobility.py @@ -391,6 +391,34 @@ async def change_connections_based_on_distance(self): logging.exception("πŸ“ Error changing connections based on distance") return + async def calculate_network_conditions(self, distance): + thresholds = sorted(self.network_conditions.keys()) + + # Si la distancia es menor que el primer umbral, devolver la mejor condiciΓ³n + if distance < thresholds[0]: + return self.network_conditions[thresholds[0]] + + # Encontrar el tramo en el que se encuentra la distancia + for i in range(len(thresholds) - 1): + lower_bound = thresholds[i] + upper_bound = thresholds[i + 1] + + if lower_bound <= distance < upper_bound: + lower_cond = self.network_conditions[lower_bound] + upper_cond = self.network_conditions[upper_bound] + + # Calcular el progreso en el tramo (0 a 1) + progress = (distance - lower_bound) / (upper_bound - lower_bound) + + # InterpolaciΓ³n lineal de valores + bandwidth = lower_cond["bandwidth"] - progress * (lower_cond["bandwidth"] - upper_cond["bandwidth"]) + delay = lower_cond["delay"] + progress * (upper_cond["delay"] - lower_cond["delay"]) + + return {"bandwidth": round(bandwidth, 2), "delay": round(delay, 2)} + + # Si la distancia es infinita, devolver el ΓΊltimo valor + return self.network_conditions[float("inf")] + async def change_connections(self): """ Changes the connections of the entity based on the specified mobility scheme. diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 32235ab23..22f234954 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -577,6 +577,9 @@ def register_message_callback(self, message_event: tuple[str, str], callback: st async def trigger_event(self, message_event): await self.event_manager.publish(message_event) + async def get_geoloc(self): + return await self.nm.get_geoloc() + async def _aditional_node_start(self): self.update_sinchronized_status(False) logging.info(f"Aditional node | {self.addr} | going to stablish connection with federation") diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 1f75bfaf0..a0bf96086 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -82,7 +82,7 @@ def __init__(self, engine: "Engine"): self._blacklist = BlackList() # Connection service to communicate with external devices - self._external_connection_service = factory_connection_service("nebula", self.addr) + self._external_connection_service = factory_connection_service("nebula", self, self.addr) @property def engine(self): @@ -191,10 +191,13 @@ async def clear_restrictions(self): # EXTERNAL CONNECTION SERVICE # ############################### """ + + async def get_geoloc(self): + return await self.engine.get_geoloc() async def start_external_connection_service(self, run_service=True): if self.ecs == None: - self._external_connection_service = factory_connection_service(self.addr) + self._external_connection_service = factory_connection_service(self, self.addr) if run_service: await self.ecs.start() diff --git a/nebula/core/network/externalconnectionservice.py b/nebula/core/network/externalconnectionservice.py index 689d8e738..b5c864bb8 100644 --- a/nebula/core/network/externalconnectionservice.py +++ b/nebula/core/network/externalconnectionservice.py @@ -38,7 +38,7 @@ async def subscribe_beacon_listener(self, listener): class ExternalConnectionServiceException(Exception): pass -def factory_connection_service(con_serv, addr) -> ExternalConnectionService: +def factory_connection_service(con_serv, cm, addr) -> ExternalConnectionService: from nebula.core.network.nebuladiscoveryservice import NebulaConnectionService CONNECTION_SERVICES = { @@ -48,6 +48,6 @@ def factory_connection_service(con_serv, addr) -> ExternalConnectionService: con_serv = CONNECTION_SERVICES.get(con_serv, NebulaConnectionService) if con_serv: - return con_serv(addr) + return con_serv(cm, addr) else: raise ExternalConnectionServiceException(f"Connection Service {con_serv} not found") \ No newline at end of file diff --git a/nebula/core/network/nebuladiscoveryservice.py b/nebula/core/network/nebuladiscoveryservice.py index 58213512e..e47673d83 100644 --- a/nebula/core/network/nebuladiscoveryservice.py +++ b/nebula/core/network/nebuladiscoveryservice.py @@ -5,6 +5,10 @@ from nebula.core.network.externalconnectionservice import ExternalConnectionService from nebula.core.utils.locker import Locker +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from nebula.core.network.communications import CommunicationsManager + class NebulaServerProtocol(asyncio.DatagramProtocol): BCAST_IP = '239.255.255.250' UPNP_PORT = 1900 @@ -43,12 +47,22 @@ async def respond(self, addr): logging.error(f"Error responding to client: {e}") async def handle_beacon_received(self, msg): - for line in msg.splitlines(): - if line.startswith("LOCATION:"): - beacon_addr = line.split(": ")[1].strip() - if beacon_addr != self.addr: - logging.info(f"Beacon received from: {beacon_addr}") - await self.nebula_service.notify_beacon_received(beacon_addr) + lines = msg.split("\r\n") + beacon_data = {} + + for line in lines: + if ": " in line: + key, value = line.split(": ", 1) + beacon_data[key] = value + + # Verificar que no sea el propio beacon + beacon_addr = beacon_data.get("LOCATION") + if beacon_addr == self.addr: + return + + latitude = float(beacon_data.get("LATITUDE", 0.0)) + longitude = float(beacon_data.get("LONGITUDE", 0.0)) + await self.nebula_service.notify_beacon_received(beacon_addr, (latitude, longitude)) def _is_nebula_message(self, msg): return "ST: urn:nebula-service" in msg @@ -56,7 +70,7 @@ def _is_nebula_message(self, msg): class NebulaClientProtocol(asyncio.DatagramProtocol): BCAST_IP = '239.255.255.250' BCAST_PORT = 1900 - SEARCH_TRIES = 5 + SEARCH_TRIES = 3 SEARCH_INTERVAL = 3 def __init__(self, nebula_service): @@ -105,7 +119,8 @@ def datagram_received(self, data, addr): logging.warning(f"Received malformed message from {addr}, ignoring.") class NebulaBeacon: - def __init__(self, addr, interval=20): + def __init__(self, nebula_service, addr, interval=20): + self.nebula_service : NebulaConnectionService = nebula_service self.addr = addr self.interval = interval # Intervalo de envΓ­o en segundos self.running = False @@ -126,12 +141,15 @@ async def modify_beacon_frequency(self, frequency): self.interval = frequency async def send_beacon(self): + latitude, longitude = await self.nebula_service.cm.get_geoloc() try: message = ("NOTIFY * HTTP/1.1\r\n" "HOST: 239.255.255.250:1900\r\n" "ST: urn:nebula-service\r\n" "TYPE: beacon\r\n" f"LOCATION: {self.addr}\r\n" + f"LATITUDE: {latitude}\r\n" + f"LONGITUDE: {longitude}\r\n" "\r\n") sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 2) @@ -141,17 +159,21 @@ async def send_beacon(self): except Exception as e: logging.error(f"Error sending beacon: {e}") -#TODO si la busqueda no devuelve nada nuevo, dejar de hacerla para eliminar trΓ‘fico inutil class NebulaConnectionService(ExternalConnectionService): - def __init__(self, addr): + def __init__(self, cm: "CommunicationsManager", addr): + self._cm = cm self.nodes_found = set() self.addr = addr self.server : NebulaServerProtocol = None self.client : NebulaClientProtocol = None - self.beacon : NebulaBeacon = NebulaBeacon(self.addr) + self.beacon : NebulaBeacon = NebulaBeacon(self, self.addr) self.running = False self._beacon_listeners_lock = Locker(name="beacon_listeners_lock", async_lock=True) self._beacon_listeners = [] + + @property + def cm(self): + return self._cm async def start(self): loop = asyncio.get_running_loop() @@ -178,7 +200,7 @@ async def stop(self): async def start_beacon(self): if not self.beacon: - self.beacon = NebulaBeacon(self.addr) + self.beacon = NebulaBeacon(self, self.addr) asyncio.create_task(self.beacon.start()) async def stop_beacon(self): @@ -204,14 +226,15 @@ async def find_federation(self): return self.nodes_found def response_received(self, data, addr): - logging.info("Parsing response...") + #logging.info("Parsing response...") msg_str = data.decode('utf-8') for line in msg_str.splitlines(): if line.strip().startswith("LOCATION:"): addr = line.split(": ")[1].strip() if addr != self.addr: - logging.info(f"Device address received: {addr}") - self.nodes_found.add(addr) + if addr not in self.nodes_found: + logging.info(f"Device address received: {addr}") + self.nodes_found.add(addr) async def subscribe_beacon_listener(self, listener : callable): await self._beacon_listeners_lock.acquire_async() @@ -219,8 +242,8 @@ async def subscribe_beacon_listener(self, listener : callable): self._beacon_listeners.append(listener) await self._beacon_listeners_lock.release_async() - async def notify_beacon_received(self, addr): + async def notify_beacon_received(self, addr, geoloc): await self._beacon_listeners_lock.acquire_async() for bec_listener in self._beacon_listeners: - await bec_listener(addr) + await bec_listener(addr, geoloc) await self._beacon_listeners_lock.release_async() \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index fa672789a..069c80055 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -109,14 +109,20 @@ def need_more_neighbors(self): def get_actions(self): return self.np.get_actions() + async def get_geoloc(self): + latitude = self.nm.config.participant["mobility_args"]["latitude"] + longitude = self.nm.config.participant["mobility_args"]["longitude"] + return (latitude,longitude) + """ ############################### # ROBUSTNESS # ############################### """ - async def beacon_received(self, addr): - logging.info(f"Beacon received SAModule, source:{addr}") + async def beacon_received(self, addr, geoloc): + latitude, longitude = geoloc + logging.info(f"Beacon received SAModule, source: {addr}, geolocalization: {latitude},{longitude}") async def check_external_connection_service_status(self): if not await self.cm.is_external_connection_service_running(): diff --git a/nebula/core/situationalawareness/nodemanager.py b/nebula/core/situationalawareness/nodemanager.py index fd020f3f5..3e3c58557 100644 --- a/nebula/core/situationalawareness/nodemanager.py +++ b/nebula/core/situationalawareness/nodemanager.py @@ -314,6 +314,9 @@ async def start_late_connection_process(self, connected=False, msg_type="discove ############################## """ + async def get_geoloc(self): + return await self.sam.get_geoloc() + async def mobility_actions(self): await self.sam.check_external_connection_service_status() await self.sam.analize_topology_robustness() From 4afccbca7fcf8efe66908e15b344b9fe4b3cb10b Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 20 Feb 2025 17:42:55 +0100 Subject: [PATCH 105/233] fix daily update --- nebula/addons/mobility.py | 52 ++++++------ nebula/core/engine.py | 31 ++++--- nebula/core/network/communications.py | 59 ++++++-------- .../awareness/samodule.py | 80 ++++++++++--------- 4 files changed, 106 insertions(+), 116 deletions(-) diff --git a/nebula/addons/mobility.py b/nebula/addons/mobility.py index 906cdb807..5143ea967 100755 --- a/nebula/addons/mobility.py +++ b/nebula/addons/mobility.py @@ -340,25 +340,25 @@ async def change_connections_based_on_distance(self): # If the distance is not found, we skip the node continue # logging.info(f"πŸ“ Distance to node {addr}: {distance}") - if ( - not self.cm.connections[addr].get_direct() - and distance < self.max_distance_with_direct_connections - ): - logging.info(f"πŸ“ Node {addr} is close enough [{distance}], adding to direct connections") - self.cm.connections[addr].set_direct(True) - await self.cm.update_neighbors(addr) - else: - # 10% margin to avoid oscillations - if ( - self.cm.connections[addr].get_direct() - and distance > self.max_distance_with_direct_connections * 1.1 - ): - logging.info( - f"πŸ“ Node {addr} is too far away [{distance}], removing from direct connections" - ) - await asyncio.sleep(1) - self.cm.connections[addr].set_direct(False) - await self.cm.update_neighbors(addr,remove=True) + # if ( + # not self.cm.connections[addr].get_direct() + # and distance < self.max_distance_with_direct_connections + # ): + # logging.info(f"πŸ“ Node {addr} is close enough [{distance}], adding to direct connections") + # self.cm.connections[addr].set_direct(True) + # await self.cm.update_neighbors(addr) + # else: + # # 10% margin to avoid oscillations + # if ( + # self.cm.connections[addr].get_direct() + # and distance > self.max_distance_with_direct_connections * 1.1 + # ): + # logging.info( + # f"πŸ“ Node {addr} is too far away [{distance}], removing from direct connections" + # ) + # await asyncio.sleep(1) + # self.cm.connections[addr].set_direct(False) + # await self.cm.update_neighbors(addr,remove=True) # Adapt network conditions of the connection based on distance for threshold in sorted(self.network_conditions.keys()): if distance < threshold: @@ -393,29 +393,29 @@ async def change_connections_based_on_distance(self): async def calculate_network_conditions(self, distance): thresholds = sorted(self.network_conditions.keys()) - + # Si la distancia es menor que el primer umbral, devolver la mejor condiciΓ³n if distance < thresholds[0]: return self.network_conditions[thresholds[0]] - + # Encontrar el tramo en el que se encuentra la distancia for i in range(len(thresholds) - 1): lower_bound = thresholds[i] upper_bound = thresholds[i + 1] - + if lower_bound <= distance < upper_bound: lower_cond = self.network_conditions[lower_bound] upper_cond = self.network_conditions[upper_bound] - + # Calcular el progreso en el tramo (0 a 1) progress = (distance - lower_bound) / (upper_bound - lower_bound) - + # InterpolaciΓ³n lineal de valores bandwidth = lower_cond["bandwidth"] - progress * (lower_cond["bandwidth"] - upper_cond["bandwidth"]) delay = lower_cond["delay"] + progress * (upper_cond["delay"] - lower_cond["delay"]) - + return {"bandwidth": round(bandwidth, 2), "delay": round(delay, 2)} - + # Si la distancia es infinita, devolver el ΓΊltimo valor return self.network_conditions[float("inf")] diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 22f234954..3d6ae0a4c 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -9,8 +9,8 @@ from nebula.addons.reporter import Reporter from nebula.core.aggregation.aggregator import create_aggregator, create_target_aggregator from nebula.core.eventmanager import EventManager -from nebula.core.situationalawareness.nodemanager import NodeManager from nebula.core.network.communications import CommunicationsManager +from nebula.core.situationalawareness.nodemanager import NodeManager from nebula.core.utils.locker import Locker logging.getLogger("requests").setLevel(logging.WARNING) @@ -165,8 +165,8 @@ def __init__( self.register_message_events_callbacks() # Additional callbacks not registered automatically - self.register_message_callback(("model","initialization"), "model_initialization_callback") - self.register_message_callback(("model","update"), "model_update_callback") + self.register_message_callback(("model", "initialization"), "model_initialization_callback") + self.register_message_callback(("model", "update"), "model_update_callback") @property def cm(self): @@ -276,17 +276,17 @@ async def model_initialization_callback(self, source, message): async def model_update_callback(self, source, message): logging.info(f"πŸ€– handle_model_message | Received model update from {source} with round {message.round}") if not self.get_federation_ready_lock().locked() and len(self.get_federation_nodes()) == 0: - logging.info("πŸ€– handle_model_message | There are no defined federation nodes") - return + logging.info("πŸ€– handle_model_message | There are no defined federation nodes") + return decoded_model = self.trainer.deserialize_model(message.parameters) await self.aggregator.update_received_from_source(decoded_model, message.weight, source, message.round) - """ ############################## # General callbacks # ############################## """ + # TODO llevar a communications async def _discovery_discover_callback(self, source, message): logging.info( f"πŸ” handle_discovery_message | Trigger | Received discovery message from {source} (network propagation)" @@ -544,9 +544,6 @@ async def _link_disconnect_from_callback(self, source, message): await self.cm.disconnect(source, mutual_disconnection=False) await self.nm.update_neighbors(addr, remove=True) - - - """ ############################## # ENGINE FUNCTIONALITY # ############################## @@ -572,7 +569,7 @@ def register_message_callback(self, message_event: tuple[str, str], callback: st event_type, action = message_event method = getattr(self, callback, None) if callable(method): - self.event_manager.subscribe((event_type, action), method) + self.event_manager.subscribe((event_type, action), method) async def trigger_event(self, message_event): await self.event_manager.publish(message_event) @@ -615,8 +612,8 @@ async def _start_learning_late(self): await self.learning_cycle_lock.acquire_async() try: model_serialized, rounds, round, _epochs = await self.nm.get_trainning_info() - self.total_rounds = rounds - epochs = _epochs + self.total_rounds = rounds + epochs = _epochs await self.get_round_lock().acquire_async() self.round = round await self.get_round_lock().release_async() @@ -841,7 +838,7 @@ async def _learning_cycle(self): indent=2, title="Round information", ) - #await self.aggregator.reset() + # await self.aggregator.reset() self.trainer.on_round_end() self.round = self.round + 1 self.config.participant["federation_args"]["round"] = ( @@ -1000,7 +997,7 @@ async def _extended_learning_cycle(self): # source=self.addr, # round=self.round, # ) - + await self.aggregator.update_received_from_source( self.trainer.get_model_parameters(), self.trainer.get_model_weight(), @@ -1040,14 +1037,14 @@ async def _extended_learning_cycle(self): # source=self.addr, # round=self.round, # ) - + await self.aggregator.update_received_from_source( self.trainer.get_model_parameters(), self.trainer.BYPASS_MODEL_WEIGHT, source=self.addr, round=self.round, ) - + await self._waiting_model_updates() await self.cm.propagator.propagate("stable") @@ -1084,7 +1081,7 @@ async def _extended_learning_cycle(self): # round=self.round, # local=True, # ) - + await self.aggregator.update_received_from_source( self.trainer.get_model_parameters(), self.trainer.get_model_weight(), diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index a0bf96086..cc03a3d14 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -3,7 +3,6 @@ import logging import subprocess import sys -import traceback from typing import TYPE_CHECKING import requests @@ -12,19 +11,10 @@ from nebula.core.network.blacklist import BlackList from nebula.core.network.connection import Connection from nebula.core.network.discoverer import Discoverer -from nebula.core.network.forwarder import Forwarder -from nebula.core.network.messages import MessagesManager from nebula.core.network.externalconnectionservice import factory_connection_service +from nebula.core.network.forwarder import Forwarder +from nebula.core.network.messages import MessageEvent, MessagesManager from nebula.core.network.propagator import Propagator -from nebula.core.network.messages import MessageEvent -from nebula.core.utils.helper import ( - cosine_metric, - euclidean_metric, - jaccard_metric, - manhattan_metric, - minkowski_metric, - pearson_correlation_metric, -) from nebula.core.utils.locker import Locker if TYPE_CHECKING: @@ -154,10 +144,10 @@ async def handle_message(self, message_event): async def handle_model_message(self, source, message): logging.info(f"πŸ€– handle_model_message | Received model from {source} with round {message.round}") if message.round == -1: - model_init_event = MessageEvent(("model","initialization"), source, message) + model_init_event = MessageEvent(("model", "initialization"), source, message) await self.engine.trigger_event(model_init_event) else: - model_updt_event = MessageEvent(("model","update"), source, message) + model_updt_event = MessageEvent(("model", "update"), source, message) await self.engine.trigger_event(model_updt_event) def create_message(self, message_type: str, action: str = "", *args, **kwargs): @@ -186,12 +176,11 @@ async def apply_restrictions(self, nodes): async def clear_restrictions(self): await self.bl.clear_restrictions() - """ ############################### # EXTERNAL CONNECTION SERVICE # ############################### """ - + async def get_geoloc(self): return await self.engine.get_geoloc() @@ -209,23 +198,19 @@ async def init_external_connection_service(self): async def is_external_connection_service_running(self): return self.ecs.is_running() - + async def start_beacon(self): await self.ecs.start_beacon() - + async def stop_beacon(self): await self.ecs.stop_beacon() - + async def subscribe_beacon_listener(self, listener): await self.ecs.subscribe_beacon_listener(listener) - + async def modify_beacon_frequency(self, frequency): - await self.ecs.modify_beacon_frequency(frequency) + await self.ecs.modify_beacon_frequency(frequency) - #TODO - # si se utilizan addr conocidas y no se consigue conectar a ninguna quΓ© hacer - # -> funcion reentrante pero sin utilizar las conocidas - # S async def stablish_connection_to_federation(self, msg_type="discover_join", addrs_known=None): """ Using ExternalConnectionService to get addrs on local network, after that @@ -241,7 +226,7 @@ async def stablish_connection_to_federation(self, msg_type="discover_join", addr addrs = addrs_known msg = self.create_message("discover", msg_type) - + # Remove neighbors neighbors = await self.get_addrs_current_connections(only_undirected=True, myself=True) addrs = set(addrs) @@ -255,13 +240,13 @@ async def stablish_connection_to_federation(self, msg_type="discover_join", addr for addr in addrs: await self.connect(addr, direct=False) await asyncio.sleep(1) - for i in range(0,max_tries): + for i in range(0, max_tries): if self.verify_any_connections(addrs): break await asyncio.sleep(1) current_connections = await self.get_addrs_current_connections(only_undirected=True) logging.info(f"Connections verified after searching: {current_connections}") - + for addr in addrs: logging.info(f"Sending {msg_type} to ---> {addr}") asyncio.create_task(self.send_message(addr, msg)) @@ -269,16 +254,15 @@ async def stablish_connection_to_federation(self, msg_type="discover_join", addr discovers_sent += 1 return discovers_sent - """ ############################## # OTHER FUNCTIONALITIES # ############################## """ - - #TODO remove - async def update_neighbors(self, addr, remove=False): - current_connections = await self.get_addrs_current_connections(only_direct=True, myself=True) - await self.engine.update_neighbors(addr, current_connections, remove=remove) + + # TODO remove + # async def update_neighbors(self, addr, remove=False): + # current_connections = await self.get_addrs_current_connections(only_direct=True, myself=True) + # await self.engine.update_neighbors(addr, current_connections, remove=remove) def get_connections_lock(self): return self.connections_lock @@ -456,7 +440,7 @@ def verify_any_connections(self, neighbors): if any(neighbor in self.connections for neighbor in neighbors): return True return False - + def verify_connections(self, neighbors): # Return True if all neighbors are connected if all(neighbor in self.connections for neighbor in neighbors): @@ -472,7 +456,7 @@ async def deploy_additional_services(self): await self._forwarder.start() if self.config.participant["mobility_args"]["mobility"]: pass - #await self._discoverer.start() + # await self._discoverer.start() # await self._health.start() self._propagator.start() await self._mobility.start() @@ -590,6 +574,9 @@ def _set_network_conditions( logging.exception(f"❗️ Network simulation error: {e}") return + async def update_connection_geolocalziation(source, latitude, longitude): + pass + async def include_received_message_hash(self, hash_message): try: await self.receive_messages_lock.acquire_async() diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index 069c80055..9565d3ae0 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -1,16 +1,18 @@ import asyncio import logging +from typing import TYPE_CHECKING + +from nebula.addons.functions import print_msg_box from nebula.core.situationalawareness.awareness.neighborpolicies.neighborpolicy import factory_NeighborPolicy from nebula.core.utils.locker import Locker -from nebula.addons.functions import print_msg_box -from typing import TYPE_CHECKING if TYPE_CHECKING: from nebula.core.situationalawareness.nodemanager import NodeManager RESTRUCTURE_COOLDOWN = 5 -class SAModule(): + +class SAModule: def __init__( self, nodemanager, @@ -18,28 +20,30 @@ def __init__( topology, ): print_msg_box( - msg=f"Starting Situational Awareness module...\nTopology: {topology}", indent=2, title="Situational Awareness module" + msg=f"Starting Situational Awareness module...\nTopology: {topology}", + indent=2, + title="Situational Awareness module", ) logging.info("🌐 Initializing SAModule") self._addr = addr self._topology = topology - self._node_manager : NodeManager = nodemanager + self._node_manager: NodeManager = nodemanager self._neighbor_policy = factory_NeighborPolicy(topology) self._restructure_process_lock = Locker(name="restructure_process_lock") self._restructure_cooldown = 0 - + @property def nm(self): return self._node_manager - + @property def np(self): return self._neighbor_policy - + @property def cm(self): return self.nm.engine.cm - + async def init(self): if not self.nm.is_additional_participant(): logging.info("Deploying External Connection Service") @@ -57,9 +61,9 @@ async def init(self): self.nm.engine.addr, self, ]) - + async def experiment_finish(self): - await self.cm.stop_external_connection_service() + await self.cm.stop_external_connection_service() """ ############################### # REESTRUCTURE TOPOLOGY # @@ -74,70 +78,69 @@ def _restructure_available(self): if self._restructure_cooldown: logging.info("Reestructure on cooldown") return self._restructure_cooldown == 0 - + def get_restructure_process_lock(self): return self._restructure_process_lock - """ ############################### # NEIGHBOR POLICY # ############################### """ - + def meet_node(self, node): if node != self._addr: logging.info(f"Update nodes known | addr: {node}") self.np.meet_node(node) - + def update_neighbors(self, node, remove=False): self.np.update_neighbors(node, remove) if not remove: self.np.meet_node(node) - + def get_nodes_known(self, neighbors_too=False, neighbors_only=False): return self.np.get_nodes_known(neighbors_too, neighbors_only) - + async def neighbors_left(self): return len(await self.cm.get_addrs_current_connections(only_direct=True, myself=False)) > 0 - + def accept_connection(self, source, joining=False): return self.np.accept_connection(source, joining) - + def need_more_neighbors(self): return self.np.need_more_neighbors() def get_actions(self): return self.np.get_actions() - + async def get_geoloc(self): latitude = self.nm.config.participant["mobility_args"]["latitude"] longitude = self.nm.config.participant["mobility_args"]["longitude"] - return (latitude,longitude) - - + return (latitude, longitude) + """ ############################### # ROBUSTNESS # ############################### """ - + async def beacon_received(self, addr, geoloc): latitude, longitude = geoloc + await self.meet_node(addr) logging.info(f"Beacon received SAModule, source: {addr}, geolocalization: {latitude},{longitude}") - + async def check_external_connection_service_status(self): if not await self.cm.is_external_connection_service_running(): logging.info("πŸ”„ External Service not running | Starting service...") await self.cm.init_external_connection_service() await self.cm.subscribe_beacon_listener(self.beacon_received) await self.cm.start_beacon() - + async def analize_topology_robustness(self): logging.info("πŸ”„ Analizing node network robustness...") if not self._restructure_process_lock.locked(): if not await self.neighbors_left(): logging.info("No Neighbors left | reconnecting with Federation") await self.reconnect_to_federation() - elif (self.np.need_more_neighbors() and self._restructure_available()): + elif self.np.need_more_neighbors() and self._restructure_available(): logging.info("Insufficient Robustness | Upgrading robustness | Searching for more connections") self._update_restructure_cooldown() possible_neighbors = self.np.get_nodes_known(neighbors_too=False) @@ -145,8 +148,8 @@ async def analize_topology_robustness(self): if not possible_neighbors: logging.info("All possible neighbors using nodes known are restricted...") else: - pass - #asyncio.create_task(self.upgrade_connection_robustness(possible_neighbors)) + pass + # asyncio.create_task(self.upgrade_connection_robustness(possible_neighbors)) else: logging.info("Sufficient Robustness | no actions required") else: @@ -155,11 +158,13 @@ async def analize_topology_robustness(self): async def reconnect_to_federation(self): self._restructure_process_lock.acquire() await self.cm.clear_restrictions() - await asyncio.sleep(120) - # If we got some refs, try to reconnect to them + await asyncio.sleep(120) + # If we got some refs, try to reconnect to them if len(self.np.get_nodes_known()) > 0: logging.info("Reconnecting | Addrs availables") - await self.nm.start_late_connection_process(connected=False, msg_type="discover_nodes", addrs_known=self.np.get_nodes_known()) + await self.nm.start_late_connection_process( + connected=False, msg_type="discover_nodes", addrs_known=self.np.get_nodes_known() + ) else: logging.info("Reconnecting | NO Addrs availables") await self.nm.start_late_connection_process(connected=False, msg_type="discover_nodes") @@ -167,22 +172,23 @@ async def reconnect_to_federation(self): async def upgrade_connection_robustness(self, possible_neighbors): self._restructure_process_lock.acquire() - #addrs_to_connect = self.neighbor_policy.get_nodes_known(neighbors_too=False) + # addrs_to_connect = self.neighbor_policy.get_nodes_known(neighbors_too=False) # If we got some refs, try to connect to them if len(possible_neighbors) > 0: logging.info(f"Reestructuring | Addrs availables | addr list: {possible_neighbors}") - await self.nm.start_late_connection_process(connected=True, msg_type="discover_nodes", addrs_known=possible_neighbors) + await self.nm.start_late_connection_process( + connected=True, msg_type="discover_nodes", addrs_known=possible_neighbors + ) else: logging.info("Reestructuring | NO Addrs availables") await self.nm.start_late_connection_process(connected=True, msg_type="discover_nodes") self._restructure_process_lock.release() - + async def stop_connections_with_federation(self): await asyncio.sleep(200) logging.info("### DISCONNECTING FROM FEDERATON ###") neighbors = self.np.get_nodes_known(neighbors_only=True) - for n in neighbors: + for n in neighbors: await self.cm.add_to_blacklist(n) for n in neighbors: await self.cm.disconnect(n, mutual_disconnection=False, forced=True) - \ No newline at end of file From 98e3e2e898b94e74e15a646124578ebc7e8767c5 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Fri, 21 Feb 2025 12:33:50 +0100 Subject: [PATCH 106/233] feature nebula gps service --- nebula/core/network/communications.py | 15 ++-- .../awareness/GPS/gpsmodule.py | 29 +++++++ .../awareness/GPS/nebulagps.py | 87 +++++++++++++++++++ .../awareness/samodule.py | 9 +- 4 files changed, 132 insertions(+), 8 deletions(-) create mode 100644 nebula/core/situationalawareness/awareness/GPS/gpsmodule.py create mode 100644 nebula/core/situationalawareness/awareness/GPS/nebulagps.py diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index cc03a3d14..520893fcc 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -259,10 +259,14 @@ async def stablish_connection_to_federation(self, msg_type="discover_join", addr ############################## """ - # TODO remove - # async def update_neighbors(self, addr, remove=False): - # current_connections = await self.get_addrs_current_connections(only_direct=True, myself=True) - # await self.engine.update_neighbors(addr, current_connections, remove=remove) + #TODO setcondition para la direccion multicast + async def update_geolocalization(self, geoloc : dict): + async with self.get_connections_lock(): + #logging.info("Update geolocs to simulate network conditions") + for source in geoloc.keys(): + latitude, longitude = geoloc[source] + #logging.info(f"Update geolocs for source: {source}, geoloc: ({latitude},{longitude})") + #self.connections[source].update_geolocation(latitude, longitude) def get_connections_lock(self): return self.connections_lock @@ -574,9 +578,6 @@ def _set_network_conditions( logging.exception(f"❗️ Network simulation error: {e}") return - async def update_connection_geolocalziation(source, latitude, longitude): - pass - async def include_received_message_hash(self, hash_message): try: await self.receive_messages_lock.acquire_async() diff --git a/nebula/core/situationalawareness/awareness/GPS/gpsmodule.py b/nebula/core/situationalawareness/awareness/GPS/gpsmodule.py new file mode 100644 index 000000000..313a41ebe --- /dev/null +++ b/nebula/core/situationalawareness/awareness/GPS/gpsmodule.py @@ -0,0 +1,29 @@ +import asyncio +from abc import ABC, abstractmethod + +class GPSModule(ABC): + + @abstractmethod + async def start(self): + pass + + @abstractmethod + async def stop(self): + pass + +class GPSModuleException(Exception): + pass + +def factory_gpsmodule(gps_module, sam) -> GPSModule: + from nebula.core.situationalawareness.awareness.GPS.nebulagps import NebulaGPS + + GPS_SERVICES = { + "nebula": NebulaGPS, + } + + gps_module = GPS_SERVICES.get(gps_module, NebulaGPS) + + if gps_module: + return gps_module(sam) + else: + raise GPSModuleException(f"GPS Module {gps_module} not found") \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/GPS/nebulagps.py b/nebula/core/situationalawareness/awareness/GPS/nebulagps.py new file mode 100644 index 000000000..2c73c3832 --- /dev/null +++ b/nebula/core/situationalawareness/awareness/GPS/nebulagps.py @@ -0,0 +1,87 @@ +import asyncio +import logging +from nebula.core.situationalawareness.awareness.GPS.gpsmodule import GPSModule +import socket +from nebula.core.utils.locker import Locker + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from nebula.core.situationalawareness.awareness.samodule import SAModule + +class NebulaGPS(GPSModule): + BROADCAST_IP = "255.255.255.255" # Broadcast IP + BROADCAST_PORT = 50001 # Poort used for GPS + INTERFACE = "eth2" # Interface to avoid network conditions + + def __init__(self, sam: "SAModule", update_interval: float = 5.0): + self._situational_awareness_module = sam + self.update_interval = update_interval # Frecuencia de emisiΓ³n + self.running = False + self._node_locations = {} # Diccionario para almacenar ubicaciones de nodos + self._broadcast_socket = None + self._nodes_location_lock = Locker("nodes_location_lock", async_lock=True) + + @property + def sam(self): + return self._situational_awareness_module + + async def start(self): + """Inicia el servicio de GPS, enviando y recibiendo ubicaciones.""" + logging.info("Starting NebulaGPS service...") + self.running = True + + # Crear socket de broadcast + self._broadcast_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self._broadcast_socket.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) + + # Enlazar socket en eth2 para recibir tambiΓ©n datos + self._broadcast_socket.bind(("", self.BROADCAST_PORT)) + + # Iniciar tareas de envΓ­o y recepciΓ³n + asyncio.create_task(self._send_location_loop()) + asyncio.create_task(self._receive_location_loop()) + asyncio.create_task(self._notify_geolocs()) + + async def stop(self): + """Detiene el servicio de GPS.""" + #logging.info("Stopping NebulaGPS service...") + self.running = False + if self._broadcast_socket: + self._broadcast_socket.close() + self._broadcast_socket = None + + async def _send_location_loop(self): + """Envia la geolocalizaciΓ³n periΓ³dicamente por broadcast.""" + while self.running: + latitude, longitude = await self.sam.get_geoloc() # Obtener ubicaciΓ³n actual + message = f"GPS-UPDATE {latitude} {longitude}" + self._broadcast_socket.sendto(message.encode(), (self.BROADCAST_IP, self.BROADCAST_PORT)) + #logging.info(f"Sent GPS location: ({latitude}, {longitude})") + await asyncio.sleep(self.update_interval) + + async def _receive_location_loop(self): + """Escucha y almacena geolocalizaciones de otros nodos.""" + while self.running: + try: + data, addr = await asyncio.get_running_loop().run_in_executor( + None, self._broadcast_socket.recvfrom, 1024 + ) + message = data.decode().strip() + if message.startswith("GPS-UPDATE"): + _, lat, lon = message.split() + self._nodes_location_lock.acquire_async() + self._node_locations[addr[0]] = (float(lat), float(lon)) + self._nodes_location_lock.release_async() + #logging.info(f"Received GPS from {addr[0]}: {lat}, {lon}") + except Exception as e: + logging.error(f"Error receiving GPS update: {e}") + + async def _notify_geolocs(self): + while True: + await asyncio.sleep(self.update_interval) + self._nodes_location_lock.acquire_async() + geolocs = self._node_locations.copy() + self._nodes_location_lock.release_async() + if geolocs: + await self.sam.cm.update_geolocalization(geolocs) + \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index 9565d3ae0..2fdfac9d3 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -4,6 +4,7 @@ from nebula.addons.functions import print_msg_box from nebula.core.situationalawareness.awareness.neighborpolicies.neighborpolicy import factory_NeighborPolicy +from nebula.core.situationalawareness.awareness.GPS.gpsmodule import factory_gpsmodule from nebula.core.utils.locker import Locker if TYPE_CHECKING: @@ -31,6 +32,7 @@ def __init__( self._neighbor_policy = factory_NeighborPolicy(topology) self._restructure_process_lock = Locker(name="restructure_process_lock") self._restructure_cooldown = 0 + self._gpsmodule = factory_gpsmodule("nebula", self) @property def nm(self): @@ -43,6 +45,10 @@ def np(self): @property def cm(self): return self.nm.engine.cm + + @property + def gps(self): + return self._gpsmodule async def init(self): if not self.nm.is_additional_participant(): @@ -50,6 +56,7 @@ async def init(self): await self.cm.start_external_connection_service() await self.cm.subscribe_beacon_listener(self.beacon_received) await self.cm.start_beacon() + await self.gps.start() else: logging.info("Deploying External Connection Service | No running") await self.cm.start_external_connection_service(run_service=False) @@ -124,7 +131,7 @@ async def get_geoloc(self): async def beacon_received(self, addr, geoloc): latitude, longitude = geoloc - await self.meet_node(addr) + self.meet_node(addr) logging.info(f"Beacon received SAModule, source: {addr}, geolocalization: {latitude},{longitude}") async def check_external_connection_service_status(self): From e6bb7b4e05f99c7f44029182e700777a696ad455 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Fri, 21 Feb 2025 13:31:18 +0100 Subject: [PATCH 107/233] fix mobility errors --- nebula/addons/mobility.py | 27 ++++++++++++++----- nebula/core/network/communications.py | 4 +-- .../awareness/GPS/gpsmodule.py | 4 +-- .../awareness/GPS/nebulagps.py | 13 ++++----- .../awareness/samodule.py | 2 +- .../frontend/config/participant.json.example | 2 +- 6 files changed, 33 insertions(+), 19 deletions(-) diff --git a/nebula/addons/mobility.py b/nebula/addons/mobility.py index 5143ea967..ca49c2227 100755 --- a/nebula/addons/mobility.py +++ b/nebula/addons/mobility.py @@ -67,7 +67,7 @@ def __init__(self, config, cm: "CommunicationsManager"): 100: {"bandwidth": "5Gbps", "delay": "5ms"}, 200: {"bandwidth": "2Gbps", "delay": "50ms"}, 300: {"bandwidth": "100Mbps", "delay": "200ms"}, - float("inf"): {"bandwidth": "10Mbps", "delay": "1000ms"}, + float("inf"): {"bandwidth": "10Mbps", "delay": "1000000000000ms"}, } # Current network conditions of each connection {addr: {bandwidth: "5Gbps", delay: "0ms"}} self.current_network_conditions = {} @@ -271,8 +271,7 @@ async def change_geo_location(self): direct_connections = await self.cm.get_direct_connections() undirect_connection = await self.cm.get_undirect_connections() - if len(undirect_connection) > len(direct_connections): - logging.info("πŸ“ Undirect Connections is higher than Direct Connections") + if True or len(undirect_connection) > len(direct_connections): # Get neighbor closer to me selected_neighbor = await self.cm.get_nearest_connections(top=1) logging.info(f"πŸ“ Selected neighbor: {selected_neighbor}") @@ -364,6 +363,7 @@ async def change_connections_based_on_distance(self): if distance < threshold: conditions = self.network_conditions[threshold] break + conditions = await self.calculate_network_conditions(distance) # Only update the network conditions if they have changed if ( addr not in self.current_network_conditions @@ -391,12 +391,16 @@ async def change_connections_based_on_distance(self): logging.exception("πŸ“ Error changing connections based on distance") return + #TODO corregir formato, no son float son float-mbps por ejemplo async def calculate_network_conditions(self, distance): thresholds = sorted(self.network_conditions.keys()) # Si la distancia es menor que el primer umbral, devolver la mejor condiciΓ³n if distance < thresholds[0]: - return self.network_conditions[thresholds[0]] + return { + "bandwidth": float(self.network_conditions[thresholds[0]]["bandwidth"]), + "delay": float(self.network_conditions[thresholds[0]]["delay"]) + } # Encontrar el tramo en el que se encuentra la distancia for i in range(len(thresholds) - 1): @@ -407,17 +411,26 @@ async def calculate_network_conditions(self, distance): lower_cond = self.network_conditions[lower_bound] upper_cond = self.network_conditions[upper_bound] + # Convertir a float antes de operar + lower_bandwidth = float(lower_cond["bandwidth"]) + upper_bandwidth = float(upper_cond["bandwidth"]) + lower_delay = float(lower_cond["delay"]) + upper_delay = float(upper_cond["delay"]) + # Calcular el progreso en el tramo (0 a 1) progress = (distance - lower_bound) / (upper_bound - lower_bound) # InterpolaciΓ³n lineal de valores - bandwidth = lower_cond["bandwidth"] - progress * (lower_cond["bandwidth"] - upper_cond["bandwidth"]) - delay = lower_cond["delay"] + progress * (upper_cond["delay"] - lower_cond["delay"]) + bandwidth = lower_bandwidth - progress * (lower_bandwidth - upper_bandwidth) + delay = lower_delay + progress * (upper_delay - lower_delay) return {"bandwidth": round(bandwidth, 2), "delay": round(delay, 2)} # Si la distancia es infinita, devolver el ΓΊltimo valor - return self.network_conditions[float("inf")] + return { + "bandwidth": float(self.network_conditions[float("inf")]["bandwidth"]), + "delay": float(self.network_conditions[float("inf")]["delay"]) + } async def change_connections(self): """ diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 520893fcc..6b955a7c4 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -36,7 +36,7 @@ def __init__(self, engine: "Engine"): self.register_endpoint = f"http://{self.config.participant['scenario_args']['controller']}/nebula/dashboard/{self.config.participant['scenario_args']['name']}/node/register" self.wait_endpoint = f"http://{self.config.participant['scenario_args']['controller']}/nebula/dashboard/{self.config.participant['scenario_args']['name']}/node/wait" - self._connections = {} + self._connections : dict[str, Connection] = {} self.connections_lock = Locker(name="connections_lock", async_lock=True) self.connections_manager_lock = Locker(name="connections_manager_lock", async_lock=True) self.connection_attempt_lock_incoming = Locker(name="connection_attempt_lock_incoming", async_lock=True) @@ -266,7 +266,7 @@ async def update_geolocalization(self, geoloc : dict): for source in geoloc.keys(): latitude, longitude = geoloc[source] #logging.info(f"Update geolocs for source: {source}, geoloc: ({latitude},{longitude})") - #self.connections[source].update_geolocation(latitude, longitude) + self.connections[source].update_geolocation(latitude, longitude) def get_connections_lock(self): return self.connections_lock diff --git a/nebula/core/situationalawareness/awareness/GPS/gpsmodule.py b/nebula/core/situationalawareness/awareness/GPS/gpsmodule.py index 313a41ebe..b6df44fb7 100644 --- a/nebula/core/situationalawareness/awareness/GPS/gpsmodule.py +++ b/nebula/core/situationalawareness/awareness/GPS/gpsmodule.py @@ -14,7 +14,7 @@ async def stop(self): class GPSModuleException(Exception): pass -def factory_gpsmodule(gps_module, sam) -> GPSModule: +def factory_gpsmodule(gps_module, sam, addr) -> GPSModule: from nebula.core.situationalawareness.awareness.GPS.nebulagps import NebulaGPS GPS_SERVICES = { @@ -24,6 +24,6 @@ def factory_gpsmodule(gps_module, sam) -> GPSModule: gps_module = GPS_SERVICES.get(gps_module, NebulaGPS) if gps_module: - return gps_module(sam) + return gps_module(sam, addr) else: raise GPSModuleException(f"GPS Module {gps_module} not found") \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/GPS/nebulagps.py b/nebula/core/situationalawareness/awareness/GPS/nebulagps.py index 2c73c3832..a0c5d89e4 100644 --- a/nebula/core/situationalawareness/awareness/GPS/nebulagps.py +++ b/nebula/core/situationalawareness/awareness/GPS/nebulagps.py @@ -13,7 +13,8 @@ class NebulaGPS(GPSModule): BROADCAST_PORT = 50001 # Poort used for GPS INTERFACE = "eth2" # Interface to avoid network conditions - def __init__(self, sam: "SAModule", update_interval: float = 5.0): + def __init__(self, sam: "SAModule", addr, update_interval: float = 5.0): + self._addr = addr self._situational_awareness_module = sam self.update_interval = update_interval # Frecuencia de emisiΓ³n self.running = False @@ -54,7 +55,7 @@ async def _send_location_loop(self): """Envia la geolocalizaciΓ³n periΓ³dicamente por broadcast.""" while self.running: latitude, longitude = await self.sam.get_geoloc() # Obtener ubicaciΓ³n actual - message = f"GPS-UPDATE {latitude} {longitude}" + message = f"GPS-UPDATE {self._addr} {latitude} {longitude}" self._broadcast_socket.sendto(message.encode(), (self.BROADCAST_IP, self.BROADCAST_PORT)) #logging.info(f"Sent GPS location: ({latitude}, {longitude})") await asyncio.sleep(self.update_interval) @@ -68,10 +69,10 @@ async def _receive_location_loop(self): ) message = data.decode().strip() if message.startswith("GPS-UPDATE"): - _, lat, lon = message.split() - self._nodes_location_lock.acquire_async() - self._node_locations[addr[0]] = (float(lat), float(lon)) - self._nodes_location_lock.release_async() + _, sender_addr, lat, lon = message.split() + if sender_addr != self._addr: + async with self._nodes_location_lock: + self._node_locations[sender_addr] = (float(lat), float(lon)) #logging.info(f"Received GPS from {addr[0]}: {lat}, {lon}") except Exception as e: logging.error(f"Error receiving GPS update: {e}") diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index 2fdfac9d3..1ca92e1d4 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -32,7 +32,7 @@ def __init__( self._neighbor_policy = factory_NeighborPolicy(topology) self._restructure_process_lock = Locker(name="restructure_process_lock") self._restructure_cooldown = 0 - self._gpsmodule = factory_gpsmodule("nebula", self) + self._gpsmodule = factory_gpsmodule("nebula", self, self._addr) @property def nm(self): diff --git a/nebula/frontend/config/participant.json.example b/nebula/frontend/config/participant.json.example index db3ec0d34..3e2b916b9 100755 --- a/nebula/frontend/config/participant.json.example +++ b/nebula/frontend/config/participant.json.example @@ -41,7 +41,7 @@ "addr": "", "neighbors": "", "interface": "eth0", - "simulation": false, + "simulation": true, "bandwidth": "5Gbps", "delay": "0ms", "delay-distro": "0ms", From 2f8f2d847aaca56c65702a4fc01a570706596d5f Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Fri, 21 Feb 2025 15:47:30 +0100 Subject: [PATCH 108/233] feature updating mobility module --- nebula/addons/mobility.py | 40 +++++++++++++------ .../awareness/samodule.py | 2 +- 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/nebula/addons/mobility.py b/nebula/addons/mobility.py index ca49c2227..23b04f39b 100755 --- a/nebula/addons/mobility.py +++ b/nebula/addons/mobility.py @@ -359,10 +359,10 @@ async def change_connections_based_on_distance(self): # self.cm.connections[addr].set_direct(False) # await self.cm.update_neighbors(addr,remove=True) # Adapt network conditions of the connection based on distance - for threshold in sorted(self.network_conditions.keys()): - if distance < threshold: - conditions = self.network_conditions[threshold] - break + # for threshold in sorted(self.network_conditions.keys()): + # if distance < threshold: + # conditions = self.network_conditions[threshold] + # break conditions = await self.calculate_network_conditions(distance) # Only update the network conditions if they have changed if ( @@ -391,8 +391,14 @@ async def change_connections_based_on_distance(self): logging.exception("πŸ“ Error changing connections based on distance") return - #TODO corregir formato, no son float son float-mbps por ejemplo async def calculate_network_conditions(self, distance): + def extract_number(value): + import re + match = re.match(r"([\d.]+)", value) + if not match: + raise ValueError(f"Formato invΓ‘lido: {value}") + return float(match.group(1)) + thresholds = sorted(self.network_conditions.keys()) # Si la distancia es menor que el primer umbral, devolver la mejor condiciΓ³n @@ -411,20 +417,28 @@ async def calculate_network_conditions(self, distance): lower_cond = self.network_conditions[lower_bound] upper_cond = self.network_conditions[upper_bound] - # Convertir a float antes de operar - lower_bandwidth = float(lower_cond["bandwidth"]) - upper_bandwidth = float(upper_cond["bandwidth"]) - lower_delay = float(lower_cond["delay"]) - upper_delay = float(upper_cond["delay"]) + # Extraer valores numΓ©ricos y unidades + lower_bandwidth_value = extract_number(lower_cond["bandwidth"]) + upper_bandwidth_value = extract_number(upper_cond["bandwidth"]) + lower_bandwidth_unit = lower_cond["bandwidth"].replace(str(lower_bandwidth_value), "") + upper_bandwidth_unit = upper_cond["bandwidth"].replace(str(upper_bandwidth_value), "") + + lower_delay_value = extract_number(lower_cond["delay"]) + upper_delay_value = extract_number(upper_cond["delay"]) + delay_unit = lower_cond["delay"].replace(str(lower_delay_value), "") # Calcular el progreso en el tramo (0 a 1) progress = (distance - lower_bound) / (upper_bound - lower_bound) # InterpolaciΓ³n lineal de valores - bandwidth = lower_bandwidth - progress * (lower_bandwidth - upper_bandwidth) - delay = lower_delay + progress * (upper_delay - lower_delay) + bandwidth_value = lower_bandwidth_value - progress * (lower_bandwidth_value - upper_bandwidth_value) + delay_value = lower_delay_value + progress * (upper_delay_value - lower_delay_value) + + # Reconstruir valores con unidades originales + bandwidth = f"{round(bandwidth_value, 2)}{lower_bandwidth_unit}" + delay = f"{round(delay_value, 2)}{delay_unit}" - return {"bandwidth": round(bandwidth, 2), "delay": round(delay, 2)} + return {"bandwidth": bandwidth, "delay": delay} # Si la distancia es infinita, devolver el ΓΊltimo valor return { diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index 1ca92e1d4..f8c47dec4 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -192,7 +192,7 @@ async def upgrade_connection_robustness(self, possible_neighbors): self._restructure_process_lock.release() async def stop_connections_with_federation(self): - await asyncio.sleep(200) + await asyncio.sleep(400) logging.info("### DISCONNECTING FROM FEDERATON ###") neighbors = self.np.get_nodes_known(neighbors_only=True) for n in neighbors: From 474de4ca194d2a7f7a289cf3395ef48afe6f0f8b Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Fri, 21 Feb 2025 19:53:53 +0100 Subject: [PATCH 109/233] fix mobility errors --- nebula/addons/mobility.py | 86 ++++++++----------- nebula/core/network/communications.py | 12 ++- nebula/core/network/nebuladiscoveryservice.py | 2 +- .../awareness/samodule.py | 1 + 4 files changed, 45 insertions(+), 56 deletions(-) diff --git a/nebula/addons/mobility.py b/nebula/addons/mobility.py index 23b04f39b..728b813b7 100755 --- a/nebula/addons/mobility.py +++ b/nebula/addons/mobility.py @@ -67,7 +67,7 @@ def __init__(self, config, cm: "CommunicationsManager"): 100: {"bandwidth": "5Gbps", "delay": "5ms"}, 200: {"bandwidth": "2Gbps", "delay": "50ms"}, 300: {"bandwidth": "100Mbps", "delay": "200ms"}, - float("inf"): {"bandwidth": "10Mbps", "delay": "1000000000000ms"}, + float("inf"): {"bandwidth": "10Mbps", "delay": "100000ms"}, } # Current network conditions of each connection {addr: {bandwidth: "5Gbps", delay: "0ms"}} self.current_network_conditions = {} @@ -269,32 +269,33 @@ async def change_geo_location(self): latitude = float(self.config.participant["mobility_args"]["latitude"]) longitude = float(self.config.participant["mobility_args"]["longitude"]) - direct_connections = await self.cm.get_direct_connections() - undirect_connection = await self.cm.get_undirect_connections() - if True or len(undirect_connection) > len(direct_connections): + if True: # Get neighbor closer to me selected_neighbor = await self.cm.get_nearest_connections(top=1) - logging.info(f"πŸ“ Selected neighbor: {selected_neighbor}") - try: - ( - neighbor_latitude, - neighbor_longitude, - ) = selected_neighbor.get_geolocation() - distance = selected_neighbor.get_neighbor_distance() - if distance > self.max_initiate_approximation: - # If the distance is too big, we move towards the neighbor - await self.change_geo_location_nearest_neighbor_strategy( - distance, - latitude, - longitude, + if selected_neighbor: + logging.info(f"πŸ“ Selected neighbor: {selected_neighbor}") + try: + ( neighbor_latitude, neighbor_longitude, - ) - else: + ) = selected_neighbor.get_geolocation() + distance = selected_neighbor.get_neighbor_distance() + if distance > self.max_initiate_approximation: + # If the distance is too big, we move towards the neighbor + await self.change_geo_location_nearest_neighbor_strategy( + distance, + latitude, + longitude, + neighbor_latitude, + neighbor_longitude, + ) + else: + await self.change_geo_location_random_strategy(latitude, longitude) + except Exception as e: + logging.info(f"πŸ“ Neighbor location/distance not found for {selected_neighbor.get_addr()}: {e}") await self.change_geo_location_random_strategy(latitude, longitude) - except Exception as e: - logging.info(f"πŸ“ Neighbor location/distance not found for {selected_neighbor.get_addr()}: {e}") - await self.change_geo_location_random_strategy(latitude, longitude) + else: + await self.change_geo_location_random_strategy(latitude, longitude) else: await self.change_geo_location_random_strategy(latitude, longitude) else: @@ -338,36 +339,11 @@ async def change_connections_based_on_distance(self): if distance is None: # If the distance is not found, we skip the node continue - # logging.info(f"πŸ“ Distance to node {addr}: {distance}") - # if ( - # not self.cm.connections[addr].get_direct() - # and distance < self.max_distance_with_direct_connections - # ): - # logging.info(f"πŸ“ Node {addr} is close enough [{distance}], adding to direct connections") - # self.cm.connections[addr].set_direct(True) - # await self.cm.update_neighbors(addr) - # else: - # # 10% margin to avoid oscillations - # if ( - # self.cm.connections[addr].get_direct() - # and distance > self.max_distance_with_direct_connections * 1.1 - # ): - # logging.info( - # f"πŸ“ Node {addr} is too far away [{distance}], removing from direct connections" - # ) - # await asyncio.sleep(1) - # self.cm.connections[addr].set_direct(False) - # await self.cm.update_neighbors(addr,remove=True) - # Adapt network conditions of the connection based on distance - # for threshold in sorted(self.network_conditions.keys()): - # if distance < threshold: - # conditions = self.network_conditions[threshold] - # break conditions = await self.calculate_network_conditions(distance) + logging.info(f"Conditions for source: {addr}, | {conditions}") # Only update the network conditions if they have changed if ( - addr not in self.current_network_conditions - or self.current_network_conditions[addr] != conditions + addr not in self.current_network_conditions or self.current_network_conditions[addr] != conditions ): # eth1 is the interface of the container that connects to the node network - eth0 is the interface of the container that connects to the frontend/backend self.cm._set_network_conditions( @@ -383,6 +359,8 @@ async def change_connections_based_on_distance(self): reordering="0%", ) self.current_network_conditions[addr] = conditions + else: + logging.info("network conditions havent changed since last time") except KeyError: # Except when self.cm.connections[addr] is not found (disconnected during the process) logging.exception(f"πŸ“ Connection {addr} not found") @@ -392,6 +370,7 @@ async def change_connections_based_on_distance(self): return async def calculate_network_conditions(self, distance): + logging.info(f"Calculating conditions for distance: {distance}") def extract_number(value): import re match = re.match(r"([\d.]+)", value) @@ -413,7 +392,11 @@ def extract_number(value): lower_bound = thresholds[i] upper_bound = thresholds[i + 1] + if upper_bound == float("inf"): + break + if lower_bound <= distance < upper_bound: + logging.info(f"Bounds | lower: {lower_bound} | upper: {upper_bound}") lower_cond = self.network_conditions[lower_bound] upper_cond = self.network_conditions[upper_bound] @@ -429,6 +412,7 @@ def extract_number(value): # Calcular el progreso en el tramo (0 a 1) progress = (distance - lower_bound) / (upper_bound - lower_bound) + logging.info(f"Progress between the bounds: {progress}") # InterpolaciΓ³n lineal de valores bandwidth_value = lower_bandwidth_value - progress * (lower_bandwidth_value - upper_bandwidth_value) @@ -442,8 +426,8 @@ def extract_number(value): # Si la distancia es infinita, devolver el ΓΊltimo valor return { - "bandwidth": float(self.network_conditions[float("inf")]["bandwidth"]), - "delay": float(self.network_conditions[float("inf")]["delay"]) + "bandwidth": self.network_conditions[float("inf")]["bandwidth"], + "delay": self.network_conditions[float("inf")]["delay"] } async def change_connections(self): diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 6b955a7c4..d87e8daf4 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -266,7 +266,8 @@ async def update_geolocalization(self, geoloc : dict): for source in geoloc.keys(): latitude, longitude = geoloc[source] #logging.info(f"Update geolocs for source: {source}, geoloc: ({latitude},{longitude})") - self.connections[source].update_geolocation(latitude, longitude) + if source in self.connections: + self.connections[source].update_geolocation(latitude, longitude) def get_connections_lock(self): return self.connections_lock @@ -927,10 +928,13 @@ async def get_nearest_connections(self, top: int = 1): conn.get_neighbor_distance() if conn.get_neighbor_distance() is not None else float("inf") ), ) - if top == 1: - return sorted_connections[0] + if sorted_connections: + if top == 1: + return sorted_connections[0] + else: + return sorted_connections[:top] else: - return sorted_connections[:top] + return None finally: await self.get_connections_lock().release_async() diff --git a/nebula/core/network/nebuladiscoveryservice.py b/nebula/core/network/nebuladiscoveryservice.py index e47673d83..8cc4908b3 100644 --- a/nebula/core/network/nebuladiscoveryservice.py +++ b/nebula/core/network/nebuladiscoveryservice.py @@ -113,7 +113,7 @@ async def search(self): def datagram_received(self, data, addr): try: if "ST: urn:nebula-service" in data.decode('utf-8'): - logging.info("Received response from Node server-service") + #logging.info("Received response from Node server-service") self.nebula_service.response_received(data, addr) except UnicodeDecodeError: logging.warning(f"Received malformed message from {addr}, ignoring.") diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index f8c47dec4..4cc9fd1de 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -60,6 +60,7 @@ async def init(self): else: logging.info("Deploying External Connection Service | No running") await self.cm.start_external_connection_service(run_service=False) + logging.info("Building neighbor policy configuration..") self.np.set_config([ From 56928c692c3f1205a50ec55fe385f2b6cd51049c Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Fri, 21 Feb 2025 21:11:48 +0100 Subject: [PATCH 110/233] fix missing await --- nebula/core/situationalawareness/awareness/GPS/nebulagps.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nebula/core/situationalawareness/awareness/GPS/nebulagps.py b/nebula/core/situationalawareness/awareness/GPS/nebulagps.py index a0c5d89e4..2951f868d 100644 --- a/nebula/core/situationalawareness/awareness/GPS/nebulagps.py +++ b/nebula/core/situationalawareness/awareness/GPS/nebulagps.py @@ -80,9 +80,9 @@ async def _receive_location_loop(self): async def _notify_geolocs(self): while True: await asyncio.sleep(self.update_interval) - self._nodes_location_lock.acquire_async() + await self._nodes_location_lock.acquire_async() geolocs = self._node_locations.copy() - self._nodes_location_lock.release_async() + await self._nodes_location_lock.release_async() if geolocs: await self.sam.cm.update_geolocalization(geolocs) \ No newline at end of file From 1d32273e01ae37480bb85d7b962b2cbff0796d5c Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Sun, 23 Feb 2025 13:06:37 +0100 Subject: [PATCH 111/233] fix generate network conditions --- nebula/addons/mobility.py | 2 +- nebula/core/network/communications.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/nebula/addons/mobility.py b/nebula/addons/mobility.py index 728b813b7..d71856a7b 100755 --- a/nebula/addons/mobility.py +++ b/nebula/addons/mobility.py @@ -347,7 +347,7 @@ async def change_connections_based_on_distance(self): ): # eth1 is the interface of the container that connects to the node network - eth0 is the interface of the container that connects to the frontend/backend self.cm._set_network_conditions( - interface="eth1", + interface="eth0", network=addr.split(":")[0], bandwidth=conditions["bandwidth"], delay=conditions["delay"], diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index d87e8daf4..d52c33306 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -457,7 +457,7 @@ async def network_wait(self): async def deploy_additional_services(self): logging.info("🌐 Deploying additional services...") - self._generate_network_conditions() + # self._generate_network_conditions() await self._forwarder.start() if self.config.participant["mobility_args"]["mobility"]: pass From 58c65b549cfd9dc81990f45279588b6abc612deb Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Sun, 23 Feb 2025 13:49:00 +0100 Subject: [PATCH 112/233] fix mobility low threshold error --- nebula/addons/mobility.py | 14 ++++++------- nebula/core/network/blacklist.py | 2 +- nebula/core/network/communications.py | 30 +++++++++++++++++++++++++++ 3 files changed, 38 insertions(+), 8 deletions(-) diff --git a/nebula/addons/mobility.py b/nebula/addons/mobility.py index d71856a7b..a007403bb 100755 --- a/nebula/addons/mobility.py +++ b/nebula/addons/mobility.py @@ -67,7 +67,7 @@ def __init__(self, config, cm: "CommunicationsManager"): 100: {"bandwidth": "5Gbps", "delay": "5ms"}, 200: {"bandwidth": "2Gbps", "delay": "50ms"}, 300: {"bandwidth": "100Mbps", "delay": "200ms"}, - float("inf"): {"bandwidth": "10Mbps", "delay": "100000ms"}, + float("inf"): {"bandwidth": "10Mbps", "delay": "1000ms"}, } # Current network conditions of each connection {addr: {bandwidth: "5Gbps", delay: "0ms"}} self.current_network_conditions = {} @@ -340,13 +340,13 @@ async def change_connections_based_on_distance(self): # If the distance is not found, we skip the node continue conditions = await self.calculate_network_conditions(distance) - logging.info(f"Conditions for source: {addr}, | {conditions}") + #logging.info(f"Conditions for source: {addr}, | {conditions}") # Only update the network conditions if they have changed if ( addr not in self.current_network_conditions or self.current_network_conditions[addr] != conditions ): # eth1 is the interface of the container that connects to the node network - eth0 is the interface of the container that connects to the frontend/backend - self.cm._set_network_conditions( + self.cm.set_network_conditions( interface="eth0", network=addr.split(":")[0], bandwidth=conditions["bandwidth"], @@ -383,8 +383,8 @@ def extract_number(value): # Si la distancia es menor que el primer umbral, devolver la mejor condiciΓ³n if distance < thresholds[0]: return { - "bandwidth": float(self.network_conditions[thresholds[0]]["bandwidth"]), - "delay": float(self.network_conditions[thresholds[0]]["delay"]) + "bandwidth": self.network_conditions[thresholds[0]]["bandwidth"], + "delay": self.network_conditions[thresholds[0]]["delay"] } # Encontrar el tramo en el que se encuentra la distancia @@ -396,7 +396,7 @@ def extract_number(value): break if lower_bound <= distance < upper_bound: - logging.info(f"Bounds | lower: {lower_bound} | upper: {upper_bound}") + #logging.info(f"Bounds | lower: {lower_bound} | upper: {upper_bound}") lower_cond = self.network_conditions[lower_bound] upper_cond = self.network_conditions[upper_bound] @@ -412,7 +412,7 @@ def extract_number(value): # Calcular el progreso en el tramo (0 a 1) progress = (distance - lower_bound) / (upper_bound - lower_bound) - logging.info(f"Progress between the bounds: {progress}") + #logging.info(f"Progress between the bounds: {progress}") # InterpolaciΓ³n lineal de valores bandwidth_value = lower_bandwidth_value - progress * (lower_bandwidth_value - upper_bandwidth_value) diff --git a/nebula/core/network/blacklist.py b/nebula/core/network/blacklist.py index ac7bdd531..6f9364de2 100644 --- a/nebula/core/network/blacklist.py +++ b/nebula/core/network/blacklist.py @@ -13,7 +13,7 @@ def __init__( ): self._max_time_listed = max_time_listed self._blacklisted_nodes: dict = {} - self._recently_disconnected: set = set() # para no inentar coenctarse a recently disconnected + self._recently_disconnected: set = set() self._recently_disconnected_lock = Locker(name="recently_disconnected_lock", async_lock=True) self._blacklisted_nodes_lock = Locker(name="blacklisted_nodes_lock", async_lock=True) self._bl_cleaner_running = False diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index d52c33306..fc8920544 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -529,6 +529,21 @@ def _reset_network_conditions(self): logging.exception(f"❗️ Network simulation error: {e}") return + def set_network_conditions( + self, + interface="eth0", + network="192.168.50.2", + bandwidth="5Gbps", + delay="0ms", + delay_distro="0ms", + delay_distribution="normal", + loss="0%", + duplicate="0%", + corrupt="0%", + reordering="0%", + ): + self._set_network_conditions(self, interface, network, bandwidth, delay, delay_distro, delay_distribution, loss, duplicate, corrupt, reordering) + def _set_network_conditions( self, interface="eth0", @@ -578,6 +593,21 @@ def _set_network_conditions( except Exception as e: logging.exception(f"❗️ Network simulation error: {e}") return + + def _set_multicast_conditions( + self, + interface="eth0", + network="192.168.50.2", + bandwidth="5Gbps", + delay="0ms", + delay_distro="0ms", + delay_distribution="normal", + loss="0%", + duplicate="0%", + corrupt="0%", + reordering="0%", + ): + pass async def include_received_message_hash(self, hash_message): try: From 00eff3fa8ef0af88941fde3ae62685ae12caade5 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Sun, 23 Feb 2025 16:40:08 +0100 Subject: [PATCH 113/233] feature network simulator --- nebula/core/network/communications.py | 2 +- .../externalconnectionservice.py | 2 +- .../nebuladiscoveryservice.py | 2 +- .../nebulanetworksimulator.py | 96 +++++++++++++++++++ .../networksimulation/networksimulator.py | 34 +++++++ 5 files changed, 133 insertions(+), 3 deletions(-) rename nebula/core/network/{ => externalconnection}/externalconnectionservice.py (91%) rename nebula/core/network/{ => externalconnection}/nebuladiscoveryservice.py (98%) create mode 100644 nebula/core/network/networksimulation/nebulanetworksimulator.py create mode 100644 nebula/core/network/networksimulation/networksimulator.py diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index fc8920544..12f5cf826 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -11,7 +11,7 @@ from nebula.core.network.blacklist import BlackList from nebula.core.network.connection import Connection from nebula.core.network.discoverer import Discoverer -from nebula.core.network.externalconnectionservice import factory_connection_service +from nebula.core.network.externalconnection.externalconnectionservice import factory_connection_service from nebula.core.network.forwarder import Forwarder from nebula.core.network.messages import MessageEvent, MessagesManager from nebula.core.network.propagator import Propagator diff --git a/nebula/core/network/externalconnectionservice.py b/nebula/core/network/externalconnection/externalconnectionservice.py similarity index 91% rename from nebula/core/network/externalconnectionservice.py rename to nebula/core/network/externalconnection/externalconnectionservice.py index b5c864bb8..1e2241bb2 100644 --- a/nebula/core/network/externalconnectionservice.py +++ b/nebula/core/network/externalconnection/externalconnectionservice.py @@ -39,7 +39,7 @@ class ExternalConnectionServiceException(Exception): pass def factory_connection_service(con_serv, cm, addr) -> ExternalConnectionService: - from nebula.core.network.nebuladiscoveryservice import NebulaConnectionService + from nebula.core.network.externalconnection.nebuladiscoveryservice import NebulaConnectionService CONNECTION_SERVICES = { "nebula": NebulaConnectionService, diff --git a/nebula/core/network/nebuladiscoveryservice.py b/nebula/core/network/externalconnection/nebuladiscoveryservice.py similarity index 98% rename from nebula/core/network/nebuladiscoveryservice.py rename to nebula/core/network/externalconnection/nebuladiscoveryservice.py index 8cc4908b3..037c9b535 100644 --- a/nebula/core/network/nebuladiscoveryservice.py +++ b/nebula/core/network/externalconnection/nebuladiscoveryservice.py @@ -2,7 +2,7 @@ import logging import socket import struct -from nebula.core.network.externalconnectionservice import ExternalConnectionService +from nebula.core.network.externalconnection.externalconnectionservice import ExternalConnectionService from nebula.core.utils.locker import Locker from typing import TYPE_CHECKING diff --git a/nebula/core/network/networksimulation/nebulanetworksimulator.py b/nebula/core/network/networksimulation/nebulanetworksimulator.py new file mode 100644 index 000000000..6c3562424 --- /dev/null +++ b/nebula/core/network/networksimulation/nebulanetworksimulator.py @@ -0,0 +1,96 @@ +import asyncio +import logging +from nebula.core.network.networksimulation.networksimulator import NetworkSimulator +from nebula.core.utils.locker import Locker + +class NebulaNS(NetworkSimulator): + NETWORK_CONDITIONS = { + 100: {"bandwidth": "5Gbps", "delay": "5ms"}, + 200: {"bandwidth": "2Gbps", "delay": "50ms"}, + 300: {"bandwidth": "100Mbps", "delay": "200ms"}, + float("inf"): {"bandwidth": "10Mbps", "delay": "1000ms"}, + } + + def __init__(self, verbose=False): + self._verbose = verbose + self._network_conditions = self.NETWORK_CONDITIONS.copy() + self._network_conditions_lock = Locker("network_conditions_lock", async_lock=True) + + async def set_thresholds(self, threshold : dict): + pass + + def set_network_conditions(self, dest_addr, distance): + pass + + def _set_network_condition_for_addr(self): + pass + + def _set_network_condition_for_multicast(self): + pass + + async def calculate_network_conditions(self, distance): + def extract_number(value): + import re + match = re.match(r"([\d.]+)", value) + if not match: + raise ValueError(f"Formato invΓ‘lido: {value}") + return float(match.group(1)) + + if self._verbose: logging.info(f"Calculating conditions for distance: {distance}") + conditions = {} + #TODO hacer copia del diccionario dentro de locks + thresholds = sorted(self.network_conditions.keys()) + + # Si la distancia es menor que el primer umbral, devolver la mejor condiciΓ³n + if distance < thresholds[0]: + return { + "bandwidth": self.network_conditions[thresholds[0]]["bandwidth"], + "delay": self.network_conditions[thresholds[0]]["delay"] + } + + # Encontrar el tramo en el que se encuentra la distancia + for i in range(len(thresholds) - 1): + lower_bound = thresholds[i] + upper_bound = thresholds[i + 1] + + if upper_bound == float("inf"): + break + + if lower_bound <= distance < upper_bound: + #logging.info(f"Bounds | lower: {lower_bound} | upper: {upper_bound}") + lower_cond = self.network_conditions[lower_bound] + upper_cond = self.network_conditions[upper_bound] + + # Extraer valores numΓ©ricos y unidades + lower_bandwidth_value = extract_number(lower_cond["bandwidth"]) + upper_bandwidth_value = extract_number(upper_cond["bandwidth"]) + lower_bandwidth_unit = lower_cond["bandwidth"].replace(str(lower_bandwidth_value), "") + upper_bandwidth_unit = upper_cond["bandwidth"].replace(str(upper_bandwidth_value), "") + + lower_delay_value = extract_number(lower_cond["delay"]) + upper_delay_value = extract_number(upper_cond["delay"]) + delay_unit = lower_cond["delay"].replace(str(lower_delay_value), "") + + # Calcular el progreso en el tramo (0 a 1) + progress = (distance - lower_bound) / (upper_bound - lower_bound) + #logging.info(f"Progress between the bounds: {progress}") + + # InterpolaciΓ³n lineal de valores + bandwidth_value = lower_bandwidth_value - progress * (lower_bandwidth_value - upper_bandwidth_value) + delay_value = lower_delay_value + progress * (upper_delay_value - lower_delay_value) + + # Reconstruir valores con unidades originales + bandwidth = f"{round(bandwidth_value, 2)}{lower_bandwidth_unit}" + delay = f"{round(delay_value, 2)}{delay_unit}" + + return {"bandwidth": bandwidth, "delay": delay} + + # Si la distancia es infinita, devolver el ΓΊltimo valor + return { + "bandwidth": self.network_conditions[float("inf")]["bandwidth"], + "delay": self.network_conditions[float("inf")]["delay"] + } + return conditions + + def clear_network_conditions(self): + pass \ No newline at end of file diff --git a/nebula/core/network/networksimulation/networksimulator.py b/nebula/core/network/networksimulation/networksimulator.py new file mode 100644 index 000000000..62c706917 --- /dev/null +++ b/nebula/core/network/networksimulation/networksimulator.py @@ -0,0 +1,34 @@ +import asyncio +from abc import ABC, abstractmethod + +class NetworkSimulator(ABC): + + @abstractmethod + async def set_thresholds(self, threshold : dict): + pass + + @abstractmethod + def set_network_conditions(self, dest_addr, distance): + pass + + @abstractmethod + def clear_network_conditions(self): + pass + + +class NetworkSimulatorException(Exception): + pass + +def factory_connection_service(net_sim, cm, addr) -> NetworkSimulator: + from nebula.core.network.networksimulation.nebulanetworksimulator import NebulaNS + + SIMULATION_SERVICES = { + "nebula": NebulaNS, + } + + con_serv = SIMULATION_SERVICES.get(net_sim, NebulaNS) + + if con_serv: + return con_serv(cm, addr) + else: + raise NetworkSimulatorException(f"Network Simulator {net_sim} not found") \ No newline at end of file From 8c4ec3b1a8670b7f60829e5a7d6ef985db91d15f Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Mon, 24 Feb 2025 13:08:01 +0100 Subject: [PATCH 114/233] feature integrated nebula network simulator --- nebula/addons/mobility.py | 2 +- nebula/core/network/communications.py | 287 +++++++++--------- .../nebulanetworksimulator.py | 232 ++++++++++++-- .../networksimulation/networksimulator.py | 22 +- 4 files changed, 363 insertions(+), 180 deletions(-) diff --git a/nebula/addons/mobility.py b/nebula/addons/mobility.py index a007403bb..dc8e22fdb 100755 --- a/nebula/addons/mobility.py +++ b/nebula/addons/mobility.py @@ -135,7 +135,7 @@ async def run_mobility(self): await asyncio.sleep(self.grace_time) while True: await self.change_geo_location() - await self.change_connections_based_on_distance() + #await self.change_connections_based_on_distance() await asyncio.sleep(self.period) async def change_geo_location_random_strategy(self, latitude, longitude): diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 12f5cf826..bf3fceec8 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -12,6 +12,7 @@ from nebula.core.network.connection import Connection from nebula.core.network.discoverer import Discoverer from nebula.core.network.externalconnection.externalconnectionservice import factory_connection_service +from nebula.core.network.networksimulation.networksimulator import factory_network_simulator from nebula.core.network.forwarder import Forwarder from nebula.core.network.messages import MessageEvent, MessagesManager from nebula.core.network.propagator import Propagator @@ -73,6 +74,10 @@ def __init__(self, engine: "Engine"): # Connection service to communicate with external devices self._external_connection_service = factory_connection_service("nebula", self, self.addr) + + # Network simulator service to deplay realistic network conditions + refresh_conditions_interval = 5 + self._network_simulator = factory_network_simulator("nebula", self, refresh_conditions_interval, "eth0", verbose=True) @property def engine(self): @@ -109,6 +114,10 @@ def mobility(self): @property def ecs(self): return self._external_connection_service + + @property + def ns(self): + return self._network_simulator @property def bl(self): @@ -254,12 +263,144 @@ async def stablish_connection_to_federation(self, msg_type="discover_join", addr discovers_sent += 1 return discovers_sent + """ ############################## + # NETWORK CONDITIONS # + ############################## + """ + + async def get_network_conditions_grace_time(self): + return await self.config.participant["mobility_args"]["change_geo_interval"] + + def _generate_network_conditions(self): + # TODO: Implement selection of network conditions from frontend + if self.config.participant["network_args"]["simulation"]: + interface = self.config.participant["network_args"]["interface"] + bandwidth = self.config.participant["network_args"]["bandwidth"] + delay = self.config.participant["network_args"]["delay"] + delay_distro = self.config.participant["network_args"]["delay-distro"] + delay_distribution = self.config.participant["network_args"]["delay-distribution"] + loss = self.config.participant["network_args"]["loss"] + duplicate = self.config.participant["network_args"]["duplicate"] + corrupt = self.config.participant["network_args"]["corrupt"] + reordering = self.config.participant["network_args"]["reordering"] + logging.info( + f"🌐 Network simulation is enabled | Interface: {interface} | Bandwidth: {bandwidth} | Delay: {delay} | Delay Distro: {delay_distro} | Delay Distribution: {delay_distribution} | Loss: {loss} | Duplicate: {duplicate} | Corrupt: {corrupt} | Reordering: {reordering}" + ) + try: + results = subprocess.run( + [ + "tcset", + str(interface), + "--rate", + str(bandwidth), + "--delay", + str(delay), + "--delay-distro", + str(delay_distro), + "--delay-distribution", + str(delay_distribution), + "--loss", + str(loss), + "--duplicate", + str(duplicate), + "--corrupt", + str(corrupt), + "--reordering", + str(reordering), + ], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + except Exception as e: + logging.exception(f"🌐 Network simulation error: {e}") + return + else: + logging.info("🌐 Network simulation is disabled. Using default network conditions...") + + def _reset_network_conditions(self): + interface = self.config.participant["network_args"]["interface"] + logging.info("🌐 Resetting network conditions") + try: + results = subprocess.run( + ["tcdel", str(interface), "--all"], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + except Exception as e: + logging.exception(f"❗️ Network simulation error: {e}") + return + + async def set_network_conditions(self, addr, distance): + await self.ns.set_network_conditions(addr, distance) + #self._set_network_conditions(self, interface, network, bandwidth, delay, delay_distro, delay_distribution, loss, duplicate, corrupt, reordering) + + def clear_network_conditions(self): + self.ns.clear_network_conditions() + + async def set_network_conditions_thresholds(self, thresholds : dict): + await self.ns.set_thresholds(thresholds) + + def _set_network_conditions( + self, + interface="eth0", + network="192.168.50.2", + bandwidth="5Gbps", + delay="0ms", + delay_distro="0ms", + delay_distribution="normal", + loss="0%", + duplicate="0%", + corrupt="0%", + reordering="0%", + ): + logging.info( + f"🌐 Changing network conditions | Interface: {interface} | Network: {network} | Bandwidth: {bandwidth} | Delay: {delay} | Delay Distro: {delay_distro} | Delay Distribution: {delay_distribution} | Loss: {loss} | Duplicate: {duplicate} | Corrupt: {corrupt} | Reordering: {reordering}" + ) + try: + results = subprocess.run( + [ + "tcset", + str(interface), + "--network", + str(network) if network is not None else "", + "--rate", + str(bandwidth), + "--delay", + str(delay), + "--delay-distro", + str(delay_distro), + "--delay-distribution", + str(delay_distribution), + "--loss", + str(loss), + "--duplicate", + str(duplicate), + "--corrupt", + str(corrupt), + "--reordering", + str(reordering), + "--change", + ], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + except Exception as e: + logging.exception(f"❗️ Network simulation error: {e}") + return + + + """ ############################## # OTHER FUNCTIONALITIES # ############################## """ - #TODO setcondition para la direccion multicast async def update_geolocalization(self, geoloc : dict): async with self.get_connections_lock(): #logging.info("Update geolocs to simulate network conditions") @@ -460,154 +601,12 @@ async def deploy_additional_services(self): # self._generate_network_conditions() await self._forwarder.start() if self.config.participant["mobility_args"]["mobility"]: - pass + await self.ns.start() # await self._discoverer.start() # await self._health.start() self._propagator.start() await self._mobility.start() - - def _generate_network_conditions(self): - # TODO: Implement selection of network conditions from frontend - if self.config.participant["network_args"]["simulation"]: - interface = self.config.participant["network_args"]["interface"] - bandwidth = self.config.participant["network_args"]["bandwidth"] - delay = self.config.participant["network_args"]["delay"] - delay_distro = self.config.participant["network_args"]["delay-distro"] - delay_distribution = self.config.participant["network_args"]["delay-distribution"] - loss = self.config.participant["network_args"]["loss"] - duplicate = self.config.participant["network_args"]["duplicate"] - corrupt = self.config.participant["network_args"]["corrupt"] - reordering = self.config.participant["network_args"]["reordering"] - logging.info( - f"🌐 Network simulation is enabled | Interface: {interface} | Bandwidth: {bandwidth} | Delay: {delay} | Delay Distro: {delay_distro} | Delay Distribution: {delay_distribution} | Loss: {loss} | Duplicate: {duplicate} | Corrupt: {corrupt} | Reordering: {reordering}" - ) - try: - results = subprocess.run( - [ - "tcset", - str(interface), - "--rate", - str(bandwidth), - "--delay", - str(delay), - "--delay-distro", - str(delay_distro), - "--delay-distribution", - str(delay_distribution), - "--loss", - str(loss), - "--duplicate", - str(duplicate), - "--corrupt", - str(corrupt), - "--reordering", - str(reordering), - ], - check=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - ) - except Exception as e: - logging.exception(f"🌐 Network simulation error: {e}") - return - else: - logging.info("🌐 Network simulation is disabled. Using default network conditions...") - - def _reset_network_conditions(self): - interface = self.config.participant["network_args"]["interface"] - logging.info("🌐 Resetting network conditions") - try: - results = subprocess.run( - ["tcdel", str(interface), "--all"], - check=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - ) - except Exception as e: - logging.exception(f"❗️ Network simulation error: {e}") - return - - def set_network_conditions( - self, - interface="eth0", - network="192.168.50.2", - bandwidth="5Gbps", - delay="0ms", - delay_distro="0ms", - delay_distribution="normal", - loss="0%", - duplicate="0%", - corrupt="0%", - reordering="0%", - ): - self._set_network_conditions(self, interface, network, bandwidth, delay, delay_distro, delay_distribution, loss, duplicate, corrupt, reordering) - - def _set_network_conditions( - self, - interface="eth0", - network="192.168.50.2", - bandwidth="5Gbps", - delay="0ms", - delay_distro="0ms", - delay_distribution="normal", - loss="0%", - duplicate="0%", - corrupt="0%", - reordering="0%", - ): - logging.info( - f"🌐 Changing network conditions | Interface: {interface} | Network: {network} | Bandwidth: {bandwidth} | Delay: {delay} | Delay Distro: {delay_distro} | Delay Distribution: {delay_distribution} | Loss: {loss} | Duplicate: {duplicate} | Corrupt: {corrupt} | Reordering: {reordering}" - ) - try: - results = subprocess.run( - [ - "tcset", - str(interface), - "--network", - str(network) if network is not None else "", - "--rate", - str(bandwidth), - "--delay", - str(delay), - "--delay-distro", - str(delay_distro), - "--delay-distribution", - str(delay_distribution), - "--loss", - str(loss), - "--duplicate", - str(duplicate), - "--corrupt", - str(corrupt), - "--reordering", - str(reordering), - "--change", - ], - check=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - ) - except Exception as e: - logging.exception(f"❗️ Network simulation error: {e}") - return - def _set_multicast_conditions( - self, - interface="eth0", - network="192.168.50.2", - bandwidth="5Gbps", - delay="0ms", - delay_distro="0ms", - delay_distribution="normal", - loss="0%", - duplicate="0%", - corrupt="0%", - reordering="0%", - ): - pass async def include_received_message_hash(self, hash_message): try: diff --git a/nebula/core/network/networksimulation/nebulanetworksimulator.py b/nebula/core/network/networksimulation/nebulanetworksimulator.py index 6c3562424..3e6058ad2 100644 --- a/nebula/core/network/networksimulation/nebulanetworksimulator.py +++ b/nebula/core/network/networksimulation/nebulanetworksimulator.py @@ -1,34 +1,195 @@ import asyncio +import subprocess import logging from nebula.core.network.networksimulation.networksimulator import NetworkSimulator from nebula.core.utils.locker import Locker +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from nebula.core.network.communications import CommunicationsManager class NebulaNS(NetworkSimulator): NETWORK_CONDITIONS = { 100: {"bandwidth": "5Gbps", "delay": "5ms"}, 200: {"bandwidth": "2Gbps", "delay": "50ms"}, 300: {"bandwidth": "100Mbps", "delay": "200ms"}, - float("inf"): {"bandwidth": "10Mbps", "delay": "1000ms"}, + float("inf"): {"bandwidth": "10Mbps", "delay": "100000ms"}, } + IP_MULTICAST = "239.255.255.250" - def __init__(self, verbose=False): + def __init__(self, communication_manager: "CommunicationsManager", changing_interval, interface, verbose=False): + self._cm = communication_manager + self._refresh_interval = changing_interval + self._node_interface = interface self._verbose = verbose self._network_conditions = self.NETWORK_CONDITIONS.copy() self._network_conditions_lock = Locker("network_conditions_lock", async_lock=True) + self._current_network_conditions = {} + self._running = False - async def set_thresholds(self, threshold : dict): - pass + async def start(self): + logging.info("🌐 Nebula Network Simulator starting...") + self._running = True + asyncio.create_task(self._change_network_conditions_based_on_distances()) + + async def stop(self): + self._running = False + + async def _change_network_conditions_based_on_distances(self): + grace_time = self._cm.config.participant["mobility_args"]["grace_time_mobility"] + if self._verbose: logging.info(f"Waiting {grace_time}s to start applying network conditions based on distances between devices") + await asyncio.sleep(grace_time) + + while self._running: + await asyncio.sleep(self._refresh_interval) + if self._verbose: logging.info("Refresh | conditions based on distances...") + current_connections = await self._cm.get_addrs_current_connections() + try: + for addr in current_connections: + distance = self._cm.connections[addr].get_neighbor_distance() + if distance is None: + # If the distance is not found, we skip the node + continue + conditions = await self._calculate_network_conditions(distance) + # Only update the network conditions if they have changed + if (addr not in self._current_network_conditions or self._current_network_conditions[addr] != conditions): + addr_ip = addr.split(":")[0] + self._set_network_condition_for_addr(self._node_interface, addr_ip, conditions["bandwidth"], conditions["delay"]) + self._set_network_condition_for_multicast(self._node_interface, addr_ip, self.IP_MULTICAST, conditions["bandwidth"], conditions["delay"]) + self._current_network_conditions[addr] = conditions + else: + logging.info("network conditions havent changed since last time") + except KeyError: + logging.exception(f"πŸ“ Connection {addr} not found") + except Exception: + logging.exception("πŸ“ Error changing connections based on distance") + + async def set_thresholds(self, thresholds : dict): + async with self._network_conditions_lock: + self._network_conditions = thresholds - def set_network_conditions(self, dest_addr, distance): - pass + async def set_network_conditions(self, dest_addr, distance): + conditions = await self._calculate_network_conditions(distance) + self._set_network_condition_for_addr(self, + interface=self._node_interface, + network=dest_addr, + bandwidth=conditions["bandwidth"], + delay=conditions["delay"] + ) + + self._set_network_condition_for_multicast(self, + interface=self._node_interface, + src_network=dest_addr, + dst_network=self.IP_MULTICAST, + bandwidth=conditions["bandwidth"], + delay=conditions["delay"] + ) - def _set_network_condition_for_addr(self): - pass + def _set_network_condition_for_addr( + self, + interface="eth0", + network="192.168.50.2", + bandwidth="5Gbps", + delay="0ms", + delay_distro="10ms", + delay_distribution="normal", + loss="0%", + duplicate="0%", + corrupt="0%", + reordering="0%", + ): + + if self._verbose: + logging.info(f"🌐 Changing network conditions | Interface: {interface} | Network: {network} | Bandwidth: {bandwidth} | Delay: {delay} | Delay Distro: {delay_distro} | Delay Distribution: {delay_distribution} | Loss: {loss} | Duplicate: {duplicate} | Corrupt: {corrupt} | Reordering: {reordering}") + try: + results = subprocess.run( + [ + "tcset", + str(interface), + "--network", + str(network) if network is not None else "", + "--rate", + str(bandwidth), + "--delay", + str(delay), + "--delay-distro", + str(delay_distro), + "--delay-distribution", + str(delay_distribution), + "--loss", + str(loss), + "--duplicate", + str(duplicate), + "--corrupt", + str(corrupt), + "--reordering", + str(reordering), + "--change", + ], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + except Exception as e: + logging.exception(f"❗️ Network simulation error: {e}") + return - def _set_network_condition_for_multicast(self): - pass + def _set_network_condition_for_multicast( + self, + interface="eth0", + src_network="", + dst_network="", + bandwidth="5Gbps", + delay="0ms", + delay_distro="10ms", + delay_distribution="normal", + loss="0%", + duplicate="0%", + corrupt="0%", + reordering="0%", + ): + if self._verbose: + logging.info(f"🌐 Changing multicast conditions | Interface: {interface} | Src Network: {src_network} | Bandwidth: {bandwidth} | Delay: {delay} | Delay Distro: {delay_distro} | Delay Distribution: {delay_distribution} | Loss: {loss} | Duplicate: {duplicate} | Corrupt: {corrupt} | Reordering: {reordering}") + + try: + results = subprocess.run( + [ + "tcset", + str(interface), + "--src-network", + str(src_network), + "--dst-network", + str(dst_network), + "--rate", + str(bandwidth), + "--delay", + str(delay), + "--delay-distro", + str(delay_distro), + "--delay-distribution", + str(delay_distribution), + "--loss", + str(loss), + "--duplicate", + str(duplicate), + "--corrupt", + str(corrupt), + "--reordering", + str(reordering), + "--direction", + "incoming", + "--change", + ], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + except Exception as e: + logging.exception(f"❗️ Network simulation error: {e}") + return - async def calculate_network_conditions(self, distance): + async def _calculate_network_conditions(self, distance): def extract_number(value): import re match = re.match(r"([\d.]+)", value) @@ -36,16 +197,18 @@ def extract_number(value): raise ValueError(f"Formato invΓ‘lido: {value}") return float(match.group(1)) - if self._verbose: logging.info(f"Calculating conditions for distance: {distance}") - conditions = {} - #TODO hacer copia del diccionario dentro de locks - thresholds = sorted(self.network_conditions.keys()) + if self._verbose: logging.info(f"Calculating conditions for distance: {distance}m") + conditions = None + async with self._network_conditions_lock: + th = self._network_conditions.copy() + + thresholds = sorted(th.keys()) # Si la distancia es menor que el primer umbral, devolver la mejor condiciΓ³n if distance < thresholds[0]: - return { - "bandwidth": self.network_conditions[thresholds[0]]["bandwidth"], - "delay": self.network_conditions[thresholds[0]]["delay"] + conditions = { + "bandwidth": th[thresholds[0]]["bandwidth"], + "delay": th[thresholds[0]]["delay"] } # Encontrar el tramo en el que se encuentra la distancia @@ -58,8 +221,8 @@ def extract_number(value): if lower_bound <= distance < upper_bound: #logging.info(f"Bounds | lower: {lower_bound} | upper: {upper_bound}") - lower_cond = self.network_conditions[lower_bound] - upper_cond = self.network_conditions[upper_bound] + lower_cond = th[lower_bound] + upper_cond = th[upper_bound] # Extraer valores numΓ©ricos y unidades lower_bandwidth_value = extract_number(lower_cond["bandwidth"]) @@ -73,7 +236,7 @@ def extract_number(value): # Calcular el progreso en el tramo (0 a 1) progress = (distance - lower_bound) / (upper_bound - lower_bound) - #logging.info(f"Progress between the bounds: {progress}") + if self._verbose: logging.info(f"Progress between the bounds: {progress}") # InterpolaciΓ³n lineal de valores bandwidth_value = lower_bandwidth_value - progress * (lower_bandwidth_value - upper_bandwidth_value) @@ -83,14 +246,27 @@ def extract_number(value): bandwidth = f"{round(bandwidth_value, 2)}{lower_bandwidth_unit}" delay = f"{round(delay_value, 2)}{delay_unit}" - return {"bandwidth": bandwidth, "delay": delay} + conditions = {"bandwidth": bandwidth, "delay": delay} # Si la distancia es infinita, devolver el ΓΊltimo valor - return { - "bandwidth": self.network_conditions[float("inf")]["bandwidth"], - "delay": self.network_conditions[float("inf")]["delay"] - } + if not conditions: + conditions = { + "bandwidth": th[float("inf")]["bandwidth"], + "delay": th[float("inf")]["delay"] + } + if self._verbose: logging.info(f"Network conditions: {conditions}") return conditions - def clear_network_conditions(self): - pass \ No newline at end of file + def clear_network_conditions(self, interface): + if self._verbose: logging.info("🌐 Resetting network conditions") + try: + results = subprocess.run( + ["tcdel", str(interface), "--all"], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + except Exception as e: + logging.exception(f"❗️ Network simulation error: {e}") + return \ No newline at end of file diff --git a/nebula/core/network/networksimulation/networksimulator.py b/nebula/core/network/networksimulation/networksimulator.py index 62c706917..5465683ed 100644 --- a/nebula/core/network/networksimulation/networksimulator.py +++ b/nebula/core/network/networksimulation/networksimulator.py @@ -3,32 +3,40 @@ class NetworkSimulator(ABC): + @abstractmethod + async def start(self): + pass + + @abstractmethod + async def stop(self): + pass + @abstractmethod - async def set_thresholds(self, threshold : dict): + async def set_thresholds(self, thresholds : dict): pass @abstractmethod - def set_network_conditions(self, dest_addr, distance): + async def set_network_conditions(self, dest_addr, distance): pass @abstractmethod - def clear_network_conditions(self): + def clear_network_conditions(self, interface): pass class NetworkSimulatorException(Exception): pass -def factory_connection_service(net_sim, cm, addr) -> NetworkSimulator: +def factory_network_simulator(net_sim, communication_manager, changing_interval, interface, verbose) -> NetworkSimulator: from nebula.core.network.networksimulation.nebulanetworksimulator import NebulaNS SIMULATION_SERVICES = { "nebula": NebulaNS, } - con_serv = SIMULATION_SERVICES.get(net_sim, NebulaNS) + net_serv = SIMULATION_SERVICES.get(net_sim, NebulaNS) - if con_serv: - return con_serv(cm, addr) + if net_serv: + return net_serv(communication_manager, changing_interval, interface, verbose) else: raise NetworkSimulatorException(f"Network Simulator {net_sim} not found") \ No newline at end of file From aaccf0ab8f4b8a2707f5d3809e487d7cbb6bd5d4 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Tue, 25 Feb 2025 13:56:50 +0100 Subject: [PATCH 115/233] feature SA submodules --- nebula/addons/mobility.py | 2 +- nebula/core/network/communications.py | 5 +- .../awareness/GPS/gpsmodule.py | 4 + .../awareness/GPS/nebulagps.py | 5 +- .../awareness/samodule.py | 248 +++++++++--------- .../neighborpolicies/__init__.py | 0 .../neighborpolicies/fcneighborpolicy.py | 7 +- .../neighborpolicies/idleneighborpolicy.py | 7 +- .../neighborpolicies/neighborpolicy.py | 8 +- .../neighborpolicies/ringneighborpolicy.py | 4 +- .../neighborpolicies/starneighborpolicy.py | 2 +- .../connectionoptimizer.py | 0 .../networkoptimization/networkoptimizer.py | 0 .../networkoptimization/timergenerator.py | 0 .../awareness/sanetwork/sanetwork.py | 200 ++++++++++++++ .../awareness/satraining/satraining.py | 25 ++ .../core/situationalawareness/nodemanager.py | 25 +- 17 files changed, 393 insertions(+), 149 deletions(-) rename nebula/core/situationalawareness/awareness/{ => sanetwork}/neighborpolicies/__init__.py (100%) rename nebula/core/situationalawareness/awareness/{ => sanetwork}/neighborpolicies/fcneighborpolicy.py (89%) rename nebula/core/situationalawareness/awareness/{ => sanetwork}/neighborpolicies/idleneighborpolicy.py (89%) rename nebula/core/situationalawareness/awareness/{ => sanetwork}/neighborpolicies/neighborpolicy.py (68%) rename nebula/core/situationalawareness/awareness/{ => sanetwork}/neighborpolicies/ringneighborpolicy.py (94%) rename nebula/core/situationalawareness/awareness/{ => sanetwork}/neighborpolicies/starneighborpolicy.py (96%) rename nebula/core/situationalawareness/awareness/{ => sanetwork}/networkoptimization/connectionoptimizer.py (100%) rename nebula/core/situationalawareness/awareness/{ => sanetwork}/networkoptimization/networkoptimizer.py (100%) rename nebula/core/situationalawareness/awareness/{ => sanetwork}/networkoptimization/timergenerator.py (100%) create mode 100644 nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py create mode 100644 nebula/core/situationalawareness/awareness/satraining/satraining.py diff --git a/nebula/addons/mobility.py b/nebula/addons/mobility.py index dc8e22fdb..0a516f591 100755 --- a/nebula/addons/mobility.py +++ b/nebula/addons/mobility.py @@ -273,7 +273,7 @@ async def change_geo_location(self): # Get neighbor closer to me selected_neighbor = await self.cm.get_nearest_connections(top=1) if selected_neighbor: - logging.info(f"πŸ“ Selected neighbor: {selected_neighbor}") + #logging.info(f"πŸ“ Selected neighbor: {selected_neighbor}") try: ( neighbor_latitude, diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index bf3fceec8..9296c208b 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -77,7 +77,7 @@ def __init__(self, engine: "Engine"): # Network simulator service to deplay realistic network conditions refresh_conditions_interval = 5 - self._network_simulator = factory_network_simulator("nebula", self, refresh_conditions_interval, "eth0", verbose=True) + self._network_simulator = factory_network_simulator("nebula", self, refresh_conditions_interval, "eth0", verbose=False) @property def engine(self): @@ -601,7 +601,8 @@ async def deploy_additional_services(self): # self._generate_network_conditions() await self._forwarder.start() if self.config.participant["mobility_args"]["mobility"]: - await self.ns.start() + if False: + await self.ns.start() # await self._discoverer.start() # await self._health.start() self._propagator.start() diff --git a/nebula/core/situationalawareness/awareness/GPS/gpsmodule.py b/nebula/core/situationalawareness/awareness/GPS/gpsmodule.py index b6df44fb7..18fbaf259 100644 --- a/nebula/core/situationalawareness/awareness/GPS/gpsmodule.py +++ b/nebula/core/situationalawareness/awareness/GPS/gpsmodule.py @@ -10,6 +10,10 @@ async def start(self): @abstractmethod async def stop(self): pass + + @abstractmethod + async def is_running(self): + pass class GPSModuleException(Exception): pass diff --git a/nebula/core/situationalawareness/awareness/GPS/nebulagps.py b/nebula/core/situationalawareness/awareness/GPS/nebulagps.py index 2951f868d..d6b7d4adf 100644 --- a/nebula/core/situationalawareness/awareness/GPS/nebulagps.py +++ b/nebula/core/situationalawareness/awareness/GPS/nebulagps.py @@ -45,11 +45,14 @@ async def start(self): async def stop(self): """Detiene el servicio de GPS.""" - #logging.info("Stopping NebulaGPS service...") + logging.info("Stopping NebulaGPS service...") self.running = False if self._broadcast_socket: self._broadcast_socket.close() self._broadcast_socket = None + + async def is_running(self): + return self.running async def _send_location_loop(self): """Envia la geolocalizaciΓ³n periΓ³dicamente por broadcast.""" diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index 4cc9fd1de..2a22ad6b8 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -3,7 +3,8 @@ from typing import TYPE_CHECKING from nebula.addons.functions import print_msg_box -from nebula.core.situationalawareness.awareness.neighborpolicies.neighborpolicy import factory_NeighborPolicy +from nebula.core.situationalawareness.awareness.sanetwork.sanetwork import SANetwork +from nebula.core.situationalawareness.awareness.satraining.satraining import SATraining from nebula.core.situationalawareness.awareness.GPS.gpsmodule import factory_gpsmodule from nebula.core.utils.locker import Locker @@ -21,7 +22,7 @@ def __init__( topology, ): print_msg_box( - msg=f"Starting Situational Awareness module...\nTopology: {topology}", + msg=f"Starting Situational Awareness module...", indent=2, title="Situational Awareness module", ) @@ -29,7 +30,8 @@ def __init__( self._addr = addr self._topology = topology self._node_manager: NodeManager = nodemanager - self._neighbor_policy = factory_NeighborPolicy(topology) + self._situational_awareness_network = SANetwork(self, self.cm, self._addr, self._topology) + self._situational_awareness_trainning = SATraining(self,"hybrid", "fastreboot") self._restructure_process_lock = Locker(name="restructure_process_lock") self._restructure_cooldown = 0 self._gpsmodule = factory_gpsmodule("nebula", self, self._addr) @@ -39,8 +41,8 @@ def nm(self): return self._node_manager @property - def np(self): - return self._neighbor_policy + def san(self): + return self._situational_awareness_network @property def cm(self): @@ -51,152 +53,154 @@ def gps(self): return self._gpsmodule async def init(self): - if not self.nm.is_additional_participant(): - logging.info("Deploying External Connection Service") - await self.cm.start_external_connection_service() - await self.cm.subscribe_beacon_listener(self.beacon_received) - await self.cm.start_beacon() + if not self.is_additional_participant(): await self.gps.start() - else: - logging.info("Deploying External Connection Service | No running") - await self.cm.start_external_connection_service(run_service=False) + await self.san.init() + + def is_additional_participant(self): + return self.nm.is_additional_participant() + + async def experiment_finish(self): + await self.san.experiment_finish() + async def get_geoloc(self): + latitude = self.nm.config.participant["mobility_args"]["latitude"] + longitude = self.nm.config.participant["mobility_args"]["longitude"] + return (latitude, longitude) + + async def mobility_actions(self): + await self.verify_gps_service() + await self.san.module_actions() - logging.info("Building neighbor policy configuration..") - self.np.set_config([ - await self.cm.get_addrs_current_connections(only_direct=True, myself=False), - await self.cm.get_addrs_current_connections(only_direct=False, only_undirected=False, myself=False), - self.nm.engine.addr, - self, - ]) + """ ############################### + # GPS SERVICE # + ############################### + """ - async def experiment_finish(self): - await self.cm.stop_external_connection_service() + async def verify_gps_service(self): + if not await self.gps.is_running(): + await self.gps.start() """ ############################### # REESTRUCTURE TOPOLOGY # ############################### """ - def _update_restructure_cooldown(self): - if self._restructure_cooldown: - self._restructure_cooldown = (self._restructure_cooldown + 1) % RESTRUCTURE_COOLDOWN + # def _update_restructure_cooldown(self): + # if self._restructure_cooldown: + # self._restructure_cooldown = (self._restructure_cooldown + 1) % RESTRUCTURE_COOLDOWN - def _restructure_available(self): - if self._restructure_cooldown: - logging.info("Reestructure on cooldown") - return self._restructure_cooldown == 0 + # def _restructure_available(self): + # if self._restructure_cooldown: + # logging.info("Reestructure on cooldown") + # return self._restructure_cooldown == 0 def get_restructure_process_lock(self): - return self._restructure_process_lock + return self.san.get_restructure_process_lock() """ ############################### - # NEIGHBOR POLICY # + # SA NETWORK # ############################### """ + async def register_node(self, node, neighbor=False, remove=False): + await self.san.register_node(self, node, neighbor, remove) + def meet_node(self, node): - if node != self._addr: - logging.info(f"Update nodes known | addr: {node}") - self.np.meet_node(node) + self.san.meet_node(node) def update_neighbors(self, node, remove=False): - self.np.update_neighbors(node, remove) + self.san.update_neighbors(node, remove) if not remove: - self.np.meet_node(node) + self.san.meet_node(node) def get_nodes_known(self, neighbors_too=False, neighbors_only=False): - return self.np.get_nodes_known(neighbors_too, neighbors_only) + return self.san.get_nodes_known(neighbors_too, neighbors_only) async def neighbors_left(self): - return len(await self.cm.get_addrs_current_connections(only_direct=True, myself=False)) > 0 + return await self.san.neighbors_left() def accept_connection(self, source, joining=False): - return self.np.accept_connection(source, joining) + return self.san.accept_connection(source, joining) def need_more_neighbors(self): - return self.np.need_more_neighbors() + return self.san.need_more_neighbors() def get_actions(self): - return self.np.get_actions() - - async def get_geoloc(self): - latitude = self.nm.config.participant["mobility_args"]["latitude"] - longitude = self.nm.config.participant["mobility_args"]["longitude"] - return (latitude, longitude) - - """ ############################### - # ROBUSTNESS # - ############################### - """ - - async def beacon_received(self, addr, geoloc): - latitude, longitude = geoloc - self.meet_node(addr) - logging.info(f"Beacon received SAModule, source: {addr}, geolocalization: {latitude},{longitude}") - - async def check_external_connection_service_status(self): - if not await self.cm.is_external_connection_service_running(): - logging.info("πŸ”„ External Service not running | Starting service...") - await self.cm.init_external_connection_service() - await self.cm.subscribe_beacon_listener(self.beacon_received) - await self.cm.start_beacon() - - async def analize_topology_robustness(self): - logging.info("πŸ”„ Analizing node network robustness...") - if not self._restructure_process_lock.locked(): - if not await self.neighbors_left(): - logging.info("No Neighbors left | reconnecting with Federation") - await self.reconnect_to_federation() - elif self.np.need_more_neighbors() and self._restructure_available(): - logging.info("Insufficient Robustness | Upgrading robustness | Searching for more connections") - self._update_restructure_cooldown() - possible_neighbors = self.np.get_nodes_known(neighbors_too=False) - possible_neighbors = await self.cm.apply_restrictions(possible_neighbors) - if not possible_neighbors: - logging.info("All possible neighbors using nodes known are restricted...") - else: - pass - # asyncio.create_task(self.upgrade_connection_robustness(possible_neighbors)) - else: - logging.info("Sufficient Robustness | no actions required") - else: - logging.info("❗️ Reestructure/Reconnecting process already running...") - - async def reconnect_to_federation(self): - self._restructure_process_lock.acquire() - await self.cm.clear_restrictions() - await asyncio.sleep(120) - # If we got some refs, try to reconnect to them - if len(self.np.get_nodes_known()) > 0: - logging.info("Reconnecting | Addrs availables") - await self.nm.start_late_connection_process( - connected=False, msg_type="discover_nodes", addrs_known=self.np.get_nodes_known() - ) - else: - logging.info("Reconnecting | NO Addrs availables") - await self.nm.start_late_connection_process(connected=False, msg_type="discover_nodes") - self._restructure_process_lock.release() - - async def upgrade_connection_robustness(self, possible_neighbors): - self._restructure_process_lock.acquire() - # addrs_to_connect = self.neighbor_policy.get_nodes_known(neighbors_too=False) - # If we got some refs, try to connect to them - if len(possible_neighbors) > 0: - logging.info(f"Reestructuring | Addrs availables | addr list: {possible_neighbors}") - await self.nm.start_late_connection_process( - connected=True, msg_type="discover_nodes", addrs_known=possible_neighbors - ) - else: - logging.info("Reestructuring | NO Addrs availables") - await self.nm.start_late_connection_process(connected=True, msg_type="discover_nodes") - self._restructure_process_lock.release() - - async def stop_connections_with_federation(self): - await asyncio.sleep(400) - logging.info("### DISCONNECTING FROM FEDERATON ###") - neighbors = self.np.get_nodes_known(neighbors_only=True) - for n in neighbors: - await self.cm.add_to_blacklist(n) - for n in neighbors: - await self.cm.disconnect(n, mutual_disconnection=False, forced=True) + return self.san.get_actions() + + # """ ############################### + # # ROBUSTNESS # + # ############################### + # """ + + # async def beacon_received(self, addr, geoloc): + # latitude, longitude = geoloc + # self.meet_node(addr) + # logging.info(f"Beacon received SAModule, source: {addr}, geolocalization: {latitude},{longitude}") + + # async def check_external_connection_service_status(self): + # if not await self.cm.is_external_connection_service_running(): + # logging.info("πŸ”„ External Service not running | Starting service...") + # await self.cm.init_external_connection_service() + # await self.cm.subscribe_beacon_listener(self.beacon_received) + # await self.cm.start_beacon() + + # async def analize_topology_robustness(self): + # logging.info("πŸ”„ Analizing node network robustness...") + # if not self._restructure_process_lock.locked(): + # if not await self.neighbors_left(): + # logging.info("No Neighbors left | reconnecting with Federation") + # await self.reconnect_to_federation() + # elif self.np.need_more_neighbors() and self._restructure_available(): + # logging.info("Insufficient Robustness | Upgrading robustness | Searching for more connections") + # self._update_restructure_cooldown() + # possible_neighbors = self.np.get_nodes_known(neighbors_too=False) + # possible_neighbors = await self.cm.apply_restrictions(possible_neighbors) + # if not possible_neighbors: + # logging.info("All possible neighbors using nodes known are restricted...") + # else: + # pass + # # asyncio.create_task(self.upgrade_connection_robustness(possible_neighbors)) + # else: + # logging.info("Sufficient Robustness | no actions required") + # else: + # logging.info("❗️ Reestructure/Reconnecting process already running...") + + # async def reconnect_to_federation(self): + # self._restructure_process_lock.acquire() + # await self.cm.clear_restrictions() + # await asyncio.sleep(120) + # # If we got some refs, try to reconnect to them + # if len(self.np.get_nodes_known()) > 0: + # logging.info("Reconnecting | Addrs availables") + # await self.nm.start_late_connection_process( + # connected=False, msg_type="discover_nodes", addrs_known=self.np.get_nodes_known() + # ) + # else: + # logging.info("Reconnecting | NO Addrs availables") + # await self.nm.start_late_connection_process(connected=False, msg_type="discover_nodes") + # self._restructure_process_lock.release() + + # async def upgrade_connection_robustness(self, possible_neighbors): + # self._restructure_process_lock.acquire() + # # addrs_to_connect = self.neighbor_policy.get_nodes_known(neighbors_too=False) + # # If we got some refs, try to connect to them + # if len(possible_neighbors) > 0: + # logging.info(f"Reestructuring | Addrs availables | addr list: {possible_neighbors}") + # await self.nm.start_late_connection_process( + # connected=True, msg_type="discover_nodes", addrs_known=possible_neighbors + # ) + # else: + # logging.info("Reestructuring | NO Addrs availables") + # await self.nm.start_late_connection_process(connected=True, msg_type="discover_nodes") + # self._restructure_process_lock.release() + + # async def stop_connections_with_federation(self): + # await asyncio.sleep(400) + # logging.info("### DISCONNECTING FROM FEDERATON ###") + # neighbors = self.np.get_nodes_known(neighbors_only=True) + # for n in neighbors: + # await self.cm.add_to_blacklist(n) + # for n in neighbors: + # await self.cm.disconnect(n, mutual_disconnection=False, forced=True) diff --git a/nebula/core/situationalawareness/awareness/neighborpolicies/__init__.py b/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/__init__.py similarity index 100% rename from nebula/core/situationalawareness/awareness/neighborpolicies/__init__.py rename to nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/__init__.py diff --git a/nebula/core/situationalawareness/awareness/neighborpolicies/fcneighborpolicy.py b/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/fcneighborpolicy.py similarity index 89% rename from nebula/core/situationalawareness/awareness/neighborpolicies/fcneighborpolicy.py rename to nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/fcneighborpolicy.py index 17ded13e5..73d559b97 100644 --- a/nebula/core/situationalawareness/awareness/neighborpolicies/fcneighborpolicy.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/fcneighborpolicy.py @@ -1,5 +1,6 @@ -from nebula.core.situationalawareness.awareness.neighborpolicies.neighborpolicy import NeighborPolicy +from nebula.core.situationalawareness.awareness.sanetwork.neighborpolicies.neighborpolicy import NeighborPolicy from nebula.core.utils.locker import Locker +import logging class FCNeighborPolicy(NeighborPolicy): @@ -29,6 +30,7 @@ def set_config(self, config): config[2] -> self addr config[3] -> NodeManager reference """ + logging.info("Initializing Fully-Connected Topology Neighbor Policy") self.neighbors_lock.acquire() self.neighbors = config[0] self.neighbors_lock.release() @@ -51,6 +53,7 @@ def meet_node(self, node): """ self.nodes_known_lock.acquire() if node != self.addr: + if not node in self.nodes_known: logging.info(f"Update nodes known | addr: {node}") self.nodes_known.add(node) self.nodes_known_lock.release() @@ -104,8 +107,10 @@ def update_neighbors(self, node, remove=False): if remove: try: self.neighbors.remove(node) + logging.info(f"Remove neighbor | addr: {node}") except KeyError: pass else: self.neighbors.add(node) + logging.info(f"Add neighbor | addr: {node}") self.neighbors_lock.release() \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/neighborpolicies/idleneighborpolicy.py b/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/idleneighborpolicy.py similarity index 89% rename from nebula/core/situationalawareness/awareness/neighborpolicies/idleneighborpolicy.py rename to nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/idleneighborpolicy.py index 02292aee9..565e63439 100644 --- a/nebula/core/situationalawareness/awareness/neighborpolicies/idleneighborpolicy.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/idleneighborpolicy.py @@ -1,5 +1,6 @@ -from nebula.core.situationalawareness.awareness.neighborpolicies.neighborpolicy import NeighborPolicy +from nebula.core.situationalawareness.awareness.sanetwork.neighborpolicies.neighborpolicy import NeighborPolicy from nebula.core.utils.locker import Locker +import logging class IDLENeighborPolicy(NeighborPolicy): @@ -29,6 +30,7 @@ def set_config(self, config): config[2] -> self addr config[3] -> NodeManager reference """ + logging.info("Initializing Random Topology Neighbor Policy") self.neighbors_lock.acquire() self.neighbors = config[0] self.neighbors_lock.release() @@ -51,6 +53,7 @@ def meet_node(self, node): """ self.nodes_known_lock.acquire() if node != self.addr: + if not node in self.nodes_known: logging.info(f"Update nodes known | addr: {node}") self.nodes_known.add(node) self.nodes_known_lock.release() @@ -104,8 +107,10 @@ def update_neighbors(self, node, remove=False): if remove: try: self.neighbors.remove(node) + logging.info(f"Remove neighbor | addr: {node}") except KeyError: pass else: self.neighbors.add(node) + logging.info(f"Add neighbor | addr: {node}") self.neighbors_lock.release() \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/neighborpolicies/neighborpolicy.py b/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/neighborpolicy.py similarity index 68% rename from nebula/core/situationalawareness/awareness/neighborpolicies/neighborpolicy.py rename to nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/neighborpolicy.py index c5e598b3e..df2104c27 100644 --- a/nebula/core/situationalawareness/awareness/neighborpolicies/neighborpolicy.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/neighborpolicy.py @@ -36,10 +36,10 @@ def update_neighbors(self, node, remove=False): pass def factory_NeighborPolicy(topology) -> NeighborPolicy: - from nebula.core.situationalawareness.awareness.neighborpolicies.idleneighborpolicy import IDLENeighborPolicy - from nebula.core.situationalawareness.awareness.neighborpolicies.fcneighborpolicy import FCNeighborPolicy - from nebula.core.situationalawareness.awareness.neighborpolicies.ringneighborpolicy import RINGNeighborPolicy - from nebula.core.situationalawareness.awareness.neighborpolicies.starneighborpolicy import STARNeighborPolicy + from nebula.core.situationalawareness.awareness.sanetwork.neighborpolicies.idleneighborpolicy import IDLENeighborPolicy + from nebula.core.situationalawareness.awareness.sanetwork.neighborpolicies.fcneighborpolicy import FCNeighborPolicy + from nebula.core.situationalawareness.awareness.sanetwork.neighborpolicies.ringneighborpolicy import RINGNeighborPolicy + from nebula.core.situationalawareness.awareness.sanetwork.neighborpolicies.starneighborpolicy import STARNeighborPolicy options = { "random": IDLENeighborPolicy, # default value diff --git a/nebula/core/situationalawareness/awareness/neighborpolicies/ringneighborpolicy.py b/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/ringneighborpolicy.py similarity index 94% rename from nebula/core/situationalawareness/awareness/neighborpolicies/ringneighborpolicy.py rename to nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/ringneighborpolicy.py index 8db66e1f4..3154aae07 100644 --- a/nebula/core/situationalawareness/awareness/neighborpolicies/ringneighborpolicy.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/ringneighborpolicy.py @@ -1,6 +1,7 @@ -from nebula.core.situationalawareness.awareness.neighborpolicies.neighborpolicy import NeighborPolicy +from nebula.core.situationalawareness.awareness.sanetwork.neighborpolicies.neighborpolicy import NeighborPolicy from nebula.core.utils.locker import Locker import random +import logging class RINGNeighborPolicy(NeighborPolicy): @@ -25,6 +26,7 @@ def set_config(self, config): config[1] -> list of nodes known on federation config[2] -> self.addr """ + logging.info("Initializing Ring Topology Neighbor Policy") self.neighbors_lock.acquire() self.neighbors = config[0] self.neighbors_lock.release() diff --git a/nebula/core/situationalawareness/awareness/neighborpolicies/starneighborpolicy.py b/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/starneighborpolicy.py similarity index 96% rename from nebula/core/situationalawareness/awareness/neighborpolicies/starneighborpolicy.py rename to nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/starneighborpolicy.py index 87931d2c7..1eda9ba91 100644 --- a/nebula/core/situationalawareness/awareness/neighborpolicies/starneighborpolicy.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/starneighborpolicy.py @@ -1,4 +1,4 @@ -from nebula.core.situationalawareness.awareness.neighborpolicies.neighborpolicy import NeighborPolicy +from nebula.core.situationalawareness.awareness.sanetwork.neighborpolicies.neighborpolicy import NeighborPolicy from nebula.core.utils.locker import Locker class STARNeighborPolicy(NeighborPolicy): diff --git a/nebula/core/situationalawareness/awareness/networkoptimization/connectionoptimizer.py b/nebula/core/situationalawareness/awareness/sanetwork/networkoptimization/connectionoptimizer.py similarity index 100% rename from nebula/core/situationalawareness/awareness/networkoptimization/connectionoptimizer.py rename to nebula/core/situationalawareness/awareness/sanetwork/networkoptimization/connectionoptimizer.py diff --git a/nebula/core/situationalawareness/awareness/networkoptimization/networkoptimizer.py b/nebula/core/situationalawareness/awareness/sanetwork/networkoptimization/networkoptimizer.py similarity index 100% rename from nebula/core/situationalawareness/awareness/networkoptimization/networkoptimizer.py rename to nebula/core/situationalawareness/awareness/sanetwork/networkoptimization/networkoptimizer.py diff --git a/nebula/core/situationalawareness/awareness/networkoptimization/timergenerator.py b/nebula/core/situationalawareness/awareness/sanetwork/networkoptimization/timergenerator.py similarity index 100% rename from nebula/core/situationalawareness/awareness/networkoptimization/timergenerator.py rename to nebula/core/situationalawareness/awareness/sanetwork/networkoptimization/timergenerator.py diff --git a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py new file mode 100644 index 000000000..6ce28e4fc --- /dev/null +++ b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py @@ -0,0 +1,200 @@ +import asyncio +import logging +from nebula.core.utils.locker import Locker +from nebula.core.situationalawareness.awareness.sanetwork.neighborpolicies.neighborpolicy import factory_NeighborPolicy +from nebula.addons.functions import print_msg_box +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from nebula.core.network.communications import CommunicationsManager + from nebula.core.situationalawareness.awareness.samodule import SAModule + +RESTRUCTURE_COOLDOWN = 5 + +class SANetwork(): + def __init__( + self, + sam: "SAModule", + communication_manager: "CommunicationsManager", + addr, + topology, + strict_topology=True + ): + print_msg_box( + msg=f"Starting Network SA\nTopology: {topology}\nStrict: {strict_topology}", + indent=2, + title="Network SA module", + ) + self._sam = sam + self._cm = communication_manager + self._addr = addr + self._topology = topology + self._strict_topology = strict_topology + self._neighbor_policy = factory_NeighborPolicy(topology) + self._restructure_process_lock = Locker(name="restructure_process_lock") + self._restructure_cooldown = 0 + + @property + def sam(self): + return self._sam + + @property + def cm(self): + return self._cm + + @property + def np(self): + return self._neighbor_policy + + async def init(self): + if not self.sam.is_additional_participant(): + logging.info("Deploying External Connection Service") + await self.cm.start_external_connection_service() + await self.cm.subscribe_beacon_listener(self.beacon_received) + await self.cm.start_beacon() + else: + logging.info("Deploying External Connection Service | No running") + await self.cm.start_external_connection_service(run_service=False) + + + logging.info("Building neighbor policy configuration..") + self.np.set_config([ + await self.cm.get_addrs_current_connections(only_direct=True, myself=False), + await self.cm.get_addrs_current_connections(only_direct=False, only_undirected=False, myself=False), + self._addr, + self, + ]) + + async def module_actions(self): + await self.check_external_connection_service_status() + await self.analize_topology_robustness() + + + """ ############################### + # NEIGHBOR POLICY # + ############################### + """ + async def register_node(self, node, neighbor=False, remove=False): + if not neighbor: + self.meet_node(node) + else: + self.update_neighbors(node, remove) + + def meet_node(self, node): + if node != self._addr: + self.np.meet_node(node) + + def update_neighbors(self, node, remove=False): + self.np.update_neighbors(node, remove) + if not remove: + self.np.meet_node(node) + + def get_nodes_known(self, neighbors_too=False, neighbors_only=False): + return self.np.get_nodes_known(neighbors_too, neighbors_only) + + async def neighbors_left(self): + return len(await self.cm.get_addrs_current_connections(only_direct=True, myself=False)) > 0 + + def accept_connection(self, source, joining=False): + return self.np.accept_connection(source, joining) + + def need_more_neighbors(self): + return self.np.need_more_neighbors() + + def get_actions(self): + return self.np.get_actions() + + """ ############################### + # EXTERNAL CONNECTION SERVICE # + ############################### + """ + + async def check_external_connection_service_status(self): + if not await self.cm.is_external_connection_service_running(): + logging.info("πŸ”„ External Service not running | Starting service...") + await self.cm.init_external_connection_service() + await self.cm.subscribe_beacon_listener(self.beacon_received) + await self.cm.start_beacon() + + async def experiment_finish(self): + await self.cm.stop_external_connection_service() + + async def beacon_received(self, addr, geoloc): + latitude, longitude = geoloc + self.meet_node(addr) + logging.info(f"Beacon received SANetwork, source: {addr}, geolocalization: {latitude},{longitude}") + + """ ############################### + # REESTRUCTURE TOPOLOGY # + ############################### + """ + + def _update_restructure_cooldown(self): + if self._restructure_cooldown: + self._restructure_cooldown = (self._restructure_cooldown + 1) % RESTRUCTURE_COOLDOWN + + def _restructure_available(self): + if self._restructure_cooldown: + logging.info("Reestructure on cooldown") + return self._restructure_cooldown == 0 + + def get_restructure_process_lock(self): + return self._restructure_process_lock + + async def analize_topology_robustness(self): + logging.info("πŸ”„ Analizing node network robustness...") + if not self._restructure_process_lock.locked(): + if not await self.neighbors_left(): + logging.info("No Neighbors left | reconnecting with Federation") + await self.reconnect_to_federation() + elif self.np.need_more_neighbors() and self._restructure_available(): + logging.info("Insufficient Robustness | Upgrading robustness | Searching for more connections") + self._update_restructure_cooldown() + possible_neighbors = self.np.get_nodes_known(neighbors_too=False) + possible_neighbors = await self.cm.apply_restrictions(possible_neighbors) + if not possible_neighbors: + logging.info("All possible neighbors using nodes known are restricted...") + else: + pass + # asyncio.create_task(self.upgrade_connection_robustness(possible_neighbors)) + else: + logging.info("Sufficient Robustness | no actions required") + else: + logging.info("❗️ Reestructure/Reconnecting process already running...") + + async def reconnect_to_federation(self): + self._restructure_process_lock.acquire() + await self.cm.clear_restrictions() + await asyncio.sleep(120) + # If we got some refs, try to reconnect to them + if len(self.np.get_nodes_known()) > 0: + logging.info("Reconnecting | Addrs availables") + await self.sam.nm.start_late_connection_process( + connected=False, msg_type="discover_nodes", addrs_known=self.np.get_nodes_known() + ) + else: + logging.info("Reconnecting | NO Addrs availables") + await self.sam.nm.start_late_connection_process(connected=False, msg_type="discover_nodes") + self._restructure_process_lock.release() + + async def upgrade_connection_robustness(self, possible_neighbors): + self._restructure_process_lock.acquire() + # addrs_to_connect = self.neighbor_policy.get_nodes_known(neighbors_too=False) + # If we got some refs, try to connect to them + if len(possible_neighbors) > 0: + logging.info(f"Reestructuring | Addrs availables | addr list: {possible_neighbors}") + await self.sam.nm.start_late_connection_process( + connected=True, msg_type="discover_nodes", addrs_known=possible_neighbors + ) + else: + logging.info("Reestructuring | NO Addrs availables") + await self.sam.nm.start_late_connection_process(connected=True, msg_type="discover_nodes") + self._restructure_process_lock.release() + + async def stop_connections_with_federation(self): + await asyncio.sleep(400) + logging.info("### DISCONNECTING FROM FEDERATON ###") + neighbors = self.np.get_nodes_known(neighbors_only=True) + for n in neighbors: + await self.cm.add_to_blacklist(n) + for n in neighbors: + await self.cm.disconnect(n, mutual_disconnection=False, forced=True) \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/satraining/satraining.py b/nebula/core/situationalawareness/awareness/satraining/satraining.py new file mode 100644 index 000000000..ca00a1293 --- /dev/null +++ b/nebula/core/situationalawareness/awareness/satraining/satraining.py @@ -0,0 +1,25 @@ +import asyncio +import logging +from nebula.core.utils.locker import Locker +from nebula.addons.functions import print_msg_box +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from nebula.core.situationalawareness.awareness.samodule import SAModule + +RESTRUCTURE_COOLDOWN = 5 + +class SATraining(): + def __init__( + self, + sam: "SAModule", + training_policy, + weight_strategies + ): + print_msg_box( + msg=f"Starting Training SA\nTraining policy: {training_policy}\nWeight strategies: {weight_strategies}", + indent=2, + title="Training SA module", + ) + self._sam = sam + self._trainning_policy = training_policy + self._weight_strategies = weight_strategies \ No newline at end of file diff --git a/nebula/core/situationalawareness/nodemanager.py b/nebula/core/situationalawareness/nodemanager.py index 3e3c58557..69caf6bbd 100644 --- a/nebula/core/situationalawareness/nodemanager.py +++ b/nebula/core/situationalawareness/nodemanager.py @@ -30,7 +30,7 @@ def __init__( self._aditional_participant = aditional_participant self.topology = topology print_msg_box( - msg=f"Starting NodeManager module...\nTopology: {self.topology}", indent=2, title="NodeManager module" + msg=f"Starting NodeManager module...", indent=2, title="NodeManager module" ) logging.info("🌐 Initializing Node Manager") self._engine = engine @@ -109,6 +109,12 @@ async def set_configs(self): if self._fast_reboot_status: self._fastreboot = FastReboot(self) + async def get_geoloc(self): + return await self.sam.get_geoloc() + + async def mobility_actions(self): + await self.sam.mobility_actions() + async def experiment_finish(self): await self.sam.experiment_finish() @@ -123,7 +129,7 @@ async def update_learning_rate(self, new_lr): async def register_late_neighbor(self, addr, joinning_federation=False): logging.info(f"Registering | late neighbor: {addr}, joining: {joinning_federation}") - self.meet_node(addr) + self.sam.meet_node(addr) await self.update_neighbors(addr) if joinning_federation: if self.fast_reboot_on(): @@ -193,7 +199,7 @@ def get_actions(self): return self.sam.get_actions() async def update_neighbors(self, node, remove=False): - logging.info(f"Update neighbor | node addr: {node} | remove: {remove}") + #logging.info(f"Update neighbor | node addr: {node} | remove: {remove}") await self._update_neighbors_lock.acquire_async() self.sam.update_neighbors(node, remove) if remove: @@ -298,7 +304,7 @@ async def start_late_connection_process(self, connected=False, msg_type="discove self.candidate_selector.remove_candidates() if not self._desc_done: #TODO remove self._desc_done = True - asyncio.create_task(self.sam.stop_connections_with_federation()) + asyncio.create_task(self.sam.san.stop_connections_with_federation()) # if no candidates, repeat process else: logging.info("❗️ No Candidates found...") @@ -308,16 +314,5 @@ async def start_late_connection_process(self, connected=False, msg_type="discove logging.info("❗️ repeating process...") await self.start_late_connection_process(connected, msg_type, addrs_known) - """ - ############################## - # ROBUSTNESS # - ############################## - """ - async def get_geoloc(self): - return await self.sam.get_geoloc() - - async def mobility_actions(self): - await self.sam.check_external_connection_service_status() - await self.sam.analize_topology_robustness() From bda0596c402eece71bcde6723a3b19f4003303eb Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 26 Feb 2025 12:12:46 +0100 Subject: [PATCH 116/233] fix additional node network conditions --- nebula/addons/mobility.py | 2 +- nebula/core/engine.py | 4 + nebula/core/network/communications.py | 8 +- .../awareness/GPS/gpsmodule.py | 4 + .../awareness/GPS/nebulagps.py | 5 + .../awareness/samodule.py | 95 ++----------------- .../awareness/sanetwork/sanetwork.py | 2 +- .../core/situationalawareness/nodemanager.py | 3 + 8 files changed, 32 insertions(+), 91 deletions(-) diff --git a/nebula/addons/mobility.py b/nebula/addons/mobility.py index 0a516f591..97baa2062 100755 --- a/nebula/addons/mobility.py +++ b/nebula/addons/mobility.py @@ -132,7 +132,7 @@ async def run_mobility(self): """ if not self.mobility: return - await asyncio.sleep(self.grace_time) + #await asyncio.sleep(self.grace_time) while True: await self.change_geo_location() #await self.change_connections_based_on_distance() diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 3d6ae0a4c..8258ef3aa 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -576,6 +576,10 @@ async def trigger_event(self, message_event): async def get_geoloc(self): return await self.nm.get_geoloc() + + async def calculate_distance(self, latitude, longitude): + return await self.nm.calculate_distance(latitude, longitude) + async def _aditional_node_start(self): self.update_sinchronized_status(False) diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 9296c208b..5dc63d539 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -77,7 +77,7 @@ def __init__(self, engine: "Engine"): # Network simulator service to deplay realistic network conditions refresh_conditions_interval = 5 - self._network_simulator = factory_network_simulator("nebula", self, refresh_conditions_interval, "eth0", verbose=False) + self._network_simulator = factory_network_simulator("nebula", self, refresh_conditions_interval, "eth0", verbose=True) @property def engine(self): @@ -406,9 +406,13 @@ async def update_geolocalization(self, geoloc : dict): #logging.info("Update geolocs to simulate network conditions") for source in geoloc.keys(): latitude, longitude = geoloc[source] - #logging.info(f"Update geolocs for source: {source}, geoloc: ({latitude},{longitude})") + logging.info(f"Update geolocs for source: {source}, geoloc: ({latitude},{longitude})") if source in self.connections: self.connections[source].update_geolocation(latitude, longitude) + else: # When not connected to device yet + logging.info(f"Update conditions for not already connected source: {source})") + distance = await self.engine.calculate_distance(latitude, longitude) + await self.ns.set_network_conditions(source, distance) def get_connections_lock(self): return self.connections_lock diff --git a/nebula/core/situationalawareness/awareness/GPS/gpsmodule.py b/nebula/core/situationalawareness/awareness/GPS/gpsmodule.py index 18fbaf259..d80d33572 100644 --- a/nebula/core/situationalawareness/awareness/GPS/gpsmodule.py +++ b/nebula/core/situationalawareness/awareness/GPS/gpsmodule.py @@ -14,6 +14,10 @@ async def stop(self): @abstractmethod async def is_running(self): pass + + @abstractmethod + async def calculate_distance(self, self_lat, self_long, other_lat, other_long): + pass class GPSModuleException(Exception): pass diff --git a/nebula/core/situationalawareness/awareness/GPS/nebulagps.py b/nebula/core/situationalawareness/awareness/GPS/nebulagps.py index d6b7d4adf..c6d812fb0 100644 --- a/nebula/core/situationalawareness/awareness/GPS/nebulagps.py +++ b/nebula/core/situationalawareness/awareness/GPS/nebulagps.py @@ -3,6 +3,7 @@ from nebula.core.situationalawareness.awareness.GPS.gpsmodule import GPSModule import socket from nebula.core.utils.locker import Locker +from geopy import distance from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -53,6 +54,10 @@ async def stop(self): async def is_running(self): return self.running + + async def calculate_distance(self, self_lat, self_long, other_lat, other_long): + distance_m = distance.distance((self_lat, self_long), (other_lat, other_long)).m + return distance_m async def _send_location_loop(self): """Envia la geolocalizaciΓ³n periΓ³dicamente por broadcast.""" diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index 2a22ad6b8..95227b717 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -31,7 +31,7 @@ def __init__( self._topology = topology self._node_manager: NodeManager = nodemanager self._situational_awareness_network = SANetwork(self, self.cm, self._addr, self._topology) - self._situational_awareness_trainning = SATraining(self,"hybrid", "fastreboot") + self._situational_awareness_trainning = SATraining(self, "hybrid", "fastreboot") self._restructure_process_lock = Locker(name="restructure_process_lock") self._restructure_cooldown = 0 self._gpsmodule = factory_gpsmodule("nebula", self, self._addr) @@ -53,8 +53,8 @@ def gps(self): return self._gpsmodule async def init(self): - if not self.is_additional_participant(): - await self.gps.start() + #if not self.is_additional_participant(): + await self.gps.start() await self.san.init() def is_additional_participant(self): @@ -76,6 +76,10 @@ async def mobility_actions(self): # GPS SERVICE # ############################### """ + + async def calculate_distance(self, other_latitude, other_longitude): + self_lat, self_long = await self.get_geoloc() + return await self.gps.calculate_distance(self_lat, self_long, other_latitude, other_longitude) async def verify_gps_service(self): if not await self.gps.is_running(): @@ -86,15 +90,6 @@ async def verify_gps_service(self): ############################### """ - # def _update_restructure_cooldown(self): - # if self._restructure_cooldown: - # self._restructure_cooldown = (self._restructure_cooldown + 1) % RESTRUCTURE_COOLDOWN - - # def _restructure_available(self): - # if self._restructure_cooldown: - # logging.info("Reestructure on cooldown") - # return self._restructure_cooldown == 0 - def get_restructure_process_lock(self): return self.san.get_restructure_process_lock() @@ -129,78 +124,4 @@ def need_more_neighbors(self): def get_actions(self): return self.san.get_actions() - # """ ############################### - # # ROBUSTNESS # - # ############################### - # """ - - # async def beacon_received(self, addr, geoloc): - # latitude, longitude = geoloc - # self.meet_node(addr) - # logging.info(f"Beacon received SAModule, source: {addr}, geolocalization: {latitude},{longitude}") - - # async def check_external_connection_service_status(self): - # if not await self.cm.is_external_connection_service_running(): - # logging.info("πŸ”„ External Service not running | Starting service...") - # await self.cm.init_external_connection_service() - # await self.cm.subscribe_beacon_listener(self.beacon_received) - # await self.cm.start_beacon() - - # async def analize_topology_robustness(self): - # logging.info("πŸ”„ Analizing node network robustness...") - # if not self._restructure_process_lock.locked(): - # if not await self.neighbors_left(): - # logging.info("No Neighbors left | reconnecting with Federation") - # await self.reconnect_to_federation() - # elif self.np.need_more_neighbors() and self._restructure_available(): - # logging.info("Insufficient Robustness | Upgrading robustness | Searching for more connections") - # self._update_restructure_cooldown() - # possible_neighbors = self.np.get_nodes_known(neighbors_too=False) - # possible_neighbors = await self.cm.apply_restrictions(possible_neighbors) - # if not possible_neighbors: - # logging.info("All possible neighbors using nodes known are restricted...") - # else: - # pass - # # asyncio.create_task(self.upgrade_connection_robustness(possible_neighbors)) - # else: - # logging.info("Sufficient Robustness | no actions required") - # else: - # logging.info("❗️ Reestructure/Reconnecting process already running...") - - # async def reconnect_to_federation(self): - # self._restructure_process_lock.acquire() - # await self.cm.clear_restrictions() - # await asyncio.sleep(120) - # # If we got some refs, try to reconnect to them - # if len(self.np.get_nodes_known()) > 0: - # logging.info("Reconnecting | Addrs availables") - # await self.nm.start_late_connection_process( - # connected=False, msg_type="discover_nodes", addrs_known=self.np.get_nodes_known() - # ) - # else: - # logging.info("Reconnecting | NO Addrs availables") - # await self.nm.start_late_connection_process(connected=False, msg_type="discover_nodes") - # self._restructure_process_lock.release() - - # async def upgrade_connection_robustness(self, possible_neighbors): - # self._restructure_process_lock.acquire() - # # addrs_to_connect = self.neighbor_policy.get_nodes_known(neighbors_too=False) - # # If we got some refs, try to connect to them - # if len(possible_neighbors) > 0: - # logging.info(f"Reestructuring | Addrs availables | addr list: {possible_neighbors}") - # await self.nm.start_late_connection_process( - # connected=True, msg_type="discover_nodes", addrs_known=possible_neighbors - # ) - # else: - # logging.info("Reestructuring | NO Addrs availables") - # await self.nm.start_late_connection_process(connected=True, msg_type="discover_nodes") - # self._restructure_process_lock.release() - - # async def stop_connections_with_federation(self): - # await asyncio.sleep(400) - # logging.info("### DISCONNECTING FROM FEDERATON ###") - # neighbors = self.np.get_nodes_known(neighbors_only=True) - # for n in neighbors: - # await self.cm.add_to_blacklist(n) - # for n in neighbors: - # await self.cm.disconnect(n, mutual_disconnection=False, forced=True) + diff --git a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py index 6ce28e4fc..f5c25a340 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py @@ -191,7 +191,7 @@ async def upgrade_connection_robustness(self, possible_neighbors): self._restructure_process_lock.release() async def stop_connections_with_federation(self): - await asyncio.sleep(400) + await asyncio.sleep(200) logging.info("### DISCONNECTING FROM FEDERATON ###") neighbors = self.np.get_nodes_known(neighbors_only=True) for n in neighbors: diff --git a/nebula/core/situationalawareness/nodemanager.py b/nebula/core/situationalawareness/nodemanager.py index 69caf6bbd..f49a379c1 100644 --- a/nebula/core/situationalawareness/nodemanager.py +++ b/nebula/core/situationalawareness/nodemanager.py @@ -111,6 +111,9 @@ async def set_configs(self): async def get_geoloc(self): return await self.sam.get_geoloc() + + async def calculate_distance(self, latitude, longitude): + return await self.sam.calculate_distance(latitude, longitude) async def mobility_actions(self): await self.sam.mobility_actions() From 387ec6d0ed8e06b4223ef5828667732d6bb08477 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 26 Feb 2025 12:57:30 +0100 Subject: [PATCH 117/233] fix network conditions fist attemp --- nebula/core/network/communications.py | 11 +++--- .../nebulanetworksimulator.py | 35 ++++++++++--------- .../frontend/config/participant.json.example | 2 +- nebula/node.py | 3 +- 4 files changed, 28 insertions(+), 23 deletions(-) diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 5dc63d539..8cf4e9e47 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -406,13 +406,14 @@ async def update_geolocalization(self, geoloc : dict): #logging.info("Update geolocs to simulate network conditions") for source in geoloc.keys(): latitude, longitude = geoloc[source] - logging.info(f"Update geolocs for source: {source}, geoloc: ({latitude},{longitude})") + #logging.info(f"Update geolocs for source: {source}, geoloc: ({latitude},{longitude})") if source in self.connections: self.connections[source].update_geolocation(latitude, longitude) else: # When not connected to device yet - logging.info(f"Update conditions for not already connected source: {source})") - distance = await self.engine.calculate_distance(latitude, longitude) - await self.ns.set_network_conditions(source, distance) + #logging.info(f"Update conditions for not already connected source: {source})") + if self.config.participant["network_args"]["simulation"]: + distance = await self.engine.calculate_distance(latitude, longitude) + await self.ns.set_network_conditions(source, distance) def get_connections_lock(self): return self.connections_lock @@ -605,7 +606,7 @@ async def deploy_additional_services(self): # self._generate_network_conditions() await self._forwarder.start() if self.config.participant["mobility_args"]["mobility"]: - if False: + if self.config.participant["network_args"]["simulation"]: await self.ns.start() # await self._discoverer.start() # await self._health.start() diff --git a/nebula/core/network/networksimulation/nebulanetworksimulator.py b/nebula/core/network/networksimulation/nebulanetworksimulator.py index 3e6058ad2..e99a8315d 100644 --- a/nebula/core/network/networksimulation/nebulanetworksimulator.py +++ b/nebula/core/network/networksimulation/nebulanetworksimulator.py @@ -55,7 +55,8 @@ async def _change_network_conditions_based_on_distances(self): addr_ip = addr.split(":")[0] self._set_network_condition_for_addr(self._node_interface, addr_ip, conditions["bandwidth"], conditions["delay"]) self._set_network_condition_for_multicast(self._node_interface, addr_ip, self.IP_MULTICAST, conditions["bandwidth"], conditions["delay"]) - self._current_network_conditions[addr] = conditions + async with self._network_conditions_lock: + self._current_network_conditions[addr] = conditions else: logging.info("network conditions havent changed since last time") except KeyError: @@ -67,22 +68,24 @@ async def set_thresholds(self, thresholds : dict): async with self._network_conditions_lock: self._network_conditions = thresholds - async def set_network_conditions(self, dest_addr, distance): + async def set_network_conditions(self, dest_addr : str, distance): conditions = await self._calculate_network_conditions(distance) - self._set_network_condition_for_addr(self, - interface=self._node_interface, - network=dest_addr, - bandwidth=conditions["bandwidth"], - delay=conditions["delay"] - ) - - self._set_network_condition_for_multicast(self, - interface=self._node_interface, - src_network=dest_addr, - dst_network=self.IP_MULTICAST, - bandwidth=conditions["bandwidth"], - delay=conditions["delay"] - ) + addr_ip = dest_addr.split(":")[0] + if (dest_addr not in self._current_network_conditions or self._current_network_conditions[dest_addr] != conditions): + self._set_network_condition_for_addr(interface=self._node_interface, + network=addr_ip, + bandwidth=conditions["bandwidth"], + delay=conditions["delay"] + ) + + self._set_network_condition_for_multicast(interface=self._node_interface, + src_network=addr_ip, + dst_network=self.IP_MULTICAST, + bandwidth=conditions["bandwidth"], + delay=conditions["delay"] + ) + async with self._network_conditions_lock: + self._current_network_conditions[dest_addr] = conditions def _set_network_condition_for_addr( self, diff --git a/nebula/frontend/config/participant.json.example b/nebula/frontend/config/participant.json.example index 3e2b916b9..db3ec0d34 100755 --- a/nebula/frontend/config/participant.json.example +++ b/nebula/frontend/config/participant.json.example @@ -41,7 +41,7 @@ "addr": "", "neighbors": "", "interface": "eth0", - "simulation": true, + "simulation": false, "bandwidth": "5Gbps", "delay": "0ms", "delay-distro": "0ms", diff --git a/nebula/node.py b/nebula/node.py index 11b1b07c7..5e58c53b2 100755 --- a/nebula/node.py +++ b/nebula/node.py @@ -309,7 +309,8 @@ def randomize_value(value, variability): logging.info(f"Waiting for round {additional_node_round} to start") logging.info("Waiting time to start finding federation") - time.sleep(150) + #time.sleep(150) + await asyncio.sleep(150) # time.sleep(6000) # DEBUG purposes # import requests From e552734c02db971f62bb4124047fa9b529950f8d Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 26 Feb 2025 15:24:27 +0100 Subject: [PATCH 118/233] feature training policy interface --- .../awareness/sanetwork/sanetwork.py | 8 ++--- .../awareness/satraining/satraining.py | 3 +- .../trainingpolicy/bpstrainingpolicy.py | 6 ++++ .../trainingpolicy/htstrainingpolicy.py | 6 ++++ .../trainingpolicy/qdstrainingpolicy.py | 6 ++++ .../trainingpolicy/sostrainingpolicy.py | 6 ++++ .../trainingpolicy/trainingpolicy.py | 29 +++++++++++++++++++ 7 files changed, 59 insertions(+), 5 deletions(-) create mode 100644 nebula/core/situationalawareness/awareness/satraining/trainingpolicy/bpstrainingpolicy.py create mode 100644 nebula/core/situationalawareness/awareness/satraining/trainingpolicy/htstrainingpolicy.py create mode 100644 nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py create mode 100644 nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py create mode 100644 nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py diff --git a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py index f5c25a340..c0e032f99 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py @@ -65,8 +65,8 @@ async def init(self): ]) async def module_actions(self): - await self.check_external_connection_service_status() - await self.analize_topology_robustness() + await self._check_external_connection_service_status() + await self._analize_topology_robustness() """ ############################### @@ -108,7 +108,7 @@ def get_actions(self): ############################### """ - async def check_external_connection_service_status(self): + async def _check_external_connection_service_status(self): if not await self.cm.is_external_connection_service_running(): logging.info("πŸ”„ External Service not running | Starting service...") await self.cm.init_external_connection_service() @@ -140,7 +140,7 @@ def _restructure_available(self): def get_restructure_process_lock(self): return self._restructure_process_lock - async def analize_topology_robustness(self): + async def _analize_topology_robustness(self): logging.info("πŸ”„ Analizing node network robustness...") if not self._restructure_process_lock.locked(): if not await self.neighbors_left(): diff --git a/nebula/core/situationalawareness/awareness/satraining/satraining.py b/nebula/core/situationalawareness/awareness/satraining/satraining.py index ca00a1293..aa0139b39 100644 --- a/nebula/core/situationalawareness/awareness/satraining/satraining.py +++ b/nebula/core/situationalawareness/awareness/satraining/satraining.py @@ -1,6 +1,7 @@ import asyncio import logging from nebula.core.utils.locker import Locker +from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.trainingpolicy import factory_training_policy from nebula.addons.functions import print_msg_box from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -21,5 +22,5 @@ def __init__( title="Training SA module", ) self._sam = sam - self._trainning_policy = training_policy + self._trainning_policy = factory_training_policy(training_policy) self._weight_strategies = weight_strategies \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/bpstrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/bpstrainingpolicy.py new file mode 100644 index 000000000..9a0b34b86 --- /dev/null +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/bpstrainingpolicy.py @@ -0,0 +1,6 @@ +from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.trainingpolicy import TrainingPolicy + +class BPSTrainingPolicy(TrainingPolicy): + + def __init__(): + pass \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/htstrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/htstrainingpolicy.py new file mode 100644 index 000000000..522e5b783 --- /dev/null +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/htstrainingpolicy.py @@ -0,0 +1,6 @@ +from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.trainingpolicy import TrainingPolicy + +class HTSTrainingPolicy(TrainingPolicy): + + def __init__(): + pass \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py new file mode 100644 index 000000000..d5a2685c4 --- /dev/null +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py @@ -0,0 +1,6 @@ +from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.trainingpolicy import TrainingPolicy + +class QDSTrainingPolicy(TrainingPolicy): + + def __init__(): + pass \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py new file mode 100644 index 000000000..6664fba75 --- /dev/null +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py @@ -0,0 +1,6 @@ +from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.trainingpolicy import TrainingPolicy + +class SOSTrainingPolicy(TrainingPolicy): + + def __init__(): + pass \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py new file mode 100644 index 000000000..5498b3285 --- /dev/null +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py @@ -0,0 +1,29 @@ +from abc import ABC, abstractmethod +from typing import Type + +class TrainingPolicy(ABC): + + @abstractmethod + async def update_neighbors(self, node, remove=False): + pass + + @abstractmethod + async def evaluate(self): + pass + + +def factory_training_policy(topology) -> TrainingPolicy: + from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.bpstrainingpolicy import BPSTrainingPolicy + from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.qdstrainingpolicy import QDSTrainingPolicy + from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.sostrainingpolicy import SOSTrainingPolicy + from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.htstrainingpolicy import HTSTrainingPolicy + + options = { + "bps": BPSTrainingPolicy, # "Broad-Propagation Strategy" (BPS) -- default value + "qds": QDSTrainingPolicy, # "Quality-Driven Selection" (QDS) + "sos": SOSTrainingPolicy, # "Speed-Oriented Selection" (SOS) + "hts": HTSTrainingPolicy, # "Hybrid Training Strategy" (HTS) + } + + cs = options.get(topology, BPSTrainingPolicy) + return cs() \ No newline at end of file From bc5dec0c2ade75b2918f634f0d21e2cc181d741f Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 27 Feb 2025 13:52:16 +0100 Subject: [PATCH 119/233] feature event system for addon functionalities --- .../awareness => addons}/GPS/gpsmodule.py | 17 +- .../awareness => addons}/GPS/nebulagps.py | 38 ++- nebula/addons/mobility.py | 245 +++--------------- .../nebulanetworksimulator.py | 148 ++++++++--- .../networksimulation/networksimulator.py | 6 +- nebula/core/addonmanager.py | 33 +++ nebula/core/engine.py | 14 +- nebula/core/eventmanager.py | 124 +++------ nebula/core/network/communications.py | 184 ++----------- nebula/core/network/connection.py | 28 -- .../nebuladiscoveryservice.py | 4 +- .../awareness/samodule.py | 19 -- .../awareness/satraining/satraining.py | 2 +- .../core/situationalawareness/nodemanager.py | 4 - .../frontend/config/participant.json.example | 2 +- 15 files changed, 298 insertions(+), 570 deletions(-) rename nebula/{core/situationalawareness/awareness => addons}/GPS/gpsmodule.py (54%) rename nebula/{core/situationalawareness/awareness => addons}/GPS/nebulagps.py (71%) rename nebula/{core/network => addons}/networksimulation/nebulanetworksimulator.py (59%) rename nebula/{core/network => addons}/networksimulation/networksimulator.py (71%) create mode 100644 nebula/core/addonmanager.py diff --git a/nebula/core/situationalawareness/awareness/GPS/gpsmodule.py b/nebula/addons/GPS/gpsmodule.py similarity index 54% rename from nebula/core/situationalawareness/awareness/GPS/gpsmodule.py rename to nebula/addons/GPS/gpsmodule.py index d80d33572..676af7dda 100644 --- a/nebula/core/situationalawareness/awareness/GPS/gpsmodule.py +++ b/nebula/addons/GPS/gpsmodule.py @@ -1,5 +1,6 @@ import asyncio from abc import ABC, abstractmethod +from nebula.core.eventmanager import AddonEvent class GPSModule(ABC): @@ -18,12 +19,22 @@ async def is_running(self): @abstractmethod async def calculate_distance(self, self_lat, self_long, other_lat, other_long): pass + +class GPSEvent(AddonEvent): + def __init__(self, distances : dict): + self.distances = distances + + def __str__(self): + return "GPSEvent" + + async def get_event_data(self) -> dict: + return self.distances.copy() class GPSModuleException(Exception): pass -def factory_gpsmodule(gps_module, sam, addr) -> GPSModule: - from nebula.core.situationalawareness.awareness.GPS.nebulagps import NebulaGPS +def factory_gpsmodule(gps_module, config, event_manager, addr, update_interval: float = 5.0, verbose=False) -> GPSModule: + from nebula.addons.GPS.nebulagps import NebulaGPS GPS_SERVICES = { "nebula": NebulaGPS, @@ -32,6 +43,6 @@ def factory_gpsmodule(gps_module, sam, addr) -> GPSModule: gps_module = GPS_SERVICES.get(gps_module, NebulaGPS) if gps_module: - return gps_module(sam, addr) + return gps_module(config, event_manager, addr, update_interval, verbose) else: raise GPSModuleException(f"GPS Module {gps_module} not found") \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/GPS/nebulagps.py b/nebula/addons/GPS/nebulagps.py similarity index 71% rename from nebula/core/situationalawareness/awareness/GPS/nebulagps.py rename to nebula/addons/GPS/nebulagps.py index c6d812fb0..e59ecfc89 100644 --- a/nebula/core/situationalawareness/awareness/GPS/nebulagps.py +++ b/nebula/addons/GPS/nebulagps.py @@ -1,31 +1,30 @@ import asyncio import logging -from nebula.core.situationalawareness.awareness.GPS.gpsmodule import GPSModule +from nebula.addons.GPS.gpsmodule import GPSModule, GPSEvent import socket from nebula.core.utils.locker import Locker from geopy import distance - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from nebula.core.situationalawareness.awareness.samodule import SAModule +from nebula.core.eventmanager import EventManager class NebulaGPS(GPSModule): BROADCAST_IP = "255.255.255.255" # Broadcast IP BROADCAST_PORT = 50001 # Poort used for GPS INTERFACE = "eth2" # Interface to avoid network conditions - def __init__(self, sam: "SAModule", addr, update_interval: float = 5.0): + def __init__(self, config, event_manager : EventManager, addr, update_interval: float = 5.0, verbose=False): + self._config = config self._addr = addr - self._situational_awareness_module = sam + self._event_manager = event_manager self.update_interval = update_interval # Frecuencia de emisiΓ³n self.running = False self._node_locations = {} # Diccionario para almacenar ubicaciones de nodos self._broadcast_socket = None self._nodes_location_lock = Locker("nodes_location_lock", async_lock=True) + self._verbose = verbose @property - def sam(self): - return self._situational_awareness_module + def em(self): + return self._event_manager async def start(self): """Inicia el servicio de GPS, enviando y recibiendo ubicaciones.""" @@ -55,6 +54,11 @@ async def stop(self): async def is_running(self): return self.running + async def get_geoloc(self): + latitude = self._config.participant["mobility_args"]["latitude"] + longitude = self._config.participant["mobility_args"]["longitude"] + return (latitude, longitude) + async def calculate_distance(self, self_lat, self_long, other_lat, other_long): distance_m = distance.distance((self_lat, self_long), (other_lat, other_long)).m return distance_m @@ -62,10 +66,10 @@ async def calculate_distance(self, self_lat, self_long, other_lat, other_long): async def _send_location_loop(self): """Envia la geolocalizaciΓ³n periΓ³dicamente por broadcast.""" while self.running: - latitude, longitude = await self.sam.get_geoloc() # Obtener ubicaciΓ³n actual + latitude, longitude = await self.get_geoloc() # Obtener ubicaciΓ³n actual message = f"GPS-UPDATE {self._addr} {latitude} {longitude}" self._broadcast_socket.sendto(message.encode(), (self.BROADCAST_IP, self.BROADCAST_PORT)) - #logging.info(f"Sent GPS location: ({latitude}, {longitude})") + if self._verbose: logging.info(f"Sent GPS location: ({latitude}, {longitude})") await asyncio.sleep(self.update_interval) async def _receive_location_loop(self): @@ -81,7 +85,7 @@ async def _receive_location_loop(self): if sender_addr != self._addr: async with self._nodes_location_lock: self._node_locations[sender_addr] = (float(lat), float(lon)) - #logging.info(f"Received GPS from {addr[0]}: {lat}, {lon}") + if self._verbose: logging.info(f"Received GPS from {addr[0]}: {lat}, {lon}") except Exception as e: logging.error(f"Error receiving GPS update: {e}") @@ -89,8 +93,14 @@ async def _notify_geolocs(self): while True: await asyncio.sleep(self.update_interval) await self._nodes_location_lock.acquire_async() - geolocs = self._node_locations.copy() + geolocs : dict = self._node_locations.copy() await self._nodes_location_lock.release_async() if geolocs: - await self.sam.cm.update_geolocalization(geolocs) + distances = {} + self_lat, self_long = await self.get_geoloc() + for addr, (lat, long) in geolocs.items(): + dist = await self.calculate_distance(self_lat, self_long, lat, long) + distances[addr] = (dist,(lat, long)) + gpsevent = GPSEvent(distances) + asyncio.create_task(self.em.publish_addonevent(gpsevent)) \ No newline at end of file diff --git a/nebula/addons/mobility.py b/nebula/addons/mobility.py index 97baa2062..72f5ffdd7 100755 --- a/nebula/addons/mobility.py +++ b/nebula/addons/mobility.py @@ -3,16 +3,17 @@ import math import random import time -from typing import TYPE_CHECKING - +from nebula.core.eventmanager import EventManager +from nebula.addons.GPS.gpsmodule import GPSEvent +from nebula.core.utils.locker import Locker from nebula.addons.functions import print_msg_box - +from typing import TYPE_CHECKING if TYPE_CHECKING: from nebula.core.network.communications import CommunicationsManager class Mobility: - def __init__(self, config, cm: "CommunicationsManager"): + def __init__(self, config, cm: "CommunicationsManager", event_manager: EventManager): """ Initializes the mobility module with specified configuration and communication manager. @@ -62,18 +63,12 @@ def __init__(self, config, cm: "CommunicationsManager"): self.max_movement_random_strategy = 100 # meters self.max_movement_nearest_strategy = 100 # meters self.max_initiate_approximation = self.max_distance_with_direct_connections * 1.2 - # Network conditions based on distance - self.network_conditions = { - 100: {"bandwidth": "5Gbps", "delay": "5ms"}, - 200: {"bandwidth": "2Gbps", "delay": "50ms"}, - 300: {"bandwidth": "100Mbps", "delay": "200ms"}, - float("inf"): {"bandwidth": "10Mbps", "delay": "1000ms"}, - } - # Current network conditions of each connection {addr: {bandwidth: "5Gbps", delay: "0ms"}} - self.current_network_conditions = {} # Logging box with mobility information mobility_msg = f"Mobility: {self.mobility}\nMobility type: {self.mobility_type}\nRadius federation: {self.radius_federation}\nScheme mobility: {self.scheme_mobility}\nEach {self.round_frequency} rounds" print_msg_box(msg=mobility_msg, indent=2, title="Mobility information") + self._event_manager = event_manager + self._nodes_distances = {} + self._nodes_distances_lock = Locker("nodes_distances_lock", async_lock=True) @property def round(self): @@ -103,9 +98,15 @@ async def start(self): asyncio.Task: An asyncio Task object representing the scheduled `run_mobility` operation. """ + await self._event_manager.subscribe_addonevent(GPSEvent, self.update_nodes_distances) task = asyncio.create_task(self.run_mobility()) return task + async def update_nodes_distances(self, gpsevent : GPSEvent): + distances = await gpsevent.get_event_data() + async with self._nodes_distances_lock: + self._nodes_distances = dict(distances) + async def run_mobility(self): """ Executes the mobility operations in a continuous loop. @@ -135,7 +136,6 @@ async def run_mobility(self): #await asyncio.sleep(self.grace_time) while True: await self.change_geo_location() - #await self.change_connections_based_on_distance() await asyncio.sleep(self.period) async def change_geo_location_random_strategy(self, latitude, longitude): @@ -268,214 +268,35 @@ async def change_geo_location(self): random.seed(time.time() + self.config.participant["device_args"]["idx"]) latitude = float(self.config.participant["mobility_args"]["latitude"]) longitude = float(self.config.participant["mobility_args"]["longitude"]) - if True: # Get neighbor closer to me - selected_neighbor = await self.cm.get_nearest_connections(top=1) + async with self._nodes_distances_lock: + sorted_list = sorted(self._nodes_distances.items(), key=lambda item: item[1][0]) + # Transformamos la lista para obtener solo direcciΓ³n y coordenadas + result = [(addr, dist, coords) for addr, (dist, coords) in sorted_list] + + selected_neighbor = result[0] if result else None if selected_neighbor: #logging.info(f"πŸ“ Selected neighbor: {selected_neighbor}") - try: - ( - neighbor_latitude, - neighbor_longitude, - ) = selected_neighbor.get_geolocation() - distance = selected_neighbor.get_neighbor_distance() - if distance > self.max_initiate_approximation: - # If the distance is too big, we move towards the neighbor - await self.change_geo_location_nearest_neighbor_strategy( - distance, - latitude, - longitude, - neighbor_latitude, - neighbor_longitude, - ) - else: - await self.change_geo_location_random_strategy(latitude, longitude) - except Exception as e: - logging.info(f"πŸ“ Neighbor location/distance not found for {selected_neighbor.get_addr()}: {e}") + addr, dist, (lat, long) = selected_neighbor + if dist > self.max_initiate_approximation: + # If the distance is too big, we move towards the neighbor + logging.info(f"Moving towards nearest neighbor: {addr}") + await self.change_geo_location_nearest_neighbor_strategy( + dist, + latitude, + longitude, + lat, + long, + ) + else: await self.change_geo_location_random_strategy(latitude, longitude) else: - await self.change_geo_location_random_strategy(latitude, longitude) + await self.change_geo_location_random_strategy(latitude, longitude) else: await self.change_geo_location_random_strategy(latitude, longitude) else: logging.error(f"πŸ“ Mobility type {self.mobility_type} not implemented") return - async def change_connections_based_on_distance(self): - """ - Changes the connections of the entity based on the distance to neighboring nodes. - This coroutine evaluates the current connections in the topology and adjusts their status to - either direct or undirected based on their distance from the entity. If a neighboring node is - within a certain distance, it is marked as a direct connection; otherwise, it is marked as - undirected. - - Additionally, it updates the network conditions for each connection based on the distance, - ensuring that the current state is reflected accurately. - - Args: - None: This function does not take any arguments. - - Raises: - KeyError: If a connection address is not found during the process. - Exception: For any other errors that may occur while changing connections. - - Notes: - - The method expects the mobility type to be either "topology" or "both". - - It logs the distance evaluations and changes made for tracking and debugging purposes. - """ - if self.mobility and (self.mobility_type == "topology" or self.mobility_type == "both"): - try: - # logging.info(f"πŸ“ Checking connections based on distance") - connections_topology = await self.cm.get_addrs_current_connections() - # logging.info(f"πŸ“ Connections of the topology: {connections_topology}") - if len(connections_topology) < 1: - # logging.error(f"πŸ“ Not enough connections for mobility") - return - # Nodes that are too far away should be marked as undirected connections, and closer nodes should be marked as directed connections. - for addr in connections_topology: - distance = self.cm.connections[addr].get_neighbor_distance() - if distance is None: - # If the distance is not found, we skip the node - continue - conditions = await self.calculate_network_conditions(distance) - #logging.info(f"Conditions for source: {addr}, | {conditions}") - # Only update the network conditions if they have changed - if ( - addr not in self.current_network_conditions or self.current_network_conditions[addr] != conditions - ): - # eth1 is the interface of the container that connects to the node network - eth0 is the interface of the container that connects to the frontend/backend - self.cm.set_network_conditions( - interface="eth0", - network=addr.split(":")[0], - bandwidth=conditions["bandwidth"], - delay=conditions["delay"], - delay_distro="10ms", - delay_distribution="normal", - loss="0%", - duplicate="0%", - corrupt="0%", - reordering="0%", - ) - self.current_network_conditions[addr] = conditions - else: - logging.info("network conditions havent changed since last time") - except KeyError: - # Except when self.cm.connections[addr] is not found (disconnected during the process) - logging.exception(f"πŸ“ Connection {addr} not found") - return - except Exception: - logging.exception("πŸ“ Error changing connections based on distance") - return - - async def calculate_network_conditions(self, distance): - logging.info(f"Calculating conditions for distance: {distance}") - def extract_number(value): - import re - match = re.match(r"([\d.]+)", value) - if not match: - raise ValueError(f"Formato invΓ‘lido: {value}") - return float(match.group(1)) - - thresholds = sorted(self.network_conditions.keys()) - - # Si la distancia es menor que el primer umbral, devolver la mejor condiciΓ³n - if distance < thresholds[0]: - return { - "bandwidth": self.network_conditions[thresholds[0]]["bandwidth"], - "delay": self.network_conditions[thresholds[0]]["delay"] - } - - # Encontrar el tramo en el que se encuentra la distancia - for i in range(len(thresholds) - 1): - lower_bound = thresholds[i] - upper_bound = thresholds[i + 1] - - if upper_bound == float("inf"): - break - - if lower_bound <= distance < upper_bound: - #logging.info(f"Bounds | lower: {lower_bound} | upper: {upper_bound}") - lower_cond = self.network_conditions[lower_bound] - upper_cond = self.network_conditions[upper_bound] - - # Extraer valores numΓ©ricos y unidades - lower_bandwidth_value = extract_number(lower_cond["bandwidth"]) - upper_bandwidth_value = extract_number(upper_cond["bandwidth"]) - lower_bandwidth_unit = lower_cond["bandwidth"].replace(str(lower_bandwidth_value), "") - upper_bandwidth_unit = upper_cond["bandwidth"].replace(str(upper_bandwidth_value), "") - - lower_delay_value = extract_number(lower_cond["delay"]) - upper_delay_value = extract_number(upper_cond["delay"]) - delay_unit = lower_cond["delay"].replace(str(lower_delay_value), "") - - # Calcular el progreso en el tramo (0 a 1) - progress = (distance - lower_bound) / (upper_bound - lower_bound) - #logging.info(f"Progress between the bounds: {progress}") - - # InterpolaciΓ³n lineal de valores - bandwidth_value = lower_bandwidth_value - progress * (lower_bandwidth_value - upper_bandwidth_value) - delay_value = lower_delay_value + progress * (upper_delay_value - lower_delay_value) - - # Reconstruir valores con unidades originales - bandwidth = f"{round(bandwidth_value, 2)}{lower_bandwidth_unit}" - delay = f"{round(delay_value, 2)}{delay_unit}" - - return {"bandwidth": bandwidth, "delay": delay} - - # Si la distancia es infinita, devolver el ΓΊltimo valor - return { - "bandwidth": self.network_conditions[float("inf")]["bandwidth"], - "delay": self.network_conditions[float("inf")]["delay"] - } - - async def change_connections(self): - """ - Changes the connections of the entity based on the specified mobility scheme. - - This coroutine evaluates the current and potential connections at specified intervals (based - on the round frequency) and makes adjustments according to the mobility scheme in use. If - the mobility type is appropriate and the current round is a multiple of the round frequency, - it will proceed to change connections. - - Args: - None: This function does not take any arguments. - - Raises: - None: This function does not raise exceptions, but it logs errors related to connection counts - and unsupported mobility schemes. - - Notes: - - The function currently supports a "random" mobility scheme, where it randomly selects - a current connection to disconnect and a potential connection to connect. - - If there are insufficient connections available, an error will be logged. - - All actions and decisions made by the function are logged for tracking purposes. - """ - if ( - self.mobility - and (self.mobility_type == "topology" or self.mobility_type == "both") - and self.round % self.round_frequency == 0 - ): - logging.info("πŸ“ Changing connections") - current_connections = await self.cm.get_addrs_current_connections(only_direct=True) - potential_connections = await self.cm.get_addrs_current_connections(only_undirected=True) - logging.info( - f"πŸ“ Current connections: {current_connections} | Potential future connections: {potential_connections}" - ) - if len(current_connections) < 1 or len(potential_connections) < 1: - logging.error("πŸ“ Not enough connections for mobility") - return - - if self.scheme_mobility == "random": - random_neighbor = random.choice(current_connections) # noqa: S311 - random_potential_neighbor = random.choice(potential_connections) # noqa: S311 - logging.info(f"πŸ“ Selected node(s) to disconnect: {random_neighbor}") - logging.info(f"πŸ“ Selected node(s) to connect: {random_potential_neighbor}") - await self.cm.disconnect(random_neighbor, mutual_disconnection=True) - await self.cm.connect(random_potential_neighbor, direct=True) - logging.info(f"πŸ“ New connections: {self.get_current_connections(only_direct=True)}") - logging.info(f"πŸ“ Neighbors in config: {self.config.participant['network_args']['neighbors']}") - else: - logging.error(f"πŸ“ Mobility scheme {self.scheme_mobility} not implemented") - return diff --git a/nebula/core/network/networksimulation/nebulanetworksimulator.py b/nebula/addons/networksimulation/nebulanetworksimulator.py similarity index 59% rename from nebula/core/network/networksimulation/nebulanetworksimulator.py rename to nebula/addons/networksimulation/nebulanetworksimulator.py index e99a8315d..152e83b65 100644 --- a/nebula/core/network/networksimulation/nebulanetworksimulator.py +++ b/nebula/addons/networksimulation/nebulanetworksimulator.py @@ -1,8 +1,10 @@ import asyncio import subprocess import logging -from nebula.core.network.networksimulation.networksimulator import NetworkSimulator +from nebula.addons.networksimulation.networksimulator import NetworkSimulator from nebula.core.utils.locker import Locker +from nebula.core.eventmanager import EventManager +from nebula.addons.GPS.gpsmodule import GPSEvent from typing import TYPE_CHECKING if TYPE_CHECKING: from nebula.core.network.communications import CommunicationsManager @@ -12,11 +14,12 @@ class NebulaNS(NetworkSimulator): 100: {"bandwidth": "5Gbps", "delay": "5ms"}, 200: {"bandwidth": "2Gbps", "delay": "50ms"}, 300: {"bandwidth": "100Mbps", "delay": "200ms"}, - float("inf"): {"bandwidth": "10Mbps", "delay": "100000ms"}, + float("inf"): {"bandwidth": "10Mbps", "delay": "10000ms"}, } IP_MULTICAST = "239.255.255.250" - def __init__(self, communication_manager: "CommunicationsManager", changing_interval, interface, verbose=False): + def __init__(self, event_manager : EventManager, communication_manager: "CommunicationsManager", changing_interval, interface, verbose=False): + self._event_manager = event_manager self._cm = communication_manager self._refresh_interval = changing_interval self._node_interface = interface @@ -26,43 +29,74 @@ def __init__(self, communication_manager: "CommunicationsManager", changing_inte self._current_network_conditions = {} self._running = False + @property + def em(self): + return self._event_manager + async def start(self): logging.info("🌐 Nebula Network Simulator starting...") self._running = True - asyncio.create_task(self._change_network_conditions_based_on_distances()) + grace_time = self._cm.config.participant["mobility_args"]["grace_time_mobility"] + # if self._verbose: logging.info(f"Waiting {grace_time}s to start applying network conditions based on distances between devices") + # await asyncio.sleep(grace_time) + await self.em.subscribe_addonevent(GPSEvent, self._change_network_conditions_based_on_distances) async def stop(self): self._running = False - async def _change_network_conditions_based_on_distances(self): - grace_time = self._cm.config.participant["mobility_args"]["grace_time_mobility"] - if self._verbose: logging.info(f"Waiting {grace_time}s to start applying network conditions based on distances between devices") - await asyncio.sleep(grace_time) + async def _change_network_conditions_based_on_distances(self, gpsevent : GPSEvent): + distances = await gpsevent.get_event_data() + await asyncio.sleep(self._refresh_interval) + if self._verbose: logging.info("Refresh | conditions based on distances...") + try: + for addr, (distance, _) in distances.items(): + if distance is None: + # If the distance is not found, we skip the node + continue + conditions = await self._calculate_network_conditions(distance) + # Only update the network conditions if they have changed + if (addr not in self._current_network_conditions or self._current_network_conditions[addr] != conditions): + addr_ip = addr.split(":")[0] + self._set_network_condition_for_addr(self._node_interface, addr_ip, conditions["bandwidth"], conditions["delay"]) + self._set_network_condition_for_multicast(self._node_interface, addr_ip, self.IP_MULTICAST, conditions["bandwidth"], conditions["delay"]) + async with self._network_conditions_lock: + self._current_network_conditions[addr] = conditions + else: + if self._verbose: logging.info("network conditions havent changed since last time") + except KeyError: + logging.exception(f"πŸ“ Connection {addr} not found") + except Exception: + logging.exception("πŸ“ Error changing connections based on distance") - while self._running: - await asyncio.sleep(self._refresh_interval) - if self._verbose: logging.info("Refresh | conditions based on distances...") - current_connections = await self._cm.get_addrs_current_connections() - try: - for addr in current_connections: - distance = self._cm.connections[addr].get_neighbor_distance() - if distance is None: - # If the distance is not found, we skip the node - continue - conditions = await self._calculate_network_conditions(distance) - # Only update the network conditions if they have changed - if (addr not in self._current_network_conditions or self._current_network_conditions[addr] != conditions): - addr_ip = addr.split(":")[0] - self._set_network_condition_for_addr(self._node_interface, addr_ip, conditions["bandwidth"], conditions["delay"]) - self._set_network_condition_for_multicast(self._node_interface, addr_ip, self.IP_MULTICAST, conditions["bandwidth"], conditions["delay"]) - async with self._network_conditions_lock: - self._current_network_conditions[addr] = conditions - else: - logging.info("network conditions havent changed since last time") - except KeyError: - logging.exception(f"πŸ“ Connection {addr} not found") - except Exception: - logging.exception("πŸ“ Error changing connections based on distance") + # async def _change_network_conditions_based_on_distances(self): + # grace_time = self._cm.config.participant["mobility_args"]["grace_time_mobility"] + # if self._verbose: logging.info(f"Waiting {grace_time}s to start applying network conditions based on distances between devices") + # await asyncio.sleep(grace_time) + + # while self._running: + # await asyncio.sleep(self._refresh_interval) + # if self._verbose: logging.info("Refresh | conditions based on distances...") + # current_connections = await self._cm.get_addrs_current_connections() + # try: + # for addr in current_connections: + # distance = self._cm.connections[addr].get_neighbor_distance() + # if distance is None: + # # If the distance is not found, we skip the node + # continue + # conditions = await self._calculate_network_conditions(distance) + # # Only update the network conditions if they have changed + # if (addr not in self._current_network_conditions or self._current_network_conditions[addr] != conditions): + # addr_ip = addr.split(":")[0] + # self._set_network_condition_for_addr(self._node_interface, addr_ip, conditions["bandwidth"], conditions["delay"]) + # self._set_network_condition_for_multicast(self._node_interface, addr_ip, self.IP_MULTICAST, conditions["bandwidth"], conditions["delay"]) + # async with self._network_conditions_lock: + # self._current_network_conditions[addr] = conditions + # else: + # logging.info("network conditions havent changed since last time") + # except KeyError: + # logging.exception(f"πŸ“ Connection {addr} not found") + # except Exception: + # logging.exception("πŸ“ Error changing connections based on distance") async def set_thresholds(self, thresholds : dict): async with self._network_conditions_lock: @@ -272,4 +306,52 @@ def clear_network_conditions(self, interface): ) except Exception as e: logging.exception(f"❗️ Network simulation error: {e}") - return \ No newline at end of file + return + + def _generate_network_conditions(self): + # TODO: Implement selection of network conditions from frontend + if self.config.participant["network_args"]["simulation"]: + interface = self.config.participant["network_args"]["interface"] + bandwidth = self.config.participant["network_args"]["bandwidth"] + delay = self.config.participant["network_args"]["delay"] + delay_distro = self.config.participant["network_args"]["delay-distro"] + delay_distribution = self.config.participant["network_args"]["delay-distribution"] + loss = self.config.participant["network_args"]["loss"] + duplicate = self.config.participant["network_args"]["duplicate"] + corrupt = self.config.participant["network_args"]["corrupt"] + reordering = self.config.participant["network_args"]["reordering"] + logging.info( + f"🌐 Network simulation is enabled | Interface: {interface} | Bandwidth: {bandwidth} | Delay: {delay} | Delay Distro: {delay_distro} | Delay Distribution: {delay_distribution} | Loss: {loss} | Duplicate: {duplicate} | Corrupt: {corrupt} | Reordering: {reordering}" + ) + try: + results = subprocess.run( + [ + "tcset", + str(interface), + "--rate", + str(bandwidth), + "--delay", + str(delay), + "--delay-distro", + str(delay_distro), + "--delay-distribution", + str(delay_distribution), + "--loss", + str(loss), + "--duplicate", + str(duplicate), + "--corrupt", + str(corrupt), + "--reordering", + str(reordering), + ], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + except Exception as e: + logging.exception(f"🌐 Network simulation error: {e}") + return + else: + logging.info("🌐 Network simulation is disabled. Using default network conditions...") \ No newline at end of file diff --git a/nebula/core/network/networksimulation/networksimulator.py b/nebula/addons/networksimulation/networksimulator.py similarity index 71% rename from nebula/core/network/networksimulation/networksimulator.py rename to nebula/addons/networksimulation/networksimulator.py index 5465683ed..d468b2bc5 100644 --- a/nebula/core/network/networksimulation/networksimulator.py +++ b/nebula/addons/networksimulation/networksimulator.py @@ -27,8 +27,8 @@ def clear_network_conditions(self, interface): class NetworkSimulatorException(Exception): pass -def factory_network_simulator(net_sim, communication_manager, changing_interval, interface, verbose) -> NetworkSimulator: - from nebula.core.network.networksimulation.nebulanetworksimulator import NebulaNS +def factory_network_simulator(net_sim, event_manager, communication_manager, changing_interval, interface, verbose) -> NetworkSimulator: + from nebula.addons.networksimulation.nebulanetworksimulator import NebulaNS SIMULATION_SERVICES = { "nebula": NebulaNS, @@ -37,6 +37,6 @@ def factory_network_simulator(net_sim, communication_manager, changing_interval, net_serv = SIMULATION_SERVICES.get(net_sim, NebulaNS) if net_serv: - return net_serv(communication_manager, changing_interval, interface, verbose) + return net_serv(event_manager, communication_manager, changing_interval, interface, verbose) else: raise NetworkSimulatorException(f"Network Simulator {net_sim} not found") \ No newline at end of file diff --git a/nebula/core/addonmanager.py b/nebula/core/addonmanager.py new file mode 100644 index 000000000..3401fbf90 --- /dev/null +++ b/nebula/core/addonmanager.py @@ -0,0 +1,33 @@ +import logging +import asyncio +from nebula.addons.functions import print_msg_box +from nebula.addons.mobility import Mobility +from nebula.addons.networksimulation.networksimulator import factory_network_simulator +from nebula.addons.GPS.gpsmodule import factory_gpsmodule +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from nebula.core.engine import Engine + +class AddondManager(): + def __init__(self, engine : "Engine", config): + self._engine = engine + self._config = config + self._addons = [] + + async def deploy_additional_services(self): + print_msg_box(msg="Deploying Additional Services\n(='.'=)", indent=2, title="Addons Manager") + if self._config.participant["mobility_args"]["mobility"]: + mobility = Mobility(self._config, self._engine.cm, self._engine.event_manager) + self._addons.append(mobility) + if self._config.participant["network_args"]["simulation"]: + refresh_conditions_interval = 5 + network_simulation = factory_network_simulator("nebula", self._engine.event_manager, self._engine.cm, refresh_conditions_interval, "eth0", verbose=False) + self._addons.append(network_simulation) + update_interval = 5 + gps = factory_gpsmodule("nebula", self._config, self._engine.event_manager, self._engine.addr, update_interval, verbose=False) + self._addons.append(gps) + + for add in self._addons: + await add.start() + + \ No newline at end of file diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 8258ef3aa..ea111f179 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -12,6 +12,7 @@ from nebula.core.network.communications import CommunicationsManager from nebula.core.situationalawareness.nodemanager import NodeManager from nebula.core.utils.locker import Locker +from nebula.core.addonmanager import AddondManager logging.getLogger("requests").setLevel(logging.WARNING) logging.getLogger("urllib3").setLevel(logging.WARNING) @@ -159,7 +160,7 @@ def __init__( engine=self, ) - self._event_manager = EventManager() + self._event_manager = EventManager(verbose=True) logging.info("Registering callbacks for MessageEvents...") self.register_message_events_callbacks() @@ -167,6 +168,8 @@ def __init__( # Additional callbacks not registered automatically self.register_message_callback(("model", "initialization"), "model_initialization_callback") self.register_message_callback(("model", "update"), "model_update_callback") + + self._addon_manager = AddondManager(self, self.config) @property def cm(self): @@ -545,7 +548,7 @@ async def _link_disconnect_from_callback(self, source, message): await self.nm.update_neighbors(addr, remove=True) """ ############################## - # ENGINE FUNCTIONALITY # + # REGISTERING CALLBACKS # ############################## """ @@ -577,9 +580,11 @@ async def trigger_event(self, message_event): async def get_geoloc(self): return await self.nm.get_geoloc() - async def calculate_distance(self, latitude, longitude): - return await self.nm.calculate_distance(latitude, longitude) + """ ############################## + # ENGINE FUNCTIONALITY # + ############################## + """ async def _aditional_node_start(self): self.update_sinchronized_status(False) @@ -684,6 +689,7 @@ async def start_communications(self): await self.nm.set_configs() await self._reporter.start() await self.cm.deploy_additional_services() + await self._addon_manager.deploy_additional_services() await asyncio.sleep(self.config.participant["misc_args"]["grace_time_connection"] // 2) async def deploy_federation(self): diff --git a/nebula/core/eventmanager.py b/nebula/core/eventmanager.py index 9c202ddb6..2566fdf4b 100755 --- a/nebula/core/eventmanager.py +++ b/nebula/core/eventmanager.py @@ -3,39 +3,21 @@ import logging from collections import defaultdict from functools import wraps - +from abc import ABC, abstractmethod from nebula.core.network.messages import MessageEvent +from nebula.core.utils.locker import Locker - -def event_handler(message_type, action): - """Decorator for registering an event handler.""" - - def decorator(func): - @wraps(func) - async def async_wrapper(*args, **kwargs): - return await func(*args, **kwargs) - - @wraps(func) - def sync_wrapper(*args, **kwargs): - return func(*args, **kwargs) - - if asyncio.iscoroutinefunction(func): - wrapper = async_wrapper - else: - wrapper = sync_wrapper - - action_name = message_type.Action.Name(action) if action is not None else "None" - wrapper._event_handler = (message_type.DESCRIPTOR.full_name, action_name) - return wrapper - - return decorator - +class AddonEvent(ABC): + @abstractmethod + async def get_event_data(self): + pass class EventManager: - def __init__(self, default_callbacks=None): - self._event_callbacks = defaultdict(list) - self._register_default_callbacks(default_callbacks or []) + def __init__(self, verbose=False): self._subscribers: dict[tuple[str, str], list] = {} + self._addons_events_subs : dict [AddonEvent, list] = {} + self._addons_event_lock = Locker("addons_event_lock", async_lock=True) + self._verbose = verbose def subscribe(self, event_type: tuple[str, str], callback: callable): """Register a callback for a specific event type.""" @@ -46,7 +28,7 @@ def subscribe(self, event_type: tuple[str, str], callback: callable): async def publish(self, message_event: MessageEvent): """Trigger all callbacks registered for a specific event type.""" - # logging.info(f"Publishing MessageEvent: {message_event.message_type}") + if self._verbose: logging.info(f"Publishing MessageEvent: {message_event.message_type}") event_type = message_event.message_type if event_type not in self._subscribers: logging.error(f"EventManager | No subscribers for event: {event_type}") @@ -54,66 +36,38 @@ async def publish(self, message_event: MessageEvent): for callback in self._subscribers[event_type]: try: - # logging.info(f"EventManager | Triggering callback for event: {event_type}, from source: {message_event.source}") - await callback(message_event.source, message_event.message) + if asyncio.iscoroutinefunction(callback) or inspect.iscoroutine(callback): + await callback(message_event.source, message_event.message) + else: + callback(message_event.source, message_event.message) + if self._verbose: logging.info(f"EventManager | Triggering callback for event: {event_type}, from source: {message_event.source}") except Exception as e: logging.exception(f"EventManager | Error in callback for event {event_type}: {e}") - - def _register_default_callbacks(self, default_callbacks): - """Registers default callbacks for events.""" - for callback in default_callbacks: - handler_info = getattr(callback, "_event_handler", None) - if handler_info is not None: - self.register_event(handler_info, callback) - else: - raise ValueError("The callback must be decorated with @event_handler.") - - def register_callback(self, callback): - """Registers a callback for an event.""" - handler_info = getattr(callback, "_event_handler", None) - if handler_info is not None: - self.register_event(handler_info, callback) - else: - raise ValueError("The callback must be decorated with @event_handler.") - - def register_event(self, handler_info, callback): - """Records a callback for a specific event.""" - if callable(callback): - self._event_callbacks[handler_info].append(callback) - else: - raise ValueError("The callback must be a callable function.") - - def unregister_event(self, handler_info, callback): - """Unregisters a previously registered callback for an event.""" - if callback in self._event_callbacks[handler_info]: - self._event_callbacks[handler_info].remove(callback) - - async def trigger_event(self, source, message, *args, **kwargs): - """Triggers an event, executing all associated callbacks.""" - message_type = message.DESCRIPTOR.full_name - if hasattr(message, "action"): - action_name = message.Action.Name(message.action) - else: - action_name = "None" - - handler_info = (message_type, action_name) - - if handler_info in self._event_callbacks: - for callback in self._event_callbacks[handler_info]: + + async def subscribe_addonevent(self, addonEventType: type[AddonEvent], callback: callable): + """Register a callback for a specific type of AddonEvent.""" + async with self._addons_event_lock: + if addonEventType not in self._addons_events_subs: + self._addons_events_subs[addonEventType] = [] + self._addons_events_subs[addonEventType].append(callback) + logging.info(f"EventManager | Subscribed callback for AddonEvent type: {addonEventType.__name__}") + + async def publish_addonevent(self, addonevent: AddonEvent): + """Trigger all callbacks registered for a specific type of AddonEvent.""" + if self._verbose: logging.info(f"Publishing AddonEvent: {addonevent}") + async with self._addons_event_lock: + event_type = type(addonevent) + if event_type not in self._addons_events_subs: + logging.error(f"EventManager | No subscribers for AddonEvent type: {event_type.__name__}") + return + + for callback in self._addons_events_subs[event_type]: try: if asyncio.iscoroutinefunction(callback) or inspect.iscoroutine(callback): - await callback(source, message, *args, **kwargs) + await callback(addonevent) else: - callback(source, message, *args, **kwargs) + callback(addonevent) + if self._verbose: logging.info(f"EventManager | Triggering callback for event type: {event_type.__name__}") except Exception as e: - logging.exception(f"Error executing callback for {handler_info}: {e}") - else: - logging.error(f"No callbacks registered for event {handler_info}") - - async def get_event_callbacks(self, event_name): - """Returns the callbacks for a specific event.""" - return self._event_callbacks[event_name] + logging.exception(f"EventManager | Error in callback for AddonEvent {event_type.__name__}: {e}") - def get_event_callbacks_names(self): - """Returns the names of the registered events.""" - return self._event_callbacks.keys() diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 8cf4e9e47..84349e8b7 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -7,12 +7,10 @@ import requests -from nebula.addons.mobility import Mobility from nebula.core.network.blacklist import BlackList from nebula.core.network.connection import Connection from nebula.core.network.discoverer import Discoverer from nebula.core.network.externalconnection.externalconnectionservice import factory_connection_service -from nebula.core.network.networksimulation.networksimulator import factory_network_simulator from nebula.core.network.forwarder import Forwarder from nebula.core.network.messages import MessageEvent, MessagesManager from nebula.core.network.propagator import Propagator @@ -58,7 +56,6 @@ def __init__(self, engine: "Engine"): # self._health = Health(addr=self.addr, config=self.config, cm=self) self._forwarder = Forwarder(config=self.config, cm=self) self._propagator = Propagator(cm=self) - self._mobility = Mobility(config=self.config, cm=self) # List of connections to reconnect {addr: addr, tries: 0} self.connections_reconnect = [] @@ -75,10 +72,6 @@ def __init__(self, engine: "Engine"): # Connection service to communicate with external devices self._external_connection_service = factory_connection_service("nebula", self, self.addr) - # Network simulator service to deplay realistic network conditions - refresh_conditions_interval = 5 - self._network_simulator = factory_network_simulator("nebula", self, refresh_conditions_interval, "eth0", verbose=True) - @property def engine(self): return self._engine @@ -107,17 +100,17 @@ def forwarder(self): def propagator(self): return self._propagator - @property - def mobility(self): - return self._mobility + # @property + # def mobility(self): + # return self._mobility @property def ecs(self): return self._external_connection_service - @property - def ns(self): - return self._network_simulator + # @property + # def ns(self): + # return self._network_simulator @property def bl(self): @@ -263,157 +256,25 @@ async def stablish_connection_to_federation(self, msg_type="discover_join", addr discovers_sent += 1 return discovers_sent - """ ############################## - # NETWORK CONDITIONS # - ############################## - """ - - async def get_network_conditions_grace_time(self): - return await self.config.participant["mobility_args"]["change_geo_interval"] - - def _generate_network_conditions(self): - # TODO: Implement selection of network conditions from frontend - if self.config.participant["network_args"]["simulation"]: - interface = self.config.participant["network_args"]["interface"] - bandwidth = self.config.participant["network_args"]["bandwidth"] - delay = self.config.participant["network_args"]["delay"] - delay_distro = self.config.participant["network_args"]["delay-distro"] - delay_distribution = self.config.participant["network_args"]["delay-distribution"] - loss = self.config.participant["network_args"]["loss"] - duplicate = self.config.participant["network_args"]["duplicate"] - corrupt = self.config.participant["network_args"]["corrupt"] - reordering = self.config.participant["network_args"]["reordering"] - logging.info( - f"🌐 Network simulation is enabled | Interface: {interface} | Bandwidth: {bandwidth} | Delay: {delay} | Delay Distro: {delay_distro} | Delay Distribution: {delay_distribution} | Loss: {loss} | Duplicate: {duplicate} | Corrupt: {corrupt} | Reordering: {reordering}" - ) - try: - results = subprocess.run( - [ - "tcset", - str(interface), - "--rate", - str(bandwidth), - "--delay", - str(delay), - "--delay-distro", - str(delay_distro), - "--delay-distribution", - str(delay_distribution), - "--loss", - str(loss), - "--duplicate", - str(duplicate), - "--corrupt", - str(corrupt), - "--reordering", - str(reordering), - ], - check=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - ) - except Exception as e: - logging.exception(f"🌐 Network simulation error: {e}") - return - else: - logging.info("🌐 Network simulation is disabled. Using default network conditions...") - - def _reset_network_conditions(self): - interface = self.config.participant["network_args"]["interface"] - logging.info("🌐 Resetting network conditions") - try: - results = subprocess.run( - ["tcdel", str(interface), "--all"], - check=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - ) - except Exception as e: - logging.exception(f"❗️ Network simulation error: {e}") - return - - async def set_network_conditions(self, addr, distance): - await self.ns.set_network_conditions(addr, distance) - #self._set_network_conditions(self, interface, network, bandwidth, delay, delay_distro, delay_distribution, loss, duplicate, corrupt, reordering) - - def clear_network_conditions(self): - self.ns.clear_network_conditions() - - async def set_network_conditions_thresholds(self, thresholds : dict): - await self.ns.set_thresholds(thresholds) - - def _set_network_conditions( - self, - interface="eth0", - network="192.168.50.2", - bandwidth="5Gbps", - delay="0ms", - delay_distro="0ms", - delay_distribution="normal", - loss="0%", - duplicate="0%", - corrupt="0%", - reordering="0%", - ): - logging.info( - f"🌐 Changing network conditions | Interface: {interface} | Network: {network} | Bandwidth: {bandwidth} | Delay: {delay} | Delay Distro: {delay_distro} | Delay Distribution: {delay_distribution} | Loss: {loss} | Duplicate: {duplicate} | Corrupt: {corrupt} | Reordering: {reordering}" - ) - try: - results = subprocess.run( - [ - "tcset", - str(interface), - "--network", - str(network) if network is not None else "", - "--rate", - str(bandwidth), - "--delay", - str(delay), - "--delay-distro", - str(delay_distro), - "--delay-distribution", - str(delay_distribution), - "--loss", - str(loss), - "--duplicate", - str(duplicate), - "--corrupt", - str(corrupt), - "--reordering", - str(reordering), - "--change", - ], - check=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - ) - except Exception as e: - logging.exception(f"❗️ Network simulation error: {e}") - return - - - + """ ############################## # OTHER FUNCTIONALITIES # ############################## """ - async def update_geolocalization(self, geoloc : dict): - async with self.get_connections_lock(): - #logging.info("Update geolocs to simulate network conditions") - for source in geoloc.keys(): - latitude, longitude = geoloc[source] - #logging.info(f"Update geolocs for source: {source}, geoloc: ({latitude},{longitude})") - if source in self.connections: - self.connections[source].update_geolocation(latitude, longitude) - else: # When not connected to device yet - #logging.info(f"Update conditions for not already connected source: {source})") - if self.config.participant["network_args"]["simulation"]: - distance = await self.engine.calculate_distance(latitude, longitude) - await self.ns.set_network_conditions(source, distance) + # async def update_geolocalization(self, geoloc : dict): + # async with self.get_connections_lock(): + # #logging.info("Update geolocs to simulate network conditions") + # for source in geoloc.keys(): + # latitude, longitude = geoloc[source] + # #logging.info(f"Update geolocs for source: {source}, geoloc: ({latitude},{longitude})") + # if source in self.connections: + # self.connections[source].update_geolocation(latitude, longitude) + # else: # When not connected to device yet + # #logging.info(f"Update conditions for not already connected source: {source})") + # if self.config.participant["network_args"]["simulation"]: + # distance = await self.engine.calculate_distance(latitude, longitude) + # await self.ns.set_network_conditions(source, distance) def get_connections_lock(self): return self.connections_lock @@ -607,11 +468,12 @@ async def deploy_additional_services(self): await self._forwarder.start() if self.config.participant["mobility_args"]["mobility"]: if self.config.participant["network_args"]["simulation"]: - await self.ns.start() + pass + #await self.ns.start() # await self._discoverer.start() # await self._health.start() self._propagator.start() - await self._mobility.start() + #await self._mobility.start() async def include_received_message_hash(self, hash_message): diff --git a/nebula/core/network/connection.py b/nebula/core/network/connection.py index 3794bdee4..8ea31e62b 100755 --- a/nebula/core/network/connection.py +++ b/nebula/core/network/connection.py @@ -58,8 +58,6 @@ def __init__( self.config = config self.federated_round = Connection.DEFAULT_FEDERATED_ROUND - self.latitude = None - self.longitude = None self.loop = asyncio.get_event_loop() self.read_task = None self.process_task = None @@ -112,32 +110,6 @@ def get_tunnel_status(self): def update_round(self, federated_round): self.federated_round = federated_round - def update_geolocation(self, latitude, longitude): - self.latitude = latitude - self.longitude = longitude - self.config.participant["mobility_args"]["neighbors_distance"][self.addr] = self.compute_distance_myself() - - def get_geolocation(self): - if self.latitude is None or self.longitude is None: - raise ValueError("Geo-location not set for this neighbor") - return self.latitude, self.longitude - - def get_neighbor_distance(self): - if self.addr not in self.config.participant["mobility_args"]["neighbors_distance"]: - return None - return self.config.participant["mobility_args"]["neighbors_distance"][self.addr] - - def compute_distance(self, latitude, longitude): - distance_m = distance.distance((self.latitude, self.longitude), (latitude, longitude)).m - return distance_m - - def compute_distance_myself(self): - distance_m = self.compute_distance( - self.config.participant["mobility_args"]["latitude"], - self.config.participant["mobility_args"]["longitude"], - ) - return distance_m - def get_ready(self): return True if self.federated_round != Connection.DEFAULT_FEDERATED_ROUND else False diff --git a/nebula/core/network/externalconnection/nebuladiscoveryservice.py b/nebula/core/network/externalconnection/nebuladiscoveryservice.py index 037c9b535..0b4742832 100644 --- a/nebula/core/network/externalconnection/nebuladiscoveryservice.py +++ b/nebula/core/network/externalconnection/nebuladiscoveryservice.py @@ -129,9 +129,9 @@ async def start(self): logging.info("[NebulaBeacon]: Starting sending pressence beacon") self.running = True while self.running: - await self.send_beacon() await asyncio.sleep(self.interval) - + await self.send_beacon() + async def stop(self): logging.info("[NebulaBeacon]: Stop existance beacon") self.running = False diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index 95227b717..f341ec857 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -5,7 +5,6 @@ from nebula.addons.functions import print_msg_box from nebula.core.situationalawareness.awareness.sanetwork.sanetwork import SANetwork from nebula.core.situationalawareness.awareness.satraining.satraining import SATraining -from nebula.core.situationalawareness.awareness.GPS.gpsmodule import factory_gpsmodule from nebula.core.utils.locker import Locker if TYPE_CHECKING: @@ -34,7 +33,6 @@ def __init__( self._situational_awareness_trainning = SATraining(self, "hybrid", "fastreboot") self._restructure_process_lock = Locker(name="restructure_process_lock") self._restructure_cooldown = 0 - self._gpsmodule = factory_gpsmodule("nebula", self, self._addr) @property def nm(self): @@ -48,13 +46,9 @@ def san(self): def cm(self): return self.nm.engine.cm - @property - def gps(self): - return self._gpsmodule async def init(self): #if not self.is_additional_participant(): - await self.gps.start() await self.san.init() def is_additional_participant(self): @@ -69,21 +63,8 @@ async def get_geoloc(self): return (latitude, longitude) async def mobility_actions(self): - await self.verify_gps_service() await self.san.module_actions() - """ ############################### - # GPS SERVICE # - ############################### - """ - - async def calculate_distance(self, other_latitude, other_longitude): - self_lat, self_long = await self.get_geoloc() - return await self.gps.calculate_distance(self_lat, self_long, other_latitude, other_longitude) - - async def verify_gps_service(self): - if not await self.gps.is_running(): - await self.gps.start() """ ############################### # REESTRUCTURE TOPOLOGY # diff --git a/nebula/core/situationalawareness/awareness/satraining/satraining.py b/nebula/core/situationalawareness/awareness/satraining/satraining.py index aa0139b39..ac0f58a9f 100644 --- a/nebula/core/situationalawareness/awareness/satraining/satraining.py +++ b/nebula/core/situationalawareness/awareness/satraining/satraining.py @@ -22,5 +22,5 @@ def __init__( title="Training SA module", ) self._sam = sam - self._trainning_policy = factory_training_policy(training_policy) + #self._trainning_policy = factory_training_policy(training_policy) self._weight_strategies = weight_strategies \ No newline at end of file diff --git a/nebula/core/situationalawareness/nodemanager.py b/nebula/core/situationalawareness/nodemanager.py index f49a379c1..ca05d022e 100644 --- a/nebula/core/situationalawareness/nodemanager.py +++ b/nebula/core/situationalawareness/nodemanager.py @@ -7,7 +7,6 @@ from nebula.core.situationalawareness.fastreboot import FastReboot from nebula.core.situationalawareness.modelhandlers.modelhandler import factory_ModelHandler from nebula.core.situationalawareness.momentum import Momentum -#from nebula.core.situationalawareness.neighborpolicies.neighborpolicy import factory_NeighborPolicy from nebula.core.situationalawareness.awareness.samodule import SAModule from nebula.core.utils.locker import Locker @@ -112,9 +111,6 @@ async def set_configs(self): async def get_geoloc(self): return await self.sam.get_geoloc() - async def calculate_distance(self, latitude, longitude): - return await self.sam.calculate_distance(latitude, longitude) - async def mobility_actions(self): await self.sam.mobility_actions() diff --git a/nebula/frontend/config/participant.json.example b/nebula/frontend/config/participant.json.example index db3ec0d34..3e2b916b9 100755 --- a/nebula/frontend/config/participant.json.example +++ b/nebula/frontend/config/participant.json.example @@ -41,7 +41,7 @@ "addr": "", "neighbors": "", "interface": "eth0", - "simulation": false, + "simulation": true, "bandwidth": "5Gbps", "delay": "0ms", "delay-distro": "0ms", From b24468434603e74d9697a6dbc1c9f81fd77e0b6a Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 27 Feb 2025 15:36:53 +0100 Subject: [PATCH 120/233] feat aggregation event --- nebula/core/aggregation/aggregator.py | 30 ++++++++----- nebula/core/eventmanager.py | 44 ++++++++++++++++++- nebula/core/network/communications.py | 16 +------ .../frontend/config/participant.json.example | 2 +- 4 files changed, 63 insertions(+), 29 deletions(-) diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index f395f140f..f433d1a73 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -4,16 +4,31 @@ from functools import partial from nebula.core.utils.locker import Locker from nebula.core.aggregation.updatehandlers.updatehandler import factory_update_handler +from nebula.core.eventmanager import NodeEvent from typing import TYPE_CHECKING if TYPE_CHECKING: from nebula.core.engine import Engine +class AggregationEvent(NodeEvent): + def __init__(self, updates : dict, expected_nodes : set, missing_nodes : set): + self._updates = updates + self._expected_nodes = expected_nodes + self._missing_nodes = missing_nodes + + def __str__(self): + return "Aggregation Ready" + + async def get_event_data(self) -> tuple[dict, set, set]: + return (self._updates, self._expected_nodes, self._missing_nodes) + + async def is_concurrent(self) -> bool: + return False + class AggregatorException(Exception): pass - def create_target_aggregator(config, engine): from nebula.core.aggregation.fedavg import FedAvg from nebula.core.aggregation.krum import Krum @@ -94,16 +109,6 @@ def get_nodes_pending_models_to_aggregate(self): def set_waiting_global_update(self): self._waiting_global_update = True - # async def reset(self): - # await self._add_model_lock.acquire_async() - # self._federation_nodes.clear() - # self._pending_models_to_aggregate.clear() - # try: - # await self._aggregation_done_lock.release_async() - # except: - # pass - # await self._add_model_lock.release_async() - async def get_aggregation(self): try: timeout = self.config.participant["aggregator_args"]["aggregation_timeout"] @@ -154,7 +159,8 @@ async def get_aggregation(self): ) await self.cm.send_message_to_neighbors(message) - updates = await self.engine.apply_weight_strategy(updates) + agg_event = AggregationEvent(updates, self._federation_nodes, missing_nodes) + await self.engine.event_manager.publish_nodeevent(agg_event) aggregated_result = self.run_aggregation(updates) return aggregated_result diff --git a/nebula/core/eventmanager.py b/nebula/core/eventmanager.py index 2566fdf4b..e59f69e21 100755 --- a/nebula/core/eventmanager.py +++ b/nebula/core/eventmanager.py @@ -11,12 +11,24 @@ class AddonEvent(ABC): @abstractmethod async def get_event_data(self): pass + +class NodeEvent(ABC): + @abstractmethod + async def get_event_data(self): + pass + + @abstractmethod + async def is_concurrent(self): + pass +#TODO pto unico de llamada class EventManager: def __init__(self, verbose=False): self._subscribers: dict[tuple[str, str], list] = {} self._addons_events_subs : dict [AddonEvent, list] = {} self._addons_event_lock = Locker("addons_event_lock", async_lock=True) + self._node_events_subs : dict [NodeEvent, list] = {} + self._node_events_lock = Locker("node_events_lock", async_lock=True) self._verbose = verbose def subscribe(self, event_type: tuple[str, str], callback: callable): @@ -69,5 +81,35 @@ async def publish_addonevent(self, addonevent: AddonEvent): callback(addonevent) if self._verbose: logging.info(f"EventManager | Triggering callback for event type: {event_type.__name__}") except Exception as e: - logging.exception(f"EventManager | Error in callback for AddonEvent {event_type.__name__}: {e}") + logging.exception(f"EventManager | Error in callback for AddonEvent {event_type.__name__}: {e}") + + + async def subscribe_nodeevent(self, nodeEventType: type[NodeEvent], callback: callable): + """Register a callback for a specific type of AddonEvent.""" + async with self._node_events_lock: + if nodeEventType not in self._node_events_subs: + self._node_events_subs[nodeEventType] = [] + self._node_events_subs[nodeEventType].append(callback) + logging.info(f"EventManager | Subscribed callback for NodeEvent type: {nodeEventType.__name__}") + + async def publish_nodeevent(self, nodeevent: NodeEvent): + """Trigger all callbacks registered for a specific type of AddonEvent.""" + if self._verbose: logging.info(f"Publishing NodeEvent: {nodeevent}") + async with self._node_events_lock: + event_type = type(nodeevent) + if event_type not in self._node_events_subs: + if self._verbose: logging.error(f"EventManager | No subscribers for NodeEvent type: {event_type.__name__}") + return + for callback in self._node_events_subs[event_type]: + try: + if asyncio.iscoroutinefunction(callback) or inspect.iscoroutine(callback): + if await nodeevent.is_concurrent(): + asyncio.create_task(callback(nodeevent)) + else: + await callback(nodeevent) + else: + callback(nodeevent) + if self._verbose: logging.info(f"EventManager | Triggering callback for event type: {event_type.__name__}") + except Exception as e: + logging.exception(f"EventManager | Error in callback for NodeEvent {event_type.__name__}: {e}") diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 84349e8b7..56f7ffe77 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -261,21 +261,7 @@ async def stablish_connection_to_federation(self, msg_type="discover_join", addr # OTHER FUNCTIONALITIES # ############################## """ - - # async def update_geolocalization(self, geoloc : dict): - # async with self.get_connections_lock(): - # #logging.info("Update geolocs to simulate network conditions") - # for source in geoloc.keys(): - # latitude, longitude = geoloc[source] - # #logging.info(f"Update geolocs for source: {source}, geoloc: ({latitude},{longitude})") - # if source in self.connections: - # self.connections[source].update_geolocation(latitude, longitude) - # else: # When not connected to device yet - # #logging.info(f"Update conditions for not already connected source: {source})") - # if self.config.participant["network_args"]["simulation"]: - # distance = await self.engine.calculate_distance(latitude, longitude) - # await self.ns.set_network_conditions(source, distance) - + def get_connections_lock(self): return self.connections_lock diff --git a/nebula/frontend/config/participant.json.example b/nebula/frontend/config/participant.json.example index 3e2b916b9..db3ec0d34 100755 --- a/nebula/frontend/config/participant.json.example +++ b/nebula/frontend/config/participant.json.example @@ -41,7 +41,7 @@ "addr": "", "neighbors": "", "interface": "eth0", - "simulation": true, + "simulation": false, "bandwidth": "5Gbps", "delay": "0ms", "delay-distro": "0ms", From 8f957c1c7fef947d301ca8528bf9dbebb0290ce9 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Fri, 28 Feb 2025 17:21:38 +0100 Subject: [PATCH 121/233] feature event system integrated --- nebula/addons/GPS/gpsmodule.py | 14 +-- nebula/addons/GPS/nebulagps.py | 14 +-- nebula/addons/mobility.py | 7 +- .../nebulanetworksimulator.py | 13 +- .../networksimulation/networksimulator.py | 4 +- nebula/core/addonmanager.py | 6 +- nebula/core/aggregation/aggregator.py | 27 +---- nebula/core/engine.py | 36 +++--- nebula/core/eventmanager.py | 113 +++++++++++------- nebula/core/nebulaevents.py | 47 ++++++++ nebula/core/network/communications.py | 10 +- nebula/core/network/messages.py | 7 +- .../awareness/samodule.py | 10 +- .../awareness/sanetwork/sanetwork.py | 1 + .../awareness/satraining/satraining.py | 21 +++- .../trainingpolicy/bpstrainingpolicy.py | 11 +- .../trainingpolicy/htstrainingpolicy.py | 12 +- .../trainingpolicy/qdstrainingpolicy.py | 55 ++++++++- .../trainingpolicy/sostrainingpolicy.py | 12 +- .../trainingpolicy/trainingpolicy.py | 11 +- 20 files changed, 286 insertions(+), 145 deletions(-) create mode 100644 nebula/core/nebulaevents.py diff --git a/nebula/addons/GPS/gpsmodule.py b/nebula/addons/GPS/gpsmodule.py index 676af7dda..41e80663a 100644 --- a/nebula/addons/GPS/gpsmodule.py +++ b/nebula/addons/GPS/gpsmodule.py @@ -20,20 +20,10 @@ async def is_running(self): async def calculate_distance(self, self_lat, self_long, other_lat, other_long): pass -class GPSEvent(AddonEvent): - def __init__(self, distances : dict): - self.distances = distances - - def __str__(self): - return "GPSEvent" - - async def get_event_data(self) -> dict: - return self.distances.copy() - class GPSModuleException(Exception): pass -def factory_gpsmodule(gps_module, config, event_manager, addr, update_interval: float = 5.0, verbose=False) -> GPSModule: +def factory_gpsmodule(gps_module, config, addr, update_interval: float = 5.0, verbose=False) -> GPSModule: from nebula.addons.GPS.nebulagps import NebulaGPS GPS_SERVICES = { @@ -43,6 +33,6 @@ def factory_gpsmodule(gps_module, config, event_manager, addr, update_interval: gps_module = GPS_SERVICES.get(gps_module, NebulaGPS) if gps_module: - return gps_module(config, event_manager, addr, update_interval, verbose) + return gps_module(config, addr, update_interval, verbose) else: raise GPSModuleException(f"GPS Module {gps_module} not found") \ No newline at end of file diff --git a/nebula/addons/GPS/nebulagps.py b/nebula/addons/GPS/nebulagps.py index e59ecfc89..0e5759e9d 100644 --- a/nebula/addons/GPS/nebulagps.py +++ b/nebula/addons/GPS/nebulagps.py @@ -1,6 +1,7 @@ import asyncio import logging -from nebula.addons.GPS.gpsmodule import GPSModule, GPSEvent +from nebula.addons.GPS.gpsmodule import GPSModule +from nebula.core.nebulaevents import GPSEvent import socket from nebula.core.utils.locker import Locker from geopy import distance @@ -11,21 +12,16 @@ class NebulaGPS(GPSModule): BROADCAST_PORT = 50001 # Poort used for GPS INTERFACE = "eth2" # Interface to avoid network conditions - def __init__(self, config, event_manager : EventManager, addr, update_interval: float = 5.0, verbose=False): + def __init__(self, config, addr, update_interval: float = 5.0, verbose=False): self._config = config self._addr = addr - self._event_manager = event_manager self.update_interval = update_interval # Frecuencia de emisiΓ³n self.running = False self._node_locations = {} # Diccionario para almacenar ubicaciones de nodos self._broadcast_socket = None self._nodes_location_lock = Locker("nodes_location_lock", async_lock=True) self._verbose = verbose - - @property - def em(self): - return self._event_manager - + async def start(self): """Inicia el servicio de GPS, enviando y recibiendo ubicaciones.""" logging.info("Starting NebulaGPS service...") @@ -102,5 +98,5 @@ async def _notify_geolocs(self): dist = await self.calculate_distance(self_lat, self_long, lat, long) distances[addr] = (dist,(lat, long)) gpsevent = GPSEvent(distances) - asyncio.create_task(self.em.publish_addonevent(gpsevent)) + asyncio.create_task(EventManager.get_instance().publish_addonevent(gpsevent)) \ No newline at end of file diff --git a/nebula/addons/mobility.py b/nebula/addons/mobility.py index 72f5ffdd7..a52dfec7f 100755 --- a/nebula/addons/mobility.py +++ b/nebula/addons/mobility.py @@ -4,7 +4,7 @@ import random import time from nebula.core.eventmanager import EventManager -from nebula.addons.GPS.gpsmodule import GPSEvent +from nebula.core.nebulaevents import GPSEvent from nebula.core.utils.locker import Locker from nebula.addons.functions import print_msg_box from typing import TYPE_CHECKING @@ -13,7 +13,7 @@ class Mobility: - def __init__(self, config, cm: "CommunicationsManager", event_manager: EventManager): + def __init__(self, config, cm: "CommunicationsManager"): """ Initializes the mobility module with specified configuration and communication manager. @@ -66,7 +66,6 @@ def __init__(self, config, cm: "CommunicationsManager", event_manager: EventMana # Logging box with mobility information mobility_msg = f"Mobility: {self.mobility}\nMobility type: {self.mobility_type}\nRadius federation: {self.radius_federation}\nScheme mobility: {self.scheme_mobility}\nEach {self.round_frequency} rounds" print_msg_box(msg=mobility_msg, indent=2, title="Mobility information") - self._event_manager = event_manager self._nodes_distances = {} self._nodes_distances_lock = Locker("nodes_distances_lock", async_lock=True) @@ -98,7 +97,7 @@ async def start(self): asyncio.Task: An asyncio Task object representing the scheduled `run_mobility` operation. """ - await self._event_manager.subscribe_addonevent(GPSEvent, self.update_nodes_distances) + await EventManager.get_instance().subscribe_addonevent(GPSEvent, self.update_nodes_distances) task = asyncio.create_task(self.run_mobility()) return task diff --git a/nebula/addons/networksimulation/nebulanetworksimulator.py b/nebula/addons/networksimulation/nebulanetworksimulator.py index 152e83b65..baf350806 100644 --- a/nebula/addons/networksimulation/nebulanetworksimulator.py +++ b/nebula/addons/networksimulation/nebulanetworksimulator.py @@ -4,7 +4,7 @@ from nebula.addons.networksimulation.networksimulator import NetworkSimulator from nebula.core.utils.locker import Locker from nebula.core.eventmanager import EventManager -from nebula.addons.GPS.gpsmodule import GPSEvent +from nebula.core.nebulaevents import GPSEvent from typing import TYPE_CHECKING if TYPE_CHECKING: from nebula.core.network.communications import CommunicationsManager @@ -18,8 +18,7 @@ class NebulaNS(NetworkSimulator): } IP_MULTICAST = "239.255.255.250" - def __init__(self, event_manager : EventManager, communication_manager: "CommunicationsManager", changing_interval, interface, verbose=False): - self._event_manager = event_manager + def __init__(self, communication_manager: "CommunicationsManager", changing_interval, interface, verbose=False): self._cm = communication_manager self._refresh_interval = changing_interval self._node_interface = interface @@ -28,18 +27,14 @@ def __init__(self, event_manager : EventManager, communication_manager: "Communi self._network_conditions_lock = Locker("network_conditions_lock", async_lock=True) self._current_network_conditions = {} self._running = False - - @property - def em(self): - return self._event_manager - + async def start(self): logging.info("🌐 Nebula Network Simulator starting...") self._running = True grace_time = self._cm.config.participant["mobility_args"]["grace_time_mobility"] # if self._verbose: logging.info(f"Waiting {grace_time}s to start applying network conditions based on distances between devices") # await asyncio.sleep(grace_time) - await self.em.subscribe_addonevent(GPSEvent, self._change_network_conditions_based_on_distances) + await EventManager.get_instance().subscribe_addonevent(GPSEvent, self._change_network_conditions_based_on_distances) async def stop(self): self._running = False diff --git a/nebula/addons/networksimulation/networksimulator.py b/nebula/addons/networksimulation/networksimulator.py index d468b2bc5..ab63378c8 100644 --- a/nebula/addons/networksimulation/networksimulator.py +++ b/nebula/addons/networksimulation/networksimulator.py @@ -27,7 +27,7 @@ def clear_network_conditions(self, interface): class NetworkSimulatorException(Exception): pass -def factory_network_simulator(net_sim, event_manager, communication_manager, changing_interval, interface, verbose) -> NetworkSimulator: +def factory_network_simulator(net_sim, communication_manager, changing_interval, interface, verbose) -> NetworkSimulator: from nebula.addons.networksimulation.nebulanetworksimulator import NebulaNS SIMULATION_SERVICES = { @@ -37,6 +37,6 @@ def factory_network_simulator(net_sim, event_manager, communication_manager, cha net_serv = SIMULATION_SERVICES.get(net_sim, NebulaNS) if net_serv: - return net_serv(event_manager, communication_manager, changing_interval, interface, verbose) + return net_serv(communication_manager, changing_interval, interface, verbose) else: raise NetworkSimulatorException(f"Network Simulator {net_sim} not found") \ No newline at end of file diff --git a/nebula/core/addonmanager.py b/nebula/core/addonmanager.py index 3401fbf90..79fdca82d 100644 --- a/nebula/core/addonmanager.py +++ b/nebula/core/addonmanager.py @@ -17,14 +17,14 @@ def __init__(self, engine : "Engine", config): async def deploy_additional_services(self): print_msg_box(msg="Deploying Additional Services\n(='.'=)", indent=2, title="Addons Manager") if self._config.participant["mobility_args"]["mobility"]: - mobility = Mobility(self._config, self._engine.cm, self._engine.event_manager) + mobility = Mobility(self._config, self._engine.cm) self._addons.append(mobility) if self._config.participant["network_args"]["simulation"]: refresh_conditions_interval = 5 - network_simulation = factory_network_simulator("nebula", self._engine.event_manager, self._engine.cm, refresh_conditions_interval, "eth0", verbose=False) + network_simulation = factory_network_simulator("nebula", self._engine.cm, refresh_conditions_interval, "eth0", verbose=True) self._addons.append(network_simulation) update_interval = 5 - gps = factory_gpsmodule("nebula", self._config, self._engine.event_manager, self._engine.addr, update_interval, verbose=False) + gps = factory_gpsmodule("nebula", self._config, self._engine.addr, update_interval, verbose=True) self._addons.append(gps) for add in self._addons: diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index f433d1a73..887b3151b 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -4,28 +4,13 @@ from functools import partial from nebula.core.utils.locker import Locker from nebula.core.aggregation.updatehandlers.updatehandler import factory_update_handler -from nebula.core.eventmanager import NodeEvent +from nebula.core.eventmanager import EventManager +from nebula.core.nebulaevents import AggregationEvent from typing import TYPE_CHECKING if TYPE_CHECKING: from nebula.core.engine import Engine - -class AggregationEvent(NodeEvent): - def __init__(self, updates : dict, expected_nodes : set, missing_nodes : set): - self._updates = updates - self._expected_nodes = expected_nodes - self._missing_nodes = missing_nodes - - def __str__(self): - return "Aggregation Ready" - - async def get_event_data(self) -> tuple[dict, set, set]: - return (self._updates, self._expected_nodes, self._missing_nodes) - - async def is_concurrent(self) -> bool: - return False - class AggregatorException(Exception): pass @@ -143,9 +128,7 @@ async def get_aggregation(self): await self.us.stop_notifying_updates() updates = await self.us.get_round_updates() - missing_nodes = await self.us.get_round_missing_nodes() - if missing_nodes: logging.info(f"πŸ”„ get_aggregation | Aggregation incomplete, missing models from: {missing_nodes}") else: @@ -154,13 +137,11 @@ async def get_aggregation(self): logging.info( f"πŸ”„ Broadcasting MODELS_INCLUDED for round {self.engine.get_round()}" ) - message = self.cm.create_message( - "federation", "federation_models_included", [str(arg) for arg in [self.engine.get_round()]] - ) + message = self.cm.create_message("federation", "federation_models_included", [str(arg) for arg in [self.engine.get_round()]]) await self.cm.send_message_to_neighbors(message) agg_event = AggregationEvent(updates, self._federation_nodes, missing_nodes) - await self.engine.event_manager.publish_nodeevent(agg_event) + await EventManager.get_instance().publish_node_event(agg_event) aggregated_result = self.run_aggregation(updates) return aggregated_result diff --git a/nebula/core/engine.py b/nebula/core/engine.py index ea111f179..325e48707 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -146,6 +146,9 @@ def __init__( self.trainning_in_progress_lock = Locker(name="trainning_in_progress_lock", async_lock=True) + event_manager = EventManager.get_instance() + event_manager._initialize(verbose=True) + # Mobility setup self._node_manager = None self.mobility = self.config.participant["mobility_args"]["mobility"] @@ -160,15 +163,6 @@ def __init__( engine=self, ) - self._event_manager = EventManager(verbose=True) - - logging.info("Registering callbacks for MessageEvents...") - self.register_message_events_callbacks() - - # Additional callbacks not registered automatically - self.register_message_callback(("model", "initialization"), "model_initialization_callback") - self.register_message_callback(("model", "update"), "model_update_callback") - self._addon_manager = AddondManager(self, self.config) @property @@ -179,10 +173,6 @@ def cm(self): def reporter(self): return self._reporter - @property - def event_manager(self): - return self._event_manager - @property def aggregator(self): return self._aggregator @@ -252,6 +242,14 @@ def set_round(self, new_round): self.round = new_round self.trainer.set_current_round(new_round) + async def init_message_callbacks(self): + logging.info("Registering callbacks for MessageEvents...") + await self.register_message_events_callbacks() + + # Additional callbacks not registered automatically + await self.register_message_callback(("model", "initialization"), "model_initialization_callback") + await self.register_message_callback(("model", "update"), "model_update_callback") + """ ############################## # MODEL CALLBACKS # ############################## @@ -552,7 +550,7 @@ async def _link_disconnect_from_callback(self, source, message): ############################## """ - def register_message_events_callbacks(self): + async def register_message_events_callbacks(self): me_dict = self.cm.get_messages_events() message_events = [ (message_name, message_action) @@ -566,16 +564,13 @@ def register_message_events_callbacks(self): method = getattr(self, callback_name, None) if callable(method): - self.event_manager.subscribe((event_type, action), method) + await EventManager.get_instance().subscribe((event_type, action), method) - def register_message_callback(self, message_event: tuple[str, str], callback: str): + async def register_message_callback(self, message_event: tuple[str, str], callback: str): event_type, action = message_event method = getattr(self, callback, None) if callable(method): - self.event_manager.subscribe((event_type, action), method) - - async def trigger_event(self, message_event): - await self.event_manager.publish(message_event) + await EventManager.get_instance().subscribe((event_type, action), method) async def get_geoloc(self): return await self.nm.get_geoloc() @@ -669,6 +664,7 @@ async def create_trainer_module(self): logging.info("Started trainer module...") async def start_communications(self): + await self.init_message_callbacks() logging.info(f"Neighbors: {self.config.participant['network_args']['neighbors']}") logging.info( f"πŸ’€ Cold start time: {self.config.participant['misc_args']['grace_time_connection']} seconds before connecting to the network" diff --git a/nebula/core/eventmanager.py b/nebula/core/eventmanager.py index e59f69e21..23fed4004 100755 --- a/nebula/core/eventmanager.py +++ b/nebula/core/eventmanager.py @@ -21,28 +21,54 @@ async def get_event_data(self): async def is_concurrent(self): pass -#TODO pto unico de llamada + class EventManager: - def __init__(self, verbose=False): + _instance = None + _lock = Locker("event_manager") # Para evitar condiciones de carrera en entornos multihilo + + def __new__(cls, *args, **kwargs): + """ImplementaciΓ³n del patrΓ³n Singleton.""" + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialize(*args, **kwargs) + return cls._instance + + def _initialize(self, verbose=False): + """Inicializa la instancia ΓΊnica (solo se ejecuta una vez).""" + if hasattr(self, "_initialized"): # Evita reinicializaciΓ³n + return self._subscribers: dict[tuple[str, str], list] = {} - self._addons_events_subs : dict [AddonEvent, list] = {} + self._message_events_lock = Locker("message_events_lock", async_lock=True) + self._addons_events_subs: dict[type, list] = {} self._addons_event_lock = Locker("addons_event_lock", async_lock=True) - self._node_events_subs : dict [NodeEvent, list] = {} + self._node_events_subs: dict[type, list] = {} self._node_events_lock = Locker("node_events_lock", async_lock=True) self._verbose = verbose + self._initialized = True # Marca que ya se inicializΓ³ + + @staticmethod + def get_instance(): + """MΓ©todo estΓ‘tico para obtener la instancia ΓΊnica.""" + if EventManager._instance is None: + EventManager() + return EventManager._instance - def subscribe(self, event_type: tuple[str, str], callback: callable): + async def subscribe(self, event_type: tuple[str, str], callback: callable): """Register a callback for a specific event type.""" - if event_type not in self._subscribers: - self._subscribers[event_type] = [] - self._subscribers[event_type].append(callback) + async with self._message_events_lock: + if event_type not in self._subscribers: + self._subscribers[event_type] = [] + self._subscribers[event_type].append(callback) logging.info(f"EventManager | Subscribed callback for event: {event_type}") async def publish(self, message_event: MessageEvent): """Trigger all callbacks registered for a specific event type.""" - if self._verbose: logging.info(f"Publishing MessageEvent: {message_event.message_type}") - event_type = message_event.message_type - if event_type not in self._subscribers: + if self._verbose or True: logging.info(f"Publishing MessageEvent: {message_event.message_type}") + async with self._message_events_lock: + event_type = message_event.message_type + callbacks = self._subscribers.get(event_type, []) + if not callbacks: logging.error(f"EventManager | No subscribers for event: {event_type}") return @@ -66,25 +92,27 @@ async def subscribe_addonevent(self, addonEventType: type[AddonEvent], callback: async def publish_addonevent(self, addonevent: AddonEvent): """Trigger all callbacks registered for a specific type of AddonEvent.""" - if self._verbose: logging.info(f"Publishing AddonEvent: {addonevent}") + if self._verbose or True: logging.info(f"Publishing AddonEvent: {addonevent}") async with self._addons_event_lock: event_type = type(addonevent) - if event_type not in self._addons_events_subs: - logging.error(f"EventManager | No subscribers for AddonEvent type: {event_type.__name__}") - return + callbacks = self._addons_events_subs.get(event_type, []) - for callback in self._addons_events_subs[event_type]: - try: - if asyncio.iscoroutinefunction(callback) or inspect.iscoroutine(callback): - await callback(addonevent) - else: - callback(addonevent) - if self._verbose: logging.info(f"EventManager | Triggering callback for event type: {event_type.__name__}") - except Exception as e: - logging.exception(f"EventManager | Error in callback for AddonEvent {event_type.__name__}: {e}") + if not callbacks: + logging.error(f"EventManager | No subscribers for AddonEvent type: {event_type.__name__}") + return + + for callback in self._addons_events_subs[event_type]: + try: + if asyncio.iscoroutinefunction(callback) or inspect.iscoroutine(callback): + await callback(addonevent) + else: + callback(addonevent) + if self._verbose: logging.info(f"EventManager | Triggering callback for event type: {event_type.__name__}") + except Exception as e: + logging.exception(f"EventManager | Error in callback for AddonEvent {event_type.__name__}: {e}") - async def subscribe_nodeevent(self, nodeEventType: type[NodeEvent], callback: callable): + async def subscribe_node_event(self, nodeEventType: type[NodeEvent], callback: callable): """Register a callback for a specific type of AddonEvent.""" async with self._node_events_lock: if nodeEventType not in self._node_events_subs: @@ -92,24 +120,27 @@ async def subscribe_nodeevent(self, nodeEventType: type[NodeEvent], callback: ca self._node_events_subs[nodeEventType].append(callback) logging.info(f"EventManager | Subscribed callback for NodeEvent type: {nodeEventType.__name__}") - async def publish_nodeevent(self, nodeevent: NodeEvent): + async def publish_node_event(self, nodeevent: NodeEvent): """Trigger all callbacks registered for a specific type of AddonEvent.""" - if self._verbose: logging.info(f"Publishing NodeEvent: {nodeevent}") + if self._verbose or True: logging.info(f"Publishing NodeEvent: {nodeevent}") async with self._node_events_lock: event_type = type(nodeevent) - if event_type not in self._node_events_subs: - if self._verbose: logging.error(f"EventManager | No subscribers for NodeEvent type: {event_type.__name__}") - return + callbacks = self._node_events_subs.get(event_type, []) # Extraer la lista de callbacks - for callback in self._node_events_subs[event_type]: - try: - if asyncio.iscoroutinefunction(callback) or inspect.iscoroutine(callback): - if await nodeevent.is_concurrent(): - asyncio.create_task(callback(nodeevent)) - else: - await callback(nodeevent) + if not callbacks: + if self._verbose: + logging.error(f"EventManager | No subscribers for NodeEvent type: {event_type.__name__}") + return + + for callback in self._node_events_subs[event_type]: + try: + if asyncio.iscoroutinefunction(callback) or inspect.iscoroutine(callback): + if await nodeevent.is_concurrent(): + asyncio.create_task(callback(nodeevent)) else: - callback(nodeevent) - if self._verbose: logging.info(f"EventManager | Triggering callback for event type: {event_type.__name__}") - except Exception as e: - logging.exception(f"EventManager | Error in callback for NodeEvent {event_type.__name__}: {e}") + await callback(nodeevent) + else: + callback(nodeevent) + if self._verbose: logging.info(f"EventManager | Triggering callback for event type: {event_type.__name__}") + except Exception as e: + logging.exception(f"EventManager | Error in callback for NodeEvent {event_type.__name__}: {e}") diff --git a/nebula/core/nebulaevents.py b/nebula/core/nebulaevents.py new file mode 100644 index 000000000..5d66c4dda --- /dev/null +++ b/nebula/core/nebulaevents.py @@ -0,0 +1,47 @@ +from abc import ABC, abstractmethod +import asyncio + +class AddonEvent(ABC): + @abstractmethod + async def get_event_data(self): + pass + +class NodeEvent(ABC): + @abstractmethod + async def get_event_data(self): + pass + + @abstractmethod + async def is_concurrent(self): + pass + +class MessageEvent: + def __init__(self, message_type, source, message): + self.source = source + self.message_type = message_type + self.message = message + +class AggregationEvent(NodeEvent): + def __init__(self, updates : dict, expected_nodes : set, missing_nodes : set): + self._updates = updates + self._expected_nodes = expected_nodes + self._missing_nodes = missing_nodes + + def __str__(self): + return "Aggregation Ready" + + async def get_event_data(self) -> tuple[dict, set, set]: + return (self._updates, self._expected_nodes, self._missing_nodes) + + async def is_concurrent(self) -> bool: + return False + +class GPSEvent(AddonEvent): + def __init__(self, distances : dict): + self.distances = distances + + def __str__(self): + return "GPSEvent" + + async def get_event_data(self) -> dict: + return self.distances.copy() \ No newline at end of file diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 56f7ffe77..3138e16ee 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -12,9 +12,11 @@ from nebula.core.network.discoverer import Discoverer from nebula.core.network.externalconnection.externalconnectionservice import factory_connection_service from nebula.core.network.forwarder import Forwarder -from nebula.core.network.messages import MessageEvent, MessagesManager +from nebula.core.network.messages import MessagesManager +from nebula.core.nebulaevents import MessageEvent from nebula.core.network.propagator import Propagator from nebula.core.utils.locker import Locker +from nebula.core.eventmanager import EventManager if TYPE_CHECKING: from nebula.core.engine import Engine @@ -141,16 +143,16 @@ async def forward_message(self, data, addr_from): await self.forwarder.forward(data, addr_from=addr_from) async def handle_message(self, message_event): - await self.engine.trigger_event(message_event) + asyncio.create_task(EventManager.get_instance().publish(message_event)) async def handle_model_message(self, source, message): logging.info(f"πŸ€– handle_model_message | Received model from {source} with round {message.round}") if message.round == -1: model_init_event = MessageEvent(("model", "initialization"), source, message) - await self.engine.trigger_event(model_init_event) + asyncio.create_task(EventManager.get_instance().publish(model_init_event)) else: model_updt_event = MessageEvent(("model", "update"), source, message) - await self.engine.trigger_event(model_updt_event) + asyncio.create_task(EventManager.get_instance().publish(model_updt_event)) def create_message(self, message_type: str, action: str = "", *args, **kwargs): return self.mm.create_message(message_type, action, *args, **kwargs) diff --git a/nebula/core/network/messages.py b/nebula/core/network/messages.py index ff75cee32..0464bb705 100644 --- a/nebula/core/network/messages.py +++ b/nebula/core/network/messages.py @@ -5,6 +5,7 @@ from nebula.core.network.actions import factory_message_action, get_action_name_from_value, get_actions_names from nebula.core.pb import nebula_pb2 +from nebula.core.nebulaevents import MessageEvent if TYPE_CHECKING: from nebula.core.network.communications import CommunicationsManager @@ -194,9 +195,3 @@ def create_message(self, message_type: str, action: str = "", *args, **kwargs): data = message_wrapper.SerializeToString() return data - -class MessageEvent: - def __init__(self, message_type, source, message): - self.source = source - self.message_type = message_type - self.message = message diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index f341ec857..1dfbb32ab 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -30,7 +30,7 @@ def __init__( self._topology = topology self._node_manager: NodeManager = nodemanager self._situational_awareness_network = SANetwork(self, self.cm, self._addr, self._topology) - self._situational_awareness_trainning = SATraining(self, "hybrid", "fastreboot") + self._situational_awareness_training = SATraining(self, self._addr, "qds", "fastreboot") self._restructure_process_lock = Locker(name="restructure_process_lock") self._restructure_cooldown = 0 @@ -41,6 +41,10 @@ def nm(self): @property def san(self): return self._situational_awareness_network + + @property + def sat(self): + return self._situational_awareness_training @property def cm(self): @@ -50,6 +54,7 @@ def cm(self): async def init(self): #if not self.is_additional_participant(): await self.san.init() + await self.sat.init() def is_additional_participant(self): return self.nm.is_additional_participant() @@ -63,7 +68,8 @@ async def get_geoloc(self): return (latitude, longitude) async def mobility_actions(self): - await self.san.module_actions() + await self.san.module_actions() + await self.sat.module_actions() """ ############################### diff --git a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py index c0e032f99..071a52504 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py @@ -65,6 +65,7 @@ async def init(self): ]) async def module_actions(self): + logging.info("SA Network evaluating current scenario") await self._check_external_connection_service_status() await self._analize_topology_robustness() diff --git a/nebula/core/situationalawareness/awareness/satraining/satraining.py b/nebula/core/situationalawareness/awareness/satraining/satraining.py index ac0f58a9f..8ee28367d 100644 --- a/nebula/core/situationalawareness/awareness/satraining/satraining.py +++ b/nebula/core/situationalawareness/awareness/satraining/satraining.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from nebula.core.situationalawareness.awareness.samodule import SAModule + from nebula.core.eventmanager import EventManager RESTRUCTURE_COOLDOWN = 5 @@ -13,6 +14,7 @@ class SATraining(): def __init__( self, sam: "SAModule", + addr, training_policy, weight_strategies ): @@ -22,5 +24,20 @@ def __init__( title="Training SA module", ) self._sam = sam - #self._trainning_policy = factory_training_policy(training_policy) - self._weight_strategies = weight_strategies \ No newline at end of file + config = {} + config["addr"] = addr + self._trainning_policy = factory_training_policy(training_policy, config) + self._weight_strategies = weight_strategies + + @property + def tp(self): + return self._trainning_policy + + async def init(self): + config = {} + config["nodes"] = set(self._sam.get_nodes_known(neighbors_only=True)) + await self.tp.init(config) + + async def module_actions(self): + logging.info("SA Trainng evaluating current scenario") + await self.tp.evaluate() diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/bpstrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/bpstrainingpolicy.py index 9a0b34b86..24769c11a 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/bpstrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/bpstrainingpolicy.py @@ -2,5 +2,14 @@ class BPSTrainingPolicy(TrainingPolicy): - def __init__(): + def __init__(self, config=None): + pass + + async def init(self, config): + pass + + async def update_neighbors(self, node, remove=False): + pass + + async def evaluate(self): pass \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/htstrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/htstrainingpolicy.py index 522e5b783..41dcd4f64 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/htstrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/htstrainingpolicy.py @@ -1,6 +1,16 @@ from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.trainingpolicy import TrainingPolicy +# "Hybrid Training Strategy" (HTS) class HTSTrainingPolicy(TrainingPolicy): - def __init__(): + def __init__(self, config): + pass + + async def init(self, config): + pass + + async def update_neighbors(self, node, remove=False): + pass + + async def evaluate(self): pass \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py index d5a2685c4..4a451298b 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py @@ -1,6 +1,57 @@ from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.trainingpolicy import TrainingPolicy +import asyncio +from nebula.core.utils.helper import cosine_metric +from nebula.core.utils.locker import Locker +from collections import deque +import logging +from nebula.core.eventmanager import EventManager +from nebula.core.nebulaevents import AggregationEvent +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from nebula.core.eventmanager import EventManager +# "Quality-Driven Selection" (QDS) class QDSTrainingPolicy(TrainingPolicy): + MAX_HISTORIC_SIZE = 10 + SIMILARITY_THRESHOLD = 0.8 + + def __init__(self, config : dict): + self._addr = config["addr"] + self._nodes : dict[str, deque] = {} + self._nodes_lock = Locker(name="nodes_lock", async_lock=True) + + + async def init(self, config): + async with self._nodes_lock: + nodes = config["nodes"] + self._nodes : dict[str, deque] = {node_id: deque(maxlen=self.MAX_HISTORIC_SIZE) for node_id in nodes} + await EventManager.get_instance().subscribe_node_event(AggregationEvent, self.process_aggregation_event) + + async def update_neighbors(self, node, remove=False): + async with self._nodes_lock: + if remove: + self._nodes.pop(node, None) + else: + if not node in self._nodes: + self._nodes.update({node : deque(maxlen=self.MAX_HISTORIC_SIZE)}) + + async def process_aggregation_event(self, agg_ev : AggregationEvent): + logging.info("Processing aggregation event") + (updates, expected_nodes, missing_nodes) = await agg_ev.get_event_data() + self_updt = updates[self._addr] + async with self._nodes_lock: + for addr, updt in updates.items(): + if addr == self._addr: continue + if not addr in self._nodes.keys(): continue + (model,_) = updt + (self_model, _) = self_updt + cos_sim = cosine_metric(self_model, model, similarity=True) + self._nodes[addr].append(cos_sim) - def __init__(): - pass \ No newline at end of file + async def evaluate(self): + async with self._nodes_lock: + for node in self._nodes: + if self._nodes[node]: + last_sim = self._nodes[node][-1] + if self._nodes[node][-1] < self.SIMILARITY_THRESHOLD: + logging.info(f"Node: {node} got a similarity value of: {last_sim} under threshold: {self.SIMILARITY_THRESHOLD}") \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py index 6664fba75..eca6464d0 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py @@ -1,6 +1,16 @@ from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.trainingpolicy import TrainingPolicy +# "Speed-Oriented Selection" (SOS) class SOSTrainingPolicy(TrainingPolicy): - def __init__(): + def __init__(self, config): + pass + + async def init(self, config): + pass + + async def update_neighbors(self, node, remove=False): + pass + + async def evaluate(self): pass \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py index 5498b3285..91f726e31 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py @@ -1,8 +1,13 @@ from abc import ABC, abstractmethod from typing import Type + class TrainingPolicy(ABC): + @abstractmethod + async def init(self, config): + pass + @abstractmethod async def update_neighbors(self, node, remove=False): pass @@ -12,7 +17,7 @@ async def evaluate(self): pass -def factory_training_policy(topology) -> TrainingPolicy: +def factory_training_policy(training_policy, config) -> TrainingPolicy: from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.bpstrainingpolicy import BPSTrainingPolicy from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.qdstrainingpolicy import QDSTrainingPolicy from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.sostrainingpolicy import SOSTrainingPolicy @@ -25,5 +30,5 @@ def factory_training_policy(topology) -> TrainingPolicy: "hts": HTSTrainingPolicy, # "Hybrid Training Strategy" (HTS) } - cs = options.get(topology, BPSTrainingPolicy) - return cs() \ No newline at end of file + cs = options.get(training_policy, BPSTrainingPolicy) + return cs(config) \ No newline at end of file From ced5f9bc579185b7b62ef009661146c1e36ee109 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Mon, 3 Mar 2025 13:52:09 +0100 Subject: [PATCH 122/233] feature QDS tp - update received event --- nebula/addons/mobility.py | 7 +- .../nebulanetworksimulator.py | 2 +- nebula/core/addonmanager.py | 6 +- nebula/core/aggregation/aggregator.py | 9 +- .../updatehandlers/dflupdatehandler.py | 15 ++- .../updatehandlers/updatehandler.py | 4 + nebula/core/engine.py | 77 ++++----------- nebula/core/eventmanager.py | 10 +- nebula/core/nebulaevents.py | 97 ++++++++++++++++++- .../awareness/samodule.py | 2 +- .../awareness/sanetwork/sanetwork.py | 2 +- .../awareness/satraining/satraining.py | 17 +++- .../trainingpolicy/bpstrainingpolicy.py | 2 +- .../trainingpolicy/qdstrainingpolicy.py | 81 +++++++++++++--- .../core/situationalawareness/nodemanager.py | 8 -- 15 files changed, 232 insertions(+), 107 deletions(-) diff --git a/nebula/addons/mobility.py b/nebula/addons/mobility.py index a52dfec7f..787c4d7db 100755 --- a/nebula/addons/mobility.py +++ b/nebula/addons/mobility.py @@ -13,7 +13,7 @@ class Mobility: - def __init__(self, config, cm: "CommunicationsManager"): + def __init__(self, config, cm: "CommunicationsManager", verbose=False): """ Initializes the mobility module with specified configuration and communication manager. @@ -68,6 +68,7 @@ def __init__(self, config, cm: "CommunicationsManager"): print_msg_box(msg=mobility_msg, indent=2, title="Mobility information") self._nodes_distances = {} self._nodes_distances_lock = Locker("nodes_distances_lock", async_lock=True) + self._verbose = verbose @property def round(self): @@ -158,7 +159,7 @@ async def change_geo_location_random_strategy(self, latitude, longitude): - The calculated radius is converted from meters to degrees based on an approximate conversion factor (1 degree is approximately 111 kilometers). """ - logging.info("πŸ“ Changing geo location randomly") + if self._verbose: logging.info("πŸ“ Changing geo location randomly") # radius_in_degrees = self.radius_federation / 111000 max_radius_in_degrees = self.max_movement_random_strategy / 111000 radius = random.uniform(0, max_radius_in_degrees) # noqa: S311 @@ -239,7 +240,7 @@ async def set_geo_location(self, latitude, longitude): self.config.participant["mobility_args"]["latitude"] = latitude self.config.participant["mobility_args"]["longitude"] = longitude - logging.info(f"πŸ“ New geo location: {latitude}, {longitude}") + if self._verbose: logging.info(f"πŸ“ New geo location: {latitude}, {longitude}") async def change_geo_location(self): """ diff --git a/nebula/addons/networksimulation/nebulanetworksimulator.py b/nebula/addons/networksimulation/nebulanetworksimulator.py index baf350806..a85b57bf0 100644 --- a/nebula/addons/networksimulation/nebulanetworksimulator.py +++ b/nebula/addons/networksimulation/nebulanetworksimulator.py @@ -14,7 +14,7 @@ class NebulaNS(NetworkSimulator): 100: {"bandwidth": "5Gbps", "delay": "5ms"}, 200: {"bandwidth": "2Gbps", "delay": "50ms"}, 300: {"bandwidth": "100Mbps", "delay": "200ms"}, - float("inf"): {"bandwidth": "10Mbps", "delay": "10000ms"}, + float("inf"): {"bandwidth": "10Mbps", "delay": "100000ms"}, } IP_MULTICAST = "239.255.255.250" diff --git a/nebula/core/addonmanager.py b/nebula/core/addonmanager.py index 79fdca82d..1416398d3 100644 --- a/nebula/core/addonmanager.py +++ b/nebula/core/addonmanager.py @@ -17,14 +17,14 @@ def __init__(self, engine : "Engine", config): async def deploy_additional_services(self): print_msg_box(msg="Deploying Additional Services\n(='.'=)", indent=2, title="Addons Manager") if self._config.participant["mobility_args"]["mobility"]: - mobility = Mobility(self._config, self._engine.cm) + mobility = Mobility(self._config, self._engine.cm, verbose=False) self._addons.append(mobility) if self._config.participant["network_args"]["simulation"]: refresh_conditions_interval = 5 - network_simulation = factory_network_simulator("nebula", self._engine.cm, refresh_conditions_interval, "eth0", verbose=True) + network_simulation = factory_network_simulator("nebula", self._engine.cm, refresh_conditions_interval, "eth0", verbose=False) self._addons.append(network_simulation) update_interval = 5 - gps = factory_gpsmodule("nebula", self._config, self._engine.addr, update_interval, verbose=True) + gps = factory_gpsmodule("nebula", self._config, self._engine.addr, update_interval, verbose=False) self._addons.append(gps) for add in self._addons: diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index 887b3151b..0adc4c240 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -69,6 +69,9 @@ def run_aggregation(self, models): if len(models) == 0: logging.error("Trying to aggregate models when there are no models") return None + + async def init(self): + await self.us.init() async def update_federation_nodes(self, federation_nodes: set): await self.us.round_expected_updates(federation_nodes=federation_nodes) @@ -82,12 +85,6 @@ async def update_federation_nodes(self, federation_nodes: set): else: raise Exception("It is not possible to set nodes to aggregate when the aggregation is running.") - async def update_received_from_source(self, model, weight, source, round, local=False): - await self.us.storage_update(model, weight, source, round, local=False) - - async def notify_federation_nodes_removed(self, federation_node, remove=False): - await self.us.notify_federation_update(federation_node, remove=remove) - def get_nodes_pending_models_to_aggregate(self): return self._federation_nodes diff --git a/nebula/core/aggregation/updatehandlers/dflupdatehandler.py b/nebula/core/aggregation/updatehandlers/dflupdatehandler.py index 9ba0a280b..b1684b754 100644 --- a/nebula/core/aggregation/updatehandlers/dflupdatehandler.py +++ b/nebula/core/aggregation/updatehandlers/dflupdatehandler.py @@ -5,6 +5,8 @@ from nebula.core.utils.locker import Locker import time from nebula.core.aggregation.updatehandlers.updatehandler import UpdateHandler +from nebula.core.nebulaevents import UpdateNeighborEvent, UpdateReceivedEvent +from nebula.core.eventmanager import EventManager from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -50,7 +52,12 @@ def us(self): @property def agg(self): - return self._aggregator + return self._aggregator + + async def init(self): + await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent, self.notify_federation_update) + await EventManager.get_instance().subscribe_node_event(UpdateReceivedEvent, self.storage_update) + async def round_expected_updates(self, federation_nodes: set): await self._update_federation_lock.acquire_async() @@ -88,8 +95,9 @@ async def _check_updates_already_received(self): logging.info(f"Update already received from source: {se} | ({len(self._sources_received)}/{len(self._sources_expected)}) Updates received") self._sources_received.add(se) - async def storage_update(self, model, weight, source, round, local=False): + async def storage_update(self, updt_received_event : UpdateReceivedEvent): time_received = time.time() + (model, weight, source, round, _) = await updt_received_event.get_event_data() if source in self._sources_expected: updt = Update(model, weight, source, round, time_received) await self._updates_storage_lock.acquire_async() @@ -137,7 +145,8 @@ async def get_round_updates(self): await self._updates_storage_lock.release_async() return updates - async def notify_federation_update(self, source, remove=False): + async def notify_federation_update(self, updt_nei_event : UpdateNeighborEvent): + source, remove = await updt_nei_event.get_event_data() if not remove: if self._round_updates_lock.locked(): logging.info(f"Source: {source} will be count next round") diff --git a/nebula/core/aggregation/updatehandlers/updatehandler.py b/nebula/core/aggregation/updatehandlers/updatehandler.py index 74e64f634..ca9d05ea1 100644 --- a/nebula/core/aggregation/updatehandlers/updatehandler.py +++ b/nebula/core/aggregation/updatehandlers/updatehandler.py @@ -11,6 +11,10 @@ class UpdateHandler(ABC): ensuring they are properly stored, retrieved, and processed during the aggregation process. """ + @abstractmethod + async def init(): + raise NotImplementedError + @abstractmethod async def round_expected_updates(self, federation_nodes: set): """ diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 325e48707..614c57faa 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -9,6 +9,7 @@ from nebula.addons.reporter import Reporter from nebula.core.aggregation.aggregator import create_aggregator, create_target_aggregator from nebula.core.eventmanager import EventManager +from nebula.core.nebulaevents import UpdateNeighborEvent, UpdateReceivedEvent from nebula.core.network.communications import CommunicationsManager from nebula.core.situationalawareness.nodemanager import NodeManager from nebula.core.utils.locker import Locker @@ -146,8 +147,7 @@ def __init__( self.trainning_in_progress_lock = Locker(name="trainning_in_progress_lock", async_lock=True) - event_manager = EventManager.get_instance() - event_manager._initialize(verbose=True) + event_manager = EventManager.get_instance(verbose=False) # Mobility setup self._node_manager = None @@ -280,7 +280,8 @@ async def model_update_callback(self, source, message): logging.info("πŸ€– handle_model_message | There are no defined federation nodes") return decoded_model = self.trainer.deserialize_model(message.parameters) - await self.aggregator.update_received_from_source(decoded_model, message.weight, source, message.round) + updt_received_event = UpdateReceivedEvent(decoded_model, message.weight, source, message.round) + await EventManager.get_instance().publish_node_event(updt_received_event) """ ############################## # General callbacks # @@ -590,21 +591,14 @@ async def _aditional_node_start(self): logging.info("Creating trainer service to start the federation process..") asyncio.create_task(self._start_learning_late()) - async def set_pushed_done(self, rounds_push): - await self.nm.set_rounds_pushed(rounds_push) - - async def apply_weight_strategy(self, pending_models): - if self.mobility: - await self.nm.apply_weight_strategy(pending_models) - return pending_models - else: - return pending_models - async def update_neighbors(self, removed_neighbor_addr, neighbors, remove=False): if self.mobility: self.federation_nodes = neighbors await self.nm.update_neighbors(removed_neighbor_addr, remove=remove) - await self.aggregator.notify_federation_nodes_removed(removed_neighbor_addr, remove=remove) + updt_nei_event = UpdateNeighborEvent(removed_neighbor_addr, remove) + asyncio.create_task(EventManager.get_instance().publish_node_event(updt_nei_event)) + + #await self.aggregator.notify_federation_nodes_removed(removed_neighbor_addr, remove=remove) async def update_model_learning_rate(self, new_lr): await self.trainning_in_progress_lock.acquire_async() @@ -665,6 +659,7 @@ async def create_trainer_module(self): async def start_communications(self): await self.init_message_callbacks() + await self.aggregator.init() logging.info(f"Neighbors: {self.config.participant['network_args']['neighbors']}") logging.info( f"πŸ’€ Cold start time: {self.config.participant['misc_args']['grace_time_connection']} seconds before connecting to the network" @@ -997,20 +992,9 @@ async def _extended_learning_cycle(self): await self.trainer.train() await self.trainning_in_progress_lock.release_async() - # await self.aggregator.include_model_in_buffer( - # self.trainer.get_model_parameters(), - # self.trainer.get_model_weight(), - # source=self.addr, - # round=self.round, - # ) - - await self.aggregator.update_received_from_source( - self.trainer.get_model_parameters(), - self.trainer.get_model_weight(), - source=self.addr, - round=self.round, - ) - + self_update_event = UpdateReceivedEvent(self.trainer.get_model_parameters(), self.trainer.get_model_weight(), self.addr, self.round) + await EventManager.get_instance().publish_node_event(self_update_event) + await self.cm.propagator.propagate("stable") await self._waiting_model_updates() @@ -1035,21 +1019,9 @@ def __init__( async def _extended_learning_cycle(self): # Define the functionality of the server node await self.trainer.test() - - # In the first round, the server node doest take into account the initial model parameters for the aggregation - # await self.aggregator.include_model_in_buffer( - # self.trainer.get_model_parameters(), - # self.trainer.BYPASS_MODEL_WEIGHT, - # source=self.addr, - # round=self.round, - # ) - - await self.aggregator.update_received_from_source( - self.trainer.get_model_parameters(), - self.trainer.BYPASS_MODEL_WEIGHT, - source=self.addr, - round=self.round, - ) + + self_update_event = UpdateReceivedEvent(self.trainer.get_model_parameters(), self.trainer.BYPASS_MODEL_WEIGHT, self.addr, self.round) + await EventManager.get_instance().publish_node_event(self_update_event) await self._waiting_model_updates() await self.cm.propagator.propagate("stable") @@ -1079,22 +1051,9 @@ async def _extended_learning_cycle(self): await self.trainer.test() await self.trainer.train() - - # await self.aggregator.include_model_in_buffer( - # self.trainer.get_model_parameters(), - # self.trainer.get_model_weight(), - # source=self.addr, - # round=self.round, - # local=True, - # ) - - await self.aggregator.update_received_from_source( - self.trainer.get_model_parameters(), - self.trainer.get_model_weight(), - source=self.addr, - round=self.round, - local=True, - ) + + self_update_event = UpdateReceivedEvent(self.trainer.get_model_parameters(), self.trainer.get_model_weight(), self.addr, self.round, local=True) + await EventManager.get_instance().publish_node_event(self_update_event) await self.cm.propagator.propagate("stable") await self._waiting_model_updates() diff --git a/nebula/core/eventmanager.py b/nebula/core/eventmanager.py index 23fed4004..c2d385747 100755 --- a/nebula/core/eventmanager.py +++ b/nebula/core/eventmanager.py @@ -48,10 +48,10 @@ def _initialize(self, verbose=False): self._initialized = True # Marca que ya se inicializΓ³ @staticmethod - def get_instance(): + def get_instance(verbose=False): """MΓ©todo estΓ‘tico para obtener la instancia ΓΊnica.""" if EventManager._instance is None: - EventManager() + EventManager(verbose=verbose) return EventManager._instance async def subscribe(self, event_type: tuple[str, str], callback: callable): @@ -64,7 +64,7 @@ async def subscribe(self, event_type: tuple[str, str], callback: callable): async def publish(self, message_event: MessageEvent): """Trigger all callbacks registered for a specific event type.""" - if self._verbose or True: logging.info(f"Publishing MessageEvent: {message_event.message_type}") + if self._verbose: logging.info(f"Publishing MessageEvent: {message_event.message_type}") async with self._message_events_lock: event_type = message_event.message_type callbacks = self._subscribers.get(event_type, []) @@ -92,7 +92,7 @@ async def subscribe_addonevent(self, addonEventType: type[AddonEvent], callback: async def publish_addonevent(self, addonevent: AddonEvent): """Trigger all callbacks registered for a specific type of AddonEvent.""" - if self._verbose or True: logging.info(f"Publishing AddonEvent: {addonevent}") + if self._verbose: logging.info(f"Publishing AddonEvent: {addonevent}") async with self._addons_event_lock: event_type = type(addonevent) callbacks = self._addons_events_subs.get(event_type, []) @@ -122,7 +122,7 @@ async def subscribe_node_event(self, nodeEventType: type[NodeEvent], callback: c async def publish_node_event(self, nodeevent: NodeEvent): """Trigger all callbacks registered for a specific type of AddonEvent.""" - if self._verbose or True: logging.info(f"Publishing NodeEvent: {nodeevent}") + if self._verbose: logging.info(f"Publishing NodeEvent: {nodeevent}") async with self._node_events_lock: event_type = type(nodeevent) callbacks = self._node_events_subs.get(event_type, []) # Extraer la lista de callbacks diff --git a/nebula/core/nebulaevents.py b/nebula/core/nebulaevents.py index 5d66c4dda..75ef463c3 100644 --- a/nebula/core/nebulaevents.py +++ b/nebula/core/nebulaevents.py @@ -15,14 +15,27 @@ async def get_event_data(self): async def is_concurrent(self): pass -class MessageEvent: +class MessageEvent(): def __init__(self, message_type, source, message): self.source = source self.message_type = message_type self.message = message +""" ############################## + # NODE EVENTS # + ############################## +""" + + class AggregationEvent(NodeEvent): def __init__(self, updates : dict, expected_nodes : set, missing_nodes : set): + """Event triggered when model aggregation is ready. + + Args: + updates (dict): Dictionary containing model updates. + expected_nodes (set): Set of nodes expected to participate in aggregation. + missing_nodes (set): Set of nodes that did not send their update. + """ self._updates = updates self._expected_nodes = expected_nodes self._missing_nodes = missing_nodes @@ -31,10 +44,90 @@ def __str__(self): return "Aggregation Ready" async def get_event_data(self) -> tuple[dict, set, set]: + """Retrieves the aggregation event data. + + Returns: + tuple[dict, set, set]: + - updates (dict): Model updates. + - expected_nodes (set): Expected nodes. + - missing_nodes (set): Missing nodes. + """ return (self._updates, self._expected_nodes, self._missing_nodes) async def is_concurrent(self) -> bool: - return False + return False + +class UpdateNeighborEvent(NodeEvent): + def __init__(self, node_addr, removed=False): + """Event triggered when a neighboring node is updated. + + Args: + node_addr (str): Address of the neighboring node. + removed (bool, optional): Indicates whether the node was removed. + Defaults to False. + """ + self._node_addr = node_addr + self._removed = removed + + def __str__(self): + return f"Node addr: {self._node_addr}, removed: {self._removed}" + + async def get_event_data(self) -> tuple[str, bool]: + """Retrieves the neighbor update event data. + + Returns: + tuple[str, bool]: + - node_addr (str): Address of the neighboring node. + - removed (bool): Whether the node was removed. + """ + return (self._node_addr, self._removed) + + async def is_concurrent(self) -> bool: + return False + +class UpdateReceivedEvent(NodeEvent): + def __init__(self, decoded_model, weight, source, round, local=False): + """ + Initializes an UpdateReceivedEvent. + + Args: + decoded_model (Any): The received model update. + weight (float): The weight associated with the received update. + source (str): The identifier or address of the node that sent the update. + round (int): The round number in which the update was received. + local (bool): Local update + """ + self._source = source + self._round = round + self._model = decoded_model + self._weight = weight + self._local = local + + def __str__(self): + return f"Update received from source: {self._source}, round: {self._round}" + + async def get_event_data(self) -> tuple[str, bool]: + """ + Retrieves the event data. + + Returns: + tuple[Any, float, str, int, bool]: A tuple containing: + - The received model update. + - The weight associated with the update. + - The source node identifier. + - The round number of the update. + - If the update is local + """ + return (self._model, self._weight, self._source, self._round, self._local) + + async def is_concurrent(self) -> bool: + return False + + +""" ############################## + # ADDON EVENTS # + ############################## +""" class GPSEvent(AddonEvent): def __init__(self, distances : dict): diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index 1dfbb32ab..84f4d8f88 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -30,7 +30,7 @@ def __init__( self._topology = topology self._node_manager: NodeManager = nodemanager self._situational_awareness_network = SANetwork(self, self.cm, self._addr, self._topology) - self._situational_awareness_training = SATraining(self, self._addr, "qds", "fastreboot") + self._situational_awareness_training = SATraining(self, self._addr, "qds", "fastreboot", verbose=True) self._restructure_process_lock = Locker(name="restructure_process_lock") self._restructure_cooldown = 0 diff --git a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py index 071a52504..0dddcbbfa 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py @@ -122,7 +122,7 @@ async def experiment_finish(self): async def beacon_received(self, addr, geoloc): latitude, longitude = geoloc self.meet_node(addr) - logging.info(f"Beacon received SANetwork, source: {addr}, geolocalization: {latitude},{longitude}") + #logging.info(f"Beacon received SANetwork, source: {addr}, geolocalization: {latitude},{longitude}") """ ############################### # REESTRUCTURE TOPOLOGY # diff --git a/nebula/core/situationalawareness/awareness/satraining/satraining.py b/nebula/core/situationalawareness/awareness/satraining/satraining.py index 8ee28367d..543f2188d 100644 --- a/nebula/core/situationalawareness/awareness/satraining/satraining.py +++ b/nebula/core/situationalawareness/awareness/satraining/satraining.py @@ -16,7 +16,8 @@ def __init__( sam: "SAModule", addr, training_policy, - weight_strategies + weight_strategies, + verbose ): print_msg_box( msg=f"Starting Training SA\nTraining policy: {training_policy}\nWeight strategies: {weight_strategies}", @@ -26,18 +27,28 @@ def __init__( self._sam = sam config = {} config["addr"] = addr + self._verbose = verbose + config["verbose"] = verbose self._trainning_policy = factory_training_policy(training_policy, config) self._weight_strategies = weight_strategies + @property + def sam(self): + return self._sam + @property def tp(self): return self._trainning_policy async def init(self): config = {} - config["nodes"] = set(self._sam.get_nodes_known(neighbors_only=True)) + config["nodes"] = set(self._sam.get_nodes_known(neighbors_only=True)) await self.tp.init(config) async def module_actions(self): logging.info("SA Trainng evaluating current scenario") - await self.tp.evaluate() + nodes = await self.tp.evaluate() + if nodes: + for n in nodes: + pass + #asyncio.create_task(self.sam.cm.disconnect(n, forced=True)) diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/bpstrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/bpstrainingpolicy.py index 24769c11a..15fa664ec 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/bpstrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/bpstrainingpolicy.py @@ -12,4 +12,4 @@ async def update_neighbors(self, node, remove=False): pass async def evaluate(self): - pass \ No newline at end of file + return None \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py index 4a451298b..2bc99d1d5 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py @@ -6,6 +6,7 @@ import logging from nebula.core.eventmanager import EventManager from nebula.core.nebulaevents import AggregationEvent +import random from typing import TYPE_CHECKING if TYPE_CHECKING: from nebula.core.eventmanager import EventManager @@ -14,17 +15,24 @@ class QDSTrainingPolicy(TrainingPolicy): MAX_HISTORIC_SIZE = 10 SIMILARITY_THRESHOLD = 0.8 + INACTIVE_THRESHOLD = 3 + GRACE_ROUNDS = 10 + CHECK_COOLDOWN = 10 def __init__(self, config : dict): self._addr = config["addr"] - self._nodes : dict[str, deque] = {} + self._verbose = config["verbose"] + self._nodes : dict[str, tuple[deque, int]] = {} self._nodes_lock = Locker(name="nodes_lock", async_lock=True) + self._round_missing_nodes = set() + self._grace_rounds = self.GRACE_ROUNDS + self._last_check = 0 async def init(self, config): async with self._nodes_lock: nodes = config["nodes"] - self._nodes : dict[str, deque] = {node_id: deque(maxlen=self.MAX_HISTORIC_SIZE) for node_id in nodes} + self._nodes : dict[str, tuple[deque, int]] = {node_id: (deque(maxlen=self.MAX_HISTORIC_SIZE), 0) for node_id in nodes} await EventManager.get_instance().subscribe_node_event(AggregationEvent, self.process_aggregation_event) async def update_neighbors(self, node, remove=False): @@ -33,25 +41,76 @@ async def update_neighbors(self, node, remove=False): self._nodes.pop(node, None) else: if not node in self._nodes: - self._nodes.update({node : deque(maxlen=self.MAX_HISTORIC_SIZE)}) + self._nodes.update({node : (deque(maxlen=self.MAX_HISTORIC_SIZE), 0)}) async def process_aggregation_event(self, agg_ev : AggregationEvent): - logging.info("Processing aggregation event") + if self._verbose: logging.info("Processing aggregation event") (updates, expected_nodes, missing_nodes) = await agg_ev.get_event_data() + self._round_missing_nodes = missing_nodes self_updt = updates[self._addr] async with self._nodes_lock: for addr, updt in updates.items(): if addr == self._addr: continue if not addr in self._nodes.keys(): continue + + for node in self._nodes.keys(): # Update inactive counters + deque_history, missed_count = self._nodes[node] + if not node in missing_nodes: + self._nodes[node] = (deque_history, 0) # Reset inactive counter + else: + if self._verbose: logging.info(f"Node inactivity counter increased for: {node}") + self._nodes[node] = (deque_history, missed_count + 1) # Inactive rounds counter +1 + (model,_) = updt (self_model, _) = self_updt cos_sim = cosine_metric(self_model, model, similarity=True) - self._nodes[addr].append(cos_sim) + self._nodes[addr][0].append(cos_sim) - async def evaluate(self): + async def _get_nodes(self): async with self._nodes_lock: - for node in self._nodes: - if self._nodes[node]: - last_sim = self._nodes[node][-1] - if self._nodes[node][-1] < self.SIMILARITY_THRESHOLD: - logging.info(f"Node: {node} got a similarity value of: {last_sim} under threshold: {self.SIMILARITY_THRESHOLD}") \ No newline at end of file + nodes = self._nodes.copy() + return nodes + + async def evaluate(self): + if self._grace_rounds: # Grace rounds + self._grace_rounds -= 1 + if self._verbose: logging.info("Grace time hasnt finished...") + return None + + result = set() + if self._last_check == 0: + self._last_check = 0 + nodes = await self._get_nodes() + redundant_nodes = set() + inactive_nodes = set() + for node in nodes: + if nodes[node][0]: + last_sim = nodes[node][0][-1] + inactivity_counter = nodes[node][1] + if inactivity_counter >= self.INACTIVE_THRESHOLD: + inactive_nodes.add(node) + if self._verbose: logging.info(f"Node: {node} hadn't participated in any of the last {self.INACTIVE_THRESHOLD} rounds") + else: + if self._verbose: logging.info(f"Node: {node} inactivity counter: {inactivity_counter}") + + if node not in self._round_missing_nodes: + if last_sim < self.SIMILARITY_THRESHOLD: + if self._verbose: logging.info(f"Node: {node} got a similarity value of: {last_sim} under threshold: {self.SIMILARITY_THRESHOLD}") + else: + if self._verbose: logging.info(f"Node: {node} got a redundant model, cossine simmilarity: {last_sim} over threshold: {self.SIMILARITY_THRESHOLD}") + redundant_nodes.add(node) + + if self._verbose: logging.info(f"Inactive nodes on aggregations: {inactive_nodes}") + if self._verbose: logging.info(f"Redundant nodes on aggregations: {redundant_nodes}") + if inactive_nodes: + result = result.union(inactive_nodes) + if len(redundant_nodes) > 1: + discard_nodes = set(random.sample(list(redundant_nodes), int(len(redundant_nodes)/2))) + if self._verbose: logging.info(f"Discarded redundant nodes: {discard_nodes}") + result = result.union(discard_nodes) + else: + if self._verbose: logging.info(f"Evaluation is on cooldown... | {self.CHECK_COOLDOWN - self._last_check} rounds remaining") + + self._last_check = (self._last_check + 1) % self.CHECK_COOLDOWN + + return result \ No newline at end of file diff --git a/nebula/core/situationalawareness/nodemanager.py b/nebula/core/situationalawareness/nodemanager.py index ca05d022e..33c922ba9 100644 --- a/nebula/core/situationalawareness/nodemanager.py +++ b/nebula/core/situationalawareness/nodemanager.py @@ -81,10 +81,6 @@ def sam(self): def fast_reboot_on(self): return self._fast_reboot_status - async def set_rounds_pushed(self, rp): - if self.fast_reboot_on(): - self.fr.set_rounds_pushed(rp) - def is_additional_participant(self): return self._aditional_participant @@ -134,10 +130,6 @@ async def register_late_neighbor(self, addr, joinning_federation=False): if self.fast_reboot_on(): await self.fr.add_fastReboot_addr(addr) - async def apply_weight_strategy(self, updates: dict): - if self.fast_reboot_on(): - await self.fr.apply_weight_strategy(updates) - """ ############################## # CONNECTIONS # From 270a56afe1d109def0522f9034b825184af8fa4d Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Tue, 4 Mar 2025 13:25:03 +0100 Subject: [PATCH 123/233] feature round start event --- nebula/core/addonmanager.py | 2 + nebula/core/engine.py | 7 +- nebula/core/eventmanager.py | 18 +++ nebula/core/nebulaevents.py | 11 ++ .../trainingpolicy/qdstrainingpolicy.py | 12 +- .../trainingpolicy/sostrainingpolicy.py | 106 +++++++++++++++++- 6 files changed, 143 insertions(+), 13 deletions(-) diff --git a/nebula/core/addonmanager.py b/nebula/core/addonmanager.py index 1416398d3..95651f568 100644 --- a/nebula/core/addonmanager.py +++ b/nebula/core/addonmanager.py @@ -19,10 +19,12 @@ async def deploy_additional_services(self): if self._config.participant["mobility_args"]["mobility"]: mobility = Mobility(self._config, self._engine.cm, verbose=False) self._addons.append(mobility) + if self._config.participant["network_args"]["simulation"]: refresh_conditions_interval = 5 network_simulation = factory_network_simulator("nebula", self._engine.cm, refresh_conditions_interval, "eth0", verbose=False) self._addons.append(network_simulation) + update_interval = 5 gps = factory_gpsmodule("nebula", self._config, self._engine.addr, update_interval, verbose=False) self._addons.append(gps) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 614c57faa..8082f941e 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -1,7 +1,7 @@ import asyncio import logging import os - +import time import docker from nebula.addons.attacks.attacks import create_attack @@ -9,7 +9,7 @@ from nebula.addons.reporter import Reporter from nebula.core.aggregation.aggregator import create_aggregator, create_target_aggregator from nebula.core.eventmanager import EventManager -from nebula.core.nebulaevents import UpdateNeighborEvent, UpdateReceivedEvent +from nebula.core.nebulaevents import UpdateNeighborEvent, UpdateReceivedEvent, RoundStartEvent from nebula.core.network.communications import CommunicationsManager from nebula.core.situationalawareness.nodemanager import NodeManager from nebula.core.utils.locker import Locker @@ -817,6 +817,9 @@ def learning_cycle_finished(self): async def _learning_cycle(self): while self.round is not None and self.round < self.total_rounds: + current_time = time.time() + rse = RoundStartEvent(self.round, current_time) + EventManager.get_instance().publish_node_event(rse) print_msg_box( msg=f"Round {self.round} of {self.total_rounds} started.", indent=2, diff --git a/nebula/core/eventmanager.py b/nebula/core/eventmanager.py index c2d385747..2f9c6daca 100755 --- a/nebula/core/eventmanager.py +++ b/nebula/core/eventmanager.py @@ -144,3 +144,21 @@ async def publish_node_event(self, nodeevent: NodeEvent): if self._verbose: logging.info(f"EventManager | Triggering callback for event type: {event_type.__name__}") except Exception as e: logging.exception(f"EventManager | Error in callback for NodeEvent {event_type.__name__}: {e}") + + async def unsubscribe_event(self, event_type, callback): + """Unsubscribe a callback from a given event type (MessageEvent, AddonEvent, or NodeEvent).""" + if isinstance(event_type, tuple): # MessageEvent + async with self._message_events_lock: + if event_type in self._subscribers and callback in self._subscribers[event_type]: + self._subscribers[event_type].remove(callback) + logging.info(f"EventManager | Unsubscribed callback for MessageEvent: {event_type}") + elif issubclass(event_type, AddonEvent): # AddonEvent + async with self._addons_event_lock: + if event_type in self._addons_events_subs and callback in self._addons_events_subs[event_type]: + self._addons_events_subs[event_type].remove(callback) + logging.info(f"EventManager | Unsubscribed callback for AddonEvent: {event_type.__name__}") + elif issubclass(event_type, NodeEvent): # NodeEvent + async with self._node_events_lock: + if event_type in self._node_events_subs and callback in self._node_events_subs[event_type]: + self._node_events_subs[event_type].remove(callback) + logging.info(f"EventManager | Unsubscribed callback for NodeEvent: {event_type.__name__}") diff --git a/nebula/core/nebulaevents.py b/nebula/core/nebulaevents.py index 75ef463c3..6d143dc49 100644 --- a/nebula/core/nebulaevents.py +++ b/nebula/core/nebulaevents.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod import asyncio +import time class AddonEvent(ABC): @abstractmethod @@ -26,6 +27,16 @@ def __init__(self, message_type, source, message): ############################## """ +class RoundStartEvent(NodeEvent): + def __init__(self, round): + self._round_start_time = time.time() + self._round = round + + async def get_event_data(self): + return (self._round, self._round_Start_time) + + async def is_concurrent(self): + return False class AggregationEvent(NodeEvent): def __init__(self, updates : dict, expected_nodes : set, missing_nodes : set): diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py index 2bc99d1d5..891aeca02 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py @@ -7,9 +7,6 @@ from nebula.core.eventmanager import EventManager from nebula.core.nebulaevents import AggregationEvent import random -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from nebula.core.eventmanager import EventManager # "Quality-Driven Selection" (QDS) class QDSTrainingPolicy(TrainingPolicy): @@ -79,7 +76,6 @@ async def evaluate(self): result = set() if self._last_check == 0: - self._last_check = 0 nodes = await self._get_nodes() redundant_nodes = set() inactive_nodes = set() @@ -98,14 +94,16 @@ async def evaluate(self): if self._verbose: logging.info(f"Node: {node} got a similarity value of: {last_sim} under threshold: {self.SIMILARITY_THRESHOLD}") else: if self._verbose: logging.info(f"Node: {node} got a redundant model, cossine simmilarity: {last_sim} over threshold: {self.SIMILARITY_THRESHOLD}") - redundant_nodes.add(node) + redundant_nodes.add((node, last_sim)) if self._verbose: logging.info(f"Inactive nodes on aggregations: {inactive_nodes}") if self._verbose: logging.info(f"Redundant nodes on aggregations: {redundant_nodes}") if inactive_nodes: - result = result.union(inactive_nodes) + result = result.union(inactive_nodes) if len(redundant_nodes) > 1: - discard_nodes = set(random.sample(list(redundant_nodes), int(len(redundant_nodes)/2))) + sorted_redundant_nodes = sorted(redundant_nodes, key=lambda x: x[1]) + n_discarded = int(len(redundant_nodes)/2) + discard_nodes = sorted_redundant_nodes[-n_discarded:] if self._verbose: logging.info(f"Discarded redundant nodes: {discard_nodes}") result = result.union(discard_nodes) else: diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py index eca6464d0..344728ff5 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py @@ -1,16 +1,114 @@ from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.trainingpolicy import TrainingPolicy +from nebula.core.utils.locker import Locker +from collections import deque +import logging +from nebula.core.eventmanager import EventManager +from nebula.core.nebulaevents import UpdateReceivedEvent, AggregationEvent, RoundStartEvent +import time +import asyncio # "Speed-Oriented Selection" (SOS) class SOSTrainingPolicy(TrainingPolicy): + MAX_HISTORIC_SIZE = 10 + INACTIVE_THRESHOLD = 3 + GRACE_ROUNDS = 10 + CHECK_COOLDOWN = 10 def __init__(self, config): - pass + self._addr = config["addr"] + self._verbose = config["verbose"] + self._nodes : dict[str, tuple[deque, int, float, float]] = {} # _nodes estructura: {node_id: (deque updates epr round, inactivity, time gap between updates, time since last aggregation)} + + self._nodes_lock = Locker(name="nodes_lock", async_lock=True) + self._round_missing_nodes = set() + self._grace_rounds = self.GRACE_ROUNDS + self._last_check = 0 + self._internal_rounds_done = 0 + self._last_aggregation_time = None async def init(self, config): - pass + async with self._nodes_lock: + nodes = config["nodes"] + self._nodes = {node_id: (deque(maxlen=self.MAX_HISTORIC_SIZE), 0, float('inf'), float('inf')) for node_id in nodes} + await EventManager.get_instance().subscribe_node_event(UpdateReceivedEvent, self._process_update_received_event) + await EventManager.get_instance().subscribe_node_event(AggregationEvent, self._process_aggregation_event) + await EventManager.get_instance().subscribe_node_event(RoundStartEvent, self._process_first_round_start) + + + async def _get_nodes(self): + async with self._nodes_lock: + nodes = self._nodes.copy() + return nodes + + async def _process_first_round_start(self, rse : RoundStartEvent): + if not self._last_aggregation_time: + (_, start_time) = await rse.get_event_data() + self._last_aggregation_time = start_time + asyncio.create_task(EventManager.get_instance().unsubscribe_event(RoundStartEvent, self._process_first_round_start)) + + async def _process_aggregation_event(self, are : AggregationEvent): + self._last_aggregation_time = time.time() + if self._verbose: logging.info("Processing aggregation event") + self._internal_rounds_done += 1 + (_, expected_nodes, missing_nodes) = await are.get_event_data() + + async with self._nodes_lock: + for node in expected_nodes: + if node in self._nodes: + history, missed_count, _, _ = self._nodes[node] + history.append((self._internal_rounds_done, 0)) + self._nodes[node] = (history, 0 if node not in missing_nodes else missed_count + 1, float('inf'), float('inf')) + + + async def _process_update_received_event(self, ure : UpdateReceivedEvent): + time_received = time.time() + if self._verbose: logging.info("Processing Update Received event") + (_, _, source, _, _) = await ure.get_event_data() + async with self._nodes_lock: + if source not in self._nodes: + return + + history, missed_count, first_update_time, last_update_time = self._nodes[source] + + if history and history[-1][0] == self._internal_rounds_done: + num_updates = history[-1][1] + 1 + history[-1] = (self._internal_rounds_done, num_updates) + else: + history.append((self._internal_rounds_done, 1)) + + if first_update_time == float('inf'): + if self._last_aggregation_time: + first_update_time = time_received - self._last_aggregation_time + else: + first_update_time = 0 + + if last_update_time == float('inf'): + last_update_time = first_update_time + else: + last_update_time = time_received - last_update_time + + self._nodes[source] = (history, missed_count, first_update_time, last_update_time) async def update_neighbors(self, node, remove=False): - pass + async with self._nodes_lock: + if remove: + self._nodes.pop(node, None) + else: + if not node in self._nodes: + self._nodes.update({node : (deque(maxlen=self.MAX_HISTORIC_SIZE), 0, float('inf'), float('inf'))}) async def evaluate(self): - pass \ No newline at end of file + if self._grace_rounds: # Grace rounds + self._grace_rounds -= 1 + if self._verbose: logging.info("Grace time hasnt finished...") + return None + + result = set() + if self._last_check == 0: + pass + else: + if self._verbose: logging.info(f"Evaluation is on cooldown... | {self.CHECK_COOLDOWN - self._last_check} rounds remaining") + + self._last_check = (self._last_check + 1) % self.CHECK_COOLDOWN + + return result \ No newline at end of file From dabc965bbc3fae9cf53ee0af1e6a6332ebc8d96b Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Tue, 4 Mar 2025 17:40:40 +0100 Subject: [PATCH 124/233] feature speed oriented selection --- nebula/core/engine.py | 2 +- nebula/core/eventmanager.py | 28 +++------ nebula/core/nebulaevents.py | 16 +++-- .../awareness/samodule.py | 2 +- .../trainingpolicy/sostrainingpolicy.py | 61 +++++++++++-------- 5 files changed, 55 insertions(+), 54 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 8082f941e..e16c8eb20 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -819,7 +819,7 @@ async def _learning_cycle(self): while self.round is not None and self.round < self.total_rounds: current_time = time.time() rse = RoundStartEvent(self.round, current_time) - EventManager.get_instance().publish_node_event(rse) + await EventManager.get_instance().publish_node_event(rse) print_msg_box( msg=f"Round {self.round} of {self.total_rounds} started.", indent=2, diff --git a/nebula/core/eventmanager.py b/nebula/core/eventmanager.py index 2f9c6daca..6eef472f0 100755 --- a/nebula/core/eventmanager.py +++ b/nebula/core/eventmanager.py @@ -6,21 +6,7 @@ from abc import ABC, abstractmethod from nebula.core.network.messages import MessageEvent from nebula.core.utils.locker import Locker - -class AddonEvent(ABC): - @abstractmethod - async def get_event_data(self): - pass - -class NodeEvent(ABC): - @abstractmethod - async def get_event_data(self): - pass - - @abstractmethod - async def is_concurrent(self): - pass - +from nebula.core.nebulaevents import AddonEvent, NodeEvent class EventManager: _instance = None @@ -74,11 +60,11 @@ async def publish(self, message_event: MessageEvent): for callback in self._subscribers[event_type]: try: + if self._verbose: logging.info(f"EventManager | Triggering callback for event: {event_type}, from source: {message_event.source}") if asyncio.iscoroutinefunction(callback) or inspect.iscoroutine(callback): await callback(message_event.source, message_event.message) else: - callback(message_event.source, message_event.message) - if self._verbose: logging.info(f"EventManager | Triggering callback for event: {event_type}, from source: {message_event.source}") + callback(message_event.source, message_event.message) except Exception as e: logging.exception(f"EventManager | Error in callback for event {event_type}: {e}") @@ -103,11 +89,11 @@ async def publish_addonevent(self, addonevent: AddonEvent): for callback in self._addons_events_subs[event_type]: try: + if self._verbose: logging.info(f"EventManager | Triggering callback for event type: {event_type.__name__}") if asyncio.iscoroutinefunction(callback) or inspect.iscoroutine(callback): await callback(addonevent) else: - callback(addonevent) - if self._verbose: logging.info(f"EventManager | Triggering callback for event type: {event_type.__name__}") + callback(addonevent) except Exception as e: logging.exception(f"EventManager | Error in callback for AddonEvent {event_type.__name__}: {e}") @@ -134,14 +120,14 @@ async def publish_node_event(self, nodeevent: NodeEvent): for callback in self._node_events_subs[event_type]: try: + if self._verbose: logging.info(f"EventManager | Triggering callback for event type: {event_type.__name__}") if asyncio.iscoroutinefunction(callback) or inspect.iscoroutine(callback): if await nodeevent.is_concurrent(): asyncio.create_task(callback(nodeevent)) else: await callback(nodeevent) else: - callback(nodeevent) - if self._verbose: logging.info(f"EventManager | Triggering callback for event type: {event_type.__name__}") + callback(nodeevent) except Exception as e: logging.exception(f"EventManager | Error in callback for NodeEvent {event_type.__name__}: {e}") diff --git a/nebula/core/nebulaevents.py b/nebula/core/nebulaevents.py index 6d143dc49..164413fc2 100644 --- a/nebula/core/nebulaevents.py +++ b/nebula/core/nebulaevents.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod import asyncio -import time class AddonEvent(ABC): @abstractmethod @@ -28,12 +27,21 @@ def __init__(self, message_type, source, message): """ class RoundStartEvent(NodeEvent): - def __init__(self, round): - self._round_start_time = time.time() + def __init__(self, round, start_time): + """Event triggered when round is going to start. + + Args: + round (int): Round number. + start_time (time): Current time when round is going to start. + """ + self._round_start_time = start_time self._round = round + def __str__(self): + return "Round starting" + async def get_event_data(self): - return (self._round, self._round_Start_time) + return (self._round, self._round_start_time) async def is_concurrent(self): return False diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index 84f4d8f88..6256d8162 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -30,7 +30,7 @@ def __init__( self._topology = topology self._node_manager: NodeManager = nodemanager self._situational_awareness_network = SANetwork(self, self.cm, self._addr, self._topology) - self._situational_awareness_training = SATraining(self, self._addr, "qds", "fastreboot", verbose=True) + self._situational_awareness_training = SATraining(self, self._addr, "sos", "fastreboot", verbose=True) self._restructure_process_lock = Locker(name="restructure_process_lock") self._restructure_cooldown = 0 diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py index 344728ff5..a33feb4dc 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py @@ -11,8 +11,8 @@ class SOSTrainingPolicy(TrainingPolicy): MAX_HISTORIC_SIZE = 10 INACTIVE_THRESHOLD = 3 - GRACE_ROUNDS = 10 - CHECK_COOLDOWN = 10 + GRACE_ROUNDS = 0 + CHECK_COOLDOWN = 1 def __init__(self, config): self._addr = config["addr"] @@ -20,10 +20,9 @@ def __init__(self, config): self._nodes : dict[str, tuple[deque, int, float, float]] = {} # _nodes estructura: {node_id: (deque updates epr round, inactivity, time gap between updates, time since last aggregation)} self._nodes_lock = Locker(name="nodes_lock", async_lock=True) - self._round_missing_nodes = set() self._grace_rounds = self.GRACE_ROUNDS self._last_check = 0 - self._internal_rounds_done = 0 + self._internal_rounds_done = -1 self._last_aggregation_time = None async def init(self, config): @@ -31,8 +30,8 @@ async def init(self, config): nodes = config["nodes"] self._nodes = {node_id: (deque(maxlen=self.MAX_HISTORIC_SIZE), 0, float('inf'), float('inf')) for node_id in nodes} await EventManager.get_instance().subscribe_node_event(UpdateReceivedEvent, self._process_update_received_event) - await EventManager.get_instance().subscribe_node_event(AggregationEvent, self._process_aggregation_event) await EventManager.get_instance().subscribe_node_event(RoundStartEvent, self._process_first_round_start) + await EventManager.get_instance().subscribe_node_event(AggregationEvent, self._process_aggregation_event) async def _get_nodes(self): @@ -41,53 +40,55 @@ async def _get_nodes(self): return nodes async def _process_first_round_start(self, rse : RoundStartEvent): + if self._verbose: logging.info("Processing round start event") if not self._last_aggregation_time: + if self._verbose: logging.info("First round start timing assigment") (_, start_time) = await rse.get_event_data() self._last_aggregation_time = start_time - asyncio.create_task(EventManager.get_instance().unsubscribe_event(RoundStartEvent, self._process_first_round_start)) + self._internal_rounds_done += 1 async def _process_aggregation_event(self, are : AggregationEvent): self._last_aggregation_time = time.time() if self._verbose: logging.info("Processing aggregation event") - self._internal_rounds_done += 1 (_, expected_nodes, missing_nodes) = await are.get_event_data() async with self._nodes_lock: for node in expected_nodes: if node in self._nodes: - history, missed_count, _, _ = self._nodes[node] - history.append((self._internal_rounds_done, 0)) - self._nodes[node] = (history, 0 if node not in missing_nodes else missed_count + 1, float('inf'), float('inf')) + history, missed_count, gap_btween_updts, time_since_agg = self._nodes[node] + self._nodes[node] = (history, 0 if node not in missing_nodes else missed_count + 1, gap_btween_updts, time_since_agg) async def _process_update_received_event(self, ure : UpdateReceivedEvent): time_received = time.time() if self._verbose: logging.info("Processing Update Received event") (_, _, source, _, _) = await ure.get_event_data() + async with self._nodes_lock: if source not in self._nodes: return - history, missed_count, first_update_time, last_update_time = self._nodes[source] + history, missed_count, first_update_time, last_update_time = self._nodes[source] - if history and history[-1][0] == self._internal_rounds_done: - num_updates = history[-1][1] + 1 - history[-1] = (self._internal_rounds_done, num_updates) - else: - history.append((self._internal_rounds_done, 1)) - - if first_update_time == float('inf'): - if self._last_aggregation_time: - first_update_time = time_received - self._last_aggregation_time + if history and history[-1][0] == self._internal_rounds_done: + num_updates = history[-1][1] + 1 + history[-1] = (self._internal_rounds_done, num_updates) else: - first_update_time = 0 + history.append((self._internal_rounds_done, 1)) - if last_update_time == float('inf'): - last_update_time = first_update_time - else: - last_update_time = time_received - last_update_time + if first_update_time == float('inf'): + if self._last_aggregation_time: + first_update_time = time_received - self._last_aggregation_time + else: + first_update_time = 0 - self._nodes[source] = (history, missed_count, first_update_time, last_update_time) + #TODO el error estΓ‘ aquΓ­ hay q comprobar con respecto a self_aggregation_time + if last_update_time == float('inf'): + last_update_time = first_update_time + else: + last_update_time = time_received - last_update_time + + self._nodes[source] = (history, missed_count, first_update_time, last_update_time) async def update_neighbors(self, node, remove=False): async with self._nodes_lock: @@ -105,7 +106,13 @@ async def evaluate(self): result = set() if self._last_check == 0: - pass + nodes = await self._get_nodes() + for node in nodes.keys(): + logging.info(f"Internal rounds done: {self._internal_rounds_done}") + logging.info(f"Node: {node}, {nodes[node][0]}") + updates_received = {x[1] for x in nodes[node][0] if x[0] == self._internal_rounds_done} + logging.info(f"Node: {node}, Updates received this round: {updates_received}, last gap: {nodes[node][2]}, time since last agg: {nodes[node][3]}") + else: if self._verbose: logging.info(f"Evaluation is on cooldown... | {self.CHECK_COOLDOWN - self._last_check} rounds remaining") From 719cefd32c811f4b44cd90e95aa3a7f06c7c704d Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 5 Mar 2025 14:45:00 +0100 Subject: [PATCH 125/233] feature CFL implementation for udpate storage --- nebula/core/aggregation/aggregator.py | 20 +- .../updatehandlers/cflupdatehandler.py | 509 ++++++------------ .../updatehandlers/dflupdatehandler.py | 4 +- .../updatehandlers/sdflupdatehandler.py | 35 ++ .../updatehandlers/updatehandler.py | 9 +- nebula/core/engine.py | 32 +- nebula/core/network/communications.py | 11 - .../trainingpolicy/sostrainingpolicy.py | 43 +- 8 files changed, 264 insertions(+), 399 deletions(-) create mode 100644 nebula/core/aggregation/updatehandlers/sdflupdatehandler.py diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index 0adc4c240..4f31505f3 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -41,7 +41,6 @@ def __init__(self, config=None, engine=None): self._addr = config.participant["network_args"]["addr"] logging.info(f"[{self.__class__.__name__}] Starting Aggregator") self._federation_nodes = set() - self._waiting_global_update = False self._pending_models_to_aggregate = {} self._pending_models_to_aggregate_lock = Locker(name="pending_models_to_aggregate_lock", async_lock=True) self._aggregation_done_lock = Locker(name="aggregation_done_lock", async_lock=True) @@ -55,11 +54,7 @@ def __str__(self): def __repr__(self): return self.__str__() - - @property - def cm(self): - return self.engine.cm - + @property def us(self): return self._update_storage @@ -71,7 +66,7 @@ def run_aggregation(self, models): return None async def init(self): - await self.us.init() + await self.us.init(self.config) async def update_federation_nodes(self, federation_nodes: set): await self.us.round_expected_updates(federation_nodes=federation_nodes) @@ -88,9 +83,6 @@ async def update_federation_nodes(self, federation_nodes: set): def get_nodes_pending_models_to_aggregate(self): return self._federation_nodes - def set_waiting_global_update(self): - self._waiting_global_update = True - async def get_aggregation(self): try: timeout = self.config.participant["aggregator_args"]["aggregation_timeout"] @@ -130,13 +122,7 @@ async def get_aggregation(self): logging.info(f"πŸ”„ get_aggregation | Aggregation incomplete, missing models from: {missing_nodes}") else: logging.info("πŸ”„ get_aggregation | All models accounted for, proceeding with aggregation.") - - logging.info( - f"πŸ”„ Broadcasting MODELS_INCLUDED for round {self.engine.get_round()}" - ) - message = self.cm.create_message("federation", "federation_models_included", [str(arg) for arg in [self.engine.get_round()]]) - await self.cm.send_message_to_neighbors(message) - + agg_event = AggregationEvent(updates, self._federation_nodes, missing_nodes) await EventManager.get_instance().publish_node_event(agg_event) aggregated_result = self.run_aggregation(updates) diff --git a/nebula/core/aggregation/updatehandlers/cflupdatehandler.py b/nebula/core/aggregation/updatehandlers/cflupdatehandler.py index 5440bf8a9..69f6c6b1f 100644 --- a/nebula/core/aggregation/updatehandlers/cflupdatehandler.py +++ b/nebula/core/aggregation/updatehandlers/cflupdatehandler.py @@ -2,362 +2,191 @@ import logging from nebula.core.utils.locker import Locker from nebula.core.aggregation.updatehandlers.updatehandler import UpdateHandler +from collections import deque +from typing import Dict, Tuple, Deque +from nebula.core.nebulaevents import UpdateNeighborEvent, UpdateReceivedEvent +from nebula.core.eventmanager import EventManager +import time from typing import TYPE_CHECKING if TYPE_CHECKING: from nebula.core.aggregation.aggregator import Aggregator +class Update(): + def __init__(self, model, weight, source, round, time_received): + self.model = model + self.weight = weight + self.source = source + self.round = round + self.time_received = time_received + + def __eq__(self, other): + return self.round == other.round + +MAX_UPDATE_BUFFER_SIZE = 1 + class CFLUpdateHandler(UpdateHandler): def __init__( self, aggregator, - addr + addr, + buffersize = MAX_UPDATE_BUFFER_SIZE ): - pass + self._addr = addr + self._aggregator: Aggregator = aggregator + self._buffersize = buffersize + self._updates_storage: Dict[str, Deque[Update]] = {} + self._updates_storage_lock = Locker(name="updates_storage_lock", async_lock=True) + self._sources_expected = set() + self._sources_received = set() + self._round_updates_lock = Locker(name="round_updates_lock", async_lock=True) # se coge cuando se empieza a comprobar si estan todas las updates + self._update_federation_lock = Locker(name="update_federation_lock", async_lock=True) + self._notification_sent_lock = Locker(name="notification_sent_lock", async_lock=True) + self._notification = False + self._missing_ones = set() + self._role = "" + - async def round_expected_updates(self, federation_nodes: set): - raise NotImplementedError + @property + def us(self): + return self._updates_storage - async def storage_update(self, model, weight, source, round, local=False): - raise NotImplementedError + @property + def agg(self): + return self._aggregator + + async def init(self, config): + self._role = config.participant["device_args"]["role"] + await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent, self.notify_federation_update) + await EventManager.get_instance().subscribe_node_event(UpdateReceivedEvent, self.storage_update) + + async def round_expected_updates(self, federation_nodes: set): + await self._update_federation_lock.acquire_async() + await self._updates_storage_lock.acquire_async() + self._sources_expected = federation_nodes.copy() + self._sources_received.clear() + + # Initialize new nodes + for fn in federation_nodes: + if fn not in self.us: + self.us[fn] = (deque(maxlen=self._buffersize)) + + # Clear removed nodes + removed_nodes = [node for node in self._updates_storage.keys() if node not in federation_nodes] + for rn in removed_nodes: + del self._updates_storage[rn] + + await self._updates_storage_lock.release_async() + await self._update_federation_lock.release_async() + + # Lock to check if all updates received + if self._round_updates_lock.locked(): + self._round_updates_lock.release_async() + + self._notification = False + + async def storage_update(self, updt_received_event : UpdateReceivedEvent): + time_received = time.time() + (model, weight, source, round, _) = await updt_received_event.get_event_data() + + if source in self._sources_expected: + updt = Update(model, weight, source, round, time_received) + await self._updates_storage_lock.acquire_async() + if updt in self.us[source]: + logging.info(f"Discard | Alerady received update from source: {source} for round: {round}") + else: + self.us[source].append(updt) + logging.info(f"Storage Update | source={source} | round={round} | weight={weight} | federation nodes: {self._sources_expected}") + + self._sources_received.add(source) + updates_left = self._sources_expected.difference(self._sources_received) + logging.info(f"Updates received ({len(self._sources_received)}/{len(self._sources_expected)}) | Missing nodes: {updates_left}") + if self._round_updates_lock.locked() and not updates_left: + all_rec = await self._all_updates_received() + if all_rec: + await self._notify() + await self._updates_storage_lock.release_async() + else: + if not source in self._sources_received: + logging.info(f"Discard update | source: {source} not in expected updates for this Round") async def get_round_updates(self) -> dict[str, tuple[object, float]]: - raise NotImplementedError - - async def notify_federation_update(self, source, remove=False): - raise NotImplementedError + await self._updates_storage_lock.acquire_async() + updates_missing = self._sources_expected.difference(self._sources_received) + if updates_missing: + self._missing_ones = updates_missing + logging.info(f"Missing updates from sources: {updates_missing}") + updates = {} + for sr in self._sources_received: + if self._role == "trainer" and len(self._sources_received) > 1: # if trainer node ignore self updt if has received udpate from server + if sr == self._addr: + continue + source_historic = self.us[sr] + updt: Update = None + updt = source_historic[-1] # Get last update received + updates[sr] = (updt.model, updt.weight) + await self._updates_storage_lock.release_async() + return updates + + async def notify_federation_update(self, updt_nei_event : UpdateNeighborEvent): + source, remove = await updt_nei_event.get_event_data() + if not remove: + if self._round_updates_lock.locked(): + logging.info(f"Source: {source} will be count next round") + else: + await self._update_source(source, remove) + else: + if not source in self._sources_received: # Not received update from this source yet + await self._update_source(source, remove=True) + await self._all_updates_received() # Verify if discarding node aggregation could be done + else: + logging.info(f"Already received update from: {source}, it will be discarded next round") + + async def _update_source(self, source, remove=False): + logging.info(f"πŸ”„ Update | remove: {remove} | source: {source}") + await self._updates_storage_lock.acquire_async() + if remove: + self._sources_expected.discard(source) + else: + self.us[source] = (deque(maxlen=self._buffersize)) + self._sources_expected.add(source) + logging.info(f"federation nodes expected this round: {self._sources_expected}") + await self._updates_storage_lock.release_async() - async def get_round_missing_nodes(self) -> set[str]: - raise NotImplementedError + async def get_round_missing_nodes(self): + return self._missing_ones async def notify_if_all_updates_received(self): - raise NotImplementedError - + logging.info("Set notification when all expected updates received") + await self._round_updates_lock.acquire_async() + await self._updates_storage_lock.acquire_async() + all_received = await self._all_updates_received() + await self._updates_storage_lock.release_async() + if all_received: + await self._notify() + async def stop_notifying_updates(self): - raise NotImplementedError - + if self._round_updates_lock.locked(): + logging.info("Stop notification updates") + await self._round_updates_lock.release_async() -# def get_nodes_pending_models_to_aggregate(self): - # return {node for key in self._pending_models_to_aggregate.keys() for node in key.split()} - - # async def _handle_global_update(self, model, source): - # logging.info(f"πŸ”„ _handle_global_update | source={source}") - # logging.info( - # f"πŸ”„ _handle_global_update | Received a model from {source}. Overwriting __models with the aggregated model." - # ) - # self._pending_models_to_aggregate.clear() - # self._pending_models_to_aggregate = {source: (model, 1)} - # self._waiting_global_update = False - # await self._add_model_lock.release_async() - # await self._aggregation_done_lock.release_async() - - # async def _add_pending_model(self, model, weight, source): - # if len(self._federation_nodes) <= len(self.get_nodes_pending_models_to_aggregate()): - # logging.info("πŸ”„ _add_pending_model | Ignoring model...") - # await self._add_model_lock.release_async() - # return None - - # if source not in self._federation_nodes: - # logging.info(f"πŸ”„ _add_pending_model | Can't add a model from ({source}), which is not in the federation.") - # await self._add_model_lock.release_async() - # return None - - # elif source not in self.get_nodes_pending_models_to_aggregate(): - # logging.info( - # "πŸ”„ _add_pending_model | Node is not in the aggregation buffer --> Include model in the aggregation buffer." - # ) - # self._pending_models_to_aggregate.update({source: (model, weight)}) - - # logging.info( - # f"πŸ”„ _add_pending_model | Model added in aggregation buffer ({len(self.get_nodes_pending_models_to_aggregate())!s}/{len(self._federation_nodes)!s}) | Pending nodes: {self._federation_nodes - self.get_nodes_pending_models_to_aggregate()}" - # ) - - # # Check if _future_models_to_aggregate has models in the current round to include in the aggregation buffer - # if self.engine.get_round() in self._future_models_to_aggregate: - # logging.info( - # f"πŸ”„ _add_pending_model | Including next models in the aggregation buffer for round {self.engine.get_round()}" - # ) - # for future_model in self._future_models_to_aggregate[self.engine.get_round()]: - # if future_model is None: - # continue - # future_model, future_weight, future_source = future_model - # if ( - # future_source in self._federation_nodes - # and future_source not in self.get_nodes_pending_models_to_aggregate() - # ): - # self._pending_models_to_aggregate.update({future_source: (future_model, future_weight)}) - # logging.info( - # f"πŸ”„ _add_pending_model | Next model added in aggregation buffer ({len(self.get_nodes_pending_models_to_aggregate())!s}/{len(self._federation_nodes)!s}) | Pending nodes: {self._federation_nodes - self.get_nodes_pending_models_to_aggregate()}" - # ) - # del self._future_models_to_aggregate[self.engine.get_round()] - - # for future_round in list(self._future_models_to_aggregate.keys()): - # if future_round < self.engine.get_round(): - # del self._future_models_to_aggregate[future_round] - - # if len(self.get_nodes_pending_models_to_aggregate()) >= len(self._federation_nodes): - # logging.info("πŸ”„ _add_pending_model | All models were added in the aggregation buffer. Run aggregation...") - # await self._aggregation_done_lock.release_async() - - # await self._add_model_lock.release_async() - # return self.get_nodes_pending_models_to_aggregate() - - # async def include_model_in_buffer(self, model, weight, source=None, round=None, local=False): - # await self._add_model_lock.acquire_async() - # logging.info( - # f"πŸ”„ include_model_in_buffer | source={source} | round={round} | weight={weight} |--| __models={self._pending_models_to_aggregate.keys()} | federation_nodes={self._federation_nodes} | pending_models_to_aggregate={self.get_nodes_pending_models_to_aggregate()}" - # ) - # if model is None: - # logging.info("πŸ”„ include_model_in_buffer | Ignoring model bad formed...") - # await self._add_model_lock.release_async() - # return - - # if round == -1: - # # Be sure that the model message is not from the initialization round (round = -1) - # logging.info("πŸ”„ include_model_in_buffer | Ignoring model with round -1") - # await self._add_model_lock.release_async() - # return - - # if self._waiting_global_update and not local: - # await self._handle_global_update(model, source) - # return - - # await self._add_pending_model(model, weight, source) - - # if len(self.get_nodes_pending_models_to_aggregate()) >= len(self._federation_nodes): - # logging.info( - # f"πŸ”„ include_model_in_buffer | Broadcasting MODELS_INCLUDED for round {self.engine.get_round()}" - # ) - # message = self.cm.create_message( - # "federation", "federation_models_included", [str(arg) for arg in [self.engine.get_round()]] - # ) - # await self.cm.send_message_to_neighbors(message) - - # return - - # async def get_aggregation(self): - # try: - # timeout = self.config.participant["aggregator_args"]["aggregation_timeout"] - # logging.info(f"Aggregation timeout: {timeout} starts...") - # await self.us.notify_if_all_updates_received() - # lock_task = asyncio.create_task(self._aggregation_done_lock.acquire_async(timeout=timeout)) - # skip_task = asyncio.create_task(self._aggregation_waiting_skip.wait()) - # done, pending = await asyncio.wait( - # [lock_task, skip_task], - # return_when=asyncio.FIRST_COMPLETED, - # ) - # lock_acquired = lock_task in done - # if skip_task in done: - # logging.info("Skipping aggregation timeout, updates received before grace time") - # self._aggregation_waiting_skip.clear() - # if not lock_acquired: - # lock_task.cancel() - # try: - # await lock_task # Clean cancel - # except asyncio.CancelledError: - # pass - - # except TimeoutError: - # logging.exception("πŸ”„ get_aggregation | Timeout reached for aggregation") - # except asyncio.CancelledError: - # logging.exception("πŸ”„ get_aggregation | Lock acquisition was cancelled") - # except Exception as e: - # logging.exception(f"πŸ”„ get_aggregation | Error acquiring lock: {e}") - # finally: - # if lock_acquired: - # await self._aggregation_done_lock.release_async() - - # await self.us.stop_notifying_updates() - # updates = await self.us.get_round_updates() - - # missing_nodes = await self.us.get_round_missing_nodes() - - # if missing_nodes: - # logging.info(f"πŸ”„ get_aggregation | Aggregation incomplete, missing models from: {missing_nodes}") - # else: - # logging.info("πŸ”„ get_aggregation | All models accounted for, proceeding with aggregation.") - - # logging.info( - # f"πŸ”„ Broadcasting MODELS_INCLUDED for round {self.engine.get_round()}" - # ) - # message = self.cm.create_message( - # "federation", "federation_models_included", [str(arg) for arg in [self.engine.get_round()]] - # ) - # await self.cm.send_message_to_neighbors(message) - - # if self._waiting_global_update and len(self._pending_models_to_aggregate) == 1: - # logging.info( - # "πŸ”„ get_aggregation | Received an global model. Overwriting my model with the aggregated model." - # ) - # aggregated_model = next(iter(self._pending_models_to_aggregate.values()))[0] - # self._pending_models_to_aggregate.clear() - # return aggregated_model - - # unique_nodes_involved = set(node for key in self._pending_models_to_aggregate for node in key.split()) - - # if len(unique_nodes_involved) != len(self._federation_nodes): - # missing_nodes = self._federation_nodes - unique_nodes_involved - # logging.info(f"πŸ”„ get_aggregation | Aggregation incomplete, missing models from: {missing_nodes}") - # else: - # logging.info("πŸ”„ get_aggregation | All models accounted for, proceeding with aggregation.") - - # self._pending_models_to_aggregate = await self.engine.apply_weight_strategy(self._pending_models_to_aggregate) - # aggregated_result = self.run_aggregation(self._pending_models_to_aggregate) - # self._pending_models_to_aggregate.clear() - - # updates = await self.engine.apply_weight_strategy(updates) - # aggregated_result = self.run_aggregation(updates) - # return aggregated_result - - # async def include_next_model_in_buffer(self, model, weight, source=None, round=None): - # logging.info(f"πŸ”„ include_next_model_in_buffer | source={source} | round={round} | weight={weight}") - # if round not in self._future_models_to_aggregate: - # self._future_models_to_aggregate[round] = [] - # decoded_model = self.engine.trainer.deserialize_model(model) - # await self._add_next_model_lock.acquire_async() - # self._future_models_to_aggregate[round].append((decoded_model, weight, source)) - # await self._add_next_model_lock.release_async() - - # # Verify if we are waiting an update that maybe we wont received - # if self._aggregation_done_lock.locked(): - # pending_nodes: set = self._federation_nodes - self.get_nodes_pending_models_to_aggregate() - # if pending_nodes: - # for f_round, future_updates in self._future_models_to_aggregate.items(): - # for _, _, source in future_updates: - # if source in pending_nodes: - # # logging.info(f"Waiting update from source: {source}, but future update storaged for round: {f_round}") - # pending_nodes.discard(source) - - # if not pending_nodes: - # logging.info("Received advanced updates for all sources missing this round") - # await self._aggregation_done_lock.release_async() - - - # def verify_push_done(self, current_round): - # current_round = self.engine.get_round() - # if self.engine.get_synchronizing_rounds(): - # logging.info("Verifying if round push is done") - # if self._end_round_push <= current_round: - # logging.info("Push done...") - # self.engine.set_synchronizing_rounds(False) - # self._end_round_push = 0 - # if len(self._future_models_to_aggregate.items()) < 2: - # logging.info("Device is sinchronized") - # self.engine.update_sinchronized_status(True) - # else: - # logging.info("Device is not sinchronized yet | more actions required...") - - # async def aggregation_push_available(self): - # """ - # If the node is not sinchronized with the federation, it may be possible to make a push - # and try to catch the federation asap. - # """ - # # TODO verify if an already sinchronized node gets desinchronized - # current_round = self.engine.get_round() - # self.verify_push_done(current_round) - - # await self._push_strategy_lock.acquire_async() - - # logging.info( - # f"❗️ synchronized status: {self.engine.get_sinchronized_status()} | Analizing if an aggregation push is available..." - # ) - # if ( - # not self.engine.get_sinchronized_status() - # and not self.engine.get_trainning_in_progress_lock().locked() - # and not self.engine.get_synchronizing_rounds() - # ): - # n_fed_nodes = len(self._federation_nodes) - # further_round = current_round - # logging.info( - # f" Pending models: {len(self.get_nodes_pending_models_to_aggregate())} | federation: {n_fed_nodes}" - # ) - # if len(self.get_nodes_pending_models_to_aggregate()) < n_fed_nodes: - # n_fed_nodes -= 1 - # for f_round, fm in self._future_models_to_aggregate.items(): - # # future_models dont count self node - # if (f_round - current_round) > 1 or len(fm) == n_fed_nodes: - # further_round = f_round - # push = self.engine.get_push_acceleration() - # if push == "slow": - # logging.info("❗️ SLOW push selected") - # logging.info( - # f"❗️ Federation is at least {(f_round - current_round)} rounds ahead, Pushing slow..." - # ) - # await self.engine.set_pushed_done(further_round - current_round) - # self.engine.update_sinchronized_status(False) - # self.engine.set_synchronizing_rounds(True) - # self._end_round_push = further_round - # self._aggregation_waiting_skip.set() - # await self._push_strategy_lock.release_async() - # return - - # if further_round != current_round and push == "fast": - # logging.info("❗️ FAST push selected") - # logging.info(f"❗️ FUTURE round: {further_round} is available, Pushing fast...") - - # if further_round == (current_round + 1): - # logging.info(f"πŸ”„ Rounds jumped: {1}...") - # await self.engine.set_pushed_done(further_round - current_round) - # self.engine.update_sinchronized_status(False) - # self.engine.set_synchronizing_rounds(True) - # self._end_round_push = further_round - # self._aggregation_waiting_skip.set() - # await self._push_strategy_lock.release_async() - # return - - # logging.info(f"πŸ”„ Number of rounds jumped: {further_round - current_round}...") - # own_update = self._pending_models_to_aggregate.get(self.engine.get_addr()) - # while own_update == None: - # own_update = self._pending_models_to_aggregate.get(self.engine.get_addr()) - # asyncio.sleep(1) - # (model, weight) = own_update - - # # Getting locks to avoid concurrency issues - # await self._add_model_lock.acquire_async() - # await self._add_next_model_lock.acquire_async() - - # # Remove all pendings updates and add own_update - # self._pending_models_to_aggregate.clear() - # self._pending_models_to_aggregate.update({self.engine.get_addr(): (model, weight)}) - - # # Add to pendings the future round updates - # for future_update in self._future_models_to_aggregate[further_round]: - # (decoded_model, weight, source) = future_update - # self._pending_models_to_aggregate.update({source: (decoded_model, weight)}) - - # # Clear all rounds that are going to be jumped - # self._future_models_to_aggregate = { - # key: value for key, value in self._future_models_to_aggregate.items() if key > further_round - # } - - # self.engine.update_sinchronized_status(False) - # self.engine.set_synchronizing_rounds(True) - # await self.engine.set_pushed_done(further_round - current_round) - # self._end_round_push = further_round - # self.engine.set_round(further_round) - # await self._add_model_lock.release_async() - # await self._add_next_model_lock.release_async() - # await self._push_strategy_lock.release_async() - # self._aggregation_waiting_skip.set() - # return - - # else: - # if len(self._future_models_to_aggregate.items()) < 2: - # logging.info("Info | No future rounds available, device is up to date...") - # self.engine.update_sinchronized_status(True) - # self.engine.set_synchronizing_rounds(False) - # else: - # logging.info("No rounds can be pushed...") - # await self._push_strategy_lock.release_async() - # else: - # logging.info( - # f"All models updates are received | models number: {len(self.get_nodes_pending_models_to_aggregate())}" - # ) - # await self._push_strategy_lock.release_async() - # else: - # if not self.engine.get_sinchronized_status(): - # if self.engine.get_trainning_in_progress_lock().locked(): - # logging.info("❗️ Cannot analize push | Trainning in progress") - # elif self.engine.get_synchronizing_rounds(): - # logging.info("❗️ Cannot analize push | Already pushing rounds") - # await self._push_strategy_lock.release_async() + async def _notify(self): + await self._notification_sent_lock.acquire_async() + if self._notification: + await self._notification_sent_lock.release_async() + return + self._notification = True + await self.stop_notifying_updates() + await self._notification_sent_lock.release_async() + logging.info("πŸ”„ Notifying aggregator to release aggregation") + await self.agg.notify_all_updates_received() + + async def _all_updates_received(self): + updates_left = self._sources_expected.difference(self._sources_received) + all_received = False + if len(updates_left) == 0: + logging.info("All updates have been received this round") + await self._round_updates_lock.release_async() + all_received = True + return all_received + \ No newline at end of file diff --git a/nebula/core/aggregation/updatehandlers/dflupdatehandler.py b/nebula/core/aggregation/updatehandlers/dflupdatehandler.py index b1684b754..203cf1c87 100644 --- a/nebula/core/aggregation/updatehandlers/dflupdatehandler.py +++ b/nebula/core/aggregation/updatehandlers/dflupdatehandler.py @@ -29,7 +29,7 @@ class DFLUpdateHandler(UpdateHandler): def __init__( self, aggregator, - addr, + addr, buffersize=MAX_UPDATE_BUFFER_SIZE ): self._addr = addr @@ -54,7 +54,7 @@ def us(self): def agg(self): return self._aggregator - async def init(self): + async def init(self, config=None): await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent, self.notify_federation_update) await EventManager.get_instance().subscribe_node_event(UpdateReceivedEvent, self.storage_update) diff --git a/nebula/core/aggregation/updatehandlers/sdflupdatehandler.py b/nebula/core/aggregation/updatehandlers/sdflupdatehandler.py new file mode 100644 index 000000000..5969ffaac --- /dev/null +++ b/nebula/core/aggregation/updatehandlers/sdflupdatehandler.py @@ -0,0 +1,35 @@ +from nebula.core.aggregation.updatehandlers.updatehandler import UpdateHandler +from nebula.core.nebulaevents import UpdateNeighborEvent, UpdateReceivedEvent +import asyncio + +class SFDLUpdateHandler(UpdateHandler): + def __init__( + self, + aggregator, + addr, + ): + pass + + async def init(): + raise NotImplementedError + + async def round_expected_updates(self, federation_nodes: set): + raise NotImplementedError + + async def storage_update(self, updt_received_event : UpdateReceivedEvent): + raise NotImplementedError + + async def get_round_updates(self) -> dict[str, tuple[object, float]]: + raise NotImplementedError + + async def notify_federation_update(self, updt_nei_event : UpdateNeighborEvent): + raise NotImplementedError + + async def get_round_missing_nodes(self) -> set[str]: + raise NotImplementedError + + async def notify_if_all_updates_received(self): + raise NotImplementedError + + async def stop_notifying_updates(self): + raise NotImplementedError \ No newline at end of file diff --git a/nebula/core/aggregation/updatehandlers/updatehandler.py b/nebula/core/aggregation/updatehandlers/updatehandler.py index ca9d05ea1..35295d160 100644 --- a/nebula/core/aggregation/updatehandlers/updatehandler.py +++ b/nebula/core/aggregation/updatehandlers/updatehandler.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from nebula.core.nebulaevents import UpdateNeighborEvent, UpdateReceivedEvent class UpdateHandlerException(Exception): pass @@ -12,7 +13,7 @@ class UpdateHandler(ABC): """ @abstractmethod - async def init(): + async def init(self, config : dict): raise NotImplementedError @abstractmethod @@ -30,7 +31,7 @@ async def round_expected_updates(self, federation_nodes: set): raise NotImplementedError @abstractmethod - async def storage_update(self, model, weight, source, round, local=False): + async def storage_update(self, updt_received_event : UpdateReceivedEvent): """ Stores an update from a source in the update storage. @@ -61,7 +62,7 @@ async def get_round_updates(self) -> dict[str, tuple[object, float]]: raise NotImplementedError @abstractmethod - async def notify_federation_update(self, source, remove=False): + async def notify_federation_update(self, updt_nei_event : UpdateNeighborEvent): """ Notifies the system of a change in the federation regarding a specific source. @@ -104,10 +105,12 @@ async def stop_notifying_updates(self): def factory_update_handler(updt_handler, aggregator, addr) -> UpdateHandler: from nebula.core.aggregation.updatehandlers.dflupdatehandler import DFLUpdateHandler from nebula.core.aggregation.updatehandlers.cflupdatehandler import CFLUpdateHandler + from nebula.core.aggregation.updatehandlers.sdflupdatehandler import SFDLUpdateHandler UPDATE_HANDLERS = { "DFL": DFLUpdateHandler, "CFL": CFLUpdateHandler, + "SDFL": SFDLUpdateHandler } update_handler = UPDATE_HANDLERS.get(updt_handler, None) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index e16c8eb20..b1361482d 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -9,7 +9,7 @@ from nebula.addons.reporter import Reporter from nebula.core.aggregation.aggregator import create_aggregator, create_target_aggregator from nebula.core.eventmanager import EventManager -from nebula.core.nebulaevents import UpdateNeighborEvent, UpdateReceivedEvent, RoundStartEvent +from nebula.core.nebulaevents import UpdateNeighborEvent, UpdateReceivedEvent, RoundStartEvent, AggregationEvent from nebula.core.network.communications import CommunicationsManager from nebula.core.situationalawareness.nodemanager import NodeManager from nebula.core.utils.locker import Locker @@ -242,14 +242,6 @@ def set_round(self, new_round): self.round = new_round self.trainer.set_current_round(new_round) - async def init_message_callbacks(self): - logging.info("Registering callbacks for MessageEvents...") - await self.register_message_events_callbacks() - - # Additional callbacks not registered automatically - await self.register_message_callback(("model", "initialization"), "model_initialization_callback") - await self.register_message_callback(("model", "update"), "model_update_callback") - """ ############################## # MODEL CALLBACKS # ############################## @@ -551,6 +543,17 @@ async def _link_disconnect_from_callback(self, source, message): ############################## """ + async def register_events_callbacks(self): + await self.init_message_callbacks() + await EventManager.get_instance().subscribe_node_event(AggregationEvent, self.broadcast_models_include) + + async def init_message_callbacks(self): + logging.info("Registering callbacks for MessageEvents...") + await self.register_message_events_callbacks() + # Additional callbacks not registered automatically + await self.register_message_callback(("model", "initialization"), "model_initialization_callback") + await self.register_message_callback(("model", "update"), "model_update_callback") + async def register_message_events_callbacks(self): me_dict = self.cm.get_messages_events() message_events = [ @@ -598,7 +601,10 @@ async def update_neighbors(self, removed_neighbor_addr, neighbors, remove=False) updt_nei_event = UpdateNeighborEvent(removed_neighbor_addr, remove) asyncio.create_task(EventManager.get_instance().publish_node_event(updt_nei_event)) - #await self.aggregator.notify_federation_nodes_removed(removed_neighbor_addr, remove=remove) + async def broadcast_models_include(self, age : AggregationEvent): + logging.info(f"πŸ”„ Broadcasting MODELS_INCLUDED for round {self.get_round()}") + message = self.cm.create_message("federation", "federation_models_included", [str(arg) for arg in [self.get_round()]]) + asyncio.create_task(self.cm.send_message_to_neighbors(message)) async def update_model_learning_rate(self, new_lr): await self.trainning_in_progress_lock.acquire_async() @@ -658,7 +664,7 @@ async def create_trainer_module(self): logging.info("Started trainer module...") async def start_communications(self): - await self.init_message_callbacks() + await self.register_events_callbacks() await self.aggregator.init() logging.info(f"Neighbors: {self.config.participant['network_args']['neighbors']}") logging.info( @@ -850,7 +856,7 @@ async def _learning_cycle(self): ) # Set current round in config (send to the controller) await self.get_round_lock().release_async() - await self.nm.experiment_finish() + if self.mobility: await self.nm.experiment_finish() # End of the learning cycle self.trainer.on_learning_cycle_end() await self.trainer.test() @@ -1050,7 +1056,6 @@ def __init__( async def _extended_learning_cycle(self): # Define the functionality of the trainer node logging.info("Waiting global update | Assign _waiting_global_update = True") - self.aggregator.set_waiting_global_update() await self.trainer.test() await self.trainer.train() @@ -1082,5 +1087,4 @@ def __init__( async def _extended_learning_cycle(self): # Define the functionality of the idle node logging.info("Waiting global update | Assign _waiting_global_update = True") - self.aggregator.set_waiting_global_update() await self._waiting_model_updates() diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 3138e16ee..5e8c1415c 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -102,18 +102,10 @@ def forwarder(self): def propagator(self): return self._propagator - # @property - # def mobility(self): - # return self._mobility - @property def ecs(self): return self._external_connection_service - # @property - # def ns(self): - # return self._network_simulator - @property def bl(self): return self._blacklist @@ -452,16 +444,13 @@ async def network_wait(self): async def deploy_additional_services(self): logging.info("🌐 Deploying additional services...") - # self._generate_network_conditions() await self._forwarder.start() if self.config.participant["mobility_args"]["mobility"]: if self.config.participant["network_args"]["simulation"]: pass - #await self.ns.start() # await self._discoverer.start() # await self._health.start() self._propagator.start() - #await self._mobility.start() async def include_received_message_hash(self, hash_message): diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py index a33feb4dc..4c398fbba 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py @@ -7,17 +7,26 @@ import time import asyncio +class TimeStamp(): + def __init__(self, time_received = None, time_since_last_event = None): + self.tr = time_received + self.tsle = time_since_last_event + + def reset(self): + self.tr = None + self.tsle = None + # "Speed-Oriented Selection" (SOS) class SOSTrainingPolicy(TrainingPolicy): MAX_HISTORIC_SIZE = 10 INACTIVE_THRESHOLD = 3 GRACE_ROUNDS = 0 CHECK_COOLDOWN = 1 - + def __init__(self, config): self._addr = config["addr"] self._verbose = config["verbose"] - self._nodes : dict[str, tuple[deque, int, float, float]] = {} # _nodes estructura: {node_id: (deque updates epr round, inactivity, time gap between updates, time since last aggregation)} + self._nodes : dict[str, tuple[deque, int, deque[TimeStamp], TimeStamp]] = {} # _nodes estructura: {node_id: (deque updates epr round, inactivity, time gaps between updates, time since last aggregation)} self._nodes_lock = Locker(name="nodes_lock", async_lock=True) self._grace_rounds = self.GRACE_ROUNDS @@ -28,7 +37,7 @@ def __init__(self, config): async def init(self, config): async with self._nodes_lock: nodes = config["nodes"] - self._nodes = {node_id: (deque(maxlen=self.MAX_HISTORIC_SIZE), 0, float('inf'), float('inf')) for node_id in nodes} + self._nodes = {node_id: (deque(maxlen=self.MAX_HISTORIC_SIZE), 0, deque(maxlen=self.MAX_HISTORIC_SIZE), TimeStamp()) for node_id in nodes} await EventManager.get_instance().subscribe_node_event(UpdateReceivedEvent, self._process_update_received_event) await EventManager.get_instance().subscribe_node_event(RoundStartEvent, self._process_first_round_start) await EventManager.get_instance().subscribe_node_event(AggregationEvent, self._process_aggregation_event) @@ -60,6 +69,7 @@ async def _process_aggregation_event(self, are : AggregationEvent): async def _process_update_received_event(self, ure : UpdateReceivedEvent): + #TODO rehacer con timestamp time_received = time.time() if self._verbose: logging.info("Processing Update Received event") (_, _, source, _, _) = await ure.get_event_data() @@ -68,7 +78,7 @@ async def _process_update_received_event(self, ure : UpdateReceivedEvent): if source not in self._nodes: return - history, missed_count, first_update_time, last_update_time = self._nodes[source] + history, missed_count, time_between_updts_historic, last_update_time = self._nodes[source] if history and history[-1][0] == self._internal_rounds_done: num_updates = history[-1][1] + 1 @@ -76,19 +86,26 @@ async def _process_update_received_event(self, ure : UpdateReceivedEvent): else: history.append((self._internal_rounds_done, 1)) - if first_update_time == float('inf'): + if time_between_updts == float("inf"): if self._last_aggregation_time: - first_update_time = time_received - self._last_aggregation_time + time_between_updts = time_received - self._last_aggregation_time else: - first_update_time = 0 - - #TODO el error estΓ‘ aquΓ­ hay q comprobar con respecto a self_aggregation_time + time_between_updts = 0 + else: + pass + #time_between_updts = + + #la cuestion es: + # en cada ronda se actualiza el tiempo de inicio desde la agregaciΓ³n + # por lo tanto habria que resetear los tiempos? seria un procedimiento sin memoria + + if last_update_time == float('inf'): - last_update_time = first_update_time + last_update_time = time_between_updts else: last_update_time = time_received - last_update_time - self._nodes[source] = (history, missed_count, first_update_time, last_update_time) + self._nodes[source] = (history, missed_count, time_between_updts, last_update_time) async def update_neighbors(self, node, remove=False): async with self._nodes_lock: @@ -118,4 +135,6 @@ async def evaluate(self): self._last_check = (self._last_check + 1) % self.CHECK_COOLDOWN - return result \ No newline at end of file + return result + + \ No newline at end of file From 0bd99b2282f6a59456a5ff0ab2561be95b8836fe Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 6 Mar 2025 10:04:15 +0100 Subject: [PATCH 126/233] feature beacon received event --- nebula/core/engine.py | 2 +- nebula/core/nebulaevents.py | 38 ++++++++++++++++++- nebula/core/network/communications.py | 7 +--- .../externalconnectionservice.py | 4 -- .../nebuladiscoveryservice.py | 19 +++------- .../awareness/samodule.py | 2 +- .../awareness/sanetwork/sanetwork.py | 11 ++++-- 7 files changed, 53 insertions(+), 30 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index b1361482d..87e9b2843 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -147,7 +147,7 @@ def __init__( self.trainning_in_progress_lock = Locker(name="trainning_in_progress_lock", async_lock=True) - event_manager = EventManager.get_instance(verbose=False) + event_manager = EventManager.get_instance(verbose=True) # Mobility setup self._node_manager = None diff --git a/nebula/core/nebulaevents.py b/nebula/core/nebulaevents.py index 164413fc2..a98c6dbc5 100644 --- a/nebula/core/nebulaevents.py +++ b/nebula/core/nebulaevents.py @@ -41,6 +41,13 @@ def __str__(self): return "Round starting" async def get_event_data(self): + """Retrieves the round start event data. + + Returns: + tuple[int, float]: + -round (int): Round number. + -start_time (time): Current time when round is going to start. + """ return (self._round, self._round_start_time) async def is_concurrent(self): @@ -125,7 +132,7 @@ def __init__(self, decoded_model, weight, source, round, local=False): def __str__(self): return f"Update received from source: {self._source}, round: {self._round}" - async def get_event_data(self) -> tuple[str, bool]: + async def get_event_data(self) -> tuple[object, int, str, int, bool]: """ Retrieves the event data. @@ -142,6 +149,35 @@ async def get_event_data(self) -> tuple[str, bool]: async def is_concurrent(self) -> bool: return False +class BeaconRecievedEvent(NodeEvent): + def __init__(self, source, geoloc): + """ + Initializes an BeaconRecievedEvent. + + Args: + source (str): The received beacon source. + geoloc (tuple): The geolocalzition associated with the received beacon source. + """ + self._source = source + self._geoloc = geoloc + + def __str__(self): + return "Beacon recieved" + + async def get_event_data(self) -> tuple[str, tuple[float, float]]: + """ + Retrieves the event data. + + Returns: + tuple[str, tuple[float, float]]: A tuple containing: + - The beacon's source. + - the device geolocalization (latitude, longitude). + """ + return (self._source, self._geoloc) + + async def is_concurrent(self) -> bool: + return True + """ ############################## # ADDON EVENTS # diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 5e8c1415c..45f64876a 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -201,9 +201,6 @@ async def start_beacon(self): async def stop_beacon(self): await self.ecs.stop_beacon() - async def subscribe_beacon_listener(self, listener): - await self.ecs.subscribe_beacon_listener(listener) - async def modify_beacon_frequency(self, frequency): await self.ecs.modify_beacon_frequency(frequency) @@ -445,9 +442,7 @@ async def network_wait(self): async def deploy_additional_services(self): logging.info("🌐 Deploying additional services...") await self._forwarder.start() - if self.config.participant["mobility_args"]["mobility"]: - if self.config.participant["network_args"]["simulation"]: - pass + # await self._discoverer.start() # await self._health.start() self._propagator.start() diff --git a/nebula/core/network/externalconnection/externalconnectionservice.py b/nebula/core/network/externalconnection/externalconnectionservice.py index 1e2241bb2..cd356a324 100644 --- a/nebula/core/network/externalconnection/externalconnectionservice.py +++ b/nebula/core/network/externalconnection/externalconnectionservice.py @@ -31,10 +31,6 @@ async def stop_beacon(self): async def modify_beacon_frequency(self, frequency): pass - @abstractmethod - async def subscribe_beacon_listener(self, listener): - pass - class ExternalConnectionServiceException(Exception): pass diff --git a/nebula/core/network/externalconnection/nebuladiscoveryservice.py b/nebula/core/network/externalconnection/nebuladiscoveryservice.py index 0b4742832..e48c5c1d4 100644 --- a/nebula/core/network/externalconnection/nebuladiscoveryservice.py +++ b/nebula/core/network/externalconnection/nebuladiscoveryservice.py @@ -4,6 +4,8 @@ import struct from nebula.core.network.externalconnection.externalconnectionservice import ExternalConnectionService from nebula.core.utils.locker import Locker +from nebula.core.nebulaevents import BeaconRecievedEvent +from nebula.core.eventmanager import EventManager from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -168,8 +170,6 @@ def __init__(self, cm: "CommunicationsManager", addr): self.client : NebulaClientProtocol = None self.beacon : NebulaBeacon = NebulaBeacon(self, self.addr) self.running = False - self._beacon_listeners_lock = Locker(name="beacon_listeners_lock", async_lock=True) - self._beacon_listeners = [] @property def cm(self): @@ -235,15 +235,8 @@ def response_received(self, data, addr): if addr not in self.nodes_found: logging.info(f"Device address received: {addr}") self.nodes_found.add(addr) - - async def subscribe_beacon_listener(self, listener : callable): - await self._beacon_listeners_lock.acquire_async() - logging.info("Registering beacon listener...") - self._beacon_listeners.append(listener) - await self._beacon_listeners_lock.release_async() - + async def notify_beacon_received(self, addr, geoloc): - await self._beacon_listeners_lock.acquire_async() - for bec_listener in self._beacon_listeners: - await bec_listener(addr, geoloc) - await self._beacon_listeners_lock.release_async() \ No newline at end of file + beacon_event = BeaconRecievedEvent(addr, geoloc) + asyncio.create_task(EventManager.get_instance().publish_node_event(beacon_event)) + \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index 6256d8162..84f4d8f88 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -30,7 +30,7 @@ def __init__( self._topology = topology self._node_manager: NodeManager = nodemanager self._situational_awareness_network = SANetwork(self, self.cm, self._addr, self._topology) - self._situational_awareness_training = SATraining(self, self._addr, "sos", "fastreboot", verbose=True) + self._situational_awareness_training = SATraining(self, self._addr, "qds", "fastreboot", verbose=True) self._restructure_process_lock = Locker(name="restructure_process_lock") self._restructure_cooldown = 0 diff --git a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py index 0dddcbbfa..94ea4af42 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py @@ -3,6 +3,8 @@ from nebula.core.utils.locker import Locker from nebula.core.situationalawareness.awareness.sanetwork.neighborpolicies.neighborpolicy import factory_NeighborPolicy from nebula.addons.functions import print_msg_box +from nebula.core.nebulaevents import BeaconRecievedEvent +from nebula.core.eventmanager import EventManager from typing import TYPE_CHECKING if TYPE_CHECKING: from nebula.core.network.communications import CommunicationsManager @@ -49,7 +51,7 @@ async def init(self): if not self.sam.is_additional_participant(): logging.info("Deploying External Connection Service") await self.cm.start_external_connection_service() - await self.cm.subscribe_beacon_listener(self.beacon_received) + await EventManager.get_instance().subscribe_node_event(BeaconRecievedEvent, self.beacon_received) await self.cm.start_beacon() else: logging.info("Deploying External Connection Service | No running") @@ -113,16 +115,17 @@ async def _check_external_connection_service_status(self): if not await self.cm.is_external_connection_service_running(): logging.info("πŸ”„ External Service not running | Starting service...") await self.cm.init_external_connection_service() - await self.cm.subscribe_beacon_listener(self.beacon_received) + await EventManager.get_instance().subscribe_node_event(BeaconRecievedEvent, self.beacon_received) await self.cm.start_beacon() async def experiment_finish(self): await self.cm.stop_external_connection_service() - async def beacon_received(self, addr, geoloc): + async def beacon_received(self, beacon_recieved_event : BeaconRecievedEvent): + addr, geoloc = await beacon_recieved_event.get_event_data() latitude, longitude = geoloc self.meet_node(addr) - #logging.info(f"Beacon received SANetwork, source: {addr}, geolocalization: {latitude},{longitude}") + logging.info(f"Beacon received SANetwork, source: {addr}, geolocalization: {latitude},{longitude}") """ ############################### # REESTRUCTURE TOPOLOGY # From 836fbad3d70cdb3110e0d233f273b17f78c36647 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Fri, 7 Mar 2025 17:20:06 +0100 Subject: [PATCH 127/233] feature sos sa strategy --- nebula/core/addonmanager.py | 2 +- nebula/core/engine.py | 114 +--------------- nebula/core/nebulaevents.py | 27 ++++ .../awareness/samodule.py | 13 +- .../awareness/sanetwork/sanetwork.py | 2 +- .../trainingpolicy/sostrainingpolicy.py | 128 ++++++++++++++---- .../core/situationalawareness/nodemanager.py | 5 +- 7 files changed, 143 insertions(+), 148 deletions(-) diff --git a/nebula/core/addonmanager.py b/nebula/core/addonmanager.py index 95651f568..fd712114d 100644 --- a/nebula/core/addonmanager.py +++ b/nebula/core/addonmanager.py @@ -15,7 +15,7 @@ def __init__(self, engine : "Engine", config): self._addons = [] async def deploy_additional_services(self): - print_msg_box(msg="Deploying Additional Services\n(='.'=)", indent=2, title="Addons Manager") + print_msg_box(msg="Deploying Additional Services", indent=2, title="Addons Manager") if self._config.participant["mobility_args"]["mobility"]: mobility = Mobility(self._config, self._engine.cm, verbose=False) self._addons.append(mobility) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 87e9b2843..d6da6100c 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -9,7 +9,7 @@ from nebula.addons.reporter import Reporter from nebula.core.aggregation.aggregator import create_aggregator, create_target_aggregator from nebula.core.eventmanager import EventManager -from nebula.core.nebulaevents import UpdateNeighborEvent, UpdateReceivedEvent, RoundStartEvent, AggregationEvent +from nebula.core.nebulaevents import UpdateNeighborEvent, UpdateReceivedEvent, RoundStartEvent, AggregationEvent, RoundEndEvent from nebula.core.network.communications import CommunicationsManager from nebula.core.situationalawareness.nodemanager import NodeManager from nebula.core.utils.locker import Locker @@ -147,7 +147,7 @@ def __init__( self.trainning_in_progress_lock = Locker(name="trainning_in_progress_lock", async_lock=True) - event_manager = EventManager.get_instance(verbose=True) + event_manager = EventManager.get_instance(verbose=False) # Mobility setup self._node_manager = None @@ -341,18 +341,6 @@ async def _federation_federation_start_callback(self, source, message): logging.info(f"πŸ“ handle_federation_message | Trigger | Received start federation message from {source}") await self.create_trainer_module() - async def _federation_reputation_callback(self, source, message): - malicious_nodes = message.arguments # List of malicious nodes - if self.with_reputation: - if len(malicious_nodes) > 0 and not self._is_malicious: - if self.is_dynamic_topology: - await self._disrupt_connection_using_reputation(malicious_nodes) - if self.is_dynamic_aggregation and self.aggregator != self.target_aggregation: - await self._dynamic_aggregator( - self.aggregator.get_nodes_pending_models_to_aggregate(), - malicious_nodes, - ) - async def _federation_federation_models_included_callback(self, source, message): logging.info(f"πŸ“ handle_federation_message | Trigger | Received aggregation finished message from {source}") try: @@ -756,50 +744,6 @@ async def _start_learning(self): if await self.learning_cycle_lock.locked_async(): await self.learning_cycle_lock.release_async() - async def _disrupt_connection_using_reputation(self, malicious_nodes): - malicious_nodes = list(set(malicious_nodes) & set(self.get_current_connections())) - logging.info(f"Disrupting connection with malicious nodes at round {self.round}") - logging.info(f"Removing {malicious_nodes} from {self.get_current_connections()}") - logging.info(f"Current connections before aggregation at round {self.round}: {self.get_current_connections()}") - for malicious_node in malicious_nodes: - if (self.get_name() != malicious_node) and (malicious_node not in self._secure_neighbors): - await self.cm.disconnect(malicious_node) - logging.info(f"Current connections after aggregation at round {self.round}: {self.get_current_connections()}") - - await self._connect_with_benign(malicious_nodes) - - async def _connect_with_benign(self, malicious_nodes): - lower_threshold = 1 - higher_threshold = len(self.federation_nodes) - 1 - if higher_threshold < lower_threshold: - higher_threshold = lower_threshold - - benign_nodes = [i for i in self.federation_nodes if i not in malicious_nodes] - logging.info(f"_reputation_callback benign_nodes at round {self.round}: {benign_nodes}") - if len(self.get_current_connections()) <= lower_threshold: - for node in benign_nodes: - if len(self.get_current_connections()) <= higher_threshold and self.get_name() != node: - connected = await self.cm.connect(node) - if connected: - logging.info(f"Connect new connection with at round {self.round}: {connected}") - - async def _dynamic_aggregator(self, aggregated_models_weights, malicious_nodes): - logging.info(f"malicious detected at round {self.round}, change aggergation protocol!") - if self.aggregator != self.target_aggregation: - logging.info(f"Current aggregator is: {self.aggregator}") - self.aggregator = self.target_aggregation - await self.aggregator.update_federation_nodes(self.federation_nodes) - - for subnodes in aggregated_models_weights.keys(): - sublist = subnodes.split() - (submodel, weights) = aggregated_models_weights[subnodes] - for node in sublist: - if node not in malicious_nodes: - await self.aggregator.include_model_in_buffer( - submodel, weights, source=self.get_name(), round=self.round - ) - logging.info(f"Current aggregator is: {self.aggregator}") - async def _waiting_model_updates(self): logging.info(f"πŸ’€ Waiting convergence in round {self.round}.") params = await self.aggregator.get_aggregation() @@ -840,7 +784,10 @@ async def _learning_cycle(self): logging.info(f"[Role {self.role}] Starting learning cycle...") await self.aggregator.update_federation_nodes(self.federation_nodes) await self._extended_learning_cycle() - await self._additional_mobility_actions() + + current_time = time.time() + ree = RoundEndEvent(self.round, current_time) + await EventManager.get_instance().publish_node_event(ree) await self.get_round_lock().acquire_async() print_msg_box( @@ -893,54 +840,7 @@ async def _extended_learning_cycle(self): """ pass - async def _additional_mobility_actions(self): - if not self.mobility: - return - logging.info("πŸ”„ Starting additional mobility actions...") - await self.nm.mobility_actions() - - def reputation_calculation(self, aggregated_models_weights): - cossim_threshold = 0.5 - loss_threshold = 0.5 - - current_models = {} - for subnodes in aggregated_models_weights.keys(): - sublist = subnodes.split() - submodel = aggregated_models_weights[subnodes][0] - for node in sublist: - current_models[node] = submodel - - malicious_nodes = [] - reputation_score = {} - local_model = self.trainer.get_model_parameters() - untrusted_nodes = list(current_models.keys()) - logging.info(f"reputation_calculation untrusted_nodes at round {self.round}: {untrusted_nodes}") - - for untrusted_node in untrusted_nodes: - logging.info(f"reputation_calculation untrusted_node at round {self.round}: {untrusted_node}") - logging.info(f"reputation_calculation self.get_name() at round {self.round}: {self.get_name()}") - if untrusted_node != self.get_name(): - untrusted_model = current_models[untrusted_node] - cossim = cosine_metric(local_model, untrusted_model, similarity=True) - logging.info(f"reputation_calculation cossim at round {self.round}: {untrusted_node}: {cossim}") - self.trainer._logger.log_data({f"Reputation/cossim_{untrusted_node}": cossim}, step=self.round) - - avg_loss = self.trainer.validate_neighbour_model(untrusted_model) - logging.info(f"reputation_calculation avg_loss at round {self.round} {untrusted_node}: {avg_loss}") - self.trainer._logger.log_data({f"Reputation/avg_loss_{untrusted_node}": avg_loss}, step=self.round) - reputation_score[untrusted_node] = (cossim, avg_loss) - - if cossim < cossim_threshold or avg_loss > loss_threshold: - malicious_nodes.append(untrusted_node) - else: - self._secure_neighbors.append(untrusted_node) - - return malicious_nodes, reputation_score - - async def send_reputation(self, malicious_nodes): - logging.info(f"Sending REPUTATION to the rest of the topology: {malicious_nodes}") - message = self.cm.create_message("federation", "reputation", arguments=[str(arg) for arg in (malicious_nodes)]) - await self.cm.send_message_to_neighbors(message) + class MaliciousNode(Engine): diff --git a/nebula/core/nebulaevents.py b/nebula/core/nebulaevents.py index a98c6dbc5..42b2bd6a0 100644 --- a/nebula/core/nebulaevents.py +++ b/nebula/core/nebulaevents.py @@ -52,6 +52,33 @@ async def get_event_data(self): async def is_concurrent(self): return False + +class RoundEndEvent(NodeEvent): + def __init__(self, round, end_time): + """Event triggered when round is going to start. + + Args: + round (int): Round number. + end_time (time): Current time when round has ended. + """ + self._round_end_time = end_time + self._round = round + + def __str__(self): + return "Round ending" + + async def get_event_data(self): + """Retrieves the round start event data. + + Returns: + tuple[int, float]: + -round (int): Round number. + -end_time (time): Current time when round has ended. + """ + return (self._round, self._round_end_time) + + async def is_concurrent(self): + return False class AggregationEvent(NodeEvent): def __init__(self, updates : dict, expected_nodes : set, missing_nodes : set): diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index 84f4d8f88..d5bae85e6 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -1,12 +1,13 @@ import asyncio import logging -from typing import TYPE_CHECKING - from nebula.addons.functions import print_msg_box from nebula.core.situationalawareness.awareness.sanetwork.sanetwork import SANetwork from nebula.core.situationalawareness.awareness.satraining.satraining import SATraining from nebula.core.utils.locker import Locker +from nebula.core.nebulaevents import RoundEndEvent +from nebula.core.eventmanager import EventManager +from typing import TYPE_CHECKING if TYPE_CHECKING: from nebula.core.situationalawareness.nodemanager import NodeManager @@ -30,7 +31,7 @@ def __init__( self._topology = topology self._node_manager: NodeManager = nodemanager self._situational_awareness_network = SANetwork(self, self.cm, self._addr, self._topology) - self._situational_awareness_training = SATraining(self, self._addr, "qds", "fastreboot", verbose=True) + self._situational_awareness_training = SATraining(self, self._addr, "sos", "fastreboot", verbose=True) self._restructure_process_lock = Locker(name="restructure_process_lock") self._restructure_cooldown = 0 @@ -52,9 +53,10 @@ def cm(self): async def init(self): - #if not self.is_additional_participant(): + await EventManager.get_instance().subscribe_node_event(RoundEndEvent, self._mobility_actions) await self.san.init() await self.sat.init() + def is_additional_participant(self): return self.nm.is_additional_participant() @@ -67,7 +69,8 @@ async def get_geoloc(self): longitude = self.nm.config.participant["mobility_args"]["longitude"] return (latitude, longitude) - async def mobility_actions(self): + async def _mobility_actions(self, ree : RoundEndEvent): + logging.info("πŸ”„ Starting additional mobility actions...") await self.san.module_actions() await self.sat.module_actions() diff --git a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py index 94ea4af42..0c53678a3 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py @@ -125,7 +125,7 @@ async def beacon_received(self, beacon_recieved_event : BeaconRecievedEvent): addr, geoloc = await beacon_recieved_event.get_event_data() latitude, longitude = geoloc self.meet_node(addr) - logging.info(f"Beacon received SANetwork, source: {addr}, geolocalization: {latitude},{longitude}") + #logging.info(f"Beacon received SANetwork, source: {addr}, geolocalization: {latitude},{longitude}") """ ############################### # REESTRUCTURE TOPOLOGY # diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py index 4c398fbba..e01218efb 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py @@ -11,6 +11,19 @@ class TimeStamp(): def __init__(self, time_received = None, time_since_last_event = None): self.tr = time_received self.tsle = time_since_last_event + + def __sub__(self, other): + if not isinstance(other, TimeStamp): + raise TypeError("Subtraction is only supported between TimeStamp instances") + if self.tr is None or other.tr is None: + raise ValueError("Cannot subtract TimeStamp instances with undefined 'tr' values") + return self.tr - other.tr + + def __str__(self): + return f"{self.tsle}s" + + def is_empty(self): + return self.tr == None def reset(self): self.tr = None @@ -20,8 +33,12 @@ def reset(self): class SOSTrainingPolicy(TrainingPolicy): MAX_HISTORIC_SIZE = 10 INACTIVE_THRESHOLD = 3 - GRACE_ROUNDS = 0 + GRACE_ROUNDS = 1 CHECK_COOLDOWN = 1 + W_UPDATE_FREQ = 0.4 # Update frequency weight + W_UPDATE_LATENCY = 0.3 # update latency weight + W_AGG_WAITING = 0.2 # time waited since start waiting for aggregation until update is received weight + W_INACTIVITY_PEN = 0.1 # inactivity penalty weight def __init__(self, config): self._addr = config["addr"] @@ -39,16 +56,15 @@ async def init(self, config): nodes = config["nodes"] self._nodes = {node_id: (deque(maxlen=self.MAX_HISTORIC_SIZE), 0, deque(maxlen=self.MAX_HISTORIC_SIZE), TimeStamp()) for node_id in nodes} await EventManager.get_instance().subscribe_node_event(UpdateReceivedEvent, self._process_update_received_event) - await EventManager.get_instance().subscribe_node_event(RoundStartEvent, self._process_first_round_start) + await EventManager.get_instance().subscribe_node_event(RoundStartEvent, self._process_round_start) await EventManager.get_instance().subscribe_node_event(AggregationEvent, self._process_aggregation_event) - async def _get_nodes(self): async with self._nodes_lock: nodes = self._nodes.copy() return nodes - async def _process_first_round_start(self, rse : RoundStartEvent): + async def _process_round_start(self, rse : RoundStartEvent): if self._verbose: logging.info("Processing round start event") if not self._last_aggregation_time: if self._verbose: logging.info("First round start timing assigment") @@ -67,9 +83,7 @@ async def _process_aggregation_event(self, are : AggregationEvent): history, missed_count, gap_btween_updts, time_since_agg = self._nodes[node] self._nodes[node] = (history, 0 if node not in missing_nodes else missed_count + 1, gap_btween_updts, time_since_agg) - async def _process_update_received_event(self, ure : UpdateReceivedEvent): - #TODO rehacer con timestamp time_received = time.time() if self._verbose: logging.info("Processing Update Received event") (_, _, source, _, _) = await ure.get_event_data() @@ -86,26 +100,17 @@ async def _process_update_received_event(self, ure : UpdateReceivedEvent): else: history.append((self._internal_rounds_done, 1)) - if time_between_updts == float("inf"): - if self._last_aggregation_time: - time_between_updts = time_received - self._last_aggregation_time - else: - time_between_updts = 0 - else: - pass - #time_between_updts = - - #la cuestion es: - # en cada ronda se actualiza el tiempo de inicio desde la agregaciΓ³n - # por lo tanto habria que resetear los tiempos? seria un procedimiento sin memoria - - - if last_update_time == float('inf'): - last_update_time = time_between_updts + if not len(time_between_updts_historic): + time_between_updts_historic.append(TimeStamp(time_received, None)) else: - last_update_time = time_received - last_update_time + ts = TimeStamp(time_received) + ts.tsle = ts - time_between_updts_historic[-1] + time_between_updts_historic.append(ts) + + last_update_time.tr = time_received + last_update_time.tsle = time_received - self._last_aggregation_time - self._nodes[source] = (history, missed_count, time_between_updts, last_update_time) + self._nodes[source] = (history, missed_count, time_between_updts_historic, last_update_time) async def update_neighbors(self, node, remove=False): async with self._nodes_lock: @@ -116,6 +121,7 @@ async def update_neighbors(self, node, remove=False): self._nodes.update({node : (deque(maxlen=self.MAX_HISTORIC_SIZE), 0, float('inf'), float('inf'))}) async def evaluate(self): + if self._verbose: logging.info("Evaluating using speed-driven strategy") if self._grace_rounds: # Grace rounds self._grace_rounds -= 1 if self._verbose: logging.info("Grace time hasnt finished...") @@ -125,15 +131,77 @@ async def evaluate(self): if self._last_check == 0: nodes = await self._get_nodes() for node in nodes.keys(): - logging.info(f"Internal rounds done: {self._internal_rounds_done}") - logging.info(f"Node: {node}, {nodes[node][0]}") - updates_received = {x[1] for x in nodes[node][0] if x[0] == self._internal_rounds_done} - logging.info(f"Node: {node}, Updates received this round: {updates_received}, last gap: {nodes[node][2]}, time since last agg: {nodes[node][3]}") - + pass + # logging.info(f"Internal rounds done: {self._internal_rounds_done}") + # logging.info(f"Node: {node}, {nodes[node][0]}") + # updates_received = {x[1] for x in nodes[node][0] if x[0] == self._internal_rounds_done} + # logging.info(f"Node: {node} | Updates received this round: {updates_received}") + # logging.info(f"Time waited since last aggregation event {nodes[node][3]}") else: if self._verbose: logging.info(f"Evaluation is on cooldown... | {self.CHECK_COOLDOWN - self._last_check} rounds remaining") - self._last_check = (self._last_check + 1) % self.CHECK_COOLDOWN + # Extraer valores mΓ‘ximos y mΓ­nimos para normalizaciΓ³n + max_updates = max( + ( + max((x[1] for x in nodes[n][0] if x[0] == self._internal_rounds_done), default=0) + for n in nodes + ), + default=1 + ) + + min_latency = min( + ( + sum(t.tsle for t in nodes[n][2] if t.tsle is not None and t.tsle != float('inf')) / len(nodes[n][2]) + if any(t.tsle is not None and t.tsle != float('inf') for t in nodes[n][2]) + else float('inf') + for n in nodes + ), + default=1 + ) + + min_wait_time = min( + ( + nodes[n][3].tsle if nodes[n][3] and nodes[n][3].tsle is not None else float('inf') + for n in nodes + ), + default=1 + ) + + scores = {} + + for node, (history, missed_count, time_between_updts_historic, last_update_time) in nodes.items(): + # 1. Frecuencia de updates normalizada + updates_received = max((x[1] for x in history if x[0] == self._internal_rounds_done), default=0) + F_updt_freq = updates_received / max_updates if max_updates > 0 else 0 + + # 2. Latencia media entre updates normalizada + valid_latencies = [t.tsle for t in time_between_updts_historic if t.tsle is not None and t.tsle != float('inf')] + avg_latency = sum(valid_latencies) / len(valid_latencies) if valid_latencies else float('inf') + F_updt_latency = min_latency / avg_latency if avg_latency > 0 and avg_latency != float('inf') else 0 + + # 3. Tiempo desde ΓΊltima agregaciΓ³n normalizado + wait_time = last_update_time.tsle if last_update_time.tsle is not None else float('inf') + F_agg_waiting = min_wait_time / wait_time if wait_time > 0 else 0 + + # 4. PenalizaciΓ³n por inactividad + P_n = 1 / (1 + missed_count) # PenalizaciΓ³n inversamente proporcional + + # Calcular puntuaciΓ³n final + score = ( + (self.W_UPDATE_FREQ * F_updt_freq) + + (self.W_UPDATE_LATENCY * F_updt_latency) + + (self.W_AGG_WAITING * F_agg_waiting) + + (self.W_INACTIVITY_PEN * P_n) + ) + scores[node] = score + + # Ordenar nodos por puntuaciΓ³n descendente + sorted_nodes = sorted(scores.items(), key=lambda x: x[1], reverse=True) + + if self._verbose: + for node, score in sorted_nodes: + logging.info(f"Node: {node} | Score: {score:.3f}") + self._last_check = (self._last_check + 1) % self.CHECK_COOLDOWN return result diff --git a/nebula/core/situationalawareness/nodemanager.py b/nebula/core/situationalawareness/nodemanager.py index 33c922ba9..feeca4c00 100644 --- a/nebula/core/situationalawareness/nodemanager.py +++ b/nebula/core/situationalawareness/nodemanager.py @@ -106,10 +106,7 @@ async def set_configs(self): async def get_geoloc(self): return await self.sam.get_geoloc() - - async def mobility_actions(self): - await self.sam.mobility_actions() - + async def experiment_finish(self): await self.sam.experiment_finish() From ef5033baed30886495283d694a31cc925128e493 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Mon, 10 Mar 2025 13:36:20 +0100 Subject: [PATCH 128/233] opt sat sos --- .../attacks/communications/delayerattack.py | 4 +- nebula/addons/mobility.py | 4 +- nebula/core/engine.py | 5 ++ .../trainingpolicy/sostrainingpolicy.py | 66 ++++++++++++------- 4 files changed, 51 insertions(+), 28 deletions(-) diff --git a/nebula/addons/attacks/communications/delayerattack.py b/nebula/addons/attacks/communications/delayerattack.py index 6b0ab9c67..f8b79731a 100644 --- a/nebula/addons/attacks/communications/delayerattack.py +++ b/nebula/addons/attacks/communications/delayerattack.py @@ -22,8 +22,8 @@ def __init__(self, engine, attack_params: dict): self.delay = int(attack_params["delay"]) round_start = int(attack_params["round_start_attack"]) round_stop = int(attack_params["round_stop_attack"]) - self.target_percentage = 50#int(attack_params["target_percentage"]) - self.selection_interval = 1#int(attack_params["selection_interval"]) + self.target_percentage = 100#int(attack_params["target_percentage"]) + self.selection_interval = None#int(attack_params["selection_interval"]) except KeyError as e: raise ValueError(f"Missing required attack parameter: {e}") except ValueError: diff --git a/nebula/addons/mobility.py b/nebula/addons/mobility.py index 787c4d7db..5f7f0d063 100755 --- a/nebula/addons/mobility.py +++ b/nebula/addons/mobility.py @@ -195,7 +195,7 @@ async def change_geo_location_nearest_neighbor_strategy( coordinates to determine the direction of movement. - The conversion from meters to degrees is based on approximate geographical conversion factors. """ - logging.info("πŸ“ Changing geo location towards the nearest neighbor") + if self._verbose: logging.info("πŸ“ Changing geo location towards the nearest neighbor") scale_factor = min(1, self.max_movement_nearest_strategy / distance) # Calcular el Γ‘ngulo hacia el vecino angle = math.atan2(neighbor_longitude - longitude, neighbor_latitude - latitude) @@ -281,7 +281,7 @@ async def change_geo_location(self): addr, dist, (lat, long) = selected_neighbor if dist > self.max_initiate_approximation: # If the distance is too big, we move towards the neighbor - logging.info(f"Moving towards nearest neighbor: {addr}") + if self._verbose: logging.info(f"Moving towards nearest neighbor: {addr}") await self.change_geo_location_nearest_neighbor_strategy( dist, latitude, diff --git a/nebula/core/engine.py b/nebula/core/engine.py index d6da6100c..ad045e1ca 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -767,9 +767,14 @@ def learning_cycle_finished(self): async def _learning_cycle(self): while self.round is not None and self.round < self.total_rounds: + if self.addr.split()[0][-1] == "5": + logging.info("### sleeping time ###") + time.sleep(30) + current_time = time.time() rse = RoundStartEvent(self.round, current_time) await EventManager.get_instance().publish_node_event(rse) + print_msg_box( msg=f"Round {self.round} of {self.total_rounds} started.", indent=2, diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py index e01218efb..6c8e1c3c5 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py @@ -19,6 +19,13 @@ def __sub__(self, other): raise ValueError("Cannot subtract TimeStamp instances with undefined 'tr' values") return self.tr - other.tr + def __add__(self, other): + if not isinstance(other, TimeStamp): + raise TypeError("Subtraction is only supported between TimeStamp instances") + if self.tsle is None or other.tsle is None: + raise ValueError("Cannot subtract TimeStamp instances with undefined 'tsle' values") + return self.tsle + other.tsle + def __str__(self): return f"{self.tsle}s" @@ -32,18 +39,19 @@ def reset(self): # "Speed-Oriented Selection" (SOS) class SOSTrainingPolicy(TrainingPolicy): MAX_HISTORIC_SIZE = 10 + SCORE_THRESHOLD = 0.7 INACTIVE_THRESHOLD = 3 GRACE_ROUNDS = 1 CHECK_COOLDOWN = 1 - W_UPDATE_FREQ = 0.4 # Update frequency weight - W_UPDATE_LATENCY = 0.3 # update latency weight - W_AGG_WAITING = 0.2 # time waited since start waiting for aggregation until update is received weight - W_INACTIVITY_PEN = 0.1 # inactivity penalty weight + W_UPDATE_FREQ = 0.25 # Update frequency weight + W_UPDATE_LATENCY = 0.05 # update latency weight + W_AGG_WAITING = 0.6 # time waited since start waiting for aggregation until update is received weight + W_INACTIVITY_PEN = 0.1 # inactivity penalty weight def __init__(self, config): self._addr = config["addr"] self._verbose = config["verbose"] - self._nodes : dict[str, tuple[deque, int, deque[TimeStamp], TimeStamp]] = {} # _nodes estructura: {node_id: (deque updates epr round, inactivity, time gaps between updates, time since last aggregation)} + self._nodes : dict[str, tuple[deque, int, deque[TimeStamp], deque[TimeStamp]]] = {} self._nodes_lock = Locker(name="nodes_lock", async_lock=True) self._grace_rounds = self.GRACE_ROUNDS @@ -54,7 +62,14 @@ def __init__(self, config): async def init(self, config): async with self._nodes_lock: nodes = config["nodes"] - self._nodes = {node_id: (deque(maxlen=self.MAX_HISTORIC_SIZE), 0, deque(maxlen=self.MAX_HISTORIC_SIZE), TimeStamp()) for node_id in nodes} + self._nodes = { + node_id: ( + deque(maxlen=self.MAX_HISTORIC_SIZE), # updates per round, + 0, # inactivity + deque(maxlen=self.MAX_HISTORIC_SIZE), # time gaps between updates + deque(maxlen=self.MAX_HISTORIC_SIZE) # times since last aggregation + ) for node_id in nodes + } await EventManager.get_instance().subscribe_node_event(UpdateReceivedEvent, self._process_update_received_event) await EventManager.get_instance().subscribe_node_event(RoundStartEvent, self._process_round_start) await EventManager.get_instance().subscribe_node_event(AggregationEvent, self._process_aggregation_event) @@ -92,7 +107,7 @@ async def _process_update_received_event(self, ure : UpdateReceivedEvent): if source not in self._nodes: return - history, missed_count, time_between_updts_historic, last_update_time = self._nodes[source] + history, missed_count, time_between_updts_historic, last_update_times = self._nodes[source] if history and history[-1][0] == self._internal_rounds_done: num_updates = history[-1][1] + 1 @@ -107,10 +122,10 @@ async def _process_update_received_event(self, ure : UpdateReceivedEvent): ts.tsle = ts - time_between_updts_historic[-1] time_between_updts_historic.append(ts) - last_update_time.tr = time_received - last_update_time.tsle = time_received - self._last_aggregation_time + lut = TimeStamp(time_received, time_received - self._last_aggregation_time) + last_update_times.append(lut) - self._nodes[source] = (history, missed_count, time_between_updts_historic, last_update_time) + self._nodes[source] = (history, missed_count, time_between_updts_historic, last_update_times) async def update_neighbors(self, node, remove=False): async with self._nodes_lock: @@ -121,7 +136,7 @@ async def update_neighbors(self, node, remove=False): self._nodes.update({node : (deque(maxlen=self.MAX_HISTORIC_SIZE), 0, float('inf'), float('inf'))}) async def evaluate(self): - if self._verbose: logging.info("Evaluating using speed-driven strategy") + if self._verbose: logging.info("Evaluating using speed-oriented strategy") if self._grace_rounds: # Grace rounds self._grace_rounds -= 1 if self._verbose: logging.info("Grace time hasnt finished...") @@ -131,12 +146,10 @@ async def evaluate(self): if self._last_check == 0: nodes = await self._get_nodes() for node in nodes.keys(): - pass - # logging.info(f"Internal rounds done: {self._internal_rounds_done}") - # logging.info(f"Node: {node}, {nodes[node][0]}") - # updates_received = {x[1] for x in nodes[node][0] if x[0] == self._internal_rounds_done} - # logging.info(f"Node: {node} | Updates received this round: {updates_received}") - # logging.info(f"Time waited since last aggregation event {nodes[node][3]}") + #logging.info(f"Node: {node}, {nodes[node][0]}") + updates_received = {x[1] for x in nodes[node][0] if x[0] == self._internal_rounds_done} + if self._verbose: logging.info(f"Node: {node} | Updates received this round: {updates_received}") + if self._verbose: logging.info(f"Time waited since last aggregation event {nodes[node][3][-1].tsle:.3f}") else: if self._verbose: logging.info(f"Evaluation is on cooldown... | {self.CHECK_COOLDOWN - self._last_check} rounds remaining") @@ -161,15 +174,16 @@ async def evaluate(self): min_wait_time = min( ( - nodes[n][3].tsle if nodes[n][3] and nodes[n][3].tsle is not None else float('inf') + sum(t.tsle for t in nodes[n][3]) / len(nodes[n][3]) if nodes[n][3] else float('inf') for n in nodes ), default=1 ) + if self._verbose: logging.info(f"max updates: {max_updates} | mean min latency: {min_latency:.3f} | mean min wait time: {min_wait_time:.3f}") scores = {} - for node, (history, missed_count, time_between_updts_historic, last_update_time) in nodes.items(): + for node, (history, missed_count, time_between_updts_historic, last_wait_times) in nodes.items(): # 1. Frecuencia de updates normalizada updates_received = max((x[1] for x in history if x[0] == self._internal_rounds_done), default=0) F_updt_freq = updates_received / max_updates if max_updates > 0 else 0 @@ -179,9 +193,9 @@ async def evaluate(self): avg_latency = sum(valid_latencies) / len(valid_latencies) if valid_latencies else float('inf') F_updt_latency = min_latency / avg_latency if avg_latency > 0 and avg_latency != float('inf') else 0 - # 3. Tiempo desde ΓΊltima agregaciΓ³n normalizado - wait_time = last_update_time.tsle if last_update_time.tsle is not None else float('inf') - F_agg_waiting = min_wait_time / wait_time if wait_time > 0 else 0 + # 3. Tiempo medio desde ΓΊltima agregaciΓ³n normalizado + avg_wait_time = sum(t.tsle for t in last_wait_times) / len(last_wait_times) if last_wait_times else float('inf') + F_agg_waiting = min_wait_time / avg_wait_time if avg_wait_time > 0 else 0 # 4. PenalizaciΓ³n por inactividad P_n = 1 / (1 + missed_count) # PenalizaciΓ³n inversamente proporcional @@ -197,11 +211,15 @@ async def evaluate(self): # Ordenar nodos por puntuaciΓ³n descendente sorted_nodes = sorted(scores.items(), key=lambda x: x[1], reverse=True) + nodes_below_th = [x for x in sorted_nodes if x[1] < self.SCORE_THRESHOLD] if self._verbose: for node, score in sorted_nodes: - logging.info(f"Node: {node} | Score: {score:.3f}") - self._last_check = (self._last_check + 1) % self.CHECK_COOLDOWN + if self._verbose: logging.info(f"Node: {node} | Score: {score:.3f}") + + if self._verbose: logging.info(f"Nodes below threshold: {nodes_below_th}") + + self._last_check = (self._last_check + 1) % self.CHECK_COOLDOWN return result From 1655705bab0cb65d8080b428fda18c2dc94c343b Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Tue, 11 Mar 2025 14:31:27 +0100 Subject: [PATCH 129/233] feature sat hts --- nebula/core/engine.py | 26 ++------------- nebula/core/nebulaevents.py | 26 ++++++++++++++- .../awareness/samodule.py | 2 +- .../awareness/sanetwork/sanetwork.py | 23 +++++++++++-- .../awareness/satraining/satraining.py | 4 +-- .../trainingpolicy/htstrainingpolicy.py | 31 ++++++++++++++++-- .../trainingpolicy/qdstrainingpolicy.py | 7 ++-- .../trainingpolicy/sostrainingpolicy.py | 14 ++++---- .../satraining/weightstrategy}/fastreboot.py | 0 .../core/situationalawareness/nodemanager.py | 32 ++----------------- 10 files changed, 95 insertions(+), 70 deletions(-) rename nebula/core/situationalawareness/{ => awareness/satraining/weightstrategy}/fastreboot.py (100%) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index ad045e1ca..4ffd1090a 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -218,25 +218,6 @@ def get_trainning_in_progress_lock(self): def get_round_lock(self): return self.round_lock - def get_sinchronized_status(self): - with self.sinchronized_status_lock: - return True - return self._sinchronized_status - - def get_synchronizing_rounds(self): - return False - return self.nm.get_syncrhonizing_rounds() - - def update_sinchronized_status(self, status): - with self.sinchronized_status_lock: - logging.info(f"Update | synchronized status from: {self._sinchronized_status} to {status}") - self._sinchronized_status = status - - def set_synchronizing_rounds(self, status): - if self.mobility: - logging.info(f"Set sinchronizing rounds: {status}") - self.nm.set_synchronizing_rounds(status) - def set_round(self, new_round): logging.info(f"πŸ€– Update round count | from: {self.round} | to round: {new_round}") self.round = new_round @@ -574,7 +555,6 @@ async def get_geoloc(self): """ async def _aditional_node_start(self): - self.update_sinchronized_status(False) logging.info(f"Aditional node | {self.addr} | going to stablish connection with federation") await self.nm.start_late_connection_process() # continue .. @@ -767,9 +747,9 @@ def learning_cycle_finished(self): async def _learning_cycle(self): while self.round is not None and self.round < self.total_rounds: - if self.addr.split()[0][-1] == "5": - logging.info("### sleeping time ###") - time.sleep(30) + # if self.addr.split()[0][-1] == "5": + # logging.info("### sleeping time ###") + # time.sleep(30) current_time = time.time() rse = RoundStartEvent(self.round, current_time) diff --git a/nebula/core/nebulaevents.py b/nebula/core/nebulaevents.py index 42b2bd6a0..89a33809d 100644 --- a/nebula/core/nebulaevents.py +++ b/nebula/core/nebulaevents.py @@ -136,7 +136,31 @@ async def get_event_data(self) -> tuple[str, bool]: return (self._node_addr, self._removed) async def is_concurrent(self) -> bool: - return False + return False + +class NodeFoundEvent(NodeEvent): + def __init__(self, node_addr): + """Event triggered when a new node is found. + + Args: + node_addr (str): Address of the neighboring node. + """ + self._node_addr = node_addr + + def __str__(self): + return f"Node addr: {self._node_addr} found" + + async def get_event_data(self) -> tuple[str, bool]: + """Retrieves the node found event data. + + Returns: + tuple[str, bool]: + - node_addr (str): Address of the node found. + """ + return self._node_addr + + async def is_concurrent(self) -> bool: + return True class UpdateReceivedEvent(NodeEvent): def __init__(self, decoded_model, weight, source, round, local=False): diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index d5bae85e6..cc4f0f9bf 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -31,7 +31,7 @@ def __init__( self._topology = topology self._node_manager: NodeManager = nodemanager self._situational_awareness_network = SANetwork(self, self.cm, self._addr, self._topology) - self._situational_awareness_training = SATraining(self, self._addr, "sos", "fastreboot", verbose=True) + self._situational_awareness_training = SATraining(self, self._addr, "qds", "fastreboot", verbose=True) self._restructure_process_lock = Locker(name="restructure_process_lock") self._restructure_cooldown = 0 diff --git a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py index 0c53678a3..f05c8d9cf 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py @@ -5,6 +5,7 @@ from nebula.addons.functions import print_msg_box from nebula.core.nebulaevents import BeaconRecievedEvent from nebula.core.eventmanager import EventManager +from nebula.core.nebulaevents import NodeFoundEvent, UpdateNeighborEvent from typing import TYPE_CHECKING if TYPE_CHECKING: from nebula.core.network.communications import CommunicationsManager @@ -19,7 +20,8 @@ def __init__( communication_manager: "CommunicationsManager", addr, topology, - strict_topology=True + strict_topology=True, + verbose = False ): print_msg_box( msg=f"Starting Network SA\nTopology: {topology}\nStrict: {strict_topology}", @@ -34,6 +36,7 @@ def __init__( self._neighbor_policy = factory_NeighborPolicy(topology) self._restructure_process_lock = Locker(name="restructure_process_lock") self._restructure_cooldown = 0 + self._verbose = verbose @property def sam(self): @@ -66,6 +69,9 @@ async def init(self): self, ]) + await EventManager.get_instance().subscribe_node_event(NodeFoundEvent, self.process_node_found_event) + await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent, self.process_update_neighbor_event) + async def module_actions(self): logging.info("SA Network evaluating current scenario") await self._check_external_connection_service_status() @@ -76,6 +82,17 @@ async def module_actions(self): # NEIGHBOR POLICY # ############################### """ + + async def process_node_found_event(self, nfe : NodeFoundEvent): + node_addr = await nfe.get_event_data() + if self._verbose: logging.info(f"Processing Node Found Event, node addr: {node_addr}") + self.np.meet_node(node_addr) + + async def process_update_neighbor_event(self, une : UpdateNeighborEvent): + node_addr, removed = await une.get_event_data() + if self._verbose: logging.info(f"Processing Update Neighbor Event, node addr: {node_addr}, remove: {removed}") + self.np.update_neighbors(node_addr, removed) + async def register_node(self, node, neighbor=False, remove=False): if not neighbor: self.meet_node(node) @@ -124,7 +141,9 @@ async def experiment_finish(self): async def beacon_received(self, beacon_recieved_event : BeaconRecievedEvent): addr, geoloc = await beacon_recieved_event.get_event_data() latitude, longitude = geoloc - self.meet_node(addr) + nfe = NodeFoundEvent(addr) + asyncio.create_task(EventManager.get_instance().publish_node_event(nfe)) + #self.meet_node(addr) #logging.info(f"Beacon received SANetwork, source: {addr}, geolocalization: {latitude},{longitude}") """ ############################### diff --git a/nebula/core/situationalawareness/awareness/satraining/satraining.py b/nebula/core/situationalawareness/awareness/satraining/satraining.py index 543f2188d..4548b70bd 100644 --- a/nebula/core/situationalawareness/awareness/satraining/satraining.py +++ b/nebula/core/situationalawareness/awareness/satraining/satraining.py @@ -50,5 +50,5 @@ async def module_actions(self): nodes = await self.tp.evaluate() if nodes: for n in nodes: - pass - #asyncio.create_task(self.sam.cm.disconnect(n, forced=True)) + # pass + asyncio.create_task(self.sam.cm.disconnect(n[0], forced=True)) diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/htstrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/htstrainingpolicy.py index 41dcd4f64..4bfbd39c4 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/htstrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/htstrainingpolicy.py @@ -1,16 +1,41 @@ from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.trainingpolicy import TrainingPolicy +from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.trainingpolicy import factory_training_policy +from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.trainingpolicy import TrainingPolicy +import logging # "Hybrid Training Strategy" (HTS) class HTSTrainingPolicy(TrainingPolicy): + TRAINING_POLICY = { + "qds", + "sos", + } def __init__(self, config): - pass + self._addr = config["addr"] + self._verbose = config["verbose"] + self._training_policies : set[TrainingPolicy] = set() + self._training_policies.add([factory_training_policy(x, config) for x in self.TRAINING_POLICY]) + + def __str__(self): + return "HTS" + + @property + def tps(self): + return self._training_policies async def init(self, config): - pass + for tp in self.tps: + await tp.init(config) async def update_neighbors(self, node, remove=False): pass async def evaluate(self): - pass \ No newline at end of file + nodes_to_remove = dict() + for tp in self.tps: + nodes_to_remove[tp] = await tp.evaluate() + + for tp, nodes in nodes_to_remove.items(): + logging.info(f"Training Policy: {tp}, nodes to remove: {nodes}") + + return None \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py index 891aeca02..cc036bd5c 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py @@ -6,7 +6,6 @@ import logging from nebula.core.eventmanager import EventManager from nebula.core.nebulaevents import AggregationEvent -import random # "Quality-Driven Selection" (QDS) class QDSTrainingPolicy(TrainingPolicy): @@ -25,6 +24,8 @@ def __init__(self, config : dict): self._grace_rounds = self.GRACE_ROUNDS self._last_check = 0 + def __str__(self): + return "QDS" async def init(self, config): async with self._nodes_lock: @@ -57,7 +58,7 @@ async def process_aggregation_event(self, agg_ev : AggregationEvent): else: if self._verbose: logging.info(f"Node inactivity counter increased for: {node}") self._nodes[node] = (deque_history, missed_count + 1) # Inactive rounds counter +1 - + #TODO hacerlo solo para los q no se estΓ‘ utilizando la ultima update guardada (model,_) = updt (self_model, _) = self_updt cos_sim = cosine_metric(self_model, model, similarity=True) @@ -103,7 +104,7 @@ async def evaluate(self): if len(redundant_nodes) > 1: sorted_redundant_nodes = sorted(redundant_nodes, key=lambda x: x[1]) n_discarded = int(len(redundant_nodes)/2) - discard_nodes = sorted_redundant_nodes[-n_discarded:] + discard_nodes = sorted_redundant_nodes[:n_discarded] if self._verbose: logging.info(f"Discarded redundant nodes: {discard_nodes}") result = result.union(discard_nodes) else: diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py index 6c8e1c3c5..cd851a98c 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py @@ -44,7 +44,7 @@ class SOSTrainingPolicy(TrainingPolicy): GRACE_ROUNDS = 1 CHECK_COOLDOWN = 1 W_UPDATE_FREQ = 0.25 # Update frequency weight - W_UPDATE_LATENCY = 0.05 # update latency weight + W_UPDATE_LATENCY = 0.15 # update latency weight W_AGG_WAITING = 0.6 # time waited since start waiting for aggregation until update is received weight W_INACTIVITY_PEN = 0.1 # inactivity penalty weight @@ -58,6 +58,9 @@ def __init__(self, config): self._last_check = 0 self._internal_rounds_done = -1 self._last_aggregation_time = None + + def __str__(self): + return "SOS" async def init(self, config): async with self._nodes_lock: @@ -142,7 +145,6 @@ async def evaluate(self): if self._verbose: logging.info("Grace time hasnt finished...") return None - result = set() if self._last_check == 0: nodes = await self._get_nodes() for node in nodes.keys(): @@ -198,14 +200,14 @@ async def evaluate(self): F_agg_waiting = min_wait_time / avg_wait_time if avg_wait_time > 0 else 0 # 4. PenalizaciΓ³n por inactividad - P_n = 1 / (1 + missed_count) # PenalizaciΓ³n inversamente proporcional + P_n = missed_count*self.W_INACTIVITY_PEN # PenalizaciΓ³n inversamente proporcional # Calcular puntuaciΓ³n final score = ( (self.W_UPDATE_FREQ * F_updt_freq) + (self.W_UPDATE_LATENCY * F_updt_latency) + - (self.W_AGG_WAITING * F_agg_waiting) + - (self.W_INACTIVITY_PEN * P_n) + (self.W_AGG_WAITING * F_agg_waiting) - + P_n ) scores[node] = score @@ -221,6 +223,6 @@ async def evaluate(self): self._last_check = (self._last_check + 1) % self.CHECK_COOLDOWN - return result + return nodes_below_th \ No newline at end of file diff --git a/nebula/core/situationalawareness/fastreboot.py b/nebula/core/situationalawareness/awareness/satraining/weightstrategy/fastreboot.py similarity index 100% rename from nebula/core/situationalawareness/fastreboot.py rename to nebula/core/situationalawareness/awareness/satraining/weightstrategy/fastreboot.py diff --git a/nebula/core/situationalawareness/nodemanager.py b/nebula/core/situationalawareness/nodemanager.py index feeca4c00..ad16e94eb 100644 --- a/nebula/core/situationalawareness/nodemanager.py +++ b/nebula/core/situationalawareness/nodemanager.py @@ -4,7 +4,6 @@ from nebula.addons.functions import print_msg_box from nebula.core.situationalawareness.candidateselection.candidateselector import factory_CandidateSelector -from nebula.core.situationalawareness.fastreboot import FastReboot from nebula.core.situationalawareness.modelhandlers.modelhandler import factory_ModelHandler from nebula.core.situationalawareness.momentum import Momentum from nebula.core.situationalawareness.awareness.samodule import SAModule @@ -23,8 +22,6 @@ def __init__( topology, model_handler, engine: "Engine", - fastreboot=False, - momentum=False, ): self._aditional_participant = aditional_participant self.topology = topology @@ -47,9 +44,6 @@ def __init__( self.discarded_offers_addr_lock = Locker(name="discarded_offers_addr_lock") self.discarded_offers_addr = [] - self._fast_reboot_status = fastreboot - self._momemtum_status = momentum - self._desc_done = False #TODO remove self._situational_awareness_module = SAModule(self, self.engine.addr, topology) @@ -58,10 +52,6 @@ def __init__( def engine(self): return self._engine - # @property - # def neighbor_policy(self): - # return self._neighbor_policy - @property def candidate_selector(self): return self._candidate_selector @@ -70,17 +60,10 @@ def candidate_selector(self): def model_handler(self): return self._model_handler - @property - def fr(self): - return self._fastreboot - @property def sam(self): return self._situational_awareness_module - def fast_reboot_on(self): - return self._fast_reboot_status - def is_additional_participant(self): return self._aditional_participant @@ -100,10 +83,7 @@ async def set_configs(self): logging.info("Building candidate selector configuration..") self.candidate_selector.set_config([0, 0.5, 0.5]) # self.engine.trainer.get_loss(), self.config.participant["molibity_args"]["weight_distance"], self.config.participant["molibity_args"]["weight_het"] - - if self._fast_reboot_status: - self._fastreboot = FastReboot(self) - + async def get_geoloc(self): return await self.sam.get_geoloc() @@ -116,16 +96,12 @@ async def experiment_finish(self): ############################## """ - async def update_learning_rate(self, new_lr): - await self.engine.update_model_learning_rate(new_lr) - async def register_late_neighbor(self, addr, joinning_federation=False): logging.info(f"Registering | late neighbor: {addr}, joining: {joinning_federation}") self.sam.meet_node(addr) await self.update_neighbors(addr) if joinning_federation: - if self.fast_reboot_on(): - await self.fr.add_fastReboot_addr(addr) + pass """ ############################## @@ -187,12 +163,10 @@ def get_actions(self): return self.sam.get_actions() async def update_neighbors(self, node, remove=False): - #logging.info(f"Update neighbor | node addr: {node} | remove: {remove}") await self._update_neighbors_lock.acquire_async() self.sam.update_neighbors(node, remove) if remove: - if self._fast_reboot_status: - self.fr.discard_fastreboot_for(node) + pass else: self.sam.meet_node(node) self._remove_pending_confirmation_from(node) From b5c1dab544c01aa70b5b804b0b99d03dbad855b6 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 12 Mar 2025 11:13:42 +0100 Subject: [PATCH 130/233] fix error offers accepted after stopped lt process --- nebula/core/engine.py | 329 +++++++++--------- nebula/core/network/communications.py | 2 +- .../awareness/satraining/satraining.py | 4 +- .../core/situationalawareness/nodemanager.py | 201 ++++++++++- 4 files changed, 360 insertions(+), 176 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 4ffd1090a..3a1f407d6 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -261,7 +261,6 @@ async def model_update_callback(self, source, message): ############################## """ - # TODO llevar a communications async def _discovery_discover_callback(self, source, message): logging.info( f"πŸ” handle_discovery_message | Trigger | Received discovery message from {source} (network propagation)" @@ -342,170 +341,170 @@ async def _federation_federation_models_included_callback(self, source, message) finally: await self.cm.get_connections_lock().release_async() - """ ############################## - # Mobility callbacks # - ############################## - """ - - async def _connection_late_connect_callback(self, source, message): - logging.info(f"πŸ”— handle_connection_message | Trigger | Received late connect message from {source}") - # Verify if it's a confirmation message from a previous late connection message sent to source - if await self.nm.waiting_confirmation_from(source): - await self.nm.confirmation_received(source, confirmation=True) - return - - if not self.get_initialization_status(): - logging.info("❗️ Connection refused | Device not initialized yet...") - return - - if self.nm.accept_connection(source, joining=True): - logging.info(f"πŸ”— handle_connection_message | Late connection accepted | source: {source}") - await self.cm.connect(source, direct=True) - - # Verify conenction is accepted - conf_msg = self.cm.create_message("connection", "late_connect") - await self.cm.send_message(source, conf_msg) - await self.nm.register_late_neighbor(source, joinning_federation=True) - - ct_actions, df_actions = self.nm.get_actions() - if len(ct_actions): - cnt_msg = self.cm.create_message("link", "connect_to", addrs=ct_actions) - await self.cm.send_message(source, cnt_msg) - - if len(df_actions): - df_msg = self.cm.create_message("link", "disconnect_from", addrs=df_actions) - await self.cm.send_message(source, df_msg) - - else: - logging.info(f"❗️ Late connection NOT accepted | source: {source}") - - async def _connection_restructure_callback(self, source, message): - logging.info(f"πŸ”— handle_connection_message | Trigger | Received restructure message from {source}") - # Verify if it's a confirmation message from a previous restructure connection message sent to source - if await self.nm.waiting_confirmation_from(source): - await self.nm.confirmation_received(source, confirmation=True) - return - - if not self.get_initialization_status(): - logging.info("❗️ Connection refused | Device not initialized yet...") - return - - if self.nm.accept_connection(source, joining=False): - logging.info(f"πŸ”— handle_connection_message | Trigger | restructure connection accepted from {source}") - await self.cm.connect(source, direct=True) - - conf_msg = self.cm.create_message("connection", "restructure") - - await self.cm.send_message(source, conf_msg) - - ct_actions, df_actions = self.nm.get_actions() - if len(ct_actions): - cnt_msg = self.cm.create_message("link", "connect_to", addrs=ct_actions) - await self.cm.send_message(source, cnt_msg) - - if len(df_actions): - df_msg = self.cm.create_message("link", "disconnect_from", addrs=df_actions) - await self.cm.send_message(source, df_msg) - - await self.nm.register_late_neighbor(source, joinning_federation=False) - else: - logging.info(f"❗️ handle_connection_message | Trigger | restructure connection denied from {source}") - await asyncio.sleep(1) - # await self.cm.disconnect(source, mutual_disconnection=True) - - async def _discover_discover_join_callback(self, source, message): - logging.info(f"πŸ” handle_discover_message | Trigger | Received discover_join message from {source} ") - if len(self.get_federation_nodes()) > 0: - await self.trainning_in_progress_lock.acquire_async() - model, rounds, round = ( - await self.cm.propagator.get_model_information(source, "stable") - if self.get_round() > 0 - else await self.cm.propagator.get_model_information(source, "initialization") - ) - await self.trainning_in_progress_lock.release_async() - if round != -1: - epochs = self.config.participant["training_args"]["epochs"] - msg = self.cm.create_message( - "offer", - "offer_model", - len(self.get_federation_nodes()), - 0, - parameters=model, - rounds=rounds, - round=round, - epochs=epochs, - ) - await self.cm.send_offer_model(source, msg) - else: - logging.info("Discover join received before federation is running..") - # starter node is going to send info to the new node - else: - logging.info(f"πŸ”— Dissmissing discover join from {source} | no active connections at the moment") - - async def _discover_discover_nodes_callback(self, source, message): - logging.info(f"πŸ” handle_discover_message | Trigger | Received discover_node message from {source} ") - # self.nm.meet_node(source) - if len(self.get_federation_nodes()) > 0: - # msg = self.cm.mm.generate_offer_message(nebula_pb2.OfferMessage.Action.OFFER_METRIC, len(self.get_federation_nodes()), self.trainer.get_current_loss()) - msg = self.cm.create_message( - "offer", - "offer_metric", - n_neighbors=len(self.get_federation_nodes()), - loss=self.trainer.get_current_loss(), - ) - await self.cm.send_message(source, msg) - else: - logging.info(f"πŸ”— Dissmissing discover nodes from {source} | no active connections at the moment") - - async def _offer_offer_model_callback(self, source, message): - logging.info(f"πŸ” handle_offer_message | Trigger | Received offer_model message from {source}") - self.nm.meet_node(source) - if self.nm.still_waiting_for_candidates(): - try: - model_compressed = message.parameters - if self.nm.accept_model_offer( - source, - model_compressed, - message.rounds, - message.round, - message.epochs, - message.n_neighbors, - message.loss, - ): - logging.info(f"πŸ”§ Model accepted from offer | source: {source}") - else: - logging.info(f"❗️ Model offer discarded | source: {source}") - self.nm.add_to_discarded_offers(source) - except RuntimeError: - logging.info(f"❗️ Error proccesing offer model from {source}") - else: - logging.info( - f"❗️ handfle_offer_message | NOT accepting offers | restructure: {self.nm.get_restructure_process_lock().locked()} | waiting candidates: {self.nm.still_waiting_for_candidates()}" - ) - self.nm.add_to_discarded_offers(source) - - async def _offer_offer_metric_callback(self, source, message): - logging.info(f"πŸ” handle_offer_message | Trigger | Received offer_metric message from {source}") - self.nm.meet_node(source) - if self.nm.still_waiting_for_candidates(): - n_neighbors = message.n_neighbors - loss = message.loss - self.nm.add_candidate(source, n_neighbors, loss) - - async def _link_connect_to_callback(self, source, message): - logging.info(f"πŸ”— handle_link_message | Trigger | Received connect_to message from {source}") - addrs = message.addrs - for addr in addrs.split(): - # await self.cm.connect(addr, direct=True) - # self.nm.update_neighbors(addr) - self.nm.meet_node(addr) - - async def _link_disconnect_from_callback(self, source, message): - logging.info(f"πŸ”— handle_link_message | Trigger | Received disconnect_from message from {source}") - addrs = message.addrs - for addr in addrs.split(): - await self.cm.disconnect(source, mutual_disconnection=False) - await self.nm.update_neighbors(addr, remove=True) + # """ ############################## + # # Mobility callbacks # + # ############################## + # """ + + # async def _connection_late_connect_callback(self, source, message): + # logging.info(f"πŸ”— handle_connection_message | Trigger | Received late connect message from {source}") + # # Verify if it's a confirmation message from a previous late connection message sent to source + # if await self.nm.waiting_confirmation_from(source): + # await self.nm.confirmation_received(source, confirmation=True) + # return + + # if not self.get_initialization_status(): + # logging.info("❗️ Connection refused | Device not initialized yet...") + # return + + # if self.nm.accept_connection(source, joining=True): + # logging.info(f"πŸ”— handle_connection_message | Late connection accepted | source: {source}") + # await self.cm.connect(source, direct=True) + + # # Verify conenction is accepted + # conf_msg = self.cm.create_message("connection", "late_connect") + # await self.cm.send_message(source, conf_msg) + # await self.nm.register_late_neighbor(source, joinning_federation=True) + + # ct_actions, df_actions = self.nm.get_actions() + # if len(ct_actions): + # cnt_msg = self.cm.create_message("link", "connect_to", addrs=ct_actions) + # await self.cm.send_message(source, cnt_msg) + + # if len(df_actions): + # df_msg = self.cm.create_message("link", "disconnect_from", addrs=df_actions) + # await self.cm.send_message(source, df_msg) + + # else: + # logging.info(f"❗️ Late connection NOT accepted | source: {source}") + + # async def _connection_restructure_callback(self, source, message): + # logging.info(f"πŸ”— handle_connection_message | Trigger | Received restructure message from {source}") + # # Verify if it's a confirmation message from a previous restructure connection message sent to source + # if await self.nm.waiting_confirmation_from(source): + # await self.nm.confirmation_received(source, confirmation=True) + # return + + # if not self.get_initialization_status(): + # logging.info("❗️ Connection refused | Device not initialized yet...") + # return + + # if self.nm.accept_connection(source, joining=False): + # logging.info(f"πŸ”— handle_connection_message | Trigger | restructure connection accepted from {source}") + # await self.cm.connect(source, direct=True) + + # conf_msg = self.cm.create_message("connection", "restructure") + + # await self.cm.send_message(source, conf_msg) + + # ct_actions, df_actions = self.nm.get_actions() + # if len(ct_actions): + # cnt_msg = self.cm.create_message("link", "connect_to", addrs=ct_actions) + # await self.cm.send_message(source, cnt_msg) + + # if len(df_actions): + # df_msg = self.cm.create_message("link", "disconnect_from", addrs=df_actions) + # await self.cm.send_message(source, df_msg) + + # await self.nm.register_late_neighbor(source, joinning_federation=False) + # else: + # logging.info(f"❗️ handle_connection_message | Trigger | restructure connection denied from {source}") + # await asyncio.sleep(1) + # # await self.cm.disconnect(source, mutual_disconnection=True) + + # async def _discover_discover_join_callback(self, source, message): + # logging.info(f"πŸ” handle_discover_message | Trigger | Received discover_join message from {source} ") + # if len(self.get_federation_nodes()) > 0: + # await self.trainning_in_progress_lock.acquire_async() + # model, rounds, round = ( + # await self.cm.propagator.get_model_information(source, "stable") + # if self.get_round() > 0 + # else await self.cm.propagator.get_model_information(source, "initialization") + # ) + # await self.trainning_in_progress_lock.release_async() + # if round != -1: + # epochs = self.config.participant["training_args"]["epochs"] + # msg = self.cm.create_message( + # "offer", + # "offer_model", + # len(self.get_federation_nodes()), + # 0, + # parameters=model, + # rounds=rounds, + # round=round, + # epochs=epochs, + # ) + # await self.cm.send_offer_model(source, msg) + # else: + # logging.info("Discover join received before federation is running..") + # # starter node is going to send info to the new node + # else: + # logging.info(f"πŸ”— Dissmissing discover join from {source} | no active connections at the moment") + + # async def _discover_discover_nodes_callback(self, source, message): + # logging.info(f"πŸ” handle_discover_message | Trigger | Received discover_node message from {source} ") + # # self.nm.meet_node(source) + # if len(self.get_federation_nodes()) > 0: + # # msg = self.cm.mm.generate_offer_message(nebula_pb2.OfferMessage.Action.OFFER_METRIC, len(self.get_federation_nodes()), self.trainer.get_current_loss()) + # msg = self.cm.create_message( + # "offer", + # "offer_metric", + # n_neighbors=len(self.get_federation_nodes()), + # loss=self.trainer.get_current_loss(), + # ) + # await self.cm.send_message(source, msg) + # else: + # logging.info(f"πŸ”— Dissmissing discover nodes from {source} | no active connections at the moment") + + # async def _offer_offer_model_callback(self, source, message): + # logging.info(f"πŸ” handle_offer_message | Trigger | Received offer_model message from {source}") + # self.nm.meet_node(source) + # if self.nm.still_waiting_for_candidates(): + # try: + # model_compressed = message.parameters + # if self.nm.accept_model_offer( + # source, + # model_compressed, + # message.rounds, + # message.round, + # message.epochs, + # message.n_neighbors, + # message.loss, + # ): + # logging.info(f"πŸ”§ Model accepted from offer | source: {source}") + # else: + # logging.info(f"❗️ Model offer discarded | source: {source}") + # self.nm.add_to_discarded_offers(source) + # except RuntimeError: + # logging.info(f"❗️ Error proccesing offer model from {source}") + # else: + # logging.info( + # f"❗️ handfle_offer_message | NOT accepting offers | restructure: {self.nm.get_restructure_process_lock().locked()} | waiting candidates: {self.nm.still_waiting_for_candidates()}" + # ) + # self.nm.add_to_discarded_offers(source) + + # async def _offer_offer_metric_callback(self, source, message): + # logging.info(f"πŸ” handle_offer_message | Trigger | Received offer_metric message from {source}") + # self.nm.meet_node(source) + # if self.nm.still_waiting_for_candidates(): + # n_neighbors = message.n_neighbors + # loss = message.loss + # self.nm.add_candidate(source, n_neighbors, loss) + + # async def _link_connect_to_callback(self, source, message): + # logging.info(f"πŸ”— handle_link_message | Trigger | Received connect_to message from {source}") + # addrs = message.addrs + # for addr in addrs.split(): + # # await self.cm.connect(addr, direct=True) + # # self.nm.update_neighbors(addr) + # self.nm.meet_node(addr) + + # async def _link_disconnect_from_callback(self, source, message): + # logging.info(f"πŸ”— handle_link_message | Trigger | Received disconnect_from message from {source}") + # addrs = message.addrs + # for addr in addrs.split(): + # await self.cm.disconnect(source, mutual_disconnection=False) + # await self.nm.update_neighbors(addr, remove=True) """ ############################## # REGISTERING CALLBACKS # diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 45f64876a..f3d9c334d 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -241,7 +241,7 @@ async def stablish_connection_to_federation(self, msg_type="discover_join", addr logging.info(f"Connections verified after searching: {current_connections}") for addr in addrs: - logging.info(f"Sending {msg_type} to ---> {addr}") + logging.info(f"Sending {msg_type} to addr: {addr}") asyncio.create_task(self.send_message(addr, msg)) await asyncio.sleep(1) discovers_sent += 1 diff --git a/nebula/core/situationalawareness/awareness/satraining/satraining.py b/nebula/core/situationalawareness/awareness/satraining/satraining.py index 4548b70bd..fa63ee6ca 100644 --- a/nebula/core/situationalawareness/awareness/satraining/satraining.py +++ b/nebula/core/situationalawareness/awareness/satraining/satraining.py @@ -50,5 +50,5 @@ async def module_actions(self): nodes = await self.tp.evaluate() if nodes: for n in nodes: - # pass - asyncio.create_task(self.sam.cm.disconnect(n[0], forced=True)) + pass + #asyncio.create_task(self.sam.cm.disconnect(n[0], forced=True)) diff --git a/nebula/core/situationalawareness/nodemanager.py b/nebula/core/situationalawareness/nodemanager.py index ad16e94eb..48631bd58 100644 --- a/nebula/core/situationalawareness/nodemanager.py +++ b/nebula/core/situationalawareness/nodemanager.py @@ -8,6 +8,8 @@ from nebula.core.situationalawareness.momentum import Momentum from nebula.core.situationalawareness.awareness.samodule import SAModule from nebula.core.utils.locker import Locker +from nebula.core.eventmanager import EventManager +from nebula.core.nebulaevents import UpdateNeighborEvent, NodeFoundEvent if TYPE_CHECKING: from nebula.core.engine import Engine @@ -51,6 +53,10 @@ def __init__( @property def engine(self): return self._engine + + @property + def cm(self): + return self._engine.cm @property def candidate_selector(self): @@ -79,6 +85,7 @@ async def set_configs(self): - self weight distance - self weight hetereogeneity """ + await self.register_message_events_callbacks() await self.sam.init() logging.info("Building candidate selector configuration..") self.candidate_selector.set_config([0, 0.5, 0.5]) @@ -98,7 +105,7 @@ async def experiment_finish(self): async def register_late_neighbor(self, addr, joinning_federation=False): logging.info(f"Registering | late neighbor: {addr}, joining: {joinning_federation}") - self.sam.meet_node(addr) + await self.meet_node(addr) await self.update_neighbors(addr) if joinning_federation: pass @@ -116,7 +123,7 @@ def accept_connection(self, source, joining=False): return self.sam.accept_connection(source, joining) def still_waiting_for_candidates(self): - return not self.accept_candidates_lock.locked() + return not self.accept_candidates_lock.locked() and self.late_connection_process_lock.locked() async def add_pending_connection_confirmation(self, addr): await self._update_neighbors_lock.acquire_async() @@ -168,12 +175,13 @@ async def update_neighbors(self, node, remove=False): if remove: pass else: - self.sam.meet_node(node) + await self.meet_node(node) self._remove_pending_confirmation_from(node) await self._update_neighbors_lock.release_async() - def meet_node(self, node): - self.sam.meet_node(node) + async def meet_node(self, node): + nfe = NodeFoundEvent(node) + await EventManager.get_instance().publish_node_event(nfe) def get_nodes_known(self, neighbors_too=False): return self.sam.get_nodes_known(neighbors_too) @@ -264,9 +272,9 @@ async def start_late_connection_process(self, connected=False, msg_type="discove self.accept_candidates_lock.release() self.late_connection_process_lock.release() self.candidate_selector.remove_candidates() - if not self._desc_done: #TODO remove - self._desc_done = True - asyncio.create_task(self.sam.san.stop_connections_with_federation()) + # if not self._desc_done: #TODO remove + # self._desc_done = True + # asyncio.create_task(self.sam.san.stop_connections_with_federation()) # if no candidates, repeat process else: logging.info("❗️ No Candidates found...") @@ -275,6 +283,183 @@ async def start_late_connection_process(self, connected=False, msg_type="discove if not connected: logging.info("❗️ repeating process...") await self.start_late_connection_process(connected, msg_type, addrs_known) + + + """ ############################## + # Mobility callbacks # + ############################## + """ + + async def register_message_events_callbacks(self): + me_dict = self.cm.get_messages_events() + message_events = [ + (message_name, message_action) + for (message_name, message_actions) in me_dict.items() + for message_action in message_actions + ] + for event_type, action in message_events: + callback_name = f"_{event_type}_{action}_callback" + method = getattr(self, callback_name, None) + + if callable(method): + await EventManager.get_instance().subscribe((event_type, action), method) + + async def _connection_late_connect_callback(self, source, message): + logging.info(f"πŸ”— handle_connection_message | Trigger | Received late connect message from {source}") + # Verify if it's a confirmation message from a previous late connection message sent to source + if await self.waiting_confirmation_from(source): + await self.confirmation_received(source, confirmation=True) + return + + if not self.engine.get_initialization_status(): + logging.info("❗️ Connection refused | Device not initialized yet...") + return + + if self.accept_connection(source, joining=True): + logging.info(f"πŸ”— handle_connection_message | Late connection accepted | source: {source}") + await self.cm.connect(source, direct=True) + + # Verify conenction is accepted + conf_msg = self.cm.create_message("connection", "late_connect") + await self.cm.send_message(source, conf_msg) + await self.register_late_neighbor(source, joinning_federation=True) + + ct_actions, df_actions = self.get_actions() + if len(ct_actions): + cnt_msg = self.cm.create_message("link", "connect_to", addrs=ct_actions) + await self.cm.send_message(source, cnt_msg) + + if len(df_actions): + df_msg = self.cm.create_message("link", "disconnect_from", addrs=df_actions) + await self.cm.send_message(source, df_msg) + + else: + logging.info(f"❗️ Late connection NOT accepted | source: {source}") + + async def _connection_restructure_callback(self, source, message): + logging.info(f"πŸ”— handle_connection_message | Trigger | Received restructure message from {source}") + # Verify if it's a confirmation message from a previous restructure connection message sent to source + if await self.waiting_confirmation_from(source): + await self.confirmation_received(source, confirmation=True) + return + + if not self.engine.get_initialization_status(): + logging.info("❗️ Connection refused | Device not initialized yet...") + return + + if self.accept_connection(source, joining=False): + logging.info(f"πŸ”— handle_connection_message | Trigger | restructure connection accepted from {source}") + await self.cm.connect(source, direct=True) + + conf_msg = self.cm.create_message("connection", "restructure") + + await self.cm.send_message(source, conf_msg) + + ct_actions, df_actions = self.get_actions() + if len(ct_actions): + cnt_msg = self.cm.create_message("link", "connect_to", addrs=ct_actions) + await self.cm.send_message(source, cnt_msg) + + if len(df_actions): + df_msg = self.cm.create_message("link", "disconnect_from", addrs=df_actions) + await self.cm.send_message(source, df_msg) + + await self.register_late_neighbor(source, joinning_federation=False) + else: + logging.info(f"❗️ handle_connection_message | Trigger | restructure connection denied from {source}") + + async def _discover_discover_join_callback(self, source, message): + logging.info(f"πŸ” handle_discover_message | Trigger | Received discover_join message from {source} ") + if len(self.engine.get_federation_nodes()) > 0: + await self.engine.trainning_in_progress_lock.acquire_async() + model, rounds, round = ( + await self.cm.propagator.get_model_information(source, "stable") + if self.engine.get_round() > 0 + else await self.cm.propagator.get_model_information(source, "initialization") + ) + await self.engine.trainning_in_progress_lock.release_async() + if round != -1: + epochs = self.config.participant["training_args"]["epochs"] + msg = self.cm.create_message( + "offer", + "offer_model", + len(self.engine.get_federation_nodes()), + 0, + parameters=model, + rounds=rounds, + round=round, + epochs=epochs, + ) + await self.cm.send_offer_model(source, msg) + else: + logging.info("Discover join received before federation is running..") + # starter node is going to send info to the new node + else: + logging.info(f"πŸ”— Dissmissing discover join from {source} | no active connections at the moment") + + async def _discover_discover_nodes_callback(self, source, message): + logging.info(f"πŸ” handle_discover_message | Trigger | Received discover_node message from {source} ") + # self.nm.meet_node(source) + if len(self.engine.get_federation_nodes()) > 0: + msg = self.cm.create_message( + "offer", + "offer_metric", + n_neighbors=len(self.engine.get_federation_nodes()), + loss=self.engine.trainer.get_current_loss(), + ) + await self.cm.send_message(source, msg) + else: + logging.info(f"πŸ”— Dissmissing discover nodes from {source} | no active connections at the moment") + + async def _offer_offer_model_callback(self, source, message): + logging.info(f"πŸ” handle_offer_message | Trigger | Received offer_model message from {source}") + await self.meet_node(source) + if self.still_waiting_for_candidates(): + try: + model_compressed = message.parameters + if self.accept_model_offer( + source, + model_compressed, + message.rounds, + message.round, + message.epochs, + message.n_neighbors, + message.loss, + ): + logging.info(f"πŸ”§ Model accepted from offer | source: {source}") + else: + logging.info(f"❗️ Model offer discarded | source: {source}") + self.add_to_discarded_offers(source) + except RuntimeError: + logging.info(f"❗️ Error proccesing offer model from {source}") + else: + logging.info( + f"❗️ handfle_offer_message | NOT accepting offers | restructure: {self.get_restructure_process_lock().locked()} | waiting candidates: {self.still_waiting_for_candidates()}" + ) + self.add_to_discarded_offers(source) + + async def _offer_offer_metric_callback(self, source, message): + logging.info(f"πŸ” handle_offer_message | Trigger | Received offer_metric message from {source}") + await self.meet_node(source) + if self.still_waiting_for_candidates(): + n_neighbors = message.n_neighbors + loss = message.loss + self.add_candidate(source, n_neighbors, loss) + + async def _link_connect_to_callback(self, source, message): + logging.info(f"πŸ”— handle_link_message | Trigger | Received connect_to message from {source}") + addrs = message.addrs + for addr in addrs.split(): + # await self.cm.connect(addr, direct=True) + # self.nm.update_neighbors(addr) + await self.meet_node(addr) + + async def _link_disconnect_from_callback(self, source, message): + logging.info(f"πŸ”— handle_link_message | Trigger | Received disconnect_from message from {source}") + addrs = message.addrs + for addr in addrs.split(): + await self.cm.disconnect(source, mutual_disconnection=False) + await self.update_neighbors(addr, remove=True) From 0b2cdc44f8d5e5ed00d5df16f80c4fcec4a8eb36 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 13 Mar 2025 10:59:37 +0100 Subject: [PATCH 131/233] opt engine --- nebula/core/engine.py | 165 ------------------ .../awareness/sanetwork/sanetwork.py | 2 +- .../awareness/satraining/satraining.py | 4 +- .../trainingpolicy/qdstrainingpolicy.py | 4 +- 4 files changed, 5 insertions(+), 170 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 3a1f407d6..303e3a5c1 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -341,171 +341,6 @@ async def _federation_federation_models_included_callback(self, source, message) finally: await self.cm.get_connections_lock().release_async() - # """ ############################## - # # Mobility callbacks # - # ############################## - # """ - - # async def _connection_late_connect_callback(self, source, message): - # logging.info(f"πŸ”— handle_connection_message | Trigger | Received late connect message from {source}") - # # Verify if it's a confirmation message from a previous late connection message sent to source - # if await self.nm.waiting_confirmation_from(source): - # await self.nm.confirmation_received(source, confirmation=True) - # return - - # if not self.get_initialization_status(): - # logging.info("❗️ Connection refused | Device not initialized yet...") - # return - - # if self.nm.accept_connection(source, joining=True): - # logging.info(f"πŸ”— handle_connection_message | Late connection accepted | source: {source}") - # await self.cm.connect(source, direct=True) - - # # Verify conenction is accepted - # conf_msg = self.cm.create_message("connection", "late_connect") - # await self.cm.send_message(source, conf_msg) - # await self.nm.register_late_neighbor(source, joinning_federation=True) - - # ct_actions, df_actions = self.nm.get_actions() - # if len(ct_actions): - # cnt_msg = self.cm.create_message("link", "connect_to", addrs=ct_actions) - # await self.cm.send_message(source, cnt_msg) - - # if len(df_actions): - # df_msg = self.cm.create_message("link", "disconnect_from", addrs=df_actions) - # await self.cm.send_message(source, df_msg) - - # else: - # logging.info(f"❗️ Late connection NOT accepted | source: {source}") - - # async def _connection_restructure_callback(self, source, message): - # logging.info(f"πŸ”— handle_connection_message | Trigger | Received restructure message from {source}") - # # Verify if it's a confirmation message from a previous restructure connection message sent to source - # if await self.nm.waiting_confirmation_from(source): - # await self.nm.confirmation_received(source, confirmation=True) - # return - - # if not self.get_initialization_status(): - # logging.info("❗️ Connection refused | Device not initialized yet...") - # return - - # if self.nm.accept_connection(source, joining=False): - # logging.info(f"πŸ”— handle_connection_message | Trigger | restructure connection accepted from {source}") - # await self.cm.connect(source, direct=True) - - # conf_msg = self.cm.create_message("connection", "restructure") - - # await self.cm.send_message(source, conf_msg) - - # ct_actions, df_actions = self.nm.get_actions() - # if len(ct_actions): - # cnt_msg = self.cm.create_message("link", "connect_to", addrs=ct_actions) - # await self.cm.send_message(source, cnt_msg) - - # if len(df_actions): - # df_msg = self.cm.create_message("link", "disconnect_from", addrs=df_actions) - # await self.cm.send_message(source, df_msg) - - # await self.nm.register_late_neighbor(source, joinning_federation=False) - # else: - # logging.info(f"❗️ handle_connection_message | Trigger | restructure connection denied from {source}") - # await asyncio.sleep(1) - # # await self.cm.disconnect(source, mutual_disconnection=True) - - # async def _discover_discover_join_callback(self, source, message): - # logging.info(f"πŸ” handle_discover_message | Trigger | Received discover_join message from {source} ") - # if len(self.get_federation_nodes()) > 0: - # await self.trainning_in_progress_lock.acquire_async() - # model, rounds, round = ( - # await self.cm.propagator.get_model_information(source, "stable") - # if self.get_round() > 0 - # else await self.cm.propagator.get_model_information(source, "initialization") - # ) - # await self.trainning_in_progress_lock.release_async() - # if round != -1: - # epochs = self.config.participant["training_args"]["epochs"] - # msg = self.cm.create_message( - # "offer", - # "offer_model", - # len(self.get_federation_nodes()), - # 0, - # parameters=model, - # rounds=rounds, - # round=round, - # epochs=epochs, - # ) - # await self.cm.send_offer_model(source, msg) - # else: - # logging.info("Discover join received before federation is running..") - # # starter node is going to send info to the new node - # else: - # logging.info(f"πŸ”— Dissmissing discover join from {source} | no active connections at the moment") - - # async def _discover_discover_nodes_callback(self, source, message): - # logging.info(f"πŸ” handle_discover_message | Trigger | Received discover_node message from {source} ") - # # self.nm.meet_node(source) - # if len(self.get_federation_nodes()) > 0: - # # msg = self.cm.mm.generate_offer_message(nebula_pb2.OfferMessage.Action.OFFER_METRIC, len(self.get_federation_nodes()), self.trainer.get_current_loss()) - # msg = self.cm.create_message( - # "offer", - # "offer_metric", - # n_neighbors=len(self.get_federation_nodes()), - # loss=self.trainer.get_current_loss(), - # ) - # await self.cm.send_message(source, msg) - # else: - # logging.info(f"πŸ”— Dissmissing discover nodes from {source} | no active connections at the moment") - - # async def _offer_offer_model_callback(self, source, message): - # logging.info(f"πŸ” handle_offer_message | Trigger | Received offer_model message from {source}") - # self.nm.meet_node(source) - # if self.nm.still_waiting_for_candidates(): - # try: - # model_compressed = message.parameters - # if self.nm.accept_model_offer( - # source, - # model_compressed, - # message.rounds, - # message.round, - # message.epochs, - # message.n_neighbors, - # message.loss, - # ): - # logging.info(f"πŸ”§ Model accepted from offer | source: {source}") - # else: - # logging.info(f"❗️ Model offer discarded | source: {source}") - # self.nm.add_to_discarded_offers(source) - # except RuntimeError: - # logging.info(f"❗️ Error proccesing offer model from {source}") - # else: - # logging.info( - # f"❗️ handfle_offer_message | NOT accepting offers | restructure: {self.nm.get_restructure_process_lock().locked()} | waiting candidates: {self.nm.still_waiting_for_candidates()}" - # ) - # self.nm.add_to_discarded_offers(source) - - # async def _offer_offer_metric_callback(self, source, message): - # logging.info(f"πŸ” handle_offer_message | Trigger | Received offer_metric message from {source}") - # self.nm.meet_node(source) - # if self.nm.still_waiting_for_candidates(): - # n_neighbors = message.n_neighbors - # loss = message.loss - # self.nm.add_candidate(source, n_neighbors, loss) - - # async def _link_connect_to_callback(self, source, message): - # logging.info(f"πŸ”— handle_link_message | Trigger | Received connect_to message from {source}") - # addrs = message.addrs - # for addr in addrs.split(): - # # await self.cm.connect(addr, direct=True) - # # self.nm.update_neighbors(addr) - # self.nm.meet_node(addr) - - # async def _link_disconnect_from_callback(self, source, message): - # logging.info(f"πŸ”— handle_link_message | Trigger | Received disconnect_from message from {source}") - # addrs = message.addrs - # for addr in addrs.split(): - # await self.cm.disconnect(source, mutual_disconnection=False) - # await self.nm.update_neighbors(addr, remove=True) - """ ############################## # REGISTERING CALLBACKS # ############################## diff --git a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py index f05c8d9cf..6d7a84f22 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py @@ -168,7 +168,7 @@ async def _analize_topology_robustness(self): if not self._restructure_process_lock.locked(): if not await self.neighbors_left(): logging.info("No Neighbors left | reconnecting with Federation") - await self.reconnect_to_federation() + #await self.reconnect_to_federation() elif self.np.need_more_neighbors() and self._restructure_available(): logging.info("Insufficient Robustness | Upgrading robustness | Searching for more connections") self._update_restructure_cooldown() diff --git a/nebula/core/situationalawareness/awareness/satraining/satraining.py b/nebula/core/situationalawareness/awareness/satraining/satraining.py index fa63ee6ca..4548b70bd 100644 --- a/nebula/core/situationalawareness/awareness/satraining/satraining.py +++ b/nebula/core/situationalawareness/awareness/satraining/satraining.py @@ -50,5 +50,5 @@ async def module_actions(self): nodes = await self.tp.evaluate() if nodes: for n in nodes: - pass - #asyncio.create_task(self.sam.cm.disconnect(n[0], forced=True)) + # pass + asyncio.create_task(self.sam.cm.disconnect(n[0], forced=True)) diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py index cc036bd5c..fdbd5487f 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py @@ -12,8 +12,8 @@ class QDSTrainingPolicy(TrainingPolicy): MAX_HISTORIC_SIZE = 10 SIMILARITY_THRESHOLD = 0.8 INACTIVE_THRESHOLD = 3 - GRACE_ROUNDS = 10 - CHECK_COOLDOWN = 10 + GRACE_ROUNDS = 0 + CHECK_COOLDOWN = 50 def __init__(self, config : dict): self._addr = config["addr"] From 6b3fe4967f70f986d9ff7ab2b0b1ac955af2c014 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 13 Mar 2025 16:15:27 +0100 Subject: [PATCH 132/233] fix evaluation before aggregation --- .../awareness/satraining/satraining.py | 2 +- .../satraining/trainingpolicy/bpstrainingpolicy.py | 3 +++ .../satraining/trainingpolicy/htstrainingpolicy.py | 5 ++++- .../satraining/trainingpolicy/qdstrainingpolicy.py | 9 ++++++++- .../satraining/trainingpolicy/sostrainingpolicy.py | 3 +++ .../satraining/trainingpolicy/trainingpolicy.py | 2 +- 6 files changed, 20 insertions(+), 4 deletions(-) diff --git a/nebula/core/situationalawareness/awareness/satraining/satraining.py b/nebula/core/situationalawareness/awareness/satraining/satraining.py index 4548b70bd..7ad4d3b53 100644 --- a/nebula/core/situationalawareness/awareness/satraining/satraining.py +++ b/nebula/core/situationalawareness/awareness/satraining/satraining.py @@ -47,7 +47,7 @@ async def init(self): async def module_actions(self): logging.info("SA Trainng evaluating current scenario") - nodes = await self.tp.evaluate() + nodes = await self.tp.get_evaluation_results() if nodes: for n in nodes: # pass diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/bpstrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/bpstrainingpolicy.py index 15fa664ec..ee41b2c1a 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/bpstrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/bpstrainingpolicy.py @@ -12,4 +12,7 @@ async def update_neighbors(self, node, remove=False): pass async def evaluate(self): + return None + + async def get_evaluation_results(self): return None \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/htstrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/htstrainingpolicy.py index 4bfbd39c4..f36384b2b 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/htstrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/htstrainingpolicy.py @@ -28,8 +28,11 @@ async def init(self, config): await tp.init(config) async def update_neighbors(self, node, remove=False): - pass + return None + async def get_evaluation_results(self): + pass + async def evaluate(self): nodes_to_remove = dict() for tp in self.tps: diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py index fdbd5487f..15fc2fb56 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py @@ -23,6 +23,7 @@ def __init__(self, config : dict): self._round_missing_nodes = set() self._grace_rounds = self.GRACE_ROUNDS self._last_check = 0 + self._evaluation_results = set() def __str__(self): return "QDS" @@ -63,6 +64,7 @@ async def process_aggregation_event(self, agg_ev : AggregationEvent): (self_model, _) = self_updt cos_sim = cosine_metric(self_model, model, similarity=True) self._nodes[addr][0].append(cos_sim) + self._evaluation_results = await self.evaluate() async def _get_nodes(self): async with self._nodes_lock: @@ -74,6 +76,8 @@ async def evaluate(self): self._grace_rounds -= 1 if self._verbose: logging.info("Grace time hasnt finished...") return None + + if self._verbose: logging.info("Evaluation in process") result = set() if self._last_check == 0: @@ -112,4 +116,7 @@ async def evaluate(self): self._last_check = (self._last_check + 1) % self.CHECK_COOLDOWN - return result \ No newline at end of file + return result + + async def get_evaluation_results(self): + return self._evaluation_results.copy() \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py index cd851a98c..e18172a48 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py @@ -137,6 +137,9 @@ async def update_neighbors(self, node, remove=False): else: if not node in self._nodes: self._nodes.update({node : (deque(maxlen=self.MAX_HISTORIC_SIZE), 0, float('inf'), float('inf'))}) + + async def get_evaluation_results(self): + return None async def evaluate(self): if self._verbose: logging.info("Evaluating using speed-oriented strategy") diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py index 91f726e31..919b9374e 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py @@ -13,7 +13,7 @@ async def update_neighbors(self, node, remove=False): pass @abstractmethod - async def evaluate(self): + async def get_evaluation_results(self): pass From 02111ebfe330576a82b31bc7cb5f43b805615755 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 13 Mar 2025 17:37:13 +0100 Subject: [PATCH 133/233] fix additional participants datasets --- nebula/scenarios.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/nebula/scenarios.py b/nebula/scenarios.py index 9937f9820..022064c31 100644 --- a/nebula/scenarios.py +++ b/nebula/scenarios.py @@ -579,7 +579,7 @@ def load_configurations_and_start_nodes(self, additional_participants=None, sche if dataset_name == "MNIST": dataset = MNISTDataset( num_classes=10, - partitions_number=self.n_nodes, + partitions_number=self.n_nodes+additional_nodes, iid=self.scenario.iid, partition=self.scenario.partition_selection, partition_parameter=self.scenario.partition_parameter, @@ -589,7 +589,7 @@ def load_configurations_and_start_nodes(self, additional_participants=None, sche elif dataset_name == "FashionMNIST": dataset = FashionMNISTDataset( num_classes=10, - partitions_number=self.n_nodes, + partitions_number=self.n_nodes+additional_nodes, iid=self.scenario.iid, partition=self.scenario.partition_selection, partition_parameter=self.scenario.partition_parameter, @@ -599,7 +599,7 @@ def load_configurations_and_start_nodes(self, additional_participants=None, sche elif dataset_name == "EMNIST": dataset = EMNISTDataset( num_classes=10, - partitions_number=self.n_nodes, + partitions_number=self.n_nodes+additional_nodes, iid=self.scenario.iid, partition=self.scenario.partition_selection, partition_parameter=self.scenario.partition_parameter, @@ -609,7 +609,7 @@ def load_configurations_and_start_nodes(self, additional_participants=None, sche elif dataset_name == "CIFAR10": dataset = CIFAR10Dataset( num_classes=10, - partitions_number=self.n_nodes, + partitions_number=self.n_nodes+additional_nodes, iid=self.scenario.iid, partition=self.scenario.partition_selection, partition_parameter=self.scenario.partition_parameter, @@ -619,7 +619,7 @@ def load_configurations_and_start_nodes(self, additional_participants=None, sche elif dataset_name == "CIFAR100": dataset = CIFAR100Dataset( num_classes=100, - partitions_number=self.n_nodes, + partitions_number=self.n_nodes+additional_nodes, iid=self.scenario.iid, partition=self.scenario.partition_selection, partition_parameter=self.scenario.partition_parameter, From 10578b35d511cb2b2bed923ede6d754b85f9fec9 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Fri, 14 Mar 2025 13:01:17 +0100 Subject: [PATCH 134/233] opt nebuladataset factory --- nebula/core/datasets/nebuladataset.py | 21 ++++++++- nebula/scenarios.py | 63 +++++---------------------- 2 files changed, 31 insertions(+), 53 deletions(-) diff --git a/nebula/core/datasets/nebuladataset.py b/nebula/core/datasets/nebuladataset.py index be84b4849..abc7e0ead 100755 --- a/nebula/core/datasets/nebuladataset.py +++ b/nebula/core/datasets/nebuladataset.py @@ -349,7 +349,7 @@ def get_local_test_indices_map(self): train_labels = np.array([self.train_set.targets[idx] for idx in self.train_indices_map[participant_id]]) indices = np.where(np.isin(test_targets, train_labels))[0].tolist() local_test_indices_map[participant_id] = indices - logging.info(f"Participant {participant_id} | Local test indices: {indices}") + #logging.info(f"Participant {participant_id} | Local test indices: {indices}") return local_test_indices_map except Exception as e: logging.exception(f"Error in get_local_test_indices_map: {e}") @@ -1010,3 +1010,22 @@ def plot_all_data_distribution(self, phase, dataset, partitions_map): path_to_save = f"{self.config_dir}/all_data_distribution_CIRCLES_{'iid' if self.iid else 'non_iid'}{'_' + self.partition if not self.iid else ''}_{phase}.pdf" plt.savefig(path_to_save, dpi=300, bbox_inches="tight") plt.close() + +def factory_nebuladataset(dataset, **config) -> NebulaDataset: + from nebula.core.datasets.cifar10.cifar10 import CIFAR10Dataset + from nebula.core.datasets.cifar100.cifar100 import CIFAR100Dataset + from nebula.core.datasets.emnist.emnist import EMNISTDataset + from nebula.core.datasets.fashionmnist.fashionmnist import FashionMNISTDataset + from nebula.core.datasets.mnist.mnist import MNISTDataset + logging.info(f"cosas: {config}") + options = { + "MNIST": MNISTDataset, + "FashionMNIST": FashionMNISTDataset, + "EMNIST": EMNISTDataset, + "CIFAR10": CIFAR10Dataset, + "CIFAR100": CIFAR100Dataset, + } + + cs = options.get(dataset, None) + if not cs: raise ValueError(f"Dataset {dataset} not supported") + return cs(**config) diff --git a/nebula/scenarios.py b/nebula/scenarios.py index 022064c31..d70530a45 100644 --- a/nebula/scenarios.py +++ b/nebula/scenarios.py @@ -23,6 +23,7 @@ from nebula.core.datasets.mnist.mnist import MNISTDataset from nebula.core.utils.certificate import generate_ca_certificate, generate_certificate from nebula.utils import DockerUtils, FileUtils +from nebula.core.datasets.nebuladataset import factory_nebuladataset # Definition of a scenario @@ -576,58 +577,16 @@ def load_configurations_and_start_nodes(self, additional_participants=None, sche # Splitting dataset dataset_name = self.scenario.dataset dataset = None - if dataset_name == "MNIST": - dataset = MNISTDataset( - num_classes=10, - partitions_number=self.n_nodes+additional_nodes, - iid=self.scenario.iid, - partition=self.scenario.partition_selection, - partition_parameter=self.scenario.partition_parameter, - seed=42, - config_dir=self.config_dir, - ) - elif dataset_name == "FashionMNIST": - dataset = FashionMNISTDataset( - num_classes=10, - partitions_number=self.n_nodes+additional_nodes, - iid=self.scenario.iid, - partition=self.scenario.partition_selection, - partition_parameter=self.scenario.partition_parameter, - seed=42, - config_dir=self.config_dir, - ) - elif dataset_name == "EMNIST": - dataset = EMNISTDataset( - num_classes=10, - partitions_number=self.n_nodes+additional_nodes, - iid=self.scenario.iid, - partition=self.scenario.partition_selection, - partition_parameter=self.scenario.partition_parameter, - seed=42, - config_dir=self.config_dir, - ) - elif dataset_name == "CIFAR10": - dataset = CIFAR10Dataset( - num_classes=10, - partitions_number=self.n_nodes+additional_nodes, - iid=self.scenario.iid, - partition=self.scenario.partition_selection, - partition_parameter=self.scenario.partition_parameter, - seed=42, - config_dir=self.config_dir, - ) - elif dataset_name == "CIFAR100": - dataset = CIFAR100Dataset( - num_classes=100, - partitions_number=self.n_nodes+additional_nodes, - iid=self.scenario.iid, - partition=self.scenario.partition_selection, - partition_parameter=self.scenario.partition_parameter, - seed=42, - config_dir=self.config_dir, - ) - else: - raise ValueError(f"Dataset {dataset_name} not supported") + dataset = factory_nebuladataset( + dataset_name, + num_classes=10, + partitions_number=self.n_nodes+additional_nodes, + iid=self.scenario.iid, + partition=self.scenario.partition_selection, + partition_parameter=self.scenario.partition_parameter, + seed=42, + config_dir=self.config_dir, + ) logging.info(f"Splitting {dataset_name} dataset...") dataset.initialize_dataset() From 540771bc591b3d7bf14c217f0a1c8fc99eec1081 Mon Sep 17 00:00:00 2001 From: FerTV Date: Fri, 14 Mar 2025 15:53:04 +0100 Subject: [PATCH 135/233] saving wating scenarios in the database --- nebula/core/datasets/nebuladataset.py | 2 +- nebula/frontend/app.py | 42 ++- nebula/frontend/database.py | 371 ++++++++++++++++++---- nebula/frontend/templates/dashboard.html | 7 +- nebula/frontend/templates/deployment.html | 12 +- nebula/scenarios.py | 12 +- 6 files changed, 360 insertions(+), 86 deletions(-) diff --git a/nebula/core/datasets/nebuladataset.py b/nebula/core/datasets/nebuladataset.py index be84b4849..a16ed8811 100755 --- a/nebula/core/datasets/nebuladataset.py +++ b/nebula/core/datasets/nebuladataset.py @@ -349,7 +349,7 @@ def get_local_test_indices_map(self): train_labels = np.array([self.train_set.targets[idx] for idx in self.train_indices_map[participant_id]]) indices = np.where(np.isin(test_targets, train_labels))[0].tolist() local_test_indices_map[participant_id] = indices - logging.info(f"Participant {participant_id} | Local test indices: {indices}") + # logging.info(f"Participant {participant_id} | Local test indices: {indices}") return local_test_indices_map except Exception as e: logging.exception(f"Error in get_local_test_indices_map: {e}") diff --git a/nebula/frontend/app.py b/nebula/frontend/app.py index 8b24df97c..db21f0ad4 100755 --- a/nebula/frontend/app.py +++ b/nebula/frontend/app.py @@ -676,8 +676,8 @@ async def nebula_dashboard_monitor(scenario_name: str, request: Request, session "scenario_status": scenario[5], "nodes_table": list(nodes_table), "scenario_name": scenario[0], - "scenario_title": scenario[3], - "scenario_description": scenario[4], + "title": scenario[3], + "description": scenario[4], }) else: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) @@ -698,8 +698,8 @@ async def nebula_dashboard_monitor(scenario_name: str, request: Request, session "scenario_status": scenario[5], "nodes_table": [], "scenario_name": scenario[0], - "scenario_title": scenario[3], - "scenario_description": scenario[4], + "title": scenario[3], + "description": scenario[4], }) else: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) @@ -1331,19 +1331,13 @@ async def run_scenario(scenario_data, role, user): scenarioManagement = ScenarioManagement(scenario_data, user) scenario_update_record( - scenario_name=scenarioManagement.scenario_name, - username=user, + name=scenarioManagement.scenario_name, start_time=scenarioManagement.start_date_scenario, end_time="", + scenario=scenarioManagement.scenario, status="running", - title=scenario_data["scenario_title"], - description=scenario_data["scenario_description"], - network_subnet=scenario_data["network_subnet"], - model=scenario_data["model"], - dataset=scenario_data["dataset"], - rounds=scenario_data["rounds"], role=role, - gpu_id=json.dumps(scenario_data["gpu_id"]), + username=user ) # Run the actual scenario @@ -1370,12 +1364,32 @@ async def run_scenario(scenario_data, role, user): # Deploy the list of scenarios async def run_scenarios(role, user): + from nebula.scenarios import Scenario + try: user_data = user_data_store[user] + scenario_pos = 0 + created_time = datetime.datetime.now().strftime('%d_%m_%Y_%H_%M_%S') + for scenario_data in user_data.scenarios_list: + scenario_data["gpu_id"] = [] + scenario = Scenario.from_dict(scenario_data) + + scenario_update_record( + name=f"nebula_{scenario.federation}_{created_time}_{scenario_pos}", + start_time="", + end_time="", + scenario=scenario, + status="waiting", + role=role, + username=user + ) + + scenario_pos+=1 + for scenario_data in user_data.scenarios_list: user_data.finish_scenario_event.clear() - logging.info(f"Running scenario {scenario_data['scenario_title']}") + logging.info(f"Running scenario {scenario_data['title']}") await run_scenario(scenario_data, role, user) # Waits till the scenario is completed while not user_data.finish_scenario_event.is_set() and not user_data.stop_all_scenarios_event.is_set(): diff --git a/nebula/frontend/database.py b/nebula/frontend/database.py index d43d0149d..aa40ef377 100755 --- a/nebula/frontend/database.py +++ b/nebula/frontend/database.py @@ -1,5 +1,6 @@ import asyncio import datetime +import json import logging import sqlite3 @@ -109,11 +110,48 @@ async def initialize_databases(): end_time TEXT, title TEXT, description TEXT, - status TEXT, - network_subnet TEXT, - model TEXT, + deployment TEXT, + federation TEXT, + topology TEXT, + nodes TEXT, + nodes_graph TEXT, + n_nodes TEXT, + matrix TEXT, + random_topology_probability TEXT, dataset TEXT, + iid TEXT, + partition_selection TEXT, + partition_parameter TEXT, + model TEXT, + agg_algorithm TEXT, rounds TEXT, + logginglevel TEXT, + report_status_data_queue TEXT, + accelerator TEXT, + network_subnet TEXT, + network_gateway TEXT, + epochs TEXT, + attacks TEXT, + poisoned_node_percent TEXT, + poisoned_sample_percent TEXT, + poisoned_noise_percent TEXT, + attack_params TEXT, + with_reputation TEXT, + is_dynamic_topology TEXT, + is_dynamic_aggregation TEXT, + target_aggregation TEXT, + random_geo TEXT, + latitude TEXT, + longitude TEXT, + mobility TEXT, + mobility_type TEXT, + radius_federation TEXT, + scheme_mobility TEXT, + round_frequency TEXT, + mobile_participants_percent TEXT, + additional_participants TEXT, + schema_additional_participants TEXT, + status TEXT, role TEXT, username TEXT, gpu_id TEXT @@ -126,14 +164,51 @@ async def initialize_databases(): "end_time": "TEXT", "title": "TEXT", "description": "TEXT", - "status": "TEXT", - "network_subnet": "TEXT", - "model": "TEXT", + "deployment": "TEXT", + "federation": "TEXT", + "topology": "TEXT", + "nodes": "TEXT", + "nodes_graph": "TEXT", + "n_nodes": "TEXT", + "matrix": "TEXT", + "random_topology_probability": "TEXT", "dataset": "TEXT", + "iid": "TEXT", + "partition_selection": "TEXT", + "partition_parameter": "TEXT", + "model": "TEXT", + "agg_algorithm": "TEXT", "rounds": "TEXT", - "role": "TEXT", - "username": "TEXT", + "logginglevel": "TEXT", + "report_status_data_queue": "TEXT", + "accelerator": "TEXT", "gpu_id": "TEXT", + "network_subnet": "TEXT", + "network_gateway": "TEXT", + "epochs": "TEXT", + "attacks": "TEXT", + "poisoned_node_percent": "TEXT", + "poisoned_sample_percent": "TEXT", + "poisoned_noise_percent": "TEXT", + "attack_params": "TEXT", + "with_reputation": "TEXT", + "is_dynamic_topology": "TEXT", + "is_dynamic_aggregation": "TEXT", + "target_aggregation": "TEXT", + "random_geo": "TEXT", + "latitude": "TEXT", + "longitude": "TEXT", + "mobility": "TEXT", + "mobility_type": "TEXT", + "radius_federation": "TEXT", + "scheme_mobility": "TEXT", + "round_frequency": "TEXT", + "mobile_participants_percent": "TEXT", + "additional_participants": "TEXT", + "schema_additional_participants": "TEXT", + "status": "TEXT", + "role": "TEXT", + "username": "TEXT" } await ensure_columns(conn, "scenarios", desired_columns) @@ -410,25 +485,41 @@ def get_all_scenarios_and_check_completed(username, role, sort_by="start_time"): if role == "admin": if sort_by == "start_time": command = """ - SELECT * FROM scenarios - ORDER BY strftime('%Y-%m-%d %H:%M:%S', substr(start_time, 7, 4) || '-' || substr(start_time, 4, 2) || '-' || substr(start_time, 1, 2) || ' ' || substr(start_time, 12, 8)); + SELECT name, username, title, start_time, model, dataset, rounds, status FROM scenarios + ORDER BY + CASE + WHEN start_time IS NULL OR start_time = '' THEN 1 + ELSE 0 + END, + strftime( + '%Y-%m-%d %H:%M:%S', + substr(start_time, 7, 4) || '-' || substr(start_time, 4, 2) || '-' || substr(start_time, 1, 2) || ' ' || substr(start_time, 12, 8) + ); """ _c.execute(command) else: - command = "SELECT * FROM scenarios ORDER BY ?;" + command = "SELECT name, username, title, start_time, model, dataset, rounds, status FROM scenarios ORDER BY ?;" _c.execute(command, (sort_by,)) # _c.execute(command) result = _c.fetchall() else: if sort_by == "start_time": command = """ - SELECT * FROM scenarios + SELECT name, username, title, start_time, model, dataset, rounds, status FROM scenarios WHERE username = ? - ORDER BY strftime('%Y-%m-%d %H:%M:%S', substr(start_time, 7, 4) || '-' || substr(start_time, 4, 2) || '-' || substr(start_time, 1, 2) || ' ' || substr(start_time, 12, 8)); + ORDER BY + CASE + WHEN start_time IS NULL OR start_time = '' THEN 1 + ELSE 0 + END, + strftime( + '%Y-%m-%d %H:%M:%S', + substr(start_time, 7, 4) || '-' || substr(start_time, 4, 2) || '-' || substr(start_time, 1, 2) || ' ' || substr(start_time, 12, 8) + ); """ _c.execute(command, (username,)) else: - command = "SELECT * FROM scenarios WHERE username = ? ORDER BY ?;" + command = "SELECT name, username, title, start_time, model, dataset, rounds, status FROM scenarios WHERE username = ? ORDER BY ?;" _c.execute( command, ( @@ -449,67 +540,237 @@ def get_all_scenarios_and_check_completed(username, role, sort_by="start_time"): def scenario_update_record( - scenario_name, - username, + name, start_time, end_time, - title, - description, + scenario, status, - network_subnet, - model, - dataset, - rounds, role, - gpu_id, + username ): _conn = sqlite3.connect(scenario_db_file_location) _c = _conn.cursor() - command = "SELECT * FROM scenarios WHERE name = ?;" - _c.execute(command, (scenario_name,)) + select_command = "SELECT * FROM scenarios WHERE name = ?;" + _c.execute(select_command, (name,)) result = _c.fetchone() if result is None: - # Create a new record - _c.execute( - "INSERT INTO scenarios VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - ( - scenario_name, + insert_command = """ + INSERT INTO scenarios ( + name, start_time, end_time, title, description, - status, - network_subnet, - model, + deployment, + federation, + topology, + nodes, + nodes_graph, + n_nodes, + matrix, + random_topology_probability, dataset, + iid, + partition_selection, + partition_parameter, + model, + agg_algorithm, rounds, - role, - username, + logginglevel, + report_status_data_queue, + accelerator, gpu_id, - ), - ) - else: - # Update the record - command = "UPDATE scenarios SET start_time = ?, end_time = ?, title = ?, description = ?, status = ?, network_subnet = ?, model = ?, dataset = ?, rounds = ?, role = ? WHERE name = ?;" - _c.execute( - command, - ( - start_time, - end_time, - title, - description, - status, network_subnet, - model, - dataset, - rounds, + network_gateway, + epochs, + attacks, + poisoned_node_percent, + poisoned_sample_percent, + poisoned_noise_percent, + attack_params, + with_reputation, + is_dynamic_topology, + is_dynamic_aggregation, + target_aggregation, + random_geo, + latitude, + longitude, + mobility, + mobility_type, + radius_federation, + scheme_mobility, + round_frequency, + mobile_participants_percent, + additional_participants, + schema_additional_participants, + status, role, - scenario_name, - ), - ) - + username + ) VALUES ( + ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? + ); + """ + _c.execute(insert_command, ( + name, + start_time, + end_time, + scenario.title, + scenario.description, + scenario.deployment, + scenario.federation, + scenario.topology, + json.dumps(scenario.nodes), + json.dumps(scenario.nodes_graph), + scenario.n_nodes, + json.dumps(scenario.matrix), + scenario.random_topology_probability, + scenario.dataset, + scenario.iid, + scenario.partition_selection, + scenario.partition_parameter, + scenario.model, + scenario.agg_algorithm, + scenario.rounds, + scenario.logginglevel, + scenario.report_status_data_queue, + scenario.accelerator, + json.dumps(scenario.gpu_id), + scenario.network_subnet, + scenario.network_gateway, + scenario.epochs, + scenario.attacks, + scenario.poisoned_node_percent, + scenario.poisoned_sample_percent, + scenario.poisoned_noise_percent, + json.dumps(scenario.attack_params), + scenario.with_reputation, + scenario.is_dynamic_topology, + scenario.is_dynamic_aggregation, + scenario.target_aggregation, + scenario.random_geo, + scenario.latitude, + scenario.longitude, + scenario.mobility, + scenario.mobility_type, + scenario.radius_federation, + scenario.scheme_mobility, + scenario.round_frequency, + scenario.mobile_participants_percent, + json.dumps(scenario.additional_participants), + scenario.schema_additional_participants, + status, + role, + username + )) + else: + update_command = """ + UPDATE scenarios SET + start_time = ?, + end_time = ?, + title = ?, + description = ?, + deployment = ?, + federation = ?, + topology = ?, + nodes = ?, + nodes_graph = ?, + n_nodes = ?, + matrix = ?, + random_topology_probability = ?, + dataset = ?, + iid = ?, + partition_selection = ?, + partition_parameter = ?, + model = ?, + agg_algorithm = ?, + rounds = ?, + logginglevel = ?, + report_status_data_queue = ?, + accelerator = ?, + gpu_id = ?, + network_subnet = ?, + network_gateway = ?, + epochs = ?, + attacks = ?, + poisoned_node_percent = ?, + poisoned_sample_percent = ?, + poisoned_noise_percent = ?, + attack_params = ?, + with_reputation = ?, + is_dynamic_topology = ?, + is_dynamic_aggregation = ?, + target_aggregation = ?, + random_geo = ?, + latitude = ?, + longitude = ?, + mobility = ?, + mobility_type = ?, + radius_federation = ?, + scheme_mobility = ?, + round_frequency = ?, + mobile_participants_percent = ?, + additional_participants = ?, + schema_additional_participants = ?, + status = ?, + role = ?, + username = ? + WHERE name = ?; + """ + _c.execute(update_command, ( + start_time, + end_time, + scenario.title, + scenario.description, + scenario.deployment, + scenario.federation, + scenario.topology, + json.dumps(scenario.nodes), + json.dumps(scenario.nodes_graph), + scenario.n_nodes, + json.dumps(scenario.matrix), + scenario.random_topology_probability, + scenario.dataset, + scenario.iid, + scenario.partition_selection, + scenario.partition_parameter, + scenario.model, + scenario.agg_algorithm, + scenario.rounds, + scenario.logginglevel, + scenario.report_status_data_queue, + scenario.accelerator, + json.dumps(scenario.gpu_id), + scenario.network_subnet, + scenario.network_gateway, + scenario.epochs, + scenario.attacks, + scenario.poisoned_node_percent, + scenario.poisoned_sample_percent, + scenario.poisoned_noise_percent, + json.dumps(scenario.attack_params), + scenario.with_reputation, + scenario.is_dynamic_topology, + scenario.is_dynamic_aggregation, + scenario.target_aggregation, + scenario.random_geo, + scenario.latitude, + scenario.longitude, + scenario.mobility, + scenario.mobility_type, + scenario.radius_federation, + scenario.scheme_mobility, + scenario.round_frequency, + scenario.mobile_participants_percent, + json.dumps(scenario.additional_participants), + scenario.schema_additional_participants, + status, + role, + username, + name + )) + _conn.commit() _conn.close() diff --git a/nebula/frontend/templates/dashboard.html b/nebula/frontend/templates/dashboard.html index 0f1eb0230..a7e73d211 100755 --- a/nebula/frontend/templates/dashboard.html +++ b/nebula/frontend/templates/dashboard.html @@ -117,8 +117,7 @@

Scenarios in the database

Action - {% for name, start_time, end_time, title, description, status, network_subnet, model, dataset, - rounds, role, username, gpu_id in scenarios %} + {% for name, username, title, start_time, model, dataset, rounds, status in scenarios %} {% if user_role == "admin" %} {{ username|lower }} @@ -130,8 +129,8 @@

Scenarios in the database

{{ rounds }} {% if status == "running" %} Running - {% elif status == "completed" %} - Completed + {% elif status == "waiting" %} + Waiting {% else %} Finished {% endif %} diff --git a/nebula/frontend/templates/deployment.html b/nebula/frontend/templates/deployment.html index a36dfefe8..115b411e7 100755 --- a/nebula/frontend/templates/deployment.html +++ b/nebula/frontend/templates/deployment.html @@ -1088,8 +1088,8 @@
Schema of deployment
var data = {} // Step 1 - data["scenario_title"] = document.getElementById("scenario-title").value - data["scenario_description"] = document.getElementById("scenario-description").value + data["title"] = document.getElementById("scenario-title").value + data["description"] = document.getElementById("scenario-description").value // Step 2 if (document.getElementById("process-radio").checked) { data["deployment"] = "process" @@ -1177,8 +1177,8 @@
Schema of deployment
document.getElementById("mode-btn").click(); } // Step 1 - document.getElementById("scenario-title").value = data["scenario_title"]; - document.getElementById("scenario-description").value = data["scenario_description"]; + document.getElementById("scenario-title").value = data["title"]; + document.getElementById("scenario-description").value = data["description"]; // Step 2 // Read deployment from data and set the specific radio button if (data["deployment"] == "process") { @@ -2188,8 +2188,8 @@
Schema of deployment
scenarioStorage.replaceScenario(); scenarioStorage.scenariosList.forEach((scenario, index) => { - if (!scenario.scenario_title){ - scenarioStorage.scenariosList[index].scenario_title = "empty"; + if (!scenario.title){ + scenarioStorage.scenariosList[index].title = "empty"; } }); var data = scenarioStorage.scenariosList; diff --git a/nebula/scenarios.py b/nebula/scenarios.py index 1718c2c80..36ddeab63 100644 --- a/nebula/scenarios.py +++ b/nebula/scenarios.py @@ -29,8 +29,8 @@ class Scenario: def __init__( self, - scenario_title, - scenario_description, + title, + description, deployment, federation, topology, @@ -78,8 +78,8 @@ def __init__( Initialize the scenario. Args: - scenario_title (str): Title of the scenario. - scenario_description (str): Description of the scenario. + title (str): Title of the scenario. + description (str): Description of the scenario. deployment (str): Type of deployment (e.g., 'docker', 'process'). federation (str): Type of federation. topology (str): Network topology. @@ -125,8 +125,8 @@ def __init__( schema_additional_participants (str): Schema for additional participants. random_topology_probability (float): Probability for random topology. """ - self.scenario_title = scenario_title - self.scenario_description = scenario_description + self.title = title + self.description = description self.deployment = deployment self.federation = federation self.topology = topology From 3d92944d74efd6a1323a8fc88916c9c4cd6ce869 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Fri, 14 Mar 2025 16:47:43 +0100 Subject: [PATCH 136/233] feature hybrid datasets --- nebula/core/datasets/nebuladataset.py | 45 +++++++++++++++++++++++---- 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/nebula/core/datasets/nebuladataset.py b/nebula/core/datasets/nebuladataset.py index abc7e0ead..3e96bca5c 100755 --- a/nebula/core/datasets/nebuladataset.py +++ b/nebula/core/datasets/nebuladataset.py @@ -9,6 +9,7 @@ import seaborn as sns from sklearn.manifold import TSNE from torch.utils.data import Dataset +from sklearn.model_selection import train_test_split matplotlib.use("Agg") plt.switch_backend("Agg") @@ -243,7 +244,7 @@ def __init__( partitions_number=1, batch_size=32, num_workers=4, - iid=True, + iid="IID", partition="dirichlet", partition_parameter=0.5, seed=42, @@ -303,11 +304,43 @@ def data_partitioning(self, plot=False): f"Partitioning data for {self.__class__.__name__} | Partitions: {self.partitions_number} | IID: {self.iid} | Partition: {self.partition} | Partition parameter: {self.partition_parameter}" ) - self.train_indices_map = ( - self.generate_iid_map(self.train_set) - if self.iid - else self.generate_non_iid_map(self.train_set, self.partition, self.partition_parameter) - ) + size_s1 = 0.5 + index = 0 + targets = set() + data = [] + sample, target = self.train_set.__getitem__(index) + while sample != None and target != None: + data.append(sample) + targets.add(target) + index += 1 + try: + sample, target = self.train_set.__getitem__(index) + except Exception: + pass + data = np.array(data) + logging.info(f"longitud del dataset: {len(data)}") + + + # DivisiΓ³n estratificada en dos subconjuntos de tamaΓ±o variable + X_s1, X_s2, y_s1, y_s2 = train_test_split(data, targets, test_size=(1 - size_s1), stratify=targets, random_state=42) + # Ver la distribuciΓ³n de clases en cada subconjunto + logging.info(f"S1 - {np.bincount(y_s1)} | S2 - {np.bincount(y_s2)}") + + self.iid = "IID" + if self.iid == "IID": + self.train_indices_map = self.generate_iid_map(self.train_set) + elif self.iid == "Non-IID": + self.train_indices_map = self.generate_non_iid_map(self.train_set, self.partition, self.partition_parameter) + else: + pass + + self.iid = True + + # self.train_indices_map = ( + # self.generate_iid_map(self.train_set) + # if self.iid + # else self.generate_non_iid_map(self.train_set, self.partition, self.partition_parameter) + # ) self.test_indices_map = self.get_test_indices_map() self.local_test_indices_map = self.get_local_test_indices_map() From 448b5ff409fc19eec808dced2230754cba8933e3 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Fri, 14 Mar 2025 18:02:28 +0100 Subject: [PATCH 137/233] feat datasets 'n' splitted --- nebula/core/datasets/nebuladataset.py | 55 +++++++++++++++------------ 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/nebula/core/datasets/nebuladataset.py b/nebula/core/datasets/nebuladataset.py index 3e96bca5c..b378c99b0 100755 --- a/nebula/core/datasets/nebuladataset.py +++ b/nebula/core/datasets/nebuladataset.py @@ -247,6 +247,8 @@ def __init__( iid="IID", partition="dirichlet", partition_parameter=0.5, + nsplits_percentages = [], + nsplits_iid = [], seed=42, config_dir=None, ): @@ -259,6 +261,8 @@ def __init__( self.partition_parameter = partition_parameter self.seed = seed self.config_dir = config_dir + self._nsplits_percentages = nsplits_percentages + self._nsplits_iid = nsplits_iid logging.info( f"Dataset {self.__class__.__name__} initialized | Partitions: {self.partitions_number} | IID: {self.iid} | Partition: {self.partition} | Partition parameter: {self.partition_parameter}" @@ -304,36 +308,38 @@ def data_partitioning(self, plot=False): f"Partitioning data for {self.__class__.__name__} | Partitions: {self.partitions_number} | IID: {self.iid} | Partition: {self.partition} | Partition parameter: {self.partition_parameter}" ) - size_s1 = 0.5 - index = 0 - targets = set() - data = [] - sample, target = self.train_set.__getitem__(index) - while sample != None and target != None: - data.append(sample) - targets.add(target) - index += 1 - try: - sample, target = self.train_set.__getitem__(index) - except Exception: - pass - data = np.array(data) - logging.info(f"longitud del dataset: {len(data)}") - - - # DivisiΓ³n estratificada en dos subconjuntos de tamaΓ±o variable - X_s1, X_s2, y_s1, y_s2 = train_test_split(data, targets, test_size=(1 - size_s1), stratify=targets, random_state=42) - # Ver la distribuciΓ³n de clases en cada subconjunto - logging.info(f"S1 - {np.bincount(y_s1)} | S2 - {np.bincount(y_s2)}") - self.iid = "IID" if self.iid == "IID": self.train_indices_map = self.generate_iid_map(self.train_set) elif self.iid == "Non-IID": self.train_indices_map = self.generate_non_iid_map(self.train_set, self.partition, self.partition_parameter) else: - pass - + index = 0 + data = [] + targets = [] + sample, target = self.train_set.__getitem__(index) + while sample != None and target != None: + data.append(sample) + targets.append(target) + index += 1 + try: + sample, target = self.train_set.__getitem__(index) + except Exception: + break + data = np.array(data) + targets = np.array(targets) + logging.info(f"number of samples on dataset: {len(data)}, targets: {targets}") + + subsets = [] + subset_to_split = data + for i in range(0, len(self._nsplits_percentages)-1): + #TODO cuadrar los porcentajes sucesivos + size_s = self._nsplits_percentages[i] + X_s1, X_s2, y_s1, y_s2 = train_test_split(subset_to_split, targets, test_size=(1 - size_s), stratify=targets, random_state=42) + logging.info(f"S1 - {np.bincount(y_s1)} | S2 - {np.bincount(y_s2)}") + subsets.append[y_s1] + subset_to_split = y_s2 + self.iid = True # self.train_indices_map = ( @@ -1050,7 +1056,6 @@ def factory_nebuladataset(dataset, **config) -> NebulaDataset: from nebula.core.datasets.emnist.emnist import EMNISTDataset from nebula.core.datasets.fashionmnist.fashionmnist import FashionMNISTDataset from nebula.core.datasets.mnist.mnist import MNISTDataset - logging.info(f"cosas: {config}") options = { "MNIST": MNISTDataset, "FashionMNIST": FashionMNISTDataset, From 88b281ef6b0930503fc88556f6a1688146d383ad Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Mon, 17 Mar 2025 11:18:23 +0100 Subject: [PATCH 138/233] feature split dataset IID subsets --- nebula/core/datasets/mnist/mnist.py | 4 +- nebula/core/datasets/nebuladataset.py | 122 ++++++++++++++++++-------- 2 files changed, 85 insertions(+), 41 deletions(-) diff --git a/nebula/core/datasets/mnist/mnist.py b/nebula/core/datasets/mnist/mnist.py index 67449bfdb..96cfb3439 100755 --- a/nebula/core/datasets/mnist/mnist.py +++ b/nebula/core/datasets/mnist/mnist.py @@ -83,9 +83,9 @@ def generate_non_iid_map(self, dataset, partition="dirichlet", partition_paramet return partitions_map - def generate_iid_map(self, dataset, partition="balancediid", partition_parameter=2): + def generate_iid_map(self, dataset, partition="balancediid", partition_parameter=2, num_clients=None): if partition == "balancediid": - partitions_map = self.balanced_iid_partition(dataset) + partitions_map = self.balanced_iid_partition(dataset, num_clients) elif partition == "unbalancediid": partitions_map = self.unbalanced_iid_partition(dataset, imbalance_factor=partition_parameter) else: diff --git a/nebula/core/datasets/nebuladataset.py b/nebula/core/datasets/nebuladataset.py index b378c99b0..44993627d 100755 --- a/nebula/core/datasets/nebuladataset.py +++ b/nebula/core/datasets/nebuladataset.py @@ -10,6 +10,7 @@ from sklearn.manifold import TSNE from torch.utils.data import Dataset from sklearn.model_selection import train_test_split +from types import SimpleNamespace matplotlib.use("Agg") plt.switch_backend("Agg") @@ -247,8 +248,8 @@ def __init__( iid="IID", partition="dirichlet", partition_parameter=0.5, - nsplits_percentages = [], - nsplits_iid = [], + nsplits_percentages = [0.50, 0.25, 0.25], + nsplits_iid = ["IID", "IID", "IID"], seed=42, config_dir=None, ): @@ -308,45 +309,15 @@ def data_partitioning(self, plot=False): f"Partitioning data for {self.__class__.__name__} | Partitions: {self.partitions_number} | IID: {self.iid} | Partition: {self.partition} | Partition parameter: {self.partition_parameter}" ) - self.iid = "IID" + self.iid = "hybrid" #TODO REMOVE if self.iid == "IID": self.train_indices_map = self.generate_iid_map(self.train_set) elif self.iid == "Non-IID": self.train_indices_map = self.generate_non_iid_map(self.train_set, self.partition, self.partition_parameter) else: - index = 0 - data = [] - targets = [] - sample, target = self.train_set.__getitem__(index) - while sample != None and target != None: - data.append(sample) - targets.append(target) - index += 1 - try: - sample, target = self.train_set.__getitem__(index) - except Exception: - break - data = np.array(data) - targets = np.array(targets) - logging.info(f"number of samples on dataset: {len(data)}, targets: {targets}") - - subsets = [] - subset_to_split = data - for i in range(0, len(self._nsplits_percentages)-1): - #TODO cuadrar los porcentajes sucesivos - size_s = self._nsplits_percentages[i] - X_s1, X_s2, y_s1, y_s2 = train_test_split(subset_to_split, targets, test_size=(1 - size_s), stratify=targets, random_state=42) - logging.info(f"S1 - {np.bincount(y_s1)} | S2 - {np.bincount(y_s2)}") - subsets.append[y_s1] - subset_to_split = y_s2 - - self.iid = True - - # self.train_indices_map = ( - # self.generate_iid_map(self.train_set) - # if self.iid - # else self.generate_non_iid_map(self.train_set, self.partition, self.partition_parameter) - # ) + self.train_indices_map = self.generate_hybrid_map() + + self.iid = True #TODO REMOVE self.test_indices_map = self.get_test_indices_map() self.local_test_indices_map = self.get_local_test_indices_map() @@ -453,12 +424,85 @@ def generate_non_iid_map(self, dataset, partition="dirichlet", plot=False): pass @abstractmethod - def generate_iid_map(self, dataset, plot=False): + def generate_iid_map(self, dataset, plot=False, num_clients=None): """ Create an iid map of the dataset. """ pass + def generate_hybrid_map(self): + index = 0 + data = [] + targets = [] + sample, target = self.train_set.__getitem__(index) + while sample != None and target != None: + data.append(sample) + targets.append(target) + index += 1 + try: + sample, target = self.train_set.__getitem__(index) + except Exception: + break + data = np.array(data) + targets = np.array(targets) + logging.info(f"number of samples on dataset: {len(data)}, targets: {targets}") + + remaining_size = 1.0 + subsets = [] + subset_to_split, targets_to_split = data, targets + + participants = [i for i in range(self.partitions_number)] + num_participants = len(participants) + grouped_participants = [] + start_idx = 0 + + for i, size in enumerate(self._nsplits_percentages[:-1]): # Last one doesnt required split + relative_size = size / remaining_size # Relative size to remaining dataset + logging.info(f"size: {size}, relative size: {relative_size}, remaining size: {remaining_size}") + X_s1, X_s2, y_s1, y_s2 = train_test_split( + subset_to_split, targets_to_split, + test_size=(1 - relative_size), + stratify=targets_to_split, + random_state=42 + ) + + logging.info(f"Subset {i+1}: {len(X_s1)} samples") + subsets.append((X_s1, y_s1)) # Saving subsets + + num_in_group = round(size * num_participants) + grouped_participants.append(participants[start_idx:start_idx + num_in_group]) + + # Update to next iteration + subset_to_split, targets_to_split = X_s2, y_s2 + remaining_size -= size + + # Saving last subset + subsets.append((subset_to_split, targets_to_split)) + grouped_participants.append(participants[start_idx:]) + + logging.info(f"Subset {len(subsets)}: {len(subset_to_split)} samples") + + for i, (_, y_subset) in enumerate(subsets): + logging.info(f"Subset {i+1} - {np.bincount(y_subset)}") + + + general_map = {} + for i, subset in enumerate(subsets): + data_mapped = dict() + dataset_wrapped = SimpleNamespace(data=subset[0], targets=subset[1]) + + if self._nsplits_iid[i] == "IID": + subset_map = self.generate_iid_map(dataset_wrapped, num_clients=len(grouped_participants[i])) + for j, real_id in enumerate(grouped_participants[i]): # Mapping subset map generated to real clients IDs + data_mapped[real_id] = subset_map[j] + + else: + #TODO nonIID case + pass + + general_map.update(data_mapped) + return general_map + def plot_data_distribution(self, phase, dataset, partitions_map): """ Plot the data distribution of the dataset. @@ -751,7 +795,7 @@ def homo_partition(self, dataset): return net_dataidx_map - def balanced_iid_partition(self, dataset): + def balanced_iid_partition(self, dataset, num_clients=None): """ Partition the dataset into balanced and IID (Independent and Identically Distributed) subsets for each client. @@ -777,7 +821,7 @@ def balanced_iid_partition(self, dataset): federated_data = balanced_iid_partition(my_dataset) # This creates federated data subsets with equal class distributions. """ - num_clients = self.partitions_number + num_clients = self.partitions_number if not num_clients else num_clients clients_data = {i: [] for i in range(num_clients)} # Get the labels from the dataset From 907a81a678e81f9d453c7ebd6f643c2f8da0378a Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Mon, 17 Mar 2025 15:11:50 +0100 Subject: [PATCH 139/233] wip hybrid datasets --- nebula/core/datasets/mnist/mnist.py | 8 ++-- nebula/core/datasets/nebuladataset.py | 53 +++++++++++++++++---------- 2 files changed, 38 insertions(+), 23 deletions(-) diff --git a/nebula/core/datasets/mnist/mnist.py b/nebula/core/datasets/mnist/mnist.py index 96cfb3439..f2e68f552 100755 --- a/nebula/core/datasets/mnist/mnist.py +++ b/nebula/core/datasets/mnist/mnist.py @@ -73,9 +73,9 @@ def load_mnist_dataset(self, train=True): download=True, ) - def generate_non_iid_map(self, dataset, partition="dirichlet", partition_parameter=0.5): + def generate_non_iid_map(self, dataset, partition="dirichlet", partition_parameter=0.2, num_clients=None): if partition == "dirichlet": - partitions_map = self.dirichlet_partition(dataset, alpha=partition_parameter) + partitions_map = self.dirichlet_partition(dataset, alpha=partition_parameter, num_clients=num_clients) elif partition == "percent": partitions_map = self.percentage_partition(dataset, percentage=partition_parameter) else: @@ -85,9 +85,9 @@ def generate_non_iid_map(self, dataset, partition="dirichlet", partition_paramet def generate_iid_map(self, dataset, partition="balancediid", partition_parameter=2, num_clients=None): if partition == "balancediid": - partitions_map = self.balanced_iid_partition(dataset, num_clients) + partitions_map = self.balanced_iid_partition(dataset, num_clients=num_clients) elif partition == "unbalancediid": - partitions_map = self.unbalanced_iid_partition(dataset, imbalance_factor=partition_parameter) + partitions_map = self.unbalanced_iid_partition(dataset, imbalance_factor=partition_parameter, num_clients=num_clients) else: raise ValueError(f"Partition {partition} is not supported for IID map") diff --git a/nebula/core/datasets/nebuladataset.py b/nebula/core/datasets/nebuladataset.py index 44993627d..f8264a940 100755 --- a/nebula/core/datasets/nebuladataset.py +++ b/nebula/core/datasets/nebuladataset.py @@ -11,6 +11,7 @@ from torch.utils.data import Dataset from sklearn.model_selection import train_test_split from types import SimpleNamespace +import copy matplotlib.use("Agg") plt.switch_backend("Agg") @@ -248,8 +249,8 @@ def __init__( iid="IID", partition="dirichlet", partition_parameter=0.5, - nsplits_percentages = [0.50, 0.25, 0.25], - nsplits_iid = ["IID", "IID", "IID"], + nsplits_percentages = [0.50, 0.50], + nsplits_iid = ["IID", "Non-IID"], seed=42, config_dir=None, ): @@ -309,7 +310,7 @@ def data_partitioning(self, plot=False): f"Partitioning data for {self.__class__.__name__} | Partitions: {self.partitions_number} | IID: {self.iid} | Partition: {self.partition} | Partition parameter: {self.partition_parameter}" ) - self.iid = "hybrid" #TODO REMOVE + self.iid = "a" #TODO REMOVE if self.iid == "IID": self.train_indices_map = self.generate_iid_map(self.train_set) elif self.iid == "Non-IID": @@ -317,7 +318,7 @@ def data_partitioning(self, plot=False): else: self.train_indices_map = self.generate_hybrid_map() - self.iid = True #TODO REMOVE + self.iid = False #TODO REMOVE self.test_indices_map = self.get_test_indices_map() self.local_test_indices_map = self.get_local_test_indices_map() @@ -475,6 +476,7 @@ def generate_hybrid_map(self): # Update to next iteration subset_to_split, targets_to_split = X_s2, y_s2 remaining_size -= size + start_idx += num_in_group # Saving last subset subsets.append((subset_to_split, targets_to_split)) @@ -489,16 +491,20 @@ def generate_hybrid_map(self): general_map = {} for i, subset in enumerate(subsets): data_mapped = dict() - dataset_wrapped = SimpleNamespace(data=subset[0], targets=subset[1]) + subset_copy = copy.deepcopy(subset) + dataset_wrapped = SimpleNamespace(data=subset_copy[0], targets=subset_copy[1]) if self._nsplits_iid[i] == "IID": + logging.info(f"Generating dataset subset IID for participants: {grouped_participants[i]}, num_clients:{len(grouped_participants[i])}") subset_map = self.generate_iid_map(dataset_wrapped, num_clients=len(grouped_participants[i])) for j, real_id in enumerate(grouped_participants[i]): # Mapping subset map generated to real clients IDs data_mapped[real_id] = subset_map[j] else: - #TODO nonIID case - pass + logging.info(f"Generating dataset subset Non-IID for participants: {grouped_participants[i]}, num_clients:{len(grouped_participants[i])}") + subset_map = self.generate_non_iid_map(dataset_wrapped, num_clients=len(grouped_participants[i])) + for j, real_id in enumerate(grouped_participants[i]): # Mapping subset map generated to real clients IDs + data_mapped[real_id] = subset_map[j] general_map.update(data_mapped) return general_map @@ -589,7 +595,8 @@ def visualize_tsne(self, dataset): def dirichlet_partition( self, dataset: Any, - alpha: float = 0.5, + alpha: float = 0.2, + num_clients=None, min_samples_size: int = 50, balanced: bool = False, max_iter: int = 100, @@ -620,9 +627,13 @@ def dirichlet_partition( partitions : dict[int, list[int]] Dictionary mapping each client index to a list of sample indices. """ + num_clients = self.partitions_number if not num_clients else num_clients + logging.info(f"Generating Dirichlet Partitioning, alpha: {alpha}, num_clients: {num_clients}") + # Extract targets and unique labels. y_data = self._get_targets(dataset) unique_labels = np.unique(y_data) + logging.info(f"Unique labels in dataset: {unique_labels}") # For each class, get a shuffled list of indices. class_indices = {} @@ -633,28 +644,32 @@ def dirichlet_partition( class_indices[label] = idx # Prepare container for client indices. - indices_per_partition = [[] for _ in range(self.partitions_number)] + indices_per_partition = [[] for _ in range(num_clients)] - def allocate_for_label(label_idx: np.ndarray, rng: np.random.Generator) -> np.ndarray: + def allocate_for_label(label_idx: np.ndarray, rng: np.random.Generator, n_clients) -> np.ndarray: num_label_samples = len(label_idx) + logging.info(f"number of samples allocating {num_label_samples}") if balanced: - proportions = np.full(self.partitions_number, 1.0 / self.partitions_number) + proportions = np.full(n_clients, 1.0 / n_clients) else: - proportions = rng.dirichlet([alpha] * self.partitions_number) + proportions = rng.dirichlet([alpha] * n_clients) + logging.info(f"Dirichlet proportions: {proportions}") sample_counts = (proportions * num_label_samples).astype(int) remainder = num_label_samples - sample_counts.sum() if remainder > 0: - extra_indices = rng.choice(self.partitions_number, size=remainder, replace=False) + extra_indices = rng.choice(n_clients, size=remainder, replace=False) for idx in extra_indices: sample_counts[idx] += 1 + logging.info(f"Samples allocated per client: {sample_counts}") return sample_counts for iteration in range(1, max_iter + 1): rng = np.random.default_rng(self.seed + iteration) - temp_indices_per_partition = [[] for _ in range(self.partitions_number)] + temp_indices_per_partition = [[] for _ in range(num_clients)] for label in unique_labels: label_idx = class_indices[label] - counts = allocate_for_label(label_idx, rng) + logging.info(f"Calculating samples distribution for label: {label}") + counts = allocate_for_label(label_idx, rng, num_clients) start = 0 for client_idx, count in enumerate(counts): end = start + count @@ -665,10 +680,10 @@ def allocate_for_label(label_idx: np.ndarray, rng: np.random.Generator) -> np.nd if min(client_sizes) >= min_samples_size: indices_per_partition = temp_indices_per_partition if verbose: - print(f"Partition successful at iteration {iteration}. Client sizes: {client_sizes}") + logging.info(f"Partition successful at iteration {iteration}. Client sizes: {client_sizes}") break if verbose: - print(f"Iteration {iteration}: client sizes {client_sizes}") + logging.info(f"Iteration {iteration}: client sizes {client_sizes}") else: raise ValueError( @@ -852,7 +867,7 @@ def balanced_iid_partition(self, dataset, num_clients=None): return clients_data - def unbalanced_iid_partition(self, dataset, imbalance_factor=2): + def unbalanced_iid_partition(self, dataset, imbalance_factor=2, num_clients=None): """ Partition the dataset into multiple IID (Independent and Identically Distributed) subsets with different size. @@ -883,7 +898,7 @@ def unbalanced_iid_partition(self, dataset, imbalance_factor=2): # This creates federated data subsets with varying number of samples based on # an imbalance factor of 2. """ - num_clients = self.partitions_number + num_clients = self.partitions_number if not num_clients else num_clients clients_data = {i: [] for i in range(num_clients)} # Get the labels from the dataset From 9031ebdb506ce7bd4c7824d21fcc195345599c44 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Tue, 18 Mar 2025 19:01:49 +0100 Subject: [PATCH 140/233] fix dirichlet subset generated --- nebula/core/datasets/nebuladataset.py | 100 +++++++++++++++++--------- 1 file changed, 68 insertions(+), 32 deletions(-) diff --git a/nebula/core/datasets/nebuladataset.py b/nebula/core/datasets/nebuladataset.py index f8264a940..91b097b5e 100755 --- a/nebula/core/datasets/nebuladataset.py +++ b/nebula/core/datasets/nebuladataset.py @@ -249,8 +249,8 @@ def __init__( iid="IID", partition="dirichlet", partition_parameter=0.5, - nsplits_percentages = [0.50, 0.50], - nsplits_iid = ["IID", "Non-IID"], + nsplits_percentages = [0.5, 0.5], + nsplits_iid = ["Non-IID", "Non-IID"], seed=42, config_dir=None, ): @@ -265,6 +265,7 @@ def __init__( self.config_dir = config_dir self._nsplits_percentages = nsplits_percentages self._nsplits_iid = nsplits_iid + self._targets_reales = None logging.info( f"Dataset {self.__class__.__name__} initialized | Partitions: {self.partitions_number} | IID: {self.iid} | Partition: {self.partition} | Partition parameter: {self.partition_parameter}" @@ -446,53 +447,70 @@ def generate_hybrid_map(self): break data = np.array(data) targets = np.array(targets) + self._targets_reales = targets.copy() #TODO remove logging.info(f"number of samples on dataset: {len(data)}, targets: {targets}") remaining_size = 1.0 subsets = [] - subset_to_split, targets_to_split = data, targets - + subset_to_split, targets_to_split = copy.deepcopy(data), copy.deepcopy(targets) + participants = [i for i in range(self.partitions_number)] num_participants = len(participants) grouped_participants = [] start_idx = 0 + + or_indices = np.arange(len(data)) - for i, size in enumerate(self._nsplits_percentages[:-1]): # Last one doesnt required split - relative_size = size / remaining_size # Relative size to remaining dataset + # Inicializar las estructuras que se dividirΓ‘n en cada iteraciΓ³n + subset_to_split, targets_to_split, indices_to_split = copy.deepcopy(data), copy.deepcopy(targets), copy.deepcopy(or_indices) + + for i, size in enumerate(self._nsplits_percentages[:-1]): # Last one doesn't require split + relative_size = size / remaining_size # TamaΓ±o relativo respecto al conjunto restante logging.info(f"size: {size}, relative size: {relative_size}, remaining size: {remaining_size}") - X_s1, X_s2, y_s1, y_s2 = train_test_split( - subset_to_split, targets_to_split, + + # Dividir manteniendo referencias originales + x_s1, x_s2, y_s1, y_s2, idx_s1, idx_s2 = train_test_split( + subset_to_split, targets_to_split, indices_to_split, test_size=(1 - relative_size), stratify=targets_to_split, random_state=42 ) - - logging.info(f"Subset {i+1}: {len(X_s1)} samples") - subsets.append((X_s1, y_s1)) # Saving subsets - + + # Guardar los datos y etiquetas originales asociados a los Γ­ndices seleccionados + original_X_s1, original_y_s1 = data[idx_s1], targets[idx_s1] + + logging.info(f"Subset {i+1}: {len(original_X_s1)} samples") + + # Guardar subset con referencia a los datos originales + subsets.append((original_X_s1, original_y_s1, idx_s1)) + num_in_group = round(size * num_participants) grouped_participants.append(participants[start_idx:start_idx + num_in_group]) - # Update to next iteration - subset_to_split, targets_to_split = X_s2, y_s2 + # Actualizar para la siguiente iteraciΓ³n + subset_to_split, targets_to_split, indices_to_split = data[idx_s2], targets[idx_s2], idx_s2 remaining_size -= size start_idx += num_in_group - # Saving last subset - subsets.append((subset_to_split, targets_to_split)) + # Guardar el ΓΊltimo subset con sus Γ­ndices originales + original_X_s2, original_y_s2 = data[indices_to_split], targets[indices_to_split] + subsets.append((original_X_s2, original_y_s2, indices_to_split)) grouped_participants.append(participants[start_idx:]) - logging.info(f"Subset {len(subsets)}: {len(subset_to_split)} samples") - - for i, (_, y_subset) in enumerate(subsets): - logging.info(f"Subset {i+1} - {np.bincount(y_subset)}") - - + for i, (_, ysubset, _) in enumerate(subsets): + logging.info(f"Subset {i+1} - {np.bincount(ysubset)}") + general_map = {} for i, subset in enumerate(subsets): data_mapped = dict() - subset_copy = copy.deepcopy(subset) - dataset_wrapped = SimpleNamespace(data=subset_copy[0], targets=subset_copy[1]) + real_indexes = subset[2] + subset_real_data = data[real_indexes] + subset_real_targets = targets[real_indexes] + + logging.info(f"comprobacion de subset {np.array_equal(subset[0], subset_real_data)}") + logging.info(f"comprobacion de subset {np.array_equal(subset[1], subset_real_targets)}") + + dataset_wrapped = SimpleNamespace(data=subset_real_data, targets=subset_real_targets, real_indexes=real_indexes) if self._nsplits_iid[i] == "IID": logging.info(f"Generating dataset subset IID for participants: {grouped_participants[i]}, num_clients:{len(grouped_participants[i])}") @@ -502,11 +520,15 @@ def generate_hybrid_map(self): else: logging.info(f"Generating dataset subset Non-IID for participants: {grouped_participants[i]}, num_clients:{len(grouped_participants[i])}") - subset_map = self.generate_non_iid_map(dataset_wrapped, num_clients=len(grouped_participants[i])) + subset_map = self.generate_non_iid_map(dataset_wrapped, num_clients=len(grouped_participants[i])) + for id, indxs in subset_map.items(): + logging.info(f"Prev | Participant id: {id}, num samples: {len(indxs)}, targets: {np.bincount(targets[indxs])}") for j, real_id in enumerate(grouped_participants[i]): # Mapping subset map generated to real clients IDs data_mapped[real_id] = subset_map[j] - + general_map.update(data_mapped) + for id, indexes in general_map.items(): + logging.info(f"After | Participant id: {id}, num samples: {len(indexes)}, targets: {np.bincount(targets[indexes])}") return general_map def plot_data_distribution(self, phase, dataset, partitions_map): @@ -631,15 +653,28 @@ def dirichlet_partition( logging.info(f"Generating Dirichlet Partitioning, alpha: {alpha}, num_clients: {num_clients}") # Extract targets and unique labels. - y_data = self._get_targets(dataset) - unique_labels = np.unique(y_data) + if not num_clients: + y_data = self._get_targets(dataset) + unique_labels = np.unique(y_data) + else: + logging.info("Extracting dataset partition targets...") + y_data = dataset.targets + logging.info(f"{y_data}") + unique_labels = np.unique(y_data) logging.info(f"Unique labels in dataset: {unique_labels}") # For each class, get a shuffled list of indices. class_indices = {} base_rng = np.random.default_rng(self.seed) for label in unique_labels: - idx = np.where(y_data == label)[0] + if not num_clients: + idx = np.where(y_data == label)[0] + else: + ri = dataset.real_indexes + idx = np.where(self._targets_reales[ri] == label)[0] + idx = ri[idx] + logging.info(f"prueba antes: {self._targets_reales[idx]}") + base_rng.shuffle(idx) class_indices[label] = idx @@ -653,7 +688,6 @@ def allocate_for_label(label_idx: np.ndarray, rng: np.random.Generator, n_client proportions = np.full(n_clients, 1.0 / n_clients) else: proportions = rng.dirichlet([alpha] * n_clients) - logging.info(f"Dirichlet proportions: {proportions}") sample_counts = (proportions * num_label_samples).astype(int) remainder = num_label_samples - sample_counts.sum() if remainder > 0: @@ -668,12 +702,16 @@ def allocate_for_label(label_idx: np.ndarray, rng: np.random.Generator, n_client temp_indices_per_partition = [[] for _ in range(num_clients)] for label in unique_labels: label_idx = class_indices[label] + # logging.info(f"prueba a saco: {self._prueba[label_idx]}") logging.info(f"Calculating samples distribution for label: {label}") counts = allocate_for_label(label_idx, rng, num_clients) start = 0 for client_idx, count in enumerate(counts): + # logging.info(f"start: {start}, count:{count}, end: {start+count}") end = start + count temp_indices_per_partition[client_idx].extend(label_idx[start:end]) + # logging.info(f"Assigned samples for: {client_idx}, label: {label}, number of samples: {len(temp_indices_per_partition[client_idx])}") + logging.info(f"comprobacion de conteo: {np.bincount(self._targets_reales[temp_indices_per_partition[client_idx]])}") start = end client_sizes = [len(indices) for indices in temp_indices_per_partition] @@ -691,9 +729,7 @@ def allocate_for_label(label_idx: np.ndarray, rng: np.random.Generator, n_client ) initial_partition = {i: indices for i, indices in enumerate(indices_per_partition)} - final_partition = self.postprocess_partition(initial_partition, y_data) - return final_partition @staticmethod From e99eb8f12d5901c036b4859d038877eb6b498074 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 20 Mar 2025 13:04:22 +0100 Subject: [PATCH 141/233] feature unbalanced IID hybrid datasets --- nebula/core/datasets/mnist/mnist.py | 4 +- nebula/core/datasets/nebuladataset.py | 85 ++++++++++++++++----------- 2 files changed, 54 insertions(+), 35 deletions(-) diff --git a/nebula/core/datasets/mnist/mnist.py b/nebula/core/datasets/mnist/mnist.py index f2e68f552..4fdf66542 100755 --- a/nebula/core/datasets/mnist/mnist.py +++ b/nebula/core/datasets/mnist/mnist.py @@ -73,11 +73,11 @@ def load_mnist_dataset(self, train=True): download=True, ) - def generate_non_iid_map(self, dataset, partition="dirichlet", partition_parameter=0.2, num_clients=None): + def generate_non_iid_map(self, dataset, partition="dirichlet", partition_parameter=0.5, num_clients=None): if partition == "dirichlet": partitions_map = self.dirichlet_partition(dataset, alpha=partition_parameter, num_clients=num_clients) elif partition == "percent": - partitions_map = self.percentage_partition(dataset, percentage=partition_parameter) + partitions_map = self.percentage_partition(dataset, percentage=partition_parameter, num_clients=num_clients) else: raise ValueError(f"Partition {partition} is not supported for Non-IID map") diff --git a/nebula/core/datasets/nebuladataset.py b/nebula/core/datasets/nebuladataset.py index 91b097b5e..43b47ef0d 100755 --- a/nebula/core/datasets/nebuladataset.py +++ b/nebula/core/datasets/nebuladataset.py @@ -249,8 +249,10 @@ def __init__( iid="IID", partition="dirichlet", partition_parameter=0.5, - nsplits_percentages = [0.5, 0.5], - nsplits_iid = ["Non-IID", "Non-IID"], + nsplits_percentages = [0.5, 0.25, 0.25], + nsplits_iid = ["Non-IID", "IID", "Non-IID"], + npartitions = ["dirichlet", "balancediid", "dirichlet"], + npartitions_parameter =[50, 2, 0.5], seed=42, config_dir=None, ): @@ -265,6 +267,8 @@ def __init__( self.config_dir = config_dir self._nsplits_percentages = nsplits_percentages self._nsplits_iid = nsplits_iid + self._npartitions = npartitions + self._npartitions_parameter = npartitions_parameter self._targets_reales = None logging.info( @@ -506,29 +510,24 @@ def generate_hybrid_map(self): real_indexes = subset[2] subset_real_data = data[real_indexes] subset_real_targets = targets[real_indexes] - - logging.info(f"comprobacion de subset {np.array_equal(subset[0], subset_real_data)}") - logging.info(f"comprobacion de subset {np.array_equal(subset[1], subset_real_targets)}") - + dataset_wrapped = SimpleNamespace(data=subset_real_data, targets=subset_real_targets, real_indexes=real_indexes) if self._nsplits_iid[i] == "IID": - logging.info(f"Generating dataset subset IID for participants: {grouped_participants[i]}, num_clients:{len(grouped_participants[i])}") - subset_map = self.generate_iid_map(dataset_wrapped, num_clients=len(grouped_participants[i])) + logging.info(f"Generating dataset subset IID for participants: {grouped_participants[i]}, num_clients: {len(grouped_participants[i])}") + subset_map = self.generate_iid_map(dataset_wrapped, self._npartitions[i], self._npartitions_parameter[i], num_clients=len(grouped_participants[i])) for j, real_id in enumerate(grouped_participants[i]): # Mapping subset map generated to real clients IDs data_mapped[real_id] = subset_map[j] else: - logging.info(f"Generating dataset subset Non-IID for participants: {grouped_participants[i]}, num_clients:{len(grouped_participants[i])}") - subset_map = self.generate_non_iid_map(dataset_wrapped, num_clients=len(grouped_participants[i])) - for id, indxs in subset_map.items(): - logging.info(f"Prev | Participant id: {id}, num samples: {len(indxs)}, targets: {np.bincount(targets[indxs])}") + logging.info(f"Generating dataset subset Non-IID for participants: {grouped_participants[i]}, num_clients: {len(grouped_participants[i])}") + subset_map = self.generate_non_iid_map(dataset_wrapped, self._npartitions[i], self._npartitions_parameter[i], num_clients=len(grouped_participants[i])) for j, real_id in enumerate(grouped_participants[i]): # Mapping subset map generated to real clients IDs data_mapped[real_id] = subset_map[j] general_map.update(data_mapped) for id, indexes in general_map.items(): - logging.info(f"After | Participant id: {id}, num samples: {len(indexes)}, targets: {np.bincount(targets[indexes])}") + logging.info(f" Participant id: {id}, num samples: {len(indexes)}, targets: {np.bincount(targets[indexes])}") return general_map def plot_data_distribution(self, phase, dataset, partitions_map): @@ -622,7 +621,7 @@ def dirichlet_partition( min_samples_size: int = 50, balanced: bool = False, max_iter: int = 100, - verbose: bool = True, + verbose: bool = False, ) -> dict[int, list[int]]: """ Partition the dataset among clients using a Dirichlet distribution. @@ -657,11 +656,11 @@ def dirichlet_partition( y_data = self._get_targets(dataset) unique_labels = np.unique(y_data) else: - logging.info("Extracting dataset partition targets...") + if verbose: logging.info("Extracting dataset partition targets...") + # For hybrid dataset scenarios y_data = dataset.targets - logging.info(f"{y_data}") unique_labels = np.unique(y_data) - logging.info(f"Unique labels in dataset: {unique_labels}") + if verbose: logging.info(f"Unique labels in dataset: {unique_labels}") # For each class, get a shuffled list of indices. class_indices = {} @@ -673,7 +672,7 @@ def dirichlet_partition( ri = dataset.real_indexes idx = np.where(self._targets_reales[ri] == label)[0] idx = ri[idx] - logging.info(f"prueba antes: {self._targets_reales[idx]}") + # logging.info(f"attempting: {self._targets_reales[idx]}") base_rng.shuffle(idx) class_indices[label] = idx @@ -683,7 +682,7 @@ def dirichlet_partition( def allocate_for_label(label_idx: np.ndarray, rng: np.random.Generator, n_clients) -> np.ndarray: num_label_samples = len(label_idx) - logging.info(f"number of samples allocating {num_label_samples}") + if verbose: logging.info(f"number of samples allocating {num_label_samples}") if balanced: proportions = np.full(n_clients, 1.0 / n_clients) else: @@ -694,7 +693,7 @@ def allocate_for_label(label_idx: np.ndarray, rng: np.random.Generator, n_client extra_indices = rng.choice(n_clients, size=remainder, replace=False) for idx in extra_indices: sample_counts[idx] += 1 - logging.info(f"Samples allocated per client: {sample_counts}") + if verbose: logging.info(f"Samples allocated per client: {sample_counts}") return sample_counts for iteration in range(1, max_iter + 1): @@ -702,16 +701,13 @@ def allocate_for_label(label_idx: np.ndarray, rng: np.random.Generator, n_client temp_indices_per_partition = [[] for _ in range(num_clients)] for label in unique_labels: label_idx = class_indices[label] - # logging.info(f"prueba a saco: {self._prueba[label_idx]}") - logging.info(f"Calculating samples distribution for label: {label}") + if verbose: logging.info(f"Calculating samples distribution for label: {label}") counts = allocate_for_label(label_idx, rng, num_clients) start = 0 for client_idx, count in enumerate(counts): - # logging.info(f"start: {start}, count:{count}, end: {start+count}") end = start + count temp_indices_per_partition[client_idx].extend(label_idx[start:end]) - # logging.info(f"Assigned samples for: {client_idx}, label: {label}, number of samples: {len(temp_indices_per_partition[client_idx])}") - logging.info(f"comprobacion de conteo: {np.bincount(self._targets_reales[temp_indices_per_partition[client_idx]])}") + if verbose: logging.info(f"Counting check: {np.bincount(self._targets_reales[temp_indices_per_partition[client_idx]])}") start = end client_sizes = [len(indices) for indices in temp_indices_per_partition] @@ -729,7 +725,7 @@ def allocate_for_label(label_idx: np.ndarray, rng: np.random.Generator, n_client ) initial_partition = {i: indices for i, indices in enumerate(indices_per_partition)} - final_partition = self.postprocess_partition(initial_partition, y_data) + final_partition = initial_partition #self.postprocess_partition(initial_partition, y_data) return final_partition @staticmethod @@ -872,6 +868,7 @@ def balanced_iid_partition(self, dataset, num_clients=None): federated_data = balanced_iid_partition(my_dataset) # This creates federated data subsets with equal class distributions. """ + logging.info("Generating balanced IID partition") num_clients = self.partitions_number if not num_clients else num_clients clients_data = {i: [] for i in range(num_clients)} @@ -888,8 +885,13 @@ def balanced_iid_partition(self, dataset, num_clients=None): min_count = label_counts[min_label] for label in range(self.num_classes): - # Get the indices of the same label samples - label_indices = np.where(labels == label)[0] + if not num_clients: + label_indices = np.where(labels == label)[0] + else: # For hybrid dataset scenarios + ri = dataset.real_indexes + label_indices = np.where(self._targets_reales[ri] == label)[0] + label_indices = ri[label_indices] + np.random.seed(self.seed) np.random.shuffle(label_indices) @@ -934,12 +936,18 @@ def unbalanced_iid_partition(self, dataset, imbalance_factor=2, num_clients=None # This creates federated data subsets with varying number of samples based on # an imbalance factor of 2. """ + logging.info("Generating unbalanced IID partition") num_clients = self.partitions_number if not num_clients else num_clients clients_data = {i: [] for i in range(num_clients)} # Get the labels from the dataset - labels = np.array([dataset.targets[idx] for idx in range(len(dataset))]) + if not num_clients: + labels = np.array([dataset.targets[idx] for idx in range(len(dataset))]) + else: + labels = np.array(self._targets_reales[dataset.real_indexes]) + label_counts = np.bincount(labels) + logging.info(f"label_counts: {label_counts}") min_label = label_counts.argmin() min_count = label_counts[min_label] @@ -954,7 +962,13 @@ def unbalanced_iid_partition(self, dataset, imbalance_factor=2, num_clients=None for label in range(self.num_classes): # Get the indices of the same label samples - label_indices = np.where(labels == label)[0] + if not num_clients: + label_indices = np.where(labels == label)[0] + else: # For hybrid dataset scenarios + ri = dataset.real_indexes + label_indices = np.where(self._targets_reales[ri] == label)[0] + label_indices = ri[label_indices] + np.random.seed(self.seed) np.random.shuffle(label_indices) @@ -967,7 +981,7 @@ def unbalanced_iid_partition(self, dataset, imbalance_factor=2, num_clients=None return clients_data - def percentage_partition(self, dataset, percentage=20): + def percentage_partition(self, dataset, percentage=20, num_clients=None): """ Partition a dataset into multiple subsets with a specified level of non-IID-ness. @@ -1007,11 +1021,16 @@ def percentage_partition(self, dataset, percentage=20): y_train = np.asarray(dataset.targets) num_classes = self.num_classes - num_subsets = self.partitions_number + num_subsets = self.partitions_number if not num_clients else num_clients + + #TODO class_indices = {i: np.where(y_train == i)[0] for i in range(num_classes)} # Get the labels from the dataset - labels = np.array([dataset.targets[idx] for idx in range(len(dataset))]) + if not num_clients: + labels = np.array([dataset.targets[idx] for idx in range(len(dataset))]) + else: + labels = np.array(self._targets_reales[dataset.real_indexes]) label_counts = np.bincount(labels) min_label = label_counts.argmin() From 62adab9430d3648b038c5c706d2e8df8c9ffaf81 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Fri, 21 Mar 2025 10:38:35 +0100 Subject: [PATCH 142/233] feature cifar10 hybrid data partitioning --- nebula/core/datasets/cifar10/cifar10.py | 12 +++---- nebula/core/datasets/nebuladataset.py | 42 ++++++++++++++----------- nebula/core/engine.py | 3 -- nebula/core/network/communications.py | 2 +- 4 files changed, 30 insertions(+), 29 deletions(-) diff --git a/nebula/core/datasets/cifar10/cifar10.py b/nebula/core/datasets/cifar10/cifar10.py index 527ad2380..4837cd11f 100755 --- a/nebula/core/datasets/cifar10/cifar10.py +++ b/nebula/core/datasets/cifar10/cifar10.py @@ -78,21 +78,21 @@ def load_cifar10_dataset(self, train=True): download=True, ) - def generate_non_iid_map(self, dataset, partition="dirichlet", partition_parameter=0.5): + def generate_non_iid_map(self, dataset, partition="dirichlet", partition_parameter=0.5, num_clients=None): if partition == "dirichlet": - partitions_map = self.dirichlet_partition(dataset, alpha=partition_parameter) + partitions_map = self.dirichlet_partition(dataset, alpha=partition_parameter, n_clients=num_clients) elif partition == "percent": - partitions_map = self.percentage_partition(dataset, percentage=partition_parameter) + partitions_map = self.percentage_partition(dataset, percentage=partition_parameter, n_clients=num_clients) else: raise ValueError(f"Partition {partition} is not supported for Non-IID map") return partitions_map - def generate_iid_map(self, dataset, partition="balancediid", partition_parameter=2): + def generate_iid_map(self, dataset, partition="balancediid", partition_parameter=2, num_clients=None): if partition == "balancediid": - partitions_map = self.balanced_iid_partition(dataset) + partitions_map = self.balanced_iid_partition(dataset, n_clients=num_clients) elif partition == "unbalancediid": - partitions_map = self.unbalanced_iid_partition(dataset, imbalance_factor=partition_parameter) + partitions_map = self.unbalanced_iid_partition(dataset, imbalance_factor=partition_parameter, n_clients=num_clients) else: raise ValueError(f"Partition {partition} is not supported for IID map") diff --git a/nebula/core/datasets/nebuladataset.py b/nebula/core/datasets/nebuladataset.py index 43b47ef0d..ca4e817b6 100755 --- a/nebula/core/datasets/nebuladataset.py +++ b/nebula/core/datasets/nebuladataset.py @@ -252,7 +252,7 @@ def __init__( nsplits_percentages = [0.5, 0.25, 0.25], nsplits_iid = ["Non-IID", "IID", "Non-IID"], npartitions = ["dirichlet", "balancediid", "dirichlet"], - npartitions_parameter =[50, 2, 0.5], + npartitions_parameter =[0.1, 2, 0.5], seed=42, config_dir=None, ): @@ -315,11 +315,11 @@ def data_partitioning(self, plot=False): f"Partitioning data for {self.__class__.__name__} | Partitions: {self.partitions_number} | IID: {self.iid} | Partition: {self.partition} | Partition parameter: {self.partition_parameter}" ) - self.iid = "a" #TODO REMOVE + self.iid = "Non-IID" #TODO REMOVE if self.iid == "IID": self.train_indices_map = self.generate_iid_map(self.train_set) elif self.iid == "Non-IID": - self.train_indices_map = self.generate_non_iid_map(self.train_set, self.partition, self.partition_parameter) + self.train_indices_map = self.generate_non_iid_map(self.train_set, partition=self.partition, partition_parameter=self.partition_parameter) else: self.train_indices_map = self.generate_hybrid_map() @@ -617,7 +617,7 @@ def dirichlet_partition( self, dataset: Any, alpha: float = 0.2, - num_clients=None, + n_clients=None, min_samples_size: int = 50, balanced: bool = False, max_iter: int = 100, @@ -648,11 +648,11 @@ def dirichlet_partition( partitions : dict[int, list[int]] Dictionary mapping each client index to a list of sample indices. """ - num_clients = self.partitions_number if not num_clients else num_clients + num_clients = self.partitions_number if not n_clients else n_clients logging.info(f"Generating Dirichlet Partitioning, alpha: {alpha}, num_clients: {num_clients}") # Extract targets and unique labels. - if not num_clients: + if not n_clients: y_data = self._get_targets(dataset) unique_labels = np.unique(y_data) else: @@ -666,7 +666,7 @@ def dirichlet_partition( class_indices = {} base_rng = np.random.default_rng(self.seed) for label in unique_labels: - if not num_clients: + if not n_clients: idx = np.where(y_data == label)[0] else: ri = dataset.real_indexes @@ -842,7 +842,7 @@ def homo_partition(self, dataset): return net_dataidx_map - def balanced_iid_partition(self, dataset, num_clients=None): + def balanced_iid_partition(self, dataset, n_clients=None): """ Partition the dataset into balanced and IID (Independent and Identically Distributed) subsets for each client. @@ -869,7 +869,7 @@ def balanced_iid_partition(self, dataset, num_clients=None): # This creates federated data subsets with equal class distributions. """ logging.info("Generating balanced IID partition") - num_clients = self.partitions_number if not num_clients else num_clients + num_clients = self.partitions_number if not n_clients else n_clients clients_data = {i: [] for i in range(num_clients)} # Get the labels from the dataset @@ -885,7 +885,7 @@ def balanced_iid_partition(self, dataset, num_clients=None): min_count = label_counts[min_label] for label in range(self.num_classes): - if not num_clients: + if not n_clients: label_indices = np.where(labels == label)[0] else: # For hybrid dataset scenarios ri = dataset.real_indexes @@ -905,7 +905,7 @@ def balanced_iid_partition(self, dataset, num_clients=None): return clients_data - def unbalanced_iid_partition(self, dataset, imbalance_factor=2, num_clients=None): + def unbalanced_iid_partition(self, dataset, imbalance_factor=2, n_clients=None): """ Partition the dataset into multiple IID (Independent and Identically Distributed) subsets with different size. @@ -937,11 +937,11 @@ def unbalanced_iid_partition(self, dataset, imbalance_factor=2, num_clients=None # an imbalance factor of 2. """ logging.info("Generating unbalanced IID partition") - num_clients = self.partitions_number if not num_clients else num_clients + num_clients = self.partitions_number if not n_clients else n_clients clients_data = {i: [] for i in range(num_clients)} # Get the labels from the dataset - if not num_clients: + if not n_clients: labels = np.array([dataset.targets[idx] for idx in range(len(dataset))]) else: labels = np.array(self._targets_reales[dataset.real_indexes]) @@ -962,7 +962,7 @@ def unbalanced_iid_partition(self, dataset, imbalance_factor=2, num_clients=None for label in range(self.num_classes): # Get the indices of the same label samples - if not num_clients: + if not n_clients: label_indices = np.where(labels == label)[0] else: # For hybrid dataset scenarios ri = dataset.real_indexes @@ -981,7 +981,7 @@ def unbalanced_iid_partition(self, dataset, imbalance_factor=2, num_clients=None return clients_data - def percentage_partition(self, dataset, percentage=20, num_clients=None): + def percentage_partition(self, dataset, percentage=20, n_clients=None): """ Partition a dataset into multiple subsets with a specified level of non-IID-ness. @@ -1021,13 +1021,17 @@ def percentage_partition(self, dataset, percentage=20, num_clients=None): y_train = np.asarray(dataset.targets) num_classes = self.num_classes - num_subsets = self.partitions_number if not num_clients else num_clients + num_subsets = self.partitions_number if not n_clients else n_clients - #TODO - class_indices = {i: np.where(y_train == i)[0] for i in range(num_classes)} + if not n_clients: + class_indices = {i: np.where(y_train == i)[0] for i in range(num_classes)} + else: + #TODO adapt, bad right now + ri = dataset.real_indexes + class_indices = {i: "" for i in range(num_classes)} # Get the labels from the dataset - if not num_clients: + if not n_clients: labels = np.array([dataset.targets[idx] for idx in range(len(dataset))]) else: labels = np.array(self._targets_reales[dataset.real_indexes]) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 6ed61fe44..b198d92dd 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -588,9 +588,6 @@ def learning_cycle_finished(self): async def _learning_cycle(self): while self.round is not None and self.round < self.total_rounds: - # if self.addr.split()[0][-1] == "5": - # logging.info("### sleeping time ###") - # time.sleep(30) current_time = time.time() rse = RoundStartEvent(self.round, current_time) diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index bb3e7bcf6..d11b296ef 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -709,7 +709,7 @@ async def disconnect(self, dest_addr, mutual_disconnection=True, forced=False): if mutual_disconnection: await self.connections[dest_addr].send(data=self.create_message("connection", "disconnect")) await asyncio.sleep(1) - self.connections[dest_addr].stop() + await self.connections[dest_addr].stop() except Exception as e: logging.exception(f"❗️ Error while disconnecting {dest_addr}: {e!s}") if dest_addr in self.connections: From de554720e1f00016b9b6b40a4f955ac6114e6285 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Fri, 21 Mar 2025 22:17:42 +0100 Subject: [PATCH 143/233] fix mnist error --- nebula/core/datasets/mnist/mnist.py | 8 ++++---- nebula/core/datasets/nebuladataset.py | 2 +- .../awareness/satraining/satraining.py | 4 ++-- .../satraining/trainingpolicy/qdstrainingpolicy.py | 7 ++++--- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/nebula/core/datasets/mnist/mnist.py b/nebula/core/datasets/mnist/mnist.py index 4fdf66542..5e7379174 100755 --- a/nebula/core/datasets/mnist/mnist.py +++ b/nebula/core/datasets/mnist/mnist.py @@ -75,9 +75,9 @@ def load_mnist_dataset(self, train=True): def generate_non_iid_map(self, dataset, partition="dirichlet", partition_parameter=0.5, num_clients=None): if partition == "dirichlet": - partitions_map = self.dirichlet_partition(dataset, alpha=partition_parameter, num_clients=num_clients) + partitions_map = self.dirichlet_partition(dataset, alpha=partition_parameter, n_clients=num_clients) elif partition == "percent": - partitions_map = self.percentage_partition(dataset, percentage=partition_parameter, num_clients=num_clients) + partitions_map = self.percentage_partition(dataset, percentage=partition_parameter, n_clients=num_clients) else: raise ValueError(f"Partition {partition} is not supported for Non-IID map") @@ -85,9 +85,9 @@ def generate_non_iid_map(self, dataset, partition="dirichlet", partition_paramet def generate_iid_map(self, dataset, partition="balancediid", partition_parameter=2, num_clients=None): if partition == "balancediid": - partitions_map = self.balanced_iid_partition(dataset, num_clients=num_clients) + partitions_map = self.balanced_iid_partition(dataset, n_clients=num_clients) elif partition == "unbalancediid": - partitions_map = self.unbalanced_iid_partition(dataset, imbalance_factor=partition_parameter, num_clients=num_clients) + partitions_map = self.unbalanced_iid_partition(dataset, imbalance_factor=partition_parameter, n_clients=num_clients) else: raise ValueError(f"Partition {partition} is not supported for IID map") diff --git a/nebula/core/datasets/nebuladataset.py b/nebula/core/datasets/nebuladataset.py index ca4e817b6..0222b701e 100755 --- a/nebula/core/datasets/nebuladataset.py +++ b/nebula/core/datasets/nebuladataset.py @@ -315,7 +315,7 @@ def data_partitioning(self, plot=False): f"Partitioning data for {self.__class__.__name__} | Partitions: {self.partitions_number} | IID: {self.iid} | Partition: {self.partition} | Partition parameter: {self.partition_parameter}" ) - self.iid = "Non-IID" #TODO REMOVE + self.iid = "a" #TODO REMOVE if self.iid == "IID": self.train_indices_map = self.generate_iid_map(self.train_set) elif self.iid == "Non-IID": diff --git a/nebula/core/situationalawareness/awareness/satraining/satraining.py b/nebula/core/situationalawareness/awareness/satraining/satraining.py index 7ad4d3b53..b40f84a65 100644 --- a/nebula/core/situationalawareness/awareness/satraining/satraining.py +++ b/nebula/core/situationalawareness/awareness/satraining/satraining.py @@ -50,5 +50,5 @@ async def module_actions(self): nodes = await self.tp.get_evaluation_results() if nodes: for n in nodes: - # pass - asyncio.create_task(self.sam.cm.disconnect(n[0], forced=True)) + pass + #asyncio.create_task(self.sam.cm.disconnect(n[0], forced=True)) diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py index 15fc2fb56..40b9725a6 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py @@ -6,6 +6,7 @@ import logging from nebula.core.eventmanager import EventManager from nebula.core.nebulaevents import AggregationEvent +import math # "Quality-Driven Selection" (QDS) class QDSTrainingPolicy(TrainingPolicy): @@ -65,7 +66,7 @@ async def process_aggregation_event(self, agg_ev : AggregationEvent): cos_sim = cosine_metric(self_model, model, similarity=True) self._nodes[addr][0].append(cos_sim) self._evaluation_results = await self.evaluate() - + async def _get_nodes(self): async with self._nodes_lock: nodes = self._nodes.copy() @@ -105,9 +106,9 @@ async def evaluate(self): if self._verbose: logging.info(f"Redundant nodes on aggregations: {redundant_nodes}") if inactive_nodes: result = result.union(inactive_nodes) - if len(redundant_nodes) > 1: + if len(redundant_nodes): sorted_redundant_nodes = sorted(redundant_nodes, key=lambda x: x[1]) - n_discarded = int(len(redundant_nodes)/2) + n_discarded = math.ceil((len(redundant_nodes)/2)) discard_nodes = sorted_redundant_nodes[:n_discarded] if self._verbose: logging.info(f"Discarded redundant nodes: {discard_nodes}") result = result.union(discard_nodes) From 4a9e25d8f2a8417b3ccbdc4b93f8aaee51d7de60 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Mon, 24 Mar 2025 11:25:54 +0100 Subject: [PATCH 144/233] fix dflupdatehandler & qds --- .../aggregation/updatehandlers/dflupdatehandler.py | 3 +++ .../satraining/trainingpolicy/qdstrainingpolicy.py | 14 +++++++------- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/nebula/core/aggregation/updatehandlers/dflupdatehandler.py b/nebula/core/aggregation/updatehandlers/dflupdatehandler.py index 203cf1c87..247760b1a 100644 --- a/nebula/core/aggregation/updatehandlers/dflupdatehandler.py +++ b/nebula/core/aggregation/updatehandlers/dflupdatehandler.py @@ -127,6 +127,9 @@ async def get_round_updates(self): if updates_missing: self._missing_ones = updates_missing logging.info(f"Missing updates from sources: {updates_missing}") + else: + self._missing_ones.clear() + self._nodes_using_historic.clear() updates = {} for sr in self._sources_received: diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py index 40b9725a6..f294f6d7a 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py @@ -53,13 +53,13 @@ async def process_aggregation_event(self, agg_ev : AggregationEvent): if addr == self._addr: continue if not addr in self._nodes.keys(): continue - for node in self._nodes.keys(): # Update inactive counters - deque_history, missed_count = self._nodes[node] - if not node in missing_nodes: - self._nodes[node] = (deque_history, 0) # Reset inactive counter - else: - if self._verbose: logging.info(f"Node inactivity counter increased for: {node}") - self._nodes[node] = (deque_history, missed_count + 1) # Inactive rounds counter +1 + deque_history, missed_count = self._nodes[addr] + if addr in missing_nodes: + if self._verbose: logging.info(f"Node inactivity counter increased for: {addr}") + self._nodes[addr] = (deque_history, missed_count + 1) # Inactive rounds counter +1 + else: + self._nodes[addr] = (deque_history, 0) # Reset inactive counter + #TODO hacerlo solo para los q no se estΓ‘ utilizando la ultima update guardada (model,_) = updt (self_model, _) = self_updt From 09e5fbe90cb917baf2e6f7ccd3e05dab88fd4f41 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Tue, 25 Mar 2025 07:58:41 +0100 Subject: [PATCH 145/233] updt --- .../situationalawareness/awareness/satraining/satraining.py | 3 +-- .../awareness/satraining/trainingpolicy/qdstrainingpolicy.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/nebula/core/situationalawareness/awareness/satraining/satraining.py b/nebula/core/situationalawareness/awareness/satraining/satraining.py index b40f84a65..d96f9e6d3 100644 --- a/nebula/core/situationalawareness/awareness/satraining/satraining.py +++ b/nebula/core/situationalawareness/awareness/satraining/satraining.py @@ -50,5 +50,4 @@ async def module_actions(self): nodes = await self.tp.get_evaluation_results() if nodes: for n in nodes: - pass - #asyncio.create_task(self.sam.cm.disconnect(n[0], forced=True)) + asyncio.create_task(self.sam.cm.disconnect(n[0], forced=True)) diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py index f294f6d7a..77d03ce21 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py @@ -11,7 +11,7 @@ # "Quality-Driven Selection" (QDS) class QDSTrainingPolicy(TrainingPolicy): MAX_HISTORIC_SIZE = 10 - SIMILARITY_THRESHOLD = 0.8 + SIMILARITY_THRESHOLD = 0.73 INACTIVE_THRESHOLD = 3 GRACE_ROUNDS = 0 CHECK_COOLDOWN = 50 From 4ff9eb334c929e2520d1c07e22ada361383c21bf Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 26 Mar 2025 12:48:51 +0100 Subject: [PATCH 146/233] opt space --- nebula/core/engine.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index b198d92dd..824a83a2c 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -251,8 +251,7 @@ async def model_initialization_callback(self, source, message): except RuntimeError: pass except RuntimeError: - pass - + pass async def model_update_callback(self, source, message): logging.info(f"πŸ€– handle_model_message | Received model update from {source} with round {message.round}") if not self.get_federation_ready_lock().locked() and len(self.get_federation_nodes()) == 0: From 6ae66c7ad898a165a0934c637bee8a9648179ccc Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 27 Mar 2025 12:23:17 +0100 Subject: [PATCH 147/233] feature nebula plugin loader --- nebula/core/pluginloader.py | 170 ++++++++++++++++++++++++++++++++++++ 1 file changed, 170 insertions(+) create mode 100644 nebula/core/pluginloader.py diff --git a/nebula/core/pluginloader.py b/nebula/core/pluginloader.py new file mode 100644 index 000000000..a239b9916 --- /dev/null +++ b/nebula/core/pluginloader.py @@ -0,0 +1,170 @@ +import logging +import asyncio +from nebula.addons.functions import print_msg_box +from nebula.core.utils.locker import Locker +from abc import ABC, abstractmethod +import importlib.util +import os + +""" +It is an example of the .json configuration file structure required +for the NebulaPluginLoader to load all plugins defined for the current +scenario. + +JSON Configuration Example: +--------------------------- +{ + "plugins": ["reputation", "trust"], + + "reputation": { + "threshold": 0.75, + "decay_factor": 0.05, + "history_length": 100 + }, + + "trust": { + "initial_trust": 0.5, + "trust_update_rate": 0.1, + "penalty_factor": 0.2 + } +} + +Plugin Directory Structure: +--------------------------- +All plugins must follow a standardized directory and naming convention +to be correctly detected and loaded by NebulaPluginLoader. + +Base path where plugins should be located: + /nebula/nebula/core/ + +Each plugin should be placed in its own directory inside the base path, +with a filename matching the plugin name (lowercase) and a class matching +the plugin name (capitalized). + +Example for the "reputation" plugin: + /nebula/nebula/core/reputation/reputation.py + +Example for the "trust" plugin: + /nebula/nebula/core/trust/trust.py + +Plugin Class Naming Convention: +------------------------------- +Each plugin must define a class with the same name as the plugin but with +the first letter capitalized. This class must inherit from `NebulaPlugin` +and implement the `initialize_plugin()` method. + +Plugins receive their configuration as a dictionary when instantiated. + +Example for "reputation": +----------------------------------------------------- +File: /nebula/nebula/core/reputation/reputation.py + +from nebula.nebula.core.plugin_loader import NebulaPlugin + +class Reputation(NebulaPlugin): + def __init__(self, config: dict): + self.threshold = config.get("threshold", 0.75) + self.decay_factor = config.get("decay_factor", 0.05) + self.history_length = config.get("history_length", 100) + + async def initialize_plugin(self): + # Initialization logic here + pass +----------------------------------------------------- + +Example for "trust": +----------------------------------------------------- +File: /nebula/nebula/core/trust/trust.py + +from nebula.nebula.core.plugin_loader import NebulaPlugin + +class Trust(NebulaPlugin): + def __init__(self, config: dict): + self.initial_trust = config.get("initial_trust", 0.5) + self.trust_update_rate = config.get("trust_update_rate", 0.1) + self.penalty_factor = config.get("penalty_factor", 0.2) + + async def initialize_plugin(self): + # Initialization logic here + pass +----------------------------------------------------- + +Important Notes: +--------------- +- The plugin class name **must match the plugin name in the JSON but capitalized**. +- The plugin module filename **must be in lowercase**. +- The plugin class must inherit from `NebulaPlugin` and implement the `initialize_plugin()` method. +- Each plugin receives its configuration as a **dictionary** when instantiated. +- The `NebulaPluginLoader` dynamically loads each plugin and passes its respective configuration from the JSON. +""" + + +class NebulaPlugin(ABC): + @abstractmethod + async def initialize_plugin(self): + """Method to be implemented by all plugins. + It should handle any necessary initialization logic.""" + raise NotImplementedError + + +class NebulaPluginLoader: + _instance = None + _lock = Locker("_nebula_pluging_loader_lock", async_lock=False) + + def __new__(cls, config_json=None, base_path="/nebula/nebula/core"): + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self, config_json=None, base_path="/nebula/nebula/core"): + """Initializes the plugin loader with the given configuration JSON and base path.""" + if self._initialized: + return + + self._config_json = config_json or {} + self._base_path = base_path + self._plugins : dict[str, NebulaPlugin] = {} + self._initialized = True + self._verbose = False + + def load_plugins(self): + """Dynamically loads the plugins defined in the JSON configuration.""" + if not self._config_json: + raise ValueError("No Configuration file provided. [JSON] file required.") + + plugin_names = self._config_json.get("plugins", []) + for name in plugin_names: + class_name = name.capitalize() + module_path = os.path.join(self._base_path, name) + module_file = os.path.join(module_path, f"{name}.py") + + if os.path.exists(module_file): + module = self._load_plugin(class_name, module_file, self._config_json.get(name, {})) + if module: + self._plugins[name] = module + else: + logging.error(f"⚠️ Plugin {name} not found on {module_file}") + + def _load_plugin(self, class_name, module_file, config): + """Loads a plugin dynamically and initializes it with its configuration.""" + spec = importlib.util.spec_from_file_location(class_name, module_file) + if spec and spec.loader: + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + if hasattr(module, class_name): # Verify if class exists + return getattr(module, class_name)(config) # Create and instance using plugin config + else: + logging.error(f"⚠️ Cannot create {class_name} plugin, class not found on {module_file}") + return None + + async def initialize_plugins(self): + """Calls the asynchronous initialization method of each loaded plugin.""" + for plugin_name, plugin in self._plugins.items(): + if self._verbose: logging.info(f"Initializing plugin name:{plugin_name}") + await plugin.initialize_plugin() + + def get_plugin(self, name): + """Returns an instance of the plugin if it has been loaded.""" + return self._plugins.get(name, None) \ No newline at end of file From dccd82ffe48ce50aedce8c053da501b22a06d1e4 Mon Sep 17 00:00:00 2001 From: FerTV Date: Fri, 28 Mar 2025 10:42:25 +0100 Subject: [PATCH 148/233] refactor(controller): migrate user modifications from front-end to controller Moved the logic for user modifications from the front-end to the controller to enhance separation of concerns and simplify maintenance. --- nebula/controller.py | 176 ++++++++++++++++++++++++++- nebula/frontend/app.py | 229 +++++++++++++++++++++++++----------- nebula/frontend/database.py | 24 +++- nebula/scenarios.py | 2 +- 4 files changed, 353 insertions(+), 78 deletions(-) diff --git a/nebula/controller.py b/nebula/controller.py index 5c266b56c..47705591f 100755 --- a/nebula/controller.py +++ b/nebula/controller.py @@ -9,12 +9,13 @@ import sys import threading import time +from typing import Annotated import docker import psutil import uvicorn from dotenv import load_dotenv -from fastapi import FastAPI +from fastapi import Body, FastAPI, status, HTTPException, Path from watchdog.events import PatternMatchingEventHandler from watchdog.observers import Observer @@ -40,7 +41,6 @@ def format(self, record): # Initialize FastAPI app outside the Controller class app = FastAPI() - # Define endpoints outside the Controller class @app.get("/") async def read_root(): @@ -144,6 +144,174 @@ async def get_available_gpu(): pass +@app.get("/scenarios/{user}/{role}") +async def get_scenarios( + user: Annotated[ + str, + Path( + regex="^[a-zA-Z0-9_-]+$", + min_length=1, + max_length=50, + description="Valid username" + ) + ], + role: Annotated[ + str, + Path( + regex="^[a-zA-Z0-9_-]+$", + min_length=1, + max_length=50, + description="Valid role" + ) + ] +): + from nebula.frontend.database import get_all_scenarios_and_check_completed, get_running_scenario + + try: + scenarios = get_all_scenarios_and_check_completed(username=user, role=role) + scenario_running = get_running_scenario() + except Exception as e: + logging.error(f"Error obtaining scenarios: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + return {"scenarios": scenarios, "scenario_running": scenario_running} + + +@app.get("/users/list") +async def list_users_controller(all_info: bool = False): + """ + Controller endpoint to retrieve the list of users. + If all_info is True, returns the complete information converted into dictionaries. + """ + from nebula.frontend.database import list_users + + try: + user_list = list_users(all_info) + if all_info: + # Convert each sqlite3.Row to a dictionary so that it is JSON serializable. + user_list = [dict(user) for user in user_list] + return {"users": user_list} + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error retrieving users: {e}" + ) + + +@app.post("/user/add") +async def add_user_controller( + user: str = Body(...), + password: str = Body(...), + role: str = Body(...) +): + """ + Controller endpoint that inserts a new user into the database. + + Parameters: + - user: The username for the new user. + - password: The user's password. + - role: The role assigned to the new user. + + Returns a success message if the user is added, or an HTTP error if an exception occurs. + """ + from nebula.frontend.database import add_user + + try: + add_user(user, password, role) + return {"detail": "User added successfully"} + except Exception as e: + logging.error(f"Error adding user: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error adding user: {e}" + ) + + +@app.post("/user/delete") +async def add_user_controller( + user: str = Body(..., embed=True) +): + """ + Controller endpoint that inserts a new user into the database. + + Parameters: + - user: The username for the new user. + + Returns a success message if the user is deleted, or an HTTP error if an exception occurs. + """ + from nebula.frontend.database import delete_user_from_db + + try: + delete_user_from_db(user) + return {"detail": "User deleted successfully"} + except Exception as e: + logging.error(f"Error deleting user: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error deleting user: {e}" + ) + + +@app.post("/user/update") +async def add_user_controller( + user: str = Body(...), + password: str = Body(...), + role: str = Body(...) +): + """ + Controller endpoint that modifies a user of the database. + + Parameters: + - user: The username of the user. + - password: The user's password. + - role: The role of the user. + + Returns a success message if the user is updated, or an HTTP error if an exception occurs. + """ + from nebula.frontend.database import update_user + + try: + update_user(user, password, role) + return {"detail": "User updated successfully"} + except Exception as e: + logging.error(f"Error updating user: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error updating user: {e}" + ) + + +@app.post("/user/verify") +async def add_user_controller( + user: str = Body(...), + password: str = Body(...) +): + """ + Controller endpoint that verifies if it's a valid user. + + Parameters: + - user: The username of the user. + - password: The user's password. + + Returns a success message if the user is verified, or an HTTP error if an exception occurs. + """ + from nebula.frontend.database import list_users, verify, get_user_info + + try: + user_submitted = user.upper() + if (user_submitted in list_users()) and verify(user_submitted, password): + user_info = get_user_info(user_submitted) + return {"user": user_submitted, "role": user_info[2]} + else: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) + except Exception as e: + logging.error(f"Error verifying user: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error verifying user: {e}" + ) + + class NebulaEventHandler(PatternMatchingEventHandler): """ NebulaEventHandler handles file system events for .sh scripts. @@ -450,6 +618,10 @@ def start(self): app_thread.start() logging.info(f"NEBULA Controller is running at port {self.controller_port}") + from nebula.frontend.database import initialize_databases + + asyncio.run(initialize_databases(self.databases_dir)) + if self.production: self.run_waf() logging.info(f"NEBULA WAF is running at port {self.waf_port}") diff --git a/nebula/frontend/app.py b/nebula/frontend/app.py index db21f0ad4..8496b5514 100755 --- a/nebula/frontend/app.py +++ b/nebula/frontend/app.py @@ -14,6 +14,7 @@ import aiohttp import requests from dotenv import load_dotenv +from yarl import URL sys.path.append(os.path.dirname(os.path.abspath(__file__))) sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..")) @@ -224,21 +225,6 @@ def get_session(request: Request) -> dict: return request.session -def set_default_user(): - username = os.environ.get("NEBULA_DEFAULT_USER", "admin") - password = os.environ.get("NEBULA_DEFAULT_PASSWORD", "admin") - if not list_users(): - add_user(username, password, "admin") - if not verify_hash_algorithm(username): - update_user(username, password, "admin") - - -@app.on_event("startup") -async def startup_event(): - await initialize_databases() - set_default_user() - - class UserData: def __init__(self): self.nodes_registration = {} @@ -279,6 +265,80 @@ async def custom_http_exception_handler(request: Request, exc: StarletteHTTPExce return await request.app.default_exception_handler(request, exc) +async def get_available_gpus(): + url = f"http://{settings.controller_host}:{settings.controller_port}/available_gpus" + async with aiohttp.ClientSession() as session, session.get(url) as response: + if response.status == 200: + try: + return await response.json() + except Exception as e: + return {"error": f"Failed to parse JSON: {e}"} + else: + return None + + +async def get_least_memory_gpu(): + url = f"http://{settings.controller_host}:{settings.controller_port}/least_memory_gpu" + async with aiohttp.ClientSession() as session, session.get(url) as response: + if response.status == 200: + try: + return await response.json() + except Exception as e: + return {"error": f"Failed to parse JSON: {e}"} + else: + return None + + +async def get_scenarios(user, role): + try: + base_url = URL(f"http://{settings.controller_host}:{settings.controller_port}/scenarios") + url = base_url / user / role + logging.info(f"Requesting URL: {url}") + + async with aiohttp.ClientSession() as session: + async with session.get(url, timeout=10) as response: + if response.status == 200: + try: + return await response.json() + except Exception as e: + logging.error(f"Error parsing JSON: {e}") + return {"error": f"Error parsing JSON: {e}"} + else: + logging.error(f"Request error: {response.status}") + return {"error": f"Request error: {response.status}"} + except Exception as e: + logging.error(f"Unexpected error in call_get_scenarios: {e}") + return {"error": f"Unexpected error: {e}"} + + +async def list_users(allinfo=True): + """ + Retrieves the list of users by calling the controller endpoint. + + Parameters: + - all_info (bool): If True, retrieves detailed information for each user. + + Returns: + - A list of users, as provided by the controller. + """ + controller_url = f"http://{settings.controller_host}:{settings.controller_port}/users/list?all_info={allinfo}" + try: + async with aiohttp.ClientSession() as client: + async with client.get(controller_url, timeout=10) as resp: + if resp.status != 200: + raise HTTPException( + status_code=resp.status, + detail="Error retrieving user list from controller" + ) + data = await resp.json() + user_list = data["users"] + except Exception as e: + logging.error(f"Error calling controller endpoint: {e}") + raise HTTPException(status_code=500, detail=f"Controller request failed: {e}") + + return user_list + + @app.get("/", response_class=HTMLResponse) async def index(): return RedirectResponse(url="/platform") @@ -313,11 +373,12 @@ async def nebula_dashboard_private(request: Request, scenario_name: str, session @app.get("/platform/admin", response_class=HTMLResponse) async def nebula_admin(request: Request, session: dict = Depends(get_session)): if session.get("role") == "admin": - user_list = list_users(all_info=True) + user_list = await list_users() + user_table = zip( range(1, len(user_list) + 1), - [user[0] for user in user_list], - [user[2] for user in user_list], + [user["user"] for user in user_list], + [user["role"] for user in user_list], strict=False, ) return templates.TemplateResponse("admin.html", {"request": request, "users": user_table}) @@ -379,14 +440,22 @@ async def nebula_login( user: str = Form(...), password: str = Form(...), ): - user_submitted = user.upper() - if (user_submitted in list_users()) and verify(user_submitted, password): - user_info = get_user_info(user_submitted) - session["user"] = user_submitted - session["role"] = user_info[2] - return JSONResponse({"message": "Login successful"}, status_code=200) - else: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) + controller_url = f"http://{settings.controller_host}:{settings.controller_port}/user/verify" + payload = {"user": user, "password": password} + try: + async with aiohttp.ClientSession() as client: + async with client.post(controller_url, json=payload, timeout=10) as resp: + if resp.status == 200: + # Successful response from the controller. + data = await resp.json() + session["user"] = data.get("user") + session["role"] = data.get("role") + return JSONResponse({"message": "Login successful"}, status_code=200) + else: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) + except Exception as e: + logging.error(f"Error calling controller endpoint: {e}") + raise HTTPException(status_code=500, detail=f"Controller request failed: {e}") @app.get("/platform/logout") @@ -403,8 +472,22 @@ async def nebula_delete_user(user: str, request: Request, session: dict = Depend if user == session["user"]: # Current user can't delete himself. raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) - delete_user_from_db(user) - return RedirectResponse(url="/platform/admin") + controller_url = f"http://{settings.controller_host}:{settings.controller_port}/user/delete" + payload = {"user": user} + + try: + async with aiohttp.ClientSession() as client: + async with client.post(controller_url, json=payload, timeout=10) as resp: + if resp.status == 200: + # Successful response from the controller. + return RedirectResponse(url="/platform/admin") + else: + error_text = await resp.text() + logging.error(f"Controller error: {error_text}") + raise HTTPException(status_code=resp.status, detail=error_text) + except Exception as e: + logging.error(f"Error calling controller endpoint: {e}") + raise HTTPException(status_code=500, detail=f"Controller request failed: {e}") else: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) @@ -417,15 +500,33 @@ async def nebula_add_user( password: str = Form(...), role: str = Form(...), ): - if session.get("role") == "admin": # only Admin should be able to add user. - user_list = list_users(all_info=True) - if user.upper() in user_list or " " in user or "'" in user or '"' in user: - return RedirectResponse(url="/platform/admin", status_code=status.HTTP_303_SEE_OTHER) - else: - add_user(user, password, role) - return RedirectResponse(url="/platform/admin", status_code=status.HTTP_303_SEE_OTHER) - else: + # Only admin users can add new users. + if session.get("role") != "admin": raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) + + # Basic validation on the user value before calling the controller. + user_list = await list_users() + + if user.upper() in user_list or " " in user or "'" in user or '"' in user: + return RedirectResponse(url="/platform/admin", status_code=status.HTTP_303_SEE_OTHER) + + # Call the controller's endpoint to add the user. + controller_url = f"http://{settings.controller_host}:{settings.controller_port}/user/add" + payload = {"user": user, "password": password, "role": role} + + try: + async with aiohttp.ClientSession() as client: + async with client.post(controller_url, json=payload, timeout=10) as resp: + if resp.status == 200: + # Successful response from the controller. + return RedirectResponse(url="/platform/admin", status_code=status.HTTP_303_SEE_OTHER) + else: + error_text = await resp.text() + logging.error(f"Controller error: {error_text}") + raise HTTPException(status_code=resp.status, detail=error_text) + except Exception as e: + logging.error(f"Error calling controller endpoint: {e}") + raise HTTPException(status_code=500, detail=f"Controller request failed: {e}") @app.post("/platform/user/update") @@ -438,8 +539,23 @@ async def nebula_update_user( ): if "user" not in session or session["role"] != "admin": return RedirectResponse(url="/platform", status_code=status.HTTP_302_FOUND) - update_user(user, password, role) - return RedirectResponse(url="/platform/admin", status_code=status.HTTP_302_FOUND) + + controller_url = f"http://{settings.controller_host}:{settings.controller_port}/user/update" + payload = {"user": user, "password": password, "role": role} + + try: + async with aiohttp.ClientSession() as client: + async with client.post(controller_url, json=payload, timeout=10) as resp: + if resp.status == 200: + # Successful response from the controller. + return RedirectResponse(url="/platform/admin", status_code=status.HTTP_302_FOUND) + else: + error_text = await resp.text() + logging.error(f"Controller error: {error_text}") + raise HTTPException(status_code=resp.status, detail=error_text) + except Exception as e: + logging.error(f"Error calling controller endpoint: {e}") + raise HTTPException(status_code=500, detail=f"Controller request failed: {e}") @app.get("/platform/api/dashboard/runningscenario", response_class=JSONResponse) @@ -465,30 +581,6 @@ async def get_host_resources(): return None -async def get_available_gpus(): - url = f"http://{settings.controller_host}:{settings.controller_port}/available_gpus" - async with aiohttp.ClientSession() as session, session.get(url) as response: - if response.status == 200: - try: - return await response.json() - except Exception as e: - return {"error": f"Failed to parse JSON: {e}"} - else: - return None - - -async def get_least_memory_gpu(): - url = f"http://{settings.controller_host}:{settings.controller_port}/least_memory_gpu" - async with aiohttp.ClientSession() as session, session.get(url) as response: - if response.status == 200: - try: - return await response.json() - except Exception as e: - return {"error": f"Failed to parse JSON: {e}"} - else: - return None - - async def check_enough_resources(): resources = await get_host_resources() @@ -557,10 +649,10 @@ async def monitor_resources(): @app.get("/platform/dashboard", response_class=HTMLResponse) async def nebula_dashboard(request: Request, session: dict = Depends(get_session)): if "user" in session: - scenarios = get_all_scenarios_and_check_completed( - username=session["user"], role=session["role"] - ) # Get all scenarios after checking if they are completed - scenario_running = get_running_scenario() + response = await get_scenarios(session["user"], session["role"]) + scenarios = response.get("scenarios") + scenario_running = response.get("scenario_running") + if session["user"] not in user_data_store: user_data_store[session["user"]] = UserData() @@ -875,10 +967,7 @@ async def nebula_monitor_log(scenario_name: str, id: str): raise HTTPException(status_code=404, detail="Log file not found") -@app.get( - "/platform/dashboard/{scenario_name}/node/{id}/infolog/{number}", - response_class=PlainTextResponse, -) +@app.get("/platform/dashboard/{scenario_name}/node/{id}/infolog/{number}", response_class=PlainTextResponse,) async def nebula_monitor_log_x(scenario_name: str, id: str, number: int): logs = FileUtils.check_path(settings.log_dir, os.path.join(scenario_name, f"participant_{id}.log")) if os.path.exists(logs): diff --git a/nebula/frontend/database.py b/nebula/frontend/database.py index aa40ef377..2019544cb 100755 --- a/nebula/frontend/database.py +++ b/nebula/frontend/database.py @@ -2,15 +2,16 @@ import datetime import json import logging +import os import sqlite3 import aiosqlite from argon2 import PasswordHasher -user_db_file_location = "databases/users.db" -node_db_file_location = "databases/nodes.db" -scenario_db_file_location = "databases/scenarios.db" -notes_db_file_location = "databases/notes.db" +user_db_file_location = "users.db" +node_db_file_location = "nodes.db" +scenario_db_file_location = "scenarios.db" +notes_db_file_location = "notes.db" _node_lock = asyncio.Lock() @@ -45,7 +46,14 @@ async def ensure_columns(conn, table_name, desired_columns): await conn.commit() -async def initialize_databases(): +async def initialize_databases(databases_dir): + global user_db_file_location, node_db_file_location, scenario_db_file_location, notes_db_file_location + + user_db_file_location = os.path.join(databases_dir, user_db_file_location) + node_db_file_location = os.path.join(databases_dir, node_db_file_location) + scenario_db_file_location = os.path.join(databases_dir, scenario_db_file_location) + notes_db_file_location = os.path.join(databases_dir, notes_db_file_location) + await setup_database(user_db_file_location) await setup_database(node_db_file_location) await setup_database(scenario_db_file_location) @@ -224,6 +232,12 @@ async def initialize_databases(): desired_columns = {"scenario": "TEXT PRIMARY KEY", "scenario_notes": "TEXT"} await ensure_columns(conn, "notes", desired_columns) + # Add default user + if not list_users(): + add_user("admin", "admin", "admin") + if not verify_hash_algorithm("admin"): + update_user("admin", "admin", "admin") + def list_users(all_info=False): with sqlite3.connect(user_db_file_location) as conn: diff --git a/nebula/scenarios.py b/nebula/scenarios.py index 36ddeab63..add9a90c7 100644 --- a/nebula/scenarios.py +++ b/nebula/scenarios.py @@ -254,7 +254,7 @@ def from_dict(cls, data): class ScenarioManagement: def __init__(self, scenario, user=None): # Current scenario - self.scenario = Scenario.from_dict(scenario) + self.scenario = scenario # Uid of the user self.user = user # Scenario management settings From 9766c3efaa5cbe088a1d2645bab106948588dbba Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 2 Apr 2025 13:30:37 +0200 Subject: [PATCH 149/233] feature SA command --- .../communications/communicationattack.py | 7 +- .../attacks/communications/delayerattack.py | 3 +- .../attacks/communications/floodingattack.py | 3 +- nebula/addons/mobility.py | 7 +- .../nebulanetworksimulator.py | 9 ++- .../networksimulation/networksimulator.py | 4 +- nebula/addons/reporter.py | 7 +- nebula/core/addonmanager.py | 4 +- nebula/core/datasets/nebuladataset.py | 3 +- nebula/core/engine.py | 23 +----- nebula/core/models/nebulamodel.py | 8 -- nebula/core/network/communications.py | 53 +++++++++++-- nebula/core/network/connection.py | 6 +- nebula/core/network/discoverer.py | 7 +- .../externalconnectionservice.py | 4 +- .../nebuladiscoveryservice.py | 5 +- nebula/core/network/forwarder.py | 7 +- nebula/core/network/health.py | 5 +- nebula/core/network/messages.py | 7 +- nebula/core/network/propagator.py | 20 +++-- .../awareness/commands/connectivitycommand.py | 18 +++++ .../awareness/commands/sacommand.py | 78 +++++++++++++++++++ .../awareness/samodule.py | 5 +- .../awareness/sanetwork/sanetwork.py | 4 +- .../core/situationalawareness/nodemanager.py | 18 ++--- 25 files changed, 227 insertions(+), 88 deletions(-) create mode 100644 nebula/core/situationalawareness/awareness/commands/connectivitycommand.py create mode 100644 nebula/core/situationalawareness/awareness/commands/sacommand.py diff --git a/nebula/addons/attacks/communications/communicationattack.py b/nebula/addons/attacks/communications/communicationattack.py index 6552e9e52..712ab4ce9 100644 --- a/nebula/addons/attacks/communications/communicationattack.py +++ b/nebula/addons/attacks/communications/communicationattack.py @@ -4,6 +4,7 @@ from abc import abstractmethod from nebula.addons.attacks.attacks import Attack +from nebula.core.network.communications import CommunicationsManager class CommunicationAttack(Attack): @@ -46,17 +47,17 @@ async def select_targets(self): if self.selection_interval: if self.last_selection_round % self.selection_interval == 0: logging.info("Recalculating targets...") - all_nodes = await self.engine.cm.get_addrs_current_connections(only_direct=True) + all_nodes = await CommunicationsManager.get_instance().get_addrs_current_connections(only_direct=True) num_targets = max(1, int(len(all_nodes) * (self.selectivity_percentage / 100))) self.targets = set(random.sample(list(all_nodes), num_targets)) elif not self.targets: logging.info("Calculating targets...") - all_nodes = await self.engine.cm.get_addrs_current_connections(only_direct=True) + all_nodes = await CommunicationsManager.get_instance().get_addrs_current_connections(only_direct=True) num_targets = max(1, int(len(all_nodes) * (self.selectivity_percentage / 100))) self.targets = set(random.sample(list(all_nodes), num_targets)) else: logging.info("All neighbors selected as targets") - self.targets = await self.engine.cm.get_addrs_current_connections(only_direct=True) + self.targets = CommunicationsManager.get_instance().get_addrs_current_connections(only_direct=True) logging.info(f"Selected {self.selectivity_percentage}% targets from neighbors: {self.targets}") self.last_selection_round += 1 diff --git a/nebula/addons/attacks/communications/delayerattack.py b/nebula/addons/attacks/communications/delayerattack.py index afe611fde..bdc5dfdc2 100644 --- a/nebula/addons/attacks/communications/delayerattack.py +++ b/nebula/addons/attacks/communications/delayerattack.py @@ -3,6 +3,7 @@ from functools import wraps from nebula.addons.attacks.communications.communicationattack import CommunicationAttack +from nebula.core.network.communications import CommunicationsManager class DelayerAttack(CommunicationAttack): @@ -32,7 +33,7 @@ def __init__(self, engine, attack_params: dict): super().__init__( engine, - engine._cm, + CommunicationsManager.get_instance(), "send_model", round_start, round_stop, diff --git a/nebula/addons/attacks/communications/floodingattack.py b/nebula/addons/attacks/communications/floodingattack.py index 0d9c070fc..fa308fe27 100644 --- a/nebula/addons/attacks/communications/floodingattack.py +++ b/nebula/addons/attacks/communications/floodingattack.py @@ -2,6 +2,7 @@ from functools import wraps from nebula.addons.attacks.communications.communicationattack import CommunicationAttack +from nebula.core.network.communications import CommunicationsManager class FloodingAttack(CommunicationAttack): @@ -33,7 +34,7 @@ def __init__(self, engine, attack_params: dict): super().__init__( engine, - engine._cm, + CommunicationsManager.get_instance(), "send_message", round_start, round_stop, diff --git a/nebula/addons/mobility.py b/nebula/addons/mobility.py index 5f7f0d063..ef82b09e1 100755 --- a/nebula/addons/mobility.py +++ b/nebula/addons/mobility.py @@ -13,7 +13,7 @@ class Mobility: - def __init__(self, config, cm: "CommunicationsManager", verbose=False): + def __init__(self, config, verbose=False): """ Initializes the mobility module with specified configuration and communication manager. @@ -50,7 +50,6 @@ def __init__(self, config, cm: "CommunicationsManager", verbose=False): """ logging.info("Starting mobility module...") self.config = config - self.cm = cm self.grace_time = self.config.participant["mobility_args"]["grace_time_mobility"] self.period = self.config.participant["mobility_args"]["change_geo_interval"] self.mobility = self.config.participant["mobility_args"]["mobility"] @@ -70,6 +69,10 @@ def __init__(self, config, cm: "CommunicationsManager", verbose=False): self._nodes_distances_lock = Locker("nodes_distances_lock", async_lock=True) self._verbose = verbose + @property + def cm(self): + return CommunicationsManager.get_instance() + @property def round(self): """ diff --git a/nebula/addons/networksimulation/nebulanetworksimulator.py b/nebula/addons/networksimulation/nebulanetworksimulator.py index a85b57bf0..05e25d6b7 100644 --- a/nebula/addons/networksimulation/nebulanetworksimulator.py +++ b/nebula/addons/networksimulation/nebulanetworksimulator.py @@ -18,8 +18,7 @@ class NebulaNS(NetworkSimulator): } IP_MULTICAST = "239.255.255.250" - def __init__(self, communication_manager: "CommunicationsManager", changing_interval, interface, verbose=False): - self._cm = communication_manager + def __init__(self, changing_interval, interface, verbose=False): self._refresh_interval = changing_interval self._node_interface = interface self._verbose = verbose @@ -27,11 +26,15 @@ def __init__(self, communication_manager: "CommunicationsManager", changing_inte self._network_conditions_lock = Locker("network_conditions_lock", async_lock=True) self._current_network_conditions = {} self._running = False + + @property + def cm(self): + return CommunicationsManager.get_instance() async def start(self): logging.info("🌐 Nebula Network Simulator starting...") self._running = True - grace_time = self._cm.config.participant["mobility_args"]["grace_time_mobility"] + grace_time = self.cm.config.participant["mobility_args"]["grace_time_mobility"] # if self._verbose: logging.info(f"Waiting {grace_time}s to start applying network conditions based on distances between devices") # await asyncio.sleep(grace_time) await EventManager.get_instance().subscribe_addonevent(GPSEvent, self._change_network_conditions_based_on_distances) diff --git a/nebula/addons/networksimulation/networksimulator.py b/nebula/addons/networksimulation/networksimulator.py index ab63378c8..3246712bb 100644 --- a/nebula/addons/networksimulation/networksimulator.py +++ b/nebula/addons/networksimulation/networksimulator.py @@ -27,7 +27,7 @@ def clear_network_conditions(self, interface): class NetworkSimulatorException(Exception): pass -def factory_network_simulator(net_sim, communication_manager, changing_interval, interface, verbose) -> NetworkSimulator: +def factory_network_simulator(net_sim, changing_interval, interface, verbose) -> NetworkSimulator: from nebula.addons.networksimulation.nebulanetworksimulator import NebulaNS SIMULATION_SERVICES = { @@ -37,6 +37,6 @@ def factory_network_simulator(net_sim, communication_manager, changing_interval, net_serv = SIMULATION_SERVICES.get(net_sim, NebulaNS) if net_serv: - return net_serv(communication_manager, changing_interval, interface, verbose) + return net_serv(changing_interval, interface, verbose) else: raise NetworkSimulatorException(f"Network Simulator {net_sim} not found") \ No newline at end of file diff --git a/nebula/addons/reporter.py b/nebula/addons/reporter.py index 7570c359e..b4ea2cb65 100755 --- a/nebula/addons/reporter.py +++ b/nebula/addons/reporter.py @@ -14,7 +14,7 @@ class Reporter: - def __init__(self, config, trainer, cm: "CommunicationsManager"): + def __init__(self, config, trainer): """ Initializes the reporter module for sending periodic updates to a dashboard controller. @@ -50,7 +50,6 @@ def __init__(self, config, trainer, cm: "CommunicationsManager"): logging.info("Starting reporter module") self.config = config self.trainer = trainer - self.cm = cm self.frequency = self.config.participant["reporter_args"]["report_frequency"] self.grace_time = self.config.participant["reporter_args"]["grace_time_reporter"] self.data_queue = asyncio.Queue() @@ -68,6 +67,10 @@ def __init__(self, config, trainer, cm: "CommunicationsManager"): self.acc_packets_sent = 0 self.acc_packets_recv = 0 + @property + def cm(self): + return CommunicationsManager.get_instance() + async def enqueue_data(self, name, value): """ Asynchronously enqueues data for reporting. diff --git a/nebula/core/addonmanager.py b/nebula/core/addonmanager.py index fd712114d..b729ed477 100644 --- a/nebula/core/addonmanager.py +++ b/nebula/core/addonmanager.py @@ -17,12 +17,12 @@ def __init__(self, engine : "Engine", config): async def deploy_additional_services(self): print_msg_box(msg="Deploying Additional Services", indent=2, title="Addons Manager") if self._config.participant["mobility_args"]["mobility"]: - mobility = Mobility(self._config, self._engine.cm, verbose=False) + mobility = Mobility(self._config, verbose=False) self._addons.append(mobility) if self._config.participant["network_args"]["simulation"]: refresh_conditions_interval = 5 - network_simulation = factory_network_simulator("nebula", self._engine.cm, refresh_conditions_interval, "eth0", verbose=False) + network_simulation = factory_network_simulator("nebula", refresh_conditions_interval, "eth0", verbose=False) self._addons.append(network_simulation) update_interval = 5 diff --git a/nebula/core/datasets/nebuladataset.py b/nebula/core/datasets/nebuladataset.py index 0222b701e..ec8c0b675 100755 --- a/nebula/core/datasets/nebuladataset.py +++ b/nebula/core/datasets/nebuladataset.py @@ -315,7 +315,8 @@ def data_partitioning(self, plot=False): f"Partitioning data for {self.__class__.__name__} | Partitions: {self.partitions_number} | IID: {self.iid} | Partition: {self.partition} | Partition parameter: {self.partition_parameter}" ) - self.iid = "a" #TODO REMOVE + #self.iid = "a" #TODO REMOVE + logging.info(f"Scenario with data distribution: {self.iid}") if self.iid == "IID": self.train_indices_map = self.generate_iid_map(self.train_set) elif self.iid == "Non-IID": diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 824a83a2c..befe40b66 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -142,11 +142,10 @@ def __init__( self.config.reload_config_file() - self._cm = CommunicationsManager(engine=self) + cm = CommunicationsManager(engine=self) # Set the communication manager in the model (send messages from there) - self.trainer.model.set_communication_manager(self._cm) - self._reporter = Reporter(config=self.config, trainer=self.trainer, cm=self.cm) + self._reporter = Reporter(config=self.config, trainer=self.trainer) self._sinchronized_status = True self.sinchronized_status_lock = Locker(name="sinchronized_status_lock") @@ -173,7 +172,7 @@ def __init__( @property def cm(self): - return self._cm + return CommunicationsManager.get_instance() @property def reporter(self): @@ -474,26 +473,12 @@ async def create_trainer_module(self): async def start_communications(self): await self.register_events_callbacks() await self.aggregator.init() - logging.info(f"Neighbors: {self.config.participant['network_args']['neighbors']}") - logging.info( - f"πŸ’€ Cold start time: {self.config.participant['misc_args']['grace_time_connection']} seconds before connecting to the network" - ) - await asyncio.sleep(self.config.participant["misc_args"]["grace_time_connection"]) - await self.cm.start() initial_neighbors = self.config.participant["network_args"]["neighbors"].split() - for i in initial_neighbors: - addr = f"{i.split(':')[0]}:{i.split(':')[1]}" - await self.cm.connect(addr, direct=True) - await asyncio.sleep(1) - while not self.cm.verify_connections(initial_neighbors): - await asyncio.sleep(1) - current_connections = await self.cm.get_addrs_current_connections() - logging.info(f"Connections verified: {current_connections}") + await self.cm.start_communications(initial_neighbors) if self.mobility: logging.info("Building NodeManager configurations...") await self.nm.set_configs() await self._reporter.start() - await self.cm.deploy_additional_services() await self._addon_manager.deploy_additional_services() await asyncio.sleep(self.config.participant["misc_args"]["grace_time_connection"] // 2) diff --git a/nebula/core/models/nebulamodel.py b/nebula/core/models/nebulamodel.py index 8e6eb6cfc..070616a95 100755 --- a/nebula/core/models/nebulamodel.py +++ b/nebula/core/models/nebulamodel.py @@ -200,14 +200,6 @@ def __init__( self._current_loss = -1 self._optimizer = None - def set_communication_manager(self, communication_manager): - self.communication_manager = communication_manager - - def get_communication_manager(self): - if self.communication_manager is None: - raise ValueError("Communication manager not set.") - return self.communication_manager - @abstractmethod def forward(self, x): """Forward pass of the model.""" diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index d11b296ef..19d581b45 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -24,7 +24,26 @@ class CommunicationsManager: + _instance = None + _lock = Locker("communications_manager_lock", async_lock=False) + + def __new__(cls, engine: "Engine"): + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + @classmethod + def get_instance(cls): + """Onbtain CommunicationsManager instance""" + if cls._instance is None: + raise ValueError("CommunicationsManager has not been initialized yet.") + return cls._instance + def __init__(self, engine: "Engine"): + if hasattr(self, '_initialized') and self._initialized: + return # Avoid reinicialization + logging.info("🌐 Initializing Communications Manager") self._engine = engine self.addr = engine.get_addr() @@ -47,16 +66,16 @@ def __init__(self, engine: "Engine"): self.outgoing_connections = {} self.ready_connections = set() - self._mm = MessagesManager(addr=self.addr, config=self.config, cm=self) + self._mm = MessagesManager(addr=self.addr, config=self.config) self.received_messages_hashes = collections.deque( maxlen=self.config.participant["message_args"]["max_local_messages"] ) self.receive_messages_lock = Locker(name="receive_messages_lock", async_lock=True) - self._discoverer = Discoverer(addr=self.addr, config=self.config, cm=self) - # self._health = Health(addr=self.addr, config=self.config, cm=self) - self._forwarder = Forwarder(config=self.config, cm=self) - self._propagator = Propagator(cm=self) + self._discoverer = Discoverer(addr=self.addr, config=self.config) + # self._health = Health(addr=self.addr, config=self.config) + self._forwarder = Forwarder(config=self.config) + self._propagator = Propagator() # List of connections to reconnect {addr: addr, tries: 0} self.connections_reconnect = [] @@ -71,7 +90,9 @@ def __init__(self, engine: "Engine"): self._blacklist = BlackList() # Connection service to communicate with external devices - self._external_connection_service = factory_connection_service("nebula", self, self.addr) + self._external_connection_service = factory_connection_service("nebula", self.addr) + + self._initialized = True @property def engine(self): @@ -120,6 +141,24 @@ async def check_federation_ready(self): async def add_ready_connection(self, addr): self.ready_connections.add(addr) + async def start_communications(self, initial_neighbors): + logging.info(f"Neighbors: {self.config.participant['network_args']['neighbors']}") + logging.info( + f"πŸ’€ Cold start time: {self.config.participant['misc_args']['grace_time_connection']} seconds before connecting to the network" + ) + await asyncio.sleep(self.config.participant["misc_args"]["grace_time_connection"]) + await self.start() + for i in initial_neighbors: + addr = f"{i.split(':')[0]}:{i.split(':')[1]}" + await self.connect(addr, direct=True) + await asyncio.sleep(1) + while not self.verify_connections(initial_neighbors): + await asyncio.sleep(1) + current_connections = await self.get_addrs_current_connections() + logging.info(f"Connections verified: {current_connections}") + await self.deploy_additional_services() + + """ ############################## # PROCESSING MESSAGES # ############################## @@ -363,7 +402,6 @@ async def process_connection(reader, writer): logging.info(f"πŸ”— [incoming] Creating new connection with {addr} (id {connected_node_id})") await writer.drain() connection = Connection( - self, reader, writer, connected_node_id, @@ -612,7 +650,6 @@ async def process_establish_connection(addr, direct, reconnect): f"πŸ”— [outgoing] Creating new connection with {host}:{port} (id {connected_node_id})" ) connection = Connection( - self, reader, writer, connected_node_id, diff --git a/nebula/core/network/connection.py b/nebula/core/network/connection.py index 8ea31e62b..bfe2b407d 100755 --- a/nebula/core/network/connection.py +++ b/nebula/core/network/connection.py @@ -32,7 +32,6 @@ class Connection: def __init__( self, - cm: "CommunicationsManager", reader, writer, id, @@ -44,7 +43,6 @@ def __init__( config=None, prio="MEDIUM", ): - self.cm = cm self.reader = reader self.writer = writer self.id = str(id) @@ -93,6 +91,10 @@ def __repr__(self): def __del__(self): self.stop() + @property + def cm(self): + return CommunicationsManager.get_instance() + def get_addr(self): return self.addr diff --git a/nebula/core/network/discoverer.py b/nebula/core/network/discoverer.py index b04d08af6..9df7c970e 100755 --- a/nebula/core/network/discoverer.py +++ b/nebula/core/network/discoverer.py @@ -9,15 +9,18 @@ class Discoverer: - def __init__(self, addr, config, cm: "CommunicationsManager"): + def __init__(self, addr, config): print_msg_box(msg="Starting discoverer module...", indent=2, title="Discoverer module") self.addr = addr self.config = config - self.cm = cm self.grace_time = self.config.participant["discoverer_args"]["grace_time_discovery"] self.period = self.config.participant["discoverer_args"]["discovery_frequency"] self.interval = self.config.participant["discoverer_args"]["discovery_interval"] + @property + def cm(self): + return CommunicationsManager.get_instance() + async def start(self): asyncio.create_task(self.run_discover()) diff --git a/nebula/core/network/externalconnection/externalconnectionservice.py b/nebula/core/network/externalconnection/externalconnectionservice.py index cd356a324..1f30c7e30 100644 --- a/nebula/core/network/externalconnection/externalconnectionservice.py +++ b/nebula/core/network/externalconnection/externalconnectionservice.py @@ -34,7 +34,7 @@ async def modify_beacon_frequency(self, frequency): class ExternalConnectionServiceException(Exception): pass -def factory_connection_service(con_serv, cm, addr) -> ExternalConnectionService: +def factory_connection_service(con_serv, addr) -> ExternalConnectionService: from nebula.core.network.externalconnection.nebuladiscoveryservice import NebulaConnectionService CONNECTION_SERVICES = { @@ -44,6 +44,6 @@ def factory_connection_service(con_serv, cm, addr) -> ExternalConnectionService: con_serv = CONNECTION_SERVICES.get(con_serv, NebulaConnectionService) if con_serv: - return con_serv(cm, addr) + return con_serv(addr) else: raise ExternalConnectionServiceException(f"Connection Service {con_serv} not found") \ No newline at end of file diff --git a/nebula/core/network/externalconnection/nebuladiscoveryservice.py b/nebula/core/network/externalconnection/nebuladiscoveryservice.py index e48c5c1d4..6285487ed 100644 --- a/nebula/core/network/externalconnection/nebuladiscoveryservice.py +++ b/nebula/core/network/externalconnection/nebuladiscoveryservice.py @@ -162,8 +162,7 @@ async def send_beacon(self): logging.error(f"Error sending beacon: {e}") class NebulaConnectionService(ExternalConnectionService): - def __init__(self, cm: "CommunicationsManager", addr): - self._cm = cm + def __init__(self, addr): self.nodes_found = set() self.addr = addr self.server : NebulaServerProtocol = None @@ -173,7 +172,7 @@ def __init__(self, cm: "CommunicationsManager", addr): @property def cm(self): - return self._cm + return CommunicationsManager.get_instance() async def start(self): loop = asyncio.get_running_loop() diff --git a/nebula/core/network/forwarder.py b/nebula/core/network/forwarder.py index dc31436e0..d7febd4df 100755 --- a/nebula/core/network/forwarder.py +++ b/nebula/core/network/forwarder.py @@ -11,10 +11,9 @@ class Forwarder: - def __init__(self, config, cm: "CommunicationsManager"): + def __init__(self, config): print_msg_box(msg="Starting forwarder module...", indent=2, title="Forwarder module") self.config = config - self.cm = cm self.pending_messages = asyncio.Queue() self.pending_messages_lock = Locker("pending_messages_lock", verbose=False, async_lock=True) @@ -22,6 +21,10 @@ def __init__(self, config, cm: "CommunicationsManager"): self.number_forwarded_messages = self.config.participant["forwarder_args"]["number_forwarded_messages"] self.messages_interval = self.config.participant["forwarder_args"]["forward_messages_interval"] + @property + def cm(self): + return CommunicationsManager.get_instance() + async def start(self): asyncio.create_task(self.run_forwarder()) diff --git a/nebula/core/network/health.py b/nebula/core/network/health.py index 5c9e9da1f..c17075e8c 100755 --- a/nebula/core/network/health.py +++ b/nebula/core/network/health.py @@ -14,12 +14,15 @@ def __init__(self, addr, config, cm: "CommunicationsManager"): print_msg_box(msg="Starting health module...", indent=2, title="Health module") self.addr = addr self.config = config - self.cm = cm self.period = self.config.participant["health_args"]["health_interval"] self.alive_interval = self.config.participant["health_args"]["send_alive_interval"] self.check_alive_interval = self.config.participant["health_args"]["check_alive_interval"] self.timeout = self.config.participant["health_args"]["alive_timeout"] + @property + def cm(self): + return CommunicationsManager.get_instance() + async def start(self): asyncio.create_task(self.run_send_alive()) asyncio.create_task(self.run_check_alive()) diff --git a/nebula/core/network/messages.py b/nebula/core/network/messages.py index 3ce3c3c62..aef57f4e2 100644 --- a/nebula/core/network/messages.py +++ b/nebula/core/network/messages.py @@ -12,13 +12,16 @@ class MessagesManager: - def __init__(self, addr, config, cm: "CommunicationsManager"): + def __init__(self, addr, config): self.addr = addr self.config = config - self.cm = cm self._message_templates = {} self._define_message_templates() + @property + def cm(self): + return CommunicationsManager.get_instance() + def _define_message_templates(self): # Dictionary that maps message types to their required parameters and default values self._message_templates = { diff --git a/nebula/core/network/propagator.py b/nebula/core/network/propagator.py index 1907c03f5..7fccd3a3b 100755 --- a/nebula/core/network/propagator.py +++ b/nebula/core/network/propagator.py @@ -10,8 +10,8 @@ from nebula.config.config import Config from nebula.core.aggregation.aggregator import Aggregator from nebula.core.engine import Engine - from nebula.core.network.communications import CommunicationsManager from nebula.core.training.lightning import Lightning + from nebula.core.network.communications import CommunicationsManager class PropagationStrategy(ABC): @@ -63,11 +63,17 @@ def prepare_model_payload(self, node: str) -> tuple[Any, float] | None: class Propagator: - def __init__(self, cm: "CommunicationsManager"): - self.engine: Engine = cm.engine - self.config: Config = cm.get_config() - self.addr = cm.get_addr() - self.cm: CommunicationsManager = cm + def __init__(self): + pass + + @property + def cm(self): + return CommunicationsManager.get_instance() + + def start(self): + self.engine: Engine = CommunicationsManager.get_instance().engine + self.config: Config = CommunicationsManager.get_instance().get_config() + self.addr = CommunicationsManager.get_instance().get_addr() self.aggregator: Aggregator = self.engine.aggregator self.trainer: Lightning = self.engine._trainer @@ -83,8 +89,6 @@ def __init__(self, cm: "CommunicationsManager"): "initialization": InitialModelPropagation(self.aggregator, self.trainer, self.engine), "stable": StableModelPropagation(self.aggregator, self.trainer, self.engine), } - - def start(self): print_msg_box( msg="Starting propagator functionality...\nModel propagation through the network", indent=2, diff --git a/nebula/core/situationalawareness/awareness/commands/connectivitycommand.py b/nebula/core/situationalawareness/awareness/commands/connectivitycommand.py new file mode 100644 index 000000000..39dd399f7 --- /dev/null +++ b/nebula/core/situationalawareness/awareness/commands/connectivitycommand.py @@ -0,0 +1,18 @@ +import asyncio +import logging +from nebula.core.utils.locker import Locker +from nebula.core.situationalawareness.awareness.commands.sacommand import SACommand, SACommandAction, SACOmmandPRIO, SACommandType + +class ConnectivityCommand(SACommand): + """Commands related to connectivity.""" + def __init__( + self, + action: SACommandAction, + target: str, + priority: SACOmmandPRIO = SACOmmandPRIO.MEDIUM.name, + paralelelizable = False + ): + super().__init__(SACommandType.CONNECTIVITY, action, target, priority, paralelelizable) + + async def execute(self): + return await super().execute() diff --git a/nebula/core/situationalawareness/awareness/commands/sacommand.py b/nebula/core/situationalawareness/awareness/commands/sacommand.py new file mode 100644 index 000000000..365f7c067 --- /dev/null +++ b/nebula/core/situationalawareness/awareness/commands/sacommand.py @@ -0,0 +1,78 @@ +from nebula.core.utils.locker import Locker +from abc import ABC, abstractmethod +from enum import Enum + +# ----------------------------------------------- +# ENUMS for types of SACommands +# ----------------------------------------------- +class SACommandType(Enum): + CONNECTIVITY = "Connectivity" + AGGREGATION = "Aggregation" + +# ----------------------------------------------- +# ENUM for available actions +# ----------------------------------------------- +class SACommandAction(Enum): + DISCONNECT = "disconnect" + RECONNECT = "reconnect" + SEARCH_CONNECTIONS = "search_connections" + ADJUST_WEIGHT = "adjust_weight" + DISCARD_UPDATE = "discard_update" + +class SACOmmandPRIO(Enum): + HIGH = 10 + MEDIUM = 5 + LOW = 1 + + """ ############################### + # SA COMMAND CLASS # + ############################### + """ + +class SACommand: + """Base class for Situational Awareness module commands.""" + def __init__( + self, + command_type: SACommandType, + action: SACommandAction, + target: str, + priority: SACOmmandPRIO = SACOmmandPRIO.MEDIUM, + paralelelizable = False + ): + self._command_type = command_type + self._action = action + self._target = target # Could be a node, parameter, etc. + self._priority = priority + self._parallelizable = paralelelizable + + @abstractmethod + async def execute(self): + raise NotImplementedError + + def is_parallelizable(self): + return self._parallelizable + + def conflicts_with(self, other: "SACommand") -> bool: + """Determines if two commands conflict with each other.""" + if self._target == other._target: + conflict_pairs = [ + {SACommandAction.DISCONNECT, SACommandAction.RECONNECT}, + {SACommandAction.ADJUST_WEIGHT, SACommandAction.DISCARD_UPDATE} + ] + return {self._action, other._action} in conflict_pairs + return False + + def __repr__(self): + return (f"{self.__class__.__name__}(Type={self._command_type.value}, " + f"Action={self._action.value}, Target={self._target}, Priority={self._priority})") + + +def factory_sa_command(sacommand_type, *config) -> SACommand: + from nebula.core.situationalawareness.awareness.commands.connectivitycommand import ConnectivityCommand + + options = { + "connectivity": ConnectivityCommand, + } + + cs = options.get(sacommand_type, None) + return cs(*config) diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index cc4f0f9bf..cfef3cb8f 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from nebula.core.situationalawareness.nodemanager import NodeManager + from nebula.core.network.communications import CommunicationsManager RESTRUCTURE_COOLDOWN = 5 @@ -30,7 +31,7 @@ def __init__( self._addr = addr self._topology = topology self._node_manager: NodeManager = nodemanager - self._situational_awareness_network = SANetwork(self, self.cm, self._addr, self._topology) + self._situational_awareness_network = SANetwork(self, self._addr, self._topology) self._situational_awareness_training = SATraining(self, self._addr, "qds", "fastreboot", verbose=True) self._restructure_process_lock = Locker(name="restructure_process_lock") self._restructure_cooldown = 0 @@ -49,7 +50,7 @@ def sat(self): @property def cm(self): - return self.nm.engine.cm + return CommunicationsManager.get_instance() async def init(self): diff --git a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py index 6d7a84f22..d9e05067c 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py @@ -17,7 +17,6 @@ class SANetwork(): def __init__( self, sam: "SAModule", - communication_manager: "CommunicationsManager", addr, topology, strict_topology=True, @@ -29,7 +28,6 @@ def __init__( title="Network SA module", ) self._sam = sam - self._cm = communication_manager self._addr = addr self._topology = topology self._strict_topology = strict_topology @@ -44,7 +42,7 @@ def sam(self): @property def cm(self): - return self._cm + return CommunicationsManager.get_instance() @property def np(self): diff --git a/nebula/core/situationalawareness/nodemanager.py b/nebula/core/situationalawareness/nodemanager.py index 48631bd58..940de40b3 100644 --- a/nebula/core/situationalawareness/nodemanager.py +++ b/nebula/core/situationalawareness/nodemanager.py @@ -5,11 +5,11 @@ from nebula.addons.functions import print_msg_box from nebula.core.situationalawareness.candidateselection.candidateselector import factory_CandidateSelector from nebula.core.situationalawareness.modelhandlers.modelhandler import factory_ModelHandler -from nebula.core.situationalawareness.momentum import Momentum from nebula.core.situationalawareness.awareness.samodule import SAModule from nebula.core.utils.locker import Locker from nebula.core.eventmanager import EventManager from nebula.core.nebulaevents import UpdateNeighborEvent, NodeFoundEvent +from nebula.core.network.communications import CommunicationsManager if TYPE_CHECKING: from nebula.core.engine import Engine @@ -56,7 +56,7 @@ def engine(self): @property def cm(self): - return self._engine.cm + return CommunicationsManager.get_instance() @property def candidate_selector(self): @@ -153,7 +153,7 @@ async def waiting_confirmation_from(self, addr): async def confirmation_received(self, addr, confirmation=False): logging.info(f" Update | connection confirmation received from: {addr} | confirmation: {confirmation}") if confirmation: - await self.engine.cm.connect(addr, direct=True) + await self.cm.connect(addr, direct=True) await self.update_neighbors(addr) else: self._remove_pending_confirmation_from(addr) @@ -210,12 +210,12 @@ async def stop_not_selected_connections(self): if len(self.discarded_offers_addr) > 0: self.discarded_offers_addr = set( self.discarded_offers_addr - ) - await self.engine.cm.get_addrs_current_connections(only_direct=True, myself=False) + ) - await self.cm.get_addrs_current_connections(only_direct=True, myself=False) logging.info( f"Interrupting connections | discarded offers | nodes discarded: {self.discarded_offers_addr}" ) for addr in self.discarded_offers_addr: - await self.engine.cm.disconnect(addr, mutual_disconnection=True) + await self.cm.disconnect(addr, mutual_disconnection=True) await asyncio.sleep(1) self.discarded_offers_addr = [] except asyncio.CancelledError: @@ -238,7 +238,7 @@ async def start_late_connection_process(self, connected=False, msg_type="discove await self.clear_pending_confirmations() # find federation and send discover - connections_stablished = await self.engine.cm.stablish_connection_to_federation(msg_type, addrs_known) + connections_stablished = await self.cm.stablish_connection_to_federation(msg_type, addrs_known) # wait offer #TODO actualizar con la informacion de latencias @@ -254,9 +254,9 @@ async def start_late_connection_process(self, connected=False, msg_type="discove logging.info("Candidates found to connect to...") # create message to send to candidates selected if not connected: - msg = self.engine.cm.create_message("connection", "late_connect") + msg = self.cm.create_message("connection", "late_connect") else: - msg = self.engine.cm.create_message("connection", "restructure") + msg = self.cm.create_message("connection", "restructure") best_candidates = self.candidate_selector.select_candidates() logging.info(f"Candidates | {[addr for addr, _, _ in best_candidates]}") @@ -264,7 +264,7 @@ async def start_late_connection_process(self, connected=False, msg_type="discove try: for addr, _, _ in best_candidates: await self.add_pending_connection_confirmation(addr) - await self.engine.cm.send_message(addr, msg) + await self.cm.send_message(addr, msg) await asyncio.sleep(1) except asyncio.CancelledError: await self.update_neighbors(addr, remove=True) From db182ad6e6456f31c2ecb51583de599e331b1247 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 2 Apr 2025 15:57:30 +0200 Subject: [PATCH 150/233] fix communciation manager importatition --- nebula/addons/mobility.py | 4 +--- .../addons/networksimulation/nebulanetworksimulator.py | 4 +--- nebula/addons/reporter.py | 1 + nebula/core/network/communications.py | 3 ++- nebula/core/network/connection.py | 1 + nebula/core/network/discoverer.py | 7 +------ .../network/externalconnection/nebuladiscoveryservice.py | 4 +--- nebula/core/network/forwarder.py | 7 +------ nebula/core/network/health.py | 9 ++------- nebula/core/network/messages.py | 5 +---- nebula/core/network/propagator.py | 8 ++++---- .../awareness/commands/connectivitycommand.py | 4 ++-- .../situationalawareness/awareness/commands/sacommand.py | 4 ++-- nebula/core/situationalawareness/awareness/samodule.py | 4 +++- .../awareness/sanetwork/sanetwork.py | 5 +++-- 15 files changed, 26 insertions(+), 44 deletions(-) diff --git a/nebula/addons/mobility.py b/nebula/addons/mobility.py index ef82b09e1..9e49948fd 100755 --- a/nebula/addons/mobility.py +++ b/nebula/addons/mobility.py @@ -7,9 +7,6 @@ from nebula.core.nebulaevents import GPSEvent from nebula.core.utils.locker import Locker from nebula.addons.functions import print_msg_box -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from nebula.core.network.communications import CommunicationsManager class Mobility: @@ -71,6 +68,7 @@ def __init__(self, config, verbose=False): @property def cm(self): + from nebula.core.network.communications import CommunicationsManager return CommunicationsManager.get_instance() @property diff --git a/nebula/addons/networksimulation/nebulanetworksimulator.py b/nebula/addons/networksimulation/nebulanetworksimulator.py index 05e25d6b7..ee8023f5e 100644 --- a/nebula/addons/networksimulation/nebulanetworksimulator.py +++ b/nebula/addons/networksimulation/nebulanetworksimulator.py @@ -5,9 +5,6 @@ from nebula.core.utils.locker import Locker from nebula.core.eventmanager import EventManager from nebula.core.nebulaevents import GPSEvent -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from nebula.core.network.communications import CommunicationsManager class NebulaNS(NetworkSimulator): NETWORK_CONDITIONS = { @@ -29,6 +26,7 @@ def __init__(self, changing_interval, interface, verbose=False): @property def cm(self): + from nebula.core.network.communications import CommunicationsManager return CommunicationsManager.get_instance() async def start(self): diff --git a/nebula/addons/reporter.py b/nebula/addons/reporter.py index b4ea2cb65..ef59f0360 100755 --- a/nebula/addons/reporter.py +++ b/nebula/addons/reporter.py @@ -69,6 +69,7 @@ def __init__(self, config, trainer): @property def cm(self): + from nebula.core.network.communications import CommunicationsManager return CommunicationsManager.get_instance() async def enqueue_data(self, name, value): diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 19d581b45..2c3391833 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -35,7 +35,7 @@ def __new__(cls, engine: "Engine"): @classmethod def get_instance(cls): - """Onbtain CommunicationsManager instance""" + """Obtain CommunicationsManager instance""" if cls._instance is None: raise ValueError("CommunicationsManager has not been initialized yet.") return cls._instance @@ -93,6 +93,7 @@ def __init__(self, engine: "Engine"): self._external_connection_service = factory_connection_service("nebula", self.addr) self._initialized = True + logging.info("Communication Manager initialized completed") @property def engine(self): diff --git a/nebula/core/network/connection.py b/nebula/core/network/connection.py index bfe2b407d..767002526 100755 --- a/nebula/core/network/connection.py +++ b/nebula/core/network/connection.py @@ -93,6 +93,7 @@ def __del__(self): @property def cm(self): + from nebula.core.network.communications import CommunicationsManager return CommunicationsManager.get_instance() def get_addr(self): diff --git a/nebula/core/network/discoverer.py b/nebula/core/network/discoverer.py index 9df7c970e..b6d646ff4 100755 --- a/nebula/core/network/discoverer.py +++ b/nebula/core/network/discoverer.py @@ -1,13 +1,7 @@ import asyncio import logging -from typing import TYPE_CHECKING - from nebula.addons.functions import print_msg_box -if TYPE_CHECKING: - from nebula.core.network.communications import CommunicationsManager - - class Discoverer: def __init__(self, addr, config): print_msg_box(msg="Starting discoverer module...", indent=2, title="Discoverer module") @@ -19,6 +13,7 @@ def __init__(self, addr, config): @property def cm(self): + from nebula.core.network.communications import CommunicationsManager return CommunicationsManager.get_instance() async def start(self): diff --git a/nebula/core/network/externalconnection/nebuladiscoveryservice.py b/nebula/core/network/externalconnection/nebuladiscoveryservice.py index 6285487ed..939caa804 100644 --- a/nebula/core/network/externalconnection/nebuladiscoveryservice.py +++ b/nebula/core/network/externalconnection/nebuladiscoveryservice.py @@ -7,9 +7,6 @@ from nebula.core.nebulaevents import BeaconRecievedEvent from nebula.core.eventmanager import EventManager -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from nebula.core.network.communications import CommunicationsManager class NebulaServerProtocol(asyncio.DatagramProtocol): BCAST_IP = '239.255.255.250' @@ -172,6 +169,7 @@ def __init__(self, addr): @property def cm(self): + from nebula.core.network.communications import CommunicationsManager return CommunicationsManager.get_instance() async def start(self): diff --git a/nebula/core/network/forwarder.py b/nebula/core/network/forwarder.py index d7febd4df..380b7d3c2 100755 --- a/nebula/core/network/forwarder.py +++ b/nebula/core/network/forwarder.py @@ -1,15 +1,9 @@ import asyncio import logging import time -from typing import TYPE_CHECKING - from nebula.addons.functions import print_msg_box from nebula.core.utils.locker import Locker -if TYPE_CHECKING: - from nebula.core.network.communications import CommunicationsManager - - class Forwarder: def __init__(self, config): print_msg_box(msg="Starting forwarder module...", indent=2, title="Forwarder module") @@ -23,6 +17,7 @@ def __init__(self, config): @property def cm(self): + from nebula.core.network.communications import CommunicationsManager return CommunicationsManager.get_instance() async def start(self): diff --git a/nebula/core/network/health.py b/nebula/core/network/health.py index c17075e8c..9669b888f 100755 --- a/nebula/core/network/health.py +++ b/nebula/core/network/health.py @@ -1,16 +1,10 @@ import asyncio import logging import time -from typing import TYPE_CHECKING - from nebula.addons.functions import print_msg_box -if TYPE_CHECKING: - from nebula.core.network.communications import CommunicationsManager - - class Health: - def __init__(self, addr, config, cm: "CommunicationsManager"): + def __init__(self, addr, config): print_msg_box(msg="Starting health module...", indent=2, title="Health module") self.addr = addr self.config = config @@ -21,6 +15,7 @@ def __init__(self, addr, config, cm: "CommunicationsManager"): @property def cm(self): + from nebula.core.network.communications import CommunicationsManager return CommunicationsManager.get_instance() async def start(self): diff --git a/nebula/core/network/messages.py b/nebula/core/network/messages.py index aef57f4e2..f23674ad5 100644 --- a/nebula/core/network/messages.py +++ b/nebula/core/network/messages.py @@ -7,10 +7,6 @@ from nebula.core.network.actions import factory_message_action, get_action_name_from_value, get_actions_names from nebula.core.pb import nebula_pb2 -if TYPE_CHECKING: - from nebula.core.network.communications import CommunicationsManager - - class MessagesManager: def __init__(self, addr, config): self.addr = addr @@ -20,6 +16,7 @@ def __init__(self, addr, config): @property def cm(self): + from nebula.core.network.communications import CommunicationsManager return CommunicationsManager.get_instance() def _define_message_templates(self): diff --git a/nebula/core/network/propagator.py b/nebula/core/network/propagator.py index 7fccd3a3b..91f050061 100755 --- a/nebula/core/network/propagator.py +++ b/nebula/core/network/propagator.py @@ -11,7 +11,6 @@ from nebula.core.aggregation.aggregator import Aggregator from nebula.core.engine import Engine from nebula.core.training.lightning import Lightning - from nebula.core.network.communications import CommunicationsManager class PropagationStrategy(ABC): @@ -68,12 +67,13 @@ def __init__(self): @property def cm(self): + from nebula.core.network.communications import CommunicationsManager return CommunicationsManager.get_instance() def start(self): - self.engine: Engine = CommunicationsManager.get_instance().engine - self.config: Config = CommunicationsManager.get_instance().get_config() - self.addr = CommunicationsManager.get_instance().get_addr() + self.engine: Engine = self.cm.engine + self.config: Config = self.cm.get_config() + self.addr = self.cm.get_addr() self.aggregator: Aggregator = self.engine.aggregator self.trainer: Lightning = self.engine._trainer diff --git a/nebula/core/situationalawareness/awareness/commands/connectivitycommand.py b/nebula/core/situationalawareness/awareness/commands/connectivitycommand.py index 39dd399f7..0520b6269 100644 --- a/nebula/core/situationalawareness/awareness/commands/connectivitycommand.py +++ b/nebula/core/situationalawareness/awareness/commands/connectivitycommand.py @@ -10,9 +10,9 @@ def __init__( action: SACommandAction, target: str, priority: SACOmmandPRIO = SACOmmandPRIO.MEDIUM.name, - paralelelizable = False + parallelizable = False ): - super().__init__(SACommandType.CONNECTIVITY, action, target, priority, paralelelizable) + super().__init__(SACommandType.CONNECTIVITY, action, target, priority, parallelizable) async def execute(self): return await super().execute() diff --git a/nebula/core/situationalawareness/awareness/commands/sacommand.py b/nebula/core/situationalawareness/awareness/commands/sacommand.py index 365f7c067..0c38fed9a 100644 --- a/nebula/core/situationalawareness/awareness/commands/sacommand.py +++ b/nebula/core/situationalawareness/awareness/commands/sacommand.py @@ -37,13 +37,13 @@ def __init__( action: SACommandAction, target: str, priority: SACOmmandPRIO = SACOmmandPRIO.MEDIUM, - paralelelizable = False + parallelizable = False ): self._command_type = command_type self._action = action self._target = target # Could be a node, parameter, etc. self._priority = priority - self._parallelizable = paralelelizable + self._parallelizable = parallelizable @abstractmethod async def execute(self): diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index cfef3cb8f..108fceb61 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -7,10 +7,12 @@ from nebula.core.nebulaevents import RoundEndEvent from nebula.core.eventmanager import EventManager +from nebula.core.network.communications import CommunicationsManager + from typing import TYPE_CHECKING if TYPE_CHECKING: from nebula.core.situationalawareness.nodemanager import NodeManager - from nebula.core.network.communications import CommunicationsManager + RESTRUCTURE_COOLDOWN = 5 diff --git a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py index d9e05067c..968643f5f 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py @@ -6,9 +6,10 @@ from nebula.core.nebulaevents import BeaconRecievedEvent from nebula.core.eventmanager import EventManager from nebula.core.nebulaevents import NodeFoundEvent, UpdateNeighborEvent +from nebula.core.network.communications import CommunicationsManager + from typing import TYPE_CHECKING if TYPE_CHECKING: - from nebula.core.network.communications import CommunicationsManager from nebula.core.situationalawareness.awareness.samodule import SAModule RESTRUCTURE_COOLDOWN = 5 @@ -42,7 +43,7 @@ def sam(self): @property def cm(self): - return CommunicationsManager.get_instance() + return CommunicationsManager.get_instance() @property def np(self): From 2babfd4e5945765047e245a170a2b24602f6a9fd Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 2 Apr 2025 17:09:32 +0200 Subject: [PATCH 151/233] feature connectivity commands --- .../awareness/commands/connectivitycommand.py | 18 ---- .../awareness/commands/sacommand.py | 85 ++++++++++++++++--- 2 files changed, 72 insertions(+), 31 deletions(-) delete mode 100644 nebula/core/situationalawareness/awareness/commands/connectivitycommand.py diff --git a/nebula/core/situationalawareness/awareness/commands/connectivitycommand.py b/nebula/core/situationalawareness/awareness/commands/connectivitycommand.py deleted file mode 100644 index 0520b6269..000000000 --- a/nebula/core/situationalawareness/awareness/commands/connectivitycommand.py +++ /dev/null @@ -1,18 +0,0 @@ -import asyncio -import logging -from nebula.core.utils.locker import Locker -from nebula.core.situationalawareness.awareness.commands.sacommand import SACommand, SACommandAction, SACOmmandPRIO, SACommandType - -class ConnectivityCommand(SACommand): - """Commands related to connectivity.""" - def __init__( - self, - action: SACommandAction, - target: str, - priority: SACOmmandPRIO = SACOmmandPRIO.MEDIUM.name, - parallelizable = False - ): - super().__init__(SACommandType.CONNECTIVITY, action, target, priority, parallelizable) - - async def execute(self): - return await super().execute() diff --git a/nebula/core/situationalawareness/awareness/commands/sacommand.py b/nebula/core/situationalawareness/awareness/commands/sacommand.py index 0c38fed9a..ee6b63d63 100644 --- a/nebula/core/situationalawareness/awareness/commands/sacommand.py +++ b/nebula/core/situationalawareness/awareness/commands/sacommand.py @@ -1,6 +1,7 @@ from nebula.core.utils.locker import Locker from abc import ABC, abstractmethod from enum import Enum +import asyncio # ----------------------------------------------- # ENUMS for types of SACommands @@ -19,7 +20,7 @@ class SACommandAction(Enum): ADJUST_WEIGHT = "adjust_weight" DISCARD_UPDATE = "discard_update" -class SACOmmandPRIO(Enum): +class SACommandPRIO(Enum): HIGH = 10 MEDIUM = 5 LOW = 1 @@ -35,8 +36,8 @@ def __init__( self, command_type: SACommandType, action: SACommandAction, - target: str, - priority: SACOmmandPRIO = SACOmmandPRIO.MEDIUM, + target, + priority: SACommandPRIO = SACommandPRIO.MEDIUM, parallelizable = False ): self._command_type = command_type @@ -49,30 +50,88 @@ def __init__( async def execute(self): raise NotImplementedError + @abstractmethod + async def conflicts_with(self, other: "SACommand") -> bool: + raise NotImplementedError + def is_parallelizable(self): return self._parallelizable - def conflicts_with(self, other: "SACommand") -> bool: + def __repr__(self): + return (f"{self.__class__.__name__}(Type={self._command_type.value}, " + f"Action={self._action.value}, Target={self._target}, Priority={self._priority})") + + """ ############################### + # SA COMMAND SUBCLASS # + ############################### + """ +class ConnectivityCommand(SACommand): + """Commands related to connectivity.""" + def __init__( + self, + action: SACommandAction, + target: str, + priority: SACommandPRIO = SACommandPRIO.MEDIUM, + parallelizable = False, + action_function = None, + *args + ): + super().__init__(SACommandType.CONNECTIVITY, action, target, priority, parallelizable) + self._action_function = action_function + self._args = args + + async def execute(self): + """Executes the assigned action function with the given parameters.""" + if self._action_function: + if asyncio.iscoroutinefunction(self._action_function): + await self._action_function(*self._args) + else: + self._action_function(*self._args) + + def conflicts_with(self, other: "ConnectivityCommand") -> bool: """Determines if two commands conflict with each other.""" if self._target == other._target: conflict_pairs = [ - {SACommandAction.DISCONNECT, SACommandAction.RECONNECT}, - {SACommandAction.ADJUST_WEIGHT, SACommandAction.DISCARD_UPDATE} + {SACommandAction.DISCONNECT, SACommandAction.RECONNECT} ] return {self._action, other._action} in conflict_pairs - return False + return False - def __repr__(self): - return (f"{self.__class__.__name__}(Type={self._command_type.value}, " - f"Action={self._action.value}, Target={self._target}, Priority={self._priority})") - +class AggregationCommand(SACommand): + """Commands related to data aggregation.""" + def __init__( + self, + action: SACommandAction, + target: dict, + priority: SACommandPRIO = SACommandPRIO.MEDIUM, + parallelizable = False, + ): + super().__init__(SACommandType.CONNECTIVITY, action, target, priority, parallelizable) + + async def execute(self): + return self._target -def factory_sa_command(sacommand_type, *config) -> SACommand: - from nebula.core.situationalawareness.awareness.commands.connectivitycommand import ConnectivityCommand + def conflicts_with(self, other: "AggregationCommand") -> bool: + """Determines if two commands conflict with each other.""" + if self._target == other._target: + conflict_pairs = [ + {SACommandAction.DISCONNECT, SACommandAction.RECONNECT} + ] + return {self._action, other._action} in conflict_pairs + return False + """ ############################### + # SA COMMAND FACTORY # + ############################### + """ + +def factory_sa_command(sacommand_type, *config) -> SACommand: options = { "connectivity": ConnectivityCommand, + "aggregation": AggregationCommand, } cs = options.get(sacommand_type, None) return cs(*config) + + From 05f2b4b6b8e959dfbf25793229f41be4de764702 Mon Sep 17 00:00:00 2001 From: FerTV Date: Wed, 2 Apr 2025 17:23:08 +0200 Subject: [PATCH 152/233] refactor previous created endpoints and scenarios endpoints added --- nebula/controller.py | 96 ++++++++++++++- nebula/frontend/app.py | 259 +++++++++++++++++++---------------------- 2 files changed, 213 insertions(+), 142 deletions(-) diff --git a/nebula/controller.py b/nebula/controller.py index 47705591f..fdb79f380 100755 --- a/nebula/controller.py +++ b/nebula/controller.py @@ -143,6 +143,23 @@ async def get_available_gpu(): except Exception: # noqa: S110 pass +@app.post("/scenarios/remove") +async def remove_scenario( + scenario_name: str = Body(..., embed=True) +): + """ + Controller endpoint to remove a scenario. + """ + from nebula.frontend.database import remove_scenario_by_name + + try: + remove_scenario_by_name(scenario_name) + except Exception as e: + logging.error(f"Error removing scenario {scenario_name}: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + return {"message": f"Scenario {scenario_name} removed successfully"} + @app.get("/scenarios/{user}/{role}") async def get_scenarios( @@ -177,7 +194,82 @@ async def get_scenarios( return {"scenarios": scenarios, "scenario_running": scenario_running} -@app.get("/users/list") +@app.get("/scenarios/running") +async def get_running_scenario(get_all: bool = False): + """ + Controller endpoint to retrieve the running scenario. + """ + from nebula.frontend.database import get_running_scenario + + try: + return get_running_scenario(get_all=get_all) + except Exception as e: + logging.error(f"Error obtaining running scenario: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + +@app.get("/scenarios/check") +async def check_scenario(role: str, scenario_name: str): + """ + Controller endpoint to check if a scenario is allowed for a specific role. + """ + from nebula.frontend.database import check_scenario_with_role + + try: + allowed = check_scenario_with_role(role, scenario_name) + return {"allowed": allowed} + except Exception as e: + logging.error(f"Error checking scenario with role: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + +@app.get("/scenarios/{scenario_name}") +async def get_scenario_by_name( + scenario_name: Annotated[ + str, + Path( + regex="^[a-zA-Z0-9_-]+$", + min_length=1, + max_length=50, + description="Valid scenario name" + ) + ] +): + from nebula.frontend.database import get_scenario_by_name + + try: + scenario = get_scenario_by_name(scenario_name) + except Exception as e: + logging.error(f"Error obtaining scenario {scenario_name}: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + return {"scenario": scenario} + + +@app.get("/scenarios/user/{scenario_name}") +async def get_user_by_scenario_name( + scenario_name: Annotated[ + str, + Path( + regex="^[a-zA-Z0-9_-]+$", + min_length=1, + max_length=50, + description="Valid scenario name" + ) + ] +): + from nebula.frontend.database import get_user_by_scenario_name + + try: + user = get_user_by_scenario_name(scenario_name) + except Exception as e: + logging.error(f"Error obtaining user {user}: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + return {"user": user} + + +@app.get("/user/list") async def list_users_controller(all_info: bool = False): """ Controller endpoint to retrieve the list of users. @@ -228,7 +320,7 @@ async def add_user_controller( @app.post("/user/delete") -async def add_user_controller( +async def remove_user_controller( user: str = Body(..., embed=True) ): """ diff --git a/nebula/frontend/app.py b/nebula/frontend/app.py index 8496b5514..8156fbff7 100755 --- a/nebula/frontend/app.py +++ b/nebula/frontend/app.py @@ -92,21 +92,21 @@ class Settings: from starlette.middleware.sessions import SessionMiddleware from nebula.frontend.database import ( - add_user, - check_scenario_with_role, + # add_user, + # check_scenario_with_role, delete_user_from_db, get_all_scenarios_and_check_completed, get_notes, get_running_scenario, - get_scenario_by_name, - get_user_by_scenario_name, + # get_scenario_by_name, + # get_user_by_scenario_name, get_user_info, initialize_databases, - list_nodes_by_scenario_name, - list_users, + # list_nodes_by_scenario_name, + # list_users, remove_nodes_by_scenario_name, remove_note, - remove_scenario_by_name, + # remove_scenario_by_name, save_notes, scenario_set_all_status_to_finished, scenario_set_status_to_finished, @@ -265,50 +265,72 @@ async def custom_http_exception_handler(request: Request, exc: StarletteHTTPExce return await request.app.default_exception_handler(request, exc) +async def get(url): + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + if response.status == 200: + return await response.json() + else: + raise HTTPException(status_code=response.status, detail="Error fetching data") + + +async def post(url, data=None): + async with aiohttp.ClientSession() as session: + async with session.post(url, json=data) as response: + if response.status == 200: + return await response.json() + else: + raise HTTPException(status_code=response.status, detail="Error posting data") + + async def get_available_gpus(): url = f"http://{settings.controller_host}:{settings.controller_port}/available_gpus" - async with aiohttp.ClientSession() as session, session.get(url) as response: - if response.status == 200: - try: - return await response.json() - except Exception as e: - return {"error": f"Failed to parse JSON: {e}"} - else: - return None + return await get(url) async def get_least_memory_gpu(): url = f"http://{settings.controller_host}:{settings.controller_port}/least_memory_gpu" - async with aiohttp.ClientSession() as session, session.get(url) as response: - if response.status == 200: - try: - return await response.json() - except Exception as e: - return {"error": f"Failed to parse JSON: {e}"} - else: - return None - - + return await get(url) + + async def get_scenarios(user, role): - try: - base_url = URL(f"http://{settings.controller_host}:{settings.controller_port}/scenarios") - url = base_url / user / role - logging.info(f"Requesting URL: {url}") - - async with aiohttp.ClientSession() as session: - async with session.get(url, timeout=10) as response: - if response.status == 200: - try: - return await response.json() - except Exception as e: - logging.error(f"Error parsing JSON: {e}") - return {"error": f"Error parsing JSON: {e}"} - else: - logging.error(f"Request error: {response.status}") - return {"error": f"Request error: {response.status}"} - except Exception as e: - logging.error(f"Unexpected error in call_get_scenarios: {e}") - return {"error": f"Unexpected error: {e}"} + url = f"http://{settings.controller_host}:{settings.controller_port}/scenarios/{user}/{role}" + return await get(url) + + +async def remove_scenario_by_name(scenario_name): + url = f"http://{settings.controller_host}:{settings.controller_port}/scenarios/remove/{scenario_name}" + data = {"scenario_name": scenario_name} + await post(url, data) + + +async def check_scenario_with_role(session, scenario_name): + url = ( + f"http://{settings.controller_host}:{settings.controller_port}" + f"/scenarios/check?role={session['role']}&scenario_name={scenario_name}" + ) + check_data = await get(url) + return check_data.get("allowed", False) + + +async def get_scenario_by_name(scenario_name): + url = f"http://{settings.controller_host}:{settings.controller_port}/scenarios/{scenario_name}" + return await get(url) + + +async def get_user_by_scenario_name(scenario_name): + url = f"http://{settings.controller_host}:{settings.controller_port}/scenarios/user/{scenario_name}" + return await get(url) + + +async def get_running_scenarios(get_all=False): + url = f"http://{settings.controller_host}:{settings.controller_port}/scenarios/running?get_all={get_all}" + return await get(url) + + +async def list_nodes_by_scenario_name(scenario_name): + url = f"http://{settings.controller_host}:{settings.controller_port}/nodes/{scenario_name}" + return await get(url) async def list_users(allinfo=True): @@ -321,24 +343,37 @@ async def list_users(allinfo=True): Returns: - A list of users, as provided by the controller. """ - controller_url = f"http://{settings.controller_host}:{settings.controller_port}/users/list?all_info={allinfo}" - try: - async with aiohttp.ClientSession() as client: - async with client.get(controller_url, timeout=10) as resp: - if resp.status != 200: - raise HTTPException( - status_code=resp.status, - detail="Error retrieving user list from controller" - ) - data = await resp.json() - user_list = data["users"] - except Exception as e: - logging.error(f"Error calling controller endpoint: {e}") - raise HTTPException(status_code=500, detail=f"Controller request failed: {e}") + url = f"http://{settings.controller_host}:{settings.controller_port}/user/list?all_info={allinfo}" + data = await get(url) + user_list = data["users"] return user_list +async def add_user(user, password, role): + url = f"http://{settings.controller_host}:{settings.controller_port}/user/add" + data = {"user": user, "password": password, "role": role} + await post(url, data) + + +async def update_user(user, password, role): + url = f"http://{settings.controller_host}:{settings.controller_port}/user/update" + data = {"user": user, "password": password, "role": role} + await post(url, data) + + +async def delete_user(user): + url = f"http://{settings.controller_host}:{settings.controller_port}/user/delete" + data = {"user": user} + await post(url, data) + + +async def verify_user(user, password): + url = f"http://{settings.controller_host}:{settings.controller_port}/user/verify" + data = {"user": user, "password": password} + return await post(url, data) + + @app.get("/", response_class=HTMLResponse) async def index(): return RedirectResponse(url="/platform") @@ -440,22 +475,10 @@ async def nebula_login( user: str = Form(...), password: str = Form(...), ): - controller_url = f"http://{settings.controller_host}:{settings.controller_port}/user/verify" - payload = {"user": user, "password": password} - try: - async with aiohttp.ClientSession() as client: - async with client.post(controller_url, json=payload, timeout=10) as resp: - if resp.status == 200: - # Successful response from the controller. - data = await resp.json() - session["user"] = data.get("user") - session["role"] = data.get("role") - return JSONResponse({"message": "Login successful"}, status_code=200) - else: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) - except Exception as e: - logging.error(f"Error calling controller endpoint: {e}") - raise HTTPException(status_code=500, detail=f"Controller request failed: {e}") + data = await verify_user(user, password) + session["user"] = data.get("user") + session["role"] = data.get("role") + return JSONResponse({"message": "Login successful"}, status_code=200) @app.get("/platform/logout") @@ -472,24 +495,8 @@ async def nebula_delete_user(user: str, request: Request, session: dict = Depend if user == session["user"]: # Current user can't delete himself. raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) - controller_url = f"http://{settings.controller_host}:{settings.controller_port}/user/delete" - payload = {"user": user} - - try: - async with aiohttp.ClientSession() as client: - async with client.post(controller_url, json=payload, timeout=10) as resp: - if resp.status == 200: - # Successful response from the controller. - return RedirectResponse(url="/platform/admin") - else: - error_text = await resp.text() - logging.error(f"Controller error: {error_text}") - raise HTTPException(status_code=resp.status, detail=error_text) - except Exception as e: - logging.error(f"Error calling controller endpoint: {e}") - raise HTTPException(status_code=500, detail=f"Controller request failed: {e}") - else: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) + await delete_user(user) + return RedirectResponse(url="/platform/admin") @app.post("/platform/user/add") @@ -510,23 +517,9 @@ async def nebula_add_user( if user.upper() in user_list or " " in user or "'" in user or '"' in user: return RedirectResponse(url="/platform/admin", status_code=status.HTTP_303_SEE_OTHER) - # Call the controller's endpoint to add the user. - controller_url = f"http://{settings.controller_host}:{settings.controller_port}/user/add" - payload = {"user": user, "password": password, "role": role} - - try: - async with aiohttp.ClientSession() as client: - async with client.post(controller_url, json=payload, timeout=10) as resp: - if resp.status == 200: - # Successful response from the controller. - return RedirectResponse(url="/platform/admin", status_code=status.HTTP_303_SEE_OTHER) - else: - error_text = await resp.text() - logging.error(f"Controller error: {error_text}") - raise HTTPException(status_code=resp.status, detail=error_text) - except Exception as e: - logging.error(f"Error calling controller endpoint: {e}") - raise HTTPException(status_code=500, detail=f"Controller request failed: {e}") + # Call the controller's endpoint to add the user. + await add_user(user, password, role) + return RedirectResponse(url="/platform/admin", status_code=status.HTTP_303_SEE_OTHER) @app.post("/platform/user/update") @@ -539,28 +532,14 @@ async def nebula_update_user( ): if "user" not in session or session["role"] != "admin": return RedirectResponse(url="/platform", status_code=status.HTTP_302_FOUND) - - controller_url = f"http://{settings.controller_host}:{settings.controller_port}/user/update" - payload = {"user": user, "password": password, "role": role} - - try: - async with aiohttp.ClientSession() as client: - async with client.post(controller_url, json=payload, timeout=10) as resp: - if resp.status == 200: - # Successful response from the controller. - return RedirectResponse(url="/platform/admin", status_code=status.HTTP_302_FOUND) - else: - error_text = await resp.text() - logging.error(f"Controller error: {error_text}") - raise HTTPException(status_code=resp.status, detail=error_text) - except Exception as e: - logging.error(f"Error calling controller endpoint: {e}") - raise HTTPException(status_code=500, detail=f"Controller request failed: {e}") + + await update_user(user, password, role) + return RedirectResponse(url="/platform/admin", status_code=status.HTTP_302_FOUND) @app.get("/platform/api/dashboard/runningscenario", response_class=JSONResponse) async def nebula_dashboard_runningscenario(): - scenario_running = get_running_scenario() + scenario_running = await get_running_scenarios() if scenario_running: scenario_running_as_dict = dict(scenario_running) scenario_running_as_dict["scenario_status"] = "running" @@ -615,7 +594,7 @@ async def monitor_resources(): while True: enough_resources = await check_enough_resources() if not enough_resources: - running_scenarios = get_running_scenario(get_all=True) + running_scenarios = await get_running_scenarios(get_all=True) if running_scenarios: last_running_scenario = running_scenarios.pop() running_scenario_as_dict = dict(last_running_scenario) @@ -703,9 +682,9 @@ async def nebula_dashboard(request: Request, session: dict = Depends(get_session @app.get("/platform/api/dashboard/{scenario_name}/monitor", response_class=JSONResponse) @app.get("/platform/dashboard/{scenario_name}/monitor", response_class=HTMLResponse) async def nebula_dashboard_monitor(scenario_name: str, request: Request, session: dict = Depends(get_session)): - scenario = get_scenario_by_name(scenario_name) + scenario = await get_scenario_by_name(scenario_name) if scenario: - nodes_list = list_nodes_by_scenario_name(scenario_name) + nodes_list = await list_nodes_by_scenario_name(scenario_name) if nodes_list: nodes_config = [] nodes_status = [] @@ -1039,7 +1018,7 @@ async def nebula_stop_scenario( session: dict = Depends(get_session), ): if "user" in session: - user = get_user_by_scenario_name(scenario_name) + user = await get_user_by_scenario_name(scenario_name) user_data = user_data_store[user] if session["role"] == "demo": @@ -1061,7 +1040,7 @@ async def nebula_stop_scenario( raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) -def remove_scenario(scenario_name=None, user=None): +async def remove_scenario(scenario_name=None, user=None): from nebula.scenarios import ScenarioManagement user_data = user_data_store[user] @@ -1071,7 +1050,7 @@ def remove_scenario(scenario_name=None, user=None): # Remove registered nodes and conditions user_data.nodes_registration.pop(scenario_name, None) remove_nodes_by_scenario_name(scenario_name) - remove_scenario_by_name(scenario_name) + await remove_scenario_by_name(scenario_name) remove_note(scenario_name) ScenarioManagement.remove_files_by_scenario(scenario_name) @@ -1086,7 +1065,7 @@ async def nebula_relaunch_scenario( if session["role"] == "demo": raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) elif session["role"] == "user": - if not check_scenario_with_role(session["role"], scenario_name): + if not await check_scenario_with_role(session["role"], scenario_name): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) scenario_path = FileUtils.check_path(settings.config_dir, os.path.join(scenario_name, "scenario.json")) @@ -1114,9 +1093,9 @@ async def nebula_remove_scenario(scenario_name: str, session: dict = Depends(get if session["role"] == "demo": raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) elif session["role"] == "user": - if not check_scenario_with_role(session["role"], scenario_name): + if not await check_scenario_with_role(session["role"], scenario_name): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) - remove_scenario(scenario_name, session["user"]) + await remove_scenario(scenario_name, session["user"]) return RedirectResponse(url="/platform/dashboard") else: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) @@ -1243,7 +1222,7 @@ async def nebula_dashboard_download_logs_metrics( @app.get("/platform/dashboard/deployment/", response_class=HTMLResponse) async def nebula_dashboard_deployment(request: Request, session: dict = Depends(get_session)): - scenario_running = get_running_scenario() + scenario_running = await get_running_scenarios() return templates.TemplateResponse( "deployment.html", { @@ -1333,13 +1312,13 @@ def mobility_assign(nodes, mobile_participants_percent): # Recieve a stopped node @app.post("/platform/dashboard/{scenario_name}/node/done") async def node_stopped(scenario_name: str, request: Request): - user = get_user_by_scenario_name(scenario_name) + user = await get_user_by_scenario_name(scenario_name) user_data = user_data_store[user] if request.headers.get("content-type") == "application/json": data = await request.json() user_data.nodes_finished.append(data["idx"]) - nodes_list = list_nodes_by_scenario_name(scenario_name) + nodes_list = await list_nodes_by_scenario_name(scenario_name) finished = True # Check if all the nodes of the scenario have finished the experiment for node in nodes_list: @@ -1371,7 +1350,7 @@ async def assign_available_gpu(scenario_data, role): available_system_gpus = response.get("available_gpus", None) if response is not None else None if available_system_gpus: - running_scenarios = get_running_scenario(get_all=True) + running_scenarios = get_running_scenarios(get_all=True) # Obtain currently used gpus if running_scenarios: running_gpus = [] From 0f7581e776b290f0c3659b0556aa8fef14ead26f Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 3 Apr 2025 13:34:47 +0200 Subject: [PATCH 153/233] feature suggestionbuffer --- nebula/core/nebulaevents.py | 4 ++ .../awareness/{commands => }/sacommand.py | 26 ++++--- .../awareness/samodule.py | 1 + .../awareness/samoduleagent.py | 20 ++++++ .../trainingpolicy/trainingpolicy.py | 4 +- .../awareness/suggestionbuffer.py | 68 +++++++++++++++++++ 6 files changed, 112 insertions(+), 11 deletions(-) rename nebula/core/situationalawareness/awareness/{commands => }/sacommand.py (87%) create mode 100644 nebula/core/situationalawareness/awareness/samoduleagent.py create mode 100644 nebula/core/situationalawareness/awareness/suggestionbuffer.py diff --git a/nebula/core/nebulaevents.py b/nebula/core/nebulaevents.py index 89a33809d..a3a6d54dd 100644 --- a/nebula/core/nebulaevents.py +++ b/nebula/core/nebulaevents.py @@ -95,6 +95,10 @@ def __init__(self, updates : dict, expected_nodes : set, missing_nodes : set): def __str__(self): return "Aggregation Ready" + + def update_updates(self, new_updates: dict): + """Allows an external module to update the updates dictionary.""" + self._updates = new_updates async def get_event_data(self) -> tuple[dict, set, set]: """Retrieves the aggregation event data. diff --git a/nebula/core/situationalawareness/awareness/commands/sacommand.py b/nebula/core/situationalawareness/awareness/sacommand.py similarity index 87% rename from nebula/core/situationalawareness/awareness/commands/sacommand.py rename to nebula/core/situationalawareness/awareness/sacommand.py index ee6b63d63..82154b0e3 100644 --- a/nebula/core/situationalawareness/awareness/commands/sacommand.py +++ b/nebula/core/situationalawareness/awareness/sacommand.py @@ -1,5 +1,5 @@ from nebula.core.utils.locker import Locker -from abc import ABC, abstractmethod +from abc import abstractmethod from enum import Enum import asyncio @@ -17,6 +17,7 @@ class SACommandAction(Enum): DISCONNECT = "disconnect" RECONNECT = "reconnect" SEARCH_CONNECTIONS = "search_connections" + MAINTAIN_CONNECTIONS = "maintain_connections" ADJUST_WEIGHT = "adjust_weight" DISCARD_UPDATE = "discard_update" @@ -59,7 +60,7 @@ def is_parallelizable(self): def __repr__(self): return (f"{self.__class__.__name__}(Type={self._command_type.value}, " - f"Action={self._action.value}, Target={self._target}, Priority={self._priority})") + f"Action={self._action.value}, Target={self._target}, Priority={self._priority.value})") """ ############################### # SA COMMAND SUBCLASS # @@ -92,7 +93,8 @@ def conflicts_with(self, other: "ConnectivityCommand") -> bool: """Determines if two commands conflict with each other.""" if self._target == other._target: conflict_pairs = [ - {SACommandAction.DISCONNECT, SACommandAction.RECONNECT} + {SACommandAction.DISCONNECT, SACommandAction.RECONNECT}, + {SACommandAction.DISCONNECT, SACommandAction.MAINTAIN_CONNECTIONS} ] return {self._action, other._action} in conflict_pairs return False @@ -113,12 +115,18 @@ async def execute(self): def conflicts_with(self, other: "AggregationCommand") -> bool: """Determines if two commands conflict with each other.""" - if self._target == other._target: - conflict_pairs = [ - {SACommandAction.DISCONNECT, SACommandAction.RECONNECT} - ] - return {self._action, other._action} in conflict_pairs - return False + topologic_conflict = False + weight_conflict = False + + if set(self._target.keys()) != set(other._target.keys()): + topologic_conflict = True + + weight_conflict = any( + abs(self._target[node][1] - other._target[node][1]) > 0 + for node in self._target.keys() if node in other._target.keys() + ) + + return weight_conflict and topologic_conflict """ ############################### # SA COMMAND FACTORY # diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index 108fceb61..bf34a23d6 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -1,6 +1,7 @@ import asyncio import logging from nebula.addons.functions import print_msg_box +from nebula.core.situationalawareness.awareness.suggestionbuffer import SuggestionBuffer from nebula.core.situationalawareness.awareness.sanetwork.sanetwork import SANetwork from nebula.core.situationalawareness.awareness.satraining.satraining import SATraining from nebula.core.utils.locker import Locker diff --git a/nebula/core/situationalawareness/awareness/samoduleagent.py b/nebula/core/situationalawareness/awareness/samoduleagent.py new file mode 100644 index 000000000..e23cc3669 --- /dev/null +++ b/nebula/core/situationalawareness/awareness/samoduleagent.py @@ -0,0 +1,20 @@ +from abc import abstractmethod, ABC +from nebula.core.situationalawareness.awareness.sacommand import SACommand + +class SAModuleAgent(ABC): +#TODO hacer el diccionario de tipos de agentes para luego tenerlo en cuenta en el arbitraje + @abstractmethod + async def get_agent(self): + raise NotImplementedError + + @abstractmethod + async def register_sa_agent(self): + raise NotImplementedError + + @abstractmethod + async def suggest_action(self, sac : SACommand): + raise NotImplementedError + + @abstractmethod + async def notify_all_suggestions_done(self, sac : SACommand): + raise NotImplementedError \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py index 919b9374e..df77da1cd 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py @@ -1,8 +1,8 @@ from abc import ABC, abstractmethod -from typing import Type +from nebula.core.situationalawareness.awareness.samoduleagent import SAModuleAgent -class TrainingPolicy(ABC): +class TrainingPolicy(ABC, SAModuleAgent): @abstractmethod async def init(self, config): diff --git a/nebula/core/situationalawareness/awareness/suggestionbuffer.py b/nebula/core/situationalawareness/awareness/suggestionbuffer.py new file mode 100644 index 000000000..fa64ba3c1 --- /dev/null +++ b/nebula/core/situationalawareness/awareness/suggestionbuffer.py @@ -0,0 +1,68 @@ +from nebula.core.utils.locker import Locker +from nebula.utils import logging +import asyncio +from nebula.core.situationalawareness.awareness.samoduleagent import SAModuleAgent +from nebula.core.situationalawareness.awareness.sacommand import SACommand +from nebula.core.nebulaevents import NodeEvent, RoundEndEvent, AggregationEvent +from collections import defaultdict + +class SuggestionBuffer(): + _instance = None + _lock = Locker("initialize_sb_lock", async_lock=False) + + def __new__(cls, verbose): + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + @classmethod + def get_instance(cls): + """Obtain SuggestionBuffer instance""" + if cls._instance is None: + raise ValueError("SuggestionBuffer has not been initialized yet.") + return cls._instance + + def __init__(self, verbose): + """Initializes the suggestion buffer with thread-safe synchronization.""" + self._verbose = verbose + self._buffer : dict[NodeEvent, list[SACommand]] = defaultdict(list) # {event: [suggestion]} + self._suggestion_buffer_lock = Locker("suggestion_buffer_lock", async_lock=True) + self._expected_agents = defaultdict(set) # {event: {agents}} + self._expected_agents_lock = Locker("expected_agents_lock", async_lock=True) + self._event_notifications : dict[SAModuleAgent, asyncio.Event] = {} + + async def register_event_agents(self, event_type, agent: SAModuleAgent): + """Registers expected agents for a given event.""" + async with self._expected_agents_lock: + self._expected_agents[event_type].add(agent) + if event_type not in self._event_notifications: + self._event_notifications[agent] = asyncio.Event() + + async def register_suggestion(self, event_type, agent: SAModuleAgent, suggestion: SACommand): + """Registers a suggestion from an agent for a specific event.""" + async with self._suggestion_buffer_lock: + self._buffer[event_type].append((agent, suggestion)) + + async def notify_all_suggestions_done_for_agent(self, saa : SAModuleAgent): + async with self._expected_agents_lock: + try: + self._event_notifications[saa].set() + except: + if self._verbose: logging.error(f"SAModuleAgent: {saa} not found on notifications awaited") + + async def get_suggestions(self, event_type): + """Retrieves all suggestions registered for a given event.""" + async with self._suggestion_buffer_lock: + return self._buffer.get(event_type, []) + + async def clear_suggestions(self, event_type): + """Clears all suggestions stored for a given event.""" + async with self._lock: + if event_type in self._buffer: + del self._buffer[event_type] + del self._expected_agents[event_type] + + async def clear_sa_agent(self, saa : SAModuleAgent): + async with self._expected_agents_lock: + pass From 982d0feddb13644bf695397f87ebeb8e40421533 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 3 Apr 2025 17:34:36 +0200 Subject: [PATCH 154/233] fix suggestionbuffer --- .../awareness/sacommand.py | 27 ++++++--- .../awareness/samodule.py | 10 +++- .../awareness/samoduleagent.py | 4 +- .../awareness/suggestionbuffer.py | 55 +++++++++++++++---- 4 files changed, 73 insertions(+), 23 deletions(-) diff --git a/nebula/core/situationalawareness/awareness/sacommand.py b/nebula/core/situationalawareness/awareness/sacommand.py index 82154b0e3..4bef2500c 100644 --- a/nebula/core/situationalawareness/awareness/sacommand.py +++ b/nebula/core/situationalawareness/awareness/sacommand.py @@ -1,18 +1,15 @@ -from nebula.core.utils.locker import Locker from abc import abstractmethod from enum import Enum import asyncio +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from nebula.core.situationalawareness.awareness.samoduleagent import SAModuleAgent -# ----------------------------------------------- -# ENUMS for types of SACommands -# ----------------------------------------------- class SACommandType(Enum): CONNECTIVITY = "Connectivity" AGGREGATION = "Aggregation" -# ----------------------------------------------- -# ENUM for available actions -# ----------------------------------------------- +#TODO separar por tipos de commands class SACommandAction(Enum): DISCONNECT = "disconnect" RECONNECT = "reconnect" @@ -26,6 +23,11 @@ class SACommandPRIO(Enum): MEDIUM = 5 LOW = 1 +class SACommandState(Enum): + PENDING = "pending" + DISCARDED = "discarded" + EXECUTED = "executed" + """ ############################### # SA COMMAND CLASS # ############################### @@ -36,16 +38,19 @@ class SACommand: def __init__( self, command_type: SACommandType, - action: SACommandAction, + action: SACommandAction, + owner: "SAModuleAgent", target, priority: SACommandPRIO = SACommandPRIO.MEDIUM, parallelizable = False ): self._command_type = command_type self._action = action + self._owner = owner self._target = target # Could be a node, parameter, etc. self._priority = priority self._parallelizable = parallelizable + self._state = SACommandState.PENDING @abstractmethod async def execute(self): @@ -55,6 +60,12 @@ async def execute(self): async def conflicts_with(self, other: "SACommand") -> bool: raise NotImplementedError + def get_owner(self): + return self._owner.get_agent() + + def update_command_state(self, sacs : SACommandState): + self._state = sacs + def is_parallelizable(self): return self._parallelizable diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index bf34a23d6..4b4d92eb8 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -7,6 +7,7 @@ from nebula.core.utils.locker import Locker from nebula.core.nebulaevents import RoundEndEvent from nebula.core.eventmanager import EventManager +from nebula.core.nebulaevents import RoundEndEvent, AggregationEvent from nebula.core.network.communications import CommunicationsManager @@ -38,6 +39,9 @@ def __init__( self._situational_awareness_training = SATraining(self, self._addr, "qds", "fastreboot", verbose=True) self._restructure_process_lock = Locker(name="restructure_process_lock") self._restructure_cooldown = 0 + self._arbitrator_notification = asyncio.Event() + self._suggestion_buffer = SuggestionBuffer(self._arbitrator_notification, verbose=True) + self._communciation_manager = CommunicationsManager.get_instance() @property def nm(self): @@ -53,7 +57,11 @@ def sat(self): @property def cm(self): - return CommunicationsManager.get_instance() + return self._communciation_manager + + @property + def sb(self): + return self._suggestion_buffer async def init(self): diff --git a/nebula/core/situationalawareness/awareness/samoduleagent.py b/nebula/core/situationalawareness/awareness/samoduleagent.py index e23cc3669..6c106a252 100644 --- a/nebula/core/situationalawareness/awareness/samoduleagent.py +++ b/nebula/core/situationalawareness/awareness/samoduleagent.py @@ -2,9 +2,9 @@ from nebula.core.situationalawareness.awareness.sacommand import SACommand class SAModuleAgent(ABC): -#TODO hacer el diccionario de tipos de agentes para luego tenerlo en cuenta en el arbitraje + @abstractmethod - async def get_agent(self): + async def get_agent(self) -> str: raise NotImplementedError @abstractmethod diff --git a/nebula/core/situationalawareness/awareness/suggestionbuffer.py b/nebula/core/situationalawareness/awareness/suggestionbuffer.py index fa64ba3c1..cbff48604 100644 --- a/nebula/core/situationalawareness/awareness/suggestionbuffer.py +++ b/nebula/core/situationalawareness/awareness/suggestionbuffer.py @@ -10,7 +10,7 @@ class SuggestionBuffer(): _instance = None _lock = Locker("initialize_sb_lock", async_lock=False) - def __new__(cls, verbose): + def __new__(cls, arbitrator_notification, verbose): with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) @@ -23,15 +23,18 @@ def get_instance(cls): raise ValueError("SuggestionBuffer has not been initialized yet.") return cls._instance - def __init__(self, verbose): + def __init__(self, arbitrator_notification : asyncio.Event, verbose): """Initializes the suggestion buffer with thread-safe synchronization.""" + self._arbitrator_notification = arbitrator_notification + self._arbitrator_notification_lock = Locker("arbitrator_notification_lock", async_lock=True) self._verbose = verbose - self._buffer : dict[NodeEvent, list[SACommand]] = defaultdict(list) # {event: [suggestion]} + self._buffer : dict[NodeEvent, list[SACommand]] = defaultdict(list) # {event: [suggestion]} self._suggestion_buffer_lock = Locker("suggestion_buffer_lock", async_lock=True) - self._expected_agents = defaultdict(set) # {event: {agents}} + self._expected_agents = defaultdict(set) # {event: {agents}} self._expected_agents_lock = Locker("expected_agents_lock", async_lock=True) - self._event_notifications : dict[SAModuleAgent, asyncio.Event] = {} - + self._event_notifications : dict[SAModuleAgent, asyncio.Event] = {} + self._event_waited = None + async def register_event_agents(self, event_type, agent: SAModuleAgent): """Registers expected agents for a given event.""" async with self._expected_agents_lock: @@ -44,17 +47,48 @@ async def register_suggestion(self, event_type, agent: SAModuleAgent, suggestion async with self._suggestion_buffer_lock: self._buffer[event_type].append((agent, suggestion)) - async def notify_all_suggestions_done_for_agent(self, saa : SAModuleAgent): + async def set_event_waited(self, event_type): + """Registers event to be waited""" + if not self._event_waited: + if self._verbose: logging.info(f"Set notification when all suggestiones are being received for event: {event_type}") + self._event_waited = event_type + + #TODO maybe should define dict using events as keys to collect notifications for agents per events + async def notify_all_suggestions_done_for_agent(self, saa : SAModuleAgent, event_type): + """SA Agent notification that has registered all the suggestions for event_type""" async with self._expected_agents_lock: try: self._event_notifications[saa].set() + if self._verbose: logging.info(f"SA Agent: {saa} notifies all suggestions registered for event: {event_type}") + await self._notify_arbitrator(event_type) except: if self._verbose: logging.error(f"SAModuleAgent: {saa} not found on notifications awaited") - + + async def _notify_arbitrator(self, event_type): + """Checking if is should notify arbitrator that all suggestions for event_type have been received""" + if event_type != self._event_waited: + return + + async with self._arbitrator_notification_lock: + async with self._expected_agents_lock: + expected_agents = self._expected_agents.get(event_type, []) # Get the expected agents for this event type + # Check if all expected agents have sent their notifications + all_received = all(self._event_notifications[agent].is_set() for agent in expected_agents if agent in self._event_notifications) + if all_received: + self._arbitrator_notification.set() + self._event_waited = None + await self._reset_notifications_for_agents(expected_agents) + + async def _reset_notifications_for_agents(self, agents): + """Reset notifications for SA Agents""" + for agent in agents: + self._event_notifications[agent].clear() + async def get_suggestions(self, event_type): """Retrieves all suggestions registered for a given event.""" async with self._suggestion_buffer_lock: - return self._buffer.get(event_type, []) + async with self._expected_agents_lock: + return self._buffer.get(event_type, []) async def clear_suggestions(self, event_type): """Clears all suggestions stored for a given event.""" @@ -63,6 +97,3 @@ async def clear_suggestions(self, event_type): del self._buffer[event_type] del self._expected_agents[event_type] - async def clear_sa_agent(self, saa : SAModuleAgent): - async with self._expected_agents_lock: - pass From f18ef812c2a7b1255cc947ebf0d7d72c55a29687 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 3 Apr 2025 18:00:00 +0200 Subject: [PATCH 155/233] daile update --- nebula/core/situationalawareness/awareness/sacommand.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nebula/core/situationalawareness/awareness/sacommand.py b/nebula/core/situationalawareness/awareness/sacommand.py index 4bef2500c..519aa1318 100644 --- a/nebula/core/situationalawareness/awareness/sacommand.py +++ b/nebula/core/situationalawareness/awareness/sacommand.py @@ -32,7 +32,7 @@ class SACommandState(Enum): # SA COMMAND CLASS # ############################### """ - +#TODO aΓ±adir estados ene xecute class SACommand: """Base class for Situational Awareness module commands.""" def __init__( From f824224222c1ce9f2733d37bb141e62fbba94aa3 Mon Sep 17 00:00:00 2001 From: FerTV Date: Thu, 3 Apr 2025 20:00:22 +0200 Subject: [PATCH 156/233] fix endpoints for scenarios --- nebula/controller.py | 205 +++++++++++++- nebula/frontend/app.py | 338 +++++++++++++---------- nebula/frontend/templates/dashboard.html | 60 ++-- nebula/scenarios.py | 2 +- 4 files changed, 417 insertions(+), 188 deletions(-) diff --git a/nebula/controller.py b/nebula/controller.py index fdb79f380..d999a6181 100755 --- a/nebula/controller.py +++ b/nebula/controller.py @@ -22,7 +22,7 @@ from nebula.addons.env import check_environment from nebula.config.config import Config from nebula.config.mender import Mender -from nebula.scenarios import ScenarioManagement +from nebula.scenarios import Scenario, ScenarioManagement from nebula.tests import main as deploy_tests from nebula.utils import DockerUtils, SocketUtils @@ -143,6 +143,7 @@ async def get_available_gpu(): except Exception: # noqa: S110 pass + @app.post("/scenarios/remove") async def remove_scenario( scenario_name: str = Body(..., embed=True) @@ -194,6 +195,53 @@ async def get_scenarios( return {"scenarios": scenarios, "scenario_running": scenario_running} +@app.post("/scenarios/update") +async def update_scenario( + scenario_name: str = Body(..., embed=True), + start_time: str = Body(..., embed=True), + end_time: str = Body(..., embed=True), + scenario: dict = Body(..., embed=True), + status: str = Body(..., embed=True), + role: str = Body(..., embed=True), + username: str = Body(..., embed=True) +): + """ + Controller endpoint to update a scenario. + """ + from nebula.frontend.database import scenario_update_record + + try: + scenario = Scenario.from_dict(scenario) + scenario_update_record(scenario_name, start_time, end_time, scenario, status, role, username) + except Exception as e: + logging.error(f"Error updating scenario {scenario_name}: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + return {"message": f"Scenario {scenario_name} updated successfully"} + + +@app.post("/scenarios/set_status_to_finished") +async def set_scenario_status_to_finished( + scenario_name: str = Body(..., embed=True), + all: bool = Body(False, embed=True) +): + """ + Controller endpoint to set the status of a scenario to finished. + """ + from nebula.frontend.database import scenario_set_status_to_finished, scenario_set_all_status_to_finished + + try: + if all: + scenario_set_all_status_to_finished() + else: + scenario_set_status_to_finished(scenario_name) + except Exception as e: + logging.error(f"Error setting scenario {scenario_name} to finished: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + return {"message": f"Scenario {scenario_name} status set to finished successfully"} + + @app.get("/scenarios/running") async def get_running_scenario(get_all: bool = False): """ @@ -208,6 +256,94 @@ async def get_running_scenario(get_all: bool = False): raise HTTPException(status_code=500, detail="Internal server error") +@app.get("/nodes/{scenario_name}") +async def list_nodes_by_scenario_name( + scenario_name: Annotated[ + str, + Path( + regex="^[a-zA-Z0-9_-]+$", + min_length=1, + max_length=50, + description="Valid scenario name" + ) + ] +): + """ + Controller endpoint to retrieve nodes by scenario name. + """ + from nebula.frontend.database import list_nodes_by_scenario_name + + try: + nodes = list_nodes_by_scenario_name(scenario_name) + except Exception as e: + logging.error(f"Error obtaining nodes: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + return {"nodes": nodes} + + +@app.post("/nodes/update/") +async def update_nodes( + node_uid: str = Body(..., embed=True), + node_idx: str = Body(..., embed=True), + node_ip: str = Body(..., embed=True), + node_port: str = Body(..., embed=True), + node_role: int = Body(..., embed=True), + node_neighbors: str = Body(..., embed=True), + node_latitude: str = Body(..., embed=True), + node_longitude: str = Body(..., embed=True), + node_timestamp: str = Body(..., embed=True), + node_federation: str = Body(..., embed=True), + node_round_number: str = Body(..., embed=True), + node_scenario_name: str = Body(..., embed=True), + node_run_hash: str = Body(..., embed=True) +): + """ + Controller endpoint to update nodes. + """ + from nebula.frontend.database import update_node_record + + try: + update_node_record( + node_uid, + node_idx, + node_ip, + node_port, + node_role, + node_neighbors, + node_latitude, + node_longitude, + node_timestamp, + node_federation, + node_round_number, + node_scenario_name, + node_run_hash, + ) + except Exception as e: + logging.error(f"Error updating nodes: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + return {"message": "Nodes updated successfully"} + + +@app.post("/nodes/remove") +async def remove_nodes_by_scenario_name( + scenario_name: str = Body(..., embed=True) +): + """ + Controller endpoint to remove nodes by scenario name. + """ + from nebula.frontend.database import remove_nodes_by_scenario_name + + try: + remove_nodes_by_scenario_name(scenario_name) + except Exception as e: + logging.error(f"Error removing nodes: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + return {"message": f"Nodes for scenario {scenario_name} removed successfully"} + + @app.get("/scenarios/check") async def check_scenario(role: str, scenario_name: str): """ @@ -235,6 +371,7 @@ async def get_scenario_by_name( ) ] ): + logging.info("[FER] controller") from nebula.frontend.database import get_scenario_by_name try: @@ -246,8 +383,8 @@ async def get_scenario_by_name( return {"scenario": scenario} -@app.get("/scenarios/user/{scenario_name}") -async def get_user_by_scenario_name( +@app.get("/notes") +async def get_notes_by_scenario_name( scenario_name: Annotated[ str, Path( @@ -258,15 +395,46 @@ async def get_user_by_scenario_name( ) ] ): - from nebula.frontend.database import get_user_by_scenario_name + from nebula.frontend.database import get_notes try: - user = get_user_by_scenario_name(scenario_name) + notes = get_notes(scenario_name) except Exception as e: - logging.error(f"Error obtaining user {user}: {e}") + logging.error(f"Error obtaining notes {notes}: {e}") raise HTTPException(status_code=500, detail="Internal server error") - return {"user": user} + return {"notes": notes} + + +@app.post("/notes/update") +async def update_notes_by_scenario_name( + scenario_name: str = Body(..., embed=True), + notes: str = Body(..., embed=True) +): + from nebula.frontend.database import save_notes + + try: + save_notes(scenario_name, notes) + except Exception as e: + logging.error(f"Error updating notes: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + return {"message": f"Notes for scenario {scenario_name} updated successfully"} + + +@app.post("/notes/remove") +async def remove_notes_by_scenario_name( + scenario_name: str = Body(..., embed=True) +): + from nebula.frontend.database import remove_note + + try: + remove_note(scenario_name) + except Exception as e: + logging.error(f"Error removing notes: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + return {"message": f"Notes for scenario {scenario_name} removed successfully"} @app.get("/user/list") @@ -290,6 +458,29 @@ async def list_users_controller(all_info: bool = False): ) +@app.get("/user/{scenario_name}") +async def get_user_by_scenario_name( + scenario_name: Annotated[ + str, + Path( + regex="^[a-zA-Z0-9_-]+$", + min_length=1, + max_length=50, + description="Valid scenario name" + ) + ] +): + from nebula.frontend.database import get_user_by_scenario_name + + try: + user = get_user_by_scenario_name(scenario_name) + except Exception as e: + logging.error(f"Error obtaining user {user}: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + return user + + @app.post("/user/add") async def add_user_controller( user: str = Body(...), diff --git a/nebula/frontend/app.py b/nebula/frontend/app.py index 8156fbff7..6c2e1433f 100755 --- a/nebula/frontend/app.py +++ b/nebula/frontend/app.py @@ -91,31 +91,6 @@ class Settings: from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.middleware.sessions import SessionMiddleware -from nebula.frontend.database import ( - # add_user, - # check_scenario_with_role, - delete_user_from_db, - get_all_scenarios_and_check_completed, - get_notes, - get_running_scenario, - # get_scenario_by_name, - # get_user_by_scenario_name, - get_user_info, - initialize_databases, - # list_nodes_by_scenario_name, - # list_users, - remove_nodes_by_scenario_name, - remove_note, - # remove_scenario_by_name, - save_notes, - scenario_set_all_status_to_finished, - scenario_set_status_to_finished, - scenario_update_record, - update_node_record, - update_user, - verify, - verify_hash_algorithm, -) from nebula.utils import DockerUtils, FileUtils logging.info(f"πŸš€ Starting Nebula Frontend on port {settings.port}") @@ -240,9 +215,9 @@ def __init__(self): # Detect CTRL+C from parent process -def signal_handler(signal, frame): +async def signal_handler(signal, frame): logging.info("You pressed Ctrl+C [frontend]!") - scenario_set_all_status_to_finished() + asyncio.get_event_loop().create_task(scenario_set_status_to_finished(all=True)) sys.exit(0) @@ -265,16 +240,17 @@ async def custom_http_exception_handler(request: Request, exc: StarletteHTTPExce return await request.app.default_exception_handler(request, exc) -async def get(url): +async def controller_get(url): async with aiohttp.ClientSession() as session: async with session.get(url) as response: if response.status == 200: + logging.info(f"[FER] GET request to {url} succeeded") return await response.json() else: raise HTTPException(status_code=response.status, detail="Error fetching data") -async def post(url, data=None): +async def controller_post(url, data=None): async with aiohttp.ClientSession() as session: async with session.post(url, json=data) as response: if response.status == 200: @@ -285,23 +261,35 @@ async def post(url, data=None): async def get_available_gpus(): url = f"http://{settings.controller_host}:{settings.controller_port}/available_gpus" - return await get(url) + return await controller_get(url) async def get_least_memory_gpu(): url = f"http://{settings.controller_host}:{settings.controller_port}/least_memory_gpu" - return await get(url) + return await controller_get(url) async def get_scenarios(user, role): url = f"http://{settings.controller_host}:{settings.controller_port}/scenarios/{user}/{role}" - return await get(url) + return await controller_get(url) + + +async def scenario_update_record(scenario_name, start_time, end_time, scenario, status, role, username): + url = f"http://{settings.controller_host}:{settings.controller_port}/scenarios/update" + data = {"scenario_name": scenario_name, "start_time": start_time, "end_time": end_time, "scenario": scenario, "status": status, "role": role, "username": username} + await controller_post(url, data) + + +async def scenario_set_status_to_finished(scenario_name, all=False): + url = f"http://{settings.controller_host}:{settings.controller_port}/scenarios/set_status_to_finished" + data = {"scenario_name": scenario_name, "all": all} + await controller_post(url, data) async def remove_scenario_by_name(scenario_name): - url = f"http://{settings.controller_host}:{settings.controller_port}/scenarios/remove/{scenario_name}" + url = f"http://{settings.controller_host}:{settings.controller_port}/scenarios/remove" data = {"scenario_name": scenario_name} - await post(url, data) + await controller_post(url, data) async def check_scenario_with_role(session, scenario_name): @@ -309,28 +297,66 @@ async def check_scenario_with_role(session, scenario_name): f"http://{settings.controller_host}:{settings.controller_port}" f"/scenarios/check?role={session['role']}&scenario_name={scenario_name}" ) - check_data = await get(url) + check_data = await controller_get(url) return check_data.get("allowed", False) async def get_scenario_by_name(scenario_name): url = f"http://{settings.controller_host}:{settings.controller_port}/scenarios/{scenario_name}" - return await get(url) - - -async def get_user_by_scenario_name(scenario_name): - url = f"http://{settings.controller_host}:{settings.controller_port}/scenarios/user/{scenario_name}" - return await get(url) + return await controller_get(url) async def get_running_scenarios(get_all=False): url = f"http://{settings.controller_host}:{settings.controller_port}/scenarios/running?get_all={get_all}" - return await get(url) + return await controller_get(url) async def list_nodes_by_scenario_name(scenario_name): url = f"http://{settings.controller_host}:{settings.controller_port}/nodes/{scenario_name}" - return await get(url) + return await controller_get(url) + + +async def update_node_record(uid, idx, ip, port, role, neighbors, latitude, longitude, timestamp, federation, round_number, scenario_name, run_hash): + url = f"http://{settings.controller_host}:{settings.controller_port}/nodes/update" + data = { + "uid": uid, + "idx": idx, + "ip": ip, + "port": port, + "role": role, + "neighbors": neighbors, + "latitude": latitude, + "longitude": longitude, + "timestamp": timestamp, + "federation": federation, + "round": round_number, + "scenario_name": scenario_name, + "run_hash": run_hash, + } + await controller_post(url, data) + + +async def remove_nodes_by_scenario_name(scenario_name): + url = f"http://{settings.controller_host}:{settings.controller_port}/nodes/remove" + data = {"scenario_name": scenario_name} + await controller_post(url, data) + + +async def get_notes(scenario_name): + url = f"http://{settings.controller_host}:{settings.controller_port}/notes/{scenario_name}" + return await controller_get(url) + + +async def save_notes(scenario_name, notes): + url = f"http://{settings.controller_host}:{settings.controller_port}/notes/save" + data = {"scenario_name": scenario_name, "notes": notes} + await controller_post(url, data) + + +async def remove_note(scenario_name): + url = f"http://{settings.controller_host}:{settings.controller_port}/notes/remove" + data = {"scenario_name": scenario_name} + await controller_post(url, data) async def list_users(allinfo=True): @@ -344,34 +370,39 @@ async def list_users(allinfo=True): - A list of users, as provided by the controller. """ url = f"http://{settings.controller_host}:{settings.controller_port}/user/list?all_info={allinfo}" - data = await get(url) + data = await controller_get(url) user_list = data["users"] return user_list +async def get_user_by_scenario_name(scenario_name): + url = f"http://{settings.controller_host}:{settings.controller_port}/user/{scenario_name}" + return await controller_get(url) + + async def add_user(user, password, role): url = f"http://{settings.controller_host}:{settings.controller_port}/user/add" data = {"user": user, "password": password, "role": role} - await post(url, data) + await controller_post(url, data) async def update_user(user, password, role): url = f"http://{settings.controller_host}:{settings.controller_port}/user/update" data = {"user": user, "password": password, "role": role} - await post(url, data) + await controller_post(url, data) async def delete_user(user): url = f"http://{settings.controller_host}:{settings.controller_port}/user/delete" data = {"user": user} - await post(url, data) + await controller_post(url, data) async def verify_user(user, password): url = f"http://{settings.controller_host}:{settings.controller_port}/user/verify" data = {"user": user, "password": password} - return await post(url, data) + return await controller_post(url, data) @app.get("/", response_class=HTMLResponse) @@ -427,7 +458,7 @@ async def save_note_for_scenario(scenario_name: str, request: Request, session: data = await request.json() notes = data["notes"] try: - save_notes(scenario_name, notes) + await save_notes(scenario_name, notes) return JSONResponse({"status": "success"}) except Exception as e: logging.exception(e) @@ -441,7 +472,7 @@ async def save_note_for_scenario(scenario_name: str, request: Request, session: @app.get("/platform/dashboard/{scenario_name}/notes") async def get_notes_for_scenario(scenario_name: str): - notes_record = get_notes(scenario_name) + notes_record = await get_notes(scenario_name) if notes_record: notes_data = dict(zip(notes_record.keys(), notes_record, strict=False)) return JSONResponse({"status": "success", "notes": notes_data["scenario_notes"]}) @@ -642,7 +673,7 @@ async def nebula_dashboard(request: Request, session: dict = Depends(get_session bool_completed = False if scenario_running: - bool_completed = scenario_running[6] == "completed" + bool_completed = scenario_running["status"] == "completed" if scenarios: if request.url.path == "/platform/dashboard": return templates.TemplateResponse( @@ -987,7 +1018,7 @@ async def nebula_monitor_image(scenario_name: str): raise HTTPException(status_code=404, detail="Topology image not found") -def stop_scenario(scenario_name, user): +async def stop_scenario(scenario_name, user): from nebula.scenarios import ScenarioManagement ScenarioManagement.stop_participants(scenario_name) @@ -996,18 +1027,18 @@ def stop_scenario(scenario_name, user): f"{(os.environ.get('NEBULA_CONTROLLER_NAME'))}_{str(user).lower()}-nebula-net-scenario" ) ScenarioManagement.stop_blockchain() - scenario_set_status_to_finished(scenario_name) + await scenario_set_status_to_finished(scenario_name) # Generate statistics for the scenario path = FileUtils.check_path(settings.log_dir, scenario_name) ScenarioManagement.generate_statistics(path) -def stop_all_scenarios(): - from nebula.scenarios import ScenarioManagement +# def stop_all_scenarios(): +# from nebula.scenarios import ScenarioManagement - ScenarioManagement.stop_participants() - ScenarioManagement.stop_blockchain() - scenario_set_all_status_to_finished() +# ScenarioManagement.stop_participants() +# ScenarioManagement.stop_blockchain() +# scenario_set_all_status_to_finished() @app.get("/platform/dashboard/{scenario_name}/stop/{stop_all}") @@ -1030,11 +1061,11 @@ async def nebula_stop_scenario( user_data.stop_all_scenarios_event.set() user_data.scenarios_list_length = 0 user_data.scenarios_finished = 0 - stop_scenario(scenario_name, user) + await stop_scenario(scenario_name, user) else: user_data.finish_scenario_event.set() user_data.scenarios_list_length -= 1 - stop_scenario(scenario_name, user) + await stop_scenario(scenario_name, user) return RedirectResponse(url="/platform/dashboard") else: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) @@ -1042,17 +1073,25 @@ async def nebula_stop_scenario( async def remove_scenario(scenario_name=None, user=None): from nebula.scenarios import ScenarioManagement + logging.info(f"[FER] remove_scenario {scenario_name} {user}") user_data = user_data_store[user] + logging.info("[FER] user_data") + if settings.advanced_analytics: logging.info("Advanced analytics enabled") # Remove registered nodes and conditions user_data.nodes_registration.pop(scenario_name, None) - remove_nodes_by_scenario_name(scenario_name) + logging.info("[FER] user_data_nodes_pop") + await remove_nodes_by_scenario_name(scenario_name) + logging.info("[FER] remove_nodes_by_scenario_name") await remove_scenario_by_name(scenario_name) - remove_note(scenario_name) + logging.info("[FER] remove_scenario_by__name") + await remove_note(scenario_name) + logging.info("[FER] remove_note") ScenarioManagement.remove_files_by_scenario(scenario_name) + logging.info("[FER] remove files") @app.get("/platform/dashboard/{scenario_name}/relaunch") @@ -1234,79 +1273,79 @@ async def nebula_dashboard_deployment(request: Request, session: dict = Depends( ) -def attack_node_assign( - nodes, - federation, - attack, - poisoned_node_percent, - poisoned_sample_percent, - poisoned_noise_percent, -): - """Identify which nodes will be attacked""" - import math - import random - - attack_matrix = [] - n_nodes = len(nodes) - if n_nodes == 0: - return attack_matrix - - nodes_index = [] - # Get the nodes index - if federation == "DFL": - nodes_index = list(nodes.keys()) - else: - for node in nodes: - if nodes[node]["role"] != "server": - nodes_index.append(node) - - n_nodes = len(nodes_index) - # Number of attacked nodes, round up - num_attacked = int(math.ceil(poisoned_node_percent / 100 * n_nodes)) - if num_attacked > n_nodes: - num_attacked = n_nodes - - # Get the index of attacked nodes - attacked_nodes = random.sample(nodes_index, num_attacked) - - # Assign the role of each node - for node in nodes: - node_att = "No Attack" - attack_sample_persent = 0 - poisoned_ratio = 0 - if (node in attacked_nodes) or (nodes[node]["malicious"]): - node_att = attack - attack_sample_persent = poisoned_sample_percent / 100 - poisoned_ratio = poisoned_noise_percent / 100 - nodes[node]["attacks"] = node_att - nodes[node]["poisoned_sample_percent"] = attack_sample_persent - nodes[node]["poisoned_ratio"] = poisoned_ratio - attack_matrix.append([node, node_att, attack_sample_persent, poisoned_ratio]) - return nodes, attack_matrix +# def attack_node_assign( +# nodes, +# federation, +# attack, +# poisoned_node_percent, +# poisoned_sample_percent, +# poisoned_noise_percent, +# ): +# """Identify which nodes will be attacked""" +# import math +# import random + +# attack_matrix = [] +# n_nodes = len(nodes) +# if n_nodes == 0: +# return attack_matrix + +# nodes_index = [] +# # Get the nodes index +# if federation == "DFL": +# nodes_index = list(nodes.keys()) +# else: +# for node in nodes: +# if nodes[node]["role"] != "server": +# nodes_index.append(node) + +# n_nodes = len(nodes_index) +# # Number of attacked nodes, round up +# num_attacked = int(math.ceil(poisoned_node_percent / 100 * n_nodes)) +# if num_attacked > n_nodes: +# num_attacked = n_nodes + +# # Get the index of attacked nodes +# attacked_nodes = random.sample(nodes_index, num_attacked) + +# # Assign the role of each node +# for node in nodes: +# node_att = "No Attack" +# attack_sample_persent = 0 +# poisoned_ratio = 0 +# if (node in attacked_nodes) or (nodes[node]["malicious"]): +# node_att = attack +# attack_sample_persent = poisoned_sample_percent / 100 +# poisoned_ratio = poisoned_noise_percent / 100 +# nodes[node]["attacks"] = node_att +# nodes[node]["poisoned_sample_percent"] = attack_sample_persent +# nodes[node]["poisoned_ratio"] = poisoned_ratio +# attack_matrix.append([node, node_att, attack_sample_persent, poisoned_ratio]) +# return nodes, attack_matrix import math -def mobility_assign(nodes, mobile_participants_percent): - """Assign mobility to nodes""" - import random +# def mobility_assign(nodes, mobile_participants_percent): +# """Assign mobility to nodes""" +# import random - # Number of mobile nodes, round down - num_mobile = math.floor(mobile_participants_percent / 100 * len(nodes)) - if num_mobile > len(nodes): - num_mobile = len(nodes) +# # Number of mobile nodes, round down +# num_mobile = math.floor(mobile_participants_percent / 100 * len(nodes)) +# if num_mobile > len(nodes): +# num_mobile = len(nodes) - # Get the index of mobile nodes - mobile_nodes = random.sample(list(nodes.keys()), num_mobile) +# # Get the index of mobile nodes +# mobile_nodes = random.sample(list(nodes.keys()), num_mobile) - # Assign the role of each node - for node in nodes: - node_mob = False - if node in mobile_nodes: - node_mob = True - nodes[node]["mobility"] = node_mob - return nodes +# # Assign the role of each node +# for node in nodes: +# node_mob = False +# if node in mobile_nodes: +# node_mob = True +# nodes[node]["mobility"] = node_mob +# return nodes # Recieve a stopped node @@ -1326,7 +1365,7 @@ async def node_stopped(scenario_name: str, request: Request): finished = False if finished: - stop_scenario(scenario_name, user) + await stop_scenario(scenario_name, user) user_data.nodes_finished.clear() user_data.finish_scenario_event.set() return JSONResponse( @@ -1398,11 +1437,11 @@ async def run_scenario(scenario_data, role, user): # Manager for the actual scenario scenarioManagement = ScenarioManagement(scenario_data, user) - scenario_update_record( - name=scenarioManagement.scenario_name, + await scenario_update_record( + scenario_name=scenarioManagement.scenario_name, start_time=scenarioManagement.start_date_scenario, end_time="", - scenario=scenarioManagement.scenario, + scenario=scenario_data, status="running", role=role, username=user @@ -1432,28 +1471,27 @@ async def run_scenario(scenario_data, role, user): # Deploy the list of scenarios async def run_scenarios(role, user): - from nebula.scenarios import Scenario - try: user_data = user_data_store[user] - scenario_pos = 0 - created_time = datetime.datetime.now().strftime('%d_%m_%Y_%H_%M_%S') - for scenario_data in user_data.scenarios_list: - scenario_data["gpu_id"] = [] - scenario = Scenario.from_dict(scenario_data) - - scenario_update_record( - name=f"nebula_{scenario.federation}_{created_time}_{scenario_pos}", - start_time="", - end_time="", - scenario=scenario, - status="waiting", - role=role, - username=user - ) - - scenario_pos+=1 + # scenario_pos = 0 + # created_time = datetime.datetime.now().strftime('%d_%m_%Y_%H_%M_%S') + + # for scenario_data in user_data.scenarios_list: + # scenario_data["gpu_id"] = [] + # federation = scenario_data["federation"] + + # await scenario_update_record( + # scenario_name=f"nebula_{federation}_{created_time}_{scenario_pos}", + # start_time="", + # end_time="", + # scenario=scenario_data, + # status="waiting", + # role=role, + # username=user + # ) + + # scenario_pos+=1 for scenario_data in user_data.scenarios_list: user_data.finish_scenario_event.clear() diff --git a/nebula/frontend/templates/dashboard.html b/nebula/frontend/templates/dashboard.html index a7e73d211..db02d38bd 100755 --- a/nebula/frontend/templates/dashboard.html +++ b/nebula/frontend/templates/dashboard.html @@ -64,17 +64,17 @@

There is a scenario Scenarios queue {{ scenarios_finished }}/{{ scenarios_list_length }}

{% if scenarios_finished != scenarios_list_length %} - Stop scenario queue {% endif %} {% endif %}
-
Scenario name: {{ scenario_running[0] }}
-
Scenario title: {{ scenario_running[3] }}
-
Scenario description: {{ scenario_running[4] }}
-
Scenario start time: {{ scenario_running[1] }}
+
Scenario name: {{ scenario_running.name }}
+
Scenario title: {{ scenario_running.title }}
+
Scenario description: {{ scenario_running.description }}
+
Scenario start time: {{ scenario_running.start_time }}
Deploy new scenario @@ -117,59 +117,59 @@

Scenarios in the database

Action - {% for name, username, title, start_time, model, dataset, rounds, status in scenarios %} + {% for scenario in scenarios %} {% if user_role == "admin" %} - {{ username|lower }} + {{ scenario.username|lower }} {% endif %} - {{ title }} - {{ start_time }} - {{ model }} - {{ dataset }} - {{ rounds }} - {% if status == "running" %} + {{ scenario.title }} + {{ scenario.start_time }} + {{ scenario.model }} + {{ scenario.dataset }} + {{ scenario.rounds }} + {% if scenario.status == "running" %} Running - {% elif status == "waiting" %} + {% elif scenario.status == "waiting" %} Waiting {% else %} Finished {% endif %} - Monitor - Real-time metrics - - - {% if status == "running" %} - Stop scenario - {% elif status == "completed" %} - Stop scenario - Stop scenario queue {% else %} - - {% endif %} - + - - - + - diff --git a/nebula/scenarios.py b/nebula/scenarios.py index add9a90c7..36ddeab63 100644 --- a/nebula/scenarios.py +++ b/nebula/scenarios.py @@ -254,7 +254,7 @@ def from_dict(cls, data): class ScenarioManagement: def __init__(self, scenario, user=None): # Current scenario - self.scenario = scenario + self.scenario = Scenario.from_dict(scenario) # Uid of the user self.user = user # Scenario management settings From 4cc209128df5625f45dff28cde21d1519c40c104 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Fri, 4 Apr 2025 11:23:10 +0200 Subject: [PATCH 157/233] feature integrated suggestions system --- nebula/core/datasets/nebuladataset.py | 10 +++---- nebula/core/engine.py | 4 +-- .../awareness/sacommand.py | 4 ++- .../awareness/samoduleagent.py | 2 +- .../awareness/satraining/satraining.py | 6 ++-- .../trainingpolicy/qdstrainingpolicy.py | 29 ++++++++++++++++++- .../trainingpolicy/trainingpolicy.py | 1 - .../awareness/suggestionbuffer.py | 7 +++-- 8 files changed, 46 insertions(+), 17 deletions(-) diff --git a/nebula/core/datasets/nebuladataset.py b/nebula/core/datasets/nebuladataset.py index ec8c0b675..2df69858e 100755 --- a/nebula/core/datasets/nebuladataset.py +++ b/nebula/core/datasets/nebuladataset.py @@ -249,10 +249,10 @@ def __init__( iid="IID", partition="dirichlet", partition_parameter=0.5, - nsplits_percentages = [0.5, 0.25, 0.25], - nsplits_iid = ["Non-IID", "IID", "Non-IID"], - npartitions = ["dirichlet", "balancediid", "dirichlet"], - npartitions_parameter =[0.1, 2, 0.5], + nsplits_percentages = [1.0], + nsplits_iid = ["Non-IID"], + npartitions = ["dirichlet"], + npartitions_parameter =[0.1], seed=42, config_dir=None, ): @@ -315,7 +315,7 @@ def data_partitioning(self, plot=False): f"Partitioning data for {self.__class__.__name__} | Partitions: {self.partitions_number} | IID: {self.iid} | Partition: {self.partition} | Partition parameter: {self.partition_parameter}" ) - #self.iid = "a" #TODO REMOVE + #self.iid = "a" #TODO REMOVE modificar para q sea string y no boolean el input del front logging.info(f"Scenario with data distribution: {self.iid}") if self.iid == "IID": self.train_indices_map = self.generate_iid_map(self.train_set) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index befe40b66..8f22d65b4 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -142,7 +142,7 @@ def __init__( self.config.reload_config_file() - cm = CommunicationsManager(engine=self) + self._cm = CommunicationsManager(engine=self) # Set the communication manager in the model (send messages from there) self._reporter = Reporter(config=self.config, trainer=self.trainer) @@ -172,7 +172,7 @@ def __init__( @property def cm(self): - return CommunicationsManager.get_instance() + self._cm @property def reporter(self): diff --git a/nebula/core/situationalawareness/awareness/sacommand.py b/nebula/core/situationalawareness/awareness/sacommand.py index 519aa1318..79aafcbb0 100644 --- a/nebula/core/situationalawareness/awareness/sacommand.py +++ b/nebula/core/situationalawareness/awareness/sacommand.py @@ -32,7 +32,7 @@ class SACommandState(Enum): # SA COMMAND CLASS # ############################### """ -#TODO aΓ±adir estados ene xecute + class SACommand: """Base class for Situational Awareness module commands.""" def __init__( @@ -94,6 +94,7 @@ def __init__( async def execute(self): """Executes the assigned action function with the given parameters.""" + self.update_command_state(SACommandState.EXECUTED) if self._action_function: if asyncio.iscoroutinefunction(self._action_function): await self._action_function(*self._args) @@ -122,6 +123,7 @@ def __init__( super().__init__(SACommandType.CONNECTIVITY, action, target, priority, parallelizable) async def execute(self): + self.update_command_state(SACommandState.EXECUTED) return self._target def conflicts_with(self, other: "AggregationCommand") -> bool: diff --git a/nebula/core/situationalawareness/awareness/samoduleagent.py b/nebula/core/situationalawareness/awareness/samoduleagent.py index 6c106a252..4b3bbcb75 100644 --- a/nebula/core/situationalawareness/awareness/samoduleagent.py +++ b/nebula/core/situationalawareness/awareness/samoduleagent.py @@ -16,5 +16,5 @@ async def suggest_action(self, sac : SACommand): raise NotImplementedError @abstractmethod - async def notify_all_suggestions_done(self, sac : SACommand): + async def notify_all_suggestions_done(self, event_type): raise NotImplementedError \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/satraining/satraining.py b/nebula/core/situationalawareness/awareness/satraining/satraining.py index d96f9e6d3..5e4bb175d 100644 --- a/nebula/core/situationalawareness/awareness/satraining/satraining.py +++ b/nebula/core/situationalawareness/awareness/satraining/satraining.py @@ -47,7 +47,5 @@ async def init(self): async def module_actions(self): logging.info("SA Trainng evaluating current scenario") - nodes = await self.tp.get_evaluation_results() - if nodes: - for n in nodes: - asyncio.create_task(self.sam.cm.disconnect(n[0], forced=True)) + asyncio.create_task(self.tp.get_evaluation_results()) + diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py index 77d03ce21..771238338 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py @@ -6,6 +6,9 @@ import logging from nebula.core.eventmanager import EventManager from nebula.core.nebulaevents import AggregationEvent +from nebula.core.situationalawareness.awareness.suggestionbuffer import SuggestionBuffer +from nebula.core.situationalawareness.awareness.sacommand import SACommand, ConnectivityCommand, SACommandAction, SACommandPRIO +from nebula.core.network.communications import CommunicationsManager import math # "Quality-Driven Selection" (QDS) @@ -34,6 +37,7 @@ async def init(self, config): nodes = config["nodes"] self._nodes : dict[str, tuple[deque, int]] = {node_id: (deque(maxlen=self.MAX_HISTORIC_SIZE), 0) for node_id in nodes} await EventManager.get_instance().subscribe_node_event(AggregationEvent, self.process_aggregation_event) + await self.register_sa_agent() async def update_neighbors(self, node, remove=False): async with self._nodes_lock: @@ -120,4 +124,27 @@ async def evaluate(self): return result async def get_evaluation_results(self): - return self._evaluation_results.copy() \ No newline at end of file + for node_discarded in self._evaluation_results: + args = (node_discarded, False, True) + sac = ConnectivityCommand( + SACommandAction.DISCONNECT, + node_discarded, + SACommandPRIO.MEDIUM, + True, + CommunicationsManager.get_instance().disconnect, + *args + ) + await self.suggest_action(sac) + await self.notify_all_suggestions_done(AggregationEvent) + + async def get_agent(self) -> str: + return "QDS_training_policy" + + async def register_sa_agent(self): + await SuggestionBuffer.get_instance().register_event_agents(AggregationEvent, self) + + async def suggest_action(self, sac : SACommand): + await SuggestionBuffer.get_instance().register_suggestion(AggregationEvent, self, sac) + + async def notify_all_suggestions_done(self, event_type): + await SuggestionBuffer.get_instance().notify_all_suggestions_done_for_agent(self, event_type) \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py index df77da1cd..dddabf779 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py @@ -1,7 +1,6 @@ from abc import ABC, abstractmethod from nebula.core.situationalawareness.awareness.samoduleagent import SAModuleAgent - class TrainingPolicy(ABC, SAModuleAgent): @abstractmethod diff --git a/nebula/core/situationalawareness/awareness/suggestionbuffer.py b/nebula/core/situationalawareness/awareness/suggestionbuffer.py index cbff48604..b890ff6cd 100644 --- a/nebula/core/situationalawareness/awareness/suggestionbuffer.py +++ b/nebula/core/situationalawareness/awareness/suggestionbuffer.py @@ -38,6 +38,7 @@ def __init__(self, arbitrator_notification : asyncio.Event, verbose): async def register_event_agents(self, event_type, agent: SAModuleAgent): """Registers expected agents for a given event.""" async with self._expected_agents_lock: + if self._verbose: logging.info(f"Registering SA Agent: {agent.get_agent()} for event: {event_type}") self._expected_agents[event_type].add(agent) if event_type not in self._event_notifications: self._event_notifications[agent] = asyncio.Event() @@ -45,6 +46,7 @@ async def register_event_agents(self, event_type, agent: SAModuleAgent): async def register_suggestion(self, event_type, agent: SAModuleAgent, suggestion: SACommand): """Registers a suggestion from an agent for a specific event.""" async with self._suggestion_buffer_lock: + if self._verbose: logging.info(f"Registering Suggestion from SA Agent: {agent.get_agent()} for event: {event_type}") self._buffer[event_type].append((agent, suggestion)) async def set_event_waited(self, event_type): @@ -59,10 +61,10 @@ async def notify_all_suggestions_done_for_agent(self, saa : SAModuleAgent, event async with self._expected_agents_lock: try: self._event_notifications[saa].set() - if self._verbose: logging.info(f"SA Agent: {saa} notifies all suggestions registered for event: {event_type}") + if self._verbose: logging.info(f"SA Agent: {saa.get_agent()} notifies all suggestions registered for event: {event_type}") await self._notify_arbitrator(event_type) except: - if self._verbose: logging.error(f"SAModuleAgent: {saa} not found on notifications awaited") + if self._verbose: logging.error(f"SAModuleAgent: {saa.get_agent()} not found on notifications awaited") async def _notify_arbitrator(self, event_type): """Checking if is should notify arbitrator that all suggestions for event_type have been received""" @@ -88,6 +90,7 @@ async def get_suggestions(self, event_type): """Retrieves all suggestions registered for a given event.""" async with self._suggestion_buffer_lock: async with self._expected_agents_lock: + if self._verbose: logging.info(f"Retrieving all sugestions for event: {event_type}") return self._buffer.get(event_type, []) async def clear_suggestions(self, event_type): From df24dfdbb93020009d26f43000c7f8259d9429e0 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Fri, 4 Apr 2025 15:54:07 +0200 Subject: [PATCH 158/233] fix error interface --- nebula/core/engine.py | 3 ++- nebula/core/network/communications.py | 2 +- .../awareness/satraining/trainingpolicy/trainingpolicy.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 8f22d65b4..eafbb266b 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -143,6 +143,7 @@ def __init__( self.config.reload_config_file() self._cm = CommunicationsManager(engine=self) + # = CommunicationsManager.get_instance() # Set the communication manager in the model (send messages from there) self._reporter = Reporter(config=self.config, trainer=self.trainer) @@ -172,7 +173,7 @@ def __init__( @property def cm(self): - self._cm + return self._cm @property def reporter(self): diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 2c3391833..ccaf015b2 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -93,7 +93,7 @@ def __init__(self, engine: "Engine"): self._external_connection_service = factory_connection_service("nebula", self.addr) self._initialized = True - logging.info("Communication Manager initialized completed") + logging.info("Communication Manager initialization completed") @property def engine(self): diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py index dddabf779..6616b9581 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from nebula.core.situationalawareness.awareness.samoduleagent import SAModuleAgent -class TrainingPolicy(ABC, SAModuleAgent): +class TrainingPolicy(SAModuleAgent): @abstractmethod async def init(self, config): From 044d517fbad81a66a633fbb7e4c01ff709e7a21c Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Sat, 5 Apr 2025 09:54:14 +0200 Subject: [PATCH 159/233] opt suggestion buffer --- nebula/core/engine.py | 2 - nebula/core/nebulaevents.py | 16 +++- .../awareness/samodule.py | 3 - .../awareness/sanetwork/sanetwork.py | 6 +- .../awareness/suggestionbuffer.py | 75 +++++++++++-------- .../core/situationalawareness/nodemanager.py | 5 +- 6 files changed, 65 insertions(+), 42 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index eafbb266b..cfabdf9f0 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -611,8 +611,6 @@ async def _learning_cycle(self): ) # Set current round in config (send to the controller) await self.get_round_lock().release_async() - if self.mobility: - await self.nm.experiment_finish() # End of the learning cycle self.trainer.on_learning_cycle_end() await self.trainer.test() diff --git a/nebula/core/nebulaevents.py b/nebula/core/nebulaevents.py index a3a6d54dd..e50aa6e37 100644 --- a/nebula/core/nebulaevents.py +++ b/nebula/core/nebulaevents.py @@ -78,7 +78,21 @@ async def get_event_data(self): return (self._round, self._round_end_time) async def is_concurrent(self): - return False + return False + +class ExperimentFinishEvent(NodeEvent): + def __init__(self): + """Event triggered when experiment is going to finish.""" + pass + + def __str__(self): + return "Experiment finished" + + async def get_event_data(self): + pass + + async def is_concurrent(self): + return False class AggregationEvent(NodeEvent): def __init__(self, updates : dict, expected_nodes : set, missing_nodes : set): diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index 4b4d92eb8..d42deab65 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -73,9 +73,6 @@ async def init(self): def is_additional_participant(self): return self.nm.is_additional_participant() - async def experiment_finish(self): - await self.san.experiment_finish() - async def get_geoloc(self): latitude = self.nm.config.participant["mobility_args"]["latitude"] longitude = self.nm.config.participant["mobility_args"]["longitude"] diff --git a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py index 968643f5f..a7e9a30d6 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py @@ -5,7 +5,7 @@ from nebula.addons.functions import print_msg_box from nebula.core.nebulaevents import BeaconRecievedEvent from nebula.core.eventmanager import EventManager -from nebula.core.nebulaevents import NodeFoundEvent, UpdateNeighborEvent +from nebula.core.nebulaevents import NodeFoundEvent, UpdateNeighborEvent, ExperimentFinishEvent from nebula.core.network.communications import CommunicationsManager from typing import TYPE_CHECKING @@ -36,6 +36,7 @@ def __init__( self._restructure_process_lock = Locker(name="restructure_process_lock") self._restructure_cooldown = 0 self._verbose = verbose + self._cm = CommunicationsManager.get_instance() @property def sam(self): @@ -43,7 +44,7 @@ def sam(self): @property def cm(self): - return CommunicationsManager.get_instance() + return self._cm @property def np(self): @@ -54,6 +55,7 @@ async def init(self): logging.info("Deploying External Connection Service") await self.cm.start_external_connection_service() await EventManager.get_instance().subscribe_node_event(BeaconRecievedEvent, self.beacon_received) + await EventManager.get_instance().subscribe_node_event(ExperimentFinishEvent,self.experiment_finish) await self.cm.start_beacon() else: logging.info("Deploying External Connection Service | No running") diff --git a/nebula/core/situationalawareness/awareness/suggestionbuffer.py b/nebula/core/situationalawareness/awareness/suggestionbuffer.py index b890ff6cd..c564d1ed6 100644 --- a/nebula/core/situationalawareness/awareness/suggestionbuffer.py +++ b/nebula/core/situationalawareness/awareness/suggestionbuffer.py @@ -32,21 +32,25 @@ def __init__(self, arbitrator_notification : asyncio.Event, verbose): self._suggestion_buffer_lock = Locker("suggestion_buffer_lock", async_lock=True) self._expected_agents = defaultdict(set) # {event: {agents}} self._expected_agents_lock = Locker("expected_agents_lock", async_lock=True) - self._event_notifications : dict[SAModuleAgent, asyncio.Event] = {} + self._event_notifications : dict[NodeEvent, list[tuple[SAModuleAgent, asyncio.Event]]] = {} self._event_waited = None async def register_event_agents(self, event_type, agent: SAModuleAgent): """Registers expected agents for a given event.""" async with self._expected_agents_lock: - if self._verbose: logging.info(f"Registering SA Agent: {agent.get_agent()} for event: {event_type}") - self._expected_agents[event_type].add(agent) + if self._verbose: + logging.info(f"Registering SA Agent: {await agent.get_agent()} for event: {event_type}") if event_type not in self._event_notifications: - self._event_notifications[agent] = asyncio.Event() + self._event_notifications[event_type] = [] + self._expected_agents[event_type].add(agent) + existing_agents = {a for a, _ in self._event_notifications[event_type]} + if agent not in existing_agents: + self._event_notifications[event_type].append((agent, asyncio.Event())) async def register_suggestion(self, event_type, agent: SAModuleAgent, suggestion: SACommand): """Registers a suggestion from an agent for a specific event.""" async with self._suggestion_buffer_lock: - if self._verbose: logging.info(f"Registering Suggestion from SA Agent: {agent.get_agent()} for event: {event_type}") + if self._verbose: logging.info(f"Registering Suggestion from SA Agent: {await agent.get_agent()} for event: {event_type}") self._buffer[event_type].append((agent, suggestion)) async def set_event_waited(self, event_type): @@ -55,48 +59,59 @@ async def set_event_waited(self, event_type): if self._verbose: logging.info(f"Set notification when all suggestiones are being received for event: {event_type}") self._event_waited = event_type - #TODO maybe should define dict using events as keys to collect notifications for agents per events async def notify_all_suggestions_done_for_agent(self, saa : SAModuleAgent, event_type): - """SA Agent notification that has registered all the suggestions for event_type""" + """SA Agent notification that has registered all the suggestions for event_type.""" async with self._expected_agents_lock: - try: - self._event_notifications[saa].set() - if self._verbose: logging.info(f"SA Agent: {saa.get_agent()} notifies all suggestions registered for event: {event_type}") - await self._notify_arbitrator(event_type) - except: - if self._verbose: logging.error(f"SAModuleAgent: {saa.get_agent()} not found on notifications awaited") + agent_found = False + for agent, event in self._event_notifications.get(event_type, []): + if agent == saa: + event.set() + agent_found = True + if self._verbose: + logging.info(f"SA Agent: {await saa.get_agent()} notifies all suggestions registered for event: {event_type}") + break + if not agent_found and self._verbose: + logging.error(f"SAModuleAgent: {await saa.get_agent()} not found on notifications awaited for event {event_type}") + await self._notify_arbitrator(event_type) async def _notify_arbitrator(self, event_type): - """Checking if is should notify arbitrator that all suggestions for event_type have been received""" + """Checks whether to notify the arbitrator that all suggestions for event_type are received.""" if event_type != self._event_waited: return - + async with self._arbitrator_notification_lock: async with self._expected_agents_lock: - expected_agents = self._expected_agents.get(event_type, []) # Get the expected agents for this event type - # Check if all expected agents have sent their notifications - all_received = all(self._event_notifications[agent].is_set() for agent in expected_agents if agent in self._event_notifications) + expected_agents = self._expected_agents.get(event_type, []) + notifications = self._event_notifications.get(event_type, list()) + + agent_event_map = {a: e for a, e in notifications} + all_received = all( + agent in agent_event_map and agent_event_map[agent].is_set() + for agent in expected_agents + ) + if all_received: self._arbitrator_notification.set() self._event_waited = None - await self._reset_notifications_for_agents(expected_agents) + await self._reset_notifications_for_agents(event_type, expected_agents) - async def _reset_notifications_for_agents(self, agents): - """Reset notifications for SA Agents""" - for agent in agents: - self._event_notifications[agent].clear() + async def _reset_notifications_for_agents(self, event_type, agents): + """Reset notifications for SA Agents for the given event.""" + notifications = self._event_notifications.get(event_type, set()) + for agent, event in notifications: + if agent in agents: + event.clear() async def get_suggestions(self, event_type): """Retrieves all suggestions registered for a given event.""" async with self._suggestion_buffer_lock: async with self._expected_agents_lock: if self._verbose: logging.info(f"Retrieving all sugestions for event: {event_type}") - return self._buffer.get(event_type, []) + suggestions = self._buffer.get(event_type, []).copy() + await self._clear_suggestions(event_type) + return suggestions - async def clear_suggestions(self, event_type): - """Clears all suggestions stored for a given event.""" - async with self._lock: - if event_type in self._buffer: - del self._buffer[event_type] - del self._expected_agents[event_type] + async def _clear_suggestions(self, event_type): + """Clears all suggestions and metadata stored for a given event.""" + self._buffer[event_type].clear() diff --git a/nebula/core/situationalawareness/nodemanager.py b/nebula/core/situationalawareness/nodemanager.py index 940de40b3..a9ab641f1 100644 --- a/nebula/core/situationalawareness/nodemanager.py +++ b/nebula/core/situationalawareness/nodemanager.py @@ -93,10 +93,7 @@ async def set_configs(self): async def get_geoloc(self): return await self.sam.get_geoloc() - - async def experiment_finish(self): - await self.sam.experiment_finish() - + """ ############################## # WEIGHT STRATEGIES # From 144c3f3bbdf32008179aaf0d66f993f4f794797b Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Sat, 5 Apr 2025 15:24:13 +0200 Subject: [PATCH 160/233] feat integrated more events --- nebula/addons/mobility.py | 6 +- nebula/core/engine.py | 19 ++--- nebula/core/nebulaevents.py | 19 ++++- nebula/core/network/communications.py | 3 - .../nebuladiscoveryservice.py | 12 ++- .../awareness/samodule.py | 13 --- .../awareness/sanetwork/sanetwork.py | 1 - .../core/situationalawareness/nodemanager.py | 80 ++++++++----------- nebula/node.py | 3 +- 9 files changed, 71 insertions(+), 85 deletions(-) diff --git a/nebula/addons/mobility.py b/nebula/addons/mobility.py index 9e49948fd..a0196f97c 100755 --- a/nebula/addons/mobility.py +++ b/nebula/addons/mobility.py @@ -4,7 +4,7 @@ import random import time from nebula.core.eventmanager import EventManager -from nebula.core.nebulaevents import GPSEvent +from nebula.core.nebulaevents import GPSEvent, ChangeLocationEvent from nebula.core.utils.locker import Locker from nebula.addons.functions import print_msg_box @@ -240,8 +240,10 @@ async def set_geo_location(self, latitude, longitude): longitude = self.config.participant["mobility_args"]["longitude"] self.config.participant["mobility_args"]["latitude"] = latitude - self.config.participant["mobility_args"]["longitude"] = longitude + self.config.participant["mobility_args"]["longitude"] = latitude if self._verbose: logging.info(f"πŸ“ New geo location: {latitude}, {longitude}") + cle = ChangeLocationEvent(latitude, latitude) + asyncio.create_task(EventManager.get_instance().publish_addonevent(cle)) async def change_geo_location(self): """ diff --git a/nebula/core/engine.py b/nebula/core/engine.py index cfabdf9f0..53010ea0c 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -167,6 +167,7 @@ def __init__( topology, model_handler, engine=self, + verbose=True ) self._addon_manager = AddondManager(self, self.config) @@ -309,11 +310,6 @@ async def _connection_connect_callback(self, source, message): async def _connection_disconnect_callback(self, source, message): logging.info(f"πŸ”— handle_connection_message | Trigger | Received disconnection message from {source}") - if self.mobility: - if await self.nm.waiting_confirmation_from(source): - await self.nm.confirmation_received(source, confirmation=False) - # if source in await self.cm.get_all_addrs_current_connections(only_direct=True): - await self.nm.update_neighbors(source, remove=True) await self.cm.disconnect(source, mutual_disconnection=False) async def _federation_federation_ready_callback(self, source, message): @@ -384,9 +380,6 @@ async def register_message_callback(self, message_event: tuple[str, str], callba if callable(method): await EventManager.get_instance().subscribe((event_type, action), method) - async def get_geoloc(self): - return await self.nm.get_geoloc() - """ ############################## # ENGINE FUNCTIONALITY # ############################## @@ -403,7 +396,6 @@ async def _aditional_node_start(self): async def update_neighbors(self, removed_neighbor_addr, neighbors, remove=False): if self.mobility: self.federation_nodes = neighbors - await self.nm.update_neighbors(removed_neighbor_addr, remove=remove) updt_nei_event = UpdateNeighborEvent(removed_neighbor_addr, remove) asyncio.create_task(EventManager.get_instance().publish_node_event(updt_nei_event)) @@ -473,15 +465,16 @@ async def create_trainer_module(self): async def start_communications(self): await self.register_events_callbacks() - await self.aggregator.init() initial_neighbors = self.config.participant["network_args"]["neighbors"].split() await self.cm.start_communications(initial_neighbors) + await asyncio.sleep(self.config.participant["misc_args"]["grace_time_connection"] // 2) + + async def deploy_components(self): + await self.aggregator.init() if self.mobility: - logging.info("Building NodeManager configurations...") await self.nm.set_configs() await self._reporter.start() - await self._addon_manager.deploy_additional_services() - await asyncio.sleep(self.config.participant["misc_args"]["grace_time_connection"] // 2) + await self._addon_manager.deploy_additional_services() async def deploy_federation(self): await self.federation_ready_lock.acquire_async() diff --git a/nebula/core/nebulaevents.py b/nebula/core/nebulaevents.py index e50aa6e37..697e11306 100644 --- a/nebula/core/nebulaevents.py +++ b/nebula/core/nebulaevents.py @@ -129,7 +129,7 @@ async def is_concurrent(self) -> bool: return False class UpdateNeighborEvent(NodeEvent): - def __init__(self, node_addr, removed=False): + def __init__(self, node_addr, removed=False, joining=False): """Event triggered when a neighboring node is updated. Args: @@ -139,6 +139,7 @@ def __init__(self, node_addr, removed=False): """ self._node_addr = node_addr self._removed = removed + self._joining_federation = joining def __str__(self): return f"Node addr: {self._node_addr}, removed: {self._removed}" @@ -156,6 +157,9 @@ async def get_event_data(self) -> tuple[str, bool]: async def is_concurrent(self) -> bool: return False + def is_joining_federation(self): + return self._joining_federation + class NodeFoundEvent(NodeEvent): def __init__(self, node_addr): """Event triggered when a new node is found. @@ -261,4 +265,15 @@ def __str__(self): return "GPSEvent" async def get_event_data(self) -> dict: - return self.distances.copy() \ No newline at end of file + return self.distances.copy() + +class ChangeLocationEvent(AddonEvent): + def __init__(self, latitude, longitude): + self.latitude = latitude + self.longitude = longitude + + def __str__(self): + return "ChangeLocationEvent" + + async def get_event_data(self): + return ( self.latitude, self.longitude) \ No newline at end of file diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index ccaf015b2..ca65b15bb 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -216,9 +216,6 @@ async def clear_restrictions(self): ############################### """ - async def get_geoloc(self): - return await self.engine.get_geoloc() - async def start_external_connection_service(self, run_service=True): if self.ecs == None: self._external_connection_service = factory_connection_service(self, self.addr) diff --git a/nebula/core/network/externalconnection/nebuladiscoveryservice.py b/nebula/core/network/externalconnection/nebuladiscoveryservice.py index 939caa804..dfadae9d4 100644 --- a/nebula/core/network/externalconnection/nebuladiscoveryservice.py +++ b/nebula/core/network/externalconnection/nebuladiscoveryservice.py @@ -4,7 +4,7 @@ import struct from nebula.core.network.externalconnection.externalconnectionservice import ExternalConnectionService from nebula.core.utils.locker import Locker -from nebula.core.nebulaevents import BeaconRecievedEvent +from nebula.core.nebulaevents import BeaconRecievedEvent, ChangeLocationEvent from nebula.core.eventmanager import EventManager @@ -123,14 +123,22 @@ def __init__(self, nebula_service, addr, interval=20): self.addr = addr self.interval = interval # Intervalo de envΓ­o en segundos self.running = False + self._latitude = None + self._longitude = None async def start(self): logging.info("[NebulaBeacon]: Starting sending pressence beacon") self.running = True + await EventManager.get_instance().subscribe_addonevent(ChangeLocationEvent, self._proces_change_location_event) while self.running: await asyncio.sleep(self.interval) await self.send_beacon() + async def _proces_change_location_event(self, cle: ChangeLocationEvent): + lat, long = await cle.get_event_data() + logging.info(f"Location changed to: ({lat},{long})") + self._latitude, self._longitude = lat, long + async def stop(self): logging.info("[NebulaBeacon]: Stop existance beacon") self.running = False @@ -140,7 +148,7 @@ async def modify_beacon_frequency(self, frequency): self.interval = frequency async def send_beacon(self): - latitude, longitude = await self.nebula_service.cm.get_geoloc() + latitude, longitude = self._latitude, self._longitude try: message = ("NOTIFY * HTTP/1.1\r\n" "HOST: 239.255.255.250:1900\r\n" diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index d42deab65..8a9d07809 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -73,11 +73,6 @@ async def init(self): def is_additional_participant(self): return self.nm.is_additional_participant() - async def get_geoloc(self): - latitude = self.nm.config.participant["mobility_args"]["latitude"] - longitude = self.nm.config.participant["mobility_args"]["longitude"] - return (latitude, longitude) - async def _mobility_actions(self, ree : RoundEndEvent): logging.info("πŸ”„ Starting additional mobility actions...") await self.san.module_actions() @@ -100,14 +95,6 @@ def get_restructure_process_lock(self): async def register_node(self, node, neighbor=False, remove=False): await self.san.register_node(self, node, neighbor, remove) - def meet_node(self, node): - self.san.meet_node(node) - - def update_neighbors(self, node, remove=False): - self.san.update_neighbors(node, remove) - if not remove: - self.san.meet_node(node) - def get_nodes_known(self, neighbors_too=False, neighbors_only=False): return self.san.get_nodes_known(neighbors_too, neighbors_only) diff --git a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py index a7e9a30d6..ec0349713 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py @@ -144,7 +144,6 @@ async def beacon_received(self, beacon_recieved_event : BeaconRecievedEvent): latitude, longitude = geoloc nfe = NodeFoundEvent(addr) asyncio.create_task(EventManager.get_instance().publish_node_event(nfe)) - #self.meet_node(addr) #logging.info(f"Beacon received SANetwork, source: {addr}, geolocalization: {latitude},{longitude}") """ ############################### diff --git a/nebula/core/situationalawareness/nodemanager.py b/nebula/core/situationalawareness/nodemanager.py index a9ab641f1..bb442f786 100644 --- a/nebula/core/situationalawareness/nodemanager.py +++ b/nebula/core/situationalawareness/nodemanager.py @@ -24,6 +24,7 @@ def __init__( topology, model_handler, engine: "Engine", + verbose=False ): self._aditional_participant = aditional_participant self.topology = topology @@ -49,6 +50,7 @@ def __init__( self._desc_done = False #TODO remove self._situational_awareness_module = SAModule(self, self.engine.addr, topology) + self._verbose = verbose @property def engine(self): @@ -85,28 +87,14 @@ async def set_configs(self): - self weight distance - self weight hetereogeneity """ + logging.info("Building NodeManager configurations...") await self.register_message_events_callbacks() await self.sam.init() + await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent, self.update_neighbors) logging.info("Building candidate selector configuration..") self.candidate_selector.set_config([0, 0.5, 0.5]) # self.engine.trainer.get_loss(), self.config.participant["molibity_args"]["weight_distance"], self.config.participant["molibity_args"]["weight_het"] - - async def get_geoloc(self): - return await self.sam.get_geoloc() - - """ - ############################## - # WEIGHT STRATEGIES # - ############################## - """ - - async def register_late_neighbor(self, addr, joinning_federation=False): - logging.info(f"Registering | late neighbor: {addr}, joining: {joinning_federation}") - await self.meet_node(addr) - await self.update_neighbors(addr) - if joinning_federation: - pass - + """ ############################## # CONNECTIONS # @@ -147,33 +135,32 @@ async def waiting_confirmation_from(self, addr): await self.pending_confirmation_from_nodes_lock.release_async() return found - async def confirmation_received(self, addr, confirmation=False): - logging.info(f" Update | connection confirmation received from: {addr} | confirmation: {confirmation}") - if confirmation: - await self.cm.connect(addr, direct=True) - await self.update_neighbors(addr) - else: - self._remove_pending_confirmation_from(addr) - + async def confirmation_received(self, addr, joining=False): + logging.info(f" Update | connection confirmation received from: {addr} | joining federation: {joining}") + await self.cm.connect(addr, direct=True) + await self._remove_pending_confirmation_from(addr) + une = UpdateNeighborEvent(addr, joining=joining) + await EventManager.get_instance().publish_node_event(une) + def add_to_discarded_offers(self, addr_discarded): self.discarded_offers_addr_lock.acquire() self.discarded_offers_addr.append(addr_discarded) self.discarded_offers_addr_lock.release() - def need_more_neighbors(self): - return self.sam.need_more_neighbors() - def get_actions(self): return self.sam.get_actions() - async def update_neighbors(self, node, remove=False): + async def register_late_neighbor(self, addr, joinning_federation=False): + if self._verbose: logging.info(f"Registering | late neighbor: {addr}, joining: {joinning_federation}") + une = UpdateNeighborEvent(addr, joining=joinning_federation) + await EventManager.get_instance().publish_node_event(une) + + async def update_neighbors(self, une : UpdateNeighborEvent): + node, remove = await une.get_event_data() await self._update_neighbors_lock.acquire_async() - self.sam.update_neighbors(node, remove) - if remove: - pass - else: + if not remove: await self.meet_node(node) - self._remove_pending_confirmation_from(node) + await self._remove_pending_confirmation_from(node) await self._update_neighbors_lock.release_async() async def meet_node(self, node): @@ -185,7 +172,7 @@ def get_nodes_known(self, neighbors_too=False): def accept_model_offer(self, source, decoded_model, rounds, round, epochs, n_neighbors, loss): if not self.accept_candidates_lock.locked(): - logging.info(f"πŸ”„ Processing offer from {source}...") + if self._verbose: logging.info(f"πŸ”„ Processing offer from {source}...") model_accepted = self.model_handler.accept_model(decoded_model) self.model_handler.set_config(config=(rounds, round, epochs, self)) if model_accepted: @@ -208,7 +195,7 @@ async def stop_not_selected_connections(self): self.discarded_offers_addr = set( self.discarded_offers_addr ) - await self.cm.get_addrs_current_connections(only_direct=True, myself=False) - logging.info( + if self._verbose: logging.info( f"Interrupting connections | discarded offers | nodes discarded: {self.discarded_offers_addr}" ) for addr in self.discarded_offers_addr: @@ -239,16 +226,16 @@ async def start_late_connection_process(self, connected=False, msg_type="discove # wait offer #TODO actualizar con la informacion de latencias - logging.info(f"Connections stablish after finding federation: {connections_stablished}") + if self._verbose: logging.info(f"Connections stablish after finding federation: {connections_stablished}") if connections_stablished: - logging.info(f"Waiting: {self.recieve_offer_timer}s to receive offers from federation") + if self._verbose: logging.info(f"Waiting: {self.recieve_offer_timer}s to receive offers from federation") await asyncio.sleep(self.recieve_offer_timer) # acquire lock to not accept late candidates self.accept_candidates_lock.acquire() if self.candidate_selector.any_candidate(): - logging.info("Candidates found to connect to...") + if self._verbose: logging.info("Candidates found to connect to...") # create message to send to candidates selected if not connected: msg = self.cm.create_message("connection", "late_connect") @@ -256,7 +243,7 @@ async def start_late_connection_process(self, connected=False, msg_type="discove msg = self.cm.create_message("connection", "restructure") best_candidates = self.candidate_selector.select_candidates() - logging.info(f"Candidates | {[addr for addr, _, _ in best_candidates]}") + if self._verbose: logging.info(f"Candidates | {[addr for addr, _, _ in best_candidates]}") #TODO candidates not choosen --> disconnect try: for addr, _, _ in best_candidates: @@ -265,7 +252,7 @@ async def start_late_connection_process(self, connected=False, msg_type="discove await asyncio.sleep(1) except asyncio.CancelledError: await self.update_neighbors(addr, remove=True) - logging.info("Error during stablishment") + if self._verbose: logging.info("Error during stablishment") self.accept_candidates_lock.release() self.late_connection_process_lock.release() self.candidate_selector.remove_candidates() @@ -274,11 +261,11 @@ async def start_late_connection_process(self, connected=False, msg_type="discove # asyncio.create_task(self.sam.san.stop_connections_with_federation()) # if no candidates, repeat process else: - logging.info("❗️ No Candidates found...") + if self._verbose: logging.info("❗️ No Candidates found...") self.accept_candidates_lock.release() self.late_connection_process_lock.release() if not connected: - logging.info("❗️ repeating process...") + if self._verbose: logging.info("❗️ repeating process...") await self.start_late_connection_process(connected, msg_type, addrs_known) @@ -305,7 +292,7 @@ async def _connection_late_connect_callback(self, source, message): logging.info(f"πŸ”— handle_connection_message | Trigger | Received late connect message from {source}") # Verify if it's a confirmation message from a previous late connection message sent to source if await self.waiting_confirmation_from(source): - await self.confirmation_received(source, confirmation=True) + await self.confirmation_received(source, joining=True) return if not self.engine.get_initialization_status(): @@ -337,7 +324,7 @@ async def _connection_restructure_callback(self, source, message): logging.info(f"πŸ”— handle_connection_message | Trigger | Received restructure message from {source}") # Verify if it's a confirmation message from a previous restructure connection message sent to source if await self.waiting_confirmation_from(source): - await self.confirmation_received(source, confirmation=True) + await self.confirmation_received(source) return if not self.engine.get_initialization_status(): @@ -396,7 +383,6 @@ async def _discover_discover_join_callback(self, source, message): async def _discover_discover_nodes_callback(self, source, message): logging.info(f"πŸ” handle_discover_message | Trigger | Received discover_node message from {source} ") - # self.nm.meet_node(source) if len(self.engine.get_federation_nodes()) > 0: msg = self.cm.create_message( "offer", @@ -447,8 +433,6 @@ async def _link_connect_to_callback(self, source, message): logging.info(f"πŸ”— handle_link_message | Trigger | Received connect_to message from {source}") addrs = message.addrs for addr in addrs.split(): - # await self.cm.connect(addr, direct=True) - # self.nm.update_neighbors(addr) await self.meet_node(addr) async def _link_disconnect_from_callback(self, source, message): diff --git a/nebula/node.py b/nebula/node.py index 63ae03f8e..5091f090a 100755 --- a/nebula/node.py +++ b/nebula/node.py @@ -199,6 +199,7 @@ def randomize_value(value, variability): security=False, ) await node.start_communications() + await node.deploy_components() await node.deploy_federation() # If it is an additional node, it should wait until additional_node_round to connect to the network @@ -208,7 +209,7 @@ def randomize_value(value, variability): logging.info("Waiting time to start finding federation") # time.sleep(150) - await asyncio.sleep(150) + await asyncio.sleep(250) # time.sleep(6000) # DEBUG purposes # import requests From d56b361e3927330d7149b9010b57cb58b91582a2 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Sun, 6 Apr 2025 19:10:33 +0200 Subject: [PATCH 161/233] refactor communication manager property --- nebula/addons/mobility.py | 5 +- .../nebulanetworksimulator.py | 5 +- nebula/addons/reporter.py | 9 ++- nebula/core/engine.py | 2 - nebula/core/network/communications.py | 57 +++++++++++-------- nebula/core/network/connection.py | 9 ++- nebula/core/network/discoverer.py | 9 ++- .../nebuladiscoveryservice.py | 11 +++- nebula/core/network/forwarder.py | 9 ++- nebula/core/network/health.py | 9 ++- nebula/core/network/messages.py | 9 ++- nebula/core/network/propagator.py | 27 +++++---- .../trainingpolicy/qdstrainingpolicy.py | 10 ++-- .../core/situationalawareness/nodemanager.py | 10 +++- 14 files changed, 120 insertions(+), 61 deletions(-) diff --git a/nebula/addons/mobility.py b/nebula/addons/mobility.py index a0196f97c..f65784633 100755 --- a/nebula/addons/mobility.py +++ b/nebula/addons/mobility.py @@ -3,6 +3,8 @@ import math import random import time +from functools import cached_property +from nebula.core.network.communications import CommunicationsManager from nebula.core.eventmanager import EventManager from nebula.core.nebulaevents import GPSEvent, ChangeLocationEvent from nebula.core.utils.locker import Locker @@ -66,9 +68,8 @@ def __init__(self, config, verbose=False): self._nodes_distances_lock = Locker("nodes_distances_lock", async_lock=True) self._verbose = verbose - @property + @cached_property def cm(self): - from nebula.core.network.communications import CommunicationsManager return CommunicationsManager.get_instance() @property diff --git a/nebula/addons/networksimulation/nebulanetworksimulator.py b/nebula/addons/networksimulation/nebulanetworksimulator.py index ee8023f5e..0a6b4acf9 100644 --- a/nebula/addons/networksimulation/nebulanetworksimulator.py +++ b/nebula/addons/networksimulation/nebulanetworksimulator.py @@ -1,6 +1,8 @@ import asyncio import subprocess import logging +from functools import cached_property +from nebula.core.network.communications import CommunicationsManager from nebula.addons.networksimulation.networksimulator import NetworkSimulator from nebula.core.utils.locker import Locker from nebula.core.eventmanager import EventManager @@ -24,9 +26,8 @@ def __init__(self, changing_interval, interface, verbose=False): self._current_network_conditions = {} self._running = False - @property + @cached_property def cm(self): - from nebula.core.network.communications import CommunicationsManager return CommunicationsManager.get_instance() async def start(self): diff --git a/nebula/addons/reporter.py b/nebula/addons/reporter.py index ef59f0360..a4b878168 100755 --- a/nebula/addons/reporter.py +++ b/nebula/addons/reporter.py @@ -48,6 +48,7 @@ def __init__(self, config, trainer): - Initializes both current and accumulated metrics for traffic monitoring. """ logging.info("Starting reporter module") + self._cm = None self.config = config self.trainer = trainer self.frequency = self.config.participant["reporter_args"]["report_frequency"] @@ -69,8 +70,12 @@ def __init__(self, config, trainer): @property def cm(self): - from nebula.core.network.communications import CommunicationsManager - return CommunicationsManager.get_instance() + if not self._cm: + from nebula.core.network.communications import CommunicationsManager + self._cm = CommunicationsManager.get_instance() + return self._cm + else: + return self._cm async def enqueue_data(self, name, value): """ diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 53010ea0c..f5714265b 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -487,7 +487,6 @@ async def deploy_federation(self): while not await self.cm.check_federation_ready(): await asyncio.sleep(1) logging.info("Sending FEDERATION_START to neighbors...") - # message = self.cm.mm.generate_federation_message(nebula_pb2.FederationMessage.Action.FEDERATION_START) message = self.cm.create_message("federation", "federation_start") await self.cm.send_message_to_neighbors(message) await self.get_federation_ready_lock().release_async() @@ -498,7 +497,6 @@ async def deploy_federation(self): else: logging.info("Sending FEDERATION_READY to neighbors...") - # message = self.cm.mm.generate_federation_message(nebula_pb2.FederationMessage.Action.FEDERATION_READY) message = self.cm.create_message("federation", "federation_ready") await self.cm.send_message_to_neighbors(message) logging.info("πŸ’€ Waiting until receiving the start signal from the start node") diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index ca65b15bb..0e21c9772 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -510,32 +510,43 @@ async def send_message_to_neighbors(self, message, neighbors=None, interval=0): if interval > 0: await asyncio.sleep(interval) - async def send_message(self, dest_addr, message): - try: - conn = self.connections[dest_addr] - await conn.send(data=message) - except Exception as e: - logging.exception(f"❗️ Cannot send message {message} to {dest_addr}. Error: {e!s}") - await self.disconnect(dest_addr, mutual_disconnection=False) - - async def send_model(self, dest_addr, round, serialized_model, weight=1): - async with self.semaphore_send_model: + async def send_message(self, dest_addr, message, is_compressed=False): + if not is_compressed: try: - conn = self.connections.get(dest_addr) - if conn is None: - logging.info(f"❗️ Connection with {dest_addr} not found") - return - logging.info( - f"Sending model to {dest_addr} with round {round}: weight={weight} |Β size={sys.getsizeof(serialized_model) / (1024** 2) if serialized_model is not None else 0} MB" - ) - # message = self.mm.generate_model_message(round, serialized_model, weight) - parameters = serialized_model - message = self.create_message("model", "", round, parameters, weight) - await conn.send(data=message, is_compressed=True) - logging.info(f"Model sent to {dest_addr} with round {round}") + conn = self.connections[dest_addr] + await conn.send(data=message) except Exception as e: - logging.exception(f"❗️ Cannot send model to {dest_addr}: {e!s}") + logging.exception(f"❗️ Cannot send message {message} to {dest_addr}. Error: {e!s}") await self.disconnect(dest_addr, mutual_disconnection=False) + else: + async with self.semaphore_send_model: + try: + conn = self.connections.get(dest_addr) + if conn is None: + logging.info(f"❗️ Connection with {dest_addr} not found") + return + await conn.send(data=message, is_compressed=True) + except Exception as e: + logging.exception(f"❗️ Cannot send model to {dest_addr}: {e!s}") + await self.disconnect(dest_addr, mutual_disconnection=False) + + # async def send_model(self, dest_addr, round, serialized_model, weight=1): + # async with self.semaphore_send_model: + # try: + # conn = self.connections.get(dest_addr) + # if conn is None: + # logging.info(f"❗️ Connection with {dest_addr} not found") + # return + # logging.info( + # f"Sending model to {dest_addr} with round {round}: weight={weight} |Β size={sys.getsizeof(serialized_model) / (1024** 2) if serialized_model is not None else 0} MB" + # ) + # parameters = serialized_model + # message = self.create_message("model", "", round, parameters, weight) + # await conn.send(data=message, is_compressed=True) + # logging.info(f"Model sent to {dest_addr} with round {round}") + # except Exception as e: + # logging.exception(f"❗️ Cannot send model to {dest_addr}: {e!s}") + # await self.disconnect(dest_addr, mutual_disconnection=False) async def send_offer_model(self, dest_addr, offer_message): async with self.semaphore_send_model: diff --git a/nebula/core/network/connection.py b/nebula/core/network/connection.py index 767002526..de6edacc5 100755 --- a/nebula/core/network/connection.py +++ b/nebula/core/network/connection.py @@ -54,6 +54,7 @@ def __init__( self.last_active = time.time() self.compression = compression self.config = config + self._cm = None self.federated_round = Connection.DEFAULT_FEDERATED_ROUND self.loop = asyncio.get_event_loop() @@ -93,8 +94,12 @@ def __del__(self): @property def cm(self): - from nebula.core.network.communications import CommunicationsManager - return CommunicationsManager.get_instance() + if not self._cm: + from nebula.core.network.communications import CommunicationsManager + self._cm = CommunicationsManager.get_instance() + return self._cm + else: + return self._cm def get_addr(self): return self.addr diff --git a/nebula/core/network/discoverer.py b/nebula/core/network/discoverer.py index b6d646ff4..4b098ce7f 100755 --- a/nebula/core/network/discoverer.py +++ b/nebula/core/network/discoverer.py @@ -7,14 +7,19 @@ def __init__(self, addr, config): print_msg_box(msg="Starting discoverer module...", indent=2, title="Discoverer module") self.addr = addr self.config = config + self._cm = None self.grace_time = self.config.participant["discoverer_args"]["grace_time_discovery"] self.period = self.config.participant["discoverer_args"]["discovery_frequency"] self.interval = self.config.participant["discoverer_args"]["discovery_interval"] @property def cm(self): - from nebula.core.network.communications import CommunicationsManager - return CommunicationsManager.get_instance() + if not self._cm: + from nebula.core.network.communications import CommunicationsManager + self._cm = CommunicationsManager.get_instance() + return self._cm + else: + return self._cm async def start(self): asyncio.create_task(self.run_discover()) diff --git a/nebula/core/network/externalconnection/nebuladiscoveryservice.py b/nebula/core/network/externalconnection/nebuladiscoveryservice.py index dfadae9d4..61421810a 100644 --- a/nebula/core/network/externalconnection/nebuladiscoveryservice.py +++ b/nebula/core/network/externalconnection/nebuladiscoveryservice.py @@ -136,7 +136,7 @@ async def start(self): async def _proces_change_location_event(self, cle: ChangeLocationEvent): lat, long = await cle.get_event_data() - logging.info(f"Location changed to: ({lat},{long})") + #logging.info(f"Location changed to: ({lat},{long})") self._latitude, self._longitude = lat, long async def stop(self): @@ -170,6 +170,7 @@ class NebulaConnectionService(ExternalConnectionService): def __init__(self, addr): self.nodes_found = set() self.addr = addr + self._cm = None self.server : NebulaServerProtocol = None self.client : NebulaClientProtocol = None self.beacon : NebulaBeacon = NebulaBeacon(self, self.addr) @@ -177,8 +178,12 @@ def __init__(self, addr): @property def cm(self): - from nebula.core.network.communications import CommunicationsManager - return CommunicationsManager.get_instance() + if not self._cm: + from nebula.core.network.communications import CommunicationsManager + self._cm = CommunicationsManager.get_instance() + return self._cm + else: + return self._cm async def start(self): loop = asyncio.get_running_loop() diff --git a/nebula/core/network/forwarder.py b/nebula/core/network/forwarder.py index 380b7d3c2..a12f32b39 100755 --- a/nebula/core/network/forwarder.py +++ b/nebula/core/network/forwarder.py @@ -8,6 +8,7 @@ class Forwarder: def __init__(self, config): print_msg_box(msg="Starting forwarder module...", indent=2, title="Forwarder module") self.config = config + self._cm = None self.pending_messages = asyncio.Queue() self.pending_messages_lock = Locker("pending_messages_lock", verbose=False, async_lock=True) @@ -17,8 +18,12 @@ def __init__(self, config): @property def cm(self): - from nebula.core.network.communications import CommunicationsManager - return CommunicationsManager.get_instance() + if not self._cm: + from nebula.core.network.communications import CommunicationsManager + self._cm = CommunicationsManager.get_instance() + return self._cm + else: + return self._cm async def start(self): asyncio.create_task(self.run_forwarder()) diff --git a/nebula/core/network/health.py b/nebula/core/network/health.py index 9669b888f..30a2d7cbf 100755 --- a/nebula/core/network/health.py +++ b/nebula/core/network/health.py @@ -8,6 +8,7 @@ def __init__(self, addr, config): print_msg_box(msg="Starting health module...", indent=2, title="Health module") self.addr = addr self.config = config + self._cm = None self.period = self.config.participant["health_args"]["health_interval"] self.alive_interval = self.config.participant["health_args"]["send_alive_interval"] self.check_alive_interval = self.config.participant["health_args"]["check_alive_interval"] @@ -15,8 +16,12 @@ def __init__(self, addr, config): @property def cm(self): - from nebula.core.network.communications import CommunicationsManager - return CommunicationsManager.get_instance() + if not self._cm: + from nebula.core.network.communications import CommunicationsManager + self._cm = CommunicationsManager.get_instance() + return self._cm + else: + return self._cm async def start(self): asyncio.create_task(self.run_send_alive()) diff --git a/nebula/core/network/messages.py b/nebula/core/network/messages.py index f23674ad5..758890475 100644 --- a/nebula/core/network/messages.py +++ b/nebula/core/network/messages.py @@ -11,13 +11,18 @@ class MessagesManager: def __init__(self, addr, config): self.addr = addr self.config = config + self._cm = None self._message_templates = {} self._define_message_templates() @property def cm(self): - from nebula.core.network.communications import CommunicationsManager - return CommunicationsManager.get_instance() + if not self._cm: + from nebula.core.network.communications import CommunicationsManager + self._cm = CommunicationsManager.get_instance() + return self._cm + else: + return self._cm def _define_message_templates(self): # Dictionary that maps message types to their required parameters and default values diff --git a/nebula/core/network/propagator.py b/nebula/core/network/propagator.py index 91f050061..998c6a356 100755 --- a/nebula/core/network/propagator.py +++ b/nebula/core/network/propagator.py @@ -1,4 +1,5 @@ import asyncio +import sys import logging from abc import ABC, abstractmethod from collections import deque @@ -11,8 +12,7 @@ from nebula.core.aggregation.aggregator import Aggregator from nebula.core.engine import Engine from nebula.core.training.lightning import Lightning - - + class PropagationStrategy(ABC): @abstractmethod def is_node_eligible(self, node: str) -> bool: @@ -63,12 +63,16 @@ def prepare_model_payload(self, node: str) -> tuple[Any, float] | None: class Propagator: def __init__(self): - pass + self._cm = None @property def cm(self): - from nebula.core.network.communications import CommunicationsManager - return CommunicationsManager.get_instance() + if not self._cm: + from nebula.core.network.communications import CommunicationsManager + self._cm = CommunicationsManager.get_instance() + return self._cm + else: + return self._cm def start(self): self.engine: Engine = self.cm.engine @@ -163,12 +167,15 @@ async def propagate(self, strategy_id: str): serialized_model = None round_number = -1 if strategy_id == "initialization" else self.get_round() - + parameters = serialized_model + message = self.cm.create_message("model", "", round_number, parameters, weight) for neighbor_addr in eligible_neighbors: - asyncio.create_task(self.cm.send_model(neighbor_addr, round_number, serialized_model, weight)) - - # if len(self.aggregator.get_nodes_pending_models_to_aggregate()) >= len(self.aggregator._federation_nodes): - # return False + logging.info( + f"Sending model to {neighbor_addr} with round {self.get_round()}: weight={weight} |Β size={sys.getsizeof(serialized_model) / (1024** 2) if serialized_model is not None else 0} MB" + ) + asyncio.create_task(self.cm.send_message(neighbor_addr, message, is_compressed=True)) + #asyncio.create_task(self.cm.send_model(neighbor_addr, round_number, serialized_model, weight)) + await asyncio.sleep(self.interval) return True diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py index 771238338..f19ef72ff 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py @@ -5,7 +5,7 @@ from collections import deque import logging from nebula.core.eventmanager import EventManager -from nebula.core.nebulaevents import AggregationEvent +from nebula.core.nebulaevents import AggregationEvent, UpdateNeighborEvent from nebula.core.situationalawareness.awareness.suggestionbuffer import SuggestionBuffer from nebula.core.situationalawareness.awareness.sacommand import SACommand, ConnectivityCommand, SACommandAction, SACommandPRIO from nebula.core.network.communications import CommunicationsManager @@ -17,7 +17,7 @@ class QDSTrainingPolicy(TrainingPolicy): SIMILARITY_THRESHOLD = 0.73 INACTIVE_THRESHOLD = 3 GRACE_ROUNDS = 0 - CHECK_COOLDOWN = 50 + CHECK_COOLDOWN = 1 def __init__(self, config : dict): self._addr = config["addr"] @@ -37,9 +37,11 @@ async def init(self, config): nodes = config["nodes"] self._nodes : dict[str, tuple[deque, int]] = {node_id: (deque(maxlen=self.MAX_HISTORIC_SIZE), 0) for node_id in nodes} await EventManager.get_instance().subscribe_node_event(AggregationEvent, self.process_aggregation_event) + await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent, self.update_neighbors) await self.register_sa_agent() - async def update_neighbors(self, node, remove=False): + async def update_neighbors(self, une: UpdateNeighborEvent): + node, remove = await une.get_event_data() async with self._nodes_lock: if remove: self._nodes.pop(node, None) @@ -138,7 +140,7 @@ async def get_evaluation_results(self): await self.notify_all_suggestions_done(AggregationEvent) async def get_agent(self) -> str: - return "QDS_training_policy" + return "SATraining" async def register_sa_agent(self): await SuggestionBuffer.get_instance().register_event_agents(AggregationEvent, self) diff --git a/nebula/core/situationalawareness/nodemanager.py b/nebula/core/situationalawareness/nodemanager.py index bb442f786..d5e87a6b7 100644 --- a/nebula/core/situationalawareness/nodemanager.py +++ b/nebula/core/situationalawareness/nodemanager.py @@ -10,6 +10,7 @@ from nebula.core.eventmanager import EventManager from nebula.core.nebulaevents import UpdateNeighborEvent, NodeFoundEvent from nebula.core.network.communications import CommunicationsManager +from functools import cached_property if TYPE_CHECKING: from nebula.core.engine import Engine @@ -33,6 +34,7 @@ def __init__( ) logging.info("🌐 Initializing Node Manager") self._engine = engine + self._cm = None self.config = engine.get_config() logging.info("Initializing Candidate Selector") self._candidate_selector = factory_CandidateSelector(self.topology) @@ -56,7 +58,7 @@ def __init__( def engine(self): return self._engine - @property + @cached_property def cm(self): return CommunicationsManager.get_instance() @@ -374,7 +376,8 @@ async def _discover_discover_join_callback(self, source, message): round=round, epochs=epochs, ) - await self.cm.send_offer_model(source, msg) + logging.info(f"Sending offer model to {source}") + await self.cm.send_message(source, msg, is_compressed=True) else: logging.info("Discover join received before federation is running..") # starter node is going to send info to the new node @@ -388,8 +391,9 @@ async def _discover_discover_nodes_callback(self, source, message): "offer", "offer_metric", n_neighbors=len(self.engine.get_federation_nodes()), - loss=self.engine.trainer.get_current_loss(), + loss=0 #self.engine.trainer.get_current_loss(), ) + logging.info(f"Sending offer metric to {source}") await self.cm.send_message(source, msg) else: logging.info(f"πŸ”— Dissmissing discover nodes from {source} | no active connections at the moment") From f78efdd0016d41454de65634156df88759d15b5b Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Sun, 6 Apr 2025 20:03:14 +0200 Subject: [PATCH 162/233] fix early updates received before starting learning --- .../aggregation/updatehandlers/dflupdatehandler.py | 10 +++++++--- nebula/core/eventmanager.py | 2 +- nebula/node.py | 2 +- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/nebula/core/aggregation/updatehandlers/dflupdatehandler.py b/nebula/core/aggregation/updatehandlers/dflupdatehandler.py index 247760b1a..fde16d044 100644 --- a/nebula/core/aggregation/updatehandlers/dflupdatehandler.py +++ b/nebula/core/aggregation/updatehandlers/dflupdatehandler.py @@ -91,9 +91,13 @@ async def _check_updates_already_received(self): for se in self._sources_expected: (last_updt, node_storage) = self._updates_storage[se] if len(node_storage): - if last_updt != node_storage[-1]: - logging.info(f"Update already received from source: {se} | ({len(self._sources_received)}/{len(self._sources_expected)}) Updates received") - self._sources_received.add(se) + try: + if (last_updt and node_storage[-1] and last_updt != node_storage[-1]) or (node_storage[-1] and not last_updt): + self._sources_received.add(se) + logging.info(f"Update already received from source: {se} | ({len(self._sources_received)}/{len(self._sources_expected)}) Updates received") + + except: + logging.error(f"ERROR: source expected: {se} | last_update None: {(True if not last_updt else False)}, last update storaged None: {(True if not node_storage[-1] else False)}") async def storage_update(self, updt_received_event : UpdateReceivedEvent): time_received = time.time() diff --git a/nebula/core/eventmanager.py b/nebula/core/eventmanager.py index 6eef472f0..d448fa875 100755 --- a/nebula/core/eventmanager.py +++ b/nebula/core/eventmanager.py @@ -84,7 +84,7 @@ async def publish_addonevent(self, addonevent: AddonEvent): callbacks = self._addons_events_subs.get(event_type, []) if not callbacks: - logging.error(f"EventManager | No subscribers for AddonEvent type: {event_type.__name__}") + if self._verbose: logging.error(f"EventManager | No subscribers for AddonEvent type: {event_type.__name__}") return for callback in self._addons_events_subs[event_type]: diff --git a/nebula/node.py b/nebula/node.py index 5091f090a..e0393f6d8 100755 --- a/nebula/node.py +++ b/nebula/node.py @@ -209,7 +209,7 @@ def randomize_value(value, variability): logging.info("Waiting time to start finding federation") # time.sleep(150) - await asyncio.sleep(250) + await asyncio.sleep(150) # time.sleep(6000) # DEBUG purposes # import requests From c642d5e4d509296420a1c37cf3f53b0b97c76c78 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Mon, 7 Apr 2025 12:06:12 +0200 Subject: [PATCH 163/233] feature training policies as SA Agents --- nebula/core/eventmanager.py | 7 +- nebula/core/network/communications.py | 26 +++--- .../awareness/sacommand.py | 16 +++- .../awareness/samodule.py | 16 ++-- .../awareness/sanetwork/sanetwork.py | 83 +++++++++++++------ .../awareness/satraining/satraining.py | 7 +- .../trainingpolicy/bpstrainingpolicy.py | 32 +++++-- .../trainingpolicy/qdstrainingpolicy.py | 36 ++++---- .../trainingpolicy/sostrainingpolicy.py | 42 ++++++++-- .../trainingpolicy/trainingpolicy.py | 4 - .../satraining/weightstrategy}/momentum.py | 0 11 files changed, 182 insertions(+), 87 deletions(-) rename nebula/core/situationalawareness/{ => awareness/satraining/weightstrategy}/momentum.py (100%) diff --git a/nebula/core/eventmanager.py b/nebula/core/eventmanager.py index d448fa875..6f4f71056 100755 --- a/nebula/core/eventmanager.py +++ b/nebula/core/eventmanager.py @@ -7,6 +7,7 @@ from nebula.core.network.messages import MessageEvent from nebula.core.utils.locker import Locker from nebula.core.nebulaevents import AddonEvent, NodeEvent +from typing import Callable class EventManager: _instance = None @@ -40,7 +41,7 @@ def get_instance(verbose=False): EventManager(verbose=verbose) return EventManager._instance - async def subscribe(self, event_type: tuple[str, str], callback: callable): + async def subscribe(self, event_type: tuple[str, str], callback: Callable): """Register a callback for a specific event type.""" async with self._message_events_lock: if event_type not in self._subscribers: @@ -68,7 +69,7 @@ async def publish(self, message_event: MessageEvent): except Exception as e: logging.exception(f"EventManager | Error in callback for event {event_type}: {e}") - async def subscribe_addonevent(self, addonEventType: type[AddonEvent], callback: callable): + async def subscribe_addonevent(self, addonEventType: type[AddonEvent], callback: Callable): """Register a callback for a specific type of AddonEvent.""" async with self._addons_event_lock: if addonEventType not in self._addons_events_subs: @@ -98,7 +99,7 @@ async def publish_addonevent(self, addonevent: AddonEvent): logging.exception(f"EventManager | Error in callback for AddonEvent {event_type.__name__}: {e}") - async def subscribe_node_event(self, nodeEventType: type[NodeEvent], callback: callable): + async def subscribe_node_event(self, nodeEventType: type[NodeEvent], callback: Callable): """Register a callback for a specific type of AddonEvent.""" async with self._node_events_lock: if nodeEventType not in self._node_events_subs: diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 0e21c9772..76337fcfb 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -548,19 +548,19 @@ async def send_message(self, dest_addr, message, is_compressed=False): # logging.exception(f"❗️ Cannot send model to {dest_addr}: {e!s}") # await self.disconnect(dest_addr, mutual_disconnection=False) - async def send_offer_model(self, dest_addr, offer_message): - async with self.semaphore_send_model: - try: - conn = self.connections.get(dest_addr) - if conn is None: - logging.info(f"❗️ Connection with {dest_addr} not found") - return - logging.info(f"Sending offer model to {dest_addr}") - await conn.send(data=offer_message, is_compressed=True) - logging.info(f"Offer_Model sent to {dest_addr}") - except Exception as e: - logging.exception(f"❗️ Cannot send model to {dest_addr}: {e!s}") - await self.disconnect(dest_addr, mutual_disconnection=False) + # async def send_offer_model(self, dest_addr, offer_message): + # async with self.semaphore_send_model: + # try: + # conn = self.connections.get(dest_addr) + # if conn is None: + # logging.info(f"❗️ Connection with {dest_addr} not found") + # return + # logging.info(f"Sending offer model to {dest_addr}") + # await conn.send(data=offer_message, is_compressed=True) + # logging.info(f"Offer_Model sent to {dest_addr}") + # except Exception as e: + # logging.exception(f"❗️ Cannot send model to {dest_addr}: {e!s}") + # await self.disconnect(dest_addr, mutual_disconnection=False) async def establish_connection(self, addr, direct=True, reconnect=False): logging.info(f"πŸ”— [outgoing] Establishing connection with {addr} (direct: {direct})") diff --git a/nebula/core/situationalawareness/awareness/sacommand.py b/nebula/core/situationalawareness/awareness/sacommand.py index 79aafcbb0..9fedc582b 100644 --- a/nebula/core/situationalawareness/awareness/sacommand.py +++ b/nebula/core/situationalawareness/awareness/sacommand.py @@ -51,6 +51,7 @@ def __init__( self._priority = priority self._parallelizable = parallelizable self._state = SACommandState.PENDING + self._state_future = asyncio.get_event_loop().create_future() @abstractmethod async def execute(self): @@ -60,11 +61,19 @@ async def execute(self): async def conflicts_with(self, other: "SACommand") -> bool: raise NotImplementedError + async def discard_command(self): + await self._update_command_state(SACommandState.DISCARDED) + def get_owner(self): return self._owner.get_agent() - def update_command_state(self, sacs : SACommandState): + async def _update_command_state(self, sacs : SACommandState): self._state = sacs + if not self._state_future.done(): + self._state_future.set_result(sacs) + + def get_state_future(self): + return self._state_future def is_parallelizable(self): return self._parallelizable @@ -73,6 +82,7 @@ def __repr__(self): return (f"{self.__class__.__name__}(Type={self._command_type.value}, " f"Action={self._action.value}, Target={self._target}, Priority={self._priority.value})") + """ ############################### # SA COMMAND SUBCLASS # ############################### @@ -94,7 +104,7 @@ def __init__( async def execute(self): """Executes the assigned action function with the given parameters.""" - self.update_command_state(SACommandState.EXECUTED) + await self._update_command_state(SACommandState.EXECUTED) if self._action_function: if asyncio.iscoroutinefunction(self._action_function): await self._action_function(*self._args) @@ -123,7 +133,7 @@ def __init__( super().__init__(SACommandType.CONNECTIVITY, action, target, priority, parallelizable) async def execute(self): - self.update_command_state(SACommandState.EXECUTED) + await self._update_command_state(SACommandState.EXECUTED) return self._target def conflicts_with(self, other: "AggregationCommand") -> bool: diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index 8a9d07809..46f01d2bd 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -1,3 +1,4 @@ +from abc import abstractmethod, ABC import asyncio import logging from nebula.addons.functions import print_msg_box @@ -18,6 +19,14 @@ RESTRUCTURE_COOLDOWN = 5 +class SAMComponent(ABC): + @abstractmethod + async def init(self): + raise NotImplementedError + @abstractmethod + async def sa_component_actions(self): + raise NotImplementedError + class SAModule: def __init__( @@ -75,8 +84,8 @@ def is_additional_participant(self): async def _mobility_actions(self, ree : RoundEndEvent): logging.info("πŸ”„ Starting additional mobility actions...") - await self.san.module_actions() - await self.sat.module_actions() + await self.san.sa_component_actions() + await self.sat.sa_component_actions() """ ############################### @@ -92,9 +101,6 @@ def get_restructure_process_lock(self): ############################### """ - async def register_node(self, node, neighbor=False, remove=False): - await self.san.register_node(self, node, neighbor, remove) - def get_nodes_known(self, neighbors_too=False, neighbors_only=False): return self.san.get_nodes_known(neighbors_too, neighbors_only) diff --git a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py index ec0349713..8f7d4e7ab 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py @@ -1,20 +1,25 @@ import asyncio import logging from nebula.core.utils.locker import Locker +from typing import Callable from nebula.core.situationalawareness.awareness.sanetwork.neighborpolicies.neighborpolicy import factory_NeighborPolicy from nebula.addons.functions import print_msg_box from nebula.core.nebulaevents import BeaconRecievedEvent from nebula.core.eventmanager import EventManager -from nebula.core.nebulaevents import NodeFoundEvent, UpdateNeighborEvent, ExperimentFinishEvent +from nebula.core.nebulaevents import NodeFoundEvent, UpdateNeighborEvent, ExperimentFinishEvent, RoundEndEvent from nebula.core.network.communications import CommunicationsManager +from nebula.core.situationalawareness.awareness.samodule import SAMComponent +from nebula.core.situationalawareness.awareness.samoduleagent import SAModuleAgent +from nebula.core.situationalawareness.awareness.sacommand import SACommand, SACommandAction, SACommandPRIO, SACommandState +from nebula.core.situationalawareness.awareness.suggestionbuffer import SuggestionBuffer from typing import TYPE_CHECKING if TYPE_CHECKING: from nebula.core.situationalawareness.awareness.samodule import SAModule -RESTRUCTURE_COOLDOWN = 5 - -class SANetwork(): +RESTRUCTURE_COOLDOWN = 5 + +class SANetwork(SAMComponent): def __init__( self, sam: "SAModule", @@ -37,6 +42,7 @@ def __init__( self._restructure_cooldown = 0 self._verbose = verbose self._cm = CommunicationsManager.get_instance() + self._sa_network_agent = SANetworkAgent(self) @property def sam(self): @@ -50,6 +56,10 @@ def cm(self): def np(self): return self._neighbor_policy + @property + def sana(self): + return self._sa_network_agent + async def init(self): if not self.sam.is_additional_participant(): logging.info("Deploying External Connection Service") @@ -72,8 +82,9 @@ async def init(self): await EventManager.get_instance().subscribe_node_event(NodeFoundEvent, self.process_node_found_event) await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent, self.process_update_neighbor_event) + await self.sana.register_sa_agent() - async def module_actions(self): + async def sa_component_actions(self): logging.info("SA Network evaluating current scenario") await self._check_external_connection_service_status() await self._analize_topology_robustness() @@ -94,21 +105,10 @@ async def process_update_neighbor_event(self, une : UpdateNeighborEvent): if self._verbose: logging.info(f"Processing Update Neighbor Event, node addr: {node_addr}, remove: {removed}") self.np.update_neighbors(node_addr, removed) - async def register_node(self, node, neighbor=False, remove=False): - if not neighbor: - self.meet_node(node) - else: - self.update_neighbors(node, remove) - def meet_node(self, node): if node != self._addr: self.np.meet_node(node) - def update_neighbors(self, node, remove=False): - self.np.update_neighbors(node, remove) - if not remove: - self.np.meet_node(node) - def get_nodes_known(self, neighbors_too=False, neighbors_only=False): return self.np.get_nodes_known(neighbors_too, neighbors_only) @@ -157,7 +157,7 @@ def _update_restructure_cooldown(self): def _restructure_available(self): if self._restructure_cooldown: - logging.info("Reestructure on cooldown") + if self._verbose: logging.info("Reestructure on cooldown") return self._restructure_cooldown == 0 def get_restructure_process_lock(self): @@ -167,22 +167,22 @@ async def _analize_topology_robustness(self): logging.info("πŸ”„ Analizing node network robustness...") if not self._restructure_process_lock.locked(): if not await self.neighbors_left(): - logging.info("No Neighbors left | reconnecting with Federation") + if self._verbose: logging.info("No Neighbors left | reconnecting with Federation") #await self.reconnect_to_federation() elif self.np.need_more_neighbors() and self._restructure_available(): - logging.info("Insufficient Robustness | Upgrading robustness | Searching for more connections") + if self._verbose: logging.info("Insufficient Robustness | Upgrading robustness | Searching for more connections") self._update_restructure_cooldown() possible_neighbors = self.np.get_nodes_known(neighbors_too=False) possible_neighbors = await self.cm.apply_restrictions(possible_neighbors) if not possible_neighbors: - logging.info("All possible neighbors using nodes known are restricted...") + if self._verbose: logging.info("All possible neighbors using nodes known are restricted...") else: pass # asyncio.create_task(self.upgrade_connection_robustness(possible_neighbors)) else: - logging.info("Sufficient Robustness | no actions required") + if self._verbose: logging.info("Sufficient Robustness | no actions required") else: - logging.info("❗️ Reestructure/Reconnecting process already running...") + if self._verbose: logging.info("❗️ Reestructure/Reconnecting process already running...") async def reconnect_to_federation(self): self._restructure_process_lock.acquire() @@ -190,12 +190,12 @@ async def reconnect_to_federation(self): await asyncio.sleep(120) # If we got some refs, try to reconnect to them if len(self.np.get_nodes_known()) > 0: - logging.info("Reconnecting | Addrs availables") + if self._verbose: logging.info("Reconnecting | Addrs availables") await self.sam.nm.start_late_connection_process( connected=False, msg_type="discover_nodes", addrs_known=self.np.get_nodes_known() ) else: - logging.info("Reconnecting | NO Addrs availables") + if self._verbose: logging.info("Reconnecting | NO Addrs availables") await self.sam.nm.start_late_connection_process(connected=False, msg_type="discover_nodes") self._restructure_process_lock.release() @@ -204,12 +204,12 @@ async def upgrade_connection_robustness(self, possible_neighbors): # addrs_to_connect = self.neighbor_policy.get_nodes_known(neighbors_too=False) # If we got some refs, try to connect to them if len(possible_neighbors) > 0: - logging.info(f"Reestructuring | Addrs availables | addr list: {possible_neighbors}") + if self._verbose: logging.info(f"Reestructuring | Addrs availables | addr list: {possible_neighbors}") await self.sam.nm.start_late_connection_process( connected=True, msg_type="discover_nodes", addrs_known=possible_neighbors ) else: - logging.info("Reestructuring | NO Addrs availables") + if self._verbose: logging.info("Reestructuring | NO Addrs availables") await self.sam.nm.start_late_connection_process(connected=True, msg_type="discover_nodes") self._restructure_process_lock.release() @@ -220,4 +220,33 @@ async def stop_connections_with_federation(self): for n in neighbors: await self.cm.add_to_blacklist(n) for n in neighbors: - await self.cm.disconnect(n, mutual_disconnection=False, forced=True) \ No newline at end of file + await self.cm.disconnect(n, mutual_disconnection=False, forced=True) + + """ ############################### + # SA NETWORK AGENT # + ############################### + """ + +class SANetworkAgent(SAModuleAgent): + + def __init__(self, sanetwork : SANetwork): + self._san = sanetwork + + async def get_agent(self) -> str: + return "SANetwork_MainNetworkAgent" + + async def register_sa_agent(self): + await SuggestionBuffer.get_instance().register_event_agents(RoundEndEvent, self) + + async def suggest_action(self, sac : SACommand): + await SuggestionBuffer.get_instance().register_suggestion(RoundEndEvent, self, sac) + + async def notify_all_suggestions_done(self, event_type): + await SuggestionBuffer.get_instance().notify_all_suggestions_done_for_agent(self, event_type) + + async def create_and_suggest_action(self, saca: SACommandAction, function : Callable, *args): + if saca == SACommandAction.MAINTAIN_CONNECTIONS: + pass + elif saca == SACommandAction.SEARCH_CONNECTIONS: + pass + \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/satraining/satraining.py b/nebula/core/situationalawareness/awareness/satraining/satraining.py index 5e4bb175d..6b5c17dd6 100644 --- a/nebula/core/situationalawareness/awareness/satraining/satraining.py +++ b/nebula/core/situationalawareness/awareness/satraining/satraining.py @@ -2,15 +2,16 @@ import logging from nebula.core.utils.locker import Locker from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.trainingpolicy import factory_training_policy +from nebula.core.situationalawareness.awareness.samodule import SAMComponent from nebula.addons.functions import print_msg_box from typing import TYPE_CHECKING if TYPE_CHECKING: - from nebula.core.situationalawareness.awareness.samodule import SAModule + from nebula.core.situationalawareness.awareness.samodule import SAModule, SAMComponent from nebula.core.eventmanager import EventManager RESTRUCTURE_COOLDOWN = 5 -class SATraining(): +class SATraining(SAMComponent): def __init__( self, sam: "SAModule", @@ -45,7 +46,7 @@ async def init(self): config["nodes"] = set(self._sam.get_nodes_known(neighbors_only=True)) await self.tp.init(config) - async def module_actions(self): + async def sa_component_actions(self): logging.info("SA Trainng evaluating current scenario") asyncio.create_task(self.tp.get_evaluation_results()) diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/bpstrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/bpstrainingpolicy.py index ee41b2c1a..60d0110a4 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/bpstrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/bpstrainingpolicy.py @@ -1,4 +1,7 @@ from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.trainingpolicy import TrainingPolicy +from nebula.core.situationalawareness.awareness.suggestionbuffer import SuggestionBuffer +from nebula.core.situationalawareness.awareness.sacommand import SACommand, ConnectivityCommand, SACommandAction, SACommandPRIO +from nebula.core.nebulaevents import RoundEndEvent class BPSTrainingPolicy(TrainingPolicy): @@ -6,13 +9,28 @@ def __init__(self, config=None): pass async def init(self, config): - pass + await self.register_sa_agent() - async def update_neighbors(self, node, remove=False): - pass + async def get_evaluation_results(self): + sac = ConnectivityCommand( + SACommandAction.MAINTAIN_CONNECTIONS, + "", + SACommandPRIO.LOW, + False, + None, + None + ) + await self.suggest_action(sac) + await self.notify_all_suggestions_done(RoundEndEvent) - async def evaluate(self): - return None + async def get_agent(self) -> str: + return "SATraining_BPSTP" + + async def register_sa_agent(self): + await SuggestionBuffer.get_instance().register_event_agents(RoundEndEvent, self) - async def get_evaluation_results(self): - return None \ No newline at end of file + async def suggest_action(self, sac : SACommand): + await SuggestionBuffer.get_instance().register_suggestion(RoundEndEvent, self, sac) + + async def notify_all_suggestions_done(self, event_type): + await SuggestionBuffer.get_instance().notify_all_suggestions_done_for_agent(self, event_type) \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py index f19ef72ff..7192ab1d2 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py @@ -5,7 +5,7 @@ from collections import deque import logging from nebula.core.eventmanager import EventManager -from nebula.core.nebulaevents import AggregationEvent, UpdateNeighborEvent +from nebula.core.nebulaevents import AggregationEvent, UpdateNeighborEvent, RoundEndEvent from nebula.core.situationalawareness.awareness.suggestionbuffer import SuggestionBuffer from nebula.core.situationalawareness.awareness.sacommand import SACommand, ConnectivityCommand, SACommandAction, SACommandPRIO from nebula.core.network.communications import CommunicationsManager @@ -27,6 +27,7 @@ def __init__(self, config : dict): self._round_missing_nodes = set() self._grace_rounds = self.GRACE_ROUNDS self._last_check = 0 + self._check_done = False self._evaluation_results = set() def __str__(self): @@ -88,6 +89,7 @@ async def evaluate(self): result = set() if self._last_check == 0: + self._check_done = True nodes = await self._get_nodes() redundant_nodes = set() inactive_nodes = set() @@ -120,33 +122,35 @@ async def evaluate(self): result = result.union(discard_nodes) else: if self._verbose: logging.info(f"Evaluation is on cooldown... | {self.CHECK_COOLDOWN - self._last_check} rounds remaining") + self._check_done = False self._last_check = (self._last_check + 1) % self.CHECK_COOLDOWN return result async def get_evaluation_results(self): - for node_discarded in self._evaluation_results: - args = (node_discarded, False, True) - sac = ConnectivityCommand( - SACommandAction.DISCONNECT, - node_discarded, - SACommandPRIO.MEDIUM, - True, - CommunicationsManager.get_instance().disconnect, - *args - ) - await self.suggest_action(sac) - await self.notify_all_suggestions_done(AggregationEvent) + if self._check_done: + for node_discarded in self._evaluation_results: + args = (node_discarded, False, True) + sac = ConnectivityCommand( + SACommandAction.DISCONNECT, + node_discarded, + SACommandPRIO.MEDIUM, + False, + CommunicationsManager.get_instance().disconnect, + *args + ) + await self.suggest_action(sac) + await self.notify_all_suggestions_done(RoundEndEvent) async def get_agent(self) -> str: - return "SATraining" + return "SATraining_QDSTP" async def register_sa_agent(self): - await SuggestionBuffer.get_instance().register_event_agents(AggregationEvent, self) + await SuggestionBuffer.get_instance().register_event_agents(RoundEndEvent, self) async def suggest_action(self, sac : SACommand): - await SuggestionBuffer.get_instance().register_suggestion(AggregationEvent, self, sac) + await SuggestionBuffer.get_instance().register_suggestion(RoundEndEvent, self, sac) async def notify_all_suggestions_done(self, event_type): await SuggestionBuffer.get_instance().notify_all_suggestions_done_for_agent(self, event_type) \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py index e18172a48..239859040 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py @@ -3,7 +3,10 @@ from collections import deque import logging from nebula.core.eventmanager import EventManager -from nebula.core.nebulaevents import UpdateReceivedEvent, AggregationEvent, RoundStartEvent +from nebula.core.nebulaevents import UpdateReceivedEvent, AggregationEvent, RoundStartEvent, UpdateNeighborEvent, RoundEndEvent +from nebula.core.situationalawareness.awareness.suggestionbuffer import SuggestionBuffer +from nebula.core.situationalawareness.awareness.sacommand import SACommand, ConnectivityCommand, SACommandAction, SACommandPRIO +from nebula.core.network.communications import CommunicationsManager import time import asyncio @@ -76,6 +79,8 @@ async def init(self, config): await EventManager.get_instance().subscribe_node_event(UpdateReceivedEvent, self._process_update_received_event) await EventManager.get_instance().subscribe_node_event(RoundStartEvent, self._process_round_start) await EventManager.get_instance().subscribe_node_event(AggregationEvent, self._process_aggregation_event) + await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent, self.update_neighbors) + await self.register_sa_agent() async def _get_nodes(self): async with self._nodes_lock: @@ -130,7 +135,8 @@ async def _process_update_received_event(self, ure : UpdateReceivedEvent): self._nodes[source] = (history, missed_count, time_between_updts_historic, last_update_times) - async def update_neighbors(self, node, remove=False): + async def update_neighbors(self, une : UpdateNeighborEvent): + node, remove = await une.get_event_data() async with self._nodes_lock: if remove: self._nodes.pop(node, None) @@ -138,10 +144,7 @@ async def update_neighbors(self, node, remove=False): if not node in self._nodes: self._nodes.update({node : (deque(maxlen=self.MAX_HISTORIC_SIZE), 0, float('inf'), float('inf'))}) - async def get_evaluation_results(self): - return None - - async def evaluate(self): + async def _evaluate(self): if self._verbose: logging.info("Evaluating using speed-oriented strategy") if self._grace_rounds: # Grace rounds self._grace_rounds -= 1 @@ -228,4 +231,31 @@ async def evaluate(self): return nodes_below_th + async def get_evaluation_results(self): + nodes_to_discard = await self._evaluate() + for node_discarded in nodes_to_discard: + args = (node_discarded, False, True) + sac = ConnectivityCommand( + SACommandAction.DISCONNECT, + node_discarded, + SACommandPRIO.MEDIUM, + False, + CommunicationsManager.get_instance().disconnect, + *args + ) + await self.suggest_action(sac) + await self.notify_all_suggestions_done(RoundEndEvent) + + async def get_agent(self) -> str: + return "SATraining_SOSTP" + + async def register_sa_agent(self): + await SuggestionBuffer.get_instance().register_event_agents(RoundEndEvent, self) + + async def suggest_action(self, sac : SACommand): + await SuggestionBuffer.get_instance().register_suggestion(RoundEndEvent, self, sac) + + async def notify_all_suggestions_done(self, event_type): + await SuggestionBuffer.get_instance().notify_all_suggestions_done_for_agent(self, event_type) + \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py index 6616b9581..5e343a14c 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py @@ -7,10 +7,6 @@ class TrainingPolicy(SAModuleAgent): async def init(self, config): pass - @abstractmethod - async def update_neighbors(self, node, remove=False): - pass - @abstractmethod async def get_evaluation_results(self): pass diff --git a/nebula/core/situationalawareness/momentum.py b/nebula/core/situationalawareness/awareness/satraining/weightstrategy/momentum.py similarity index 100% rename from nebula/core/situationalawareness/momentum.py rename to nebula/core/situationalawareness/awareness/satraining/weightstrategy/momentum.py From a60336bbe76277fc899182ff4ba1c028dc0ebe5d Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Mon, 7 Apr 2025 12:38:38 +0200 Subject: [PATCH 164/233] fix owner missing SACommands --- .../awareness/sacommand.py | 16 ++++++++---- .../awareness/samodule.py | 4 +-- .../awareness/sanetwork/sanetwork.py | 26 ++++++++++++++++--- .../trainingpolicy/bpstrainingpolicy.py | 8 +++--- .../trainingpolicy/qdstrainingpolicy.py | 18 +++++++------ .../trainingpolicy/sostrainingpolicy.py | 8 +++--- 6 files changed, 56 insertions(+), 24 deletions(-) diff --git a/nebula/core/situationalawareness/awareness/sacommand.py b/nebula/core/situationalawareness/awareness/sacommand.py index 9fedc582b..533a10569 100644 --- a/nebula/core/situationalawareness/awareness/sacommand.py +++ b/nebula/core/situationalawareness/awareness/sacommand.py @@ -91,14 +91,15 @@ class ConnectivityCommand(SACommand): """Commands related to connectivity.""" def __init__( self, - action: SACommandAction, + action: SACommandAction, + owner : SAModuleAgent, target: str, priority: SACommandPRIO = SACommandPRIO.MEDIUM, parallelizable = False, action_function = None, *args ): - super().__init__(SACommandType.CONNECTIVITY, action, target, priority, parallelizable) + super().__init__(SACommandType.CONNECTIVITY, action, owner, target, priority, parallelizable) self._action_function = action_function self._args = args @@ -111,12 +112,14 @@ async def execute(self): else: self._action_function(*self._args) + #TODO repasar def conflicts_with(self, other: "ConnectivityCommand") -> bool: """Determines if two commands conflict with each other.""" if self._target == other._target: conflict_pairs = [ {SACommandAction.DISCONNECT, SACommandAction.RECONNECT}, - {SACommandAction.DISCONNECT, SACommandAction.MAINTAIN_CONNECTIONS} + {SACommandAction.DISCONNECT, SACommandAction.MAINTAIN_CONNECTIONS}, + {SACommandAction.DISCONNECT, SACommandAction.SEARCH_CONNECTIONS} ] return {self._action, other._action} in conflict_pairs return False @@ -125,12 +128,13 @@ class AggregationCommand(SACommand): """Commands related to data aggregation.""" def __init__( self, - action: SACommandAction, + action: SACommandAction, + owner : SAModuleAgent, target: dict, priority: SACommandPRIO = SACommandPRIO.MEDIUM, parallelizable = False, ): - super().__init__(SACommandType.CONNECTIVITY, action, target, priority, parallelizable) + super().__init__(SACommandType.CONNECTIVITY, action, owner, target, priority, parallelizable) async def execute(self): await self._update_command_state(SACommandState.EXECUTED) @@ -163,6 +167,8 @@ def factory_sa_command(sacommand_type, *config) -> SACommand: } cs = options.get(sacommand_type, None) + if cs is None: + raise ValueError(f"Unknown SACommand type: {sacommand_type}") return cs(*config) diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index 46f01d2bd..89a31c4db 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -84,8 +84,8 @@ def is_additional_participant(self): async def _mobility_actions(self, ree : RoundEndEvent): logging.info("πŸ”„ Starting additional mobility actions...") - await self.san.sa_component_actions() - await self.sat.sa_component_actions() + asyncio.create_task(self.san.sa_component_actions()) + asyncio.create_task(self.sat.sa_component_actions()) """ ############################### diff --git a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py index 8f7d4e7ab..01a8231b9 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py @@ -10,7 +10,7 @@ from nebula.core.network.communications import CommunicationsManager from nebula.core.situationalawareness.awareness.samodule import SAMComponent from nebula.core.situationalawareness.awareness.samoduleagent import SAModuleAgent -from nebula.core.situationalawareness.awareness.sacommand import SACommand, SACommandAction, SACommandPRIO, SACommandState +from nebula.core.situationalawareness.awareness.sacommand import SACommand, SACommandAction, SACommandPRIO, SACommandState, factory_sa_command from nebula.core.situationalawareness.awareness.suggestionbuffer import SuggestionBuffer from typing import TYPE_CHECKING @@ -246,7 +246,27 @@ async def notify_all_suggestions_done(self, event_type): async def create_and_suggest_action(self, saca: SACommandAction, function : Callable, *args): if saca == SACommandAction.MAINTAIN_CONNECTIONS: - pass + sac = factory_sa_command( + "connectivity", + SACommandAction.MAINTAIN_CONNECTIONS, + "", + SACommandPRIO.MEDIUM, + False, + function, + None + ) + await self.suggest_action(sac) + await self.notify_all_suggestions_done(RoundEndEvent) elif saca == SACommandAction.SEARCH_CONNECTIONS: - pass + sac = factory_sa_command( + "connectivity", + SACommandAction.MAINTAIN_CONNECTIONS, + "", + SACommandPRIO.MEDIUM, + True, + function, + *args + ) + await self.suggest_action(sac) + await self.notify_all_suggestions_done(RoundEndEvent) \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/bpstrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/bpstrainingpolicy.py index 60d0110a4..576cae626 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/bpstrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/bpstrainingpolicy.py @@ -1,6 +1,6 @@ from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.trainingpolicy import TrainingPolicy from nebula.core.situationalawareness.awareness.suggestionbuffer import SuggestionBuffer -from nebula.core.situationalawareness.awareness.sacommand import SACommand, ConnectivityCommand, SACommandAction, SACommandPRIO +from nebula.core.situationalawareness.awareness.sacommand import SACommand, factory_sa_command, SACommandAction, SACommandPRIO from nebula.core.nebulaevents import RoundEndEvent class BPSTrainingPolicy(TrainingPolicy): @@ -12,8 +12,10 @@ async def init(self, config): await self.register_sa_agent() async def get_evaluation_results(self): - sac = ConnectivityCommand( - SACommandAction.MAINTAIN_CONNECTIONS, + sac = factory_sa_command( + "connectivity", + SACommandAction.MAINTAIN_CONNECTIONS, + self, "", SACommandPRIO.LOW, False, diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py index 7192ab1d2..56c7009f6 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py @@ -7,7 +7,7 @@ from nebula.core.eventmanager import EventManager from nebula.core.nebulaevents import AggregationEvent, UpdateNeighborEvent, RoundEndEvent from nebula.core.situationalawareness.awareness.suggestionbuffer import SuggestionBuffer -from nebula.core.situationalawareness.awareness.sacommand import SACommand, ConnectivityCommand, SACommandAction, SACommandPRIO +from nebula.core.situationalawareness.awareness.sacommand import SACommand, SACommandAction, SACommandPRIO, factory_sa_command from nebula.core.network.communications import CommunicationsManager import math @@ -132,13 +132,15 @@ async def get_evaluation_results(self): if self._check_done: for node_discarded in self._evaluation_results: args = (node_discarded, False, True) - sac = ConnectivityCommand( - SACommandAction.DISCONNECT, - node_discarded, - SACommandPRIO.MEDIUM, - False, - CommunicationsManager.get_instance().disconnect, - *args + sac = factory_sa_command( + "connectivity", + SACommandAction.DISCONNECT, + self, + node_discarded, + SACommandPRIO.MEDIUM, + False, + CommunicationsManager.get_instance().disconnect, + *args ) await self.suggest_action(sac) await self.notify_all_suggestions_done(RoundEndEvent) diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py index 239859040..2e355e54c 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py @@ -5,7 +5,7 @@ from nebula.core.eventmanager import EventManager from nebula.core.nebulaevents import UpdateReceivedEvent, AggregationEvent, RoundStartEvent, UpdateNeighborEvent, RoundEndEvent from nebula.core.situationalawareness.awareness.suggestionbuffer import SuggestionBuffer -from nebula.core.situationalawareness.awareness.sacommand import SACommand, ConnectivityCommand, SACommandAction, SACommandPRIO +from nebula.core.situationalawareness.awareness.sacommand import SACommand, SACommandAction, SACommandPRIO, factory_sa_command from nebula.core.network.communications import CommunicationsManager import time import asyncio @@ -235,8 +235,10 @@ async def get_evaluation_results(self): nodes_to_discard = await self._evaluate() for node_discarded in nodes_to_discard: args = (node_discarded, False, True) - sac = ConnectivityCommand( - SACommandAction.DISCONNECT, + sac = factory_sa_command( + "connectivity", + SACommandAction.DISCONNECT, + self, node_discarded, SACommandPRIO.MEDIUM, False, From dcf415d5589034f1b01fbc305a85e8c18c36bbc9 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Tue, 8 Apr 2025 11:27:21 +0200 Subject: [PATCH 165/233] morning update --- .../updatehandlers/dflupdatehandler.py | 1 - .../awareness/sacommand.py | 11 +++++--- .../awareness/samodule.py | 18 +++++++------ .../awareness/sanetwork/sanetwork.py | 27 +++++++++++++++---- 4 files changed, 39 insertions(+), 18 deletions(-) diff --git a/nebula/core/aggregation/updatehandlers/dflupdatehandler.py b/nebula/core/aggregation/updatehandlers/dflupdatehandler.py index fde16d044..9e5241ca4 100644 --- a/nebula/core/aggregation/updatehandlers/dflupdatehandler.py +++ b/nebula/core/aggregation/updatehandlers/dflupdatehandler.py @@ -95,7 +95,6 @@ async def _check_updates_already_received(self): if (last_updt and node_storage[-1] and last_updt != node_storage[-1]) or (node_storage[-1] and not last_updt): self._sources_received.add(se) logging.info(f"Update already received from source: {se} | ({len(self._sources_received)}/{len(self._sources_expected)}) Updates received") - except: logging.error(f"ERROR: source expected: {se} | last_update None: {(True if not last_updt else False)}, last update storaged None: {(True if not node_storage[-1] else False)}") diff --git a/nebula/core/situationalawareness/awareness/sacommand.py b/nebula/core/situationalawareness/awareness/sacommand.py index 533a10569..97fbe3ef0 100644 --- a/nebula/core/situationalawareness/awareness/sacommand.py +++ b/nebula/core/situationalawareness/awareness/sacommand.py @@ -92,7 +92,7 @@ class ConnectivityCommand(SACommand): def __init__( self, action: SACommandAction, - owner : SAModuleAgent, + owner : "SAModuleAgent", target: str, priority: SACommandPRIO = SACommandPRIO.MEDIUM, parallelizable = False, @@ -112,24 +112,27 @@ async def execute(self): else: self._action_function(*self._args) - #TODO repasar def conflicts_with(self, other: "ConnectivityCommand") -> bool: """Determines if two commands conflict with each other.""" if self._target == other._target: + conflict_pairs = [ + {SACommandAction.DISCONNECT, SACommandAction.DISCONNECT}, + ] + return {self._action, other._action} in conflict_pairs + else: conflict_pairs = [ {SACommandAction.DISCONNECT, SACommandAction.RECONNECT}, {SACommandAction.DISCONNECT, SACommandAction.MAINTAIN_CONNECTIONS}, {SACommandAction.DISCONNECT, SACommandAction.SEARCH_CONNECTIONS} ] return {self._action, other._action} in conflict_pairs - return False class AggregationCommand(SACommand): """Commands related to data aggregation.""" def __init__( self, action: SACommandAction, - owner : SAModuleAgent, + owner : "SAModuleAgent", target: dict, priority: SACommandPRIO = SACommandPRIO.MEDIUM, parallelizable = False, diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index 89a31c4db..db0d74015 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -3,22 +3,16 @@ import logging from nebula.addons.functions import print_msg_box from nebula.core.situationalawareness.awareness.suggestionbuffer import SuggestionBuffer -from nebula.core.situationalawareness.awareness.sanetwork.sanetwork import SANetwork -from nebula.core.situationalawareness.awareness.satraining.satraining import SATraining from nebula.core.utils.locker import Locker from nebula.core.nebulaevents import RoundEndEvent from nebula.core.eventmanager import EventManager from nebula.core.nebulaevents import RoundEndEvent, AggregationEvent - from nebula.core.network.communications import CommunicationsManager from typing import TYPE_CHECKING if TYPE_CHECKING: from nebula.core.situationalawareness.nodemanager import NodeManager - -RESTRUCTURE_COOLDOWN = 5 - class SAMComponent(ABC): @abstractmethod async def init(self): @@ -44,8 +38,8 @@ def __init__( self._addr = addr self._topology = topology self._node_manager: NodeManager = nodemanager - self._situational_awareness_network = SANetwork(self, self._addr, self._topology) - self._situational_awareness_training = SATraining(self, self._addr, "qds", "fastreboot", verbose=True) + self._situational_awareness_network = None + self._situational_awareness_training = None self._restructure_process_lock = Locker(name="restructure_process_lock") self._restructure_cooldown = 0 self._arbitrator_notification = asyncio.Event() @@ -74,6 +68,10 @@ def sb(self): async def init(self): + from nebula.core.situationalawareness.awareness.sanetwork.sanetwork import SANetwork + from nebula.core.situationalawareness.awareness.satraining.satraining import SATraining + self._situational_awareness_network = SANetwork(self, self._addr, self._topology, True) + self._situational_awareness_training = SATraining(self, self._addr, "qds", "fastreboot", verbose=True) await EventManager.get_instance().subscribe_node_event(RoundEndEvent, self._mobility_actions) await self.san.init() await self.sat.init() @@ -116,4 +114,8 @@ def need_more_neighbors(self): def get_actions(self): return self.san.get_actions() + """ ############################### + # ARBITRATION # + ############################### + """ diff --git a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py index 01a8231b9..96ec2b900 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py @@ -167,8 +167,9 @@ async def _analize_topology_robustness(self): logging.info("πŸ”„ Analizing node network robustness...") if not self._restructure_process_lock.locked(): if not await self.neighbors_left(): - if self._verbose: logging.info("No Neighbors left | reconnecting with Federation") + if self._verbose: logging.info("No Neighbors left | reconnecting with Federation") #await self.reconnect_to_federation() + await self.sana.create_and_suggest_action(SACommandAction.RECONNECT, self.reconnect_to_federation, None) elif self.np.need_more_neighbors() and self._restructure_available(): if self._verbose: logging.info("Insufficient Robustness | Upgrading robustness | Searching for more connections") self._update_restructure_cooldown() @@ -178,9 +179,11 @@ async def _analize_topology_robustness(self): if self._verbose: logging.info("All possible neighbors using nodes known are restricted...") else: pass + await self.sana.create_and_suggest_action(SACommandAction.SEARCH_CONNECTIONS, self.upgrade_connection_robustness, possible_neighbors) # asyncio.create_task(self.upgrade_connection_robustness(possible_neighbors)) else: if self._verbose: logging.info("Sufficient Robustness | no actions required") + await self.sana.create_and_suggest_action(SACommandAction.MAINTAIN_CONNECTIONS) else: if self._verbose: logging.info("❗️ Reestructure/Reconnecting process already running...") @@ -244,11 +247,12 @@ async def suggest_action(self, sac : SACommand): async def notify_all_suggestions_done(self, event_type): await SuggestionBuffer.get_instance().notify_all_suggestions_done_for_agent(self, event_type) - async def create_and_suggest_action(self, saca: SACommandAction, function : Callable, *args): + async def create_and_suggest_action(self, saca: SACommandAction, function: Callable = None, *args): if saca == SACommandAction.MAINTAIN_CONNECTIONS: sac = factory_sa_command( "connectivity", - SACommandAction.MAINTAIN_CONNECTIONS, + SACommandAction.MAINTAIN_CONNECTIONS, + self, "", SACommandPRIO.MEDIUM, False, @@ -260,7 +264,8 @@ async def create_and_suggest_action(self, saca: SACommandAction, function : Call elif saca == SACommandAction.SEARCH_CONNECTIONS: sac = factory_sa_command( "connectivity", - SACommandAction.MAINTAIN_CONNECTIONS, + SACommandAction.SEARCH_CONNECTIONS, + self, "", SACommandPRIO.MEDIUM, True, @@ -269,4 +274,16 @@ async def create_and_suggest_action(self, saca: SACommandAction, function : Call ) await self.suggest_action(sac) await self.notify_all_suggestions_done(RoundEndEvent) - \ No newline at end of file + elif saca == SACommandAction.RECONNECT: + sac = factory_sa_command( + "connectivity", + SACommandAction.RECONNECT, + self, + "", + SACommandPRIO.MEDIUM, + True, + None, + *args + ) + await self.suggest_action(sac) + await self.notify_all_suggestions_done(RoundEndEvent) \ No newline at end of file From 17ade4096e79857b481a08c2e0ed8032e07c192d Mon Sep 17 00:00:00 2001 From: FerTV Date: Tue, 8 Apr 2025 12:02:32 +0200 Subject: [PATCH 166/233] removed unused parameters --- nebula/controller.py | 1 - nebula/frontend/app.py | 1 + nebula/frontend/database.py | 20 +------------------- nebula/scenarios.py | 12 ------------ 4 files changed, 2 insertions(+), 32 deletions(-) diff --git a/nebula/controller.py b/nebula/controller.py index d999a6181..316aeea20 100755 --- a/nebula/controller.py +++ b/nebula/controller.py @@ -371,7 +371,6 @@ async def get_scenario_by_name( ) ] ): - logging.info("[FER] controller") from nebula.frontend.database import get_scenario_by_name try: diff --git a/nebula/frontend/app.py b/nebula/frontend/app.py index 6c2e1433f..e6eef0c19 100755 --- a/nebula/frontend/app.py +++ b/nebula/frontend/app.py @@ -256,6 +256,7 @@ async def controller_post(url, data=None): if response.status == 200: return await response.json() else: + logging.info(f"[FER] POST request to {url} data {data} failed with status code {response.status}") raise HTTPException(status_code=response.status, detail="Error posting data") diff --git a/nebula/frontend/database.py b/nebula/frontend/database.py index 2019544cb..a24179776 100755 --- a/nebula/frontend/database.py +++ b/nebula/frontend/database.py @@ -145,9 +145,6 @@ async def initialize_databases(databases_dir): poisoned_noise_percent TEXT, attack_params TEXT, with_reputation TEXT, - is_dynamic_topology TEXT, - is_dynamic_aggregation TEXT, - target_aggregation TEXT, random_geo TEXT, latitude TEXT, longitude TEXT, @@ -200,9 +197,6 @@ async def initialize_databases(databases_dir): "poisoned_noise_percent": "TEXT", "attack_params": "TEXT", "with_reputation": "TEXT", - "is_dynamic_topology": "TEXT", - "is_dynamic_aggregation": "TEXT", - "target_aggregation": "TEXT", "random_geo": "TEXT", "latitude": "TEXT", "longitude": "TEXT", @@ -605,9 +599,6 @@ def scenario_update_record( poisoned_noise_percent, attack_params, with_reputation, - is_dynamic_topology, - is_dynamic_aggregation, - target_aggregation, random_geo, latitude, longitude, @@ -623,7 +614,7 @@ def scenario_update_record( role, username ) VALUES ( - ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? + ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? ); """ _c.execute(insert_command, ( @@ -660,9 +651,6 @@ def scenario_update_record( scenario.poisoned_noise_percent, json.dumps(scenario.attack_params), scenario.with_reputation, - scenario.is_dynamic_topology, - scenario.is_dynamic_aggregation, - scenario.target_aggregation, scenario.random_geo, scenario.latitude, scenario.longitude, @@ -713,9 +701,6 @@ def scenario_update_record( poisoned_noise_percent = ?, attack_params = ?, with_reputation = ?, - is_dynamic_topology = ?, - is_dynamic_aggregation = ?, - target_aggregation = ?, random_geo = ?, latitude = ?, longitude = ?, @@ -765,9 +750,6 @@ def scenario_update_record( scenario.poisoned_noise_percent, json.dumps(scenario.attack_params), scenario.with_reputation, - scenario.is_dynamic_topology, - scenario.is_dynamic_aggregation, - scenario.target_aggregation, scenario.random_geo, scenario.latitude, scenario.longitude, diff --git a/nebula/scenarios.py b/nebula/scenarios.py index 17e7755bc..788c21222 100644 --- a/nebula/scenarios.py +++ b/nebula/scenarios.py @@ -70,9 +70,6 @@ def __init__( weight_model_similarity, weight_num_messages, weight_fraction_params_changed, - # is_dynamic_topology, - # is_dynamic_aggregation, - # target_aggregation, random_geo, latitude, longitude, @@ -129,9 +126,6 @@ def __init__( weight_model_similarity (float): Weight of model similarity. weight_num_messages (float): Weight of number of messages. weight_fraction_params_changed (float): Weight of fraction of parameters changed. - # is_dynamic_topology (bool): Indicator if topology is dynamic. - # is_dynamic_aggregation (bool): Indicator if aggregation is dynamic. - # target_aggregation (str): Target aggregation method. random_geo (bool): Indicator if random geo is used. latitude (float): Latitude for mobility. longitude (float): Longitude for mobility. @@ -181,9 +175,6 @@ def __init__( self.weight_model_similarity = weight_model_similarity self.weight_num_messages = weight_num_messages self.weight_fraction_params_changed = weight_fraction_params_changed - # self.is_dynamic_topology = is_dynamic_topology - # self.is_dynamic_aggregation = is_dynamic_aggregation - # self.target_aggregation = target_aggregation self.random_geo = random_geo self.latitude = latitude self.longitude = longitude @@ -395,9 +386,6 @@ def __init__(self, scenario, user=None): participant_config["adversarial_args"]["attacks"] = node_config["attacks"] participant_config["adversarial_args"]["attack_params"] = node_config["attack_params"] participant_config["defense_args"]["with_reputation"] = node_config["with_reputation"] - # participant_config["defense_args"]["is_dynamic_topology"] = self.scenario.is_dynamic_topology - # participant_config["defense_args"]["is_dynamic_aggregation"] = self.scenario.is_dynamic_aggregation - # participant_config["defense_args"]["target_aggregation"] = self.scenario.target_aggregation participant_config["defense_args"]["reputation_metrics"] = self.scenario.reputation_metrics participant_config["defense_args"]["initial_reputation"] = self.scenario.initial_reputation participant_config["defense_args"]["weighting_factor"] = self.scenario.weighting_factor From d3f09f1c372bee5e04e84c31b796c7f5c59cda46 Mon Sep 17 00:00:00 2001 From: FerTV Date: Tue, 8 Apr 2025 12:07:40 +0200 Subject: [PATCH 167/233] databases removed of frontend docker container --- nebula/controller.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nebula/controller.py b/nebula/controller.py index 316aeea20..ba7eb1b37 100755 --- a/nebula/controller.py +++ b/nebula/controller.py @@ -1143,7 +1143,7 @@ def run_frontend(self): f"{self.root_path}:/nebula", "/var/run/docker.sock:/var/run/docker.sock", f"{self.root_path}/nebula/frontend/config/nebula:/etc/nginx/sites-available/default", - f"{self.databases_dir}:/nebula/nebula/frontend/databases", + # f"{self.databases_dir}:/nebula/nebula/frontend/databases", ], extra_hosts={"host.docker.internal": "host-gateway"}, port_bindings={80: self.frontend_port, 8080: self.statistics_port}, From 8cac0b2a50b0f05397a9c9320542a0572a443619 Mon Sep 17 00:00:00 2001 From: FerTV Date: Tue, 8 Apr 2025 16:55:40 +0200 Subject: [PATCH 168/233] fix monitor page and node related endpoints --- nebula/controller.py | 16 +++++------ nebula/frontend/app.py | 64 +++++++++++++++++++++++------------------- 2 files changed, 42 insertions(+), 38 deletions(-) diff --git a/nebula/controller.py b/nebula/controller.py index ba7eb1b37..7d3d7d11d 100755 --- a/nebula/controller.py +++ b/nebula/controller.py @@ -279,22 +279,22 @@ async def list_nodes_by_scenario_name( logging.error(f"Error obtaining nodes: {e}") raise HTTPException(status_code=500, detail="Internal server error") - return {"nodes": nodes} + return nodes -@app.post("/nodes/update/") +@app.post("/nodes/update") async def update_nodes( node_uid: str = Body(..., embed=True), node_idx: str = Body(..., embed=True), node_ip: str = Body(..., embed=True), node_port: str = Body(..., embed=True), - node_role: int = Body(..., embed=True), + node_role: str = Body(..., embed=True), node_neighbors: str = Body(..., embed=True), node_latitude: str = Body(..., embed=True), node_longitude: str = Body(..., embed=True), node_timestamp: str = Body(..., embed=True), node_federation: str = Body(..., embed=True), - node_round_number: str = Body(..., embed=True), + node_round: str = Body(..., embed=True), node_scenario_name: str = Body(..., embed=True), node_run_hash: str = Body(..., embed=True) ): @@ -302,9 +302,8 @@ async def update_nodes( Controller endpoint to update nodes. """ from nebula.frontend.database import update_node_record - try: - update_node_record( + await update_node_record( node_uid, node_idx, node_ip, @@ -315,7 +314,7 @@ async def update_nodes( node_longitude, node_timestamp, node_federation, - node_round_number, + node_round, node_scenario_name, node_run_hash, ) @@ -379,7 +378,7 @@ async def get_scenario_by_name( logging.error(f"Error obtaining scenario {scenario_name}: {e}") raise HTTPException(status_code=500, detail="Internal server error") - return {"scenario": scenario} + return scenario @app.get("/notes") @@ -1143,7 +1142,6 @@ def run_frontend(self): f"{self.root_path}:/nebula", "/var/run/docker.sock:/var/run/docker.sock", f"{self.root_path}/nebula/frontend/config/nebula:/etc/nginx/sites-available/default", - # f"{self.databases_dir}:/nebula/nebula/frontend/databases", ], extra_hosts={"host.docker.internal": "host-gateway"}, port_bindings={80: self.frontend_port, 8080: self.statistics_port}, diff --git a/nebula/frontend/app.py b/nebula/frontend/app.py index e6eef0c19..0c9a05118 100755 --- a/nebula/frontend/app.py +++ b/nebula/frontend/app.py @@ -320,19 +320,19 @@ async def list_nodes_by_scenario_name(scenario_name): async def update_node_record(uid, idx, ip, port, role, neighbors, latitude, longitude, timestamp, federation, round_number, scenario_name, run_hash): url = f"http://{settings.controller_host}:{settings.controller_port}/nodes/update" data = { - "uid": uid, - "idx": idx, - "ip": ip, - "port": port, - "role": role, - "neighbors": neighbors, - "latitude": latitude, - "longitude": longitude, - "timestamp": timestamp, - "federation": federation, - "round": round_number, - "scenario_name": scenario_name, - "run_hash": run_hash, + "node_uid": uid, + "node_idx": idx, + "node_ip": ip, + "node_port": port, + "node_role": role, + "node_neighbors": neighbors, + "node_latitude": latitude, + "node_longitude": longitude, + "node_timestamp": timestamp, + "node_federation": federation, + "node_round": round_number, + "node_scenario_name": scenario_name, + "node_run_hash": run_hash, } await controller_post(url, data) @@ -717,6 +717,7 @@ async def nebula_dashboard_monitor(scenario_name: str, request: Request, session scenario = await get_scenario_by_name(scenario_name) if scenario: nodes_list = await list_nodes_by_scenario_name(scenario_name) + logging.info("[FER] nodes_list: %s", nodes_list) if nodes_list: nodes_config = [] nodes_status = [] @@ -855,21 +856,24 @@ async def nebula_update_node(scenario_name: str, request: Request): config = await request.json() timestamp = datetime.datetime.now() # Update the node in database - await update_node_record( - str(config["device_args"]["uid"]), - str(config["device_args"]["idx"]), - str(config["network_args"]["ip"]), - str(config["network_args"]["port"]), - str(config["device_args"]["role"]), - str(config["network_args"]["neighbors"]), - str(config["mobility_args"]["latitude"]), - str(config["mobility_args"]["longitude"]), - str(timestamp), - str(config["scenario_args"]["federation"]), - str(config["federation_args"]["round"]), - str(config["scenario_args"]["name"]), - str(config["tracking_args"]["run_hash"]), - ) + try: + await update_node_record( + str(config["device_args"]["uid"]), + str(config["device_args"]["idx"]), + str(config["network_args"]["ip"]), + str(config["network_args"]["port"]), + str(config["device_args"]["role"]), + str(config["network_args"]["neighbors"]), + str(config["mobility_args"]["latitude"]), + str(config["mobility_args"]["longitude"]), + str(timestamp), + str(config["scenario_args"]["federation"]), + str(config["federation_args"]["round"]), + str(config["scenario_args"]["name"]), + str(config["tracking_args"]["run_hash"]), + ) + except Exception as e: + logging.info("[FER] Error updating node record") neighbors_distance = config["mobility_args"]["neighbors_distance"] @@ -1325,7 +1329,7 @@ async def nebula_dashboard_deployment(request: Request, session: dict = Depends( # return nodes, attack_matrix -import math +# import math # def mobility_assign(nodes, mobile_participants_percent): @@ -1365,6 +1369,8 @@ async def node_stopped(scenario_name: str, request: Request): if str(node[1]) not in map(str, user_data.nodes_finished): finished = False + logging.info(f"[FER] Finished nodes: {user_data.nodes_finished} Finished: {finished}") + if finished: await stop_scenario(scenario_name, user) user_data.nodes_finished.clear() From ca04b128814ce92e57ef9b19f335eed7019a49c3 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Tue, 8 Apr 2025 17:11:15 +0200 Subject: [PATCH 169/233] fix samodule mediate function --- .../awareness/sacommand.py | 12 +++ .../awareness/samodule.py | 75 +++++++++++++++++-- .../awareness/sanetwork/sanetwork.py | 6 +- .../awareness/suggestionbuffer.py | 8 +- .../core/situationalawareness/nodemanager.py | 2 +- 5 files changed, 88 insertions(+), 15 deletions(-) diff --git a/nebula/core/situationalawareness/awareness/sacommand.py b/nebula/core/situationalawareness/awareness/sacommand.py index 97fbe3ef0..36c1ebbd2 100644 --- a/nebula/core/situationalawareness/awareness/sacommand.py +++ b/nebula/core/situationalawareness/awareness/sacommand.py @@ -1,6 +1,7 @@ from abc import abstractmethod from enum import Enum import asyncio +import logging from typing import TYPE_CHECKING if TYPE_CHECKING: from nebula.core.situationalawareness.awareness.samoduleagent import SAModuleAgent @@ -63,10 +64,19 @@ async def conflicts_with(self, other: "SACommand") -> bool: async def discard_command(self): await self._update_command_state(SACommandState.DISCARDED) + + def got_higher_priority_than(self, other_prio: SACommandPRIO): + return self._priority > other_prio + + def get_prio(self): + return self._priority def get_owner(self): return self._owner.get_agent() + def get_action(self) -> SACommandAction: + return self._action + async def _update_command_state(self, sacs : SACommandState): self._state = sacs if not self._state_future.done(): @@ -115,11 +125,13 @@ async def execute(self): def conflicts_with(self, other: "ConnectivityCommand") -> bool: """Determines if two commands conflict with each other.""" if self._target == other._target: + logging.info(f"Evaluation posible conflict | targets {self._target}, {other._target}") conflict_pairs = [ {SACommandAction.DISCONNECT, SACommandAction.DISCONNECT}, ] return {self._action, other._action} in conflict_pairs else: + logging.info(f"Evaluation posible conflict | actions {self._action}, {other._action}") conflict_pairs = [ {SACommandAction.DISCONNECT, SACommandAction.RECONNECT}, {SACommandAction.DISCONNECT, SACommandAction.MAINTAIN_CONNECTIONS}, diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index db0d74015..d1c0f9e96 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -3,6 +3,7 @@ import logging from nebula.addons.functions import print_msg_box from nebula.core.situationalawareness.awareness.suggestionbuffer import SuggestionBuffer +from nebula.core.situationalawareness.awareness.sacommand import SACommand from nebula.core.utils.locker import Locker from nebula.core.nebulaevents import RoundEndEvent from nebula.core.eventmanager import EventManager @@ -28,6 +29,7 @@ def __init__( nodemanager, addr, topology, + verbose = False, ): print_msg_box( msg=f"Starting Situational Awareness module...", @@ -45,6 +47,7 @@ def __init__( self._arbitrator_notification = asyncio.Event() self._suggestion_buffer = SuggestionBuffer(self._arbitrator_notification, verbose=True) self._communciation_manager = CommunicationsManager.get_instance() + self._verbose = verbose @property def nm(self): @@ -72,7 +75,8 @@ async def init(self): from nebula.core.situationalawareness.awareness.satraining.satraining import SATraining self._situational_awareness_network = SANetwork(self, self._addr, self._topology, True) self._situational_awareness_training = SATraining(self, self._addr, "qds", "fastreboot", verbose=True) - await EventManager.get_instance().subscribe_node_event(RoundEndEvent, self._mobility_actions) + await EventManager.get_instance().subscribe_node_event(RoundEndEvent, self._process_round_end_event) + await EventManager.get_instance().subscribe_node_event(AggregationEvent, self._process_aggregation_event) await self.san.init() await self.sat.init() @@ -80,12 +84,6 @@ async def init(self): def is_additional_participant(self): return self.nm.is_additional_participant() - async def _mobility_actions(self, ree : RoundEndEvent): - logging.info("πŸ”„ Starting additional mobility actions...") - asyncio.create_task(self.san.sa_component_actions()) - asyncio.create_task(self.sat.sa_component_actions()) - - """ ############################### # REESTRUCTURE TOPOLOGY # ############################### @@ -115,7 +113,68 @@ def get_actions(self): return self.san.get_actions() """ ############################### - # ARBITRATION # + # ARBITRATION # ############################### """ + async def _tie_breaker(c1: SACommand, c2: SACommand): + return True + + async def _process_round_end_event(self, ree : RoundEndEvent): + logging.info("πŸ”„ Arbitration | Round End Event...") + asyncio.create_task(self.san.sa_component_actions()) + asyncio.create_task(self.sat.sa_component_actions()) + valid_commands = await self._mediate_suggestions(RoundEndEvent) + + # Execute SACommand selected + for cmd in valid_commands: + if cmd.is_parallelizable(): + asyncio.create_task(cmd.execute()) + else: + await cmd.execute() + + async def _process_aggregation_event(self, age : AggregationEvent): + logging.info("πŸ”„ Arbitration | Aggregation Event...") + aggregation_command: SACommand = (await self._mediate_suggestions(AggregationEvent))[0] + final_updates = await aggregation_command.execute() + age.update_updates(final_updates) + + async def _mediate_suggestions(self, event_type): + if self._verbose: logging.info("Waiting for all suggestions done") + await self.sb.set_event_waited(event_type) + self._arbitrator_notification.wait() + suggestions = await self.sb.get_suggestions(event_type) + self._arbitrator_notification.clear() + if self._verbose: logging.info("Starting mediation, suggestions received") + if self._verbose: logging.info(f"Number of suggestions received: {len(suggestions)}") + + valid_commands: list[SACommand] = [] + + for cmd in suggestions: + has_conflict = False + to_remove: list[SACommand] = [] + + for other in valid_commands: + if await cmd.conflicts_with(other): + if self._verbose: logging.info(f"Conflict detected | between -- {await cmd.get_owner().get_agent()} and {await other.get_owner().get_agent()} --") + if self._verbose: logging.info(f"Action in conflict ({cmd.get_action()}, {other.get_action()})") + if cmd.got_higher_priority_than(other.get_prio()): + to_remove.append(other) + elif cmd.get_prio() == other.get_prio(): + if await self._tie_breaker(cmd, other): + to_remove.append(other) + else: + has_conflict = True + break + else: + has_conflict = True + break + + if not has_conflict: + for r in to_remove: + await r.discard_command() + valid_commands.remove(r) + valid_commands.append(cmd) + + return valid_commands + diff --git a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py index 96ec2b900..064afc19e 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py @@ -254,7 +254,7 @@ async def create_and_suggest_action(self, saca: SACommandAction, function: Calla SACommandAction.MAINTAIN_CONNECTIONS, self, "", - SACommandPRIO.MEDIUM, + SACommandPRIO.HIGH, False, function, None @@ -267,7 +267,7 @@ async def create_and_suggest_action(self, saca: SACommandAction, function: Calla SACommandAction.SEARCH_CONNECTIONS, self, "", - SACommandPRIO.MEDIUM, + SACommandPRIO.HIGH, True, function, *args @@ -280,7 +280,7 @@ async def create_and_suggest_action(self, saca: SACommandAction, function: Calla SACommandAction.RECONNECT, self, "", - SACommandPRIO.MEDIUM, + SACommandPRIO.HIGH, True, None, *args diff --git a/nebula/core/situationalawareness/awareness/suggestionbuffer.py b/nebula/core/situationalawareness/awareness/suggestionbuffer.py index c564d1ed6..52d353fc2 100644 --- a/nebula/core/situationalawareness/awareness/suggestionbuffer.py +++ b/nebula/core/situationalawareness/awareness/suggestionbuffer.py @@ -5,6 +5,7 @@ from nebula.core.situationalawareness.awareness.sacommand import SACommand from nebula.core.nebulaevents import NodeEvent, RoundEndEvent, AggregationEvent from collections import defaultdict +from typing import Type class SuggestionBuffer(): _instance = None @@ -28,11 +29,11 @@ def __init__(self, arbitrator_notification : asyncio.Event, verbose): self._arbitrator_notification = arbitrator_notification self._arbitrator_notification_lock = Locker("arbitrator_notification_lock", async_lock=True) self._verbose = verbose - self._buffer : dict[NodeEvent, list[SACommand]] = defaultdict(list) # {event: [suggestion]} + self._buffer : dict[Type[NodeEvent], list[SACommand]] = defaultdict(list) # {event: [suggestion]} self._suggestion_buffer_lock = Locker("suggestion_buffer_lock", async_lock=True) self._expected_agents = defaultdict(set) # {event: {agents}} self._expected_agents_lock = Locker("expected_agents_lock", async_lock=True) - self._event_notifications : dict[NodeEvent, list[tuple[SAModuleAgent, asyncio.Event]]] = {} + self._event_notifications : dict[Type[NodeEvent], list[tuple[SAModuleAgent, asyncio.Event]]] = {} self._event_waited = None async def register_event_agents(self, event_type, agent: SAModuleAgent): @@ -56,8 +57,9 @@ async def register_suggestion(self, event_type, agent: SAModuleAgent, suggestion async def set_event_waited(self, event_type): """Registers event to be waited""" if not self._event_waited: - if self._verbose: logging.info(f"Set notification when all suggestiones are being received for event: {event_type}") + if self._verbose: logging.info(f"Set notification when all suggestions are being received for event: {event_type}") self._event_waited = event_type + await self._notify_arbitrator(event_type) async def notify_all_suggestions_done_for_agent(self, saa : SAModuleAgent, event_type): """SA Agent notification that has registered all the suggestions for event_type.""" diff --git a/nebula/core/situationalawareness/nodemanager.py b/nebula/core/situationalawareness/nodemanager.py index d5e87a6b7..190a6dc4b 100644 --- a/nebula/core/situationalawareness/nodemanager.py +++ b/nebula/core/situationalawareness/nodemanager.py @@ -51,7 +51,7 @@ def __init__( self._desc_done = False #TODO remove - self._situational_awareness_module = SAModule(self, self.engine.addr, topology) + self._situational_awareness_module = SAModule(self, self.engine.addr, topology, True) self._verbose = verbose @property From f6c374d6f84ebbfbd2b1f233886084a2f199539b Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 9 Apr 2025 16:32:57 +0200 Subject: [PATCH 170/233] fix sa module arbitatrion --- .../awareness/sacommand.py | 13 +++---- .../awareness/samodule.py | 36 +++++++++++-------- .../awareness/suggestionbuffer.py | 31 ++++++++-------- 3 files changed, 43 insertions(+), 37 deletions(-) diff --git a/nebula/core/situationalawareness/awareness/sacommand.py b/nebula/core/situationalawareness/awareness/sacommand.py index 36c1ebbd2..7e4929316 100644 --- a/nebula/core/situationalawareness/awareness/sacommand.py +++ b/nebula/core/situationalawareness/awareness/sacommand.py @@ -1,7 +1,6 @@ from abc import abstractmethod from enum import Enum import asyncio -import logging from typing import TYPE_CHECKING if TYPE_CHECKING: from nebula.core.situationalawareness.awareness.samoduleagent import SAModuleAgent @@ -66,13 +65,13 @@ async def discard_command(self): await self._update_command_state(SACommandState.DISCARDED) def got_higher_priority_than(self, other_prio: SACommandPRIO): - return self._priority > other_prio + return self._priority.value > other_prio.value def get_prio(self): return self._priority - def get_owner(self): - return self._owner.get_agent() + async def get_owner(self): + return await self._owner.get_agent() def get_action(self) -> SACommandAction: return self._action @@ -122,16 +121,14 @@ async def execute(self): else: self._action_function(*self._args) - def conflicts_with(self, other: "ConnectivityCommand") -> bool: + async def conflicts_with(self, other: "ConnectivityCommand") -> bool: """Determines if two commands conflict with each other.""" if self._target == other._target: - logging.info(f"Evaluation posible conflict | targets {self._target}, {other._target}") conflict_pairs = [ {SACommandAction.DISCONNECT, SACommandAction.DISCONNECT}, ] return {self._action, other._action} in conflict_pairs else: - logging.info(f"Evaluation posible conflict | actions {self._action}, {other._action}") conflict_pairs = [ {SACommandAction.DISCONNECT, SACommandAction.RECONNECT}, {SACommandAction.DISCONNECT, SACommandAction.MAINTAIN_CONNECTIONS}, @@ -155,7 +152,7 @@ async def execute(self): await self._update_command_state(SACommandState.EXECUTED) return self._target - def conflicts_with(self, other: "AggregationCommand") -> bool: + async def conflicts_with(self, other: "AggregationCommand") -> bool: """Determines if two commands conflict with each other.""" topologic_conflict = False weight_conflict = False diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index d1c0f9e96..41f87e4fd 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -69,18 +69,16 @@ def cm(self): def sb(self): return self._suggestion_buffer - async def init(self): from nebula.core.situationalawareness.awareness.sanetwork.sanetwork import SANetwork from nebula.core.situationalawareness.awareness.satraining.satraining import SATraining - self._situational_awareness_network = SANetwork(self, self._addr, self._topology, True) + self._situational_awareness_network = SANetwork(self, self._addr, self._topology, verbose=True) self._situational_awareness_training = SATraining(self, self._addr, "qds", "fastreboot", verbose=True) - await EventManager.get_instance().subscribe_node_event(RoundEndEvent, self._process_round_end_event) - await EventManager.get_instance().subscribe_node_event(AggregationEvent, self._process_aggregation_event) await self.san.init() await self.sat.init() + await EventManager.get_instance().subscribe_node_event(RoundEndEvent, self._process_round_end_event) + await EventManager.get_instance().subscribe_node_event(AggregationEvent, self._process_aggregation_event) - def is_additional_participant(self): return self.nm.is_additional_participant() @@ -124,39 +122,46 @@ async def _process_round_end_event(self, ree : RoundEndEvent): logging.info("πŸ”„ Arbitration | Round End Event...") asyncio.create_task(self.san.sa_component_actions()) asyncio.create_task(self.sat.sa_component_actions()) - valid_commands = await self._mediate_suggestions(RoundEndEvent) + valid_commands = await self._arbitatrion_suggestions(RoundEndEvent) # Execute SACommand selected for cmd in valid_commands: if cmd.is_parallelizable(): + logging.info(f"going to execute parallelizable action: {cmd.get_action()}") asyncio.create_task(cmd.execute()) else: await cmd.execute() async def _process_aggregation_event(self, age : AggregationEvent): logging.info("πŸ”„ Arbitration | Aggregation Event...") - aggregation_command: SACommand = (await self._mediate_suggestions(AggregationEvent))[0] - final_updates = await aggregation_command.execute() - age.update_updates(final_updates) + aggregation_command = await self._arbitatrion_suggestions(AggregationEvent) + if len(aggregation_command): + if self._verbose: logging.info(f"Aggregation event resolved. SA Agente that suggest action: {await aggregation_command[0].get_owner}") + final_updates = await aggregation_command[0].execute() + age.update_updates(final_updates) - async def _mediate_suggestions(self, event_type): + async def _arbitatrion_suggestions(self, event_type): if self._verbose: logging.info("Waiting for all suggestions done") await self.sb.set_event_waited(event_type) - self._arbitrator_notification.wait() + await self._arbitrator_notification.wait() + logging.info("waiting released") suggestions = await self.sb.get_suggestions(event_type) self._arbitrator_notification.clear() - if self._verbose: logging.info("Starting mediation, suggestions received") - if self._verbose: logging.info(f"Number of suggestions received: {len(suggestions)}") + if not len(suggestions): + if self._verbose: logging.info("No suggestions for this event | Arbitatrion not required") + return [] + + if self._verbose: logging.info(f"Starting arbitatrion | Number of suggestions received: {len(suggestions)}") valid_commands: list[SACommand] = [] - for cmd in suggestions: + for agent, cmd in suggestions: has_conflict = False to_remove: list[SACommand] = [] for other in valid_commands: if await cmd.conflicts_with(other): - if self._verbose: logging.info(f"Conflict detected | between -- {await cmd.get_owner().get_agent()} and {await other.get_owner().get_agent()} --") + if self._verbose: logging.info(f"Conflict detected between -- {await cmd.get_owner()} and {await other.get_owner()} --") if self._verbose: logging.info(f"Action in conflict ({cmd.get_action()}, {other.get_action()})") if cmd.got_higher_priority_than(other.get_prio()): to_remove.append(other) @@ -176,5 +181,6 @@ async def _mediate_suggestions(self, event_type): valid_commands.remove(r) valid_commands.append(cmd) + logging.info("Arbitatrion finished") return valid_commands diff --git a/nebula/core/situationalawareness/awareness/suggestionbuffer.py b/nebula/core/situationalawareness/awareness/suggestionbuffer.py index 52d353fc2..3639f875a 100644 --- a/nebula/core/situationalawareness/awareness/suggestionbuffer.py +++ b/nebula/core/situationalawareness/awareness/suggestionbuffer.py @@ -29,21 +29,24 @@ def __init__(self, arbitrator_notification : asyncio.Event, verbose): self._arbitrator_notification = arbitrator_notification self._arbitrator_notification_lock = Locker("arbitrator_notification_lock", async_lock=True) self._verbose = verbose - self._buffer : dict[Type[NodeEvent], list[SACommand]] = defaultdict(list) # {event: [suggestion]} + self._buffer : dict[Type[NodeEvent], list[tuple[SAModuleAgent, SACommand]]] = defaultdict(list) self._suggestion_buffer_lock = Locker("suggestion_buffer_lock", async_lock=True) - self._expected_agents = defaultdict(set) # {event: {agents}} + self._expected_agents: dict[Type[NodeEvent] ,list[SAModuleAgent]] = defaultdict(list) self._expected_agents_lock = Locker("expected_agents_lock", async_lock=True) - self._event_notifications : dict[Type[NodeEvent], list[tuple[SAModuleAgent, asyncio.Event]]] = {} + self._event_notifications : dict[Type[NodeEvent], list[tuple[SAModuleAgent, asyncio.Event]]] = defaultdict(list) self._event_waited = None async def register_event_agents(self, event_type, agent: SAModuleAgent): """Registers expected agents for a given event.""" async with self._expected_agents_lock: if self._verbose: - logging.info(f"Registering SA Agent: {await agent.get_agent()} for event: {event_type}") + logging.info(f"Registering SA Agent: {await agent.get_agent()} for event: {event_type. __name__}") + if event_type not in self._event_notifications: - self._event_notifications[event_type] = [] - self._expected_agents[event_type].add(agent) + self._event_notifications[event_type] = [] + + self._expected_agents[event_type].append(agent) + existing_agents = {a for a, _ in self._event_notifications[event_type]} if agent not in existing_agents: self._event_notifications[event_type].append((agent, asyncio.Event())) @@ -51,13 +54,13 @@ async def register_event_agents(self, event_type, agent: SAModuleAgent): async def register_suggestion(self, event_type, agent: SAModuleAgent, suggestion: SACommand): """Registers a suggestion from an agent for a specific event.""" async with self._suggestion_buffer_lock: - if self._verbose: logging.info(f"Registering Suggestion from SA Agent: {await agent.get_agent()} for event: {event_type}") + if self._verbose: logging.info(f"Registering Suggestion from SA Agent: {await agent.get_agent()} for event: {event_type. __name__}") self._buffer[event_type].append((agent, suggestion)) async def set_event_waited(self, event_type): """Registers event to be waited""" if not self._event_waited: - if self._verbose: logging.info(f"Set notification when all suggestions are being received for event: {event_type}") + if self._verbose: logging.info(f"Set notification when all suggestions are being received for event: {event_type. __name__}") self._event_waited = event_type await self._notify_arbitrator(event_type) @@ -70,11 +73,11 @@ async def notify_all_suggestions_done_for_agent(self, saa : SAModuleAgent, event event.set() agent_found = True if self._verbose: - logging.info(f"SA Agent: {await saa.get_agent()} notifies all suggestions registered for event: {event_type}") + logging.info(f"SA Agent: {await saa.get_agent()} notifies all suggestions registered for event: {event_type. __name__}") break if not agent_found and self._verbose: - logging.error(f"SAModuleAgent: {await saa.get_agent()} not found on notifications awaited for event {event_type}") - await self._notify_arbitrator(event_type) + logging.error(f"SAModuleAgent: {await saa.get_agent()} not found on notifications awaited for event {event_type. __name__}") + await self._notify_arbitrator(event_type) async def _notify_arbitrator(self, event_type): """Checks whether to notify the arbitrator that all suggestions for event_type are received.""" @@ -104,12 +107,12 @@ async def _reset_notifications_for_agents(self, event_type, agents): if agent in agents: event.clear() - async def get_suggestions(self, event_type): + async def get_suggestions(self, event_type) -> list[tuple[SAModuleAgent, SACommand]]: """Retrieves all suggestions registered for a given event.""" async with self._suggestion_buffer_lock: async with self._expected_agents_lock: - if self._verbose: logging.info(f"Retrieving all sugestions for event: {event_type}") - suggestions = self._buffer.get(event_type, []).copy() + suggestions = list(self._buffer.get(event_type, [])) + if self._verbose: logging.info(f"Retrieving all sugestions for event: {event_type. __name__}") await self._clear_suggestions(event_type) return suggestions From 64e8e0596344c32f801496811f545a08076ddbdb Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 9 Apr 2025 18:15:07 +0200 Subject: [PATCH 171/233] feature close to integrate node forgiveness --- .../situationalawareness/awareness/sacommand.py | 2 +- .../situationalawareness/awareness/samodule.py | 3 ++- .../neighborpolicies/fcneighborpolicy.py | 12 ++++++++---- .../neighborpolicies/idleneighborpolicy.py | 12 ++++++++---- .../sanetwork/neighborpolicies/neighborpolicy.py | 6 +++++- .../neighborpolicies/ringneighborpolicy.py | 11 ++++++++--- .../neighborpolicies/starneighborpolicy.py | 9 +++++++-- .../awareness/sanetwork/sanetwork.py | 15 +++++++++++---- nebula/core/situationalawareness/nodemanager.py | 1 - 9 files changed, 50 insertions(+), 21 deletions(-) diff --git a/nebula/core/situationalawareness/awareness/sacommand.py b/nebula/core/situationalawareness/awareness/sacommand.py index 7e4929316..76cab146a 100644 --- a/nebula/core/situationalawareness/awareness/sacommand.py +++ b/nebula/core/situationalawareness/awareness/sacommand.py @@ -81,7 +81,7 @@ async def _update_command_state(self, sacs : SACommandState): if not self._state_future.done(): self._state_future.set_result(sacs) - def get_state_future(self): + def get_state_future(self) -> asyncio.Future: return self._state_future def is_parallelizable(self): diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index 41f87e4fd..077f4449e 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -127,9 +127,10 @@ async def _process_round_end_event(self, ree : RoundEndEvent): # Execute SACommand selected for cmd in valid_commands: if cmd.is_parallelizable(): - logging.info(f"going to execute parallelizable action: {cmd.get_action()}") + if self._verbose: logging.info(f"going to execute parallelizable action: {cmd.get_action()}") asyncio.create_task(cmd.execute()) else: + if self._verbose: logging.info(f"going to execute action: {cmd.get_action()}") await cmd.execute() async def _process_aggregation_event(self, age : AggregationEvent): diff --git a/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/fcneighborpolicy.py b/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/fcneighborpolicy.py index 73d559b97..a8cc7d2d3 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/fcneighborpolicy.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/fcneighborpolicy.py @@ -28,7 +28,7 @@ def set_config(self, config): config[0] -> list of self neighbors config[1] -> list of nodes known on federation config[2] -> self addr - config[3] -> NodeManager reference + config[3] -> stricted_topology """ logging.info("Initializing Fully-Connected Topology Neighbor Policy") self.neighbors_lock.acquire() @@ -73,12 +73,13 @@ def get_nodes_known(self, neighbors_too=False, neighbors_only=False): self.nodes_known_lock.release() return nk - def forget_nodes(self, node, forget_all=False): + def forget_nodes(self, nodes, forget_all=False): self.nodes_known_lock.acquire() if forget_all: self.nodes_known.clear() else: - self.nodes_known.discard(node) + for node in nodes: + self.nodes_known.discard(node) self.nodes_known_lock.release() def get_actions(self): @@ -113,4 +114,7 @@ def update_neighbors(self, node, remove=False): else: self.neighbors.add(node) logging.info(f"Add neighbor | addr: {node}") - self.neighbors_lock.release() \ No newline at end of file + self.neighbors_lock.release() + + def stricted_topology_status(stricted_topology: bool): + pass \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/idleneighborpolicy.py b/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/idleneighborpolicy.py index 565e63439..8e603529b 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/idleneighborpolicy.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/idleneighborpolicy.py @@ -28,7 +28,7 @@ def set_config(self, config): config[0] -> list of self neighbors config[1] -> list of nodes known on federation config[2] -> self addr - config[3] -> NodeManager reference + config[3] -> stricted_topology """ logging.info("Initializing Random Topology Neighbor Policy") self.neighbors_lock.acquire() @@ -73,12 +73,13 @@ def get_nodes_known(self, neighbors_too=False, neighbors_only=False): self.nodes_known_lock.release() return nk - def forget_nodes(self, node, forget_all=False): + def forget_nodes(self, nodes, forget_all=False): self.nodes_known_lock.acquire() if forget_all: self.nodes_known.clear() else: - self.nodes_known.discard(node) + for node in nodes: + self.nodes_known.discard(node) self.nodes_known_lock.release() def get_actions(self): @@ -113,4 +114,7 @@ def update_neighbors(self, node, remove=False): else: self.neighbors.add(node) logging.info(f"Add neighbor | addr: {node}") - self.neighbors_lock.release() \ No newline at end of file + self.neighbors_lock.release() + + def stricted_topology_status(stricted_topology: bool): + pass \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/neighborpolicy.py b/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/neighborpolicy.py index df2104c27..bd1827c51 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/neighborpolicy.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/neighborpolicy.py @@ -24,7 +24,7 @@ def meet_node(self, node): pass abstractmethod - def forget_nodes(self, node, forget_all=False): + def forget_nodes(self, nodes, forget_all=False): pass @abstractmethod @@ -35,6 +35,10 @@ def get_nodes_known(self, neighbors_too=False, neighbors_only=False): def update_neighbors(self, node, remove=False): pass + @abstractmethod + def stricted_topology_status(stricted_topology: bool): + pass + def factory_NeighborPolicy(topology) -> NeighborPolicy: from nebula.core.situationalawareness.awareness.sanetwork.neighborpolicies.idleneighborpolicy import IDLENeighborPolicy from nebula.core.situationalawareness.awareness.sanetwork.neighborpolicies.fcneighborpolicy import FCNeighborPolicy diff --git a/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/ringneighborpolicy.py b/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/ringneighborpolicy.py index 3154aae07..bab9a87a8 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/ringneighborpolicy.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/ringneighborpolicy.py @@ -25,6 +25,7 @@ def set_config(self, config): config[0] -> list of self neighbors config[1] -> list of nodes known on federation config[2] -> self.addr + config[3] -> stricted_topology """ logging.info("Initializing Ring Topology Neighbor Policy") self.neighbors_lock.acquire() @@ -52,12 +53,13 @@ def meet_node(self, node): self.nodes_known.add(node) self.nodes_known_lock.release() - def forget_nodes(self, node, forget_all=False): + def forget_nodes(self, nodes, forget_all=False): self.nodes_known_lock.acquire() if forget_all: self.nodes_known.clear() else: - self.nodes_known.discard(node) + for node in nodes: + self.nodes_known.discard(node) self.nodes_known_lock.release() def get_nodes_known(self, neighbors_too=False, neighbors_only=False): @@ -94,4 +96,7 @@ def update_neighbors(self, node, remove=False): self.neighbors.remove(node) else: self.neighbors.add(node) - self.neighbors_lock.release() \ No newline at end of file + self.neighbors_lock.release() + + def stricted_topology_status(stricted_topology: bool): + pass \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/starneighborpolicy.py b/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/starneighborpolicy.py index 1eda9ba91..884e87696 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/starneighborpolicy.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/starneighborpolicy.py @@ -23,6 +23,7 @@ def set_config(self, config): config[0] -> list of self neighbors, in this case, the star point config[1] -> list of nodes known on federation config[2] -> self.addr + config[3] -> stricted_topology """ self.neighbors_lock.acquire() self.neighbors = config[0] @@ -43,12 +44,13 @@ def meet_node(self, node): self.nodes_known.add(node) self.nodes_known_lock.release() - def forget_nodes(self, node, forget_all=False): + def forget_nodes(self, nodes, forget_all=False): self.nodes_known_lock.acquire() if forget_all: self.nodes_known.clear() else: - self.nodes_known.discard(node) + for node in nodes: + self.nodes_known.discard(node) self.nodes_known_lock.release() def get_nodes_known(self, neighbors_too=False, neighbors_only=False): @@ -77,4 +79,7 @@ def get_actions(self): return [ct_actions, df_actions] def update_neighbors(self, node, remove=False): + pass + + def stricted_topology_status(stricted_topology: bool): pass \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py index 064afc19e..a1a130701 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py @@ -77,7 +77,7 @@ async def init(self): await self.cm.get_addrs_current_connections(only_direct=True, myself=False), await self.cm.get_addrs_current_connections(only_direct=False, only_undirected=False, myself=False), self._addr, - self, + self._strict_topology, ]) await EventManager.get_instance().subscribe_node_event(NodeFoundEvent, self.process_node_found_event) @@ -206,7 +206,7 @@ async def upgrade_connection_robustness(self, possible_neighbors): self._restructure_process_lock.acquire() # addrs_to_connect = self.neighbor_policy.get_nodes_known(neighbors_too=False) # If we got some refs, try to connect to them - if len(possible_neighbors) > 0: + if possible_neighbors and len(possible_neighbors) > 0: if self._verbose: logging.info(f"Reestructuring | Addrs availables | addr list: {possible_neighbors}") await self.sam.nm.start_late_connection_process( connected=True, msg_type="discover_nodes", addrs_known=possible_neighbors @@ -223,7 +223,10 @@ async def stop_connections_with_federation(self): for n in neighbors: await self.cm.add_to_blacklist(n) for n in neighbors: - await self.cm.disconnect(n, mutual_disconnection=False, forced=True) + await self.cm.disconnect(n, mutual_disconnection=False, forced=True) + + async def forget_nodes(self, nodes_to_forget): + self.np.forget_nodes(nodes_to_forget) """ ############################### # SA NETWORK AGENT # @@ -248,6 +251,7 @@ async def notify_all_suggestions_done(self, event_type): await SuggestionBuffer.get_instance().notify_all_suggestions_done_for_agent(self, event_type) async def create_and_suggest_action(self, saca: SACommandAction, function: Callable = None, *args): + sac = None if saca == SACommandAction.MAINTAIN_CONNECTIONS: sac = factory_sa_command( "connectivity", @@ -274,6 +278,9 @@ async def create_and_suggest_action(self, saca: SACommandAction, function: Calla ) await self.suggest_action(sac) await self.notify_all_suggestions_done(RoundEndEvent) + sa_command_state = await sac.get_state_future() + #TODO en este caso se ha tratado de hacer conexiones con los nodos conocidos en el caso de haberlos, + # habrΓ­a que lanzar una tarea que los borre en el caso de no realizarse la conexiΓ³n a ellos elif saca == SACommandAction.RECONNECT: sac = factory_sa_command( "connectivity", @@ -286,4 +293,4 @@ async def create_and_suggest_action(self, saca: SACommandAction, function: Calla *args ) await self.suggest_action(sac) - await self.notify_all_suggestions_done(RoundEndEvent) \ No newline at end of file + await self.notify_all_suggestions_done(RoundEndEvent) diff --git a/nebula/core/situationalawareness/nodemanager.py b/nebula/core/situationalawareness/nodemanager.py index 190a6dc4b..98daf849a 100644 --- a/nebula/core/situationalawareness/nodemanager.py +++ b/nebula/core/situationalawareness/nodemanager.py @@ -246,7 +246,6 @@ async def start_late_connection_process(self, connected=False, msg_type="discove best_candidates = self.candidate_selector.select_candidates() if self._verbose: logging.info(f"Candidates | {[addr for addr, _, _ in best_candidates]}") - #TODO candidates not choosen --> disconnect try: for addr, _, _ in best_candidates: await self.add_pending_connection_confirmation(addr) From f9c56c5abaa2132a505d8adf826a916edcd6851b Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 10 Apr 2025 15:35:53 +0200 Subject: [PATCH 172/233] feature system monitoring + forget nodes --- .../nebulanetworksimulator.py | 32 +--- .../awareness/samodule.py | 10 +- .../awareness/sanetwork/sanetwork.py | 20 ++- .../trainingpolicy/bpstrainingpolicy.py | 2 +- .../trainingpolicy/qdstrainingpolicy.py | 2 +- .../trainingpolicy/sostrainingpolicy.py | 2 +- .../trainingpolicy/trainingpolicy.py | 2 +- .../awareness/{ => sautils}/sacommand.py | 6 +- .../awareness/{ => sautils}/samoduleagent.py | 2 +- .../awareness/sautils/sasystemmonitor.py | 161 ++++++++++++++++++ .../awareness/suggestionbuffer.py | 4 +- 11 files changed, 195 insertions(+), 48 deletions(-) rename nebula/core/situationalawareness/awareness/{ => sautils}/sacommand.py (97%) rename nebula/core/situationalawareness/awareness/{ => sautils}/samoduleagent.py (85%) create mode 100644 nebula/core/situationalawareness/awareness/sautils/sasystemmonitor.py diff --git a/nebula/addons/networksimulation/nebulanetworksimulator.py b/nebula/addons/networksimulation/nebulanetworksimulator.py index 0a6b4acf9..c74d9cc87 100644 --- a/nebula/addons/networksimulation/nebulanetworksimulator.py +++ b/nebula/addons/networksimulation/nebulanetworksimulator.py @@ -64,37 +64,7 @@ async def _change_network_conditions_based_on_distances(self, gpsevent : GPSEven logging.exception(f"πŸ“ Connection {addr} not found") except Exception: logging.exception("πŸ“ Error changing connections based on distance") - - # async def _change_network_conditions_based_on_distances(self): - # grace_time = self._cm.config.participant["mobility_args"]["grace_time_mobility"] - # if self._verbose: logging.info(f"Waiting {grace_time}s to start applying network conditions based on distances between devices") - # await asyncio.sleep(grace_time) - - # while self._running: - # await asyncio.sleep(self._refresh_interval) - # if self._verbose: logging.info("Refresh | conditions based on distances...") - # current_connections = await self._cm.get_addrs_current_connections() - # try: - # for addr in current_connections: - # distance = self._cm.connections[addr].get_neighbor_distance() - # if distance is None: - # # If the distance is not found, we skip the node - # continue - # conditions = await self._calculate_network_conditions(distance) - # # Only update the network conditions if they have changed - # if (addr not in self._current_network_conditions or self._current_network_conditions[addr] != conditions): - # addr_ip = addr.split(":")[0] - # self._set_network_condition_for_addr(self._node_interface, addr_ip, conditions["bandwidth"], conditions["delay"]) - # self._set_network_condition_for_multicast(self._node_interface, addr_ip, self.IP_MULTICAST, conditions["bandwidth"], conditions["delay"]) - # async with self._network_conditions_lock: - # self._current_network_conditions[addr] = conditions - # else: - # logging.info("network conditions havent changed since last time") - # except KeyError: - # logging.exception(f"πŸ“ Connection {addr} not found") - # except Exception: - # logging.exception("πŸ“ Error changing connections based on distance") - + async def set_thresholds(self, thresholds : dict): async with self._network_conditions_lock: self._network_conditions = thresholds diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index 077f4449e..30555d864 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -3,12 +3,13 @@ import logging from nebula.addons.functions import print_msg_box from nebula.core.situationalawareness.awareness.suggestionbuffer import SuggestionBuffer -from nebula.core.situationalawareness.awareness.sacommand import SACommand +from nebula.core.situationalawareness.awareness.sautils.sacommand import SACommand from nebula.core.utils.locker import Locker from nebula.core.nebulaevents import RoundEndEvent from nebula.core.eventmanager import EventManager from nebula.core.nebulaevents import RoundEndEvent, AggregationEvent from nebula.core.network.communications import CommunicationsManager +from nebula.core.situationalawareness.awareness.sautils.sasystemmonitor import SystemMonitor from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -48,6 +49,7 @@ def __init__( self._suggestion_buffer = SuggestionBuffer(self._arbitrator_notification, verbose=True) self._communciation_manager = CommunicationsManager.get_instance() self._verbose = verbose + self._sys_monitor = SystemMonitor() @property def nm(self): @@ -78,7 +80,7 @@ async def init(self): await self.sat.init() await EventManager.get_instance().subscribe_node_event(RoundEndEvent, self._process_round_end_event) await EventManager.get_instance().subscribe_node_event(AggregationEvent, self._process_aggregation_event) - + def is_additional_participant(self): return self.nm.is_additional_participant() @@ -127,10 +129,10 @@ async def _process_round_end_event(self, ree : RoundEndEvent): # Execute SACommand selected for cmd in valid_commands: if cmd.is_parallelizable(): - if self._verbose: logging.info(f"going to execute parallelizable action: {cmd.get_action()}") + if self._verbose: logging.info(f"going to execute parallelizable action: {cmd.get_action()} made by: {await cmd.get_owner()}") asyncio.create_task(cmd.execute()) else: - if self._verbose: logging.info(f"going to execute action: {cmd.get_action()}") + if self._verbose: logging.info(f"going to execute action: {cmd.get_action()} made by: {await cmd.get_owner()}") await cmd.execute() async def _process_aggregation_event(self, age : AggregationEvent): diff --git a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py index a1a130701..511c270b8 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py @@ -9,8 +9,8 @@ from nebula.core.nebulaevents import NodeFoundEvent, UpdateNeighborEvent, ExperimentFinishEvent, RoundEndEvent from nebula.core.network.communications import CommunicationsManager from nebula.core.situationalawareness.awareness.samodule import SAMComponent -from nebula.core.situationalawareness.awareness.samoduleagent import SAModuleAgent -from nebula.core.situationalawareness.awareness.sacommand import SACommand, SACommandAction, SACommandPRIO, SACommandState, factory_sa_command +from nebula.core.situationalawareness.awareness.sautils.samoduleagent import SAModuleAgent +from nebula.core.situationalawareness.awareness.sautils.sacommand import SACommand, SACommandAction, SACommandPRIO, SACommandState, factory_sa_command from nebula.core.situationalawareness.awareness.suggestionbuffer import SuggestionBuffer from typing import TYPE_CHECKING @@ -20,6 +20,9 @@ RESTRUCTURE_COOLDOWN = 5 class SANetwork(SAMComponent): + + NEIGHBOR_VERIFICATION_TIMEOUT = 30 + def __init__( self, sam: "SAModule", @@ -225,6 +228,14 @@ async def stop_connections_with_federation(self): for n in neighbors: await self.cm.disconnect(n, mutual_disconnection=False, forced=True) + async def verify_neighbors_stablished(self, nodes: set): + await asyncio.sleep(self.NEIGHBOR_VERIFICATION_TIMEOUT) + nodes_to_forget = nodes.copy() + neighbors = self.np.get_nodes_known(neighbors_only=True) + if neighbors: + nodes_to_forget.difference_update(neighbors) + self.forget_nodes(nodes_to_forget) + async def forget_nodes(self, nodes_to_forget): self.np.forget_nodes(nodes_to_forget) @@ -279,8 +290,9 @@ async def create_and_suggest_action(self, saca: SACommandAction, function: Calla await self.suggest_action(sac) await self.notify_all_suggestions_done(RoundEndEvent) sa_command_state = await sac.get_state_future() - #TODO en este caso se ha tratado de hacer conexiones con los nodos conocidos en el caso de haberlos, - # habrΓ­a que lanzar una tarea que los borre en el caso de no realizarse la conexiΓ³n a ellos + if sa_command_state == SACommandState.EXECUTED: + (nodes_to_forget,) = args + asyncio.create_task(self._san.verify_neighbors_stablished(nodes_to_forget)) elif saca == SACommandAction.RECONNECT: sac = factory_sa_command( "connectivity", diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/bpstrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/bpstrainingpolicy.py index 576cae626..a5e00d5d9 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/bpstrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/bpstrainingpolicy.py @@ -1,6 +1,6 @@ from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.trainingpolicy import TrainingPolicy from nebula.core.situationalawareness.awareness.suggestionbuffer import SuggestionBuffer -from nebula.core.situationalawareness.awareness.sacommand import SACommand, factory_sa_command, SACommandAction, SACommandPRIO +from nebula.core.situationalawareness.awareness.sautils.sacommand import SACommand, factory_sa_command, SACommandAction, SACommandPRIO from nebula.core.nebulaevents import RoundEndEvent class BPSTrainingPolicy(TrainingPolicy): diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py index 56c7009f6..f43e1e01f 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py @@ -7,7 +7,7 @@ from nebula.core.eventmanager import EventManager from nebula.core.nebulaevents import AggregationEvent, UpdateNeighborEvent, RoundEndEvent from nebula.core.situationalawareness.awareness.suggestionbuffer import SuggestionBuffer -from nebula.core.situationalawareness.awareness.sacommand import SACommand, SACommandAction, SACommandPRIO, factory_sa_command +from nebula.core.situationalawareness.awareness.sautils.sacommand import SACommand, SACommandAction, SACommandPRIO, factory_sa_command from nebula.core.network.communications import CommunicationsManager import math diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py index 2e355e54c..a71fdc08c 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py @@ -5,7 +5,7 @@ from nebula.core.eventmanager import EventManager from nebula.core.nebulaevents import UpdateReceivedEvent, AggregationEvent, RoundStartEvent, UpdateNeighborEvent, RoundEndEvent from nebula.core.situationalawareness.awareness.suggestionbuffer import SuggestionBuffer -from nebula.core.situationalawareness.awareness.sacommand import SACommand, SACommandAction, SACommandPRIO, factory_sa_command +from nebula.core.situationalawareness.awareness.sautils.sacommand import SACommand, SACommandAction, SACommandPRIO, factory_sa_command from nebula.core.network.communications import CommunicationsManager import time import asyncio diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py index 5e343a14c..78eb948c0 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from nebula.core.situationalawareness.awareness.samoduleagent import SAModuleAgent +from nebula.core.situationalawareness.awareness.sautils.samoduleagent import SAModuleAgent class TrainingPolicy(SAModuleAgent): diff --git a/nebula/core/situationalawareness/awareness/sacommand.py b/nebula/core/situationalawareness/awareness/sautils/sacommand.py similarity index 97% rename from nebula/core/situationalawareness/awareness/sacommand.py rename to nebula/core/situationalawareness/awareness/sautils/sacommand.py index 76cab146a..a8411d599 100644 --- a/nebula/core/situationalawareness/awareness/sacommand.py +++ b/nebula/core/situationalawareness/awareness/sautils/sacommand.py @@ -3,7 +3,7 @@ import asyncio from typing import TYPE_CHECKING if TYPE_CHECKING: - from nebula.core.situationalawareness.awareness.samoduleagent import SAModuleAgent + from nebula.core.situationalawareness.awareness.sautils.samoduleagent import SAModuleAgent class SACommandType(Enum): CONNECTIVITY = "Connectivity" @@ -19,9 +19,11 @@ class SACommandAction(Enum): DISCARD_UPDATE = "discard_update" class SACommandPRIO(Enum): + CRITICAL = 20 HIGH = 10 MEDIUM = 5 - LOW = 1 + LOW = 3 + MAINTENANCE = 1 class SACommandState(Enum): PENDING = "pending" diff --git a/nebula/core/situationalawareness/awareness/samoduleagent.py b/nebula/core/situationalawareness/awareness/sautils/samoduleagent.py similarity index 85% rename from nebula/core/situationalawareness/awareness/samoduleagent.py rename to nebula/core/situationalawareness/awareness/sautils/samoduleagent.py index 4b3bbcb75..03878e6ef 100644 --- a/nebula/core/situationalawareness/awareness/samoduleagent.py +++ b/nebula/core/situationalawareness/awareness/sautils/samoduleagent.py @@ -1,5 +1,5 @@ from abc import abstractmethod, ABC -from nebula.core.situationalawareness.awareness.sacommand import SACommand +from nebula.core.situationalawareness.awareness.sautils.sacommand import SACommand class SAModuleAgent(ABC): diff --git a/nebula/core/situationalawareness/awareness/sautils/sasystemmonitor.py b/nebula/core/situationalawareness/awareness/sautils/sasystemmonitor.py new file mode 100644 index 000000000..11843f581 --- /dev/null +++ b/nebula/core/situationalawareness/awareness/sautils/sasystemmonitor.py @@ -0,0 +1,161 @@ +import psutil +from pynvml import nvmlInit, nvmlShutdown, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, nvmlDeviceGetUtilizationRates +from nebula.core.utils.locker import Locker +import asyncio +import platform +import subprocess +import logging + +class SystemMonitor: + _instance = None + _lock = Locker("communications_manager_lock", async_lock=False) + + def __new__(cls): + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + @classmethod + def get_instance(cls): + """Obtain SystemMonitor instance""" + if cls._instance is None: + raise ValueError("SystemMonitor has not been initialized yet.") + return cls._instance + + def __init__(self): + """Initialize the system monitor and check for GPU availability.""" + if not hasattr(self, '_initialized'): # To avoid reinitialization on subsequent calls + # Try to initialize NVIDIA library if available + try: + nvmlInit() + self.gpu_available = True # Flag to check if GPU is available + except Exception: + self.gpu_available = False # If not, set GPU availability to False + self._initialized = True + + async def get_cpu_usage(self): + """Returns the CPU usage percentage.""" + return psutil.cpu_percent(interval=1) + + async def get_cpu_per_core_usage(self): + """Returns the CPU usage percentage per core.""" + return psutil.cpu_percent(interval=1, percpu=True) + + async def get_memory_usage(self): + """Returns the percentage of used RAM memory.""" + memory_info = psutil.virtual_memory() + return memory_info.percent + + async def get_swap_memory_usage(self): + """Returns the percentage of used swap memory.""" + swap_info = psutil.swap_memory() + return swap_info.percent + + async def get_network_usage(self, interval=5): + """Measures network usage over a time interval and returns bandwidth percentage usage.""" + os_name = platform.system() + + # Get max bandwidth (only implemented for Linux) + if os_name == "Linux": + max_bandwidth = self._get_max_bandwidth_linux() + else: + max_bandwidth = None + + # Take first measurement + net_io_start = psutil.net_io_counters() + bytes_sent_start = net_io_start.bytes_sent + bytes_recv_start = net_io_start.bytes_recv + + # Wait for the interval + await asyncio.sleep(interval) + + # Take second measurement + net_io_end = psutil.net_io_counters() + bytes_sent_end = net_io_end.bytes_sent + bytes_recv_end = net_io_end.bytes_recv + + # Calculate bytes transferred during interval + bytes_sent = bytes_sent_end - bytes_sent_start + bytes_recv = bytes_recv_end - bytes_recv_start + total_bytes = bytes_sent + bytes_recv + + # Calculate bandwidth usage percentage + bandwidth_used_percent = self._calculate_bandwidth_usage( + total_bytes, + max_bandwidth, + interval + ) + + return { + 'interval': interval, + 'bytes_sent': bytes_sent, + 'bytes_recv': bytes_recv, + 'bandwidth_used_percent': bandwidth_used_percent, + 'bandwidth_max': max_bandwidth + } + + #TODO catched speed to avoid reading file + def _get_max_bandwidth_linux(self, interface="eth0"): + """Reads max bandwidth from /sys/class/net/{iface}/speed (Linux only).""" + try: + with open(f"/sys/class/net/{interface}/speed", "r") as f: + speed = int(f.read().strip()) # In Mbps + return speed + except Exception as e: + print(f"Could not read max bandwidth: {e}") + return None + + def _calculate_bandwidth_usage(self, bytes_transferred, max_bandwidth_mbps, interval): + """Calculates bandwidth usage percentage over the given interval.""" + if max_bandwidth_mbps is None or interval <= 0: + return None + + try: + # Convert bytes to megabits + megabits_transferred = (bytes_transferred * 8) / (1024 * 1024) + # Calculate usage in Mbps + current_usage_mbps = megabits_transferred / interval + # Percentage of max bandwidth + usage_percentage = (current_usage_mbps / max_bandwidth_mbps) * 100 + return usage_percentage + except Exception as e: + print(f"Error calculating bandwidth usage: {e}") + return None + + async def get_gpu_usage(self): + """Returns GPU usage stats if available, otherwise returns None.""" + if not self.gpu_available: + return None # No GPU available, return None + + # If GPU is available, get the usage using pynvml + device_count = nvmlDeviceGetCount() + gpu_usage = [] + for i in range(device_count): + handle = nvmlDeviceGetHandleByIndex(i) + memory_info = nvmlDeviceGetMemoryInfo(handle) + utilization = nvmlDeviceGetUtilizationRates(handle) + gpu_usage.append({ + 'gpu': i, + 'memory_used': memory_info.used / 1024**2, # MB + 'memory_total': memory_info.total / 1024**2, # MB + 'gpu_usage': utilization.gpu + }) + return gpu_usage + + async def get_system_resources(self): + """Returns a dictionary with all system resource usage statistics.""" + resources = { + 'cpu_usage': await self.get_cpu_usage(), + 'cpu_per_core_usage': await self.get_cpu_per_core_usage(), + 'memory_usage': await self.get_memory_usage(), + 'swap_memory_usage': await self.get_swap_memory_usage(), + 'network_usage': await self.get_network_usage(), + 'gpu_usage': await self.get_gpu_usage(), # Includes GPU usage or None if no GPU + } + return resources + + async def close(self): + """Closes the initialization of the NVIDIA library (if used).""" + if self.gpu_available: + nvmlShutdown() \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/suggestionbuffer.py b/nebula/core/situationalawareness/awareness/suggestionbuffer.py index 3639f875a..989a2ba80 100644 --- a/nebula/core/situationalawareness/awareness/suggestionbuffer.py +++ b/nebula/core/situationalawareness/awareness/suggestionbuffer.py @@ -1,8 +1,8 @@ from nebula.core.utils.locker import Locker from nebula.utils import logging import asyncio -from nebula.core.situationalawareness.awareness.samoduleagent import SAModuleAgent -from nebula.core.situationalawareness.awareness.sacommand import SACommand +from nebula.core.situationalawareness.awareness.sautils.samoduleagent import SAModuleAgent +from nebula.core.situationalawareness.awareness.sautils.sacommand import SACommand from nebula.core.nebulaevents import NodeEvent, RoundEndEvent, AggregationEvent from collections import defaultdict from typing import Type From 5a47c4a07084c1938d76a4f437d6206e12652be8 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Fri, 11 Apr 2025 13:17:20 +0200 Subject: [PATCH 173/233] feature static arbitatrion policy --- .../arbitatrionpolicies/arbitatrionpolicy.py | 24 +++++++++++ .../saarbitatrionpolicy.py | 16 +++++++ .../staticarbitatrionpolicy.py | 42 +++++++++++++++++++ .../awareness/samodule.py | 14 ++++--- .../awareness/sanetwork/sanetwork.py | 6 +-- .../trainingpolicy/qdstrainingpolicy.py | 3 +- .../awareness/sautils/sasystemmonitor.py | 2 - 7 files changed, 96 insertions(+), 11 deletions(-) create mode 100644 nebula/core/situationalawareness/awareness/arbitatrionpolicies/arbitatrionpolicy.py create mode 100644 nebula/core/situationalawareness/awareness/arbitatrionpolicies/saarbitatrionpolicy.py create mode 100644 nebula/core/situationalawareness/awareness/arbitatrionpolicies/staticarbitatrionpolicy.py diff --git a/nebula/core/situationalawareness/awareness/arbitatrionpolicies/arbitatrionpolicy.py b/nebula/core/situationalawareness/awareness/arbitatrionpolicies/arbitatrionpolicy.py new file mode 100644 index 000000000..cd1e6db7d --- /dev/null +++ b/nebula/core/situationalawareness/awareness/arbitatrionpolicies/arbitatrionpolicy.py @@ -0,0 +1,24 @@ +from abc import ABC, abstractmethod +from nebula.core.situationalawareness.awareness.sautils.sacommand import SACommand + +class ArbitatrionPolicy(ABC): + + @abstractmethod + async def init(self, config): + raise NotImplementedError + + @abstractmethod + async def tie_break(self, sac1: SACommand, sac2: SACommand) -> bool: + raise NotImplementedError + +def factory_arbitatrion_policy(arbitatrion_policy, verbose) -> ArbitatrionPolicy: + from nebula.core.situationalawareness.awareness.arbitatrionpolicies.staticarbitatrionpolicy import SAP + from nebula.core.situationalawareness.awareness.arbitatrionpolicies.saarbitatrionpolicy import SAAP + + options = { + "sap": SAP, # "Static Arbitatrion Policy" (SAP) -- default value + "saap": SAAP, # "Situational Awareness Arbitatrion Policy" (SAAP) + } + + cs = options.get(arbitatrion_policy, SAP) + return cs(verbose) \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/arbitatrionpolicies/saarbitatrionpolicy.py b/nebula/core/situationalawareness/awareness/arbitatrionpolicies/saarbitatrionpolicy.py new file mode 100644 index 000000000..906abf8a7 --- /dev/null +++ b/nebula/core/situationalawareness/awareness/arbitatrionpolicies/saarbitatrionpolicy.py @@ -0,0 +1,16 @@ +import asyncio +from nebula.core.situationalawareness.awareness.arbitatrionpolicies.arbitatrionpolicy import ArbitatrionPolicy +from nebula.core.situationalawareness.awareness.sautils.sacommand import SACommand + +class SAAP(ArbitatrionPolicy): + def __init__(self, verbose): + pass + + async def init(self, config): + pass + + async def tie_break(self, sac1: SACommand, sac2: SACommand) -> SACommand: + """ + Tie break conflcited SA Commands + """ + pass \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/arbitatrionpolicies/staticarbitatrionpolicy.py b/nebula/core/situationalawareness/awareness/arbitatrionpolicies/staticarbitatrionpolicy.py new file mode 100644 index 000000000..decc5f9ca --- /dev/null +++ b/nebula/core/situationalawareness/awareness/arbitatrionpolicies/staticarbitatrionpolicy.py @@ -0,0 +1,42 @@ +import logging +from nebula.core.situationalawareness.awareness.sautils.sacommand import SACommand +from nebula.core.situationalawareness.awareness.arbitatrionpolicies.arbitatrionpolicy import ArbitatrionPolicy +import asyncio + + +class SAP(ArbitatrionPolicy): # Static Arbitatrion Policy + def __init__(self, verbose): + self._verbose = verbose + # Define static weights for SA Agents from SA Components + self.agent_weights = { + "SATraining": 1, + "SANetwork": 2, + "SAReputation": 3 + } + + async def init(self, config): + pass + + async def _get_agent_category(self, sa_command: SACommand) -> str: + """ + Extract agent category name. + Example: "SATraining_Agent1" β†’ "SATraining" + """ + full_name = await sa_command.get_owner() + return full_name.split("_")[0] if "_" in full_name else full_name + + async def tie_break(self, sac1: SACommand, sac2: SACommand) -> bool: + """ + Tie break conflcited SA Commands + """ + if self._verbose: logging.info(f"Tie break between ({await sac1.get_owner()}, {sac1.get_action().value}) & ({await sac2.get_owner()}, {sac2.get_action().value})") + async def get_weight(cmd): + category = await self._get_agent_category(cmd) + return self.agent_weights.get(category, 0) + + if await get_weight(sac1) > await get_weight(sac2): + if self._verbose: logging.info(f"Tie break resolved, SA Command choosen ({await sac1.get_owner()}, {sac1.get_action().value})") + return True + else: + if self._verbose: logging.info(f"Tie break resolved, SA Command choosen ({await sac2.get_owner()}, {sac2.get_action().value})") + return False diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index 30555d864..f430f6e2d 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -10,6 +10,7 @@ from nebula.core.nebulaevents import RoundEndEvent, AggregationEvent from nebula.core.network.communications import CommunicationsManager from nebula.core.situationalawareness.awareness.sautils.sasystemmonitor import SystemMonitor +from nebula.core.situationalawareness.awareness.arbitatrionpolicies.arbitatrionpolicy import factory_arbitatrion_policy from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -48,8 +49,10 @@ def __init__( self._arbitrator_notification = asyncio.Event() self._suggestion_buffer = SuggestionBuffer(self._arbitrator_notification, verbose=True) self._communciation_manager = CommunicationsManager.get_instance() - self._verbose = verbose self._sys_monitor = SystemMonitor() + self._arbitatrion_policy = factory_arbitatrion_policy("sad", True) + self._verbose = verbose + @property def nm(self): @@ -71,6 +74,10 @@ def cm(self): def sb(self): return self._suggestion_buffer + @property + def ab(self): + return self._arbitatrion_policy + async def init(self): from nebula.core.situationalawareness.awareness.sanetwork.sanetwork import SANetwork from nebula.core.situationalawareness.awareness.satraining.satraining import SATraining @@ -117,9 +124,6 @@ def get_actions(self): ############################### """ - async def _tie_breaker(c1: SACommand, c2: SACommand): - return True - async def _process_round_end_event(self, ree : RoundEndEvent): logging.info("πŸ”„ Arbitration | Round End Event...") asyncio.create_task(self.san.sa_component_actions()) @@ -169,7 +173,7 @@ async def _arbitatrion_suggestions(self, event_type): if cmd.got_higher_priority_than(other.get_prio()): to_remove.append(other) elif cmd.get_prio() == other.get_prio(): - if await self._tie_breaker(cmd, other): + if await self.ab.tie_break(cmd, other): to_remove.append(other) else: has_conflict = True diff --git a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py index 511c270b8..5c1884041 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py @@ -193,7 +193,7 @@ async def _analize_topology_robustness(self): async def reconnect_to_federation(self): self._restructure_process_lock.acquire() await self.cm.clear_restrictions() - await asyncio.sleep(120) + await asyncio.sleep(120) #TODO remove # If we got some refs, try to reconnect to them if len(self.np.get_nodes_known()) > 0: if self._verbose: logging.info("Reconnecting | Addrs availables") @@ -269,7 +269,7 @@ async def create_and_suggest_action(self, saca: SACommandAction, function: Calla SACommandAction.MAINTAIN_CONNECTIONS, self, "", - SACommandPRIO.HIGH, + SACommandPRIO.MEDIUM, False, function, None @@ -282,7 +282,7 @@ async def create_and_suggest_action(self, saca: SACommandAction, function: Calla SACommandAction.SEARCH_CONNECTIONS, self, "", - SACommandPRIO.HIGH, + SACommandPRIO.MEDIUM, True, function, *args diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py index f43e1e01f..d697577a9 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/qdstrainingpolicy.py @@ -17,7 +17,7 @@ class QDSTrainingPolicy(TrainingPolicy): SIMILARITY_THRESHOLD = 0.73 INACTIVE_THRESHOLD = 3 GRACE_ROUNDS = 0 - CHECK_COOLDOWN = 1 + CHECK_COOLDOWN = 10000 def __init__(self, config : dict): self._addr = config["addr"] @@ -118,6 +118,7 @@ async def evaluate(self): sorted_redundant_nodes = sorted(redundant_nodes, key=lambda x: x[1]) n_discarded = math.ceil((len(redundant_nodes)/2)) discard_nodes = sorted_redundant_nodes[:n_discarded] + discard_nodes = [node for (node,_) in discard_nodes] if self._verbose: logging.info(f"Discarded redundant nodes: {discard_nodes}") result = result.union(discard_nodes) else: diff --git a/nebula/core/situationalawareness/awareness/sautils/sasystemmonitor.py b/nebula/core/situationalawareness/awareness/sautils/sasystemmonitor.py index 11843f581..c1a71a286 100644 --- a/nebula/core/situationalawareness/awareness/sautils/sasystemmonitor.py +++ b/nebula/core/situationalawareness/awareness/sautils/sasystemmonitor.py @@ -3,8 +3,6 @@ from nebula.core.utils.locker import Locker import asyncio import platform -import subprocess -import logging class SystemMonitor: _instance = None From d4075aaf4ef8ff87cf5cc44d929815ab520960bd Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 16 Apr 2025 12:56:06 +0200 Subject: [PATCH 174/233] feature behavior reputation --- nebula/core/nebulaevents.py | 14 ++ nebula/core/network/actions.py | 6 + nebula/core/network/blacklist.py | 6 + nebula/core/network/messages.py | 8 +- nebula/core/pb/nebula.proto | 12 ++ .../behaviorreputation.py} | 166 +++++++++--------- .../awareness/sareputation/sareputation.py | 38 ++++ .../trainingpolicy/trainingpolicy.py | 2 - .../frontend/config/participant.json.example | 11 +- 9 files changed, 177 insertions(+), 86 deletions(-) rename nebula/core/situationalawareness/awareness/{satraining/trainingpolicy/sostrainingpolicy.py => sareputation/behaviorreputation.py} (66%) create mode 100644 nebula/core/situationalawareness/awareness/sareputation/sareputation.py diff --git a/nebula/core/nebulaevents.py b/nebula/core/nebulaevents.py index 697e11306..70c56cc1b 100644 --- a/nebula/core/nebulaevents.py +++ b/nebula/core/nebulaevents.py @@ -159,6 +159,20 @@ async def is_concurrent(self) -> bool: def is_joining_federation(self): return self._joining_federation + +class NodeBlacklistedEvent(NodeEvent): + def __init__(self, node_addr, blacklisted: bool = False): + self._node_addr = node_addr + self._blacklisted = blacklisted + + def __str__(self): + return f"Node addr: {self._node_addr} | Blacklisted: {self._blacklisted} | Recently disconnected: {not self._blacklisted}" + + async def get_event_data(self) -> tuple[str, bool]: + return (self._node_addr, self._blacklisted) + + async def is_concurrent(self): + return True class NodeFoundEvent(NodeEvent): def __init__(self, node_addr): diff --git a/nebula/core/network/actions.py b/nebula/core/network/actions.py index b8b9bb9b9..d065d6803 100644 --- a/nebula/core/network/actions.py +++ b/nebula/core/network/actions.py @@ -44,6 +44,11 @@ class OfferAction(Enum): class LinkAction(Enum): CONNECT_TO = nebula_pb2.LinkMessage.Action.CONNECT_TO DISCONNECT_FROM = nebula_pb2.LinkMessage.Action.DISCONNECT_FROM + +class ReputationAction(Enum): + SHARE_REPUTATION = nebula_pb2.LinkMessage.Action.SHARE_REPUTATION + START_TRIAL = nebula_pb2.LinkMessage.Action.START_TRIAL + SUBMIT_VERDICT = nebula_pb2.LinkMessage.Action.SUBMIT_VERDICT ACTION_CLASSES = { @@ -54,6 +59,7 @@ class LinkAction(Enum): "discover": DiscoverAction, "offer": OfferAction, "link": LinkAction, + "reputation": ReputationAction, } diff --git a/nebula/core/network/blacklist.py b/nebula/core/network/blacklist.py index 6f9364de2..97a03f317 100644 --- a/nebula/core/network/blacklist.py +++ b/nebula/core/network/blacklist.py @@ -2,6 +2,8 @@ import logging import time from nebula.core.utils.locker import Locker +from nebula.core.eventmanager import EventManager +from nebula.core.nebulaevents import NodeBlacklistedEvent BLACKLIST_EXPIRATION_TIME = 240 RECENTLY_DISCONNECTED_EXPIRE_TIME = 60 @@ -44,6 +46,8 @@ async def add_to_blacklist(self, addr): self._running = True asyncio.create_task(self._start_blacklist_cleaner()) await self._blacklisted_nodes_lock.release_async() + nbe = NodeBlacklistedEvent(addr, blacklisted=True) + asyncio.create_task(EventManager.get_instance().publish_node_event(nbe)) async def get_blacklist(self) -> set: bl = None @@ -117,6 +121,8 @@ async def add_recently_disconnected(self, addr): self._recently_disconnected.add(addr) self._recently_disconnected_lock.release_async() asyncio.create_task(self._remove_recently_disc(addr)) + nbe = NodeBlacklistedEvent(addr) + asyncio.create_task(EventManager.get_instance().publish_node_event(nbe)) async def clear_recently_disconected(self): self._recently_disconnected_lock.acquire_async() diff --git a/nebula/core/network/messages.py b/nebula/core/network/messages.py index 758890475..4af210572 100644 --- a/nebula/core/network/messages.py +++ b/nebula/core/network/messages.py @@ -63,7 +63,13 @@ def _define_message_templates(self): "weight": 1, }, }, - "reputation": {"parameters": ["reputation"], "defaults": {}}, + "reputation": { + "parameters": ["action", "reputation", "defendant", "verdict"], + "defaults": { + "defendant": "", + "verdict": "" + } + }, "discover": {"parameters": ["action"], "defaults": {}}, "link": {"parameters": ["action", "addrs"], "defaults": {}}, # Add additional message types here diff --git a/nebula/core/pb/nebula.proto b/nebula/core/pb/nebula.proto index b1c418fe4..4607f4334 100755 --- a/nebula/core/pb/nebula.proto +++ b/nebula/core/pb/nebula.proto @@ -111,6 +111,18 @@ message LinkMessage { string addrs = 2; } +message ReputationMessage { + enum Action { + SHARE_REPUTATION = 0; // Message to tell reputation state to a node + START_TRIAL = 1; // Message to start a trial + SUBMIT_VERDICT = 2; // Message to send back the verdict for a trial + } + Action action = 1; + string reputation = 2; + string defendant = 3; + string verdict = 4; +} + // Response transmits the outcome of a requested operation, including any errors. message ResponseMessage { string response = 1; // Outcome of the requested operation. diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py b/nebula/core/situationalawareness/awareness/sareputation/behaviorreputation.py similarity index 66% rename from nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py rename to nebula/core/situationalawareness/awareness/sareputation/behaviorreputation.py index a71fdc08c..59f987821 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/sostrainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/sareputation/behaviorreputation.py @@ -1,12 +1,9 @@ -from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.trainingpolicy import TrainingPolicy from nebula.core.utils.locker import Locker from collections import deque import logging +from collections import defaultdict from nebula.core.eventmanager import EventManager -from nebula.core.nebulaevents import UpdateReceivedEvent, AggregationEvent, RoundStartEvent, UpdateNeighborEvent, RoundEndEvent -from nebula.core.situationalawareness.awareness.suggestionbuffer import SuggestionBuffer -from nebula.core.situationalawareness.awareness.sautils.sacommand import SACommand, SACommandAction, SACommandPRIO, factory_sa_command -from nebula.core.network.communications import CommunicationsManager +from nebula.core.nebulaevents import UpdateReceivedEvent, AggregationEvent, RoundStartEvent, UpdateNeighborEvent, NodeBlacklistedEvent import time import asyncio @@ -39,13 +36,10 @@ def reset(self): self.tr = None self.tsle = None -# "Speed-Oriented Selection" (SOS) -class SOSTrainingPolicy(TrainingPolicy): +class BehaviorReputation(): MAX_HISTORIC_SIZE = 10 - SCORE_THRESHOLD = 0.7 + SCORE_THRESHOLD = 0.5 # Threshold to detect posible malicious nodes INACTIVE_THRESHOLD = 3 - GRACE_ROUNDS = 1 - CHECK_COOLDOWN = 1 W_UPDATE_FREQ = 0.25 # Update frequency weight W_UPDATE_LATENCY = 0.15 # update latency weight W_AGG_WAITING = 0.6 # time waited since start waiting for aggregation until update is received weight @@ -54,40 +48,65 @@ class SOSTrainingPolicy(TrainingPolicy): def __init__(self, config): self._addr = config["addr"] self._verbose = config["verbose"] - self._nodes : dict[str, tuple[deque, int, deque[TimeStamp], deque[TimeStamp]]] = {} - + self._nodes : dict[str, tuple[deque, int, deque[TimeStamp], deque[TimeStamp]]] = {} self._nodes_lock = Locker(name="nodes_lock", async_lock=True) - self._grace_rounds = self.GRACE_ROUNDS - self._last_check = 0 + self._historical_blacklist_activity: dict[str, tuple[int,int]] = defaultdict(tuple) + self._historical_blacklist_activity_lock = Locker(name="historical_blacklist_activity_lock", async_lock=True) + self._historical_behavior_scores: dict[str, deque[int]] = defaultdict(deque) self._internal_rounds_done = -1 self._last_aggregation_time = None + self._suspicious_nodes = set() + + @property + def hba(self): + """historical_blacklist_activity""" + return self._historical_blacklist_activity + + @property + def hbs(self): + """historical_behavior_scores""" + return self._historical_behavior_scores + + @property + def hba_lock(self): + return self._historical_blacklist_activity_lock def __str__(self): - return "SOS" + return "BehaviorReputation" async def init(self, config): async with self._nodes_lock: nodes = config["nodes"] self._nodes = { node_id: ( - deque(maxlen=self.MAX_HISTORIC_SIZE), # updates per round, - 0, # inactivity - deque(maxlen=self.MAX_HISTORIC_SIZE), # time gaps between updates - deque(maxlen=self.MAX_HISTORIC_SIZE) # times since last aggregation + deque(maxlen=self.MAX_HISTORIC_SIZE), # Updates per round, + 0, # Inactivity + deque(maxlen=self.MAX_HISTORIC_SIZE), # Time gaps between updates + deque(maxlen=self.MAX_HISTORIC_SIZE) # Times since last aggregation + ) for node_id in nodes + } + self.hbs = { + node_id: ( + deque(maxlen=self.MAX_HISTORIC_SIZE), # Historical scores ) for node_id in nodes } + await EventManager.get_instance().subscribe_node_event(UpdateReceivedEvent, self._process_update_received_event) await EventManager.get_instance().subscribe_node_event(RoundStartEvent, self._process_round_start) await EventManager.get_instance().subscribe_node_event(AggregationEvent, self._process_aggregation_event) - await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent, self.update_neighbors) - await self.register_sa_agent() + await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent, self._update_neighbors) + await EventManager.get_instance().subscribe_node_event(NodeBlacklistedEvent, self._process_node_blacklisted_event) async def _get_nodes(self): async with self._nodes_lock: nodes = self._nodes.copy() return nodes - async def _process_round_start(self, rse : RoundStartEvent): + def _update_behavior_scores(self, scores: dict): + for node, score in scores.items(): + self.hbs[node].append(score) + + async def _process_round_start(self, rse: RoundStartEvent): if self._verbose: logging.info("Processing round start event") if not self._last_aggregation_time: if self._verbose: logging.info("First round start timing assigment") @@ -95,7 +114,7 @@ async def _process_round_start(self, rse : RoundStartEvent): self._last_aggregation_time = start_time self._internal_rounds_done += 1 - async def _process_aggregation_event(self, are : AggregationEvent): + async def _process_aggregation_event(self, are: AggregationEvent): self._last_aggregation_time = time.time() if self._verbose: logging.info("Processing aggregation event") (_, expected_nodes, missing_nodes) = await are.get_event_data() @@ -106,7 +125,7 @@ async def _process_aggregation_event(self, are : AggregationEvent): history, missed_count, gap_btween_updts, time_since_agg = self._nodes[node] self._nodes[node] = (history, 0 if node not in missing_nodes else missed_count + 1, gap_btween_updts, time_since_agg) - async def _process_update_received_event(self, ure : UpdateReceivedEvent): + async def _process_update_received_event(self, ure: UpdateReceivedEvent): time_received = time.time() if self._verbose: logging.info("Processing Update Received event") (_, _, source, _, _) = await ure.get_event_data() @@ -135,7 +154,7 @@ async def _process_update_received_event(self, ure : UpdateReceivedEvent): self._nodes[source] = (history, missed_count, time_between_updts_historic, last_update_times) - async def update_neighbors(self, une : UpdateNeighborEvent): + async def _update_neighbors(self, une: UpdateNeighborEvent): node, remove = await une.get_event_data() async with self._nodes_lock: if remove: @@ -144,24 +163,32 @@ async def update_neighbors(self, une : UpdateNeighborEvent): if not node in self._nodes: self._nodes.update({node : (deque(maxlen=self.MAX_HISTORIC_SIZE), 0, float('inf'), float('inf'))}) + async def _process_node_blacklisted_event(self, nbe: NodeBlacklistedEvent): + node_addr, blacklisted = await nbe.get_event_data() + times_bl = 0 + times_rd = 0 + async with self.hba_lock: + if self.hba[node_addr]: + bl, rd = self.hba[node_addr] + times_bl = bl + times_rd = rd + if blacklisted: # It means forced blacklisted not recently disconnected + times_bl += 1 + else: + times_rd += 1 + self.hba[node_addr] = (times_bl, times_rd) + async def _evaluate(self): - if self._verbose: logging.info("Evaluating using speed-oriented strategy") - if self._grace_rounds: # Grace rounds - self._grace_rounds -= 1 - if self._verbose: logging.info("Grace time hasnt finished...") - return None - - if self._last_check == 0: - nodes = await self._get_nodes() - for node in nodes.keys(): - #logging.info(f"Node: {node}, {nodes[node][0]}") - updates_received = {x[1] for x in nodes[node][0] if x[0] == self._internal_rounds_done} - if self._verbose: logging.info(f"Node: {node} | Updates received this round: {updates_received}") - if self._verbose: logging.info(f"Time waited since last aggregation event {nodes[node][3][-1].tsle:.3f}") - else: - if self._verbose: logging.info(f"Evaluation is on cooldown... | {self.CHECK_COOLDOWN - self._last_check} rounds remaining") + if self._verbose: logging.info("Evaluating Behavior Reputation, generating score...") + + nodes = await self._get_nodes() + for node in nodes.keys(): + #logging.info(f"Node: {node}, {nodes[node][0]}") + updates_received = {x[1] for x in nodes[node][0] if x[0] == self._internal_rounds_done} + if self._verbose: logging.info(f"Node: {node} | Updates received this round: {updates_received}") + if self._verbose: logging.info(f"Time waited since last aggregation event {nodes[node][3][-1].tsle:.3f}") - # Extraer valores mΓ‘ximos y mΓ­nimos para normalizaciΓ³n + # Extract max and min values for normalization max_updates = max( ( max((x[1] for x in nodes[n][0] if x[0] == self._internal_rounds_done), default=0) @@ -170,6 +197,7 @@ async def _evaluate(self): default=1 ) + # Min latency observed min_latency = min( ( sum(t.tsle for t in nodes[n][2] if t.tsle is not None and t.tsle != float('inf')) / len(nodes[n][2]) @@ -180,6 +208,7 @@ async def _evaluate(self): default=1 ) + # Min wait time observed min_wait_time = min( ( sum(t.tsle for t in nodes[n][3]) / len(nodes[n][3]) if nodes[n][3] else float('inf') @@ -192,23 +221,23 @@ async def _evaluate(self): scores = {} for node, (history, missed_count, time_between_updts_historic, last_wait_times) in nodes.items(): - # 1. Frecuencia de updates normalizada + # 1. Normalized frequency updates_received = max((x[1] for x in history if x[0] == self._internal_rounds_done), default=0) F_updt_freq = updates_received / max_updates if max_updates > 0 else 0 - # 2. Latencia media entre updates normalizada + # 2. Normalized mean latency between updates valid_latencies = [t.tsle for t in time_between_updts_historic if t.tsle is not None and t.tsle != float('inf')] avg_latency = sum(valid_latencies) / len(valid_latencies) if valid_latencies else float('inf') F_updt_latency = min_latency / avg_latency if avg_latency > 0 and avg_latency != float('inf') else 0 - # 3. Tiempo medio desde ΓΊltima agregaciΓ³n normalizado + # 3. Normalized time since last aggregation avg_wait_time = sum(t.tsle for t in last_wait_times) / len(last_wait_times) if last_wait_times else float('inf') F_agg_waiting = min_wait_time / avg_wait_time if avg_wait_time > 0 else 0 - # 4. PenalizaciΓ³n por inactividad - P_n = missed_count*self.W_INACTIVITY_PEN # PenalizaciΓ³n inversamente proporcional + # 4. Inactivity penalty + P_n = missed_count*self.W_INACTIVITY_PEN - # Calcular puntuaciΓ³n final + # 5. Final score score = ( (self.W_UPDATE_FREQ * F_updt_freq) + (self.W_UPDATE_LATENCY * F_updt_latency) + @@ -217,7 +246,9 @@ async def _evaluate(self): ) scores[node] = score - # Ordenar nodos por puntuaciΓ³n descendente + self._update_behavior_scores(scores) + + # Order nodes sorted_nodes = sorted(scores.items(), key=lambda x: x[1], reverse=True) nodes_below_th = [x for x in sorted_nodes if x[1] < self.SCORE_THRESHOLD] @@ -225,39 +256,10 @@ async def _evaluate(self): for node, score in sorted_nodes: if self._verbose: logging.info(f"Node: {node} | Score: {score:.3f}") - if self._verbose: logging.info(f"Nodes below threshold: {nodes_below_th}") - - self._last_check = (self._last_check + 1) % self.CHECK_COOLDOWN - - return nodes_below_th - - async def get_evaluation_results(self): - nodes_to_discard = await self._evaluate() - for node_discarded in nodes_to_discard: - args = (node_discarded, False, True) - sac = factory_sa_command( - "connectivity", - SACommandAction.DISCONNECT, - self, - node_discarded, - SACommandPRIO.MEDIUM, - False, - CommunicationsManager.get_instance().disconnect, - *args - ) - await self.suggest_action(sac) - await self.notify_all_suggestions_done(RoundEndEvent) - - async def get_agent(self) -> str: - return "SATraining_SOSTP" - - async def register_sa_agent(self): - await SuggestionBuffer.get_instance().register_event_agents(RoundEndEvent, self) - - async def suggest_action(self, sac : SACommand): - await SuggestionBuffer.get_instance().register_suggestion(RoundEndEvent, self, sac) + if self._verbose: logging.info(f"Nodes below threshold: {nodes_below_th}") + + # Update suspicious nodes + self._suspicious_nodes.union({n for n in nodes_below_th}) + - async def notify_all_suggestions_done(self, event_type): - await SuggestionBuffer.get_instance().notify_all_suggestions_done_for_agent(self, event_type) - \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/sareputation/sareputation.py b/nebula/core/situationalawareness/awareness/sareputation/sareputation.py new file mode 100644 index 000000000..1a7194c2e --- /dev/null +++ b/nebula/core/situationalawareness/awareness/sareputation/sareputation.py @@ -0,0 +1,38 @@ +from nebula.core.situationalawareness.awareness.samodule import SAMComponent +from enum import Enum + +class ReputationCategory(Enum): # Reputational thresholds + HIGH_TRUSTED = 0.9 + TRUSTED = 0.8 + RESPECTED = 0.6 + SUSPICIOUS = 0.5 + MALICIOUS = 0.0 + +class ReputationScore(): + + def __init__(self, target, reputation_score): + self._target = target + self._reputation = reputation_score + self._reputation_category = self._assign_reputation_category(reputation_score) + + def _assign_reputation_category(self, score): + for category in sorted(ReputationCategory, key=lambda c: c.value, reverse=True): + if score >= category.value: + return category + + def get_category(self): + return self._reputation_category + + def get_reputational_score(self): + return self._reputation + + def update_reputation(self, reputation_score): + self._reputation = reputation_score + self._reputation_category = self._assign_reputation_category(reputation_score) + +class SAReputation(SAMComponent): + + async def init(self): + pass + async def sa_component_actions(self): + pass \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py index 78eb948c0..ffc8592f4 100644 --- a/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py +++ b/nebula/core/situationalawareness/awareness/satraining/trainingpolicy/trainingpolicy.py @@ -15,13 +15,11 @@ async def get_evaluation_results(self): def factory_training_policy(training_policy, config) -> TrainingPolicy: from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.bpstrainingpolicy import BPSTrainingPolicy from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.qdstrainingpolicy import QDSTrainingPolicy - from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.sostrainingpolicy import SOSTrainingPolicy from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.htstrainingpolicy import HTSTrainingPolicy options = { "bps": BPSTrainingPolicy, # "Broad-Propagation Strategy" (BPS) -- default value "qds": QDSTrainingPolicy, # "Quality-Driven Selection" (QDS) - "sos": SOSTrainingPolicy, # "Speed-Oriented Selection" (SOS) "hts": HTSTrainingPolicy, # "Hybrid Training Strategy" (HTS) } diff --git a/nebula/frontend/config/participant.json.example b/nebula/frontend/config/participant.json.example index db3ec0d34..ec5796b07 100755 --- a/nebula/frontend/config/participant.json.example +++ b/nebula/frontend/config/participant.json.example @@ -63,7 +63,6 @@ "mobility": false, "mobility_type": "topology", "topology_type": "", - "push_strategy": "slow", "radius_federation": 1000, "scheme_mobility": "random", "round_frequency": 1, @@ -147,6 +146,16 @@ "propagation_early_stop": 3, "history_size": 20 }, + "situational_awareness": { + "sa_components": { + "sanetwork": true + }, + "sanetwork": { + "strict_topology": true + } + "arbitatrion_policy": "sap", + "model_handler": "std" + }, "misc_args": { "grace_time_connection": 10, "grace_time_start_federation": 10 From c554a91cdd8e58ae9ebee92270b4e856c23f347b Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 23 Apr 2025 10:43:03 +0200 Subject: [PATCH 175/233] feature reputation messages --- nebula/core/eventmanager.py | 31 +++++-- nebula/core/pb/nebula_pb2.py | 83 ++++++++++--------- .../sareputation/behaviorreputation.py | 41 +++++++-- 3 files changed, 101 insertions(+), 54 deletions(-) diff --git a/nebula/core/eventmanager.py b/nebula/core/eventmanager.py index 6f4f71056..e4ff0d776 100755 --- a/nebula/core/eventmanager.py +++ b/nebula/core/eventmanager.py @@ -1,13 +1,10 @@ import asyncio import inspect import logging -from collections import defaultdict -from functools import wraps -from abc import ABC, abstractmethod from nebula.core.network.messages import MessageEvent from nebula.core.utils.locker import Locker from nebula.core.nebulaevents import AddonEvent, NodeEvent -from typing import Callable +from typing import Callable, Union class EventManager: _instance = None @@ -31,6 +28,8 @@ def _initialize(self, verbose=False): self._addons_event_lock = Locker("addons_event_lock", async_lock=True) self._node_events_subs: dict[type, list] = {} self._node_events_lock = Locker("node_events_lock", async_lock=True) + self._global_message_subscribers: list[Callable] = [] + self._global_message_subscribers_lock = Locker("global_message_subscribers_lock", async_lock=True) self._verbose = verbose self._initialized = True # Marca que ya se inicializΓ³ @@ -41,8 +40,14 @@ def get_instance(verbose=False): EventManager(verbose=verbose) return EventManager._instance - async def subscribe(self, event_type: tuple[str, str], callback: Callable): - """Register a callback for a specific event type.""" + async def subscribe(self, event_type: Union[tuple[str, str], None], callback: Callable): + """Register a callback for a message event.""" + if not event_type: + async with self._global_message_subscribers_lock: + self._global_message_subscribers.append(callback) + logging.info(f"EventManager | Subscribed callback for all message events: {event_type}") + return + async with self._message_events_lock: if event_type not in self._subscribers: self._subscribers[event_type] = [] @@ -68,6 +73,20 @@ async def publish(self, message_event: MessageEvent): callback(message_event.source, message_event.message) except Exception as e: logging.exception(f"EventManager | Error in callback for event {event_type}: {e}") + + # Global callbacks (callbacks for all message events) + async with self._global_message_subscribers_lock: + global_callbacks = self._global_message_subscribers.copy() + + for global_cb in global_callbacks: + try: + if self._verbose: logging.info(f"EventManager | Triggering callback for event: {event_type}, from source: {message_event.source}") + if asyncio.iscoroutinefunction(callback) or inspect.iscoroutine(callback): + await global_cb(message_event.source, message_event.message) + else: + global_cb(message_event.source, message_event.message) + except Exception as e: + logging.exception(f"EventManager | Error in callback for event {event_type}: {e}") async def subscribe_addonevent(self, addonEventType: type[AddonEvent], callback: Callable): """Register a callback for a specific type of AddonEvent.""" diff --git a/nebula/core/pb/nebula_pb2.py b/nebula/core/pb/nebula_pb2.py index d5c476808..7aed62cca 100755 --- a/nebula/core/pb/nebula_pb2.py +++ b/nebula/core/pb/nebula_pb2.py @@ -1,12 +1,11 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: nebula.proto -# Protobuf Python Version: 4.25.3 """Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -14,45 +13,49 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cnebula.proto\x12\x06nebula\"\xf5\x03\n\x07Wrapper\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x35\n\x11\x64iscovery_message\x18\x02 \x01(\x0b\x32\x18.nebula.DiscoveryMessageH\x00\x12\x31\n\x0f\x63ontrol_message\x18\x03 \x01(\x0b\x32\x16.nebula.ControlMessageH\x00\x12\x37\n\x12\x66\x65\x64\x65ration_message\x18\x04 \x01(\x0b\x32\x19.nebula.FederationMessageH\x00\x12-\n\rmodel_message\x18\x05 \x01(\x0b\x32\x14.nebula.ModelMessageH\x00\x12\x37\n\x12\x63onnection_message\x18\x06 \x01(\x0b\x32\x19.nebula.ConnectionMessageH\x00\x12\x33\n\x10response_message\x18\x07 \x01(\x0b\x32\x17.nebula.ResponseMessageH\x00\x12\x33\n\x10\x64iscover_message\x18\x08 \x01(\x0b\x32\x17.nebula.DiscoverMessageH\x00\x12-\n\roffer_message\x18\t \x01(\x0b\x32\x14.nebula.OfferMessageH\x00\x12+\n\x0clink_message\x18\n \x01(\x0b\x32\x13.nebula.LinkMessageH\x00\x42\t\n\x07message\"\x9e\x01\n\x10\x44iscoveryMessage\x12/\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1f.nebula.DiscoveryMessage.Action\x12\x10\n\x08latitude\x18\x02 \x01(\x02\x12\x11\n\tlongitude\x18\x03 \x01(\x02\"4\n\x06\x41\x63tion\x12\x0c\n\x08\x44ISCOVER\x10\x00\x12\x0c\n\x08REGISTER\x10\x01\x12\x0e\n\nDEREGISTER\x10\x02\"\x9a\x01\n\x0e\x43ontrolMessage\x12-\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1d.nebula.ControlMessage.Action\x12\x0b\n\x03log\x18\x02 \x01(\t\"L\n\x06\x41\x63tion\x12\t\n\x05\x41LIVE\x10\x00\x12\x0c\n\x08OVERHEAD\x10\x01\x12\x0c\n\x08MOBILITY\x10\x02\x12\x0c\n\x08RECOVERY\x10\x03\x12\r\n\tWEAK_LINK\x10\x04\"\xcd\x01\n\x11\x46\x65\x64\x65rationMessage\x12\x30\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32 .nebula.FederationMessage.Action\x12\x11\n\targuments\x18\x02 \x03(\t\x12\r\n\x05round\x18\x03 \x01(\x05\"d\n\x06\x41\x63tion\x12\x14\n\x10\x46\x45\x44\x45RATION_START\x10\x00\x12\x0e\n\nREPUTATION\x10\x01\x12\x1e\n\x1a\x46\x45\x44\x45RATION_MODELS_INCLUDED\x10\x02\x12\x14\n\x10\x46\x45\x44\x45RATION_READY\x10\x03\"A\n\x0cModelMessage\x12\x12\n\nparameters\x18\x01 \x01(\x0c\x12\x0e\n\x06weight\x18\x02 \x01(\x03\x12\r\n\x05round\x18\x03 \x01(\x05\"\x8f\x01\n\x11\x43onnectionMessage\x12\x30\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32 .nebula.ConnectionMessage.Action\"H\n\x06\x41\x63tion\x12\x0b\n\x07\x43ONNECT\x10\x00\x12\x0e\n\nDISCONNECT\x10\x01\x12\x10\n\x0cLATE_CONNECT\x10\x02\x12\x0f\n\x0bRESTRUCTURE\x10\x03\"r\n\x0f\x44iscoverMessage\x12.\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1e.nebula.DiscoverMessage.Action\"/\n\x06\x41\x63tion\x12\x11\n\rDISCOVER_JOIN\x10\x00\x12\x12\n\x0e\x44ISCOVER_NODES\x10\x01\"\xce\x01\n\x0cOfferMessage\x12+\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1b.nebula.OfferMessage.Action\x12\x13\n\x0bn_neighbors\x18\x02 \x01(\x02\x12\x0c\n\x04loss\x18\x03 \x01(\x02\x12\x12\n\nparameters\x18\x04 \x01(\x0c\x12\x0e\n\x06rounds\x18\x05 \x01(\x05\x12\r\n\x05round\x18\x06 \x01(\x05\x12\x0e\n\x06\x65pochs\x18\x07 \x01(\x05\"+\n\x06\x41\x63tion\x12\x0f\n\x0bOFFER_MODEL\x10\x00\x12\x10\n\x0cOFFER_METRIC\x10\x01\"w\n\x0bLinkMessage\x12*\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1a.nebula.LinkMessage.Action\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x01(\t\"-\n\x06\x41\x63tion\x12\x0e\n\nCONNECT_TO\x10\x00\x12\x13\n\x0f\x44ISCONNECT_FROM\x10\x01\"#\n\x0fResponseMessage\x12\x10\n\x08response\x18\x01 \x01(\tb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cnebula.proto\x12\x06nebula\"\xf5\x03\n\x07Wrapper\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x35\n\x11\x64iscovery_message\x18\x02 \x01(\x0b\x32\x18.nebula.DiscoveryMessageH\x00\x12\x31\n\x0f\x63ontrol_message\x18\x03 \x01(\x0b\x32\x16.nebula.ControlMessageH\x00\x12\x37\n\x12\x66\x65\x64\x65ration_message\x18\x04 \x01(\x0b\x32\x19.nebula.FederationMessageH\x00\x12-\n\rmodel_message\x18\x05 \x01(\x0b\x32\x14.nebula.ModelMessageH\x00\x12\x37\n\x12\x63onnection_message\x18\x06 \x01(\x0b\x32\x19.nebula.ConnectionMessageH\x00\x12\x33\n\x10response_message\x18\x07 \x01(\x0b\x32\x17.nebula.ResponseMessageH\x00\x12\x33\n\x10\x64iscover_message\x18\x08 \x01(\x0b\x32\x17.nebula.DiscoverMessageH\x00\x12-\n\roffer_message\x18\t \x01(\x0b\x32\x14.nebula.OfferMessageH\x00\x12+\n\x0clink_message\x18\n \x01(\x0b\x32\x13.nebula.LinkMessageH\x00\x42\t\n\x07message\"\x9e\x01\n\x10\x44iscoveryMessage\x12/\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1f.nebula.DiscoveryMessage.Action\x12\x10\n\x08latitude\x18\x02 \x01(\x02\x12\x11\n\tlongitude\x18\x03 \x01(\x02\"4\n\x06\x41\x63tion\x12\x0c\n\x08\x44ISCOVER\x10\x00\x12\x0c\n\x08REGISTER\x10\x01\x12\x0e\n\nDEREGISTER\x10\x02\"\x9a\x01\n\x0e\x43ontrolMessage\x12-\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1d.nebula.ControlMessage.Action\x12\x0b\n\x03log\x18\x02 \x01(\t\"L\n\x06\x41\x63tion\x12\t\n\x05\x41LIVE\x10\x00\x12\x0c\n\x08OVERHEAD\x10\x01\x12\x0c\n\x08MOBILITY\x10\x02\x12\x0c\n\x08RECOVERY\x10\x03\x12\r\n\tWEAK_LINK\x10\x04\"\xcd\x01\n\x11\x46\x65\x64\x65rationMessage\x12\x30\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32 .nebula.FederationMessage.Action\x12\x11\n\targuments\x18\x02 \x03(\t\x12\r\n\x05round\x18\x03 \x01(\x05\"d\n\x06\x41\x63tion\x12\x14\n\x10\x46\x45\x44\x45RATION_START\x10\x00\x12\x0e\n\nREPUTATION\x10\x01\x12\x1e\n\x1a\x46\x45\x44\x45RATION_MODELS_INCLUDED\x10\x02\x12\x14\n\x10\x46\x45\x44\x45RATION_READY\x10\x03\"A\n\x0cModelMessage\x12\x12\n\nparameters\x18\x01 \x01(\x0c\x12\x0e\n\x06weight\x18\x02 \x01(\x03\x12\r\n\x05round\x18\x03 \x01(\x05\"\x8f\x01\n\x11\x43onnectionMessage\x12\x30\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32 .nebula.ConnectionMessage.Action\"H\n\x06\x41\x63tion\x12\x0b\n\x07\x43ONNECT\x10\x00\x12\x0e\n\nDISCONNECT\x10\x01\x12\x10\n\x0cLATE_CONNECT\x10\x02\x12\x0f\n\x0bRESTRUCTURE\x10\x03\"r\n\x0f\x44iscoverMessage\x12.\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1e.nebula.DiscoverMessage.Action\"/\n\x06\x41\x63tion\x12\x11\n\rDISCOVER_JOIN\x10\x00\x12\x12\n\x0e\x44ISCOVER_NODES\x10\x01\"\xce\x01\n\x0cOfferMessage\x12+\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1b.nebula.OfferMessage.Action\x12\x13\n\x0bn_neighbors\x18\x02 \x01(\x02\x12\x0c\n\x04loss\x18\x03 \x01(\x02\x12\x12\n\nparameters\x18\x04 \x01(\x0c\x12\x0e\n\x06rounds\x18\x05 \x01(\x05\x12\r\n\x05round\x18\x06 \x01(\x05\x12\x0e\n\x06\x65pochs\x18\x07 \x01(\x05\"+\n\x06\x41\x63tion\x12\x0f\n\x0bOFFER_MODEL\x10\x00\x12\x10\n\x0cOFFER_METRIC\x10\x01\"w\n\x0bLinkMessage\x12*\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1a.nebula.LinkMessage.Action\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x01(\t\"-\n\x06\x41\x63tion\x12\x0e\n\nCONNECT_TO\x10\x00\x12\x13\n\x0f\x44ISCONNECT_FROM\x10\x01\"\xc2\x01\n\x11ReputationMessage\x12\x30\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32 .nebula.ReputationMessage.Action\x12\x12\n\nreputation\x18\x02 \x01(\t\x12\x11\n\tdefendant\x18\x03 \x01(\t\x12\x0f\n\x07verdict\x18\x04 \x01(\t\"C\n\x06\x41\x63tion\x12\x14\n\x10SHARE_REPUTATION\x10\x00\x12\x0f\n\x0bSTART_TRIAL\x10\x01\x12\x12\n\x0eSUBMIT_VERDICT\x10\x02\"#\n\x0fResponseMessage\x12\x10\n\x08response\x18\x01 \x01(\tb\x06proto3') -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'nebula_pb2', _globals) +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'nebula_pb2', globals()) if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None - _globals['_WRAPPER']._serialized_start=25 - _globals['_WRAPPER']._serialized_end=526 - _globals['_DISCOVERYMESSAGE']._serialized_start=529 - _globals['_DISCOVERYMESSAGE']._serialized_end=687 - _globals['_DISCOVERYMESSAGE_ACTION']._serialized_start=635 - _globals['_DISCOVERYMESSAGE_ACTION']._serialized_end=687 - _globals['_CONTROLMESSAGE']._serialized_start=690 - _globals['_CONTROLMESSAGE']._serialized_end=844 - _globals['_CONTROLMESSAGE_ACTION']._serialized_start=768 - _globals['_CONTROLMESSAGE_ACTION']._serialized_end=844 - _globals['_FEDERATIONMESSAGE']._serialized_start=847 - _globals['_FEDERATIONMESSAGE']._serialized_end=1052 - _globals['_FEDERATIONMESSAGE_ACTION']._serialized_start=952 - _globals['_FEDERATIONMESSAGE_ACTION']._serialized_end=1052 - _globals['_MODELMESSAGE']._serialized_start=1054 - _globals['_MODELMESSAGE']._serialized_end=1119 - _globals['_CONNECTIONMESSAGE']._serialized_start=1122 - _globals['_CONNECTIONMESSAGE']._serialized_end=1265 - _globals['_CONNECTIONMESSAGE_ACTION']._serialized_start=1193 - _globals['_CONNECTIONMESSAGE_ACTION']._serialized_end=1265 - _globals['_DISCOVERMESSAGE']._serialized_start=1267 - _globals['_DISCOVERMESSAGE']._serialized_end=1381 - _globals['_DISCOVERMESSAGE_ACTION']._serialized_start=1334 - _globals['_DISCOVERMESSAGE_ACTION']._serialized_end=1381 - _globals['_OFFERMESSAGE']._serialized_start=1384 - _globals['_OFFERMESSAGE']._serialized_end=1590 - _globals['_OFFERMESSAGE_ACTION']._serialized_start=1547 - _globals['_OFFERMESSAGE_ACTION']._serialized_end=1590 - _globals['_LINKMESSAGE']._serialized_start=1592 - _globals['_LINKMESSAGE']._serialized_end=1711 - _globals['_LINKMESSAGE_ACTION']._serialized_start=1666 - _globals['_LINKMESSAGE_ACTION']._serialized_end=1711 - _globals['_RESPONSEMESSAGE']._serialized_start=1713 - _globals['_RESPONSEMESSAGE']._serialized_end=1748 + _WRAPPER._serialized_start=25 + _WRAPPER._serialized_end=526 + _DISCOVERYMESSAGE._serialized_start=529 + _DISCOVERYMESSAGE._serialized_end=687 + _DISCOVERYMESSAGE_ACTION._serialized_start=635 + _DISCOVERYMESSAGE_ACTION._serialized_end=687 + _CONTROLMESSAGE._serialized_start=690 + _CONTROLMESSAGE._serialized_end=844 + _CONTROLMESSAGE_ACTION._serialized_start=768 + _CONTROLMESSAGE_ACTION._serialized_end=844 + _FEDERATIONMESSAGE._serialized_start=847 + _FEDERATIONMESSAGE._serialized_end=1052 + _FEDERATIONMESSAGE_ACTION._serialized_start=952 + _FEDERATIONMESSAGE_ACTION._serialized_end=1052 + _MODELMESSAGE._serialized_start=1054 + _MODELMESSAGE._serialized_end=1119 + _CONNECTIONMESSAGE._serialized_start=1122 + _CONNECTIONMESSAGE._serialized_end=1265 + _CONNECTIONMESSAGE_ACTION._serialized_start=1193 + _CONNECTIONMESSAGE_ACTION._serialized_end=1265 + _DISCOVERMESSAGE._serialized_start=1267 + _DISCOVERMESSAGE._serialized_end=1381 + _DISCOVERMESSAGE_ACTION._serialized_start=1334 + _DISCOVERMESSAGE_ACTION._serialized_end=1381 + _OFFERMESSAGE._serialized_start=1384 + _OFFERMESSAGE._serialized_end=1590 + _OFFERMESSAGE_ACTION._serialized_start=1547 + _OFFERMESSAGE_ACTION._serialized_end=1590 + _LINKMESSAGE._serialized_start=1592 + _LINKMESSAGE._serialized_end=1711 + _LINKMESSAGE_ACTION._serialized_start=1666 + _LINKMESSAGE_ACTION._serialized_end=1711 + _REPUTATIONMESSAGE._serialized_start=1714 + _REPUTATIONMESSAGE._serialized_end=1908 + _REPUTATIONMESSAGE_ACTION._serialized_start=1841 + _REPUTATIONMESSAGE_ACTION._serialized_end=1908 + _RESPONSEMESSAGE._serialized_start=1910 + _RESPONSEMESSAGE._serialized_end=1945 # @@protoc_insertion_point(module_scope) diff --git a/nebula/core/situationalawareness/awareness/sareputation/behaviorreputation.py b/nebula/core/situationalawareness/awareness/sareputation/behaviorreputation.py index 59f987821..4d25314d8 100644 --- a/nebula/core/situationalawareness/awareness/sareputation/behaviorreputation.py +++ b/nebula/core/situationalawareness/awareness/sareputation/behaviorreputation.py @@ -5,8 +5,14 @@ from nebula.core.eventmanager import EventManager from nebula.core.nebulaevents import UpdateReceivedEvent, AggregationEvent, RoundStartEvent, UpdateNeighborEvent, NodeBlacklistedEvent import time +from enum import Enum import asyncio +class ThreatCategory(Enum): + FLOODING = "flooding" + INACTIVITY = "inactivity" + BAD_BEHAVIOR = "bad behavior" + class TimeStamp(): def __init__(self, time_received = None, time_since_last_event = None): self.tr = time_received @@ -39,6 +45,7 @@ def reset(self): class BehaviorReputation(): MAX_HISTORIC_SIZE = 10 SCORE_THRESHOLD = 0.5 # Threshold to detect posible malicious nodes + MAX_MESSAGES_PER_ROUND = 15 # Maximun number os messages allowed for each node INACTIVE_THRESHOLD = 3 W_UPDATE_FREQ = 0.25 # Update frequency weight W_UPDATE_LATENCY = 0.15 # update latency weight @@ -48,23 +55,26 @@ class BehaviorReputation(): def __init__(self, config): self._addr = config["addr"] self._verbose = config["verbose"] - self._nodes : dict[str, tuple[deque, int, deque[TimeStamp], deque[TimeStamp]]] = {} + self._nodes : dict[str, tuple[deque, int, deque[TimeStamp], deque[TimeStamp]]] = {} # Updates time registration self._nodes_lock = Locker(name="nodes_lock", async_lock=True) - self._historical_blacklist_activity: dict[str, tuple[int,int]] = defaultdict(tuple) + self._historical_blacklist_activity: dict[str, tuple[int,int]] = defaultdict(tuple) # Blacklist activity self._historical_blacklist_activity_lock = Locker(name="historical_blacklist_activity_lock", async_lock=True) - self._historical_behavior_scores: dict[str, deque[int]] = defaultdict(deque) + self._historical_behavior_scores: dict[str, deque[int]] = defaultdict(deque) # Scores registration + self._messages_received_per_round: dict[str, int] = defaultdict(int) + self._messages_received_per_round_lock = Locker(name="messages_received_per_round_lock", async_lock=True) self._internal_rounds_done = -1 self._last_aggregation_time = None self._suspicious_nodes = set() + self._suspicious_nodes_lock = Locker(name="suspicious_nodes_lock", async_lock=True) @property def hba(self): - """historical_blacklist_activity""" + """historical blacklist activity""" return self._historical_blacklist_activity @property def hbs(self): - """historical_behavior_scores""" + """historical behavior scores""" return self._historical_behavior_scores @property @@ -72,7 +82,7 @@ def hba_lock(self): return self._historical_blacklist_activity_lock def __str__(self): - return "BehaviorReputation" + return "Behavior Reputation" async def init(self, config): async with self._nodes_lock: @@ -96,6 +106,7 @@ async def init(self, config): await EventManager.get_instance().subscribe_node_event(AggregationEvent, self._process_aggregation_event) await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent, self._update_neighbors) await EventManager.get_instance().subscribe_node_event(NodeBlacklistedEvent, self._process_node_blacklisted_event) + await EventManager.get_instance().subscribe(None, self._process_messages_received) async def _get_nodes(self): async with self._nodes_lock: @@ -114,6 +125,9 @@ async def _process_round_start(self, rse: RoundStartEvent): self._last_aggregation_time = start_time self._internal_rounds_done += 1 + async with self._messages_received_per_round_lock: + self._messages_received_per_round.clear() + async def _process_aggregation_event(self, are: AggregationEvent): self._last_aggregation_time = time.time() if self._verbose: logging.info("Processing aggregation event") @@ -178,6 +192,16 @@ async def _process_node_blacklisted_event(self, nbe: NodeBlacklistedEvent): times_rd += 1 self.hba[node_addr] = (times_bl, times_rd) + async def _process_messages_received(self, source, message): + async with self._messages_received_per_round_lock: + n_messages = self._messages_received_per_round[source] + n_messages += 1 + if n_messages >= self.MAX_MESSAGES_PER_ROUND: + async with self._suspicious_nodes_lock: + self._suspicious_nodes.union({(source, ThreatCategory.FLOODING)}) + self._messages_received_per_round[source] = n_messages + + async def _evaluate(self): if self._verbose: logging.info("Evaluating Behavior Reputation, generating score...") @@ -248,7 +272,7 @@ async def _evaluate(self): self._update_behavior_scores(scores) - # Order nodes + # Final Step sorted_nodes = sorted(scores.items(), key=lambda x: x[1], reverse=True) nodes_below_th = [x for x in sorted_nodes if x[1] < self.SCORE_THRESHOLD] @@ -259,7 +283,8 @@ async def _evaluate(self): if self._verbose: logging.info(f"Nodes below threshold: {nodes_below_th}") # Update suspicious nodes - self._suspicious_nodes.union({n for n in nodes_below_th}) + async with self._suspicious_nodes_lock: + self._suspicious_nodes.union({(n, ThreatCategory.BAD_BEHAVIOR) for n in nodes_below_th}) From 7c4144bfb4abde82211f60dbd6a9086976d6ccbf Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 24 Apr 2025 13:01:01 +0200 Subject: [PATCH 176/233] feature dinamically loading sa components --- .../awareness/samodule.py | 82 ++++++++++++++----- .../awareness/sanetwork/sanetwork.py | 9 +- .../sareputation/behaviorreputation.py | 31 ++++++- .../candidateselection/fccandidateselector.py | 7 -- .../stdcandidateselector.py | 1 + .../core/situationalawareness/nodemanager.py | 16 ++-- .../frontend/config/participant.json.example | 6 +- 7 files changed, 106 insertions(+), 46 deletions(-) diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index f430f6e2d..d92bc2667 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -2,10 +2,11 @@ import asyncio import logging from nebula.addons.functions import print_msg_box +import importlib.util +import os from nebula.core.situationalawareness.awareness.suggestionbuffer import SuggestionBuffer from nebula.core.situationalawareness.awareness.sautils.sacommand import SACommand from nebula.core.utils.locker import Locker -from nebula.core.nebulaevents import RoundEndEvent from nebula.core.eventmanager import EventManager from nebula.core.nebulaevents import RoundEndEvent, AggregationEvent from nebula.core.network.communications import CommunicationsManager @@ -26,9 +27,11 @@ async def sa_component_actions(self): class SAModule: + MODULE_PATH = "nebula/nebula/core/situationalawareness/awareness" + def __init__( self, - nodemanager, + config, addr, topology, verbose = False, @@ -39,6 +42,7 @@ def __init__( title="Situational Awareness module", ) logging.info("🌐 Initializing SAModule") + self._config = config self._addr = addr self._topology = topology self._node_manager: NodeManager = nodemanager @@ -51,31 +55,26 @@ def __init__( self._communciation_manager = CommunicationsManager.get_instance() self._sys_monitor = SystemMonitor() self._arbitatrion_policy = factory_arbitatrion_policy("sad", True) + self._sa_components: dict[str, SAMComponent] = {} self._verbose = verbose - - @property - def nm(self): - return self._node_manager - @property def san(self): + """Situational Awareness Network""" return self._situational_awareness_network - @property - def sat(self): - return self._situational_awareness_training - @property def cm(self): return self._communciation_manager @property def sb(self): + """Suggestion Buffer""" return self._suggestion_buffer @property def ab(self): + """Arbitatrion Policy""" return self._arbitatrion_policy async def init(self): @@ -84,12 +83,11 @@ async def init(self): self._situational_awareness_network = SANetwork(self, self._addr, self._topology, verbose=True) self._situational_awareness_training = SATraining(self, self._addr, "qds", "fastreboot", verbose=True) await self.san.init() - await self.sat.init() await EventManager.get_instance().subscribe_node_event(RoundEndEvent, self._process_round_end_event) await EventManager.get_instance().subscribe_node_event(AggregationEvent, self._process_aggregation_event) def is_additional_participant(self): - return self.nm.is_additional_participant() + return self._config.participant["mobility_args"]["additional_node"]["status"] """ ############################### # REESTRUCTURE TOPOLOGY # @@ -107,15 +105,9 @@ def get_restructure_process_lock(self): def get_nodes_known(self, neighbors_too=False, neighbors_only=False): return self.san.get_nodes_known(neighbors_too, neighbors_only) - async def neighbors_left(self): - return await self.san.neighbors_left() - def accept_connection(self, source, joining=False): return self.san.accept_connection(source, joining) - def need_more_neighbors(self): - return self.san.need_more_neighbors() - def get_actions(self): return self.san.get_actions() @@ -190,4 +182,56 @@ async def _arbitatrion_suggestions(self, event_type): logging.info("Arbitatrion finished") return valid_commands + + """ ############################### + # SA COMPONENT LOADING # + ############################### + """ + + async def loading_sa_components(self): + """Dynamically loads the SA Components defined in the JSON configuration.""" + sa_section = self._config.participant["situational_awareness"] + components: dict = sa_section["sa_components"] + + for component_name, is_enabled in components.items(): + if is_enabled: + component_config = sa_section[component_name] + class_name = "SA" + component_name[2:].capitalize() + module_path = os.path.join(self.MODULE_PATH, component_name) + module_file = os.path.join(module_path, f"{component_name}.py") + + if os.path.exists(module_file): + module = self._load_component(class_name, module_file, component_config) + if module: + self._sa_components[component_name] = module + else: + logging.error(f"⚠️ SA Component {component_name} not found on {module_file}") + + await self._initialize_sa_components() + await self._set_minimal_requirements() + + async def _load_component(self, class_name, component_file, config): + """Loads a SA Component dynamically and initializes it with its configuration.""" + spec = importlib.util.spec_from_file_location(class_name, component_file) + if spec and spec.loader: + component = importlib.util.module_from_spec(spec) + spec.loader.exec_module(component) + if hasattr(component, class_name): # Verify if class exists + return getattr(component, class_name)(config) # Create and instance using component config + else: + logging.error(f"⚠️ Cannot create {class_name} SA Component, class not found on {component_file}") + return None + + async def _initialize_sa_components(self): + if self._sa_components: + for sacomp in self._sa_components.values(): + await sacomp.init() + + async def _set_minimal_requirements(self): + if self._sa_components: + self._situational_awareness_network = self._sa_components["sanetwork"] + else: + raise ValueError("SA Network not found") + + diff --git a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py index 5c1884041..26ec86fa2 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py @@ -49,6 +49,7 @@ def __init__( @property def sam(self): + """SA Module""" return self._sam @property @@ -57,10 +58,12 @@ def cm(self): @property def np(self): + """Neighbor Policy""" return self._neighbor_policy @property def sana(self): + """SA Network Agent""" return self._sa_network_agent async def init(self): @@ -146,8 +149,7 @@ async def beacon_received(self, beacon_recieved_event : BeaconRecievedEvent): addr, geoloc = await beacon_recieved_event.get_event_data() latitude, longitude = geoloc nfe = NodeFoundEvent(addr) - asyncio.create_task(EventManager.get_instance().publish_node_event(nfe)) - #logging.info(f"Beacon received SANetwork, source: {addr}, geolocalization: {latitude},{longitude}") + asyncio.create_task(EventManager.get_instance().publish_node_event(nfe)) """ ############################### # REESTRUCTURE TOPOLOGY # @@ -171,7 +173,6 @@ async def _analize_topology_robustness(self): if not self._restructure_process_lock.locked(): if not await self.neighbors_left(): if self._verbose: logging.info("No Neighbors left | reconnecting with Federation") - #await self.reconnect_to_federation() await self.sana.create_and_suggest_action(SACommandAction.RECONNECT, self.reconnect_to_federation, None) elif self.np.need_more_neighbors() and self._restructure_available(): if self._verbose: logging.info("Insufficient Robustness | Upgrading robustness | Searching for more connections") @@ -183,7 +184,6 @@ async def _analize_topology_robustness(self): else: pass await self.sana.create_and_suggest_action(SACommandAction.SEARCH_CONNECTIONS, self.upgrade_connection_robustness, possible_neighbors) - # asyncio.create_task(self.upgrade_connection_robustness(possible_neighbors)) else: if self._verbose: logging.info("Sufficient Robustness | no actions required") await self.sana.create_and_suggest_action(SACommandAction.MAINTAIN_CONNECTIONS) @@ -193,7 +193,6 @@ async def _analize_topology_robustness(self): async def reconnect_to_federation(self): self._restructure_process_lock.acquire() await self.cm.clear_restrictions() - await asyncio.sleep(120) #TODO remove # If we got some refs, try to reconnect to them if len(self.np.get_nodes_known()) > 0: if self._verbose: logging.info("Reconnecting | Addrs availables") diff --git a/nebula/core/situationalawareness/awareness/sareputation/behaviorreputation.py b/nebula/core/situationalawareness/awareness/sareputation/behaviorreputation.py index 4d25314d8..6047e861f 100644 --- a/nebula/core/situationalawareness/awareness/sareputation/behaviorreputation.py +++ b/nebula/core/situationalawareness/awareness/sareputation/behaviorreputation.py @@ -8,7 +8,7 @@ from enum import Enum import asyncio -class ThreatCategory(Enum): +class ThreatCategoryBehavior(Enum): FLOODING = "flooding" INACTIVITY = "inactivity" BAD_BEHAVIOR = "bad behavior" @@ -46,6 +46,7 @@ class BehaviorReputation(): MAX_HISTORIC_SIZE = 10 SCORE_THRESHOLD = 0.5 # Threshold to detect posible malicious nodes MAX_MESSAGES_PER_ROUND = 15 # Maximun number os messages allowed for each node + MAX_INACTIVITY_ALLOWED = 3 # Max number of consecutive inactive rounds INACTIVE_THRESHOLD = 3 W_UPDATE_FREQ = 0.25 # Update frequency weight W_UPDATE_LATENCY = 0.15 # update latency weight @@ -55,6 +56,10 @@ class BehaviorReputation(): def __init__(self, config): self._addr = config["addr"] self._verbose = config["verbose"] + # historic - register of updates per round + # inactivity - consecutive rounds inactivity + # t_between_upd - time between updates + # t_last_agg - time waited since last aggregation until receive update self._nodes : dict[str, tuple[deque, int, deque[TimeStamp], deque[TimeStamp]]] = {} # Updates time registration self._nodes_lock = Locker(name="nodes_lock", async_lock=True) self._historical_blacklist_activity: dict[str, tuple[int,int]] = defaultdict(tuple) # Blacklist activity @@ -108,6 +113,17 @@ async def init(self, config): await EventManager.get_instance().subscribe_node_event(NodeBlacklistedEvent, self._process_node_blacklisted_event) await EventManager.get_instance().subscribe(None, self._process_messages_received) + async def get_behavior_scores(self, historical=False): + if historical: + return self.hbs.copy() + else: + last_scores = {node: scores[-1] for node,scores in self.hbs.items()} + return last_scores + + async def get_suspicious_nodes(self): + async with self._suspicious_nodes_lock: + return self._suspicious_nodes.copy() + async def _get_nodes(self): async with self._nodes_lock: nodes = self._nodes.copy() @@ -198,10 +214,9 @@ async def _process_messages_received(self, source, message): n_messages += 1 if n_messages >= self.MAX_MESSAGES_PER_ROUND: async with self._suspicious_nodes_lock: - self._suspicious_nodes.union({(source, ThreatCategory.FLOODING)}) + self._suspicious_nodes.union({(source, ThreatCategoryBehavior.FLOODING)}) self._messages_received_per_round[source] = n_messages - async def _evaluate(self): if self._verbose: logging.info("Evaluating Behavior Reputation, generating score...") @@ -245,6 +260,14 @@ async def _evaluate(self): scores = {} for node, (history, missed_count, time_between_updts_historic, last_wait_times) in nodes.items(): + + # Check inactivity beyond max tolerance + if missed_count >= self.MAX_INACTIVITY_ALLOWED: + async with self._suspicious_nodes_lock: + self._suspicious_nodes.union(((node, ThreatCategoryBehavior.INACTIVITY))) + scores[node] = 0.0 + continue + # 1. Normalized frequency updates_received = max((x[1] for x in history if x[0] == self._internal_rounds_done), default=0) F_updt_freq = updates_received / max_updates if max_updates > 0 else 0 @@ -284,7 +307,7 @@ async def _evaluate(self): # Update suspicious nodes async with self._suspicious_nodes_lock: - self._suspicious_nodes.union({(n, ThreatCategory.BAD_BEHAVIOR) for n in nodes_below_th}) + self._suspicious_nodes.union({(n, ThreatCategoryBehavior.BAD_BEHAVIOR) for n in nodes_below_th}) diff --git a/nebula/core/situationalawareness/candidateselection/fccandidateselector.py b/nebula/core/situationalawareness/candidateselection/fccandidateselector.py index 3a91098d8..947eaec5a 100644 --- a/nebula/core/situationalawareness/candidateselection/fccandidateselector.py +++ b/nebula/core/situationalawareness/candidateselection/fccandidateselector.py @@ -19,15 +19,8 @@ def select_candidates(self): """ In Fully-Connected topology all candidates should be selected """ - #0145 - #listed = ["192.168.51.2:45001", "192.168.51.3:45002", "192.168.51.6:45005", "192.168.51.7:45006"] - #defined = [] self.candidates_lock.acquire() cdts = self.candidates.copy() - #for (addr,a,b) in cdts: - # if addr in listed: - # defined.append((addr,a,b)) - #cdts = defined self.candidates_lock.release() return cdts diff --git a/nebula/core/situationalawareness/candidateselection/stdcandidateselector.py b/nebula/core/situationalawareness/candidateselection/stdcandidateselector.py index 022677fc6..fd4c16398 100644 --- a/nebula/core/situationalawareness/candidateselection/stdcandidateselector.py +++ b/nebula/core/situationalawareness/candidateselection/stdcandidateselector.py @@ -20,6 +20,7 @@ def select_candidates(self): Select mean number of neighbors """ self.candidates_lock.acquire() + #TODO revisar mean_neighbors = sum(n for n, _ in self.candidates) / len(self.candidates) if self.candidates else 0 cdts = self.candidates[:mean_neighbors] self.candidates_lock.release() diff --git a/nebula/core/situationalawareness/nodemanager.py b/nebula/core/situationalawareness/nodemanager.py index 98daf849a..1371af65e 100644 --- a/nebula/core/situationalawareness/nodemanager.py +++ b/nebula/core/situationalawareness/nodemanager.py @@ -19,6 +19,8 @@ class NodeManager: + OFFER_TIMEOUT = 5 + def __init__( self, aditional_participant, @@ -45,13 +47,11 @@ def __init__( self.pending_confirmation_from_nodes = set() self.pending_confirmation_from_nodes_lock = Locker(name="pending_confirmation_from_nodes_lock", async_lock=True) self.accept_candidates_lock = Locker(name="accept_candidates_lock") - self.recieve_offer_timer = 5 + self.recieve_offer_timer = self.OFFER_TIMEOUT self.discarded_offers_addr_lock = Locker(name="discarded_offers_addr_lock") self.discarded_offers_addr = [] - - self._desc_done = False #TODO remove - self._situational_awareness_module = SAModule(self, self.engine.addr, topology, True) + self._situational_awareness_module = SAModule(self, self.config, self.engine.addr, topology, True) self._verbose = verbose @property @@ -72,6 +72,7 @@ def model_handler(self): @property def sam(self): + """Situational Awareness Module""" return self._situational_awareness_module def is_additional_participant(self): @@ -227,7 +228,6 @@ async def start_late_connection_process(self, connected=False, msg_type="discove connections_stablished = await self.cm.stablish_connection_to_federation(msg_type, addrs_known) # wait offer - #TODO actualizar con la informacion de latencias if self._verbose: logging.info(f"Connections stablish after finding federation: {connections_stablished}") if connections_stablished: if self._verbose: logging.info(f"Waiting: {self.recieve_offer_timer}s to receive offers from federation") @@ -252,14 +252,12 @@ async def start_late_connection_process(self, connected=False, msg_type="discove await self.cm.send_message(addr, msg) await asyncio.sleep(1) except asyncio.CancelledError: - await self.update_neighbors(addr, remove=True) + upe = UpdateNeighborEvent(addr, removed=True) + asyncio.create_task(EventManager.get_instance().publish_node_event(upe)) if self._verbose: logging.info("Error during stablishment") self.accept_candidates_lock.release() self.late_connection_process_lock.release() self.candidate_selector.remove_candidates() - # if not self._desc_done: #TODO remove - # self._desc_done = True - # asyncio.create_task(self.sam.san.stop_connections_with_federation()) # if no candidates, repeat process else: if self._verbose: logging.info("❗️ No Candidates found...") diff --git a/nebula/frontend/config/participant.json.example b/nebula/frontend/config/participant.json.example index ec5796b07..993f8285d 100755 --- a/nebula/frontend/config/participant.json.example +++ b/nebula/frontend/config/participant.json.example @@ -151,8 +151,10 @@ "sanetwork": true }, "sanetwork": { - "strict_topology": true - } + "addr": "", + "strict_topology": true, + "verbose": true + }, "arbitatrion_policy": "sap", "model_handler": "std" }, From 1ffb5e220f01cd1782866f114d76993d5403218f Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Thu, 24 Apr 2025 18:31:24 +0200 Subject: [PATCH 177/233] refactor situational awareness module --- nebula/core/engine.py | 7 ++--- .../awareness/samodule.py | 8 ++--- .../awareness/sanetwork/sanetwork.py | 1 - .../satraining/weightstrategy/fastreboot.py | 4 +-- .../satraining/weightstrategy/momentum.py | 4 +-- .../candidateselection/__init__.py | 0 .../candidateselection/candidateselector.py | 8 ++--- .../candidateselection/fccandidateselector.py | 2 +- .../hetcandidateselector.py | 2 +- .../ringcandidateselector.py | 2 +- .../stdcandidateselector.py | 2 +- .../federationconnector.py} | 8 ++--- .../{ => discovery}/modelhandlers/__init__.py | 0 .../modelhandlers/aggmodelhandler.py | 2 +- .../modelhandlers/defaultmodelhandler.py | 8 ++--- .../modelhandlers/modelhandler.py | 6 ++-- .../modelhandlers/stdmodelhandler.py | 2 +- .../situationalawareness.py | 29 +++++++++++++++++++ 18 files changed, 59 insertions(+), 36 deletions(-) rename nebula/core/situationalawareness/{ => discovery}/candidateselection/__init__.py (100%) rename nebula/core/situationalawareness/{ => discovery}/candidateselection/candidateselector.py (61%) rename nebula/core/situationalawareness/{ => discovery}/candidateselection/fccandidateselector.py (90%) rename nebula/core/situationalawareness/{ => discovery}/candidateselection/hetcandidateselector.py (97%) rename nebula/core/situationalawareness/{ => discovery}/candidateselection/ringcandidateselector.py (90%) rename nebula/core/situationalawareness/{ => discovery}/candidateselection/stdcandidateselector.py (91%) rename nebula/core/situationalawareness/{nodemanager.py => discovery/federationconnector.py} (98%) rename nebula/core/situationalawareness/{ => discovery}/modelhandlers/__init__.py (100%) rename nebula/core/situationalawareness/{ => discovery}/modelhandlers/aggmodelhandler.py (93%) rename nebula/core/situationalawareness/{ => discovery}/modelhandlers/defaultmodelhandler.py (81%) rename nebula/core/situationalawareness/{ => discovery}/modelhandlers/modelhandler.py (66%) rename nebula/core/situationalawareness/{ => discovery}/modelhandlers/stdmodelhandler.py (93%) create mode 100644 nebula/core/situationalawareness/situationalawareness.py diff --git a/nebula/core/engine.py b/nebula/core/engine.py index f5714265b..ad53fb5eb 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -19,7 +19,8 @@ UpdateReceivedEvent, ) from nebula.core.network.communications import CommunicationsManager -from nebula.core.situationalawareness.nodemanager import NodeManager +from nebula.core.situationalawareness.discovery.federationconnector import FederationConnector +from nebula.core.situationalawareness.situationalawareness import SituationalAwareness from nebula.core.utils.locker import Locker logging.getLogger("requests").setLevel(logging.WARNING) @@ -143,8 +144,6 @@ def __init__( self.config.reload_config_file() self._cm = CommunicationsManager(engine=self) - # = CommunicationsManager.get_instance() - # Set the communication manager in the model (send messages from there) self._reporter = Reporter(config=self.config, trainer=self.trainer) @@ -162,7 +161,7 @@ def __init__( topology = self.config.participant["mobility_args"]["topology_type"] topology = topology.lower() model_handler = "std" # self.config.participant["mobility_args"]["model_handler"] - self._node_manager = NodeManager( + self._node_manager = FederationConnector( config.participant["mobility_args"]["additional_node"]["status"], topology, model_handler, diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/samodule.py index d92bc2667..f339aa874 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/samodule.py @@ -12,10 +12,7 @@ from nebula.core.network.communications import CommunicationsManager from nebula.core.situationalawareness.awareness.sautils.sasystemmonitor import SystemMonitor from nebula.core.situationalawareness.awareness.arbitatrionpolicies.arbitatrionpolicy import factory_arbitatrion_policy - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from nebula.core.situationalawareness.nodemanager import NodeManager +from nebula.core.situationalawareness.situationalawareness import ISAReasoner, ISADiscovery class SAMComponent(ABC): @abstractmethod @@ -26,7 +23,7 @@ async def sa_component_actions(self): raise NotImplementedError -class SAModule: +class SAModule(ISAReasoner): MODULE_PATH = "nebula/nebula/core/situationalawareness/awareness" def __init__( @@ -45,7 +42,6 @@ def __init__( self._config = config self._addr = addr self._topology = topology - self._node_manager: NodeManager = nodemanager self._situational_awareness_network = None self._situational_awareness_training = None self._restructure_process_lock = Locker(name="restructure_process_lock") diff --git a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py index 26ec86fa2..328d51ddd 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py @@ -206,7 +206,6 @@ async def reconnect_to_federation(self): async def upgrade_connection_robustness(self, possible_neighbors): self._restructure_process_lock.acquire() - # addrs_to_connect = self.neighbor_policy.get_nodes_known(neighbors_too=False) # If we got some refs, try to connect to them if possible_neighbors and len(possible_neighbors) > 0: if self._verbose: logging.info(f"Reestructuring | Addrs availables | addr list: {possible_neighbors}") diff --git a/nebula/core/situationalawareness/awareness/satraining/weightstrategy/fastreboot.py b/nebula/core/situationalawareness/awareness/satraining/weightstrategy/fastreboot.py index d11da965f..b7d018ca1 100644 --- a/nebula/core/situationalawareness/awareness/satraining/weightstrategy/fastreboot.py +++ b/nebula/core/situationalawareness/awareness/satraining/weightstrategy/fastreboot.py @@ -4,7 +4,7 @@ from nebula.core.utils.locker import Locker if TYPE_CHECKING: - from nebula.core.situationalawareness.nodemanager import NodeManager + from nebula.core.situationalawareness.discovery.federationconnector import FederationConnector VANILLA_LEARNING_RATE = 1e-3 FR_LEARNING_RATE = 1e-3 @@ -15,7 +15,7 @@ class FastReboot: def __init__( self, - node_manager: "NodeManager", + node_manager: "FederationConnector", max_rounds_application=MAX_ROUNDS, # Max rounds to be applied FastReboot weight_modifier=DEFAULT_WEIGHT_MODIFIER, default_learning_rate=VANILLA_LEARNING_RATE, # Stable value for learning rate diff --git a/nebula/core/situationalawareness/awareness/satraining/weightstrategy/momentum.py b/nebula/core/situationalawareness/awareness/satraining/weightstrategy/momentum.py index 7bfdefbb2..2cefa1573 100644 --- a/nebula/core/situationalawareness/awareness/satraining/weightstrategy/momentum.py +++ b/nebula/core/situationalawareness/awareness/satraining/weightstrategy/momentum.py @@ -9,7 +9,7 @@ from nebula.core.utils.locker import Locker if TYPE_CHECKING: - from nebula.core.situationalawareness.nodemanager import NodeManager + from nebula.core.situationalawareness.discovery.federationconnector import FederationConnector SimilarityMetricType = Callable[[OrderedDict, OrderedDict, bool], float | None] MappingSimilarityType = Callable[[float, float], Annotated[float, "Value in (0, 1]"]] @@ -28,7 +28,7 @@ class Momentum: def __init__( self, - node_manager: "NodeManager", + node_manager: "FederationConnector", nodes, dispersion_penalty=True, global_priority=GLOBAL_PRIORITY, diff --git a/nebula/core/situationalawareness/candidateselection/__init__.py b/nebula/core/situationalawareness/discovery/candidateselection/__init__.py similarity index 100% rename from nebula/core/situationalawareness/candidateselection/__init__.py rename to nebula/core/situationalawareness/discovery/candidateselection/__init__.py diff --git a/nebula/core/situationalawareness/candidateselection/candidateselector.py b/nebula/core/situationalawareness/discovery/candidateselection/candidateselector.py similarity index 61% rename from nebula/core/situationalawareness/candidateselection/candidateselector.py rename to nebula/core/situationalawareness/discovery/candidateselection/candidateselector.py index 71c2b28c5..6d18ff404 100644 --- a/nebula/core/situationalawareness/candidateselection/candidateselector.py +++ b/nebula/core/situationalawareness/discovery/candidateselection/candidateselector.py @@ -24,10 +24,10 @@ def any_candidate(self): pass def factory_CandidateSelector(topology) -> CandidateSelector: - from nebula.core.situationalawareness.candidateselection.stdcandidateselector import STDandidateSelector - from nebula.core.situationalawareness.candidateselection.fccandidateselector import FCCandidateSelector - from nebula.core.situationalawareness.candidateselection.hetcandidateselector import HETCandidateSelector - from nebula.core.situationalawareness.candidateselection.ringcandidateselector import RINGCandidateSelector + from nebula.core.situationalawareness.discovery.candidateselection.stdcandidateselector import STDandidateSelector + from nebula.core.situationalawareness.discovery.candidateselection.fccandidateselector import FCCandidateSelector + from nebula.core.situationalawareness.discovery.candidateselection.hetcandidateselector import HETCandidateSelector + from nebula.core.situationalawareness.discovery.candidateselection.ringcandidateselector import RINGCandidateSelector options = { "ring": RINGCandidateSelector, diff --git a/nebula/core/situationalawareness/candidateselection/fccandidateselector.py b/nebula/core/situationalawareness/discovery/candidateselection/fccandidateselector.py similarity index 90% rename from nebula/core/situationalawareness/candidateselection/fccandidateselector.py rename to nebula/core/situationalawareness/discovery/candidateselection/fccandidateselector.py index 947eaec5a..1043445fc 100644 --- a/nebula/core/situationalawareness/candidateselection/fccandidateselector.py +++ b/nebula/core/situationalawareness/discovery/candidateselection/fccandidateselector.py @@ -1,4 +1,4 @@ -from nebula.core.situationalawareness.candidateselection.candidateselector import CandidateSelector +from nebula.core.situationalawareness.discovery.candidateselection.candidateselector import CandidateSelector from nebula.core.utils.locker import Locker class FCCandidateSelector(CandidateSelector): diff --git a/nebula/core/situationalawareness/candidateselection/hetcandidateselector.py b/nebula/core/situationalawareness/discovery/candidateselection/hetcandidateselector.py similarity index 97% rename from nebula/core/situationalawareness/candidateselection/hetcandidateselector.py rename to nebula/core/situationalawareness/discovery/candidateselection/hetcandidateselector.py index 84766581d..cf775a985 100644 --- a/nebula/core/situationalawareness/candidateselection/hetcandidateselector.py +++ b/nebula/core/situationalawareness/discovery/candidateselection/hetcandidateselector.py @@ -1,4 +1,4 @@ -from nebula.core.situationalawareness.candidateselection.candidateselector import CandidateSelector +from nebula.core.situationalawareness.discovery.candidateselection.candidateselector import CandidateSelector from nebula.core.utils.locker import Locker class HETCandidateSelector(CandidateSelector): diff --git a/nebula/core/situationalawareness/candidateselection/ringcandidateselector.py b/nebula/core/situationalawareness/discovery/candidateselection/ringcandidateselector.py similarity index 90% rename from nebula/core/situationalawareness/candidateselection/ringcandidateselector.py rename to nebula/core/situationalawareness/discovery/candidateselection/ringcandidateselector.py index 02effc281..f7b2bffef 100644 --- a/nebula/core/situationalawareness/candidateselection/ringcandidateselector.py +++ b/nebula/core/situationalawareness/discovery/candidateselection/ringcandidateselector.py @@ -1,4 +1,4 @@ -from nebula.core.situationalawareness.candidateselection.candidateselector import CandidateSelector +from nebula.core.situationalawareness.discovery.candidateselection.candidateselector import CandidateSelector from nebula.core.utils.locker import Locker class RINGCandidateSelector(CandidateSelector): diff --git a/nebula/core/situationalawareness/candidateselection/stdcandidateselector.py b/nebula/core/situationalawareness/discovery/candidateselection/stdcandidateselector.py similarity index 91% rename from nebula/core/situationalawareness/candidateselection/stdcandidateselector.py rename to nebula/core/situationalawareness/discovery/candidateselection/stdcandidateselector.py index fd4c16398..d88e156df 100644 --- a/nebula/core/situationalawareness/candidateselection/stdcandidateselector.py +++ b/nebula/core/situationalawareness/discovery/candidateselection/stdcandidateselector.py @@ -1,4 +1,4 @@ -from nebula.core.situationalawareness.candidateselection.candidateselector import CandidateSelector +from nebula.core.situationalawareness.discovery.candidateselection.candidateselector import CandidateSelector from nebula.core.utils.locker import Locker class STDandidateSelector(CandidateSelector): diff --git a/nebula/core/situationalawareness/nodemanager.py b/nebula/core/situationalawareness/discovery/federationconnector.py similarity index 98% rename from nebula/core/situationalawareness/nodemanager.py rename to nebula/core/situationalawareness/discovery/federationconnector.py index 1371af65e..c193147af 100644 --- a/nebula/core/situationalawareness/nodemanager.py +++ b/nebula/core/situationalawareness/discovery/federationconnector.py @@ -3,8 +3,9 @@ from typing import TYPE_CHECKING from nebula.addons.functions import print_msg_box -from nebula.core.situationalawareness.candidateselection.candidateselector import factory_CandidateSelector -from nebula.core.situationalawareness.modelhandlers.modelhandler import factory_ModelHandler +from nebula.core.situationalawareness.discovery.candidateselection.candidateselector import factory_CandidateSelector +from nebula.core.situationalawareness.discovery.modelhandlers.modelhandler import factory_ModelHandler +from nebula.core.situationalawareness.situationalawareness import ISADiscovery, ISAReasoner from nebula.core.situationalawareness.awareness.samodule import SAModule from nebula.core.utils.locker import Locker from nebula.core.eventmanager import EventManager @@ -18,7 +19,7 @@ RESTRUCTURE_COOLDOWN = 5 -class NodeManager: +class FederationConnector(ISADiscovery): OFFER_TIMEOUT = 5 def __init__( @@ -441,7 +442,6 @@ async def _link_disconnect_from_callback(self, source, message): addrs = message.addrs for addr in addrs.split(): await self.cm.disconnect(source, mutual_disconnection=False) - await self.update_neighbors(addr, remove=True) diff --git a/nebula/core/situationalawareness/modelhandlers/__init__.py b/nebula/core/situationalawareness/discovery/modelhandlers/__init__.py similarity index 100% rename from nebula/core/situationalawareness/modelhandlers/__init__.py rename to nebula/core/situationalawareness/discovery/modelhandlers/__init__.py diff --git a/nebula/core/situationalawareness/modelhandlers/aggmodelhandler.py b/nebula/core/situationalawareness/discovery/modelhandlers/aggmodelhandler.py similarity index 93% rename from nebula/core/situationalawareness/modelhandlers/aggmodelhandler.py rename to nebula/core/situationalawareness/discovery/modelhandlers/aggmodelhandler.py index a55c5d1cd..3e546f45a 100644 --- a/nebula/core/situationalawareness/modelhandlers/aggmodelhandler.py +++ b/nebula/core/situationalawareness/discovery/modelhandlers/aggmodelhandler.py @@ -1,4 +1,4 @@ -from nebula.core.situationalawareness.modelhandlers.modelhandler import ModelHandler +from nebula.core.situationalawareness.discovery.modelhandlers.modelhandler import ModelHandler from nebula.core.utils.locker import Locker class AGGModelHandler(ModelHandler): diff --git a/nebula/core/situationalawareness/modelhandlers/defaultmodelhandler.py b/nebula/core/situationalawareness/discovery/modelhandlers/defaultmodelhandler.py similarity index 81% rename from nebula/core/situationalawareness/modelhandlers/defaultmodelhandler.py rename to nebula/core/situationalawareness/discovery/modelhandlers/defaultmodelhandler.py index bcf850972..c0c8ced18 100644 --- a/nebula/core/situationalawareness/modelhandlers/defaultmodelhandler.py +++ b/nebula/core/situationalawareness/discovery/modelhandlers/defaultmodelhandler.py @@ -1,6 +1,6 @@ -from nebula.core.situationalawareness.modelhandlers.modelhandler import ModelHandler +from nebula.core.situationalawareness.discovery.modelhandlers.modelhandler import ModelHandler from nebula.core.utils.locker import Locker -from nebula.core.situationalawareness.nodemanager import NodeManager +from nebula.core.situationalawareness.discovery.federationconnector import FederationConnector import logging class DefaultModelHandler(ModelHandler): @@ -12,7 +12,7 @@ def __init__(self): self.epochs = 0 self.model_lock = Locker(name="model_lock") self.params_lock = Locker(name="param_lock") - self._nm : NodeManager = None + self._nm : FederationConnector = None def set_config(self, config): """ @@ -20,7 +20,7 @@ def set_config(self, config): config[0] -> total rounds config[1] -> current round config[2] -> epochs - config[3] -> NodeManager + config[3] -> FederationConnector """ self.params_lock.acquire() self.rounds = config[0] diff --git a/nebula/core/situationalawareness/modelhandlers/modelhandler.py b/nebula/core/situationalawareness/discovery/modelhandlers/modelhandler.py similarity index 66% rename from nebula/core/situationalawareness/modelhandlers/modelhandler.py rename to nebula/core/situationalawareness/discovery/modelhandlers/modelhandler.py index b60022c22..4af3ee557 100644 --- a/nebula/core/situationalawareness/modelhandlers/modelhandler.py +++ b/nebula/core/situationalawareness/discovery/modelhandlers/modelhandler.py @@ -20,9 +20,9 @@ def pre_process_model(self): pass def factory_ModelHandler(model_handler) -> ModelHandler: - from nebula.core.situationalawareness.modelhandlers.stdmodelhandler import STDModelHandler - from nebula.core.situationalawareness.modelhandlers.aggmodelhandler import AGGModelHandler - from nebula.core.situationalawareness.modelhandlers.defaultmodelhandler import DefaultModelHandler + from nebula.core.situationalawareness.discovery.modelhandlers.stdmodelhandler import STDModelHandler + from nebula.core.situationalawareness.discovery.modelhandlers.aggmodelhandler import AGGModelHandler + from nebula.core.situationalawareness.discovery.modelhandlers.defaultmodelhandler import DefaultModelHandler options = { "std": STDModelHandler, diff --git a/nebula/core/situationalawareness/modelhandlers/stdmodelhandler.py b/nebula/core/situationalawareness/discovery/modelhandlers/stdmodelhandler.py similarity index 93% rename from nebula/core/situationalawareness/modelhandlers/stdmodelhandler.py rename to nebula/core/situationalawareness/discovery/modelhandlers/stdmodelhandler.py index 24f3c59b5..3e41b1a25 100644 --- a/nebula/core/situationalawareness/modelhandlers/stdmodelhandler.py +++ b/nebula/core/situationalawareness/discovery/modelhandlers/stdmodelhandler.py @@ -1,4 +1,4 @@ -from nebula.core.situationalawareness.modelhandlers.modelhandler import ModelHandler +from nebula.core.situationalawareness.discovery.modelhandlers.modelhandler import ModelHandler from nebula.core.utils.locker import Locker diff --git a/nebula/core/situationalawareness/situationalawareness.py b/nebula/core/situationalawareness/situationalawareness.py new file mode 100644 index 000000000..72e95de95 --- /dev/null +++ b/nebula/core/situationalawareness/situationalawareness.py @@ -0,0 +1,29 @@ +from nebula.core.situationalawareness.discovery.federationconnector import FederationConnector +from nebula.core.situationalawareness.awareness.samodule import SAModule +from abc import ABC, abstractmethod + +class ISADiscovery(ABC): + @abstractmethod + async def start_late_connection_process(self, connected=False, msg_type="discover_join", addrs_known=None): + raise NotImplementedError + +class ISAReasoner(ABC): + @abstractmethod + def accept_connection(self, source, joining=False): + raise NotImplementedError + + @abstractmethod + def get_nodes_known(self, neighbors_too=False, neighbors_only=False): + raise NotImplementedError + + @abstractmethod + def get_actions(self): + raise NotImplementedError + +class SituationalAwareness(): + def __init__(self): + #self._situational_awareness_module = SAModule(self, self.config, self.engine.addr, topology, True) + pass + + async def init(self): + pass From 73607111a80e49a8171fd8cae8787404c7122450 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Mon, 28 Apr 2025 11:45:05 +0200 Subject: [PATCH 178/233] refactor SA module --- nebula/core/engine.py | 36 +++++--------- .../awareness/sanetwork/sanetwork.py | 24 +++++----- .../awareness/{samodule.py => sareasoner.py} | 28 +++++++---- .../awareness/sareputation/sareputation.py | 2 +- .../awareness/satraining/satraining.py | 6 +-- .../discovery/federationconnector.py | 31 ++++++------ .../situationalawareness.py | 47 +++++++++++++++++-- 7 files changed, 101 insertions(+), 73 deletions(-) rename nebula/core/situationalawareness/awareness/{samodule.py => sareasoner.py} (92%) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index ad53fb5eb..8592f8e6c 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -19,7 +19,6 @@ UpdateReceivedEvent, ) from nebula.core.network.communications import CommunicationsManager -from nebula.core.situationalawareness.discovery.federationconnector import FederationConnector from nebula.core.situationalawareness.situationalawareness import SituationalAwareness from nebula.core.utils.locker import Locker @@ -153,23 +152,10 @@ def __init__( self.trainning_in_progress_lock = Locker(name="trainning_in_progress_lock", async_lock=True) event_manager = EventManager.get_instance(verbose=False) - - # Mobility setup - self._node_manager = None - self.mobility = self.config.participant["mobility_args"]["mobility"] - if self.mobility == True: - topology = self.config.participant["mobility_args"]["topology_type"] - topology = topology.lower() - model_handler = "std" # self.config.participant["mobility_args"]["model_handler"] - self._node_manager = FederationConnector( - config.participant["mobility_args"]["additional_node"]["status"], - topology, - model_handler, - engine=self, - verbose=True - ) - self._addon_manager = AddondManager(self, self.config) + + # Additional Components + self._situational_awareness = SituationalAwareness(self.config) @property def cm(self): @@ -189,10 +175,10 @@ def get_aggregator_type(self): @property def trainer(self): return self._trainer - + @property - def nm(self): - return self._node_manager + def sa(self): + return self._situational_awareness def get_addr(self): return self.addr @@ -386,9 +372,9 @@ async def register_message_callback(self, message_event: tuple[str, str], callba async def _aditional_node_start(self): logging.info(f"Aditional node | {self.addr} | going to stablish connection with federation") - await self.nm.start_late_connection_process() + await self.sa.start_late_connection_process() # continue .. - # asyncio.create_task(self.nm.stop_not_selected_connections()) + # asyncio.create_task(self.sa.stop_not_selected_connections()) logging.info("Creating trainer service to start the federation process..") asyncio.create_task(self._start_learning_late()) @@ -414,7 +400,7 @@ async def update_model_learning_rate(self, new_lr): async def _start_learning_late(self): await self.learning_cycle_lock.acquire_async() try: - model_serialized, rounds, round, _epochs = await self.nm.get_trainning_info() + model_serialized, rounds, round, _epochs = await self.sa.get_trainning_info() self.total_rounds = rounds epochs = _epochs await self.get_round_lock().acquire_async() @@ -470,8 +456,8 @@ async def start_communications(self): async def deploy_components(self): await self.aggregator.init() - if self.mobility: - await self.nm.set_configs() + if self.config.participant["mobility_args"]["mobility"]: + await self.sa.init() await self._reporter.start() await self._addon_manager.deploy_additional_services() diff --git a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py index 328d51ddd..e541ca6e7 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py @@ -8,14 +8,14 @@ from nebula.core.eventmanager import EventManager from nebula.core.nebulaevents import NodeFoundEvent, UpdateNeighborEvent, ExperimentFinishEvent, RoundEndEvent from nebula.core.network.communications import CommunicationsManager -from nebula.core.situationalawareness.awareness.samodule import SAMComponent +from nebula.core.situationalawareness.awareness.sareasoner import SAMComponent from nebula.core.situationalawareness.awareness.sautils.samoduleagent import SAModuleAgent from nebula.core.situationalawareness.awareness.sautils.sacommand import SACommand, SACommandAction, SACommandPRIO, SACommandState, factory_sa_command from nebula.core.situationalawareness.awareness.suggestionbuffer import SuggestionBuffer from typing import TYPE_CHECKING if TYPE_CHECKING: - from nebula.core.situationalawareness.awareness.samodule import SAModule + from nebula.core.situationalawareness.awareness.sareasoner import SAReasoner RESTRUCTURE_COOLDOWN = 5 @@ -25,7 +25,7 @@ class SANetwork(SAMComponent): def __init__( self, - sam: "SAModule", + sar: "SAReasoner", addr, topology, strict_topology=True, @@ -36,7 +36,7 @@ def __init__( indent=2, title="Network SA module", ) - self._sam = sam + self._sar = sar self._addr = addr self._topology = topology self._strict_topology = strict_topology @@ -48,9 +48,9 @@ def __init__( self._sa_network_agent = SANetworkAgent(self) @property - def sam(self): - """SA Module""" - return self._sam + def sar(self): + """SA Reasoner""" + return self._sar @property def cm(self): @@ -67,7 +67,7 @@ def sana(self): return self._sa_network_agent async def init(self): - if not self.sam.is_additional_participant(): + if not self.sar.is_additional_participant(): logging.info("Deploying External Connection Service") await self.cm.start_external_connection_service() await EventManager.get_instance().subscribe_node_event(BeaconRecievedEvent, self.beacon_received) @@ -196,12 +196,12 @@ async def reconnect_to_federation(self): # If we got some refs, try to reconnect to them if len(self.np.get_nodes_known()) > 0: if self._verbose: logging.info("Reconnecting | Addrs availables") - await self.sam.nm.start_late_connection_process( + await self.sar.sad.start_late_connection_process( connected=False, msg_type="discover_nodes", addrs_known=self.np.get_nodes_known() ) else: if self._verbose: logging.info("Reconnecting | NO Addrs availables") - await self.sam.nm.start_late_connection_process(connected=False, msg_type="discover_nodes") + await self.sar.sad.start_late_connection_process(connected=False, msg_type="discover_nodes") self._restructure_process_lock.release() async def upgrade_connection_robustness(self, possible_neighbors): @@ -209,12 +209,12 @@ async def upgrade_connection_robustness(self, possible_neighbors): # If we got some refs, try to connect to them if possible_neighbors and len(possible_neighbors) > 0: if self._verbose: logging.info(f"Reestructuring | Addrs availables | addr list: {possible_neighbors}") - await self.sam.nm.start_late_connection_process( + await self.sar.sad.start_late_connection_process( connected=True, msg_type="discover_nodes", addrs_known=possible_neighbors ) else: if self._verbose: logging.info("Reestructuring | NO Addrs availables") - await self.sam.nm.start_late_connection_process(connected=True, msg_type="discover_nodes") + await self.sar.sad.start_late_connection_process(connected=True, msg_type="discover_nodes") self._restructure_process_lock.release() async def stop_connections_with_federation(self): diff --git a/nebula/core/situationalawareness/awareness/samodule.py b/nebula/core/situationalawareness/awareness/sareasoner.py similarity index 92% rename from nebula/core/situationalawareness/awareness/samodule.py rename to nebula/core/situationalawareness/awareness/sareasoner.py index f339aa874..dde27445f 100644 --- a/nebula/core/situationalawareness/awareness/samodule.py +++ b/nebula/core/situationalawareness/awareness/sareasoner.py @@ -23,7 +23,7 @@ async def sa_component_actions(self): raise NotImplementedError -class SAModule(ISAReasoner): +class SAReasoner(ISAReasoner): MODULE_PATH = "nebula/nebula/core/situationalawareness/awareness" def __init__( @@ -34,11 +34,11 @@ def __init__( verbose = False, ): print_msg_box( - msg=f"Starting Situational Awareness module...", + msg=f"Starting Situational Awareness Reasoner module...", indent=2, - title="Situational Awareness module", + title="SA Reasoner", ) - logging.info("🌐 Initializing SAModule") + logging.info("🌐 Initializing SAReasoner") self._config = config self._addr = addr self._topology = topology @@ -50,8 +50,9 @@ def __init__( self._suggestion_buffer = SuggestionBuffer(self._arbitrator_notification, verbose=True) self._communciation_manager = CommunicationsManager.get_instance() self._sys_monitor = SystemMonitor() - self._arbitatrion_policy = factory_arbitatrion_policy("sad", True) + self._arbitatrion_policy = factory_arbitatrion_policy("sap", True) self._sa_components: dict[str, SAMComponent] = {} + self._sa_discovery: ISADiscovery = None self._verbose = verbose @property @@ -73,11 +74,18 @@ def ab(self): """Arbitatrion Policy""" return self._arbitatrion_policy - async def init(self): + @property + def sad(self): + """SA Discovery""" + return self._sa_discovery + + async def init(self, sa_discovery: ISADiscovery): + #await self.loading_sa_components() from nebula.core.situationalawareness.awareness.sanetwork.sanetwork import SANetwork - from nebula.core.situationalawareness.awareness.satraining.satraining import SATraining + #from nebula.core.situationalawareness.awareness.satraining.satraining import SATraining self._situational_awareness_network = SANetwork(self, self._addr, self._topology, verbose=True) - self._situational_awareness_training = SATraining(self, self._addr, "qds", "fastreboot", verbose=True) + #self._situational_awareness_training = SATraining(self, self._addr, "qds", "fastreboot", verbose=True) + self._sa_discovery = sa_discovery await self.san.init() await EventManager.get_instance().subscribe_node_event(RoundEndEvent, self._process_round_end_event) await EventManager.get_instance().subscribe_node_event(AggregationEvent, self._process_aggregation_event) @@ -114,8 +122,8 @@ def get_actions(self): async def _process_round_end_event(self, ree : RoundEndEvent): logging.info("πŸ”„ Arbitration | Round End Event...") - asyncio.create_task(self.san.sa_component_actions()) - asyncio.create_task(self.sat.sa_component_actions()) + for sa_comp in self._sa_components.values(): + asyncio.create_task(sa_comp.sa_component_actions()) valid_commands = await self._arbitatrion_suggestions(RoundEndEvent) # Execute SACommand selected diff --git a/nebula/core/situationalawareness/awareness/sareputation/sareputation.py b/nebula/core/situationalawareness/awareness/sareputation/sareputation.py index 1a7194c2e..9015ba870 100644 --- a/nebula/core/situationalawareness/awareness/sareputation/sareputation.py +++ b/nebula/core/situationalawareness/awareness/sareputation/sareputation.py @@ -1,4 +1,4 @@ -from nebula.core.situationalawareness.awareness.samodule import SAMComponent +from nebula.core.situationalawareness.awareness.sareasoner import SAMComponent from enum import Enum class ReputationCategory(Enum): # Reputational thresholds diff --git a/nebula/core/situationalawareness/awareness/satraining/satraining.py b/nebula/core/situationalawareness/awareness/satraining/satraining.py index 6b5c17dd6..deded5a86 100644 --- a/nebula/core/situationalawareness/awareness/satraining/satraining.py +++ b/nebula/core/situationalawareness/awareness/satraining/satraining.py @@ -2,11 +2,11 @@ import logging from nebula.core.utils.locker import Locker from nebula.core.situationalawareness.awareness.satraining.trainingpolicy.trainingpolicy import factory_training_policy -from nebula.core.situationalawareness.awareness.samodule import SAMComponent +from nebula.core.situationalawareness.awareness.sareasoner import SAMComponent from nebula.addons.functions import print_msg_box from typing import TYPE_CHECKING if TYPE_CHECKING: - from nebula.core.situationalawareness.awareness.samodule import SAModule, SAMComponent + from nebula.core.situationalawareness.awareness.sareasoner import SAReasoner, SAMComponent from nebula.core.eventmanager import EventManager RESTRUCTURE_COOLDOWN = 5 @@ -14,7 +14,7 @@ class SATraining(SAMComponent): def __init__( self, - sam: "SAModule", + sam: "SAReasoner", addr, training_policy, weight_strategies, diff --git a/nebula/core/situationalawareness/discovery/federationconnector.py b/nebula/core/situationalawareness/discovery/federationconnector.py index c193147af..138060c68 100644 --- a/nebula/core/situationalawareness/discovery/federationconnector.py +++ b/nebula/core/situationalawareness/discovery/federationconnector.py @@ -6,7 +6,7 @@ from nebula.core.situationalawareness.discovery.candidateselection.candidateselector import factory_CandidateSelector from nebula.core.situationalawareness.discovery.modelhandlers.modelhandler import factory_ModelHandler from nebula.core.situationalawareness.situationalawareness import ISADiscovery, ISAReasoner -from nebula.core.situationalawareness.awareness.samodule import SAModule +from nebula.core.situationalawareness.awareness.sareasoner import SAReasoner from nebula.core.utils.locker import Locker from nebula.core.eventmanager import EventManager from nebula.core.nebulaevents import UpdateNeighborEvent, NodeFoundEvent @@ -33,9 +33,9 @@ def __init__( self._aditional_participant = aditional_participant self.topology = topology print_msg_box( - msg=f"Starting NodeManager module...", indent=2, title="NodeManager module" + msg=f"Starting FederationConnector module...", indent=2, title="FederationConnector module" ) - logging.info("🌐 Initializing Node Manager") + logging.info("🌐 Initializing Federation Connector") self._engine = engine self._cm = None self.config = engine.get_config() @@ -52,7 +52,7 @@ def __init__( self.discarded_offers_addr_lock = Locker(name="discarded_offers_addr_lock") self.discarded_offers_addr = [] - self._situational_awareness_module = SAModule(self, self.config, self.engine.addr, topology, True) + self._sa_reasoner: ISAReasoner = None self._verbose = verbose @property @@ -72,14 +72,14 @@ def model_handler(self): return self._model_handler @property - def sam(self): + def sar(self): """Situational Awareness Module""" - return self._situational_awareness_module + return self._sa_reasoner def is_additional_participant(self): return self._aditional_participant - async def set_configs(self): + async def init(self, sa_reasoner: ISAReasoner): """ model_handler config: - self total rounds @@ -91,9 +91,9 @@ async def set_configs(self): - self weight distance - self weight hetereogeneity """ - logging.info("Building NodeManager configurations...") + logging.info("Building Federation Connector configurations...") + self._sa_reasoner = sa_reasoner await self.register_message_events_callbacks() - await self.sam.init() await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent, self.update_neighbors) logging.info("Building candidate selector configuration..") self.candidate_selector.set_config([0, 0.5, 0.5]) @@ -105,11 +105,8 @@ async def set_configs(self): ############################## """ - def get_restructure_process_lock(self): - return self.sam.get_restructure_process_lock() - def accept_connection(self, source, joining=False): - return self.sam.accept_connection(source, joining) + return self.sar.accept_connection(source, joining) def still_waiting_for_candidates(self): return not self.accept_candidates_lock.locked() and self.late_connection_process_lock.locked() @@ -117,7 +114,7 @@ def still_waiting_for_candidates(self): async def add_pending_connection_confirmation(self, addr): await self._update_neighbors_lock.acquire_async() await self.pending_confirmation_from_nodes_lock.acquire_async() - if addr not in self.sam.get_nodes_known(neighbors_only=True): + if addr not in self.sar.get_nodes_known(neighbors_only=True): logging.info(f" Addition | pending connection confirmation from: {addr}") self.pending_confirmation_from_nodes.add(addr) await self.pending_confirmation_from_nodes_lock.release_async() @@ -152,7 +149,7 @@ def add_to_discarded_offers(self, addr_discarded): self.discarded_offers_addr_lock.release() def get_actions(self): - return self.sam.get_actions() + return self.sar.get_actions() async def register_late_neighbor(self, addr, joinning_federation=False): if self._verbose: logging.info(f"Registering | late neighbor: {addr}, joining: {joinning_federation}") @@ -172,7 +169,7 @@ async def meet_node(self, node): await EventManager.get_instance().publish_node_event(nfe) def get_nodes_known(self, neighbors_too=False): - return self.sam.get_nodes_known(neighbors_too) + return self.sar.get_nodes_known(neighbors_too) def accept_model_offer(self, source, decoded_model, rounds, round, epochs, n_neighbors, loss): if not self.accept_candidates_lock.locked(): @@ -419,7 +416,7 @@ async def _offer_offer_model_callback(self, source, message): logging.info(f"❗️ Error proccesing offer model from {source}") else: logging.info( - f"❗️ handfle_offer_message | NOT accepting offers | restructure: {self.get_restructure_process_lock().locked()} | waiting candidates: {self.still_waiting_for_candidates()}" + f"❗️ handfle_offer_message | NOT accepting offers | waiting candidates: {self.still_waiting_for_candidates()}" ) self.add_to_discarded_offers(source) diff --git a/nebula/core/situationalawareness/situationalawareness.py b/nebula/core/situationalawareness/situationalawareness.py index 72e95de95..0616361d3 100644 --- a/nebula/core/situationalawareness/situationalawareness.py +++ b/nebula/core/situationalawareness/situationalawareness.py @@ -1,6 +1,7 @@ from nebula.core.situationalawareness.discovery.federationconnector import FederationConnector -from nebula.core.situationalawareness.awareness.samodule import SAModule +from nebula.core.situationalawareness.awareness.sareasoner import SAReasoner from abc import ABC, abstractmethod +from nebula.addons.functions import print_msg_box class ISADiscovery(ABC): @abstractmethod @@ -21,9 +22,45 @@ def get_actions(self): raise NotImplementedError class SituationalAwareness(): - def __init__(self): - #self._situational_awareness_module = SAModule(self, self.config, self.engine.addr, topology, True) - pass + def __init__(self, config): + print_msg_box( + msg=f"Starting Situational Awareness module...", + indent=2, + title="Situational Awareness module", + ) + self._config = config + topology = self._config.participant["mobility_args"]["topology_type"] + topology = topology.lower() + model_handler = "std" + self._federation_connector = FederationConnector( + self._config.participant["mobility_args"]["additional_node"]["status"], + topology, + model_handler, + engine=self, + verbose=True + ) + self._sareasoner = SAReasoner( + self._config, + self._config.participant["network_args"]["addr"], + topology, + verbose=True) + + @property + def fedcon(self): + """Federation Connector""" + return self._federation_connector + + @property + def sar(self): + """SA Reasoner""" + return self._sareasoner async def init(self): - pass + await self.fedcon.init(self.sar) + await self.sar.init(self.fedcon) + + async def start_late_connection_process(self): + await self.fedcon.start_late_connection_process() + + async def get_trainning_info(self): + return await self.fedcon.get_trainning_info() From 745b7312a2d85f55a139a5c79673468b63bcb8c5 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 30 Apr 2025 12:57:07 +0200 Subject: [PATCH 179/233] feature consistency reputation --- .../sareputation/behaviorreputation.py | 18 +-- .../sareputation/collaborativereputation.py | 3 + .../sareputation/consistencyreputation.py | 135 ++++++++++++++++++ .../awareness/sareputation/sareputation.py | 6 + .../situationalawareness.py | 3 +- 5 files changed, 153 insertions(+), 12 deletions(-) create mode 100644 nebula/core/situationalawareness/awareness/sareputation/collaborativereputation.py create mode 100644 nebula/core/situationalawareness/awareness/sareputation/consistencyreputation.py diff --git a/nebula/core/situationalawareness/awareness/sareputation/behaviorreputation.py b/nebula/core/situationalawareness/awareness/sareputation/behaviorreputation.py index 6047e861f..3e5a03a1a 100644 --- a/nebula/core/situationalawareness/awareness/sareputation/behaviorreputation.py +++ b/nebula/core/situationalawareness/awareness/sareputation/behaviorreputation.py @@ -6,13 +6,9 @@ from nebula.core.nebulaevents import UpdateReceivedEvent, AggregationEvent, RoundStartEvent, UpdateNeighborEvent, NodeBlacklistedEvent import time from enum import Enum +from nebula.core.situationalawareness.awareness.sareputation.sareputation import ThreatCategory import asyncio -class ThreatCategoryBehavior(Enum): - FLOODING = "flooding" - INACTIVITY = "inactivity" - BAD_BEHAVIOR = "bad behavior" - class TimeStamp(): def __init__(self, time_received = None, time_since_last_event = None): self.tr = time_received @@ -89,9 +85,9 @@ def hba_lock(self): def __str__(self): return "Behavior Reputation" - async def init(self, config): + async def init(self, neighbor_List): async with self._nodes_lock: - nodes = config["nodes"] + nodes = neighbor_List self._nodes = { node_id: ( deque(maxlen=self.MAX_HISTORIC_SIZE), # Updates per round, @@ -113,7 +109,7 @@ async def init(self, config): await EventManager.get_instance().subscribe_node_event(NodeBlacklistedEvent, self._process_node_blacklisted_event) await EventManager.get_instance().subscribe(None, self._process_messages_received) - async def get_behavior_scores(self, historical=False): + async def get_scores(self, historical=False): if historical: return self.hbs.copy() else: @@ -144,10 +140,10 @@ async def _process_round_start(self, rse: RoundStartEvent): async with self._messages_received_per_round_lock: self._messages_received_per_round.clear() - async def _process_aggregation_event(self, are: AggregationEvent): + async def _process_aggregation_event(self, age: AggregationEvent): self._last_aggregation_time = time.time() if self._verbose: logging.info("Processing aggregation event") - (_, expected_nodes, missing_nodes) = await are.get_event_data() + (_, expected_nodes, missing_nodes) = await age.get_event_data() async with self._nodes_lock: for node in expected_nodes: @@ -217,7 +213,7 @@ async def _process_messages_received(self, source, message): self._suspicious_nodes.union({(source, ThreatCategoryBehavior.FLOODING)}) self._messages_received_per_round[source] = n_messages - async def _evaluate(self): + async def evaluate(self): if self._verbose: logging.info("Evaluating Behavior Reputation, generating score...") nodes = await self._get_nodes() diff --git a/nebula/core/situationalawareness/awareness/sareputation/collaborativereputation.py b/nebula/core/situationalawareness/awareness/sareputation/collaborativereputation.py new file mode 100644 index 000000000..2daae0c10 --- /dev/null +++ b/nebula/core/situationalawareness/awareness/sareputation/collaborativereputation.py @@ -0,0 +1,3 @@ + +class CollaborativeReputation(): + pass \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/sareputation/consistencyreputation.py b/nebula/core/situationalawareness/awareness/sareputation/consistencyreputation.py new file mode 100644 index 000000000..b249d031c --- /dev/null +++ b/nebula/core/situationalawareness/awareness/sareputation/consistencyreputation.py @@ -0,0 +1,135 @@ +from nebula.core.utils.locker import Locker +from collections import deque, OrderedDict +import logging +import numpy as np +from scipy.stats import linregress +from collections import defaultdict +from nebula.core.utils.helper import cosine_metric +from nebula.core.eventmanager import EventManager +from nebula.core.nebulaevents import AggregationEvent, UpdateNeighborEvent + +class ConsistencyReputation(): + MAX_HISTORIC_SIZE = 20 + CONSISTENCY_THRESHOLD = 5 + SIMILARITY_THRESHOLD = 0.85 + DEFAULT_SIMILARITY_WEIGHT = 0.7 + DEFAULT_CONSISTENCY_WEIGHT = 0.3 + ADVANCED_SIMILARITY_WEIGHT = 0.6 + ADVANCED_CONSISTENCY_WEIGHT = 0.4 + + def __init__(self, config): + self._addr = config["addr"] + self._verbose = config["verbose"] + self._historical_similarities: dict[str, deque[float]] = defaultdict(deque) + self._historical_similarities_lock = Locker("historical_similarities_lock", async_lock=True) + self._consistency_scores: dict[str, deque[float]] = defaultdict(deque) + self._consistency_scores_lock = Locker("consistency_scores_lock", async_lock=True) + + @property + def hs(self): + return self._historical_similarities + + @property + def cs(self): + return self._consistency_scores + + async def init(self, neighbor_list): + async with self._historical_similarities_lock: + async with self._consistency_scores_lock: + for node in neighbor_list: + self.hs[node] = deque(maxlen=self.MAX_HISTORIC_SIZE) + self.cs[node] = deque(maxlen=self.MAX_HISTORIC_SIZE) + + await EventManager.get_instance().subscribe_node_event(AggregationEvent, self._process_aggregation_event) + await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent, self._process_update_neighbor_event) + + async def _process_update_neighbor_event(self, une: UpdateNeighborEvent): + node, remove = await une.get_event_data() + async with self._historical_similarities_lock: + async with self._consistency_scores_lock: + if remove: + self.hs.pop(node, None) + self.cs.pop(node, None) + else: + if not node in self.hs: + self.hs.update({node : deque(maxlen=self.MAX_HISTORIC_SIZE)}) + self.cs.update({node : deque(maxlen=self.MAX_HISTORIC_SIZE)}) + + async def _process_aggregation_event(self, age: AggregationEvent): + (updates, expected_nodes, _) = await age.get_event_data() + self_model, _ = updates[self._addr] + async with self._historical_similarities_lock: + for node in expected_nodes: + similarity = self._calculate_model_similarity(self_model, updates[node][0]) + self.hs[node].append(similarity) + + async def get_scores(self, historical=False): + if historical: + return self.cs.copy() + else: + last_scores = {node: scores[-1] for node,scores in self.cs.items()} + return last_scores + + def _calculate_model_similarity(self, model1: OrderedDict, model2: OrderedDict): + return cosine_metric(model1=model1, model2=model2, similarity=True) + + async def evaluate(self): + if self._verbose: logging.info("Evaluating Consistency Reputation, generating score...") + reputation_scores = {} + + async with self._historical_similarities_lock: + async with self._consistency_scores_lock: + for node, history in self.hs.items(): + if not history: + reputation_scores[node] = 0.0 + continue + + score = self._compute_score(list(history)) + self.cs[node].append(score) + reputation_scores[node] = score + + if self._verbose: + for node, score in reputation_scores.items(): + logging.info(f"Node {node} consistency score: {score:.4f}") + + return reputation_scores + + def _compute_score(self, similarities: list[float]) -> float: + if not similarities: + return 0.0 + + similarity_weight = self.DEFAULT_SIMILARITY_WEIGHT + consistency_weight = self.DEFAULT_CONSISTENCY_WEIGHT + + latest_similarity = similarities[-1] + + if latest_similarity >= self.SIMILARITY_THRESHOLD: + base_score = 1.0 + else: + # Linearly scaled score between 0 and 1 based on how close it is to the threshold + base_score = latest_similarity / self.SIMILARITY_THRESHOLD + + temporal_consistency = 0.5 # Default medium trust for very short histories + + # Temporal consistency: compute inverse variance (lower variance β†’ higher trust) + if len(similarities) >= self.CONSISTENCY_THRESHOLD: + similarity_weight = self.ADVANCED_SIMILARITY_WEIGHT + consistency_weight = self.ADVANCED_CONSISTENCY_WEIGHT + + var = np.var(similarities) + temporal_consistency = 1.0 - min(var, 1.0) + + # 2. Trend analysis: penalize downward trends + x = list(range(len(similarities))) + slope, _, _, _, _ = linregress(x, similarities) + + # Normalize slope to a range [-1, 1] depending on steepness + # and penalize negative slope + if slope < 0: + # For example: max penalty of 0.2 if strong negative slope + trend_penalty = min(abs(slope), 0.2) + + + # Weighted average: 70% recent similarity, 30% temporal consistency + final_score = similarity_weight * base_score + consistency_weight * temporal_consistency + return round(final_score, 4) \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/sareputation/sareputation.py b/nebula/core/situationalawareness/awareness/sareputation/sareputation.py index 9015ba870..9f73b941d 100644 --- a/nebula/core/situationalawareness/awareness/sareputation/sareputation.py +++ b/nebula/core/situationalawareness/awareness/sareputation/sareputation.py @@ -1,6 +1,12 @@ from nebula.core.situationalawareness.awareness.sareasoner import SAMComponent from enum import Enum +class ThreatCategory(Enum): + FLOODING = "flooding" + INACTIVITY = "inactivity" + BAD_BEHAVIOR = "bad behavior" + MODEL_POISSONING = "model poissoning" + class ReputationCategory(Enum): # Reputational thresholds HIGH_TRUSTED = 0.9 TRUSTED = 0.8 diff --git a/nebula/core/situationalawareness/situationalawareness.py b/nebula/core/situationalawareness/situationalawareness.py index 0616361d3..d74f03dda 100644 --- a/nebula/core/situationalawareness/situationalawareness.py +++ b/nebula/core/situationalawareness/situationalawareness.py @@ -43,7 +43,8 @@ def __init__(self, config): self._config, self._config.participant["network_args"]["addr"], topology, - verbose=True) + verbose=True + ) @property def fedcon(self): From 9e567eb3b8700d9d0eec42772008e3aec47084dd Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Fri, 2 May 2025 12:21:28 +0200 Subject: [PATCH 180/233] feature advanced consistency metrics --- .../sareputation/behaviorreputation.py | 7 +- .../sareputation/collaborativereputation.py | 22 ++++- .../sareputation/consistencyreputation.py | 91 ++++++++++++++----- 3 files changed, 91 insertions(+), 29 deletions(-) diff --git a/nebula/core/situationalawareness/awareness/sareputation/behaviorreputation.py b/nebula/core/situationalawareness/awareness/sareputation/behaviorreputation.py index 3e5a03a1a..c2e5b8082 100644 --- a/nebula/core/situationalawareness/awareness/sareputation/behaviorreputation.py +++ b/nebula/core/situationalawareness/awareness/sareputation/behaviorreputation.py @@ -5,7 +5,6 @@ from nebula.core.eventmanager import EventManager from nebula.core.nebulaevents import UpdateReceivedEvent, AggregationEvent, RoundStartEvent, UpdateNeighborEvent, NodeBlacklistedEvent import time -from enum import Enum from nebula.core.situationalawareness.awareness.sareputation.sareputation import ThreatCategory import asyncio @@ -210,7 +209,7 @@ async def _process_messages_received(self, source, message): n_messages += 1 if n_messages >= self.MAX_MESSAGES_PER_ROUND: async with self._suspicious_nodes_lock: - self._suspicious_nodes.union({(source, ThreatCategoryBehavior.FLOODING)}) + self._suspicious_nodes.union({(source, ThreatCategory.FLOODING)}) self._messages_received_per_round[source] = n_messages async def evaluate(self): @@ -260,7 +259,7 @@ async def evaluate(self): # Check inactivity beyond max tolerance if missed_count >= self.MAX_INACTIVITY_ALLOWED: async with self._suspicious_nodes_lock: - self._suspicious_nodes.union(((node, ThreatCategoryBehavior.INACTIVITY))) + self._suspicious_nodes.union(((node, ThreatCategory.INACTIVITY))) scores[node] = 0.0 continue @@ -303,7 +302,7 @@ async def evaluate(self): # Update suspicious nodes async with self._suspicious_nodes_lock: - self._suspicious_nodes.union({(n, ThreatCategoryBehavior.BAD_BEHAVIOR) for n in nodes_below_th}) + self._suspicious_nodes.union({(n, ThreatCategory.BAD_BEHAVIOR) for n in nodes_below_th}) diff --git a/nebula/core/situationalawareness/awareness/sareputation/collaborativereputation.py b/nebula/core/situationalawareness/awareness/sareputation/collaborativereputation.py index 2daae0c10..0378742f4 100644 --- a/nebula/core/situationalawareness/awareness/sareputation/collaborativereputation.py +++ b/nebula/core/situationalawareness/awareness/sareputation/collaborativereputation.py @@ -1,3 +1,23 @@ +from nebula.core.utils.locker import Locker +from collections import deque, OrderedDict +import logging +from collections import defaultdict +from nebula.core.eventmanager import EventManager +from nebula.core.nebulaevents import UpdateNeighborEvent +from nebula.core.situationalawareness.awareness.sareputation.sareputation import ThreatCategory class CollaborativeReputation(): - pass \ No newline at end of file + MAX_TRIALS_ACCEPTED = 3 + + def __init__(self): + self._trials_recently_received: int = 0 + + async def init(self): + await EventManager.get_instance().subscribe(("reputation", "share_reputation"), self._process_share_reputation_message) + await EventManager.get_instance().subscribe(("reputation", "submit_verdict"), self._process_trial_verdict_message) + + async def _process_share_reputation_message(self, source, message): + pass + + async def _process_trial_verdict_message(self, source, message): + pass \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/sareputation/consistencyreputation.py b/nebula/core/situationalawareness/awareness/sareputation/consistencyreputation.py index b249d031c..03141b4c3 100644 --- a/nebula/core/situationalawareness/awareness/sareputation/consistencyreputation.py +++ b/nebula/core/situationalawareness/awareness/sareputation/consistencyreputation.py @@ -7,15 +7,21 @@ from nebula.core.utils.helper import cosine_metric from nebula.core.eventmanager import EventManager from nebula.core.nebulaevents import AggregationEvent, UpdateNeighborEvent +from nebula.core.situationalawareness.awareness.sareputation.sareputation import ThreatCategory class ConsistencyReputation(): MAX_HISTORIC_SIZE = 20 - CONSISTENCY_THRESHOLD = 5 + SCORE_THRESHOLD_MALICIOUS = 0.5 # Threshold to detect posible malicious nodes + SCORE_THRESHOLD_SUSPICIOUS = 0.6 + ADVANCED_METRICS_THRESHOLD = 5 SIMILARITY_THRESHOLD = 0.85 - DEFAULT_SIMILARITY_WEIGHT = 0.7 - DEFAULT_CONSISTENCY_WEIGHT = 0.3 - ADVANCED_SIMILARITY_WEIGHT = 0.6 - ADVANCED_CONSISTENCY_WEIGHT = 0.4 + DEFAULT_SIMILARITY_WEIGHT = 0.6 # Default metrics + DEFAULT_CONSISTENCY_WEIGHT = 0.15 + DEFAULT_STABILITY_WEIGHT = 0.25 + ADVANCED_SIMILARITY_WEIGHT = 0.45 # Advanced metrics + ADVANCED_CONSISTENCY_WEIGHT = 0.2 + ADVANCED_STABILITY_WEIGHT = 0.35 + def __init__(self, config): self._addr = config["addr"] @@ -24,6 +30,8 @@ def __init__(self, config): self._historical_similarities_lock = Locker("historical_similarities_lock", async_lock=True) self._consistency_scores: dict[str, deque[float]] = defaultdict(deque) self._consistency_scores_lock = Locker("consistency_scores_lock", async_lock=True) + self._suspicious_nodes = set() + self._suspicious_nodes_lock = Locker(name="suspicious_nodes_lock", async_lock=True) @property def hs(self): @@ -42,6 +50,10 @@ async def init(self, neighbor_list): await EventManager.get_instance().subscribe_node_event(AggregationEvent, self._process_aggregation_event) await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent, self._process_update_neighbor_event) + + async def get_suspicious_nodes(self): + async with self._suspicious_nodes_lock: + return self._suspicious_nodes.copy() async def _process_update_neighbor_event(self, une: UpdateNeighborEvent): node, remove = await une.get_event_data() @@ -73,6 +85,16 @@ async def get_scores(self, historical=False): def _calculate_model_similarity(self, model1: OrderedDict, model2: OrderedDict): return cosine_metric(model1=model1, model2=model2, similarity=True) + async def analize_malice(self, node, score): + category = None + if score <= self.SCORE_THRESHOLD_MALICIOUS: + category = (node, ThreatCategory.MODEL_POISSONING) + elif score <= self.SCORE_THRESHOLD_SUSPICIOUS: + category = (node, ThreatCategory.BAD_BEHAVIOR) + if category: + async with self._suspicious_nodes_lock: + self._suspicious_nodes.add(category) + async def evaluate(self): if self._verbose: logging.info("Evaluating Consistency Reputation, generating score...") reputation_scores = {} @@ -83,8 +105,9 @@ async def evaluate(self): if not history: reputation_scores[node] = 0.0 continue - + if self._verbose: logging.info(f"Node being evaluated: {node}") score = self._compute_score(list(history)) + if self._verbose: logging.info(f"Final score: {score}") self.cs[node].append(score) reputation_scores[node] = score @@ -97,39 +120,59 @@ async def evaluate(self): def _compute_score(self, similarities: list[float]) -> float: if not similarities: return 0.0 - - similarity_weight = self.DEFAULT_SIMILARITY_WEIGHT - consistency_weight = self.DEFAULT_CONSISTENCY_WEIGHT + # Última similitud observada latest_similarity = similarities[-1] + # Base score (cuΓ‘nto se acerca a la similitud esperada) if latest_similarity >= self.SIMILARITY_THRESHOLD: base_score = 1.0 else: - # Linearly scaled score between 0 and 1 based on how close it is to the threshold base_score = latest_similarity / self.SIMILARITY_THRESHOLD - temporal_consistency = 0.5 # Default medium trust for very short histories + # Pesos por defecto + similarity_weight = self.DEFAULT_SIMILARITY_WEIGHT + consistency_weight = self.DEFAULT_CONSISTENCY_WEIGHT + stability_weight = self.DEFAULT_STABILITY_WEIGHT + + temporal_consistency = 0.5 + local_stability_score = 0.5 + trend_penalty = 0.0 - # Temporal consistency: compute inverse variance (lower variance β†’ higher trust) - if len(similarities) >= self.CONSISTENCY_THRESHOLD: + final_score = base_score + + if len(similarities) >= self.ADVANCED_METRICS_THRESHOLD: + # Ajuste de pesos avanzados similarity_weight = self.ADVANCED_SIMILARITY_WEIGHT consistency_weight = self.ADVANCED_CONSISTENCY_WEIGHT - + stability_weight = self.ADVANCED_STABILITY_WEIGHT + + # --- Consistencia temporal (varianza) --- var = np.var(similarities) temporal_consistency = 1.0 - min(var, 1.0) - - # 2. Trend analysis: penalize downward trends + + # --- Tendencia (slope) --- x = list(range(len(similarities))) slope, _, _, _, _ = linregress(x, similarities) - - # Normalize slope to a range [-1, 1] depending on steepness - # and penalize negative slope if slope < 0: - # For example: max penalty of 0.2 if strong negative slope trend_penalty = min(abs(slope), 0.2) - + elif slope > 0: + # Refuerzo positivo si la tendencia es creciente + base_score += min(slope, 0.2) + + # --- Estabilidad local (diferencia entre actualizaciones consecutivas) --- + diffs = [abs(similarities[i] - similarities[i - 1]) for i in range(1, len(similarities))] + avg_fluctuation = np.mean(diffs) + local_stability_score = 1.0 - min(avg_fluctuation, 1.0) - # Weighted average: 70% recent similarity, 30% temporal consistency - final_score = similarity_weight * base_score + consistency_weight * temporal_consistency - return round(final_score, 4) \ No newline at end of file + # --- CΓ‘lculo final --- + final_score = ( + similarity_weight * base_score + + consistency_weight * temporal_consistency + + stability_weight * local_stability_score + ) + if self._verbose: logging.info(f"Similarity score: {similarity_weight * base_score} | Consistency score: {consistency_weight * temporal_consistency} | Stability score: {stability_weight * local_stability_score} | Tren penalty: {trend_penalty}") + final_score = max(0.0, final_score - trend_penalty) + + return round(final_score, 4) + From d27205850f4511a953eea9e1ecc8ad8802f4786f Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Mon, 5 May 2025 11:14:17 +0200 Subject: [PATCH 181/233] fix merge errors --- nebula/addons/mobility.py | 2 - nebula/core/datasets/nebuladataset.py | 1 - nebula/core/engine.py | 5 +- nebula/core/nebulaevents.py | 5 +- nebula/core/network/actions.py | 6 +- nebula/core/pb/nebula.proto | 17 --- nebula/core/pb/nebula_pb2.py | 100 +++++++++--------- .../awareness/sanetwork/sanetwork.py | 14 ++- .../awareness/sareasoner.py | 4 +- .../sareputation/behaviorreputation.py | 2 +- .../discovery/federationconnector.py | 4 +- .../situationalawareness.py | 58 +++++++--- 12 files changed, 117 insertions(+), 101 deletions(-) diff --git a/nebula/addons/mobility.py b/nebula/addons/mobility.py index ed0707187..b94e6b520 100755 --- a/nebula/addons/mobility.py +++ b/nebula/addons/mobility.py @@ -273,7 +273,6 @@ async def change_geo_location(self): random.seed(time.time() + self.config.participant["device_args"]["idx"]) latitude = float(self.config.participant["mobility_args"]["latitude"]) longitude = float(self.config.participant["mobility_args"]["longitude"]) - if True: if True: # Get neighbor closer to me async with self._nodes_distances_lock: @@ -289,7 +288,6 @@ async def change_geo_location(self): # If the distance is too big, we move towards the neighbor if self._verbose: logging.info(f"Moving towards nearest neighbor: {addr}") await self.change_geo_location_nearest_neighbor_strategy( - dist, dist, latitude, longitude, diff --git a/nebula/core/datasets/nebuladataset.py b/nebula/core/datasets/nebuladataset.py index f7da66b5b..ac0f01fff 100755 --- a/nebula/core/datasets/nebuladataset.py +++ b/nebula/core/datasets/nebuladataset.py @@ -371,7 +371,6 @@ def data_partitioning(self, plot=False): f"Partitioning data for {self.__class__.__name__} | Partitions: {self.partitions_number} | IID: {self.iid} | Partition: {self.partition} | Partition parameter: {self.partition_parameter}" ) - # self.iid = "a" #TODO REMOVE modificar para q sea string y no boolean el input del front logging.info(f"Scenario with data distribution: {self.iid}") if self.iid == "IID": self.train_indices_map = self.generate_iid_map(self.train_set) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 57c6077cf..fd2bf46f5 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -142,7 +142,7 @@ def __init__( self._addon_manager = AddondManager(self, self.config) # Additional Components - self._situational_awareness = SituationalAwareness(self.config) + self._situational_awareness = SituationalAwareness(self.config, self) @property def cm(self): @@ -543,9 +543,6 @@ def learning_cycle_finished(self): async def _learning_cycle(self): while self.round is not None and self.round < self.total_rounds: current_time = time.time() - rse = RoundStartEvent(self.round, current_time) - await EventManager.get_instance().publish_node_event(rse) - print_msg_box( msg=f"Round {self.round} of {self.total_rounds - 1} started (max. {self.total_rounds} rounds)", indent=2, diff --git a/nebula/core/nebulaevents.py b/nebula/core/nebulaevents.py index ba69c0500..45e41f2f8 100644 --- a/nebula/core/nebulaevents.py +++ b/nebula/core/nebulaevents.py @@ -31,7 +31,7 @@ def __init__(self, message_type, source, message): class RoundStartEvent(NodeEvent): - def __init__(self, round, start_time): + def __init__(self, round, start_time, expected_nodes): """Event triggered when round is going to start. Args: @@ -40,6 +40,7 @@ def __init__(self, round, start_time): """ self._round_start_time = start_time self._round = round + self._expected_nodes = expected_nodes def __str__(self): return "Round starting" @@ -52,7 +53,7 @@ async def get_event_data(self): -round (int): Round number. -start_time (time): Current time when round is going to start. """ - return (self._round, self._round_start_time) + return (self._round, self._round_start_time, self._expected_nodes) async def is_concurrent(self): return False diff --git a/nebula/core/network/actions.py b/nebula/core/network/actions.py index c0d56cf6e..0c27eaf65 100644 --- a/nebula/core/network/actions.py +++ b/nebula/core/network/actions.py @@ -47,9 +47,9 @@ class LinkAction(Enum): class ReputationAction(Enum): - SHARE_REPUTATION = nebula_pb2.LinkMessage.Action.SHARE_REPUTATION - START_TRIAL = nebula_pb2.LinkMessage.Action.START_TRIAL - SUBMIT_VERDICT = nebula_pb2.LinkMessage.Action.SUBMIT_VERDICT + SHARE_REPUTATION = nebula_pb2.ReputationMessage.Action.SHARE_REPUTATION + START_TRIAL = nebula_pb2.ReputationMessage.Action.START_TRIAL + SUBMIT_VERDICT = nebula_pb2.ReputationMessage.Action.SUBMIT_VERDICT ACTION_CLASSES = { diff --git a/nebula/core/pb/nebula.proto b/nebula/core/pb/nebula.proto index 6ef005311..8faa70281 100755 --- a/nebula/core/pb/nebula.proto +++ b/nebula/core/pb/nebula.proto @@ -91,14 +91,6 @@ message DiscoverMessage { Action action = 1; } -message DiscoverMessage { - enum Action { - DISCOVER_JOIN = 0; // Message to discover nodes on federation when i'm new - DISCOVER_NODES = 1; // Message to discover nodes on federation when i'm already in - } - Action action = 1; -} - message OfferMessage{ enum Action { OFFER_MODEL = 0; // Message to offer model info to a new node @@ -139,12 +131,3 @@ message ResponseMessage { string response = 1; // Outcome of the requested operation. } -message ReputationMessage { - enum Action { - SHARE = 0; - } - string node_id = 1; //Id of the node to which the reputation is sent - float score = 2; //Score reputation - int32 round = 3; //Round to send the reputation - Action action = 4; // Action type (default: SHARE) -} diff --git a/nebula/core/pb/nebula_pb2.py b/nebula/core/pb/nebula_pb2.py index f26fb978e..b8424e1ad 100644 --- a/nebula/core/pb/nebula_pb2.py +++ b/nebula/core/pb/nebula_pb2.py @@ -1,61 +1,61 @@ +# -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: nebula.proto """Generated protocol buffer code.""" - +from google.protobuf.internal import builder as _builder from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder - # @@protoc_insertion_point(imports) - + _sym_db = _symbol_database.Default() - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x0cnebula.proto\x12\x06nebula"\xf5\x03\n\x07Wrapper\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x35\n\x11\x64iscovery_message\x18\x02 \x01(\x0b\x32\x18.nebula.DiscoveryMessageH\x00\x12\x31\n\x0f\x63ontrol_message\x18\x03 \x01(\x0b\x32\x16.nebula.ControlMessageH\x00\x12\x37\n\x12\x66\x65\x64\x65ration_message\x18\x04 \x01(\x0b\x32\x19.nebula.FederationMessageH\x00\x12-\n\rmodel_message\x18\x05 \x01(\x0b\x32\x14.nebula.ModelMessageH\x00\x12\x37\n\x12\x63onnection_message\x18\x06 \x01(\x0b\x32\x19.nebula.ConnectionMessageH\x00\x12\x33\n\x10response_message\x18\x07 \x01(\x0b\x32\x17.nebula.ResponseMessageH\x00\x12\x33\n\x10\x64iscover_message\x18\x08 \x01(\x0b\x32\x17.nebula.DiscoverMessageH\x00\x12-\n\roffer_message\x18\t \x01(\x0b\x32\x14.nebula.OfferMessageH\x00\x12+\n\x0clink_message\x18\n \x01(\x0b\x32\x13.nebula.LinkMessageH\x00\x42\t\n\x07message"\x9e\x01\n\x10\x44iscoveryMessage\x12/\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1f.nebula.DiscoveryMessage.Action\x12\x10\n\x08latitude\x18\x02 \x01(\x02\x12\x11\n\tlongitude\x18\x03 \x01(\x02"4\n\x06\x41\x63tion\x12\x0c\n\x08\x44ISCOVER\x10\x00\x12\x0c\n\x08REGISTER\x10\x01\x12\x0e\n\nDEREGISTER\x10\x02"\x9a\x01\n\x0e\x43ontrolMessage\x12-\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1d.nebula.ControlMessage.Action\x12\x0b\n\x03log\x18\x02 \x01(\t"L\n\x06\x41\x63tion\x12\t\n\x05\x41LIVE\x10\x00\x12\x0c\n\x08OVERHEAD\x10\x01\x12\x0c\n\x08MOBILITY\x10\x02\x12\x0c\n\x08RECOVERY\x10\x03\x12\r\n\tWEAK_LINK\x10\x04"\xcd\x01\n\x11\x46\x65\x64\x65rationMessage\x12\x30\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32 .nebula.FederationMessage.Action\x12\x11\n\targuments\x18\x02 \x03(\t\x12\r\n\x05round\x18\x03 \x01(\x05"d\n\x06\x41\x63tion\x12\x14\n\x10\x46\x45\x44\x45RATION_START\x10\x00\x12\x0e\n\nREPUTATION\x10\x01\x12\x1e\n\x1a\x46\x45\x44\x45RATION_MODELS_INCLUDED\x10\x02\x12\x14\n\x10\x46\x45\x44\x45RATION_READY\x10\x03"A\n\x0cModelMessage\x12\x12\n\nparameters\x18\x01 \x01(\x0c\x12\x0e\n\x06weight\x18\x02 \x01(\x03\x12\r\n\x05round\x18\x03 \x01(\x05"\x8f\x01\n\x11\x43onnectionMessage\x12\x30\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32 .nebula.ConnectionMessage.Action"H\n\x06\x41\x63tion\x12\x0b\n\x07\x43ONNECT\x10\x00\x12\x0e\n\nDISCONNECT\x10\x01\x12\x10\n\x0cLATE_CONNECT\x10\x02\x12\x0f\n\x0bRESTRUCTURE\x10\x03"r\n\x0f\x44iscoverMessage\x12.\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1e.nebula.DiscoverMessage.Action"/\n\x06\x41\x63tion\x12\x11\n\rDISCOVER_JOIN\x10\x00\x12\x12\n\x0e\x44ISCOVER_NODES\x10\x01"\xce\x01\n\x0cOfferMessage\x12+\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1b.nebula.OfferMessage.Action\x12\x13\n\x0bn_neighbors\x18\x02 \x01(\x02\x12\x0c\n\x04loss\x18\x03 \x01(\x02\x12\x12\n\nparameters\x18\x04 \x01(\x0c\x12\x0e\n\x06rounds\x18\x05 \x01(\x05\x12\r\n\x05round\x18\x06 \x01(\x05\x12\x0e\n\x06\x65pochs\x18\x07 \x01(\x05"+\n\x06\x41\x63tion\x12\x0f\n\x0bOFFER_MODEL\x10\x00\x12\x10\n\x0cOFFER_METRIC\x10\x01"w\n\x0bLinkMessage\x12*\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1a.nebula.LinkMessage.Action\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x01(\t"-\n\x06\x41\x63tion\x12\x0e\n\nCONNECT_TO\x10\x00\x12\x13\n\x0f\x44ISCONNECT_FROM\x10\x01"\xc2\x01\n\x11ReputationMessage\x12\x30\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32 .nebula.ReputationMessage.Action\x12\x12\n\nreputation\x18\x02 \x01(\t\x12\x11\n\tdefendant\x18\x03 \x01(\t\x12\x0f\n\x07verdict\x18\x04 \x01(\t"C\n\x06\x41\x63tion\x12\x14\n\x10SHARE_REPUTATION\x10\x00\x12\x0f\n\x0bSTART_TRIAL\x10\x01\x12\x12\n\x0eSUBMIT_VERDICT\x10\x02"#\n\x0fResponseMessage\x12\x10\n\x08response\x18\x01 \x01(\tb\x06proto3' -) - + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cnebula.proto\x12\x06nebula\"\xae\x04\n\x07Wrapper\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x35\n\x11\x64iscovery_message\x18\x02 \x01(\x0b\x32\x18.nebula.DiscoveryMessageH\x00\x12\x31\n\x0f\x63ontrol_message\x18\x03 \x01(\x0b\x32\x16.nebula.ControlMessageH\x00\x12\x37\n\x12\x66\x65\x64\x65ration_message\x18\x04 \x01(\x0b\x32\x19.nebula.FederationMessageH\x00\x12-\n\rmodel_message\x18\x05 \x01(\x0b\x32\x14.nebula.ModelMessageH\x00\x12\x37\n\x12\x63onnection_message\x18\x06 \x01(\x0b\x32\x19.nebula.ConnectionMessageH\x00\x12\x33\n\x10response_message\x18\x07 \x01(\x0b\x32\x17.nebula.ResponseMessageH\x00\x12\x37\n\x12reputation_message\x18\x08 \x01(\x0b\x32\x19.nebula.ReputationMessageH\x00\x12\x33\n\x10\x64iscover_message\x18\t \x01(\x0b\x32\x17.nebula.DiscoverMessageH\x00\x12-\n\roffer_message\x18\n \x01(\x0b\x32\x14.nebula.OfferMessageH\x00\x12+\n\x0clink_message\x18\x0b \x01(\x0b\x32\x13.nebula.LinkMessageH\x00\x42\t\n\x07message\"\x9e\x01\n\x10\x44iscoveryMessage\x12/\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1f.nebula.DiscoveryMessage.Action\x12\x10\n\x08latitude\x18\x02 \x01(\x02\x12\x11\n\tlongitude\x18\x03 \x01(\x02\"4\n\x06\x41\x63tion\x12\x0c\n\x08\x44ISCOVER\x10\x00\x12\x0c\n\x08REGISTER\x10\x01\x12\x0e\n\nDEREGISTER\x10\x02\"\x9a\x01\n\x0e\x43ontrolMessage\x12-\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1d.nebula.ControlMessage.Action\x12\x0b\n\x03log\x18\x02 \x01(\t\"L\n\x06\x41\x63tion\x12\t\n\x05\x41LIVE\x10\x00\x12\x0c\n\x08OVERHEAD\x10\x01\x12\x0c\n\x08MOBILITY\x10\x02\x12\x0c\n\x08RECOVERY\x10\x03\x12\r\n\tWEAK_LINK\x10\x04\"\xcd\x01\n\x11\x46\x65\x64\x65rationMessage\x12\x30\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32 .nebula.FederationMessage.Action\x12\x11\n\targuments\x18\x02 \x03(\t\x12\r\n\x05round\x18\x03 \x01(\x05\"d\n\x06\x41\x63tion\x12\x14\n\x10\x46\x45\x44\x45RATION_START\x10\x00\x12\x0e\n\nREPUTATION\x10\x01\x12\x1e\n\x1a\x46\x45\x44\x45RATION_MODELS_INCLUDED\x10\x02\x12\x14\n\x10\x46\x45\x44\x45RATION_READY\x10\x03\"A\n\x0cModelMessage\x12\x12\n\nparameters\x18\x01 \x01(\x0c\x12\x0e\n\x06weight\x18\x02 \x01(\x03\x12\r\n\x05round\x18\x03 \x01(\x05\"\x8f\x01\n\x11\x43onnectionMessage\x12\x30\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32 .nebula.ConnectionMessage.Action\"H\n\x06\x41\x63tion\x12\x0b\n\x07\x43ONNECT\x10\x00\x12\x0e\n\nDISCONNECT\x10\x01\x12\x10\n\x0cLATE_CONNECT\x10\x02\x12\x0f\n\x0bRESTRUCTURE\x10\x03\"\x95\x01\n\x0f\x44iscoverMessage\x12.\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1e.nebula.DiscoverMessage.Action\"R\n\x06\x41\x63tion\x12\x11\n\rDISCOVER_JOIN\x10\x00\x12\x12\n\x0e\x44ISCOVER_NODES\x10\x01\x12\x10\n\x0cLATE_CONNECT\x10\x02\x12\x0f\n\x0bRESTRUCTURE\x10\x03\"\xce\x01\n\x0cOfferMessage\x12+\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1b.nebula.OfferMessage.Action\x12\x13\n\x0bn_neighbors\x18\x02 \x01(\x02\x12\x0c\n\x04loss\x18\x03 \x01(\x02\x12\x12\n\nparameters\x18\x04 \x01(\x0c\x12\x0e\n\x06rounds\x18\x05 \x01(\x05\x12\r\n\x05round\x18\x06 \x01(\x05\x12\x0e\n\x06\x65pochs\x18\x07 \x01(\x05\"+\n\x06\x41\x63tion\x12\x0f\n\x0bOFFER_MODEL\x10\x00\x12\x10\n\x0cOFFER_METRIC\x10\x01\"w\n\x0bLinkMessage\x12*\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1a.nebula.LinkMessage.Action\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x01(\t\"-\n\x06\x41\x63tion\x12\x0e\n\nCONNECT_TO\x10\x00\x12\x13\n\x0f\x44ISCONNECT_FROM\x10\x01\"#\n\x0fResponseMessage\x12\x10\n\x08response\x18\x01 \x01(\t\"\xc2\x01\n\x11ReputationMessage\x12\x30\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32 .nebula.ReputationMessage.Action\x12\x12\n\nreputation\x18\x02 \x01(\t\x12\x11\n\tdefendant\x18\x03 \x01(\t\x12\x0f\n\x07verdict\x18\x04 \x01(\t\"C\n\x06\x41\x63tion\x12\x14\n\x10SHARE_REPUTATION\x10\x00\x12\x0f\n\x0bSTART_TRIAL\x10\x01\x12\x12\n\x0eSUBMIT_VERDICT\x10\x02\x62\x06proto3') + _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "nebula_pb2", globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'nebula_pb2', globals()) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _WRAPPER._serialized_start = 25 - _WRAPPER._serialized_end = 526 - _DISCOVERYMESSAGE._serialized_start = 529 - _DISCOVERYMESSAGE._serialized_end = 687 - _DISCOVERYMESSAGE_ACTION._serialized_start = 635 - _DISCOVERYMESSAGE_ACTION._serialized_end = 687 - _CONTROLMESSAGE._serialized_start = 690 - _CONTROLMESSAGE._serialized_end = 844 - _CONTROLMESSAGE_ACTION._serialized_start = 768 - _CONTROLMESSAGE_ACTION._serialized_end = 844 - _FEDERATIONMESSAGE._serialized_start = 847 - _FEDERATIONMESSAGE._serialized_end = 1052 - _FEDERATIONMESSAGE_ACTION._serialized_start = 952 - _FEDERATIONMESSAGE_ACTION._serialized_end = 1052 - _MODELMESSAGE._serialized_start = 1054 - _MODELMESSAGE._serialized_end = 1119 - _CONNECTIONMESSAGE._serialized_start = 1122 - _CONNECTIONMESSAGE._serialized_end = 1265 - _CONNECTIONMESSAGE_ACTION._serialized_start = 1193 - _CONNECTIONMESSAGE_ACTION._serialized_end = 1265 - _DISCOVERMESSAGE._serialized_start = 1267 - _DISCOVERMESSAGE._serialized_end = 1381 - _DISCOVERMESSAGE_ACTION._serialized_start = 1334 - _DISCOVERMESSAGE_ACTION._serialized_end = 1381 - _OFFERMESSAGE._serialized_start = 1384 - _OFFERMESSAGE._serialized_end = 1590 - _OFFERMESSAGE_ACTION._serialized_start = 1547 - _OFFERMESSAGE_ACTION._serialized_end = 1590 - _LINKMESSAGE._serialized_start = 1592 - _LINKMESSAGE._serialized_end = 1711 - _LINKMESSAGE_ACTION._serialized_start = 1666 - _LINKMESSAGE_ACTION._serialized_end = 1711 - _REPUTATIONMESSAGE._serialized_start = 1714 - _REPUTATIONMESSAGE._serialized_end = 1908 - _REPUTATIONMESSAGE_ACTION._serialized_start = 1841 - _REPUTATIONMESSAGE_ACTION._serialized_end = 1908 - _RESPONSEMESSAGE._serialized_start = 1910 - _RESPONSEMESSAGE._serialized_end = 1945 + + DESCRIPTOR._options = None + _WRAPPER._serialized_start=25 + _WRAPPER._serialized_end=583 + _DISCOVERYMESSAGE._serialized_start=586 + _DISCOVERYMESSAGE._serialized_end=744 + _DISCOVERYMESSAGE_ACTION._serialized_start=692 + _DISCOVERYMESSAGE_ACTION._serialized_end=744 + _CONTROLMESSAGE._serialized_start=747 + _CONTROLMESSAGE._serialized_end=901 + _CONTROLMESSAGE_ACTION._serialized_start=825 + _CONTROLMESSAGE_ACTION._serialized_end=901 + _FEDERATIONMESSAGE._serialized_start=904 + _FEDERATIONMESSAGE._serialized_end=1109 + _FEDERATIONMESSAGE_ACTION._serialized_start=1009 + _FEDERATIONMESSAGE_ACTION._serialized_end=1109 + _MODELMESSAGE._serialized_start=1111 + _MODELMESSAGE._serialized_end=1176 + _CONNECTIONMESSAGE._serialized_start=1179 + _CONNECTIONMESSAGE._serialized_end=1322 + _CONNECTIONMESSAGE_ACTION._serialized_start=1250 + _CONNECTIONMESSAGE_ACTION._serialized_end=1322 + _DISCOVERMESSAGE._serialized_start=1325 + _DISCOVERMESSAGE._serialized_end=1474 + _DISCOVERMESSAGE_ACTION._serialized_start=1392 + _DISCOVERMESSAGE_ACTION._serialized_end=1474 + _OFFERMESSAGE._serialized_start=1477 + _OFFERMESSAGE._serialized_end=1683 + _OFFERMESSAGE_ACTION._serialized_start=1640 + _OFFERMESSAGE_ACTION._serialized_end=1683 + _LINKMESSAGE._serialized_start=1685 + _LINKMESSAGE._serialized_end=1804 + _LINKMESSAGE_ACTION._serialized_start=1759 + _LINKMESSAGE_ACTION._serialized_end=1804 + _RESPONSEMESSAGE._serialized_start=1806 + _RESPONSEMESSAGE._serialized_end=1841 + _REPUTATIONMESSAGE._serialized_start=1844 + _REPUTATIONMESSAGE._serialized_end=2038 + _REPUTATIONMESSAGE_ACTION._serialized_start=1971 + _REPUTATIONMESSAGE_ACTION._serialized_end=2038 # @@protoc_insertion_point(module_scope) diff --git a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py index e541ca6e7..6f3aea784 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py @@ -86,8 +86,9 @@ async def init(self): self._strict_topology, ]) - await EventManager.get_instance().subscribe_node_event(NodeFoundEvent, self.process_node_found_event) - await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent, self.process_update_neighbor_event) + await EventManager.get_instance().subscribe_node_event(NodeFoundEvent, self._process_node_found_event) + await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent, self._process_update_neighbor_event) + await EventManager.get_instance().subscribe_node_event(RoundEndEvent, self._process_round_end_event) await self.sana.register_sa_agent() async def sa_component_actions(self): @@ -101,15 +102,18 @@ async def sa_component_actions(self): ############################### """ - async def process_node_found_event(self, nfe : NodeFoundEvent): + async def _process_node_found_event(self, nfe : NodeFoundEvent): node_addr = await nfe.get_event_data() if self._verbose: logging.info(f"Processing Node Found Event, node addr: {node_addr}") self.np.meet_node(node_addr) - async def process_update_neighbor_event(self, une : UpdateNeighborEvent): + async def _process_update_neighbor_event(self, une : UpdateNeighborEvent): node_addr, removed = await une.get_event_data() if self._verbose: logging.info(f"Processing Update Neighbor Event, node addr: {node_addr}, remove: {removed}") - self.np.update_neighbors(node_addr, removed) + self.np.update_neighbors(node_addr, removed) + + async def _process_round_end_event(self, ree: RoundEndEvent): + await self._analize_topology_robustness() def meet_node(self, node): if node != self._addr: diff --git a/nebula/core/situationalawareness/awareness/sareasoner.py b/nebula/core/situationalawareness/awareness/sareasoner.py index dde27445f..37f32630a 100644 --- a/nebula/core/situationalawareness/awareness/sareasoner.py +++ b/nebula/core/situationalawareness/awareness/sareasoner.py @@ -79,13 +79,13 @@ def sad(self): """SA Discovery""" return self._sa_discovery - async def init(self, sa_discovery: ISADiscovery): + async def init(self, sa_discovery): #await self.loading_sa_components() from nebula.core.situationalawareness.awareness.sanetwork.sanetwork import SANetwork #from nebula.core.situationalawareness.awareness.satraining.satraining import SATraining self._situational_awareness_network = SANetwork(self, self._addr, self._topology, verbose=True) #self._situational_awareness_training = SATraining(self, self._addr, "qds", "fastreboot", verbose=True) - self._sa_discovery = sa_discovery + self._sa_discovery: ISADiscovery = sa_discovery await self.san.init() await EventManager.get_instance().subscribe_node_event(RoundEndEvent, self._process_round_end_event) await EventManager.get_instance().subscribe_node_event(AggregationEvent, self._process_aggregation_event) diff --git a/nebula/core/situationalawareness/awareness/sareputation/behaviorreputation.py b/nebula/core/situationalawareness/awareness/sareputation/behaviorreputation.py index c2e5b8082..d78f39cae 100644 --- a/nebula/core/situationalawareness/awareness/sareputation/behaviorreputation.py +++ b/nebula/core/situationalawareness/awareness/sareputation/behaviorreputation.py @@ -132,7 +132,7 @@ async def _process_round_start(self, rse: RoundStartEvent): if self._verbose: logging.info("Processing round start event") if not self._last_aggregation_time: if self._verbose: logging.info("First round start timing assigment") - (_, start_time) = await rse.get_event_data() + (_, start_time,_) = await rse.get_event_data() self._last_aggregation_time = start_time self._internal_rounds_done += 1 diff --git a/nebula/core/situationalawareness/discovery/federationconnector.py b/nebula/core/situationalawareness/discovery/federationconnector.py index 138060c68..1676ee180 100644 --- a/nebula/core/situationalawareness/discovery/federationconnector.py +++ b/nebula/core/situationalawareness/discovery/federationconnector.py @@ -79,7 +79,7 @@ def sar(self): def is_additional_participant(self): return self._aditional_participant - async def init(self, sa_reasoner: ISAReasoner): + async def init(self, sa_reasoner): """ model_handler config: - self total rounds @@ -92,7 +92,7 @@ async def init(self, sa_reasoner: ISAReasoner): - self weight hetereogeneity """ logging.info("Building Federation Connector configurations...") - self._sa_reasoner = sa_reasoner + self._sa_reasoner: ISAReasoner = sa_reasoner await self.register_message_events_callbacks() await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent, self.update_neighbors) logging.info("Building candidate selector configuration..") diff --git a/nebula/core/situationalawareness/situationalawareness.py b/nebula/core/situationalawareness/situationalawareness.py index d74f03dda..02b20260f 100644 --- a/nebula/core/situationalawareness/situationalawareness.py +++ b/nebula/core/situationalawareness/situationalawareness.py @@ -1,14 +1,24 @@ -from nebula.core.situationalawareness.discovery.federationconnector import FederationConnector -from nebula.core.situationalawareness.awareness.sareasoner import SAReasoner from abc import ABC, abstractmethod from nebula.addons.functions import print_msg_box class ISADiscovery(ABC): + @abstractmethod + async def init(self, sa_reasoner): + raise NotImplementedError + @abstractmethod async def start_late_connection_process(self, connected=False, msg_type="discover_join", addrs_known=None): raise NotImplementedError + + @abstractmethod + async def get_trainning_info(self): + raise NotImplementedError class ISAReasoner(ABC): + @abstractmethod + async def init(self, sa_discovery): + raise NotImplementedError + @abstractmethod def accept_connection(self, source, joining=False): raise NotImplementedError @@ -20,9 +30,31 @@ def get_nodes_known(self, neighbors_too=False, neighbors_only=False): @abstractmethod def get_actions(self): raise NotImplementedError + +def factory_sa_discovery(sa_discovery, additional, topology, model_handler, engine, verbose) -> ISADiscovery: + from nebula.core.situationalawareness.discovery.federationconnector import FederationConnector + DISCOVERY = { + "fedcon": FederationConnector, + } + sad = DISCOVERY.get(sa_discovery) + if sad: + return sad(additional, topology, model_handler, engine, verbose) + else: + raise Exception(f"SA Discovery service {sa_discovery} not found.") + +def factory_sa_reasoner(sa_reasoner, config, addr, topology, verbose) -> ISAReasoner: + from nebula.core.situationalawareness.awareness.sareasoner import SAReasoner + REASONER = { + "nebula_reasoner": SAReasoner, + } + sar = REASONER.get(sa_reasoner) + if sar: + return sar(config, addr, topology, verbose) + else: + raise Exception(f"SA Reasoner service {sa_reasoner} not found.") class SituationalAwareness(): - def __init__(self, config): + def __init__(self, config, engine): print_msg_box( msg=f"Starting Situational Awareness module...", indent=2, @@ -32,14 +64,16 @@ def __init__(self, config): topology = self._config.participant["mobility_args"]["topology_type"] topology = topology.lower() model_handler = "std" - self._federation_connector = FederationConnector( + self._sad = factory_sa_discovery( + "fedcon", self._config.participant["mobility_args"]["additional_node"]["status"], topology, model_handler, - engine=self, + engine=engine, verbose=True ) - self._sareasoner = SAReasoner( + self._sareasoner = factory_sa_reasoner( + "nebula_reasoner", self._config, self._config.participant["network_args"]["addr"], topology, @@ -47,9 +81,9 @@ def __init__(self, config): ) @property - def fedcon(self): + def sad(self): """Federation Connector""" - return self._federation_connector + return self._sad @property def sar(self): @@ -57,11 +91,11 @@ def sar(self): return self._sareasoner async def init(self): - await self.fedcon.init(self.sar) - await self.sar.init(self.fedcon) + await self.sad.init(self.sar) + await self.sar.init(self.sad) async def start_late_connection_process(self): - await self.fedcon.start_late_connection_process() + await self.sad.start_late_connection_process() async def get_trainning_info(self): - return await self.fedcon.get_trainning_info() + return await self.sad.get_trainning_info() From fdbc0d817ce3c24d9d44b6a622021e3d4066a15c Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Tue, 6 May 2025 10:27:14 +0200 Subject: [PATCH 182/233] refactor soem stuff --- nebula/core/datasets/nebuladataset.py | 13 ++++++------- nebula/core/engine.py | 12 +++--------- .../awareness/sanetwork/sanetwork.py | 4 ---- .../situationalawareness/awareness/sareasoner.py | 5 +++-- .../awareness/suggestionbuffer.py | 2 +- .../discovery/federationconnector.py | 10 +++++++--- 6 files changed, 20 insertions(+), 26 deletions(-) diff --git a/nebula/core/datasets/nebuladataset.py b/nebula/core/datasets/nebuladataset.py index ac0f01fff..e1348bee7 100755 --- a/nebula/core/datasets/nebuladataset.py +++ b/nebula/core/datasets/nebuladataset.py @@ -302,7 +302,7 @@ def __init__( partitions_number=1, batch_size=32, num_workers=4, - iid="IID", + iid=False, partition="dirichlet", partition_parameter=0.5, nsplits_percentages=[1.0], @@ -371,17 +371,16 @@ def data_partitioning(self, plot=False): f"Partitioning data for {self.__class__.__name__} | Partitions: {self.partitions_number} | IID: {self.iid} | Partition: {self.partition} | Partition parameter: {self.partition_parameter}" ) - logging.info(f"Scenario with data distribution: {self.iid}") - if self.iid == "IID": + logging.info(f"Scenario with data distribution IID: {self.iid}") + if self.iid: self.train_indices_map = self.generate_iid_map(self.train_set) - elif self.iid == "Non-IID": + else : self.train_indices_map = self.generate_non_iid_map( self.train_set, partition=self.partition, partition_parameter=self.partition_parameter ) - else: - self.train_indices_map = self.generate_hybrid_map() + # else: + # self.train_indices_map = self.generate_hybrid_map() - self.iid = False # TODO REMOVE self.test_indices_map = self.get_test_indices_map() self.local_test_indices_map = self.get_local_test_indices_map() diff --git a/nebula/core/engine.py b/nebula/core/engine.py index fd2bf46f5..6a70eaf29 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -283,11 +283,6 @@ async def _connection_connect_callback(self, source, message): async def _connection_disconnect_callback(self, source, message): logging.info(f"πŸ”— handle_connection_message | Trigger | Received disconnection message from {source}") - if self.mobility: - if await self.nm.waiting_confirmation_from(source): - await self.nm.confirmation_received(source, confirmation=False) - # if source in await self.cm.get_all_addrs_current_connections(only_direct=True): - await self.nm.update_neighbors(source, remove=True) await self.cm.disconnect(source, mutual_disconnection=False) async def _federation_federation_ready_callback(self, source, message): @@ -372,10 +367,9 @@ async def _aditional_node_start(self): asyncio.create_task(self._start_learning_late()) async def update_neighbors(self, removed_neighbor_addr, neighbors, remove=False): - if self.mobility: - self.federation_nodes = neighbors - updt_nei_event = UpdateNeighborEvent(removed_neighbor_addr, remove) - asyncio.create_task(EventManager.get_instance().publish_node_event(updt_nei_event)) + self.federation_nodes = neighbors + updt_nei_event = UpdateNeighborEvent(removed_neighbor_addr, remove) + asyncio.create_task(EventManager.get_instance().publish_node_event(updt_nei_event)) async def broadcast_models_include(self, age: AggregationEvent): logging.info(f"πŸ”„ Broadcasting MODELS_INCLUDED for round {self.get_round()}") diff --git a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py index 6f3aea784..1993df3e4 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/sanetwork.py @@ -88,7 +88,6 @@ async def init(self): await EventManager.get_instance().subscribe_node_event(NodeFoundEvent, self._process_node_found_event) await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent, self._process_update_neighbor_event) - await EventManager.get_instance().subscribe_node_event(RoundEndEvent, self._process_round_end_event) await self.sana.register_sa_agent() async def sa_component_actions(self): @@ -112,9 +111,6 @@ async def _process_update_neighbor_event(self, une : UpdateNeighborEvent): if self._verbose: logging.info(f"Processing Update Neighbor Event, node addr: {node_addr}, remove: {removed}") self.np.update_neighbors(node_addr, removed) - async def _process_round_end_event(self, ree: RoundEndEvent): - await self._analize_topology_robustness() - def meet_node(self, node): if node != self._addr: self.np.meet_node(node) diff --git a/nebula/core/situationalawareness/awareness/sareasoner.py b/nebula/core/situationalawareness/awareness/sareasoner.py index 37f32630a..2df531da5 100644 --- a/nebula/core/situationalawareness/awareness/sareasoner.py +++ b/nebula/core/situationalawareness/awareness/sareasoner.py @@ -83,13 +83,14 @@ async def init(self, sa_discovery): #await self.loading_sa_components() from nebula.core.situationalawareness.awareness.sanetwork.sanetwork import SANetwork #from nebula.core.situationalawareness.awareness.satraining.satraining import SATraining - self._situational_awareness_network = SANetwork(self, self._addr, self._topology, verbose=True) #self._situational_awareness_training = SATraining(self, self._addr, "qds", "fastreboot", verbose=True) self._sa_discovery: ISADiscovery = sa_discovery - await self.san.init() await EventManager.get_instance().subscribe_node_event(RoundEndEvent, self._process_round_end_event) await EventManager.get_instance().subscribe_node_event(AggregationEvent, self._process_aggregation_event) + self._situational_awareness_network = SANetwork(self, self._addr, self._topology, verbose=True) + await self.san.init() + def is_additional_participant(self): return self._config.participant["mobility_args"]["additional_node"]["status"] diff --git a/nebula/core/situationalawareness/awareness/suggestionbuffer.py b/nebula/core/situationalawareness/awareness/suggestionbuffer.py index 989a2ba80..369dff1f3 100644 --- a/nebula/core/situationalawareness/awareness/suggestionbuffer.py +++ b/nebula/core/situationalawareness/awareness/suggestionbuffer.py @@ -60,7 +60,7 @@ async def register_suggestion(self, event_type, agent: SAModuleAgent, suggestion async def set_event_waited(self, event_type): """Registers event to be waited""" if not self._event_waited: - if self._verbose: logging.info(f"Set notification when all suggestions are being received for event: {event_type. __name__}") + if self._verbose: logging.info(f"Set notification when all suggestions have being received for event: {event_type. __name__}") self._event_waited = event_type await self._notify_arbitrator(event_type) diff --git a/nebula/core/situationalawareness/discovery/federationconnector.py b/nebula/core/situationalawareness/discovery/federationconnector.py index 1676ee180..5616af419 100644 --- a/nebula/core/situationalawareness/discovery/federationconnector.py +++ b/nebula/core/situationalawareness/discovery/federationconnector.py @@ -6,7 +6,6 @@ from nebula.core.situationalawareness.discovery.candidateselection.candidateselector import factory_CandidateSelector from nebula.core.situationalawareness.discovery.modelhandlers.modelhandler import factory_ModelHandler from nebula.core.situationalawareness.situationalawareness import ISADiscovery, ISAReasoner -from nebula.core.situationalawareness.awareness.sareasoner import SAReasoner from nebula.core.utils.locker import Locker from nebula.core.eventmanager import EventManager from nebula.core.nebulaevents import UpdateNeighborEvent, NodeFoundEvent @@ -95,6 +94,7 @@ async def init(self, sa_reasoner): self._sa_reasoner: ISAReasoner = sa_reasoner await self.register_message_events_callbacks() await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent, self.update_neighbors) + logging.info("Building candidate selector configuration..") self.candidate_selector.set_config([0, 0.5, 0.5]) # self.engine.trainer.get_loss(), self.config.participant["molibity_args"]["weight_distance"], self.config.participant["molibity_args"]["weight_het"] @@ -104,7 +104,7 @@ async def init(self, sa_reasoner): # CONNECTIONS # ############################## """ - + def accept_connection(self, source, joining=False): return self.sar.accept_connection(source, joining) @@ -284,6 +284,10 @@ async def register_message_events_callbacks(self): if callable(method): await EventManager.get_instance().subscribe((event_type, action), method) + + async def _connection_disconnect_callback(self, source, message): + if await self.waiting_confirmation_from(source): + await self.confirmation_received(source, confirmation=False) async def _connection_late_connect_callback(self, source, message): logging.info(f"πŸ”— handle_connection_message | Trigger | Received late connect message from {source}") @@ -438,7 +442,7 @@ async def _link_disconnect_from_callback(self, source, message): logging.info(f"πŸ”— handle_link_message | Trigger | Received disconnect_from message from {source}") addrs = message.addrs for addr in addrs.split(): - await self.cm.disconnect(source, mutual_disconnection=False) + await asyncio.create_task(self.cm.disconnect(source, mutual_disconnection=False)) From 952f5c60e56d5d34a9e8ec7a12d55cf8251945f8 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Tue, 6 May 2025 11:36:13 +0200 Subject: [PATCH 183/233] refactor private methods --- nebula/core/engine.py | 25 +++-- .../discovery/federationconnector.py | 106 +++++++++--------- 2 files changed, 67 insertions(+), 64 deletions(-) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 6a70eaf29..52535f1e4 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -99,6 +99,7 @@ def __init__( self.round = None self.total_rounds = None self.federation_nodes = set() + self._federation_nodes_lock = Locker("federation_nodes_lock", async_lock=True) self.initialized = False self.log_dir = os.path.join(config.participant["tracking_args"]["log_dir"], self.experiment_name) @@ -156,9 +157,6 @@ def reporter(self): def aggregator(self): return self._aggregator - def get_aggregator_type(self): - return type(self.aggregator) - @property def trainer(self): return self._trainer @@ -166,6 +164,9 @@ def trainer(self): @property def sa(self): return self._situational_awareness + + def get_aggregator_type(self): + return type(self.aggregator) def get_addr(self): return self.addr @@ -173,8 +174,13 @@ def get_addr(self): def get_config(self): return self.config - def get_federation_nodes(self): - return self.federation_nodes + async def get_federation_nodes(self): + async with self._federation_nodes_lock: + return self.federation_nodes.copy() + + async def update_federation_nodes(self, federation_nodes): + async with self._federation_nodes_lock: + self.federation_nodes = federation_nodes def get_initialization_status(self): return self.initialized @@ -228,7 +234,7 @@ async def model_initialization_callback(self, source, message): async def model_update_callback(self, source, message): logging.info(f"πŸ€– handle_model_message | Received model update from {source} with round {message.round}") - if not self.get_federation_ready_lock().locked() and len(self.get_federation_nodes()) == 0: + if not self.get_federation_ready_lock().locked() and len(await self.get_federation_nodes()) == 0: logging.info("πŸ€– handle_model_message | There are no defined federation nodes") return decoded_model = self.trainer.deserialize_model(message.parameters) @@ -362,12 +368,11 @@ async def _aditional_node_start(self): logging.info(f"Aditional node | {self.addr} | going to stablish connection with federation") await self.sa.start_late_connection_process() # continue .. - # asyncio.create_task(self.sa.stop_not_selected_connections()) logging.info("Creating trainer service to start the federation process..") asyncio.create_task(self._start_learning_late()) async def update_neighbors(self, removed_neighbor_addr, neighbors, remove=False): - self.federation_nodes = neighbors + await self.update_federation_nodes(neighbors) updt_nei_event = UpdateNeighborEvent(removed_neighbor_addr, remove) asyncio.create_task(EventManager.get_instance().publish_node_event(updt_nei_event)) @@ -543,8 +548,8 @@ async def _learning_cycle(self): title="Round information", ) logging.info(f"Federation nodes: {self.federation_nodes}") - self.federation_nodes = await self.cm.get_addrs_current_connections(only_direct=True, myself=True) - expected_nodes = self.federation_nodes.copy() + await self.update_federation_nodes(await self.cm.get_addrs_current_connections(only_direct=True, myself=True)) + expected_nodes = await self.get_federation_nodes() rse = RoundStartEvent(self.round, current_time, expected_nodes) await EventManager.get_instance().publish_node_event(rse) self.trainer.on_round_start() diff --git a/nebula/core/situationalawareness/discovery/federationconnector.py b/nebula/core/situationalawareness/discovery/federationconnector.py index 5616af419..5892d8e1b 100644 --- a/nebula/core/situationalawareness/discovery/federationconnector.py +++ b/nebula/core/situationalawareness/discovery/federationconnector.py @@ -16,11 +16,10 @@ from nebula.core.engine import Engine RESTRUCTURE_COOLDOWN = 5 - +OFFER_TIMEOUT = 5 class FederationConnector(ISADiscovery): - OFFER_TIMEOUT = 5 - + def __init__( self, aditional_participant, @@ -47,7 +46,7 @@ def __init__( self.pending_confirmation_from_nodes = set() self.pending_confirmation_from_nodes_lock = Locker(name="pending_confirmation_from_nodes_lock", async_lock=True) self.accept_candidates_lock = Locker(name="accept_candidates_lock") - self.recieve_offer_timer = self.OFFER_TIMEOUT + self.recieve_offer_timer = OFFER_TIMEOUT self.discarded_offers_addr_lock = Locker(name="discarded_offers_addr_lock") self.discarded_offers_addr = [] @@ -75,9 +74,6 @@ def sar(self): """Situational Awareness Module""" return self._sa_reasoner - def is_additional_participant(self): - return self._aditional_participant - async def init(self, sa_reasoner): """ model_handler config: @@ -92,8 +88,9 @@ async def init(self, sa_reasoner): """ logging.info("Building Federation Connector configurations...") self._sa_reasoner: ISAReasoner = sa_reasoner - await self.register_message_events_callbacks() - await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent, self.update_neighbors) + await self._register_message_events_callbacks() + await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent, self._update_neighbors) + await EventManager.get_instance().subscribe(("model", "update"), self._model_update_callback) logging.info("Building candidate selector configuration..") self.candidate_selector.set_config([0, 0.5, 0.5]) @@ -105,13 +102,13 @@ async def init(self, sa_reasoner): ############################## """ - def accept_connection(self, source, joining=False): + def _accept_connection(self, source, joining=False): return self.sar.accept_connection(source, joining) - def still_waiting_for_candidates(self): + def _still_waiting_for_candidates(self): return not self.accept_candidates_lock.locked() and self.late_connection_process_lock.locked() - async def add_pending_connection_confirmation(self, addr): + async def _add_pending_connection_confirmation(self, addr): await self._update_neighbors_lock.acquire_async() await self.pending_confirmation_from_nodes_lock.acquire_async() if addr not in self.sar.get_nodes_known(neighbors_only=True): @@ -125,52 +122,49 @@ async def _remove_pending_confirmation_from(self, addr): self.pending_confirmation_from_nodes.discard(addr) await self.pending_confirmation_from_nodes_lock.release_async() - async def clear_pending_confirmations(self): + async def _clear_pending_confirmations(self): await self.pending_confirmation_from_nodes_lock.acquire_async() self.pending_confirmation_from_nodes.clear() await self.pending_confirmation_from_nodes_lock.release_async() - async def waiting_confirmation_from(self, addr): + async def _waiting_confirmation_from(self, addr): await self.pending_confirmation_from_nodes_lock.acquire_async() found = addr in self.pending_confirmation_from_nodes await self.pending_confirmation_from_nodes_lock.release_async() return found - async def confirmation_received(self, addr, joining=False): + async def _confirmation_received(self, addr, joining=False): logging.info(f" Update | connection confirmation received from: {addr} | joining federation: {joining}") await self.cm.connect(addr, direct=True) await self._remove_pending_confirmation_from(addr) une = UpdateNeighborEvent(addr, joining=joining) await EventManager.get_instance().publish_node_event(une) - def add_to_discarded_offers(self, addr_discarded): + def _add_to_discarded_offers(self, addr_discarded): self.discarded_offers_addr_lock.acquire() self.discarded_offers_addr.append(addr_discarded) self.discarded_offers_addr_lock.release() - def get_actions(self): + def _get_actions(self): return self.sar.get_actions() - async def register_late_neighbor(self, addr, joinning_federation=False): + async def _register_late_neighbor(self, addr, joinning_federation=False): if self._verbose: logging.info(f"Registering | late neighbor: {addr}, joining: {joinning_federation}") une = UpdateNeighborEvent(addr, joining=joinning_federation) await EventManager.get_instance().publish_node_event(une) - async def update_neighbors(self, une : UpdateNeighborEvent): + async def _update_neighbors(self, une : UpdateNeighborEvent): node, remove = await une.get_event_data() await self._update_neighbors_lock.acquire_async() if not remove: - await self.meet_node(node) + await self._meet_node(node) await self._remove_pending_confirmation_from(node) await self._update_neighbors_lock.release_async() - async def meet_node(self, node): + async def _meet_node(self, node): nfe = NodeFoundEvent(node) await EventManager.get_instance().publish_node_event(nfe) - def get_nodes_known(self, neighbors_too=False): - return self.sar.get_nodes_known(neighbors_too) - def accept_model_offer(self, source, decoded_model, rounds, round, epochs, n_neighbors, loss): if not self.accept_candidates_lock.locked(): if self._verbose: logging.info(f"πŸ”„ Processing offer from {source}...") @@ -185,11 +179,11 @@ def accept_model_offer(self, source, decoded_model, rounds, round, epochs, n_nei async def get_trainning_info(self): return await self.model_handler.get_model(None) - def add_candidate(self, source, n_neighbors, loss): + def _add_candidate(self, source, n_neighbors, loss): if not self.accept_candidates_lock.locked(): self.candidate_selector.add_candidate((source, n_neighbors, loss)) - async def stop_not_selected_connections(self): + async def _stop_not_selected_connections(self): try: with self.discarded_offers_addr_lock: if len(self.discarded_offers_addr) > 0: @@ -220,7 +214,7 @@ async def start_late_connection_process(self, connected=False, msg_type="discove self.late_connection_process_lock.acquire() best_candidates = [] self.candidate_selector.remove_candidates() - await self.clear_pending_confirmations() + await self._clear_pending_confirmations() # find federation and send discover connections_stablished = await self.cm.stablish_connection_to_federation(msg_type, addrs_known) @@ -246,7 +240,7 @@ async def start_late_connection_process(self, connected=False, msg_type="discove if self._verbose: logging.info(f"Candidates | {[addr for addr, _, _ in best_candidates]}") try: for addr, _, _ in best_candidates: - await self.add_pending_connection_confirmation(addr) + await self._add_pending_connection_confirmation(addr) await self.cm.send_message(addr, msg) await asyncio.sleep(1) except asyncio.CancelledError: @@ -271,7 +265,7 @@ async def start_late_connection_process(self, connected=False, msg_type="discove ############################## """ - async def register_message_events_callbacks(self): + async def _register_message_events_callbacks(self): me_dict = self.cm.get_messages_events() message_events = [ (message_name, message_action) @@ -286,30 +280,34 @@ async def register_message_events_callbacks(self): await EventManager.get_instance().subscribe((event_type, action), method) async def _connection_disconnect_callback(self, source, message): - if await self.waiting_confirmation_from(source): - await self.confirmation_received(source, confirmation=False) + if await self._waiting_confirmation_from(source): + await self._confirmation_received(source, confirmation=False) + + async def _model_update_callback(self, source, message): + if await self._waiting_confirmation_from(source): + await self._confirmation_received(source, confirmation=False) async def _connection_late_connect_callback(self, source, message): logging.info(f"πŸ”— handle_connection_message | Trigger | Received late connect message from {source}") # Verify if it's a confirmation message from a previous late connection message sent to source - if await self.waiting_confirmation_from(source): - await self.confirmation_received(source, joining=True) + if await self._waiting_confirmation_from(source): + await self._confirmation_received(source, joining=True) return if not self.engine.get_initialization_status(): logging.info("❗️ Connection refused | Device not initialized yet...") return - if self.accept_connection(source, joining=True): + if self._accept_connection(source, joining=True): logging.info(f"πŸ”— handle_connection_message | Late connection accepted | source: {source}") await self.cm.connect(source, direct=True) # Verify conenction is accepted conf_msg = self.cm.create_message("connection", "late_connect") await self.cm.send_message(source, conf_msg) - await self.register_late_neighbor(source, joinning_federation=True) + await self._register_late_neighbor(source, joinning_federation=True) - ct_actions, df_actions = self.get_actions() + ct_actions, df_actions = self._get_actions() if len(ct_actions): cnt_msg = self.cm.create_message("link", "connect_to", addrs=ct_actions) await self.cm.send_message(source, cnt_msg) @@ -324,15 +322,15 @@ async def _connection_late_connect_callback(self, source, message): async def _connection_restructure_callback(self, source, message): logging.info(f"πŸ”— handle_connection_message | Trigger | Received restructure message from {source}") # Verify if it's a confirmation message from a previous restructure connection message sent to source - if await self.waiting_confirmation_from(source): - await self.confirmation_received(source) + if await self._waiting_confirmation_from(source): + await self._confirmation_received(source) return if not self.engine.get_initialization_status(): logging.info("❗️ Connection refused | Device not initialized yet...") return - if self.accept_connection(source, joining=False): + if self._accept_connection(source, joining=False): logging.info(f"πŸ”— handle_connection_message | Trigger | restructure connection accepted from {source}") await self.cm.connect(source, direct=True) @@ -340,7 +338,7 @@ async def _connection_restructure_callback(self, source, message): await self.cm.send_message(source, conf_msg) - ct_actions, df_actions = self.get_actions() + ct_actions, df_actions = self._get_actions() if len(ct_actions): cnt_msg = self.cm.create_message("link", "connect_to", addrs=ct_actions) await self.cm.send_message(source, cnt_msg) @@ -349,13 +347,13 @@ async def _connection_restructure_callback(self, source, message): df_msg = self.cm.create_message("link", "disconnect_from", addrs=df_actions) await self.cm.send_message(source, df_msg) - await self.register_late_neighbor(source, joinning_federation=False) + await self._register_late_neighbor(source, joinning_federation=False) else: logging.info(f"❗️ handle_connection_message | Trigger | restructure connection denied from {source}") async def _discover_discover_join_callback(self, source, message): logging.info(f"πŸ” handle_discover_message | Trigger | Received discover_join message from {source} ") - if len(self.engine.get_federation_nodes()) > 0: + if len(await self.engine.get_federation_nodes()) > 0: await self.engine.trainning_in_progress_lock.acquire_async() model, rounds, round = ( await self.cm.propagator.get_model_information(source, "stable") @@ -368,7 +366,7 @@ async def _discover_discover_join_callback(self, source, message): msg = self.cm.create_message( "offer", "offer_model", - len(self.engine.get_federation_nodes()), + len(await self.engine.get_federation_nodes()), 0, parameters=model, rounds=rounds, @@ -385,11 +383,11 @@ async def _discover_discover_join_callback(self, source, message): async def _discover_discover_nodes_callback(self, source, message): logging.info(f"πŸ” handle_discover_message | Trigger | Received discover_node message from {source} ") - if len(self.engine.get_federation_nodes()) > 0: + if len(await self.engine.get_federation_nodes()) > 0: msg = self.cm.create_message( "offer", "offer_metric", - n_neighbors=len(self.engine.get_federation_nodes()), + n_neighbors=len(await self.engine.get_federation_nodes()), loss=0 #self.engine.trainer.get_current_loss(), ) logging.info(f"Sending offer metric to {source}") @@ -399,8 +397,8 @@ async def _discover_discover_nodes_callback(self, source, message): async def _offer_offer_model_callback(self, source, message): logging.info(f"πŸ” handle_offer_message | Trigger | Received offer_model message from {source}") - await self.meet_node(source) - if self.still_waiting_for_candidates(): + await self._meet_node(source) + if self._still_waiting_for_candidates(): try: model_compressed = message.parameters if self.accept_model_offer( @@ -415,28 +413,28 @@ async def _offer_offer_model_callback(self, source, message): logging.info(f"πŸ”§ Model accepted from offer | source: {source}") else: logging.info(f"❗️ Model offer discarded | source: {source}") - self.add_to_discarded_offers(source) + self._add_to_discarded_offers(source) except RuntimeError: logging.info(f"❗️ Error proccesing offer model from {source}") else: logging.info( - f"❗️ handfle_offer_message | NOT accepting offers | waiting candidates: {self.still_waiting_for_candidates()}" + f"❗️ handfle_offer_message | NOT accepting offers | waiting candidates: {self._still_waiting_for_candidates()}" ) - self.add_to_discarded_offers(source) + self._add_to_discarded_offers(source) async def _offer_offer_metric_callback(self, source, message): logging.info(f"πŸ” handle_offer_message | Trigger | Received offer_metric message from {source}") - await self.meet_node(source) - if self.still_waiting_for_candidates(): + await self._meet_node(source) + if self._still_waiting_for_candidates(): n_neighbors = message.n_neighbors loss = message.loss - self.add_candidate(source, n_neighbors, loss) + self._add_candidate(source, n_neighbors, loss) async def _link_connect_to_callback(self, source, message): logging.info(f"πŸ”— handle_link_message | Trigger | Received connect_to message from {source}") addrs = message.addrs for addr in addrs.split(): - await self.meet_node(addr) + await self._meet_node(addr) async def _link_disconnect_from_callback(self, source, message): logging.info(f"πŸ”— handle_link_message | Trigger | Received disconnect_from message from {source}") From 97bc699d7788f4767527e3543a5a0af249caff51 Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Tue, 6 May 2025 16:25:40 +0200 Subject: [PATCH 184/233] refactor pluging loader --- nebula/core/pluginloader.py | 81 ++++++++++--------- .../arbitatrionpolicies/arbitatrionpolicy.py | 2 - .../saarbitatrionpolicy.py | 16 ---- .../awareness/sareasoner.py | 10 ++- 4 files changed, 53 insertions(+), 56 deletions(-) delete mode 100644 nebula/core/situationalawareness/awareness/arbitatrionpolicies/saarbitatrionpolicy.py diff --git a/nebula/core/pluginloader.py b/nebula/core/pluginloader.py index a239b9916..d28a7a8c4 100644 --- a/nebula/core/pluginloader.py +++ b/nebula/core/pluginloader.py @@ -1,11 +1,3 @@ -import logging -import asyncio -from nebula.addons.functions import print_msg_box -from nebula.core.utils.locker import Locker -from abc import ABC, abstractmethod -import importlib.util -import os - """ It is an example of the .json configuration file structure required for the NebulaPluginLoader to load all plugins defined for the current @@ -31,15 +23,14 @@ Plugin Directory Structure: --------------------------- -All plugins must follow a standardized directory and naming convention +All plugins must follow a standardized directory and file naming convention to be correctly detected and loaded by NebulaPluginLoader. Base path where plugins should be located: /nebula/nebula/core/ Each plugin should be placed in its own directory inside the base path, -with a filename matching the plugin name (lowercase) and a class matching -the plugin name (capitalized). +with a filename matching the plugin name (lowercase). Example for the "reputation" plugin: /nebula/nebula/core/reputation/reputation.py @@ -47,11 +38,16 @@ Example for the "trust" plugin: /nebula/nebula/core/trust/trust.py -Plugin Class Naming Convention: -------------------------------- -Each plugin must define a class with the same name as the plugin but with -the first letter capitalized. This class must inherit from `NebulaPlugin` -and implement the `initialize_plugin()` method. +Plugin Class Requirements: +-------------------------- +Each plugin module must define at least one class that: + +- Inherits from `NebulaPlugin` +- Implements the `initialize_plugin()` method + +There are **no strict naming conventions** for the class name itself; +`NebulaPluginLoader` will automatically detect and instantiate the first +valid plugin class found in the module. Plugins receive their configuration as a dictionary when instantiated. @@ -61,7 +57,7 @@ from nebula.nebula.core.plugin_loader import NebulaPlugin -class Reputation(NebulaPlugin): +class ReputationModel(NebulaPlugin): def __init__(self, config: dict): self.threshold = config.get("threshold", 0.75) self.decay_factor = config.get("decay_factor", 0.05) @@ -78,7 +74,7 @@ async def initialize_plugin(self): from nebula.nebula.core.plugin_loader import NebulaPlugin -class Trust(NebulaPlugin): +class TrustSystem(NebulaPlugin): def __init__(self, config: dict): self.initial_trust = config.get("initial_trust", 0.5) self.trust_update_rate = config.get("trust_update_rate", 0.1) @@ -91,13 +87,22 @@ async def initialize_plugin(self): Important Notes: --------------- -- The plugin class name **must match the plugin name in the JSON but capitalized**. -- The plugin module filename **must be in lowercase**. -- The plugin class must inherit from `NebulaPlugin` and implement the `initialize_plugin()` method. -- Each plugin receives its configuration as a **dictionary** when instantiated. -- The `NebulaPluginLoader` dynamically loads each plugin and passes its respective configuration from the JSON. +- The plugin class name can be arbitrary but must inherit from `NebulaPlugin`. +- Only one valid plugin class is allowed per module; if more exist, only the first will be used. +- The plugin module filename must still match the plugin name (lowercase). +- The plugin class must implement `initialize_plugin()`. +- Each plugin receives its configuration as a dictionary. +- `NebulaPluginLoader` automatically detects and instantiates the appropriate plugin. """ +import logging +from nebula.addons.functions import print_msg_box +from nebula.core.utils.locker import Locker +from abc import ABC, abstractmethod +import importlib.util +import os +import inspect +import inspect class NebulaPlugin(ABC): @abstractmethod @@ -136,33 +141,35 @@ def load_plugins(self): plugin_names = self._config_json.get("plugins", []) for name in plugin_names: - class_name = name.capitalize() module_path = os.path.join(self._base_path, name) module_file = os.path.join(module_path, f"{name}.py") if os.path.exists(module_file): - module = self._load_plugin(class_name, module_file, self._config_json.get(name, {})) - if module: - self._plugins[name] = module + plugin_instance = self._load_plugin(module_file, self._config_json.get(name, {})) + if plugin_instance: + self._plugins[name] = plugin_instance else: - logging.error(f"⚠️ Plugin {name} not found on {module_file}") - - def _load_plugin(self, class_name, module_file, config): + logging.error(f"⚠️ Plugin {name} not found at {module_file}") + + def _load_plugin(self, module_file, config): """Loads a plugin dynamically and initializes it with its configuration.""" - spec = importlib.util.spec_from_file_location(class_name, module_file) + spec = importlib.util.spec_from_file_location("plugin_module", module_file) if spec and spec.loader: module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - if hasattr(module, class_name): # Verify if class exists - return getattr(module, class_name)(config) # Create and instance using plugin config - else: - logging.error(f"⚠️ Cannot create {class_name} plugin, class not found on {module_file}") + + for _, obj in inspect.getmembers(module, inspect.isclass): + if issubclass(obj, NebulaPlugin) and obj is not NebulaPlugin: + return obj(config) + + logging.error(f"⚠️ No valid plugin class found in {module_file}") return None - + async def initialize_plugins(self): """Calls the asynchronous initialization method of each loaded plugin.""" for plugin_name, plugin in self._plugins.items(): - if self._verbose: logging.info(f"Initializing plugin name:{plugin_name}") + if self._verbose: + logging.info(f"Initializing plugin name: {plugin_name}") await plugin.initialize_plugin() def get_plugin(self, name): diff --git a/nebula/core/situationalawareness/awareness/arbitatrionpolicies/arbitatrionpolicy.py b/nebula/core/situationalawareness/awareness/arbitatrionpolicies/arbitatrionpolicy.py index cd1e6db7d..659e28f00 100644 --- a/nebula/core/situationalawareness/awareness/arbitatrionpolicies/arbitatrionpolicy.py +++ b/nebula/core/situationalawareness/awareness/arbitatrionpolicies/arbitatrionpolicy.py @@ -13,11 +13,9 @@ async def tie_break(self, sac1: SACommand, sac2: SACommand) -> bool: def factory_arbitatrion_policy(arbitatrion_policy, verbose) -> ArbitatrionPolicy: from nebula.core.situationalawareness.awareness.arbitatrionpolicies.staticarbitatrionpolicy import SAP - from nebula.core.situationalawareness.awareness.arbitatrionpolicies.saarbitatrionpolicy import SAAP options = { "sap": SAP, # "Static Arbitatrion Policy" (SAP) -- default value - "saap": SAAP, # "Situational Awareness Arbitatrion Policy" (SAAP) } cs = options.get(arbitatrion_policy, SAP) diff --git a/nebula/core/situationalawareness/awareness/arbitatrionpolicies/saarbitatrionpolicy.py b/nebula/core/situationalawareness/awareness/arbitatrionpolicies/saarbitatrionpolicy.py deleted file mode 100644 index 906abf8a7..000000000 --- a/nebula/core/situationalawareness/awareness/arbitatrionpolicies/saarbitatrionpolicy.py +++ /dev/null @@ -1,16 +0,0 @@ -import asyncio -from nebula.core.situationalawareness.awareness.arbitatrionpolicies.arbitatrionpolicy import ArbitatrionPolicy -from nebula.core.situationalawareness.awareness.sautils.sacommand import SACommand - -class SAAP(ArbitatrionPolicy): - def __init__(self, verbose): - pass - - async def init(self, config): - pass - - async def tie_break(self, sac1: SACommand, sac2: SACommand) -> SACommand: - """ - Tie break conflcited SA Commands - """ - pass \ No newline at end of file diff --git a/nebula/core/situationalawareness/awareness/sareasoner.py b/nebula/core/situationalawareness/awareness/sareasoner.py index 2df531da5..07e7c0c74 100644 --- a/nebula/core/situationalawareness/awareness/sareasoner.py +++ b/nebula/core/situationalawareness/awareness/sareasoner.py @@ -192,6 +192,14 @@ async def _arbitatrion_suggestions(self, event_type): # SA COMPONENT LOADING # ############################### """ + def _to_pascal_case(name: str) -> str: + """Converts a snake_case or compact lowercase name into PascalCase with 'SA' prefix.""" + if name.startswith("sa_"): + name = name[3:] # remove 'sa_' prefix + elif name.startswith("sa"): + name = name[2:] # remove 'sa' prefix + parts = name.split("_") if "_" in name else [name] + return "SA" + ''.join(part.capitalize() for part in parts) async def loading_sa_components(self): """Dynamically loads the SA Components defined in the JSON configuration.""" @@ -201,7 +209,7 @@ async def loading_sa_components(self): for component_name, is_enabled in components.items(): if is_enabled: component_config = sa_section[component_name] - class_name = "SA" + component_name[2:].capitalize() + class_name = self._to_pascal_case(component_name) # ← funciΓ³n limpia module_path = os.path.join(self.MODULE_PATH, component_name) module_file = os.path.join(module_path, f"{component_name}.py") From a0a4740c13506c89bd66f3f7e4c1f4bb0e203c0c Mon Sep 17 00:00:00 2001 From: "Alejandro.A.S" Date: Wed, 7 May 2025 17:09:54 +0200 Subject: [PATCH 185/233] refactor ring topology np --- nebula/core/network/communications.py | 33 +-------------- nebula/core/pluginloader.py | 6 +-- .../neighborpolicies/ringneighborpolicy.py | 23 +++++++---- .../awareness/sareasoner.py | 8 ++-- .../discovery/federationconnector.py | 41 +++++++++++-------- .../situationalawareness.py | 8 ++-- nebula/node.py | 2 +- 7 files changed, 52 insertions(+), 69 deletions(-) diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 0e9782575..7a6f3e800 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -508,6 +508,7 @@ async def send_message_to_neighbors(self, message, neighbors=None, interval=0): await asyncio.sleep(interval) async def send_message(self, dest_addr, message, is_compressed=False): + logging.info(f"Sending message to addr: {dest_addr}") if not is_compressed: try: conn = self.connections[dest_addr] @@ -527,38 +528,6 @@ async def send_message(self, dest_addr, message, is_compressed=False): logging.exception(f"❗️ Cannot send model to {dest_addr}: {e!s}") await self.disconnect(dest_addr, mutual_disconnection=False) - # async def send_model(self, dest_addr, round, serialized_model, weight=1): - # async with self.semaphore_send_model: - # try: - # conn = self.connections.get(dest_addr) - # if conn is None: - # logging.info(f"❗️ Connection with {dest_addr} not found") - # return - # logging.info( - # f"Sending model to {dest_addr} with round {round}: weight={weight} |Β size={sys.getsizeof(serialized_model) / (1024** 2) if serialized_model is not None else 0} MB" - # ) - # parameters = serialized_model - # message = self.create_message("model", "", round, parameters, weight) - # await conn.send(data=message, is_compressed=True) - # logging.info(f"Model sent to {dest_addr} with round {round}") - # except Exception as e: - # logging.exception(f"❗️ Cannot send model to {dest_addr}: {e!s}") - # await self.disconnect(dest_addr, mutual_disconnection=False) - - # async def send_offer_model(self, dest_addr, offer_message): - # async with self.semaphore_send_model: - # try: - # conn = self.connections.get(dest_addr) - # if conn is None: - # logging.info(f"❗️ Connection with {dest_addr} not found") - # return - # logging.info(f"Sending offer model to {dest_addr}") - # await conn.send(data=offer_message, is_compressed=True) - # logging.info(f"Offer_Model sent to {dest_addr}") - # except Exception as e: - # logging.exception(f"❗️ Cannot send model to {dest_addr}: {e!s}") - # await self.disconnect(dest_addr, mutual_disconnection=False) - async def establish_connection(self, addr, direct=True, reconnect=False): logging.info(f"πŸ”— [outgoing] Establishing connection with {addr} (direct: {direct})") diff --git a/nebula/core/pluginloader.py b/nebula/core/pluginloader.py index d28a7a8c4..c99f09967 100644 --- a/nebula/core/pluginloader.py +++ b/nebula/core/pluginloader.py @@ -55,7 +55,7 @@ ----------------------------------------------------- File: /nebula/nebula/core/reputation/reputation.py -from nebula.nebula.core.plugin_loader import NebulaPlugin +from nebula.nebula.core.pluginloader import NebulaPlugin class ReputationModel(NebulaPlugin): def __init__(self, config: dict): @@ -116,14 +116,14 @@ class NebulaPluginLoader: _instance = None _lock = Locker("_nebula_pluging_loader_lock", async_lock=False) - def __new__(cls, config_json=None, base_path="/nebula/nebula/core"): + def __new__(cls, config_json=None, base_path="/nebula/nebula/core/plugins"): with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance - def __init__(self, config_json=None, base_path="/nebula/nebula/core"): + def __init__(self, config_json=None, base_path="/nebula/nebula/core/plugins"): """Initializes the plugin loader with the given configuration JSON and base path.""" if self._initialized: return diff --git a/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/ringneighborpolicy.py b/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/ringneighborpolicy.py index bab9a87a8..34e1858de 100644 --- a/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/ringneighborpolicy.py +++ b/nebula/core/situationalawareness/awareness/sanetwork/neighborpolicies/ringneighborpolicy.py @@ -29,6 +29,7 @@ def set_config(self, config): """ logging.info("Initializing Ring Topology Neighbor Policy") self.neighbors_lock.acquire() + logging.info(f"neighbors: {config[0]}") self.neighbors = config[0] self.neighbors_lock.release() for addr in config[1]: @@ -41,7 +42,7 @@ def accept_connection(self, source, joining=False): """ ac = False self.neighbors_lock.acquire() - if not joining: + if joining: ac = not source in self.neighbors else: ac = not len(self.neighbors) == self.max_neighbors @@ -63,11 +64,17 @@ def forget_nodes(self, nodes, forget_all=False): self.nodes_known_lock.release() def get_nodes_known(self, neighbors_too=False, neighbors_only=False): + if neighbors_only: + self.neighbors_lock.acquire() + no = self.neighbors.copy() + self.neighbors_lock.release() + return no + self.nodes_known_lock.acquire() nk = self.nodes_known.copy() if not neighbors_too: self.neighbors_lock.acquire() - nk = self.nodes_known - self.neighbors + nk = self.nodes_known - self.neighbors self.neighbors_lock.release() self.nodes_known_lock.release() return nk @@ -79,14 +86,14 @@ def get_actions(self): - Second one represents the same but for disconnect from LinkMessage """ self.neighbors_lock.acquire() - ct_actions = [] - df_actions = [] - if len(self.neighbors) < self.max_neighbors: + ct_actions = "" + df_actions = "" + if len(self.neighbors) <= self.max_neighbors: list_neighbors = list(self.neighbors) index = random.randint(0, len(list_neighbors)-1) - node = list_neighbors[index] - ct_actions.append(node) # connect to - df_actions.append(self.addr) # disconnect from + node = list_neighbors[index] + ct_actions = node # connect to + df_actions = self.addr # disconnect from self.neighbors_lock.release() return [ct_actions, df_actions] diff --git a/nebula/core/situationalawareness/awareness/sareasoner.py b/nebula/core/situationalawareness/awareness/sareasoner.py index 07e7c0c74..9668f8543 100644 --- a/nebula/core/situationalawareness/awareness/sareasoner.py +++ b/nebula/core/situationalawareness/awareness/sareasoner.py @@ -123,8 +123,10 @@ def get_actions(self): async def _process_round_end_event(self, ree : RoundEndEvent): logging.info("πŸ”„ Arbitration | Round End Event...") - for sa_comp in self._sa_components.values(): - asyncio.create_task(sa_comp.sa_component_actions()) + # TODO change when front is done + # for sa_comp in self._sa_components.values(): + # asyncio.create_task(sa_comp.sa_component_actions()) + asyncio.create_task(self.san.sa_component_actions()) valid_commands = await self._arbitatrion_suggestions(RoundEndEvent) # Execute SACommand selected @@ -204,7 +206,7 @@ def _to_pascal_case(name: str) -> str: async def loading_sa_components(self): """Dynamically loads the SA Components defined in the JSON configuration.""" sa_section = self._config.participant["situational_awareness"] - components: dict = sa_section["sa_components"] + components: dict = sa_section["sar_components"] for component_name, is_enabled in components.items(): if is_enabled: diff --git a/nebula/core/situationalawareness/discovery/federationconnector.py b/nebula/core/situationalawareness/discovery/federationconnector.py index 5892d8e1b..220c4cf0d 100644 --- a/nebula/core/situationalawareness/discovery/federationconnector.py +++ b/nebula/core/situationalawareness/discovery/federationconnector.py @@ -109,28 +109,23 @@ def _still_waiting_for_candidates(self): return not self.accept_candidates_lock.locked() and self.late_connection_process_lock.locked() async def _add_pending_connection_confirmation(self, addr): - await self._update_neighbors_lock.acquire_async() - await self.pending_confirmation_from_nodes_lock.acquire_async() - if addr not in self.sar.get_nodes_known(neighbors_only=True): - logging.info(f" Addition | pending connection confirmation from: {addr}") - self.pending_confirmation_from_nodes.add(addr) - await self.pending_confirmation_from_nodes_lock.release_async() - await self._update_neighbors_lock.release_async() + async with self._update_neighbors_lock: + async with self.pending_confirmation_from_nodes_lock: + if addr not in self.sar.get_nodes_known(neighbors_only=True): + logging.info(f"Addition | pending connection confirmation from: {addr}") + self.pending_confirmation_from_nodes.add(addr) async def _remove_pending_confirmation_from(self, addr): - await self.pending_confirmation_from_nodes_lock.acquire_async() - self.pending_confirmation_from_nodes.discard(addr) - await self.pending_confirmation_from_nodes_lock.release_async() + async with self.pending_confirmation_from_nodes_lock: + self.pending_confirmation_from_nodes.discard(addr) async def _clear_pending_confirmations(self): - await self.pending_confirmation_from_nodes_lock.acquire_async() - self.pending_confirmation_from_nodes.clear() - await self.pending_confirmation_from_nodes_lock.release_async() + async with self.pending_confirmation_from_nodes_lock: + self.pending_confirmation_from_nodes.clear() async def _waiting_confirmation_from(self, addr): - await self.pending_confirmation_from_nodes_lock.acquire_async() - found = addr in self.pending_confirmation_from_nodes - await self.pending_confirmation_from_nodes_lock.release_async() + async with self.pending_confirmation_from_nodes_lock: + found = addr in self.pending_confirmation_from_nodes return found async def _confirmation_received(self, addr, joining=False): @@ -305,17 +300,26 @@ async def _connection_late_connect_callback(self, source, message): # Verify conenction is accepted conf_msg = self.cm.create_message("connection", "late_connect") await self.cm.send_message(source, conf_msg) - await self._register_late_neighbor(source, joinning_federation=True) + + # SI ACTUALIZO PRINMERO SE PASA DEL NUMERO DE VECINOS TODO + ct_actions, df_actions = self._get_actions() + logging.info("voy a mostrar acciones en respuesta a late connect") if len(ct_actions): + logging.info("1 acciones") + logging.info(f"{ct_actions}") cnt_msg = self.cm.create_message("link", "connect_to", addrs=ct_actions) await self.cm.send_message(source, cnt_msg) if len(df_actions): + logging.info("2 acciones") + logging.info(f"{df_actions}") df_msg = self.cm.create_message("link", "disconnect_from", addrs=df_actions) await self.cm.send_message(source, df_msg) + await self._register_late_neighbor(source, joinning_federation=True) + else: logging.info(f"❗️ Late connection NOT accepted | source: {source}") @@ -344,6 +348,7 @@ async def _connection_restructure_callback(self, source, message): await self.cm.send_message(source, cnt_msg) if len(df_actions): + # TODO el q se tiene q desconectar de mi no es source, es el vecino seleccionado df_msg = self.cm.create_message("link", "disconnect_from", addrs=df_actions) await self.cm.send_message(source, df_msg) @@ -440,7 +445,7 @@ async def _link_disconnect_from_callback(self, source, message): logging.info(f"πŸ”— handle_link_message | Trigger | Received disconnect_from message from {source}") addrs = message.addrs for addr in addrs.split(): - await asyncio.create_task(self.cm.disconnect(source, mutual_disconnection=False)) + await asyncio.create_task(self.cm.disconnect(addr, mutual_disconnection=False)) diff --git a/nebula/core/situationalawareness/situationalawareness.py b/nebula/core/situationalawareness/situationalawareness.py index 02b20260f..a9466ed90 100644 --- a/nebula/core/situationalawareness/situationalawareness.py +++ b/nebula/core/situationalawareness/situationalawareness.py @@ -34,7 +34,7 @@ def get_actions(self): def factory_sa_discovery(sa_discovery, additional, topology, model_handler, engine, verbose) -> ISADiscovery: from nebula.core.situationalawareness.discovery.federationconnector import FederationConnector DISCOVERY = { - "fedcon": FederationConnector, + "nebula": FederationConnector, } sad = DISCOVERY.get(sa_discovery) if sad: @@ -45,7 +45,7 @@ def factory_sa_discovery(sa_discovery, additional, topology, model_handler, engi def factory_sa_reasoner(sa_reasoner, config, addr, topology, verbose) -> ISAReasoner: from nebula.core.situationalawareness.awareness.sareasoner import SAReasoner REASONER = { - "nebula_reasoner": SAReasoner, + "nebula": SAReasoner, } sar = REASONER.get(sa_reasoner) if sar: @@ -65,7 +65,7 @@ def __init__(self, config, engine): topology = topology.lower() model_handler = "std" self._sad = factory_sa_discovery( - "fedcon", + "nebula", self._config.participant["mobility_args"]["additional_node"]["status"], topology, model_handler, @@ -73,7 +73,7 @@ def __init__(self, config, engine): verbose=True ) self._sareasoner = factory_sa_reasoner( - "nebula_reasoner", + "nebula", self._config, self._config.participant["network_args"]["addr"], topology, diff --git a/nebula/node.py b/nebula/node.py index 1e27171c3..aa0155792 100755 --- a/nebula/node.py +++ b/nebula/node.py @@ -230,7 +230,7 @@ def randomize_value(value, variability): logging.info("Waiting time to start finding federation") # time.sleep(150) - await asyncio.sleep(150) + await asyncio.sleep(120) # time.sleep(6000) # DEBUG purposes # import requests From 584e94a71d27c0af3037357cb67dbc95a6ce284b Mon Sep 17 00:00:00 2001 From: FerTV Date: Thu, 8 May 2025 11:44:55 +0200 Subject: [PATCH 186/233] feature sa frontend --- .../frontend/config/participant.json.example | 25 ++- nebula/frontend/templates/deployment.html | 169 +++++++++++++++++- nebula/node.py | 2 +- nebula/scenarios.py | 30 +++- 4 files changed, 207 insertions(+), 19 deletions(-) diff --git a/nebula/frontend/config/participant.json.example b/nebula/frontend/config/participant.json.example index 11048ba58..c37690461 100755 --- a/nebula/frontend/config/participant.json.example +++ b/nebula/frontend/config/participant.json.example @@ -147,16 +147,23 @@ "history_size": 20 }, "situational_awareness": { - "sa_components": { - "sanetwork": true + "strict_topology" : true, + "sa_discovery" : { + "candidate_selector": "random", + "model_handler": "std", + "verbose" : true }, - "sanetwork": { - "addr": "", - "strict_topology": true, - "verbose": true - }, - "arbitatrion_policy": "sap", - "model_handler": "std" + "sa_reasoner" : { + "arbitatrion_policy": "sap", + "verbose" : true, + "sar_components": { + "sa_network" : true + }, + "sar_network" : { + "neighbor_policy": "random", + "verbose" : true + } + } }, "misc_args": { "grace_time_connection": 10, diff --git a/nebula/frontend/templates/deployment.html b/nebula/frontend/templates/deployment.html index 123bd11e5..64d40888e 100755 --- a/nebula/frontend/templates/deployment.html +++ b/nebula/frontend/templates/deployment.html @@ -708,6 +708,10 @@
Mobility configuration
+ +
+
Situtational Awareness + + +
+
Enable/Disable SA
+
+ +
+ + + +
@@ -1080,8 +1157,14 @@
Schema of deployment
var aditional_participants = []; additionalParticipants = document.getElementById("additionalParticipants"); for (var i = 0; i < additionalParticipants.value; i++) { - aditional_participants[i] = { - "round": document.getElementById("roundsAdditionalParticipant" + i).value, + if(document.getElementById("connectionDelaySwitch").checked){ + aditional_participants[i] = { + "time": document.getElementById("connectionDelay").value, + } + } else{ + aditional_participants[i] = { + "time": document.getElementById("timeAdditionalParticipant" + i).value, + } } } @@ -1166,12 +1249,20 @@
Schema of deployment
data["longitude"] = parseFloat(document.getElementById("longitude").value) data["mobility"] = document.getElementById("mobility-btn").checked ? true : false data["mobility_type"] = document.getElementById("mobilitySelect").value + data["network_simulation"] = document.getElementById("networkSimulation").checked data["radius_federation"] = document.getElementById("radiusFederation").value data["scheme_mobility"] = document.getElementById("schemeMobilitySelect").value data["round_frequency"] = document.getElementById("roundFrequency").value data["mobile_participants_percent"] = document.getElementById("mobileParticipantsPercent").value data["additional_participants"] = aditional_participants data["schema_additional_participants"] = document.getElementById("schemaAdditionalParticipantsSelect").value + // Step 15 + data["with_sa"] = document.getElementById("situationalAwarenessSwitch").checked; + data["strict_topology"] = document.getElementById("StrictTopologySwitch").checked; + data["sad_candidate_selector"] = document.getElementById("candidate-selector-select").value; + data["sad_model_handler"] = document.getElementById("model-handler-select").value; + data["sar_arbitration_policy"] = document.getElementById("arbitration-policy-select").value; + data["sar_neighbor_policy"] = document.getElementById("neighbor-policy-select").value; return data } @@ -1313,6 +1404,7 @@
Schema of deployment
if (data["mobility"]) { document.getElementById("mobility-options").style.display = "block"; } + document.getElementById("networkSimulation").checked = data["network_simulation"]; document.getElementById("mobilitySelect").value = data["mobility_type"]; document.getElementById("radiusFederation").value = data["radius_federation"]; document.getElementById("schemeMobilitySelect").value = data["scheme_mobility"]; @@ -1326,9 +1418,18 @@
Schema of deployment
additionalParticipants.value = Object.keys(data["additional_participants"]).length; additionalParticipants.dispatchEvent(new Event('change')); for (var i = 0; i < additionalParticipants.value; i++) { - document.getElementById("roundsAdditionalParticipant" + i).value = data["additional_participants"][i]["round"]; + document.getElementById("timeAdditionalParticipant" + i).value = data["additional_participants"][i]["round"]; } } + // Step 14 + document.getElementById("situationalAwarenessSwitch").checked = data["with_sa"]; + document.getElementById("situationalAwarenessSwitch").dispatchEvent(new Event('change')); + document.getElementById("StrictTopologySwitch").checked = data["strict_topology"]; + document.getElementById("StrictTopologySwitch").dispatchEvent(new Event('change')); + document.getElementById("candidate-selector-select").value = data["sad_candidate_selector"]; + document.getElementById("model-handler-select").value = data["sad_model_handler"]; + document.getElementById("arbitration-policy-select").value = data["sar_arbitration_policy"]; + document.getElementById("neighbor-policy-select").value = data["sar_neighbor_policy"]; } catch (error) { console.log(error); } @@ -1637,6 +1738,17 @@
Schema of deployment
container: "body", delay: { show: 500, hide: 100 }, }); + + var connectionDelay= document.getElementById("connectionDelayHelpIcon"); + var popover = new bootstrap.Popover(connectionDelay, { + title: "Connection Delay", + content: "Delay in seconds for joining the federation", + trigger: "hover", + placement: "right", + html: true, + container: "body", + delay: { show: 500, hide: 100 }, + }); + +