diff --git a/README.md b/README.md index 6c79b4f91..6cf237430 100755 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@
- nebula-dfl.com | + nebula-dfl.com | nebula-dfl.eu | federatedlearning.inf.um.es
diff --git a/analysis/README.md b/analysis/README.md index 5118b091c..3f651d934 100755 --- a/analysis/README.md +++ b/analysis/README.md @@ -7,7 +7,7 @@- nebula-dfl.com | + nebula-dfl.com | nebula-dfl.eu | federatedlearning.inf.um.es
diff --git a/app/main.py b/app/main.py index a9a4733f0..3d6c300ad 100755 --- a/app/main.py +++ b/app/main.py @@ -4,8 +4,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..")) # Parent directory where is the NEBULA module import nebula -from nebula.controller import Controller -from nebula.scenarios import ScenarioManagement +from nebula.controller.controller import Controller +from nebula.controller.scenarios import ScenarioManagement argparser = argparse.ArgumentParser(description="Controller of NEBULA platform", add_help=False) @@ -54,8 +54,6 @@ help="Statistics port (default: 8080)", ) -argparser.add_argument("-t", "--test", dest="test", action="store_true", default=False, help="Run tests") - argparser.add_argument( "-st", "--stop", diff --git a/docs/_prebuilt/developerguide.md b/docs/_prebuilt/developerguide.md index 43eb0c182..b54565017 100644 --- a/docs/_prebuilt/developerguide.md +++ b/docs/_prebuilt/developerguide.md @@ -829,8 +829,8 @@ To add a new message to the application, follow these steps: FEDERATION_READY = 3; } Action action = 1; - repeated string arguments = 2; - int32 round = 3; + repeated string arguments = 2; + int32 round = 3; } ``` @@ -905,4 +905,4 @@ Note that **EventType** is the class that represents the event (not a specific i When the event is published, all subscribed listeners for that event type will be triggered. As mentioned, there are three different **publish** functions, each tied to a specific type of event. -Finally, to **create a new event**, go to the file **/core/nebulaevents.py**. Depending on the type of event you wish to implement, create a class that extends one of the three native event types. After doing this, the usage of your new event is transparent to the rest of the system, and you can use the functions described above without any issues. \ No newline at end of file +Finally, to **create a new event**, go to the file **/core/nebulaevents.py**. Depending on the type of event you wish to implement, create a class that extends one of the three native event types. After doing this, the usage of your new event is transparent to the rest of the system, and you can use the functions described above without any issues. diff --git a/nebula/addons/attacks/communications/communicationattack.py b/nebula/addons/attacks/communications/communicationattack.py index 6552e9e52..bacae998c 100644 --- a/nebula/addons/attacks/communications/communicationattack.py +++ b/nebula/addons/attacks/communications/communicationattack.py @@ -1,9 +1,11 @@ import logging import random +import random import types from abc import abstractmethod from nebula.addons.attacks.attacks import Attack +from nebula.core.network.communications import CommunicationsManager class CommunicationAttack(Attack): @@ -46,21 +48,23 @@ async def select_targets(self): if self.selection_interval: if self.last_selection_round % self.selection_interval == 0: logging.info("Recalculating targets...") - all_nodes = await self.engine.cm.get_addrs_current_connections(only_direct=True) + all_nodes = await CommunicationsManager.get_instance().get_addrs_current_connections(only_direct=True) num_targets = max(1, int(len(all_nodes) * (self.selectivity_percentage / 100))) self.targets = set(random.sample(list(all_nodes), num_targets)) elif not self.targets: logging.info("Calculating targets...") - all_nodes = await self.engine.cm.get_addrs_current_connections(only_direct=True) + all_nodes = await CommunicationsManager.get_instance().get_addrs_current_connections(only_direct=True) num_targets = max(1, int(len(all_nodes) * (self.selectivity_percentage / 100))) self.targets = set(random.sample(list(all_nodes), num_targets)) else: logging.info("All neighbors selected as targets") - self.targets = await self.engine.cm.get_addrs_current_connections(only_direct=True) + self.targets = await CommunicationsManager.get_instance().get_addrs_current_connections(only_direct=True) logging.info(f"Selected {self.selectivity_percentage}% targets from neighbors: {self.targets}") self.last_selection_round += 1 + self.last_selection_round += 1 + async def _inject_malicious_behaviour(self): """Inject malicious behavior into the target method.""" decorated_method = self.decorator(self.decorator_args)(self.original_method) diff --git a/nebula/addons/attacks/communications/delayerattack.py b/nebula/addons/attacks/communications/delayerattack.py index afe611fde..05daae735 100644 --- a/nebula/addons/attacks/communications/delayerattack.py +++ b/nebula/addons/attacks/communications/delayerattack.py @@ -3,6 +3,7 @@ from functools import wraps from nebula.addons.attacks.communications.communicationattack import CommunicationAttack +from nebula.core.network.communications import CommunicationsManager class DelayerAttack(CommunicationAttack): @@ -32,8 +33,8 @@ def __init__(self, engine, attack_params: dict): super().__init__( engine, - engine._cm, - "send_model", + CommunicationsManager.get_instance(), + "send_message", round_start, round_stop, attack_interval, @@ -43,27 +44,27 @@ def __init__(self, engine, attack_params: dict): ) def decorator(self, delay: int): - """ - Decorator that adds a delay to the execution of the original method. + """ + Decorator that adds a delay to the execution of the original method. - Args: - delay (int): The time in seconds to delay the method execution. + Args: + delay (int): The time in seconds to delay the method execution. - Returns: - function: A decorator function that wraps the target method with the delay logic. - """ + Returns: + function: A decorator function that wraps the target method with the delay logic. + """ - def decorator(func): - @wraps(func) - async def wrapper(*args, **kwargs): - if len(args) > 1: - dest_addr = args[1] - if dest_addr in self.targets: - logging.info(f"[DelayerAttack] Delaying model propagation to {dest_addr} by {delay} seconds") - await asyncio.sleep(delay) - _, *new_args = args # Exclude self argument - return await func(*new_args) + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + if len(args) == 4 and args[3] == "model": + dest_addr = args[1] + if dest_addr in self.targets: + logging.info(f"[DelayerAttack] Delaying model propagation to {dest_addr} by {delay} seconds") + await asyncio.sleep(delay) + _, *new_args = args # Exclude self argument + return await func(*new_args) - return wrapper + return wrapper - return decorator + return decorator diff --git a/nebula/addons/attacks/communications/floodingattack.py b/nebula/addons/attacks/communications/floodingattack.py index 643969739..146854fa3 100644 --- a/nebula/addons/attacks/communications/floodingattack.py +++ b/nebula/addons/attacks/communications/floodingattack.py @@ -1,9 +1,8 @@ -import asyncio import logging from functools import wraps -import time from nebula.addons.attacks.communications.communicationattack import CommunicationAttack +from nebula.core.network.communications import CommunicationsManager class FloodingAttack(CommunicationAttack): @@ -35,8 +34,8 @@ def __init__(self, engine, attack_params: dict): super().__init__( engine, - engine._cm, - "send_model", + CommunicationsManager.get_instance(), + "send_message", round_start, round_stop, attack_interval, @@ -59,7 +58,7 @@ def decorator(self, flooding_factor: int): def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): - if len(args) > 1: + if len(args) == 4 and args[3] == "model": dest_addr = args[1] if dest_addr in self.targets: logging.info(f"[FloodingAttack] Flooding message to {dest_addr} by {flooding_factor} times") @@ -68,13 +67,11 @@ async def wrapper(*args, **kwargs): logging.info( f"[FloodingAttack] Sending duplicate {i + 1}/{flooding_factor} to {dest_addr}" ) - _, dest_addr, _, serialized_model, weight = args # Exclude self argument - new_args = [dest_addr, i, serialized_model, weight] + _, *new_args = args # Exclude self argument await func(*new_args, **kwargs) - _, dest_addr, _, serialized_model, weight = args # Exclude self argument - new_args = [dest_addr, i, serialized_model, weight] + _, *new_args = args return await func(*new_args) - + return wrapper return decorator diff --git a/nebula/addons/attacks/dataset/datasetattack.py b/nebula/addons/attacks/dataset/datasetattack.py index c740b8b91..812e09bb3 100644 --- a/nebula/addons/attacks/dataset/datasetattack.py +++ b/nebula/addons/attacks/dataset/datasetattack.py @@ -35,9 +35,11 @@ async def attack(self): """ if self.engine.round not in range(self.round_start_attack, self.round_stop_attack + 1): pass - elif self.engine.round == self.round_stop_attack: + elif self.engine.round == self.round_stop_attack: logging.info(f"[{self.__class__.__name__}] Stopping attack") - elif self.engine.round >= self.round_start_attack and ((self.engine.round - self.round_start_attack) % self.attack_interval == 0): + elif self.engine.round >= self.round_start_attack and ( + (self.engine.round - self.round_start_attack) % self.attack_interval == 0 + ): logging.info(f"[{self.__class__.__name__}] Performing attack") self.engine.trainer.datamodule.train_set = self.get_malicious_dataset() diff --git a/nebula/addons/attacks/dataset/labelflipping.py b/nebula/addons/attacks/dataset/labelflipping.py index 1116a9019..41d1e106e 100755 --- a/nebula/addons/attacks/dataset/labelflipping.py +++ b/nebula/addons/attacks/dataset/labelflipping.py @@ -10,6 +10,7 @@ import copy import logging import random + import numpy as np from nebula.addons.attacks.dataset.datasetattack import DatasetAttack diff --git a/nebula/addons/attacks/model/gllneuroninversion.py b/nebula/addons/attacks/model/gllneuroninversion.py index 45686ad86..18cce6ddf 100644 --- a/nebula/addons/attacks/model/gllneuroninversion.py +++ b/nebula/addons/attacks/model/gllneuroninversion.py @@ -34,7 +34,7 @@ def __init__(self, engine, attack_params): raise ValueError(f"Missing required attack parameter: {e}") except ValueError: raise ValueError("Invalid value in attack_params. Ensure all values are integers.") - + super().__init__(engine, round_start, round_stop, attack_interval) def model_attack(self, received_weights): diff --git a/nebula/addons/attacks/model/modelattack.py b/nebula/addons/attacks/model/modelattack.py index 643f3d728..7f1c719bb 100644 --- a/nebula/addons/attacks/model/modelattack.py +++ b/nebula/addons/attacks/model/modelattack.py @@ -110,7 +110,9 @@ async def attack(self): elif self.engine.round == self.round_stop_attack: logging.info(f"[{self.__class__.__name__}] Stopping attack") await self._restore_original_behaviour() - elif (self.engine.round == self.round_start_attack) or ((self.engine.round - self.round_start_attack) % self.attack_interval == 0): + elif (self.engine.round == self.round_start_attack) or ( + (self.engine.round - self.round_start_attack) % self.attack_interval == 0 + ): logging.info(f"[{self.__class__.__name__}] Performing attack") await self._inject_malicious_behaviour() else: diff --git a/nebula/addons/attacks/model/swappingweights.py b/nebula/addons/attacks/model/swappingweights.py index a194ba8ae..70ac21834 100644 --- a/nebula/addons/attacks/model/swappingweights.py +++ b/nebula/addons/attacks/model/swappingweights.py @@ -40,7 +40,7 @@ def __init__(self, engine, attack_params): raise ValueError(f"Missing required attack parameter: {e}") except ValueError: raise ValueError("Invalid value in attack_params. Ensure all values are integers.") - + super().__init__(engine, round_start, round_stop, attack_interval) self.layer_idx = int(attack_params["layer_idx"]) diff --git a/nebula/addons/gps/nebulagps.py b/nebula/addons/gps/nebulagps.py index 208f54dbf..8a310c561 100644 --- a/nebula/addons/gps/nebulagps.py +++ b/nebula/addons/gps/nebulagps.py @@ -26,7 +26,7 @@ def __init__(self, config, addr, update_interval: float = 5.0, verbose=False): self._verbose = verbose async def start(self): - """Inicia el servicio de GPS, enviando y recibiendo ubicaciones.""" + """Starts the GPS service, sending and receiving locations.""" logging.info("Starting NebulaGPS service...") self.running = True @@ -73,7 +73,7 @@ async def _send_location_loop(self): await asyncio.sleep(self.update_interval) async def _receive_location_loop(self): - """Escucha y almacena geolocalizaciones de otros nodos.""" + """Listens to and stores geolocations from other nodes.""" while self.running: try: data, addr = await asyncio.get_running_loop().run_in_executor( @@ -88,7 +88,7 @@ async def _receive_location_loop(self): if self._verbose: logging.info(f"Received GPS from {addr[0]}: {lat}, {lon}") except Exception as e: - logging.error(f"Error receiving GPS update: {e}") + logging.exception(f"Error receiving GPS update: {e}") async def _notify_geolocs(self): while True: @@ -102,5 +102,7 @@ async def _notify_geolocs(self): for addr, (lat, long) in geolocs.items(): dist = await self.calculate_distance(self_lat, self_long, lat, long) distances[addr] = (dist, (lat, long)) + + self._config.update_nodes_distance(distances) gpsevent = GPSEvent(distances) asyncio.create_task(EventManager.get_instance().publish_addonevent(gpsevent)) diff --git a/nebula/addons/mobility.py b/nebula/addons/mobility.py index 41fa6624d..522161e70 100755 --- a/nebula/addons/mobility.py +++ b/nebula/addons/mobility.py @@ -3,19 +3,17 @@ import math import random import time -from typing import TYPE_CHECKING +from functools import cached_property from nebula.addons.functions import print_msg_box from nebula.core.eventmanager import EventManager -from nebula.core.nebulaevents import GPSEvent +from nebula.core.nebulaevents import ChangeLocationEvent, GPSEvent +from nebula.core.network.communications import CommunicationsManager from nebula.core.utils.locker import Locker -if TYPE_CHECKING: - from nebula.core.network.communications import CommunicationsManager - class Mobility: - def __init__(self, config, cm: "CommunicationsManager", verbose=False): + def __init__(self, config, verbose=False): """ Initializes the mobility module with specified configuration and communication manager. @@ -52,7 +50,6 @@ def __init__(self, config, cm: "CommunicationsManager", verbose=False): """ logging.info("Starting mobility module...") self.config = config - self.cm = cm self.grace_time = self.config.participant["mobility_args"]["grace_time_mobility"] self.period = self.config.participant["mobility_args"]["change_geo_interval"] self.mobility = self.config.participant["mobility_args"]["mobility"] @@ -60,10 +57,10 @@ def __init__(self, config, cm: "CommunicationsManager", verbose=False): self.radius_federation = float(self.config.participant["mobility_args"]["radius_federation"]) self.scheme_mobility = self.config.participant["mobility_args"]["scheme_mobility"] self.round_frequency = int(self.config.participant["mobility_args"]["round_frequency"]) - # Protocol to change connections based on distance - self.max_distance_with_direct_connections = 300 # meters - self.max_movement_random_strategy = 100 # meters - self.max_movement_nearest_strategy = 100 # meters + # INFO: These values may change according to the needs of the federation + self.max_distance_with_direct_connections = 150 # meters + self.max_movement_random_strategy = 50 # meters + self.max_movement_nearest_strategy = 50 # meters self.max_initiate_approximation = self.max_distance_with_direct_connections * 1.2 # Logging box with mobility information mobility_msg = f"Mobility: {self.mobility}\nMobility type: {self.mobility_type}\nRadius federation: {self.radius_federation}\nScheme mobility: {self.scheme_mobility}\nEach {self.round_frequency} rounds" @@ -72,6 +69,10 @@ def __init__(self, config, cm: "CommunicationsManager", verbose=False): self._nodes_distances_lock = Locker("nodes_distances_lock", async_lock=True) self._verbose = verbose + @cached_property + def cm(self): + return CommunicationsManager.get_instance() + @property def round(self): """ @@ -101,6 +102,7 @@ async def start(self): `run_mobility` operation. """ await EventManager.get_instance().subscribe_addonevent(GPSEvent, self.update_nodes_distances) + await EventManager.get_instance().subscribe_addonevent(GPSEvent, self.update_nodes_distances) task = asyncio.create_task(self.run_mobility()) return task @@ -198,7 +200,8 @@ async def change_geo_location_nearest_neighbor_strategy( coordinates to determine the direction of movement. - The conversion from meters to degrees is based on approximate geographical conversion factors. """ - logging.info("π Changing geo location towards the nearest neighbor") + if self._verbose: + logging.info("π Changing geo location towards the nearest neighbor") scale_factor = min(1, self.max_movement_nearest_strategy / distance) # Calcular el Γ‘ngulo hacia el vecino angle = math.atan2(neighbor_longitude - longitude, neighbor_latitude - latitude) @@ -245,6 +248,8 @@ async def set_geo_location(self, latitude, longitude): self.config.participant["mobility_args"]["longitude"] = longitude if self._verbose: logging.info(f"π New geo location: {latitude}, {longitude}") + cle = ChangeLocationEvent(latitude, longitude) + asyncio.create_task(EventManager.get_instance().publish_addonevent(cle)) async def change_geo_location(self): """ @@ -285,7 +290,8 @@ async def change_geo_location(self): addr, dist, (lat, long) = selected_neighbor if dist > self.max_initiate_approximation: # If the distance is too big, we move towards the neighbor - logging.info(f"Moving towards nearest neighbor: {addr}") + if self._verbose: + logging.info(f"Moving towards nearest neighbor: {addr}") await self.change_geo_location_nearest_neighbor_strategy( dist, latitude, diff --git a/nebula/addons/networksimulation/nebulanetworksimulator.py b/nebula/addons/networksimulation/nebulanetworksimulator.py index c78626603..22010e872 100644 --- a/nebula/addons/networksimulation/nebulanetworksimulator.py +++ b/nebula/addons/networksimulation/nebulanetworksimulator.py @@ -1,16 +1,14 @@ import asyncio import logging import subprocess -from typing import TYPE_CHECKING +from functools import cached_property from nebula.addons.networksimulation.networksimulator import NetworkSimulator from nebula.core.eventmanager import EventManager from nebula.core.nebulaevents import GPSEvent +from nebula.core.network.communications import CommunicationsManager from nebula.core.utils.locker import Locker -if TYPE_CHECKING: - from nebula.core.network.communications import CommunicationsManager - class NebulaNS(NetworkSimulator): NETWORK_CONDITIONS = { @@ -21,8 +19,7 @@ class NebulaNS(NetworkSimulator): } IP_MULTICAST = "239.255.255.250" - def __init__(self, communication_manager: "CommunicationsManager", changing_interval, interface, verbose=False): - self._cm = communication_manager + def __init__(self, changing_interval, interface, verbose=False): self._refresh_interval = changing_interval self._node_interface = interface self._verbose = verbose @@ -31,9 +28,16 @@ def __init__(self, communication_manager: "CommunicationsManager", changing_inte self._current_network_conditions = {} self._running = False + @cached_property + def cm(self): + return CommunicationsManager.get_instance() + async def start(self): logging.info("π Nebula Network Simulator starting...") self._running = True + grace_time = self.cm.config.participant["mobility_args"]["grace_time_mobility"] + # if self._verbose: logging.info(f"Waiting {grace_time}s to start applying network conditions based on distances between devices") + # await asyncio.sleep(grace_time) await EventManager.get_instance().subscribe_addonevent( GPSEvent, self._change_network_conditions_based_on_distances ) @@ -213,7 +217,7 @@ def extract_number(value): match = re.match(r"([\d.]+)", value) if not match: - raise ValueError(f"Invalid format: {value}") + raise ValueError(f"Formato invΓ‘lido: {value}") return float(match.group(1)) if self._verbose: @@ -224,11 +228,11 @@ def extract_number(value): thresholds = sorted(th.keys()) - # If the distance is less than the first threshold, return the best condition + # Si la distancia es menor que el primer umbral, devolver la mejor condiciΓ³n if distance < thresholds[0]: conditions = {"bandwidth": th[thresholds[0]]["bandwidth"], "delay": th[thresholds[0]]["delay"]} - # Find the section in which the distance is located. + # Encontrar el tramo en el que se encuentra la distancia for i in range(len(thresholds) - 1): lower_bound = thresholds[i] upper_bound = thresholds[i + 1] @@ -241,7 +245,7 @@ def extract_number(value): lower_cond = th[lower_bound] upper_cond = th[upper_bound] - # Extract numerical values and units + # Extraer valores numΓ©ricos y unidades lower_bandwidth_value = extract_number(lower_cond["bandwidth"]) upper_bandwidth_value = extract_number(upper_cond["bandwidth"]) lower_bandwidth_unit = lower_cond["bandwidth"].replace(str(lower_bandwidth_value), "") @@ -251,22 +255,22 @@ def extract_number(value): upper_delay_value = extract_number(upper_cond["delay"]) delay_unit = lower_cond["delay"].replace(str(lower_delay_value), "") - # Calculate progress in the leg (0 to 1) + # Calcular el progreso en el tramo (0 a 1) progress = (distance - lower_bound) / (upper_bound - lower_bound) if self._verbose: logging.info(f"Progress between the bounds: {progress}") - # Linear interpolation of values + # InterpolaciΓ³n lineal de valores bandwidth_value = lower_bandwidth_value - progress * (lower_bandwidth_value - upper_bandwidth_value) delay_value = lower_delay_value + progress * (upper_delay_value - lower_delay_value) - # Reconstruct values with original units + # Reconstruir valores con unidades originales bandwidth = f"{round(bandwidth_value, 2)}{lower_bandwidth_unit}" delay = f"{round(delay_value, 2)}{delay_unit}" conditions = {"bandwidth": bandwidth, "delay": delay} - # If the distance is infinite, return the last value + # Si la distancia es infinita, devolver el ΓΊltimo valor if not conditions: conditions = {"bandwidth": th[float("inf")]["bandwidth"], "delay": th[float("inf")]["delay"]} if self._verbose: diff --git a/nebula/addons/networksimulation/networksimulator.py b/nebula/addons/networksimulation/networksimulator.py index 16bd22409..c7d70b7cb 100644 --- a/nebula/addons/networksimulation/networksimulator.py +++ b/nebula/addons/networksimulation/networksimulator.py @@ -27,9 +27,7 @@ class NetworkSimulatorException(Exception): pass -def factory_network_simulator( - net_sim, communication_manager, changing_interval, interface, verbose -) -> NetworkSimulator: +def factory_network_simulator(net_sim, changing_interval, interface, verbose) -> NetworkSimulator: from nebula.addons.networksimulation.nebulanetworksimulator import NebulaNS SIMULATION_SERVICES = { @@ -39,6 +37,6 @@ def factory_network_simulator( net_serv = SIMULATION_SERVICES.get(net_sim, NebulaNS) if net_serv: - return net_serv(communication_manager, changing_interval, interface, verbose) + return net_serv(changing_interval, interface, verbose) else: raise NetworkSimulatorException(f"Network Simulator {net_sim} not found") diff --git a/nebula/addons/reporter.py b/nebula/addons/reporter.py index cc6215b79..0a8e1426c 100755 --- a/nebula/addons/reporter.py +++ b/nebula/addons/reporter.py @@ -10,11 +10,11 @@ import psutil if TYPE_CHECKING: - from nebula.core.network.communications import CommunicationsManager + pass class Reporter: - def __init__(self, config, trainer, cm: "CommunicationsManager"): + def __init__(self, config, trainer): """ Initializes the reporter module for sending periodic updates to a dashboard controller. @@ -48,13 +48,13 @@ def __init__(self, config, trainer, cm: "CommunicationsManager"): - Initializes both current and accumulated metrics for traffic monitoring. """ logging.info("Starting reporter module") + self._cm = None self.config = config self.trainer = trainer - self.cm = cm self.frequency = self.config.participant["reporter_args"]["report_frequency"] self.grace_time = self.config.participant["reporter_args"]["grace_time_reporter"] self.data_queue = asyncio.Queue() - self.url = f"http://{self.config.participant['scenario_args']['controller']}/platform/dashboard/{self.config.participant['scenario_args']['name']}/node/update" + self.url = f"http://{self.config.participant['scenario_args']['controller']}/nodes/{self.config.participant['scenario_args']['name']}/update" self.counter = 0 self.first_net_metrics = True @@ -68,6 +68,16 @@ def __init__(self, config, trainer, cm: "CommunicationsManager"): self.acc_packets_sent = 0 self.acc_packets_recv = 0 + @property + def cm(self): + if not self._cm: + from nebula.core.network.communications import CommunicationsManager + + self._cm = CommunicationsManager.get_instance() + return self._cm + else: + return self._cm + async def enqueue_data(self, name, value): """ Asynchronously enqueues data for reporting. @@ -157,7 +167,7 @@ async def report_scenario_finished(self): might be temporarily overloaded. - Logs exceptions if the connection attempt to the controller fails. """ - url = f"http://{self.config.participant['scenario_args']['controller']}/platform/dashboard/{self.config.participant['scenario_args']['name']}/node/done" + url = f"http://{self.config.participant['scenario_args']['controller']}/nodes/{self.config.participant['scenario_args']['name']}/done" data = json.dumps({"idx": self.config.participant["device_args"]["idx"]}) headers = { "Content-Type": "application/json", diff --git a/nebula/addons/reputation/reputation.py b/nebula/addons/reputation/reputation.py index 38b86fcfc..58af42c04 100644 --- a/nebula/addons/reputation/reputation.py +++ b/nebula/addons/reputation/reputation.py @@ -1,17 +1,15 @@ -import csv import logging -import os import random -import torch -import numpy as np import time +from datetime import datetime +from typing import TYPE_CHECKING + import numpy as np +import torch -from typing import TYPE_CHECKING from nebula.addons.functions import print_msg_box -from nebula.core.nebulaevents import RoundStartEvent, UpdateReceivedEvent, MessageEvent, AggregationEvent from nebula.core.eventmanager import EventManager -from datetime import datetime +from nebula.core.nebulaevents import AggregationEvent, RoundStartEvent, UpdateReceivedEvent from nebula.core.utils.helper import ( cosine_metric, euclidean_metric, @@ -22,8 +20,9 @@ ) if TYPE_CHECKING: - from nebula.core.engine import Engine from nebula.config.config import Config + from nebula.core.engine import Engine + class Metrics: def __init__( @@ -47,15 +46,11 @@ def __init__( self.fraction_of_params_changed = { "fraction_changed": fraction_changed, "threshold": threshold, - "round": num_round - } - - self.model_arrival_latency = { - "latency": latency, "round": num_round, - "round_received": current_round } + self.model_arrival_latency = {"latency": latency, "round": num_round, "round_received": current_round} + self.messages = [] self.similarity = [] @@ -64,10 +59,11 @@ def __init__( class Reputation: """ Class to define and manage the reputation of a participant in the network. - + The class handles collection of metrics, calculation of static and dynamic reputation, updating history, and communication of reputation scores to neighbors. """ + def __init__(self, engine: "Engine", config: "Config"): """ Initialize the Reputation system. @@ -102,21 +98,25 @@ def __init__(self, engine: "Engine", config: "Config"): self._log_dir = engine.log_dir self._idx = engine.idx self.connection_metrics = [] - + neighbors: str = self._config.participant["network_args"]["neighbors"] self.connection_metrics = {} for nei in neighbors.split(): self.connection_metrics[f"{nei}"] = Metrics() - + self._with_reputation = self._config.participant["defense_args"]["with_reputation"] self._reputation_metrics = self._config.participant["defense_args"]["reputation_metrics"] self._initial_reputation = float(self._config.participant["defense_args"]["initial_reputation"]) self._weighting_factor = self._config.participant["defense_args"]["weighting_factor"] - self._weight_model_arrival_latency = float(self._config.participant["defense_args"]["weight_model_arrival_latency"]) + self._weight_model_arrival_latency = float( + self._config.participant["defense_args"]["weight_model_arrival_latency"] + ) self._weight_model_similarity = float(self._config.participant["defense_args"]["weight_model_similarity"]) self._weight_num_messages = float(self._config.participant["defense_args"]["weight_num_messages"]) - self._weight_fraction_params_changed = float(self._config.participant["defense_args"]["weight_fraction_params_changed"]) - + self._weight_fraction_params_changed = float( + self._config.participant["defense_args"]["weight_fraction_params_changed"] + ) + msg = f"Reputation system: {self._with_reputation}" msg += f"\nReputation metrics: {self._reputation_metrics}" msg += f"\nInitial reputation: {self._initial_reputation}" @@ -127,12 +127,12 @@ def __init__(self, engine: "Engine", config: "Config"): msg += f"\nWeight number of messages: {self._weight_num_messages}" msg += f"\nWeight fraction of parameters changed: {self._weight_fraction_params_changed}" print_msg_box(msg=msg, indent=2, title="Defense information") - + @property def engine(self): """Return the engine instance.""" return self._engine - + def save_data( self, type_data, @@ -195,17 +195,19 @@ def save_data( self.connection_metrics[nei].messages = [] self.connection_metrics[nei].messages.append(combined_data["number_message"]) elif type_data == "fraction_of_params_changed": - self.connection_metrics[nei].fraction_of_params_changed.update(combined_data["fraction_of_params_changed"]) + self.connection_metrics[nei].fraction_of_params_changed.update( + combined_data["fraction_of_params_changed"] + ) elif type_data == "model_arrival_latency": self.connection_metrics[nei].model_arrival_latency.update(combined_data["model_arrival_latency"]) except Exception: logging.exception("Error saving data") - + async def setup(self): """ Setup the reputation system by subscribing to various events. - + This function enables the reputation system and subscribes to events based on active metrics. """ if self._with_reputation: @@ -215,17 +217,25 @@ async def setup(self): if self._reputation_metrics.get("model_similarity", False): await EventManager.get_instance().subscribe_node_event(UpdateReceivedEvent, self.recollect_similarity) if self._reputation_metrics.get("fraction_parameters_changed", False): - await EventManager.get_instance().subscribe_node_event(UpdateReceivedEvent, self.recollect_fraction_of_parameters_changed) + await EventManager.get_instance().subscribe_node_event( + UpdateReceivedEvent, self.recollect_fraction_of_parameters_changed + ) if self._reputation_metrics.get("num_messages", False): await EventManager.get_instance().subscribe(("model", "update"), self.recollect_number_message) await EventManager.get_instance().subscribe(("model", "initialization"), self.recollect_number_message) await EventManager.get_instance().subscribe(("control", "alive"), self.recollect_number_message) - await EventManager.get_instance().subscribe(("federation", "federation_models_included"), self.recollect_number_message) + await EventManager.get_instance().subscribe( + ("federation", "federation_models_included"), self.recollect_number_message + ) await EventManager.get_instance().subscribe(("reputation", "share"), self.recollect_number_message) if self._reputation_metrics.get("model_arrival_latency", False): - await EventManager.get_instance().subscribe_node_event(UpdateReceivedEvent, self.recollect_model_arrival_latency) - - def init_reputation(self, addr, federation_nodes=None, round_num=None, last_feedback_round=None, init_reputation=None): + await EventManager.get_instance().subscribe_node_event( + UpdateReceivedEvent, self.recollect_model_arrival_latency + ) + + def init_reputation( + self, addr, federation_nodes=None, round_num=None, last_feedback_round=None, init_reputation=None + ): """ Initialize the reputation for each federation node. @@ -272,13 +282,24 @@ def is_valid_ip(self, federation_nodes): list: A list of valid IP addresses. """ valid_ip = [] - for i in federation_nodes: + for i in federation_nodes: valid_ip.append(i) return valid_ip - def _calculate_static_reputation(self, addr, nei, metric_messages_number, metric_similarity, metric_fraction, metric_model_arrival_latency, - weight_messages_number, weight_similarity, weight_fraction, weight_model_arrival_latency): + def _calculate_static_reputation( + self, + addr, + nei, + metric_messages_number, + metric_similarity, + metric_fraction, + metric_model_arrival_latency, + weight_messages_number, + weight_similarity, + weight_fraction, + weight_model_arrival_latency, + ): """ Calculate the static reputation of a participant using fixed weights. @@ -343,7 +364,8 @@ async def _calculate_dynamic_reputation(self, addr, neighbors): for metric_name in self.history_data.keys(): if self._reputation_metrics.get(metric_name, False): valid_entries = [ - entry for entry in self.history_data[metric_name] + entry + for entry in self.history_data[metric_name] if entry["round"] >= self._engine.get_round() and entry.get("weight") not in [None, -1] ] @@ -358,16 +380,21 @@ async def _calculate_dynamic_reputation(self, addr, neighbors): for metric_name in self.history_data.keys(): if self._reputation_metrics.get(metric_name, False): for entry in self.history_data.get(metric_name, []): - if entry["round"] == self._engine.get_round() and entry["metric_name"] == metric_name and entry["nei"] == nei: + if ( + entry["round"] == self._engine.get_round() + and entry["metric_name"] == metric_name + and entry["nei"] == nei + ): metric_values[metric_name] = entry["metric_value"] break if all(metric_name in metric_values for metric_name in average_weights): reputation_with_weights = sum( - metric_values.get(metric_name, 0) * average_weights[metric_name] - for metric_name in average_weights + metric_values.get(metric_name, 0) * average_weights[metric_name] for metric_name in average_weights + ) + logging.info( + f"Dynamic reputation with weights for {nei} at round {self.engine.get_round()}: {reputation_with_weights}" ) - logging.info(f"Dynamic reputation with weights for {nei} at round {self.engine.get_round()}: {reputation_with_weights}") avg_reputation = self.save_reputation_history_in_memory(self.engine.addr, nei, reputation_with_weights) @@ -406,7 +433,7 @@ def _update_reputation_record(self, nei, reputation, data): if self.reputation[nei]["reputation"] < 0.6: self.rejected_nodes.add(nei) logging.info(f"Rejected node {nei} at round {self._engine.get_round()}") - + def calculate_weighted_values( self, avg_messages_number_message_normalized, @@ -417,7 +444,7 @@ def calculate_weighted_values( current_round, addr, nei, - reputation_metrics + reputation_metrics, ): """ Calculate the weighted values for each metric based on current measurements and historical data. @@ -434,7 +461,6 @@ def calculate_weighted_values( reputation_metrics (dict): Dictionary indicating which metrics are active. """ if current_round is not None: - normalized_weights = {} required_keys = [ "num_messages", @@ -451,7 +477,7 @@ def calculate_weighted_values( "num_messages": avg_messages_number_message_normalized, "model_similarity": similarity_reputation, "fraction_parameters_changed": fraction_score_asign, - "model_arrival_latency": avg_model_arrival_latency, + "model_arrival_latency": avg_model_arrival_latency, } active_metrics = {k: v for k, v in metrics.items() if reputation_metrics.get(k, False)} @@ -464,7 +490,7 @@ def calculate_weighted_values( "nei": nei, "metric_name": metric_name, "metric_value": current_value, - "weight": None + "weight": None, }) adjusted_weights = {} @@ -474,7 +500,11 @@ def calculate_weighted_values( for metric_name, current_value in active_metrics.items(): historical_values = history_data[metric_name] - metric_values = [entry['metric_value'] for entry in historical_values if 'metric_value' in entry and entry["metric_value"] != 0] + metric_values = [ + entry["metric_value"] + for entry in historical_values + if "metric_value" in entry and entry["metric_value"] != 0 + ] if metric_values: mean_value = np.mean(metric_values) @@ -487,7 +517,10 @@ def calculate_weighted_values( if all(deviation == 0.0 for deviation in desviations.values()): random_weights = [random.random() for _ in range(num_active_metrics)] total_random_weight = sum(random_weights) - normalized_weights = {metric_name: weight / total_random_weight for metric_name, weight in zip(active_metrics, random_weights)} + normalized_weights = { + metric_name: weight / total_random_weight + for metric_name, weight in zip(active_metrics, random_weights, strict=False) + } else: max_desviation = max(desviations.values()) if desviations else 1 normalized_weights = { @@ -503,7 +536,7 @@ def calculate_weighted_values( normalized_weights = {metric_name: 1 / num_active_metrics for metric_name in active_metrics} mean_deviation = np.mean(list(desviations.values())) - dynamic_min_weight = max(0.1, mean_deviation / (mean_deviation + 1)) + dynamic_min_weight = max(0.1, mean_deviation / (mean_deviation + 1)) total_adjusted_weight = 0 @@ -562,10 +595,17 @@ async def calculate_value_metrics(self, log_dir, id_node, addr, nei, metrics_act metrics_instance = self.connection_metrics.get(nei) if not metrics_instance: logging.warning(f"No metrics found for neighbor {nei}") - return avg_messages_number_message_normalized, similarity_reputation, fraction_score_asign, avg_model_arrival_latency + return ( + avg_messages_number_message_normalized, + similarity_reputation, + fraction_score_asign, + avg_model_arrival_latency, + ) if metrics_active.get("num_messages", False): - filtered_messages = [msg for msg in metrics_instance.messages if msg.get("current_round") == current_round] + filtered_messages = [ + msg for msg in metrics_instance.messages if msg.get("current_round") == current_round + ] for msg in filtered_messages: self.messages_number_message.append({ "number_message": msg.get("time"), @@ -580,7 +620,9 @@ async def calculate_value_metrics(self, log_dir, id_node, addr, nei, metrics_act addr, nei, messages_number_message_normalized, current_round ) if avg_messages_number_message_normalized is None and current_round > 4: - avg_messages_number_message_normalized = self.number_message_history[(addr, nei)][current_round - 1]["avg_number_message"] + avg_messages_number_message_normalized = self.number_message_history[(addr, nei)][ + current_round - 1 + ]["avg_number_message"] if metrics_active.get("fraction_parameters_changed", False): if metrics_instance.fraction_of_params_changed.get("round") == current_round: @@ -600,11 +642,7 @@ async def calculate_value_metrics(self, log_dir, id_node, addr, nei, metrics_act round_latency = metrics_instance.model_arrival_latency.get("round") latency = metrics_instance.model_arrival_latency.get("latency") messages_model_arrival_latency_normalized = self.manage_model_arrival_latency( - round_latency, - addr, - nei, - latency, - current_round + round_latency, addr, nei, latency, current_round ) if current_round >= 5 and metrics_active.get("model_similarity", False): @@ -617,7 +655,9 @@ async def calculate_value_metrics(self, log_dir, id_node, addr, nei, metrics_act addr, nei, messages_model_arrival_latency_normalized, current_round ) if avg_model_arrival_latency is None and current_round > 4: - avg_model_arrival_latency = self.model_arrival_latency_history[(addr, nei)][current_round - 1]["score"] + avg_model_arrival_latency = self.model_arrival_latency_history[(addr, nei)][current_round - 1][ + "score" + ] if self.messages_number_message is not None: messages_number_message_normalized, messages_number_message_count = self.manage_metric_number_message( @@ -627,7 +667,9 @@ async def calculate_value_metrics(self, log_dir, id_node, addr, nei, metrics_act addr, nei, messages_number_message_normalized, current_round ) if avg_messages_number_message_normalized is None and current_round > 4: - avg_messages_number_message_normalized = self.number_message_history[(addr, nei)][current_round - 1]["avg_number_message"] + avg_messages_number_message_normalized = self.number_message_history[(addr, nei)][ + current_round - 1 + ]["avg_number_message"] if current_round >= 5: if fraction_score_normalized > 0: @@ -640,10 +682,14 @@ async def calculate_value_metrics(self, log_dir, id_node, addr, nei, metrics_act if fraction_previous_round is not None: fraction_score_asign = fraction_score_normalized * 0.8 + fraction_previous_round * 0.2 - self.fraction_changed_history[(addr, nei, current_round)]["fraction_score"] = fraction_score_asign + self.fraction_changed_history[(addr, nei, current_round)]["fraction_score"] = ( + fraction_score_asign + ) else: fraction_score_asign = fraction_score_normalized - self.fraction_changed_history[(addr, nei, current_round)]["fraction_score"] = fraction_score_asign + self.fraction_changed_history[(addr, nei, current_round)]["fraction_score"] = ( + fraction_score_asign + ) else: fraction_previous_round = None key_previous_round = (addr, nei, current_round - 1) if current_round - 1 > 0 else None @@ -665,7 +711,7 @@ async def calculate_value_metrics(self, log_dir, id_node, addr, nei, metrics_act if fraction_neighbors_scores: fraction_score_asign = np.mean(list(fraction_neighbors_scores.values())) else: - fraction_score_asign = 0 + fraction_score_asign = 0 else: fraction_score_asign = 0 @@ -681,7 +727,12 @@ async def calculate_value_metrics(self, log_dir, id_node, addr, nei, metrics_act self.engine.total_rounds, ) - return avg_messages_number_message_normalized, similarity_reputation, fraction_score_asign, avg_model_arrival_latency + return ( + avg_messages_number_message_normalized, + similarity_reputation, + fraction_score_asign, + avg_model_arrival_latency, + ) except Exception as e: logging.exception(f"Error calculating reputation. Type: {type(e).__name__}") @@ -714,7 +765,9 @@ def create_graphics_to_metrics( """ if current_round is not None and current_round < total_rounds: model_arrival_latency_dict = {f"R-Model_arrival_latency_reputation/{addr}": {nei: model_arrival_latency}} - messages_number_message_count_dict = {f"R-Count_messages_number_message_reputation/{addr}": {nei: number_message_count}} + messages_number_message_count_dict = { + f"R-Count_messages_number_message_reputation/{addr}": {nei: number_message_count} + } messages_number_message_norm_dict = {f"R-number_message_reputation/{addr}": {nei: number_message_norm}} similarity_dict = {f"R-Similarity_reputation/{addr}": {nei: similarity}} fraction_dict = {f"R-Fraction_reputation/{addr}": {nei: fraction}} @@ -815,9 +868,7 @@ def analyze_anomalies( for i in range(0, round_num + 1): potential_prev_key = (addr, nei, round_num - i) if potential_prev_key in self.fraction_changed_history: - mean_fraction_prev = self.fraction_changed_history[potential_prev_key][ - "mean_fraction" - ] + mean_fraction_prev = self.fraction_changed_history[potential_prev_key]["mean_fraction"] if mean_fraction_prev is not None: prev_key = potential_prev_key break @@ -840,8 +891,16 @@ def analyze_anomalies( self.fraction_changed_history[key]["fraction_anomaly"] = fraction_anomaly self.fraction_changed_history[key]["threshold_anomaly"] = threshold_anomaly - penalization_factor_fraction = abs(current_fraction - mean_fraction_prev) / mean_fraction_prev if mean_fraction_prev != 0 else 1 - penalization_factor_threshold = abs(current_threshold - mean_threshold_prev) / mean_threshold_prev if mean_threshold_prev != 0 else 1 + penalization_factor_fraction = ( + abs(current_fraction - mean_fraction_prev) / mean_fraction_prev + if mean_fraction_prev != 0 + else 1 + ) + penalization_factor_threshold = ( + abs(current_threshold - mean_threshold_prev) / mean_threshold_prev + if mean_threshold_prev != 0 + else 1 + ) k_fraction = penalization_factor_fraction if penalization_factor_fraction != 0 else 1 k_threshold = penalization_factor_threshold if penalization_factor_threshold != 0 else 1 @@ -871,16 +930,20 @@ def analyze_anomalies( if current_threshold is not None and mean_threshold_prev is not None else 0 ) - + fraction_weight = 0.5 threshold_weight = 0.5 fraction_score = fraction_weight * fraction_value + threshold_weight * threshold_value self.fraction_changed_history[key]["mean_fraction"] = (current_fraction + mean_fraction_prev) / 2 - self.fraction_changed_history[key]["std_dev_fraction"] = np.sqrt(((current_fraction - mean_fraction_prev) ** 2 + std_dev_fraction_prev**2) / 2) + self.fraction_changed_history[key]["std_dev_fraction"] = np.sqrt( + ((current_fraction - mean_fraction_prev) ** 2 + std_dev_fraction_prev**2) / 2 + ) self.fraction_changed_history[key]["mean_threshold"] = (current_threshold + mean_threshold_prev) / 2 - self.fraction_changed_history[key]["std_dev_threshold"] = np.sqrt(((0.1 * (current_threshold - mean_threshold_prev) ** 2) + std_dev_threshold_prev**2) / 2) + self.fraction_changed_history[key]["std_dev_threshold"] = np.sqrt( + ((0.1 * (current_threshold - mean_threshold_prev) ** 2) + std_dev_threshold_prev**2) / 2 + ) return max(fraction_score, 0) else: @@ -888,10 +951,8 @@ def analyze_anomalies( except Exception: logging.exception("Error analyzing anomalies") return -1 - - def manage_model_arrival_latency( - self, round_num, addr, nei, latency, current_round - ): + + def manage_model_arrival_latency(self, round_num, addr, nei, latency, current_round): """ Manage the model arrival latency metric and normalize it based on historical latencies. @@ -1021,7 +1082,7 @@ def save_model_arrival_latency_history(self, addr, nei, model_arrival_latency, r return avg_model_arrival_latency except Exception: logging.exception("Error saving model_arrival_latency history") - + def manage_metric_number_message(self, messages_number_message, addr, nei, current_round, metric_active=True): """ Manage and normalize the number of messages metric using percentiles. @@ -1042,7 +1103,7 @@ def manage_metric_number_message(self, messages_number_message, addr, nei, curre if not metric_active: return 0.0, 0 - + previous_round = current_round current_addr_nei = (addr, nei) @@ -1093,7 +1154,7 @@ def manage_metric_number_message(self, messages_number_message, addr, nei, curre except Exception: logging.exception("Error managing metric number_message") return 0.0, 0 - + def save_number_message_history(self, addr, nei, messages_number_message_normalized, current_round): """ Save the normalized number_message history in memory and calculate a weighted average. @@ -1137,7 +1198,7 @@ def save_number_message_history(self, addr, nei, messages_number_message_normali except Exception as e: logging.exception(f"Error managing model_arrival_latency latency: {e}") return 0.0 - + def save_reputation_history_in_memory(self, addr, nei, reputation): """ Save the reputation history for a neighbor and compute an average reputation. @@ -1236,10 +1297,10 @@ def calculate_similarity_from_metrics(self, nei, current_round): pearson_correlation = float(metric.get("pearson_correlation", 0)) similarity_value = ( - weight_cosine * cosine + - weight_euclidean * euclidean + - weight_manhattan * manhattan + - weight_pearson * pearson_correlation + weight_cosine * cosine + + weight_euclidean * euclidean + + weight_manhattan * manhattan + + weight_pearson * pearson_correlation ) return similarity_value @@ -1264,16 +1325,19 @@ async def calculate_reputation(self, ae: AggregationEvent): history_data = self.history_data for nei in neighbors: - metric_messages_number, metric_similarity, metric_fraction, metric_model_arrival_latency = ( - await self.calculate_value_metrics( - self._log_dir, - self._idx, - self._addr, - nei, - metrics_active=self._reputation_metrics, - ) + ( + metric_messages_number, + metric_similarity, + metric_fraction, + metric_model_arrival_latency, + ) = await self.calculate_value_metrics( + self._log_dir, + self._idx, + self._addr, + nei, + metrics_active=self._reputation_metrics, ) - + if self._weighting_factor == "dynamic": self.calculate_weighted_values( metric_messages_number, @@ -1300,7 +1364,7 @@ async def calculate_reputation(self, ae: AggregationEvent): self._weight_fraction_params_changed, self._weight_model_arrival_latency, ) - + if self._weighting_factor == "dynamic" and self._engine.get_round() >= 5: await self._calculate_dynamic_reputation(self._addr, neighbors) @@ -1342,13 +1406,17 @@ async def send_reputation_to_neighbors(self, neighbors): for neighbor in neighbors_to_send: message = self._engine.cm.create_message( - "reputation", "share", node_id=nei, score=float(data["reputation"]), round=self._engine.get_round() + "reputation", + "share", + node_id=nei, + score=float(data["reputation"]), + round=self._engine.get_round(), ) await self._engine.cm.send_message(neighbor, message) logging.info( f"Sending reputation to node {nei} from node {neighbor} with reputation {data['reputation']}" ) - + def create_graphic_reputation(self, addr, round_num): """ Create a graphical representation of the reputation scores and log the data. @@ -1371,7 +1439,7 @@ def create_graphic_reputation(self, addr, round_num): except Exception: logging.exception("Error creating reputation graphic") - + async def update_process_aggregation(self, updates): """ Update the aggregation process by removing nodes that have been rejected. @@ -1402,7 +1470,7 @@ async def include_feedback_in_reputation(self): if self.reputation_with_all_feedback is None: logging.info("No feedback received.") return False - + updated = False for (current_node, node_ip, round_num), scores in self.reputation_with_all_feedback.items(): @@ -1414,7 +1482,10 @@ async def include_feedback_in_reputation(self): logging.info(f"No reputation for node {node_ip}") continue - if "last_feedback_round" in self.reputation[node_ip] and self.reputation[node_ip]["last_feedback_round"] >= round_num: + if ( + "last_feedback_round" in self.reputation[node_ip] + and self.reputation[node_ip]["last_feedback_round"] >= round_num + ): continue avg_feedback = sum(scores) / len(scores) @@ -1440,7 +1511,7 @@ async def include_feedback_in_reputation(self): return True else: return False - + async def on_round_start(self, rse: RoundStartEvent): """ Event handler for the start of a round. It stores the start time and updates the expected nodes. @@ -1467,7 +1538,7 @@ async def recollect_model_arrival_latency(self, ure: UpdateReceivedEvent): if current_round not in self.round_timing_info: self.round_timing_info[current_round] = {} - + if "model_received_time" not in self.round_timing_info[current_round]: self.round_timing_info[current_round]["model_received_time"] = {} diff --git a/nebula/config/config.py b/nebula/config/config.py index 20df01016..4b0b8df3a 100755 --- a/nebula/config/config.py +++ b/nebula/config/config.py @@ -184,6 +184,9 @@ def add_neighbor_from_config(self, addr): self.participant["network_args"]["neighbors"] += " " + addr self.participant["mobility_args"]["neighbors_distance"][addr] = None + def update_nodes_distance(self, distances: dict): + self.participant["mobility_args"]["neighbors_distance"] = {node: dist for node, (dist, _) in distances.items()} + def update_neighbors_from_config(self, current_connections, dest_addr): final_neighbors = [] for n in current_connections: diff --git a/nebula/controller.py b/nebula/controller.py deleted file mode 100755 index 5c266b56c..000000000 --- a/nebula/controller.py +++ /dev/null @@ -1,738 +0,0 @@ -import asyncio -import importlib -import json -import logging -import os -import re -import signal -import subprocess -import sys -import threading -import time - -import docker -import psutil -import uvicorn -from dotenv import load_dotenv -from fastapi import FastAPI -from watchdog.events import PatternMatchingEventHandler -from watchdog.observers import Observer - -from nebula.addons.env import check_environment -from nebula.config.config import Config -from nebula.config.mender import Mender -from nebula.scenarios import ScenarioManagement -from nebula.tests import main as deploy_tests -from nebula.utils import DockerUtils, SocketUtils - - -# Setup controller logger -class TermEscapeCodeFormatter(logging.Formatter): - def __init__(self, fmt=None, datefmt=None, style="%", validate=True): - super().__init__(fmt, datefmt, style, validate) - - def format(self, record): - escape_re = re.compile(r"\x1b\[[0-9;]*m") - record.msg = re.sub(escape_re, "", str(record.msg)) - return super().format(record) - - -# Initialize FastAPI app outside the Controller class -app = FastAPI() - - -# Define endpoints outside the Controller class -@app.get("/") -async def read_root(): - return {"message": "Welcome to the NEBULA Controller API"} - - -@app.get("/status") -async def get_status(): - return {"status": "NEBULA Controller API is running"} - - -@app.get("/resources") -async def get_resources(): - devices = 0 - gpu_memory_percent = [] - - # Obtain available RAM - memory_info = await asyncio.to_thread(psutil.virtual_memory) - - if importlib.util.find_spec("pynvml") is not None: - try: - import pynvml - - await asyncio.to_thread(pynvml.nvmlInit) - devices = await asyncio.to_thread(pynvml.nvmlDeviceGetCount) - - # Obtain GPU info - for i in range(devices): - handle = await asyncio.to_thread(pynvml.nvmlDeviceGetHandleByIndex, i) - memory_info_gpu = await asyncio.to_thread(pynvml.nvmlDeviceGetMemoryInfo, handle) - memory_used_percent = (memory_info_gpu.used / memory_info_gpu.total) * 100 - gpu_memory_percent.append(memory_used_percent) - - except Exception: # noqa: S110 - pass - - return { - # "cpu_percent": psutil.cpu_percent(), - "gpus": devices, - "memory_percent": memory_info.percent, - "gpu_memory_percent": gpu_memory_percent, - } - - -@app.get("/least_memory_gpu") -async def get_least_memory_gpu(): - gpu_with_least_memory_index = None - - if importlib.util.find_spec("pynvml") is not None: - max_memory_used_percent = 50 - try: - import pynvml - - await asyncio.to_thread(pynvml.nvmlInit) - devices = await asyncio.to_thread(pynvml.nvmlDeviceGetCount) - - # Obtain GPU info - for i in range(devices): - handle = await asyncio.to_thread(pynvml.nvmlDeviceGetHandleByIndex, i) - memory_info = await asyncio.to_thread(pynvml.nvmlDeviceGetMemoryInfo, handle) - memory_used_percent = (memory_info.used / memory_info.total) * 100 - - # Obtain GPU with less memory available - if memory_used_percent > max_memory_used_percent: - max_memory_used_percent = memory_used_percent - gpu_with_least_memory_index = i - - except Exception: # noqa: S110 - pass - - return { - "gpu_with_least_memory_index": gpu_with_least_memory_index, - } - - -@app.get("/available_gpus/") -async def get_available_gpu(): - available_gpus = [] - - if importlib.util.find_spec("pynvml") is not None: - try: - import pynvml - - await asyncio.to_thread(pynvml.nvmlInit) - devices = await asyncio.to_thread(pynvml.nvmlDeviceGetCount) - - # Obtain GPU info - for i in range(devices): - handle = await asyncio.to_thread(pynvml.nvmlDeviceGetHandleByIndex, i) - memory_info = await asyncio.to_thread(pynvml.nvmlDeviceGetMemoryInfo, handle) - memory_used_percent = (memory_info.used / memory_info.total) * 100 - - # Obtain available GPUs - if memory_used_percent < 5: - available_gpus.append(i) - - return { - "available_gpus": available_gpus, - } - except Exception: # noqa: S110 - pass - - -class NebulaEventHandler(PatternMatchingEventHandler): - """ - NebulaEventHandler handles file system events for .sh scripts. - - This class monitors the creation, modification, and deletion of .sh scripts - in a specified directory. - """ - - patterns = ["*.sh", "*.ps1"] - - def __init__(self): - super(NebulaEventHandler, self).__init__() - self.last_processed = {} - self.timeout_ns = 5 * 1e9 - self.processing_files = set() - self.lock = threading.Lock() - - def _should_process_event(self, src_path: str) -> bool: - current_time_ns = time.time_ns() - logging.info(f"Current time (ns): {current_time_ns}") - with self.lock: - if src_path in self.last_processed: - logging.info(f"Last processed time for {src_path}: {self.last_processed[src_path]}") - last_time = self.last_processed[src_path] - if current_time_ns - last_time < self.timeout_ns: - return False - self.last_processed[src_path] = current_time_ns - return True - - def _is_being_processed(self, src_path: str) -> bool: - with self.lock: - if src_path in self.processing_files: - logging.info(f"Skipping {src_path} as it is already being processed.") - return True - self.processing_files.add(src_path) - return False - - def _processing_done(self, src_path: str): - with self.lock: - if src_path in self.processing_files: - self.processing_files.remove(src_path) - - def verify_nodes_ports(self, src_path): - parent_dir = os.path.dirname(src_path) - base_dir = os.path.basename(parent_dir) - scenario_path = os.path.join(os.path.dirname(parent_dir), base_dir) - - try: - port_mapping = {} - new_port_start = 50000 - - participant_files = sorted( - f for f in os.listdir(scenario_path) if f.endswith(".json") and f.startswith("participant") - ) - - for filename in participant_files: - file_path = os.path.join(scenario_path, filename) - with open(file_path) as json_file: - node = json.load(json_file) - current_port = node["network_args"]["port"] - port_mapping[current_port] = SocketUtils.find_free_port(start_port=new_port_start) - logging.info( - f"Participant file: {filename} | Current port: {current_port} | New port: {port_mapping[current_port]}" - ) - new_port_start = port_mapping[current_port] + 1 - - for filename in participant_files: - file_path = os.path.join(scenario_path, filename) - with open(file_path) as json_file: - node = json.load(json_file) - current_port = node["network_args"]["port"] - node["network_args"]["port"] = port_mapping[current_port] - neighbors = node["network_args"]["neighbors"] - - for old_port, new_port in port_mapping.items(): - neighbors = neighbors.replace(f":{old_port}", f":{new_port}") - - node["network_args"]["neighbors"] = neighbors - - with open(file_path, "w") as f: - json.dump(node, f, indent=4) - - except Exception as e: - print(f"Error processing JSON files: {e}") - - def on_created(self, event): - """ - Handles the event when a file is created. - """ - if event.is_directory: - return - src_path = event.src_path - if not self._should_process_event(src_path): - return - if self._is_being_processed(src_path): - return - logging.info("File created: %s" % src_path) - try: - self.verify_nodes_ports(src_path) - self.run_script(src_path) - finally: - self._processing_done(src_path) - - def on_deleted(self, event): - """ - Handles the event when a file is deleted. - """ - if event.is_directory: - return - src_path = event.src_path - if not self._should_process_event(src_path): - return - if self._is_being_processed(src_path): - return - logging.info("File deleted: %s" % src_path) - directory_script = os.path.dirname(src_path) - pids_file = os.path.join(directory_script, "current_scenario_pids.txt") - logging.info(f"Killing processes from {pids_file}") - try: - self.kill_script_processes(pids_file) - os.remove(pids_file) - except FileNotFoundError: - logging.warning(f"{pids_file} not found.") - except Exception as e: - logging.exception(f"Error while killing processes: {e}") - finally: - self._processing_done(src_path) - - def run_script(self, script): - try: - logging.info(f"Running script: {script}") - if script.endswith(".sh"): - result = subprocess.run(["bash", script], capture_output=True, text=True) - logging.info(f"Script output:\n{result.stdout}") - if result.stderr: - logging.error(f"Script error:\n{result.stderr}") - elif script.endswith(".ps1"): - subprocess.Popen( - ["powershell", "-ExecutionPolicy", "Bypass", "-File", script], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=False, - ) - else: - logging.error("Unsupported script format.") - return - except Exception as e: - logging.exception(f"Error while running script: {e}") - - def kill_script_processes(self, pids_file): - try: - with open(pids_file) as f: - pids = f.readlines() - for pid in pids: - try: - pid = int(pid.strip()) - if psutil.pid_exists(pid): - process = psutil.Process(pid) - children = process.children(recursive=True) - logging.info(f"Forcibly killing process {pid} and {len(children)} child processes...") - for child in children: - try: - logging.info(f"Forcibly killing child process {child.pid}") - child.kill() - except psutil.NoSuchProcess: - logging.warning(f"Child process {child.pid} already terminated.") - except Exception as e: - logging.exception(f"Error while forcibly killing child process {child.pid}: {e}") - try: - logging.info(f"Forcibly killing main process {pid}") - process.kill() - except psutil.NoSuchProcess: - logging.warning(f"Process {pid} already terminated.") - except Exception as e: - logging.exception(f"Error while forcibly killing main process {pid}: {e}") - else: - logging.warning(f"PID {pid} does not exist.") - except ValueError: - logging.exception(f"Invalid PID value in file: {pid}") - except Exception as e: - logging.exception(f"Error while forcibly killing process {pid}: {e}") - except FileNotFoundError: - logging.exception(f"PID file not found: {pids_file}") - except Exception as e: - logging.exception(f"Error while reading PIDs from file: {e}") - - -class Controller: - def __init__(self, args): - self.scenario_name = args.scenario_name if hasattr(args, "scenario_name") else None - self.start_date_scenario = None - self.federation = args.federation if hasattr(args, "federation") else None - self.topology = args.topology if hasattr(args, "topology") else None - self.controller_port = int(args.controllerport) if hasattr(args, "controllerport") else 5000 - self.waf_port = int(args.wafport) if hasattr(args, "wafport") else 6000 - self.frontend_port = int(args.webport) if hasattr(args, "webport") else 6060 - self.grafana_port = int(args.grafanaport) if hasattr(args, "grafanaport") else 6040 - self.loki_port = int(args.lokiport) if hasattr(args, "lokiport") else 6010 - self.statistics_port = int(args.statsport) if hasattr(args, "statsport") else 8080 - self.simulation = args.simulation - self.config_dir = args.config - self.databases_dir = args.databases if hasattr(args, "databases") else "/opt/nebula" - self.test = args.test if hasattr(args, "test") else False - self.log_dir = args.logs - self.cert_dir = args.certs - self.env_path = args.env - self.production = args.production if hasattr(args, "production") else False - self.advanced_analytics = args.advanced_analytics if hasattr(args, "advanced_analytics") else False - self.matrix = args.matrix if hasattr(args, "matrix") else None - self.root_path = ( - args.root_path - if hasattr(args, "root_path") - else os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - ) - self.host_platform = "windows" if sys.platform == "win32" else "unix" - - # Network configuration (nodes deployment in a network) - self.network_subnet = args.network_subnet if hasattr(args, "network_subnet") else None - self.network_gateway = args.network_gateway if hasattr(args, "network_gateway") else None - - # Configure logger - self.configure_logger() - - # Check ports available - if not SocketUtils.is_port_open(self.controller_port): - self.controller_port = SocketUtils.find_free_port() - - if not SocketUtils.is_port_open(self.frontend_port): - self.frontend_port = SocketUtils.find_free_port(self.controller_port + 1) - - if not SocketUtils.is_port_open(self.statistics_port): - self.statistics_port = SocketUtils.find_free_port(self.frontend_port + 1) - - self.config = Config(entity="controller") - self.topologymanager = None - self.n_nodes = 0 - self.mender = None if self.simulation else Mender() - self.use_blockchain = args.use_blockchain if hasattr(args, "use_blockchain") else False - self.gpu_available = False - - # Reference the global app instance - self.app = app - - def configure_logger(self): - log_console_format = "[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s" - console_handler = logging.StreamHandler() - console_handler.setLevel(logging.INFO) - console_handler.setFormatter(TermEscapeCodeFormatter(log_console_format)) - console_handler_file = logging.FileHandler(os.path.join(self.log_dir, "controller.log"), mode="a") - console_handler_file.setLevel(logging.INFO) - console_handler_file.setFormatter(logging.Formatter("[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")) - logging.basicConfig( - level=logging.DEBUG, - handlers=[ - console_handler, - console_handler_file, - ], - ) - uvicorn_loggers = ["uvicorn", "uvicorn.error", "uvicorn.access"] - for logger_name in uvicorn_loggers: - logger = logging.getLogger(logger_name) - logger.handlers = [] # Remove existing handlers - logger.propagate = False # Prevent duplicate logs - handler = logging.FileHandler(os.path.join(self.log_dir, "controller.log"), mode="a") - handler.setFormatter(logging.Formatter("[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")) - logger.addHandler(handler) - - def start(self): - banner = """ - ββββ ββββββββββββββββββ βββ ββββββ ββββββ - βββββ ββββββββββββββββββββββ ββββββ ββββββββ - ββββββ βββββββββ βββββββββββ ββββββ ββββββββ - ββββββββββββββββ βββββββββββ ββββββ ββββββββ - βββ ββββββββββββββββββββββββββββββββββββββββββ βββ - βββ ββββββββββββββββββββ βββββββ βββββββββββ βββ - A Platform for Decentralized Federated Learning - Created by Enrique TomΓ‘s MartΓnez BeltrΓ‘n - https://github.com/CyberDataLab/nebula - """ - print("\x1b[0;36m" + banner + "\x1b[0m") - - # Load the environment variables - load_dotenv(self.env_path) - - # Save controller pid - with open(os.path.join(os.path.dirname(__file__), "controller.pid"), "w") as f: - f.write(str(os.getpid())) - - # Check information about the environment - check_environment() - - # Save the configuration in environment variables - logging.info("Saving configuration in environment variables...") - os.environ["NEBULA_ROOT"] = self.root_path - os.environ["NEBULA_LOGS_DIR"] = self.log_dir - os.environ["NEBULA_CONFIG_DIR"] = self.config_dir - os.environ["NEBULA_CERTS_DIR"] = self.cert_dir - os.environ["NEBULA_STATISTICS_PORT"] = str(self.statistics_port) - os.environ["NEBULA_ROOT_HOST"] = self.root_path - os.environ["NEBULA_HOST_PLATFORM"] = self.host_platform - - # Start the FastAPI app in a daemon thread - app_thread = threading.Thread(target=self.run_controller_api, daemon=True) - app_thread.start() - logging.info(f"NEBULA Controller is running at port {self.controller_port}") - - if self.production: - self.run_waf() - logging.info(f"NEBULA WAF is running at port {self.waf_port}") - logging.info(f"Grafana Dashboard is running at port {self.grafana_port}") - - if self.test: - self.run_test() - else: - self.run_frontend() - logging.info(f"NEBULA Frontend is running at http://localhost:{self.frontend_port}") - logging.info(f"NEBULA Databases created in {self.databases_dir}") - - # Watchdog for running additional scripts in the host machine (i.e. during the execution of a federation) - event_handler = NebulaEventHandler() - observer = Observer() - observer.schedule(event_handler, path=self.config_dir, recursive=True) - observer.start() - - if self.mender: - logging.info("[Mender.module] Mender module initialized") - time.sleep(2) - mender = Mender() - logging.info("[Mender.module] Getting token from Mender server: {}".format(os.getenv("MENDER_SERVER"))) - mender.renew_token() - time.sleep(2) - logging.info( - "[Mender.module] Getting devices from {} with group Cluster_Thun".format(os.getenv("MENDER_SERVER")) - ) - time.sleep(2) - devices = mender.get_devices_by_group("Cluster_Thun") - logging.info("[Mender.module] Getting a pool of devices: 5 devices") - # devices = devices[:5] - for i in self.config.participants: - logging.info( - "[Mender.module] Device {} | IP: {}".format(i["device_args"]["idx"], i["network_args"]["ip"]) - ) - logging.info("[Mender.module] \tCreating artifacts...") - logging.info("[Mender.module] \tSending NEBULA Core...") - # mender.deploy_artifact_device("my-update-2.0.mender", i['device_args']['idx']) - logging.info("[Mender.module] \tSending configuration...") - time.sleep(5) - sys.exit(0) - - logging.info("Press Ctrl+C for exit from NEBULA (global exit)") - - # Adjust signal handling inside the start method - signal.signal(signal.SIGTERM, self.signal_handler) - signal.signal(signal.SIGINT, self.signal_handler) - - try: - while True: - time.sleep(1) - except KeyboardInterrupt: - logging.info("Closing NEBULA (exiting from components)... Please wait") - observer.stop() - self.stop() - - observer.join() - - def signal_handler(self, sig, frame): - # Handle termination signals - logging.info("Received termination signal, shutting down...") - self.stop() - sys.exit(0) - - def run_controller_api(self): - uvicorn.run( - self.app, - host="0.0.0.0", - port=self.controller_port, - log_config=None, # Prevent Uvicorn from configuring logging - ) - - def run_waf(self): - network_name = f"{os.environ['USER']}_nebula-net-base" - base = DockerUtils.create_docker_network(network_name) - - client = docker.from_env() - - volumes_waf = ["/var/log/nginx"] - - ports_waf = [80] - - host_config_waf = client.api.create_host_config( - binds=[f"{os.environ['NEBULA_LOGS_DIR']}/waf/nginx:/var/log/nginx"], - privileged=True, - port_bindings={80: self.waf_port}, - ) - - networking_config_waf = client.api.create_networking_config({ - f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.200") - }) - - container_id_waf = client.api.create_container( - image="nebula-waf", - name=f"{os.environ['USER']}_nebula-waf", - detach=True, - volumes=volumes_waf, - host_config=host_config_waf, - networking_config=networking_config_waf, - ports=ports_waf, - ) - - client.api.start(container_id_waf) - - environment = { - "GF_SECURITY_ADMIN_PASSWORD": "admin", - "GF_USERS_ALLOW_SIGN_UP": "false", - "GF_SERVER_HTTP_PORT": "3000", - "GF_SERVER_PROTOCOL": "http", - "GF_SERVER_DOMAIN": f"localhost:{self.grafana_port}", - "GF_SERVER_ROOT_URL": f"http://localhost:{self.grafana_port}/grafana/", - "GF_SERVER_SERVE_FROM_SUB_PATH": "true", - "GF_DASHBOARDS_DEFAULT_HOME_DASHBOARD_PATH": "/var/lib/grafana/dashboards/dashboard.json", - "GF_METRICS_MAX_LIMIT_TSDB": "0", - } - - ports = [3000] - - host_config = client.api.create_host_config( - port_bindings={3000: self.grafana_port}, - ) - - networking_config = client.api.create_networking_config({ - f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.201") - }) - - container_id = client.api.create_container( - image="nebula-waf-grafana", - name=f"{os.environ['USER']}_nebula-waf-grafana", - detach=True, - environment=environment, - host_config=host_config, - networking_config=networking_config, - ports=ports, - ) - - client.api.start(container_id) - - command = ["-config.file=/mnt/config/loki-config.yml"] - - ports_loki = [3100] - - host_config_loki = client.api.create_host_config( - port_bindings={3100: self.loki_port}, - ) - - networking_config_loki = client.api.create_networking_config({ - f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.202") - }) - - container_id_loki = client.api.create_container( - image="nebula-waf-loki", - name=f"{os.environ['USER']}_nebula-waf-loki", - detach=True, - command=command, - host_config=host_config_loki, - networking_config=networking_config_loki, - ports=ports_loki, - ) - - client.api.start(container_id_loki) - - volumes_promtail = ["/var/log/nginx"] - - host_config_promtail = client.api.create_host_config( - binds=[ - f"{os.environ['NEBULA_LOGS_DIR']}/waf/nginx:/var/log/nginx", - ], - ) - - networking_config_promtail = client.api.create_networking_config({ - f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.203") - }) - - container_id_promtail = client.api.create_container( - image="nebula-waf-promtail", - name=f"{os.environ['USER']}_nebula-waf-promtail", - detach=True, - volumes=volumes_promtail, - host_config=host_config_promtail, - networking_config=networking_config_promtail, - ) - - client.api.start(container_id_promtail) - - def run_frontend(self): - if sys.platform == "win32": - if not os.path.exists("//./pipe/docker_Engine"): - raise Exception( - "Docker is not running, please check if Docker is running and Docker Compose is installed." - ) - else: - if not os.path.exists("/var/run/docker.sock"): - raise Exception( - "/var/run/docker.sock not found, please check if Docker is running and Docker Compose is installed." - ) - - try: - subprocess.check_call(["nvidia-smi"]) - self.gpu_available = True - except Exception: - logging.info("No GPU available for the frontend, nodes will be deploy in CPU mode") - - network_name = f"{os.environ['USER']}_nebula-net-base" - - # Create the Docker network - base = DockerUtils.create_docker_network(network_name) - - client = docker.from_env() - - environment = { - "NEBULA_CONTROLLER_NAME": os.environ["USER"], - "NEBULA_PRODUCTION": self.production, - "NEBULA_GPU_AVAILABLE": self.gpu_available, - "NEBULA_ADVANCED_ANALYTICS": self.advanced_analytics, - "NEBULA_FRONTEND_LOG": "/nebula/app/logs/frontend.log", - "NEBULA_LOGS_DIR": "/nebula/app/logs/", - "NEBULA_CONFIG_DIR": "/nebula/app/config/", - "NEBULA_CERTS_DIR": "/nebula/app/certs/", - "NEBULA_ENV_PATH": "/nebula/app/.env", - "NEBULA_ROOT_HOST": self.root_path, - "NEBULA_HOST_PLATFORM": self.host_platform, - "NEBULA_DEFAULT_USER": "admin", - "NEBULA_DEFAULT_PASSWORD": "admin", - "NEBULA_FRONTEND_PORT": self.frontend_port, - "NEBULA_CONTROLLER_PORT": self.controller_port, - "NEBULA_CONTROLLER_HOST": "host.docker.internal", - } - - volumes = ["/nebula", "/var/run/docker.sock", "/etc/nginx/sites-available/default"] - - ports = [80, 8080] - - host_config = client.api.create_host_config( - binds=[ - f"{self.root_path}:/nebula", - "/var/run/docker.sock:/var/run/docker.sock", - f"{self.root_path}/nebula/frontend/config/nebula:/etc/nginx/sites-available/default", - f"{self.databases_dir}:/nebula/nebula/frontend/databases", - ], - extra_hosts={"host.docker.internal": "host-gateway"}, - port_bindings={80: self.frontend_port, 8080: self.statistics_port}, - ) - - networking_config = client.api.create_networking_config({ - f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.100") - }) - - container_id = client.api.create_container( - image="nebula-frontend", - name=f"{os.environ['USER']}_nebula-frontend", - detach=True, - environment=environment, - volumes=volumes, - host_config=host_config, - networking_config=networking_config, - ports=ports, - ) - - client.api.start(container_id) - - def run_test(self): - deploy_tests.start() - - @staticmethod - def stop_waf(): - DockerUtils.remove_containers_by_prefix(f"{os.environ['USER']}_nebula-waf") - - @staticmethod - def stop(): - logging.info("Closing NEBULA (exiting from components)... Please wait") - DockerUtils.remove_containers_by_prefix(f"{os.environ['USER']}_") - ScenarioManagement.stop_blockchain() - ScenarioManagement.stop_participants() - Controller.stop_waf() - DockerUtils.remove_docker_networks_by_prefix(f"{os.environ['USER']}_") - controller_pid_file = os.path.join(os.path.dirname(__file__), "controller.pid") - try: - with open(controller_pid_file) as f: - pid = int(f.read()) - os.kill(pid, signal.SIGKILL) - os.remove(controller_pid_file) - except Exception as e: - logging.exception(f"Error while killing controller process: {e}") - sys.exit(0) diff --git a/nebula/controller/controller.py b/nebula/controller/controller.py new file mode 100755 index 000000000..cc0bc7826 --- /dev/null +++ b/nebula/controller/controller.py @@ -0,0 +1,1600 @@ +import asyncio +import datetime +import importlib +import json +import logging +import os +import re +import signal +import subprocess +import sys +import threading +import time +from typing import Annotated + +import aiohttp +import docker +import psutil +import uvicorn +from dotenv import load_dotenv +from fastapi import Body, FastAPI, Request, status, HTTPException, Path +from watchdog.events import PatternMatchingEventHandler +from watchdog.observers import Observer + +from nebula.addons.env import check_environment +from nebula.config.config import Config +from nebula.config.mender import Mender +from nebula.controller.scenarios import Scenario, ScenarioManagement +from nebula.utils import DockerUtils, SocketUtils + + +# Setup controller logger +class TermEscapeCodeFormatter(logging.Formatter): + """ + Custom logging formatter that removes ANSI terminal escape codes from log messages. + + This formatter is useful when you want to clean up log outputs by stripping out + any terminal color codes or formatting sequences before logging them to a file + or other non-terminal output. + + Attributes: + fmt (str): Format string for the log message. + datefmt (str): Format string for the date in the log message. + style (str): Formatting style (default is '%'). + validate (bool): Whether to validate the format string. + + Methods: + format(record): Strips ANSI escape codes from the log message and formats it. + """ + + def __init__(self, fmt=None, datefmt=None, style="%", validate=True): + """ + Initializes the TermEscapeCodeFormatter. + + Args: + fmt (str, optional): The format string for the log message. + datefmt (str, optional): The format string for the date. + style (str, optional): The formatting style. Defaults to '%'. + validate (bool, optional): Whether to validate the format string. Defaults to True. + """ + super().__init__(fmt, datefmt, style, validate) + + def format(self, record): + """ + Formats the specified log record, stripping out any ANSI escape codes. + + Args: + record (logging.LogRecord): The log record to be formatted. + + Returns: + str: The formatted log message with escape codes removed. + """ + escape_re = re.compile(r"\x1b\[[0-9;]*m") + record.msg = re.sub(escape_re, "", str(record.msg)) + return super().format(record) + +os.environ["NEBULA_CONTROLLER_NAME"] = os.environ["USER"] + +# Initialize FastAPI app outside the Controller class +app = FastAPI() + +# Define endpoints outside the Controller class +@app.get("/") +async def read_root(): + """ + Root endpoint of the NEBULA Controller API. + + Returns: + dict: A welcome message indicating the API is accessible. + """ + return {"message": "Welcome to the NEBULA Controller API"} + + +@app.get("/status") +async def get_status(): + """ + Check the status of the NEBULA Controller API. + + Returns: + dict: A status message confirming the API is running. + """ + return {"status": "NEBULA Controller API is running"} + + +@app.get("/resources") +async def get_resources(): + """ + Get system resource usage including RAM and GPU memory usage. + + Returns: + dict: A dictionary containing: + - gpus (int): Number of GPUs detected. + - memory_percent (float): Percentage of used RAM. + - gpu_memory_percent (List[float]): List of GPU memory usage percentages. + """ + devices = 0 + gpu_memory_percent = [] + + # Obtain available RAM + memory_info = await asyncio.to_thread(psutil.virtual_memory) + + if importlib.util.find_spec("pynvml") is not None: + try: + import pynvml + + await asyncio.to_thread(pynvml.nvmlInit) + devices = await asyncio.to_thread(pynvml.nvmlDeviceGetCount) + + # Obtain GPU info + for i in range(devices): + handle = await asyncio.to_thread(pynvml.nvmlDeviceGetHandleByIndex, i) + memory_info_gpu = await asyncio.to_thread(pynvml.nvmlDeviceGetMemoryInfo, handle) + memory_used_percent = (memory_info_gpu.used / memory_info_gpu.total) * 100 + gpu_memory_percent.append(memory_used_percent) + + except Exception: # noqa: S110 + pass + + return { + # "cpu_percent": psutil.cpu_percent(), + "gpus": devices, + "memory_percent": memory_info.percent, + "gpu_memory_percent": gpu_memory_percent, + } + + +@app.get("/least_memory_gpu") +async def get_least_memory_gpu(): + """ + Identify the GPU with the highest memory usage above a threshold (50%). + + Note: + Despite the name, this function returns the GPU using the **most** + memory above 50% usage. + + Returns: + dict: A dictionary with the index of the GPU using the most memory above the threshold, + or None if no such GPU is found. + """ + gpu_with_least_memory_index = None + + if importlib.util.find_spec("pynvml") is not None: + max_memory_used_percent = 50 + try: + import pynvml + + await asyncio.to_thread(pynvml.nvmlInit) + devices = await asyncio.to_thread(pynvml.nvmlDeviceGetCount) + + # Obtain GPU info + for i in range(devices): + handle = await asyncio.to_thread(pynvml.nvmlDeviceGetHandleByIndex, i) + memory_info = await asyncio.to_thread(pynvml.nvmlDeviceGetMemoryInfo, handle) + memory_used_percent = (memory_info.used / memory_info.total) * 100 + + # Obtain GPU with less memory available + if memory_used_percent > max_memory_used_percent: + max_memory_used_percent = memory_used_percent + gpu_with_least_memory_index = i + + except Exception: # noqa: S110 + pass + + return { + "gpu_with_least_memory_index": gpu_with_least_memory_index, + } + + +@app.get("/available_gpus/") +async def get_available_gpu(): + """ + Get the list of GPUs with memory usage below 5%. + + Returns: + dict: A dictionary with a list of GPU indices that are mostly free (usage < 5%). + """ + available_gpus = [] + + if importlib.util.find_spec("pynvml") is not None: + try: + import pynvml + + await asyncio.to_thread(pynvml.nvmlInit) + devices = await asyncio.to_thread(pynvml.nvmlDeviceGetCount) + + # Obtain GPU info + for i in range(devices): + handle = await asyncio.to_thread(pynvml.nvmlDeviceGetHandleByIndex, i) + memory_info = await asyncio.to_thread(pynvml.nvmlDeviceGetMemoryInfo, handle) + memory_used_percent = (memory_info.used / memory_info.total) * 100 + + # Obtain available GPUs + if memory_used_percent < 5: + available_gpus.append(i) + + return { + "available_gpus": available_gpus, + } + except Exception: # noqa: S110 + pass + + +@app.post("/scenarios/run") +async def run_scenario( + scenario_data: dict = Body(..., embed=True), + role: str = Body(..., embed=True), + user: str = Body(..., embed=True) +): + """ + Launches a new scenario based on the provided configuration. + + Args: + scenario_data (dict): The complete configuration of the scenario to be executed. + role (str): The role of the user initiating the scenario. + user (str): The username of the user initiating the scenario. + + Returns: + str: The name of the scenario that was started. + """ + + import subprocess + + from nebula.controller.scenarios import ScenarioManagement + + # Manager for the actual scenario + scenarioManagement = ScenarioManagement(scenario_data, user) + + await update_scenario( + scenario_name=scenarioManagement.scenario_name, + start_time=scenarioManagement.start_date_scenario, + end_time="", + scenario=scenario_data, + status="running", + role=role, + username=user + ) + + # Run the actual scenario + try: + if scenarioManagement.scenario.mobility: + additional_participants = scenario_data["additional_participants"] + schema_additional_participants = scenario_data["schema_additional_participants"] + scenarioManagement.load_configurations_and_start_nodes( + additional_participants, schema_additional_participants + ) + else: + scenarioManagement.load_configurations_and_start_nodes() + except subprocess.CalledProcessError as e: + logging.exception(f"Error docker-compose up: {e}") + return + + return scenarioManagement.scenario_name + + +@app.post("/scenarios/remove") +async def remove_scenario( + scenario_name: str = Body(..., embed=True) +): + """ + Removes a scenario from the database by its name. + + Args: + scenario_name (str): Name of the scenario to remove. + + Returns: + dict: A message indicating successful removal. + """ + from nebula.controller.database import remove_scenario_by_name + + try: + remove_scenario_by_name(scenario_name) + except Exception as e: + logging.error(f"Error removing scenario {scenario_name}: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + return {"message": f"Scenario {scenario_name} removed successfully"} + + +@app.get("/scenarios/{user}/{role}") +async def get_scenarios( + user: Annotated[ + str, + Path( + regex="^[a-zA-Z0-9_-]+$", + min_length=1, + max_length=50, + description="Valid username" + ) + ], + role: Annotated[ + str, + Path( + regex="^[a-zA-Z0-9_-]+$", + min_length=1, + max_length=50, + description="Valid role" + ) + ] +): + """ + Retrieves all scenarios associated with a given user and role. + + Args: + user (str): Username to filter scenarios. + role (str): Role of the user (e.g., "admin"). + + Returns: + dict: A list of scenarios and the currently running scenario. + """ + from nebula.controller.database import get_all_scenarios_and_check_completed, get_running_scenario + + try: + scenarios = get_all_scenarios_and_check_completed(username=user, role=role) + if role == "admin": + scenario_running = get_running_scenario() + else: + scenario_running = get_running_scenario(username=user) + except Exception as e: + logging.error(f"Error obtaining scenarios: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + return {"scenarios": scenarios, "scenario_running": scenario_running} + + +@app.post("/scenarios/update") +async def update_scenario( + scenario_name: str = Body(..., embed=True), + start_time: str = Body(..., embed=True), + end_time: str = Body(..., embed=True), + scenario: dict = Body(..., embed=True), + status: str = Body(..., embed=True), + role: str = Body(..., embed=True), + username: str = Body(..., embed=True) +): + """ + Updates the status and metadata of a scenario. + + Args: + scenario_name (str): Name of the scenario. + start_time (str): Start time of the scenario. + end_time (str): End time of the scenario. + scenario (dict): Scenario configuration. + status (str): New status of the scenario (e.g., "running", "finished"). + role (str): Role associated with the scenario. + username (str): User performing the update. + + Returns: + dict: A message confirming the update. + """ + from nebula.controller.database import scenario_update_record + + try: + scenario = Scenario.from_dict(scenario) + scenario_update_record(scenario_name, start_time, end_time, scenario, status, role, username) + except Exception as e: + logging.error(f"Error updating scenario {scenario_name}: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + return {"message": f"Scenario {scenario_name} updated successfully"} + + +@app.post("/scenarios/set_status_to_finished") +async def set_scenario_status_to_finished( + scenario_name: str = Body(..., embed=True), + all: bool = Body(False, embed=True) +): + """ + Sets the status of a scenario (or all scenarios) to 'finished'. + + Args: + scenario_name (str): Name of the scenario to mark as finished. + all (bool): If True, sets all scenarios to finished. + + Returns: + dict: A message confirming the operation. + """ + from nebula.controller.database import scenario_set_status_to_finished, scenario_set_all_status_to_finished + + try: + if all: + scenario_set_all_status_to_finished() + else: + scenario_set_status_to_finished(scenario_name) + except Exception as e: + logging.error(f"Error setting scenario {scenario_name} to finished: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + return {"message": f"Scenario {scenario_name} status set to finished successfully"} + + +@app.get("/scenarios/running") +async def get_running_scenario(get_all: bool = False): + """ + Retrieves the currently running scenario(s). + + Args: + get_all (bool): If True, retrieves all running scenarios. + + Returns: + dict or list: Running scenario(s) information. + """ + from nebula.controller.database import get_running_scenario + + try: + return get_running_scenario(get_all=get_all) + except Exception as e: + logging.error(f"Error obtaining running scenario: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + +@app.get("/scenarios/check") +async def check_scenario(role: str, scenario_name: str): + """ + Checks if a scenario is allowed for a specific role. + + Args: + role (str): Role to validate. + scenario_name (str): Name of the scenario. + + Returns: + dict: Whether the scenario is allowed for the role. + """ + from nebula.controller.database import check_scenario_with_role + + try: + allowed = check_scenario_with_role(role, scenario_name) + return {"allowed": allowed} + except Exception as e: + logging.error(f"Error checking scenario with role: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + +@app.get("/scenarios/{scenario_name}") +async def get_scenario_by_name( + scenario_name: Annotated[ + str, + Path( + regex="^[a-zA-Z0-9_-]+$", + min_length=1, + max_length=50, + description="Valid scenario name" + ) + ] +): + """ + Fetches a scenario by its name. + + Args: + scenario_name (str): The name of the scenario. + + Returns: + dict: The scenario data. + """ + from nebula.controller.database import get_scenario_by_name + + try: + scenario = get_scenario_by_name(scenario_name) + except Exception as e: + logging.error(f"Error obtaining scenario {scenario_name}: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + return scenario + + +@app.get("/nodes/{scenario_name}") +async def list_nodes_by_scenario_name( + scenario_name: Annotated[ + str, + Path( + regex="^[a-zA-Z0-9_-]+$", + min_length=1, + max_length=50, + description="Valid scenario name" + ) + ] +): + """ + Lists all nodes associated with a specific scenario. + + Args: + scenario_name (str): Name of the scenario. + + Returns: + list: List of nodes. + """ + from nebula.controller.database import list_nodes_by_scenario_name + + try: + nodes = list_nodes_by_scenario_name(scenario_name) + except Exception as e: + logging.error(f"Error obtaining nodes: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + return nodes + + +@app.post("/nodes/{scenario_name}/update") +async def update_nodes( + scenario_name: Annotated[ + str, + Path( + regex="^[a-zA-Z0-9_-]+$", + min_length=1, + max_length=50, + description="Valid scenario name" + ), + ], + request: Request +): + """ + Updates the configuration of a node in the database and notifies the frontend. + + Args: + scenario_name (str): The scenario to which the node belongs. + request (Request): The HTTP request containing the node data. + + Returns: + dict: Confirmation or response from the frontend. + """ + from nebula.controller.database import update_node_record + try: + config = await request.json() + timestamp = datetime.datetime.now() + # Update the node in database + await update_node_record( + str(config["device_args"]["uid"]), + str(config["device_args"]["idx"]), + str(config["network_args"]["ip"]), + str(config["network_args"]["port"]), + str(config["device_args"]["role"]), + str(config["network_args"]["neighbors"]), + str(config["mobility_args"]["latitude"]), + str(config["mobility_args"]["longitude"]), + str(timestamp), + str(config["scenario_args"]["federation"]), + str(config["federation_args"]["round"]), + str(config["scenario_args"]["name"]), + str(config["tracking_args"]["run_hash"]), + str(config["device_args"]["malicious"]), + ) + except Exception as e: + logging.error(f"Error updating nodes: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + port = os.environ["NEBULA_FRONTEND_PORT"] + url = f"http://localhost:{port}/platform/dashboard/{scenario_name}/node/update" + + config["timestamp"] = str(timestamp) + + async with aiohttp.ClientSession() as session: + async with session.post(url, json=config) as response: + if response.status == 200: + return await response.json() + else: + raise HTTPException(status_code=response.status, detail="Error posting data") + + return {"message": "Nodes updated successfully in the database"} + + +@app.post("/nodes/{scenario_name}/done") +async def node_done( + scenario_name: Annotated[ + str, + Path( + regex="^[a-zA-Z0-9_-]+$", + min_length=1, + max_length=50, + description="Valid scenario name" + ), + ], + request: Request +): + """ + Endpoint to forward node status to the frontend. + + Receives a JSON payload and forwards it to the frontend's /node/done route + for the given scenario. + + Parameters: + - scenario_name: Name of the scenario. + - request: HTTP request with JSON body. + + Returns the response from the frontend or raises an HTTPException if it fails. + """ + port = os.environ["NEBULA_FRONTEND_PORT"] + url = f"http://localhost:{port}/platform/dashboard/{scenario_name}/node/done" + + data = await request.json() + + async with aiohttp.ClientSession() as session: + async with session.post(url, json=data) as response: + if response.status == 200: + return await response.json() + else: + raise HTTPException(status_code=response.status, detail="Error posting data") + + return {"message": "Nodes done"} + + +@app.post("/nodes/remove") +async def remove_nodes_by_scenario_name( + scenario_name: str = Body(..., embed=True) +): + """ + Endpoint to remove all nodes associated with a scenario. + + Body Parameters: + - scenario_name: Name of the scenario whose nodes should be removed. + + Returns a success message or an error if something goes wrong. + """ + from nebula.controller.database import remove_nodes_by_scenario_name + + try: + remove_nodes_by_scenario_name(scenario_name) + except Exception as e: + logging.error(f"Error removing nodes: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + return {"message": f"Nodes for scenario {scenario_name} removed successfully"} + + +@app.get("/notes/{scenario_name}") +async def get_notes_by_scenario_name( + scenario_name: Annotated[ + str, + Path( + regex="^[a-zA-Z0-9_-]+$", + min_length=1, + max_length=50, + description="Valid scenario name" + ) + ] +): + """ + Endpoint to retrieve notes associated with a scenario. + + Path Parameters: + - scenario_name: Name of the scenario. + + Returns the notes or raises an HTTPException on error. + """ + from nebula.controller.database import get_notes + + try: + notes = get_notes(scenario_name) + except Exception as e: + logging.error(f"Error obtaining notes {notes}: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + return notes + + +@app.post("/notes/update") +async def update_notes_by_scenario_name( + scenario_name: str = Body(..., embed=True), + notes: str = Body(..., embed=True) +): + """ + Endpoint to update notes for a given scenario. + + Body Parameters: + - scenario_name: Name of the scenario. + - notes: Text content to store as notes. + + Returns a success message or an error if something goes wrong. + """ + from nebula.controller.database import save_notes + + try: + save_notes(scenario_name, notes) + except Exception as e: + logging.error(f"Error updating notes: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + return {"message": f"Notes for scenario {scenario_name} updated successfully"} + + +@app.post("/notes/remove") +async def remove_notes_by_scenario_name( + scenario_name: str = Body(..., embed=True) +): + """ + Endpoint to remove notes associated with a scenario. + + Body Parameters: + - scenario_name: Name of the scenario. + + Returns a success message or an error if something goes wrong. + """ + from nebula.controller.database import remove_note + + try: + remove_note(scenario_name) + except Exception as e: + logging.error(f"Error removing notes: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + return {"message": f"Notes for scenario {scenario_name} removed successfully"} + + +@app.get("/user/list") +async def list_users_controller(all_info: bool = False): + """ + Endpoint to list all users in the database. + + Query Parameters: + - all_info (bool): If True, returns full user info as dictionaries. + + Returns a list of users or raises an HTTPException on error. + """ + from nebula.controller.database import list_users + + try: + user_list = list_users(all_info) + if all_info: + # Convert each sqlite3.Row to a dictionary so that it is JSON serializable. + user_list = [dict(user) for user in user_list] + return {"users": user_list} + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error retrieving users: {e}" + ) + + +@app.get("/user/{scenario_name}") +async def get_user_by_scenario_name( + scenario_name: Annotated[ + str, + Path( + regex="^[a-zA-Z0-9_-]+$", + min_length=1, + max_length=50, + description="Valid scenario name" + ) + ] +): + """ + Endpoint to retrieve the user assigned to a scenario. + + Path Parameters: + - scenario_name: Name of the scenario. + + Returns user info or raises an HTTPException on error. + """ + from nebula.controller.database import get_user_by_scenario_name + + try: + user = get_user_by_scenario_name(scenario_name) + except Exception as e: + logging.error(f"Error obtaining user {user}: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + return user + + +@app.post("/user/add") +async def add_user_controller( + user: str = Body(...), + password: str = Body(...), + role: str = Body(...) +): + """ + Endpoint to add a new user to the database. + + Body Parameters: + - user: Username. + - password: Password for the new user. + - role: Role assigned to the user (e.g., "admin", "user"). + + Returns a success message or an error if the user could not be added. + """ + from nebula.controller.database import add_user + + try: + add_user(user, password, role) + return {"detail": "User added successfully"} + except Exception as e: + logging.error(f"Error adding user: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error adding user: {e}" + ) + + +@app.post("/user/delete") +async def remove_user_controller( + user: str = Body(..., embed=True) +): + """ + Controller endpoint that inserts a new user into the database. + + Parameters: + - user: The username for the new user. + + Returns a success message if the user is deleted, or an HTTP error if an exception occurs. + """ + from nebula.controller.database import delete_user_from_db + + try: + delete_user_from_db(user) + return {"detail": "User deleted successfully"} + except Exception as e: + logging.error(f"Error deleting user: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error deleting user: {e}" + ) + + +@app.post("/user/update") +async def add_user_controller( + user: str = Body(...), + password: str = Body(...), + role: str = Body(...) +): + """ + Controller endpoint that modifies a user of the database. + + Parameters: + - user: The username of the user. + - password: The user's password. + - role: The role of the user. + + Returns a success message if the user is updated, or an HTTP error if an exception occurs. + """ + from nebula.controller.database import update_user + + try: + update_user(user, password, role) + return {"detail": "User updated successfully"} + except Exception as e: + logging.error(f"Error updating user: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error updating user: {e}" + ) + + +@app.post("/user/verify") +async def add_user_controller( + user: str = Body(...), + password: str = Body(...) +): + """ + Endpoint to verify user credentials. + + Body Parameters: + - user: Username. + - password: Password. + + Returns the user role on success or raises an error on failure. + """ + from nebula.controller.database import list_users, verify, get_user_info + + try: + user_submitted = user.upper() + if (user_submitted in list_users()) and verify(user_submitted, password): + user_info = get_user_info(user_submitted) + return {"user": user_submitted, "role": user_info[2]} + else: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) + except Exception as e: + logging.error(f"Error verifying user: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error verifying user: {e}" + ) + + +class NebulaEventHandler(PatternMatchingEventHandler): + """ + NebulaEventHandler handles file system events for .sh scripts. + + This class monitors the creation, modification, and deletion of .sh scripts + in a specified directory. + """ + + patterns = ["*.sh", "*.ps1"] + + def __init__(self): + super(NebulaEventHandler, self).__init__() + self.last_processed = {} + self.timeout_ns = 5 * 1e9 + self.processing_files = set() + self.lock = threading.Lock() + + def _should_process_event(self, src_path: str) -> bool: + current_time_ns = time.time_ns() + logging.info(f"Current time (ns): {current_time_ns}") + with self.lock: + if src_path in self.last_processed: + logging.info(f"Last processed time for {src_path}: {self.last_processed[src_path]}") + last_time = self.last_processed[src_path] + if current_time_ns - last_time < self.timeout_ns: + return False + self.last_processed[src_path] = current_time_ns + return True + + def _is_being_processed(self, src_path: str) -> bool: + with self.lock: + if src_path in self.processing_files: + logging.info(f"Skipping {src_path} as it is already being processed.") + return True + self.processing_files.add(src_path) + return False + + def _processing_done(self, src_path: str): + with self.lock: + if src_path in self.processing_files: + self.processing_files.remove(src_path) + + def verify_nodes_ports(self, src_path): + parent_dir = os.path.dirname(src_path) + base_dir = os.path.basename(parent_dir) + scenario_path = os.path.join(os.path.dirname(parent_dir), base_dir) + + try: + port_mapping = {} + new_port_start = 50000 + + participant_files = sorted( + f for f in os.listdir(scenario_path) if f.endswith(".json") and f.startswith("participant") + ) + + for filename in participant_files: + file_path = os.path.join(scenario_path, filename) + with open(file_path) as json_file: + node = json.load(json_file) + current_port = node["network_args"]["port"] + port_mapping[current_port] = SocketUtils.find_free_port(start_port=new_port_start) + logging.info( + f"Participant file: {filename} | Current port: {current_port} | New port: {port_mapping[current_port]}" + ) + new_port_start = port_mapping[current_port] + 1 + + for filename in participant_files: + file_path = os.path.join(scenario_path, filename) + with open(file_path) as json_file: + node = json.load(json_file) + current_port = node["network_args"]["port"] + node["network_args"]["port"] = port_mapping[current_port] + neighbors = node["network_args"]["neighbors"] + + for old_port, new_port in port_mapping.items(): + neighbors = neighbors.replace(f":{old_port}", f":{new_port}") + + node["network_args"]["neighbors"] = neighbors + + with open(file_path, "w") as f: + json.dump(node, f, indent=4) + + except Exception as e: + print(f"Error processing JSON files: {e}") + + def on_created(self, event): + """ + Handles the event when a file is created. + """ + if event.is_directory: + return + src_path = event.src_path + if not self._should_process_event(src_path): + return + if self._is_being_processed(src_path): + return + logging.info("File created: %s" % src_path) + try: + self.verify_nodes_ports(src_path) + self.run_script(src_path) + finally: + self._processing_done(src_path) + + def on_deleted(self, event): + """ + Handles the event when a file is deleted. + """ + if event.is_directory: + return + src_path = event.src_path + if not self._should_process_event(src_path): + return + if self._is_being_processed(src_path): + return + logging.info("File deleted: %s" % src_path) + directory_script = os.path.dirname(src_path) + pids_file = os.path.join(directory_script, "current_scenario_pids.txt") + logging.info(f"Killing processes from {pids_file}") + try: + self.kill_script_processes(pids_file) + os.remove(pids_file) + except FileNotFoundError: + logging.warning(f"{pids_file} not found.") + except Exception as e: + logging.exception(f"Error while killing processes: {e}") + finally: + self._processing_done(src_path) + + def run_script(self, script): + try: + logging.info(f"Running script: {script}") + if script.endswith(".sh"): + result = subprocess.run(["bash", script], capture_output=True, text=True) + logging.info(f"Script output:\n{result.stdout}") + if result.stderr: + logging.error(f"Script error:\n{result.stderr}") + elif script.endswith(".ps1"): + subprocess.Popen( + ["powershell", "-ExecutionPolicy", "Bypass", "-File", script], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=False, + ) + else: + logging.error("Unsupported script format.") + return + except Exception as e: + logging.exception(f"Error while running script: {e}") + + def kill_script_processes(self, pids_file): + try: + with open(pids_file) as f: + pids = f.readlines() + for pid in pids: + try: + pid = int(pid.strip()) + if psutil.pid_exists(pid): + process = psutil.Process(pid) + children = process.children(recursive=True) + logging.info(f"Forcibly killing process {pid} and {len(children)} child processes...") + for child in children: + try: + logging.info(f"Forcibly killing child process {child.pid}") + child.kill() + except psutil.NoSuchProcess: + logging.warning(f"Child process {child.pid} already terminated.") + except Exception as e: + logging.exception(f"Error while forcibly killing child process {child.pid}: {e}") + try: + logging.info(f"Forcibly killing main process {pid}") + process.kill() + except psutil.NoSuchProcess: + logging.warning(f"Process {pid} already terminated.") + except Exception as e: + logging.exception(f"Error while forcibly killing main process {pid}: {e}") + else: + logging.warning(f"PID {pid} does not exist.") + except ValueError: + logging.exception(f"Invalid PID value in file: {pid}") + except Exception as e: + logging.exception(f"Error while forcibly killing process {pid}: {e}") + except FileNotFoundError: + logging.exception(f"PID file not found: {pids_file}") + except Exception as e: + logging.exception(f"Error while reading PIDs from file: {e}") + + +class Controller: + def __init__(self, args): + """ + Initializes the main controller class for the NEBULA system. + + Parses and stores all configuration values from the provided `args` object, + which is expected to come from an argument parser (e.g., argparse). + + Parameters (from `args`): + - scenario_name (str): Name of the current scenario. + - federation (str): Federation type used in the simulation. + - topology (str): Path to the topology file. + - controllerport (int): Port for the controller service (default: 5000). + - wafport (int): Port for the WAF service (default: 6000). + - webport (int): Port for the frontend (default: 6060). + - grafanaport (int): Port for Grafana (default: 6040). + - lokiport (int): Port for Loki logs (default: 6010). + - statsport (int): Port for the statistics module (default: 8080). + - simulation (bool): Whether the scenario runs in simulation mode. + - config (str): Path to the configuration directory. + - databases (str): Path to the databases directory (default: /opt/nebula). + - logs (str): Path to the log directory. + - certs (str): Path to the certificates directory. + - env (str): Path to the environment (venv, etc.). + - production (bool): Whether the system is running in production mode. + - advanced_analytics (bool): Whether advanced analytics are enabled. + - matrix (str): Path to the evaluation matrix file. + - root_path (str): Root path of the application. + - network_subnet (str): Custom Docker network subnet. + - network_gateway (str): Custom Docker network gateway. + - use_blockchain (bool): Whether the blockchain component is enabled. + + This method also: + - Sets platform type (`windows` or `unix`) + - Configures logging + - Dynamically selects free ports if the specified ones are in use + - Initializes configuration and deployment objects + """ + self.scenario_name = args.scenario_name if hasattr(args, "scenario_name") else None + self.start_date_scenario = None + self.federation = args.federation if hasattr(args, "federation") else None + self.topology = args.topology if hasattr(args, "topology") else None + self.controller_port = int(args.controllerport) if hasattr(args, "controllerport") else 5000 + self.waf_port = int(args.wafport) if hasattr(args, "wafport") else 6000 + self.frontend_port = int(args.webport) if hasattr(args, "webport") else 6060 + self.grafana_port = int(args.grafanaport) if hasattr(args, "grafanaport") else 6040 + self.loki_port = int(args.lokiport) if hasattr(args, "lokiport") else 6010 + self.statistics_port = int(args.statsport) if hasattr(args, "statsport") else 8080 + self.simulation = args.simulation + self.config_dir = args.config + self.databases_dir = args.databases if hasattr(args, "databases") else "/opt/nebula" + self.log_dir = args.logs + self.cert_dir = args.certs + self.env_path = args.env + self.production = args.production if hasattr(args, "production") else False + self.advanced_analytics = args.advanced_analytics if hasattr(args, "advanced_analytics") else False + self.matrix = args.matrix if hasattr(args, "matrix") else None + self.root_path = ( + args.root_path + if hasattr(args, "root_path") + else os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + ) + self.host_platform = "windows" if sys.platform == "win32" else "unix" + + # Network configuration (nodes deployment in a network) + self.network_subnet = args.network_subnet if hasattr(args, "network_subnet") else None + self.network_gateway = args.network_gateway if hasattr(args, "network_gateway") else None + + # Configure logger + self.configure_logger() + + # Check ports available + if not SocketUtils.is_port_open(self.controller_port): + self.controller_port = SocketUtils.find_free_port() + + if not SocketUtils.is_port_open(self.frontend_port): + self.frontend_port = SocketUtils.find_free_port(self.controller_port + 1) + + if not SocketUtils.is_port_open(self.statistics_port): + self.statistics_port = SocketUtils.find_free_port(self.frontend_port + 1) + + self.config = Config(entity="controller") + self.topologymanager = None + self.n_nodes = 0 + self.mender = None if self.simulation else Mender() + self.use_blockchain = args.use_blockchain if hasattr(args, "use_blockchain") else False + self.gpu_available = False + + # Reference the global app instance + self.app = app + + def configure_logger(self): + """ + Configures the logging system for the controller. + + - Sets a format for console and file logging. + - Creates a console handler with INFO level. + - Creates a file handler for 'controller.log' with INFO level. + - Configures specific Uvicorn loggers to use the file handler + without duplicating log messages. + """ + log_console_format = "[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s" + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + console_handler.setFormatter(TermEscapeCodeFormatter(log_console_format)) + console_handler_file = logging.FileHandler(os.path.join(self.log_dir, "controller.log"), mode="a") + console_handler_file.setLevel(logging.INFO) + console_handler_file.setFormatter(logging.Formatter("[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")) + logging.basicConfig( + level=logging.DEBUG, + handlers=[ + console_handler, + console_handler_file, + ], + ) + uvicorn_loggers = ["uvicorn", "uvicorn.error", "uvicorn.access"] + for logger_name in uvicorn_loggers: + logger = logging.getLogger(logger_name) + logger.handlers = [] # Remove existing handlers + logger.propagate = False # Prevent duplicate logs + handler = logging.FileHandler(os.path.join(self.log_dir, "controller.log"), mode="a") + handler.setFormatter(logging.Formatter("[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")) + logger.addHandler(handler) + + def start(self): + """ + Starts the NEBULA controller. + + - Displays the welcome banner. + - Loads environment variables from the `.env` file. + - Saves the process PID to 'controller.pid'. + - Checks the environment and saves configuration to environment variables. + - Launches the FastAPI app in a daemon thread. + - Initializes databases. + - In production mode, starts the WAF and logs WAF and Grafana ports. + - Runs the frontend and logs its URL. + - Starts a watchdog to monitor configuration directory changes. + - If enabled, initializes the Mender module for artifact deployment. + - Captures SIGTERM and SIGINT signals for graceful shutdown. + - Keeps the process running until termination signal or Ctrl+C. + """ + banner = """ + ββββ ββββββββββββββββββ βββ ββββββ ββββββ + βββββ ββββββββββββββββββββββ ββββββ ββββββββ + ββββββ βββββββββ βββββββββββ ββββββ ββββββββ + ββββββββββββββββ βββββββββββ ββββββ ββββββββ + βββ ββββββββββββββββββββββββββββββββββββββββββ βββ + βββ ββββββββββββββββββββ βββββββ βββββββββββ βββ + A Platform for Decentralized Federated Learning + Created by Enrique TomΓ‘s MartΓnez BeltrΓ‘n + https://github.com/CyberDataLab/nebula + """ + print("\x1b[0;36m" + banner + "\x1b[0m") + + # Load the environment variables + load_dotenv(self.env_path) + + # Save controller pid + with open(os.path.join(os.path.dirname(__file__), "controller.pid"), "w") as f: + f.write(str(os.getpid())) + + # Check information about the environment + check_environment() + + # Save the configuration in environment variables + logging.info("Saving configuration in environment variables...") + os.environ["NEBULA_ROOT"] = self.root_path + os.environ["NEBULA_LOGS_DIR"] = self.log_dir + os.environ["NEBULA_CONFIG_DIR"] = self.config_dir + os.environ["NEBULA_CERTS_DIR"] = self.cert_dir + os.environ["NEBULA_ROOT_HOST"] = self.root_path + os.environ["NEBULA_HOST_PLATFORM"] = self.host_platform + os.environ["NEBULA_CONTROLLER_HOST"] = "host.docker.internal" + os.environ["NEBULA_STATISTICS_PORT"] = str(self.statistics_port) + os.environ["NEBULA_CONTROLLER_PORT"] = str(self.controller_port) + os.environ["NEBULA_FRONTEND_PORT"] = str(self.frontend_port) + + # Start the FastAPI app in a daemon thread + app_thread = threading.Thread(target=self.run_controller_api, daemon=True) + app_thread.start() + logging.info(f"NEBULA Controller is running at port {self.controller_port}") + + from nebula.controller.database import initialize_databases + + asyncio.run(initialize_databases(self.databases_dir)) + + if self.production: + self.run_waf() + logging.info(f"NEBULA WAF is running at port {self.waf_port}") + logging.info(f"Grafana Dashboard is running at port {self.grafana_port}") + + self.run_frontend() + logging.info(f"NEBULA Frontend is running at http://localhost:{self.frontend_port}") + logging.info(f"NEBULA Databases created in {self.databases_dir}") + + # Watchdog for running additional scripts in the host machine (i.e. during the execution of a federation) + event_handler = NebulaEventHandler() + observer = Observer() + observer.schedule(event_handler, path=self.config_dir, recursive=True) + observer.start() + + if self.mender: + logging.info("[Mender.module] Mender module initialized") + time.sleep(2) + mender = Mender() + logging.info("[Mender.module] Getting token from Mender server: {}".format(os.getenv("MENDER_SERVER"))) + mender.renew_token() + time.sleep(2) + logging.info( + "[Mender.module] Getting devices from {} with group Cluster_Thun".format(os.getenv("MENDER_SERVER")) + ) + time.sleep(2) + devices = mender.get_devices_by_group("Cluster_Thun") + logging.info("[Mender.module] Getting a pool of devices: 5 devices") + # devices = devices[:5] + for i in self.config.participants: + logging.info( + "[Mender.module] Device {} | IP: {}".format(i["device_args"]["idx"], i["network_args"]["ip"]) + ) + logging.info("[Mender.module] \tCreating artifacts...") + logging.info("[Mender.module] \tSending NEBULA Core...") + # mender.deploy_artifact_device("my-update-2.0.mender", i['device_args']['idx']) + logging.info("[Mender.module] \tSending configuration...") + time.sleep(5) + sys.exit(0) + + logging.info("Press Ctrl+C for exit from NEBULA (global exit)") + + # Adjust signal handling inside the start method + signal.signal(signal.SIGTERM, self.signal_handler) + signal.signal(signal.SIGINT, self.signal_handler) + + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + logging.info("Closing NEBULA (exiting from components)... Please wait") + observer.stop() + self.stop() + + observer.join() + + def signal_handler(self, sig, frame): + """ + Handler for termination signals (SIGTERM, SIGINT). + + - Logs signal reception. + - Executes a graceful shutdown by calling self.stop(). + - Exits the process with sys.exit(0). + + Parameters: + - sig: The signal number received. + - frame: The current stack frame at signal reception. + """ + logging.info("Received termination signal, shutting down...") + self.stop() + sys.exit(0) + + def run_controller_api(self): + """ + Runs the FastAPI controller application using Uvicorn. + + - Binds to all network interfaces (0.0.0.0). + - Uses the port specified in self.controller_port. + - Disables Uvicorn's default logging configuration to use custom logging. + """ + uvicorn.run( + self.app, + host="0.0.0.0", + port=self.controller_port, + log_config=None, # Prevent Uvicorn from configuring logging + ) + + def run_waf(self): + """ + Starts the Web Application Firewall (WAF) and related monitoring containers. + + - Creates a Docker network named based on the current user. + - Starts the 'nebula-waf' container with logs volume and port mapping. + - Starts the 'nebula-waf-grafana' container for monitoring dashboards, + setting environment variables for Grafana configuration. + - Starts the 'nebula-waf-loki' container for log aggregation with a config file. + - Starts the 'nebula-waf-promtail' container to collect logs from nginx. + + All containers are connected to the same Docker network with assigned static IPs. + """ + network_name = f"{os.environ['USER']}_nebula-net-base" + base = DockerUtils.create_docker_network(network_name) + + client = docker.from_env() + + volumes_waf = ["/var/log/nginx"] + + ports_waf = [80] + + host_config_waf = client.api.create_host_config( + binds=[f"{os.environ['NEBULA_LOGS_DIR']}/waf/nginx:/var/log/nginx"], + privileged=True, + port_bindings={80: self.waf_port}, + ) + + networking_config_waf = client.api.create_networking_config({ + f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.200") + }) + + container_id_waf = client.api.create_container( + image="nebula-waf", + name=f"{os.environ['USER']}_nebula-waf", + detach=True, + volumes=volumes_waf, + host_config=host_config_waf, + networking_config=networking_config_waf, + ports=ports_waf, + ) + + client.api.start(container_id_waf) + + environment = { + "GF_SECURITY_ADMIN_PASSWORD": "admin", + "GF_USERS_ALLOW_SIGN_UP": "false", + "GF_SERVER_HTTP_PORT": "3000", + "GF_SERVER_PROTOCOL": "http", + "GF_SERVER_DOMAIN": f"localhost:{self.grafana_port}", + "GF_SERVER_ROOT_URL": f"http://localhost:{self.grafana_port}/grafana/", + "GF_SERVER_SERVE_FROM_SUB_PATH": "true", + "GF_DASHBOARDS_DEFAULT_HOME_DASHBOARD_PATH": "/var/lib/grafana/dashboards/dashboard.json", + "GF_METRICS_MAX_LIMIT_TSDB": "0", + } + + ports = [3000] + + host_config = client.api.create_host_config( + port_bindings={3000: self.grafana_port}, + ) + + networking_config = client.api.create_networking_config({ + f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.201") + }) + + container_id = client.api.create_container( + image="nebula-waf-grafana", + name=f"{os.environ['USER']}_nebula-waf-grafana", + detach=True, + environment=environment, + host_config=host_config, + networking_config=networking_config, + ports=ports, + ) + + client.api.start(container_id) + + command = ["-config.file=/mnt/config/loki-config.yml"] + + ports_loki = [3100] + + host_config_loki = client.api.create_host_config( + port_bindings={3100: self.loki_port}, + ) + + networking_config_loki = client.api.create_networking_config({ + f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.202") + }) + + container_id_loki = client.api.create_container( + image="nebula-waf-loki", + name=f"{os.environ['USER']}_nebula-waf-loki", + detach=True, + command=command, + host_config=host_config_loki, + networking_config=networking_config_loki, + ports=ports_loki, + ) + + client.api.start(container_id_loki) + + volumes_promtail = ["/var/log/nginx"] + + host_config_promtail = client.api.create_host_config( + binds=[ + f"{os.environ['NEBULA_LOGS_DIR']}/waf/nginx:/var/log/nginx", + ], + ) + + networking_config_promtail = client.api.create_networking_config({ + f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.203") + }) + + container_id_promtail = client.api.create_container( + image="nebula-waf-promtail", + name=f"{os.environ['USER']}_nebula-waf-promtail", + detach=True, + volumes=volumes_promtail, + host_config=host_config_promtail, + networking_config=networking_config_promtail, + ) + + client.api.start(container_id_promtail) + + def run_frontend(self): + """ + Starts the NEBULA frontend Docker container. + + - Checks if Docker is running (different checks for Windows and Unix). + - Detects if an NVIDIA GPU is available and sets a flag. + - Creates a Docker network named based on the current user. + - Prepares environment variables and volume mounts for the container. + - Binds ports for HTTP (80) and statistics (8080). + - Starts the 'nebula-frontend' container connected to the created network + with static IP assignment. + """ + if sys.platform == "win32": + if not os.path.exists("//./pipe/docker_Engine"): + raise Exception( + "Docker is not running, please check if Docker is running and Docker Compose is installed." + ) + else: + if not os.path.exists("/var/run/docker.sock"): + raise Exception( + "/var/run/docker.sock not found, please check if Docker is running and Docker Compose is installed." + ) + + try: + subprocess.check_call(["nvidia-smi"]) + self.gpu_available = True + except Exception: + logging.info("No GPU available for the frontend, nodes will be deploy in CPU mode") + + network_name = f"{os.environ['USER']}_nebula-net-base" + + # Create the Docker network + base = DockerUtils.create_docker_network(network_name) + + client = docker.from_env() + + environment = { + "NEBULA_CONTROLLER_NAME": os.environ["USER"], + "NEBULA_PRODUCTION": self.production, + "NEBULA_GPU_AVAILABLE": self.gpu_available, + "NEBULA_ADVANCED_ANALYTICS": self.advanced_analytics, + "NEBULA_FRONTEND_LOG": "/nebula/app/logs/frontend.log", + "NEBULA_LOGS_DIR": "/nebula/app/logs/", + "NEBULA_CONFIG_DIR": "/nebula/app/config/", + "NEBULA_CERTS_DIR": "/nebula/app/certs/", + "NEBULA_ENV_PATH": "/nebula/app/.env", + "NEBULA_ROOT_HOST": self.root_path, + "NEBULA_HOST_PLATFORM": self.host_platform, + "NEBULA_DEFAULT_USER": "admin", + "NEBULA_DEFAULT_PASSWORD": "admin", + "NEBULA_FRONTEND_PORT": self.frontend_port, + "NEBULA_CONTROLLER_PORT": self.controller_port, + "NEBULA_CONTROLLER_HOST": "host.docker.internal", + } + + volumes = ["/nebula", "/var/run/docker.sock", "/etc/nginx/sites-available/default"] + + ports = [80, 8080] + + host_config = client.api.create_host_config( + binds=[ + f"{self.root_path}:/nebula", + "/var/run/docker.sock:/var/run/docker.sock", + f"{self.root_path}/nebula/frontend/config/nebula:/etc/nginx/sites-available/default", + ], + extra_hosts={"host.docker.internal": "host-gateway"}, + port_bindings={80: self.frontend_port, 8080: self.statistics_port}, + ) + + networking_config = client.api.create_networking_config({ + f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.100") + }) + + container_id = client.api.create_container( + image="nebula-frontend", + name=f"{os.environ['USER']}_nebula-frontend", + detach=True, + environment=environment, + volumes=volumes, + host_config=host_config, + networking_config=networking_config, + ports=ports, + ) + + client.api.start(container_id) + + @staticmethod + def stop_waf(): + """ + Stops all running Docker containers whose names start with + the pattern '