diff --git a/nebula/addons/gps/nebulagps.py b/nebula/addons/gps/nebulagps.py index d2dcc6f98..571a5ab53 100644 --- a/nebula/addons/gps/nebulagps.py +++ b/nebula/addons/gps/nebulagps.py @@ -90,7 +90,7 @@ async def _send_location_loop(self): from nebula.core.network.communications import CommunicationsManager cm = CommunicationsManager.get_instance() - if cm.learning_finished(): + if await cm.learning_finished(): logging.info("GPS: Learning cycle finished, stopping location broadcast") break except Exception: @@ -111,7 +111,7 @@ async def _receive_location_loop(self): from nebula.core.network.communications import CommunicationsManager cm = CommunicationsManager.get_instance() - if cm.learning_finished(): + if await cm.learning_finished(): logging.info("GPS: Learning cycle finished, stopping location reception") break except Exception: @@ -139,7 +139,7 @@ async def _notify_geolocs(self): from nebula.core.network.communications import CommunicationsManager cm = CommunicationsManager.get_instance() - if cm.learning_finished(): + if await cm.learning_finished(): logging.info("GPS: Learning cycle finished, stopping geolocation notifications") break except Exception: diff --git a/nebula/addons/mobility.py b/nebula/addons/mobility.py index c4b9e00c3..b46f7fe88 100755 --- a/nebula/addons/mobility.py +++ b/nebula/addons/mobility.py @@ -77,20 +77,20 @@ def __init__(self, config, verbose=False): def cm(self): return CommunicationsManager.get_instance() - @property - def round(self): - """ - Gets the current round number from the Communications Manager. - - This property retrieves the current round number that is being managed by the - CommunicationsManager instance associated with this module. It provides an - interface to access the ongoing round of the communication process without - directly exposing the underlying method in the CommunicationsManager. - - Returns: - int: The current round number managed by the CommunicationsManager. - """ - return self.cm.get_round() + # @property + # def round(self): + # """ + # Gets the current round number from the Communications Manager. + + # This property retrieves the current round number that is being managed by the + # CommunicationsManager instance associated with this module. It provides an + # interface to access the ongoing round of the communication process without + # directly exposing the underlying method in the CommunicationsManager. + + # Returns: + # int: The current round number managed by the CommunicationsManager. + # """ + # return self.cm.get_round() async def start(self): """ diff --git a/nebula/addons/networksimulation/nebulanetworksimulator.py b/nebula/addons/networksimulation/nebulanetworksimulator.py index 2d87e9cdf..9af1f768e 100644 --- a/nebula/addons/networksimulation/nebulanetworksimulator.py +++ b/nebula/addons/networksimulation/nebulanetworksimulator.py @@ -51,41 +51,40 @@ async def is_running(self): async def _change_network_conditions_based_on_distances(self, gpsevent: GPSEvent): distances = await gpsevent.get_event_data() - while await self.is_running(): - if self._verbose: - logging.info("Refresh | conditions based on distances...") - try: - for addr, (distance, _) in distances.items(): - if distance is None: - # If the distance is not found, we skip the node - continue - conditions = await self._calculate_network_conditions(distance) - # Only update the network conditions if they have changed - if ( - addr not in self._current_network_conditions - or self._current_network_conditions[addr] != conditions - ): - addr_ip = addr.split(":")[0] - self._set_network_condition_for_addr( - self._node_interface, addr_ip, conditions["bandwidth"], conditions["delay"] - ) - self._set_network_condition_for_multicast( - self._node_interface, - addr_ip, - self.IP_MULTICAST, - conditions["bandwidth"], - conditions["delay"], - ) - async with self._network_conditions_lock: - self._current_network_conditions[addr] = conditions - else: - if self._verbose: - logging.info("network conditions havent changed since last time") - except KeyError: - logging.exception(f"πŸ“ Connection {addr} not found") - except Exception: - logging.exception("πŸ“ Error changing connections based on distance") - await asyncio.sleep(self._refresh_interval) + if self._verbose: + logging.info("Refresh | conditions based on distances...") + try: + for addr, (distance, _) in distances.items(): + if distance is None: + # If the distance is not found, we skip the node + continue + conditions = await self._calculate_network_conditions(distance) + # Only update the network conditions if they have changed + if ( + addr not in self._current_network_conditions + or self._current_network_conditions[addr] != conditions + ): + addr_ip = addr.split(":")[0] + self._set_network_condition_for_addr( + self._node_interface, addr_ip, conditions["bandwidth"], conditions["delay"] + ) + self._set_network_condition_for_multicast( + self._node_interface, + addr_ip, + self.IP_MULTICAST, + conditions["bandwidth"], + conditions["delay"], + ) + async with self._network_conditions_lock: + self._current_network_conditions[addr] = conditions + else: + if self._verbose: + logging.info("network conditions havent changed since last time") + except KeyError: + logging.exception(f"πŸ“ Connection {addr} not found") + except Exception: + logging.exception("πŸ“ Error changing connections based on distance") + await asyncio.sleep(self._refresh_interval) async def set_thresholds(self, thresholds: dict): async with self._network_conditions_lock: diff --git a/nebula/addons/reputation/reputation.py b/nebula/addons/reputation/reputation.py index 6f659c7a7..561199513 100644 --- a/nebula/addons/reputation/reputation.py +++ b/nebula/addons/reputation/reputation.py @@ -340,7 +340,7 @@ async def setup(self): ) await EventManager.get_instance().subscribe_node_event(DuplicatedMessageEvent, self.recollect_duplicated_number_message) - def init_reputation( + async def init_reputation( self, federation_nodes=None, round_num=None, last_feedback_round=None, init_reputation=None ): """ @@ -363,7 +363,7 @@ def init_reputation( logging.error("init_reputation | No valid neighbors found") return - self._initialize_neighbor_reputations(neighbors, round_num, last_feedback_round, init_reputation) + await self._initialize_neighbor_reputations(neighbors, round_num, last_feedback_round, init_reputation) def _validate_init_parameters(self, federation_nodes, round_num, init_reputation) -> bool: """Validate initialization parameters.""" @@ -379,11 +379,11 @@ def _validate_init_parameters(self, federation_nodes, round_num, init_reputation return True - def _initialize_neighbor_reputations(self, neighbors: list, round_num: int, last_feedback_round: int, init_reputation: float): + async def _initialize_neighbor_reputations(self, neighbors: list, round_num: int, last_feedback_round: int, init_reputation: float): """Initialize reputation entries for all neighbors.""" for nei in neighbors: self._create_or_update_reputation_entry(nei, round_num, last_feedback_round, init_reputation) - self.save_reputation_history_in_memory(self._addr, nei, init_reputation) + await self.save_reputation_history_in_memory(self._addr, nei, init_reputation) def _create_or_update_reputation_entry(self, nei: str, round_num: int, last_feedback_round: int, init_reputation: float): """Create or update a single reputation entry.""" @@ -418,7 +418,7 @@ def _validate_federation_nodes(self, federation_nodes) -> list: return valid_nodes - def _calculate_static_reputation( + async def _calculate_static_reputation( self, addr: str, nei: str, @@ -444,19 +444,19 @@ def _calculate_static_reputation( for metric_name in static_weights ) - logging.info(f"Static reputation for node {nei} at round {self.engine.get_round()}: {reputation_static}") + logging.info(f"Static reputation for node {nei} at round {await self.engine.get_round()}: {reputation_static}") - avg_reputation = self.save_reputation_history_in_memory(self.engine.addr, nei, reputation_static) + avg_reputation = await self.save_reputation_history_in_memory(self.engine.addr, nei, reputation_static) metrics_data = { "addr": addr, "nei": nei, - "round": self.engine.get_round(), + "round": await self.engine.get_round(), "reputation_without_feedback": avg_reputation, **{f"average_{name}": weight for name, weight in static_weights.items()} } - self._update_reputation_record(nei, avg_reputation, metrics_data) + await self._update_reputation_record(nei, avg_reputation, metrics_data) async def _calculate_dynamic_reputation(self, addr, neighbors): """ @@ -470,20 +470,20 @@ async def _calculate_dynamic_reputation(self, addr, neighbors): logging.warning("_metrics is not properly initialized") return - average_weights = self._calculate_average_weights() - self._process_neighbors_reputation(addr, neighbors, average_weights) + average_weights = await self._calculate_average_weights() + await self._process_neighbors_reputation(addr, neighbors, average_weights) - def _calculate_average_weights(self): + async def _calculate_average_weights(self): """Calculate average weights for all enabled metrics.""" average_weights = {} for metric_name in self.history_data.keys(): if self._is_metric_enabled(metric_name): - average_weights[metric_name] = self._get_metric_average_weight(metric_name) + average_weights[metric_name] = await self._get_metric_average_weight(metric_name) return average_weights - def _get_metric_average_weight(self, metric_name): + async def _get_metric_average_weight(self, metric_name): """Get the average weight for a specific metric.""" if metric_name not in self.history_data or not self.history_data[metric_name]: logging.debug(f"No history data available for metric: {metric_name}") @@ -492,7 +492,7 @@ def _get_metric_average_weight(self, metric_name): valid_entries = [ entry for entry in self.history_data[metric_name] if (entry.get("round") is not None and - entry["round"] >= self._engine.get_round() and + entry["round"] >= await self._engine.get_round() and entry.get("weight") not in [None, -1]) ] @@ -506,22 +506,22 @@ def _get_metric_average_weight(self, metric_name): logging.warning(f"Error calculating average weight for {metric_name}: {e}") return 0 - def _process_neighbors_reputation(self, addr, neighbors, average_weights): + async def _process_neighbors_reputation(self, addr, neighbors, average_weights): """Process reputation calculation for all neighbors.""" for nei in neighbors: - metric_values = self._get_neighbor_metric_values(nei) + metric_values = await self._get_neighbor_metric_values(nei) if all(metric_name in metric_values for metric_name in average_weights): - self._update_neighbor_reputation(addr, nei, metric_values, average_weights) + await self._update_neighbor_reputation(addr, nei, metric_values, average_weights) - def _get_neighbor_metric_values(self, nei): + async def _get_neighbor_metric_values(self, nei): """Get metric values for a specific neighbor in the current round.""" metric_values = {} for metric_name in self.history_data: if self._is_metric_enabled(metric_name): for entry in self.history_data.get(metric_name, []): - if (entry.get("round") == self._engine.get_round() and + if (entry.get("round") == await self._engine.get_round() and entry.get("metric_name") == metric_name and entry.get("nei") == nei): metric_values[metric_name] = entry.get("metric_value", 0) @@ -529,7 +529,7 @@ def _get_neighbor_metric_values(self, nei): return metric_values - def _update_neighbor_reputation(self, addr, nei, metric_values, average_weights): + async def _update_neighbor_reputation(self, addr, nei, metric_values, average_weights): """Update reputation for a specific neighbor.""" reputation_with_weights = sum( metric_values.get(metric_name, 0) * average_weights[metric_name] @@ -537,24 +537,24 @@ def _update_neighbor_reputation(self, addr, nei, metric_values, average_weights) ) logging.info( - f"Dynamic reputation with weights for {nei} at round {self._engine.get_round()}: {reputation_with_weights}" + f"Dynamic reputation with weights for {nei} at round {await self._engine.get_round()}: {reputation_with_weights}" ) - avg_reputation = self.save_reputation_history_in_memory(self._engine.addr, nei, reputation_with_weights) + avg_reputation = await self.save_reputation_history_in_memory(self._engine.addr, nei, reputation_with_weights) metrics_data = { "addr": addr, "nei": nei, - "round": self._engine.get_round(), + "round": await self._engine.get_round(), "reputation_without_feedback": avg_reputation, } for metric_name in metric_values: metrics_data[f"average_{metric_name}"] = average_weights[metric_name] - self._update_reputation_record(nei, avg_reputation, metrics_data) + await self._update_reputation_record(nei, avg_reputation, metrics_data) - def _update_reputation_record(self, nei: str, reputation: float, data: dict): + async def _update_reputation_record(self, nei: str, reputation: float, data: dict): """ Update the reputation record of a participant. @@ -563,7 +563,7 @@ def _update_reputation_record(self, nei: str, reputation: float, data: dict): reputation: The reputation value data: Additional data to update (currently unused) """ - current_round = self._engine.get_round() + current_round = await self._engine.get_round() if nei not in self.reputation: self.reputation[nei] = { @@ -763,7 +763,7 @@ async def calculate_value_metrics(self, addr, nei, metrics_active=None): metrics_active (dict): The active metrics. """ try: - current_round = self._engine.get_round() + current_round = await self._engine.get_round() metrics_instance = self.connection_metrics.get(nei) if not metrics_instance: @@ -1518,7 +1518,7 @@ def _calculate_weighted_average_zero(self, key: tuple, current_round: int) -> fl ) return previous_avg * self.ZERO_VALUE_DECAY_FACTOR if previous_avg is not None else 0 - def save_reputation_history_in_memory(self, addr: str, nei: str, reputation: float) -> float: + async def save_reputation_history_in_memory(self, addr: str, nei: str, reputation: float) -> float: """ Save reputation history and calculate weighted average. @@ -1532,7 +1532,7 @@ def save_reputation_history_in_memory(self, addr: str, nei: str, reputation: flo """ try: key = (addr, nei) - current_round = self._engine.get_round() + current_round = await self._engine.get_round() if key not in self.reputation_history: self.reputation_history[key] = {} @@ -1619,7 +1619,7 @@ async def calculate_reputation(self, ae: AggregationEvent): return (updates, _, _) = await ae.get_event_data() - self._log_reputation_calculation_start() + await self._log_reputation_calculation_start() neighbors = set(await self._engine._cm.get_addrs_current_connections(only_direct=True)) @@ -1629,9 +1629,9 @@ async def calculate_reputation(self, ae: AggregationEvent): await self._process_feedback() await self._finalize_reputation_calculation(updates, neighbors) - def _log_reputation_calculation_start(self): + async def _log_reputation_calculation_start(self): """Log the start of reputation calculation with relevant information.""" - current_round = self._engine.get_round() + current_round = await self._engine.get_round() logging.info(f"Calculating reputation at round {current_round}") logging.info(f"Active metrics: {self._metrics}") logging.info(f"rejected nodes at round {current_round}: {self.rejected_nodes}") @@ -1646,11 +1646,11 @@ async def _process_neighbor_metrics(self, neighbors): ) if self._weighting_factor == "dynamic": - self._process_dynamic_metrics(nei, metrics) - elif self._weighting_factor == "static" and self._engine.get_round() >= 1: - self._process_static_metrics(nei, metrics) + await self._process_dynamic_metrics(nei, metrics) + elif self._weighting_factor == "static" and await self._engine.get_round() >= 1: + await self._process_static_metrics(nei, metrics) - def _process_dynamic_metrics(self, nei, metrics): + async def _process_dynamic_metrics(self, nei, metrics): """Process metrics for dynamic weighting factor.""" (metric_messages_number, metric_similarity, metric_fraction, metric_model_arrival_latency) = metrics @@ -1660,13 +1660,13 @@ def _process_dynamic_metrics(self, nei, metrics): metric_fraction, metric_model_arrival_latency, self.history_data, - self._engine.get_round(), + await self._engine.get_round(), self._addr, nei, self._metrics, ) - def _process_static_metrics(self, nei, metrics): + async def _process_static_metrics(self, nei, metrics): """Process metrics for static weighting factor.""" (metric_messages_number, metric_similarity, metric_fraction, metric_model_arrival_latency) = metrics @@ -1676,20 +1676,20 @@ def _process_static_metrics(self, nei, metrics): "fraction_parameters_changed": metric_fraction, "model_arrival_latency": metric_model_arrival_latency, } - self._calculate_static_reputation(self._addr, nei, metric_values_dict) + await self._calculate_static_reputation(self._addr, nei, metric_values_dict) async def _calculate_reputation_by_factor(self, neighbors): """Calculate reputation based on the weighting factor.""" - if self._weighting_factor == "dynamic" and self._engine.get_round() >= 1: + if self._weighting_factor == "dynamic" and await self._engine.get_round() >= 1: await self._calculate_dynamic_reputation(self._addr, neighbors) async def _handle_initial_reputation(self): """Handle reputation initialization for the first round.""" - if self._engine.get_round() < 1 and self._enabled: + if await self._engine.get_round() < 1 and self._enabled: federation = self._engine.config.participant["network_args"]["neighbors"].split() - self.init_reputation( + await self.init_reputation( federation_nodes=federation, - round_num=self._engine.get_round(), + round_num=await self._engine.get_round(), last_feedback_round=-1, init_reputation=self._initial_reputation, ) @@ -1697,7 +1697,7 @@ async def _handle_initial_reputation(self): async def _process_feedback(self): """Process and include feedback in reputation.""" status = await self.include_feedback_in_reputation() - current_round = self._engine.get_round() + current_round = await self._engine.get_round() if status: logging.info(f"Feedback included in reputation at round {current_round}") @@ -1707,7 +1707,7 @@ async def _process_feedback(self): async def _finalize_reputation_calculation(self, updates, neighbors): """Finalize reputation calculation by creating graphics and sending data.""" if self.reputation is not None: - self.create_graphic_reputation(self._addr, self._engine.get_round()) + self.create_graphic_reputation(self._addr, await self._engine.get_round()) await self.update_process_aggregation(updates) await self.send_reputation_to_neighbors(neighbors) @@ -1725,7 +1725,7 @@ async def send_reputation_to_neighbors(self, neighbors): "share", node_id=nei, score=float(data["reputation"]), - round=self._engine.get_round(), + round=await self._engine.get_round(), ) await self._engine.cm.send_message(neighbor, message) logging.info( @@ -1763,7 +1763,7 @@ async def update_process_aggregation(self, updates): if rn in updates: updates.pop(rn) - if self.engine.get_round() >= 1: + if await self.engine.get_round() >= 1: for nei in list(updates.keys()): if nei in self.reputation: rep = self.reputation[nei].get("reputation", 0) @@ -1823,7 +1823,7 @@ async def include_feedback_in_reputation(self): self.reputation[node_ip] = { "reputation": combined_reputation, - "round": self._engine.get_round(), + "round": await self._engine.get_round(), "last_feedback_round": round_num, } updated = True @@ -1848,22 +1848,22 @@ async def on_round_start(self, rse: RoundStartEvent): async def recollect_model_arrival_latency(self, ure: UpdateReceivedEvent): (decoded_model, weight, source, round_num, local) = await ure.get_event_data() - current_round = self._engine.get_round() + current_round = await self._engine.get_round() self.round_timing_info.setdefault(round_num, {}) if round_num == current_round: - self._process_current_round(round_num, source) + await self._process_current_round(round_num, source) elif round_num > current_round: self.round_timing_info[round_num]["pending_recalculation"] = True self.round_timing_info[round_num].setdefault("pending_sources", set()).add(source) logging.info(f"Model from future round {round_num} stored, pending recalculation.") else: - self._process_past_round(round_num, source) + await self._process_past_round(round_num, source) self._recalculate_pending_latencies(current_round) - def _process_current_round(self, round_num, source): + async def _process_current_round(self, round_num, source): """ Process models that arrive in the current round. """ @@ -1885,13 +1885,13 @@ def _process_current_round(self, round_num, source): source, self._addr, num_round=round_num, - current_round=self._engine.get_round(), + current_round=await self._engine.get_round(), latency=duration, ) else: logging.info(f"Start time not yet available for round {round_num}.") - def _process_past_round(self, round_num, source): + async def _process_past_round(self, round_num, source): """ Process models that arrive in past rounds. """ @@ -1913,7 +1913,7 @@ def _process_past_round(self, round_num, source): source, self._addr, num_round=round_num, - current_round=self._engine.get_round(), + current_round=await self._engine.get_round(), latency=duration, ) else: @@ -1978,7 +1978,7 @@ async def recollect_similarity(self, ure: UpdateReceivedEvent): "timestamp": datetime.now(), "nei": nei, "round": round_num, - "current_round": self._engine.get_round(), + "current_round": await self._engine.get_round(), **similarity_values } @@ -2037,7 +2037,7 @@ def _check_similarity_threshold(self, nei: str, cosine_value: float): async def recollect_number_message(self, source, message): """Record a number message from a source.""" - self._record_message_data(source) + await self._record_message_data(source) async def recollect_duplicated_number_message(self, dme: DuplicatedMessageEvent): """Record a duplicated message event.""" @@ -2046,9 +2046,9 @@ async def recollect_duplicated_number_message(self, dme: DuplicatedMessageEvent) source = event_data[0] else: source = event_data - self._record_message_data(source) + await self._record_message_data(source) - def _record_message_data(self, source: str): + async def _record_message_data(self, source: str): """Record message data for the given source if it's not the current address.""" if source != self._addr: current_time = time.time() @@ -2058,7 +2058,7 @@ def _record_message_data(self, source: str): source, self._addr, time=current_time, - current_round=self._engine.get_round(), + current_round=await self._engine.get_round(), ) async def recollect_fraction_of_parameters_changed(self, ure: UpdateReceivedEvent): @@ -2070,7 +2070,7 @@ async def recollect_fraction_of_parameters_changed(self, ure: UpdateReceivedEven """ (decoded_model, weight, source, round_num, local) = await ure.get_event_data() - current_round = self._engine.get_round() + current_round = await self._engine.get_round() parameters_local = self._engine.trainer.get_model_parameters() prev_threshold = self._get_previous_threshold(source, current_round) @@ -2156,4 +2156,4 @@ def _store_fraction_data(self, source: str, current_round: int, data: dict): if current_round not in self.fraction_of_params_changed[source]: self.fraction_of_params_changed[source][current_round] = [] - self.fraction_of_params_changed[source][current_round].append(data) + self.fraction_of_params_changed[source][current_round].append(data) \ No newline at end of file diff --git a/nebula/addons/trustworthiness/trustworthiness.py b/nebula/addons/trustworthiness/trustworthiness.py index 3132461b7..1eaa17c6a 100644 --- a/nebula/addons/trustworthiness/trustworthiness.py +++ b/nebula/addons/trustworthiness/trustworthiness.py @@ -2,7 +2,7 @@ from nebula.addons.functions import print_msg_box from nebula.core.nebulaevents import ExperimentFinishEvent, RoundEndEvent, TestMetricsEvent from nebula.core.eventmanager import EventManager -from nebula.core.role import Role +from nebula.core.noderole import Role, ServerRoleBehavior from abc import ABC, abstractmethod from nebula.config.config import Config from nebula.core.engine import Engine @@ -119,12 +119,13 @@ async def _process_experiment_finished_event(self, efe:ExperimentFinishEvent): class TrustWorkloadServer(TrustWorkload): - def __init__(self, engine, idx, trust_files_route): + def __init__(self, engine: Engine, idx, trust_files_route): self._workload = 'aggregation' self._sample_size = 0 self._current_loss = None self._current_accuracy = None - self._start_time = engine._start_time + server_start_time: ServerRoleBehavior = engine.rb + self._start_time = server_start_time._start_time self._engine: Engine = engine self._end_time = None self._experiment_name = "" @@ -216,7 +217,7 @@ class Trustworthiness(): def __init__(self, engine: Engine, config: Config): config.reset_logging_configuration() print_msg_box( - msg=f"Name Trustworthiness Module\nRole: {engine.role.value}", + msg=f"Name Trustworthiness Module\nRole: {engine.rb.get_role_name()}", indent=2, ) self._engine = engine @@ -225,7 +226,7 @@ def __init__(self, engine: Engine, config: Config): self._experiment_name = self._config.participant["scenario_args"]["name"] self._trust_dir_files = f"/nebula/app/logs/{self._experiment_name}/trustworthiness" self._emissions_file = 'emissions.csv' - self._role: Role = engine.role + self._role: Role = engine.rb.get_role() self._idx = self._config.participant["device_args"]["idx"] self._trust_workload: TrustWorkload = self._factory_trust_workload(self._role, self._engine, self._idx, self._trust_dir_files) diff --git a/nebula/config/config.py b/nebula/config/config.py index cae3cf7f8..5ef336e3a 100755 --- a/nebula/config/config.py +++ b/nebula/config/config.py @@ -55,7 +55,7 @@ def reset_logging_configuration(self): self.__set_default_logging(mode="a") self.__set_training_logging(mode="a") - + def shutdown_logging(self): """ Properly shuts down all loggers and their handlers in the system. diff --git a/nebula/controller/controller.py b/nebula/controller/controller.py index 9eebf4ed3..a00d142d1 100755 --- a/nebula/controller/controller.py +++ b/nebula/controller/controller.py @@ -367,7 +367,7 @@ async def stop_scenario( """ from nebula.controller.scenarios import ScenarioManagement - ScenarioManagement.stop_participants(scenario_name) + # ScenarioManagement.stop_participants(scenario_name) DockerUtils.remove_containers_by_prefix(f"{os.environ.get('NEBULA_CONTROLLER_NAME')}_{username}-participant") DockerUtils.remove_docker_network( f"{(os.environ.get('NEBULA_CONTROLLER_NAME'))}_{str(username).lower()}-nebula-net-scenario" diff --git a/nebula/controller/scenarios.py b/nebula/controller/scenarios.py index 8475a26ed..acf1e98e5 100644 --- a/nebula/controller/scenarios.py +++ b/nebula/controller/scenarios.py @@ -477,11 +477,13 @@ def validate_positive_int(value, name): raise ValueError("round_start_attack must be less than round_stop_attack") node_attack_params["attacks"] = node_att + nodes[node]["malicious"] = True nodes[node]["attack_params"] = node_attack_params + nodes[node]["fake_behavior"] = nodes[node]["role"] + nodes[node]["role"] = "malicious" else: nodes[node]["attack_params"] = {"attacks": "No Attack"} - nodes[node]["malicious"] = malicious nodes[node]["reputation"] = node_reputation logging.info( @@ -693,7 +695,8 @@ def __init__(self, scenario, user=None): participant_config["device_args"]["logging"] = self.scenario.logginglevel participant_config["aggregator_args"]["algorithm"] = self.scenario.agg_algorithm # To be sure that benign nodes have no attack parameters - if node_config["malicious"]: + if node_config["role"] == "malicious": + participant_config["adversarial_args"]["fake_behavior"] = node_config["fake_behavior"] participant_config["adversarial_args"]["attack_params"] = node_config["attack_params"] else: participant_config["adversarial_args"]["attack_params"] = {"attacks": "No Attack"} diff --git a/nebula/core/aggregation/aggregator.py b/nebula/core/aggregation/aggregator.py index 0648b5b42..ff88668de 100755 --- a/nebula/core/aggregation/aggregator.py +++ b/nebula/core/aggregation/aggregator.py @@ -48,7 +48,7 @@ def run_aggregation(self, models): return None async def init(self): - await self.us.init(self.config) + await self.us.init(self.engine.rb.get_role_name(True)) async def update_federation_nodes(self, federation_nodes: set): """ @@ -108,12 +108,7 @@ async def get_aggregation(self): TimeoutError: If the aggregation lock is not acquired within the defined timeout. asyncio.CancelledError: If the aggregation lock acquisition is cancelled. Exception: For any other unexpected errors during the aggregation process. - """ - # Check if learning cycle has finished to prevent blocking - if self.engine.learning_cycle_finished(): - logging.info("πŸ”„ get_aggregation | Learning cycle has finished, skipping aggregation") - return None - + """ try: timeout = self.config.participant["aggregator_args"]["aggregation_timeout"] logging.info(f"Aggregation timeout: {timeout} starts...") @@ -147,6 +142,10 @@ async def get_aggregation(self): await self.us.stop_notifying_updates() updates = await self.us.get_round_updates() + if not updates: + logging.info(f"πŸ”„ get_aggregation | No updates has been received..resolving conflict to continue...") + updates = {self._addr: await self.engine.resolve_missing_updates()} + missing_nodes = await self.us.get_round_missing_nodes() if missing_nodes: logging.info(f"πŸ”„ get_aggregation | Aggregation incomplete, missing models from: {missing_nodes}") diff --git a/nebula/core/aggregation/updatehandlers/cflupdatehandler.py b/nebula/core/aggregation/updatehandlers/cflupdatehandler.py index aa4648849..6e66203cb 100644 --- a/nebula/core/aggregation/updatehandlers/cflupdatehandler.py +++ b/nebula/core/aggregation/updatehandlers/cflupdatehandler.py @@ -88,7 +88,7 @@ async def init(self, config): Initializes the handler with the participant configuration, and subscribes to relevant node events. """ - self._role = config.participant["device_args"]["role"] + self._role = config await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent, self.notify_federation_update) await EventManager.get_instance().subscribe_node_event(UpdateReceivedEvent, self.storage_update) diff --git a/nebula/core/aggregation/updatehandlers/dflupdatehandler.py b/nebula/core/aggregation/updatehandlers/dflupdatehandler.py index fc722645a..b98cbaf98 100644 --- a/nebula/core/aggregation/updatehandlers/dflupdatehandler.py +++ b/nebula/core/aggregation/updatehandlers/dflupdatehandler.py @@ -227,7 +227,9 @@ async def notify_federation_update(self, updt_nei_event: UpdateNeighborEvent): else: if source not in self._sources_received: # Not received update from this source yet await self._update_source(source, remove=True) - await self._all_updates_received() # Verify if discarding node aggregation could be done + all_rec = await self._all_updates_received() # Verify if discarding node aggregation could be done + if all_rec: + await self._notify() else: logging.info(f"Already received update from: {source}, it will be discarded next round") diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 3ee3f5105..416637a77 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -7,6 +7,7 @@ import docker +from nebula.core.noderole import factory_role_behavior, change_role_behavior, Role, RoleBehavior from nebula.addons.functions import print_msg_box from nebula.addons.reporter import Reporter from nebula.addons.reputation.reputation import Reputation @@ -20,6 +21,8 @@ RoundStartEvent, UpdateNeighborEvent, UpdateReceivedEvent, + ExperimentFinishEvent, + ModelPropagationEvent, ) from nebula.core.network.communications import CommunicationsManager from nebula.core.role import Role, factory_node_role @@ -91,19 +94,12 @@ def __init__( self.ip = config.participant["network_args"]["ip"] self.port = config.participant["network_args"]["port"] self.addr = config.participant["network_args"]["addr"] - self.role = config.participant["device_args"]["role"] - self.role: Role = factory_node_role(self.role) + self.name = config.participant["device_args"]["name"] self.client = docker.from_env() print_banner() - print_msg_box( - msg=f"Name {self.name}\nRole: {self.role.value}", - indent=2, - title="Node information", - ) - self._trainer = None self._aggregator = None self.round = None @@ -121,6 +117,16 @@ def __init__( self._secure_neighbors = [] self._is_malicious = self.config.participant["adversarial_args"]["attack_params"]["attacks"] != "No Attack" + role = config.participant["device_args"]["role"] + self._role_behavior: RoleBehavior = factory_role_behavior(role, self, config) + self._role_behavior_performance_lock = Locker("role_behavior_performance_lock", async_lock=True) + + print_msg_box( + msg=f"Name {self.name}\nRole: {self._role_behavior.get_role_name()}", + indent=2, + title="Node information", + ) + msg = f"Trainer: {self._trainer.__class__.__name__}" msg += f"\nDataset: {self.config.participant['data_args']['dataset']}" msg += f"\nIID: {self.config.participant['data_args']['iid']}" @@ -138,6 +144,7 @@ def __init__( self.federation_setup_lock = Locker(name="federation_setup_lock", async_lock=True) self.federation_ready_lock = Locker(name="federation_ready_lock", async_lock=True) self.round_lock = Locker(name="round_lock", async_lock=True) + self._round_in_process_lock = Locker("round_in_process_lock", async_lock=True) self.config.reload_config_file() self._cm = CommunicationsManager(engine=self) @@ -180,6 +187,11 @@ def aggregator(self): def trainer(self): """Trainer""" return self._trainer + + @property + def rb(self): + """Role Behavior""" + return self._role_behavior @property def sa(self): @@ -209,8 +221,10 @@ def get_initialization_status(self): def set_initialization_status(self, status): self.initialized = status - def get_round(self): - return self.round + async def get_round(self): + async with self.round_lock: + current_round = self.round + return current_round def get_federation_ready_lock(self): return self.federation_ready_lock @@ -255,12 +269,6 @@ async def model_initialization_callback(self, source, message): async def model_update_callback(self, source, message): logging.info(f"πŸ€– handle_model_message | Received model update from {source} with round {message.round}") - - # Ignore model updates if the learning cycle has finished - if self.learning_cycle_finished(): - logging.info(f"πŸ€– Ignoring model update from {source} because learning cycle has finished") - return - if not self.get_federation_ready_lock().locked() and len(await self.get_federation_nodes()) == 0: logging.info("πŸ€– handle_model_message | There are no defined federation nodes") return @@ -309,29 +317,44 @@ async def _control_alive_callback(self, source, message): async def _control_leadership_transfer_callback(self, source, message): logging.info(f"πŸ”§ handle_control_message | Trigger | Received leadership transfer message from {source}") - if self.role == Role.AGGREGATOR: - neighbors = await self.cm.get_addrs_current_connections(myself=True) - if len(neighbors) > 1: - random_neighbor = random.choice(neighbors) - message = self.cm.create_message("control", "leadership_transfer") - await self.cm.send_message(random_neighbor, message) - logging.info( - f"πŸ”§ handle_control_message | Trigger | Leadership transfer message sent to {random_neighbor}" - ) - message = self.cm.create_message("control", "leadership_transfer_ack") - await self.cm.send_message(source, message) - logging.info(f"πŸ”§ handle_control_message | Trigger | Leadership transfer ack message sent to {source}") - else: - logging.info("πŸ”§ handle_control_message | Trigger | Only one neighbor found, I am the leader") + + if await self._round_in_process_lock.locked_async(): + logging.info("Learning cycle is executing, role behavior will be modified next round") + await self.rb.set_next_role(Role.AGGREGATOR, source_to_notificate=source) else: - self.role = Role.AGGREGATOR - logging.info("πŸ”§ handle_control_message | Trigger | I am now the leader") - message = self.cm.create_message("control", "leadership_transfer_ack") - await self.cm.send_message(source, message) - logging.info(f"πŸ”§ handle_control_message | Trigger | Leadership transfer ack message sent to {source}") + try: + logging.info("Trying to modify Role behavior") + lock_task = asyncio.create_task(self._round_in_process_lock.acquire_async()) + await asyncio.wait_for(lock_task, timeout=3) + self._role_behavior = change_role_behavior(self.rb, Role.AGGREGATOR, self, self.config) + await self.rb.set_next_role(Role.AGGREGATOR) + await self.update_self_role() + await self._round_in_process_lock.release_async() + except TimeoutError: + logging.info("Learning cycle is locked, role behavior will be modified next round") + await self.rb.set_next_role(Role.AGGREGATOR, source_to_notificate=source) async def _control_leadership_transfer_ack_callback(self, source, message): logging.info(f"πŸ”§ handle_control_message | Trigger | Received leadership transfer ack message from {source}") + # No concurrence of difference ack received treated, be aware of that. + if await self._round_in_process_lock.locked_async(): + logging.info("Learning cycle is executing, role behavior will be modified next round") + await self.rb.set_next_role(Role.TRAINER) + else: + try: + lock_task = asyncio.create_task(self._round_in_process_lock.acquire_async()) + await asyncio.wait_for(lock_task, timeout=3) + + logging.info("Role behavior could be executed...") + await self.rb.set_next_role(Role.TRAINER) + await self.update_self_role() + + await self._round_in_process_lock.release_async() + + except TimeoutError: + logging.info("Learning cycle is locked, role behavior will be modified next round") + await self.rb.set_next_role(Role.TRAINER) + async def _connection_connect_callback(self, source, message): logging.info(f"πŸ”— handle_connection_message | Trigger | Received connection message from {source}") @@ -356,20 +379,15 @@ async def _federation_federation_start_callback(self, source, message): async def _federation_federation_models_included_callback(self, source, message): logging.info(f"πŸ“ handle_federation_message | Trigger | Received aggregation finished message from {source}") - - # Ignore aggregation finished messages if the learning cycle has finished - if not self.learning_cycle_finished(): - logging.info(f"πŸ“ Ignoring aggregation finished message from {source} because learning cycle has finished") - return - + current_round = await self.get_round() try: await self.cm.get_connections_lock().acquire_async() - if self.round is not None and source in self.cm.connections: + if current_round is not None and source in self.cm.connections: try: if message is not None and len(message.arguments) > 0: self.cm.connections[source].update_round(int(message.arguments[0])) if message.round in [ - self.round - 1, - self.round, + current_round - 1, + current_round, ] else None except Exception as e: logging.exception(f"Error updating round in connection: {e}") @@ -478,14 +496,10 @@ async def broadcast_models_include(self, age: AggregationEvent): Sends: federation_models_included: A message containing the round number of the aggregation. """ - # Don't broadcast if the learning cycle has finished - if not self.learning_cycle_finished(): - logging.info("πŸ”„ Not broadcasting MODELS_INCLUDED because learning cycle has finished") - return - - logging.info(f"πŸ”„ Broadcasting MODELS_INCLUDED for round {self.get_round()}") + logging.info(f"πŸ”„ Broadcasting MODELS_INCLUDED for round {await self.get_round()}") + current_round = await self.get_round() message = self.cm.create_message( - "federation", "federation_models_included", [str(arg) for arg in [self.get_round()]] + "federation", "federation_models_included", [str(arg) for arg in [current_round]] ) asyncio.create_task(self.cm.send_message_to_neighbors(message)) @@ -696,7 +710,10 @@ async def _start_learning(self): await self.get_federation_ready_lock().acquire_async() if self.config.participant["device_args"]["start"]: logging.info("Propagate initial model updates.") - await self.cm.propagator.propagate("initialization") + + mpe = ModelPropagationEvent(await self.cm.get_addrs_current_connections(only_direct=True, myself=False), "initialization") + await EventManager.get_instance().publish_node_event(mpe) + await self.get_federation_ready_lock().release_async() self.trainer.set_epochs(epochs) @@ -741,12 +758,55 @@ def print_round_information(self): title="Round information", ) - def learning_cycle_finished(self): - if not self.round or not self.total_rounds: + async def learning_cycle_finished(self): + current_round = await self.get_round() + if not current_round or not self.total_rounds: return False else: - return self.round >= self.total_rounds + return current_round >= self.total_rounds + + async def resolve_missing_updates(self): + """ + Delegates the resolution strategy for missing updates to the current role behavior. + + This function is called when the node receives no model updates from neighbors + and needs to apply a fallback strategy depending on its role (e.g., using default weights + if aggregator, or local model if trainer). + Returns: + The result of the role-specific resolution strategy. + """ + logging.info(f"Using Role behavior: {self.rb.get_role_name()} conflict resolve strategy") + return await self.rb.resolve_missing_updates() + + async def update_self_role(self): + """ + Checks whether a role update is required and performs the transition if necessary. + + If a new role has been assigned (i.e., self.rb.update_role_needed() is True), + this function updates the role behavior accordingly and notifies the source + that initiated the role transfer, if applicable. + + It logs the role change and spawns an async task to send a control message + acknowledging the update to the initiating node. + + Raises: + Any exceptions from change_role_behavior or communication logic. + """ + if await self.rb.update_role_needed(): + logging.info("Starting Role Behavior modification...") + from_role = self.rb.get_role_name() + next_role = await self.rb.get_next_role() + source_to_notificate = await self.rb.get_source_to_notificate() + self._role_behavior: RoleBehavior = change_role_behavior(self.rb, next_role, self, self.config) + to_role = self.rb.get_role_name() + logging.info(f"Role behavior changing from: {from_role} to {to_role}") + self.config.participant["device_args"]["role"] = to_role + if source_to_notificate: + logging.info(f"Sending role modification ACK to transferer: {source_to_notificate}") + message = self.cm.create_message("control", "leadership_transfer_ack") + asyncio.create_task(self.cm.send_message(source_to_notificate, message)) + async def _learning_cycle(self): """ Main asynchronous loop for executing the Federated Learning process across multiple rounds. @@ -770,56 +830,71 @@ async def _learning_cycle(self): This function blocks (awaits) until the full FL process concludes. """ while self.round is not None and self.round < self.total_rounds: - current_time = time.time() - print_msg_box( - msg=f"Round {self.round} of {self.total_rounds - 1} started (max. {self.total_rounds} rounds)", - indent=2, - title="Round information", - ) - logging.info(f"Federation nodes: {self.federation_nodes}") - await self.update_federation_nodes( - await self.cm.get_addrs_current_connections(only_direct=True, myself=True) - ) - expected_nodes = await self.get_federation_nodes() - rse = RoundStartEvent(self.round, current_time, expected_nodes) - await EventManager.get_instance().publish_node_event(rse) - self.trainer.on_round_start() - logging.info(f"Expected nodes: {expected_nodes}") - direct_connections = await self.cm.get_addrs_current_connections(only_direct=True) - undirected_connections = await self.cm.get_addrs_current_connections(only_undirected=True) - logging.info(f"Direct connections: {direct_connections} | Undirected connections: {undirected_connections}") - logging.info(f"[Role {self.role.value}] Starting learning cycle...") - await self.aggregator.update_federation_nodes(expected_nodes) - await self._extended_learning_cycle() - - current_time = time.time() - ree = RoundEndEvent(self.round, current_time) - await EventManager.get_instance().publish_node_event(ree) - - await self.get_round_lock().acquire_async() + async with self._round_in_process_lock: + current_time = time.time() + print_msg_box( + msg=f"Round {self.round} of {self.total_rounds - 1} started (max. {self.total_rounds} rounds)", + indent=2, + title="Round information", + ) + + await self.update_self_role() + + logging.info(f"Federation nodes: {self.federation_nodes}") + await self.update_federation_nodes( + await self.cm.get_addrs_current_connections(only_direct=True, myself=True) + ) + expected_nodes = await self.rb.select_nodes_to_wait() + rse = RoundStartEvent(self.round, current_time, expected_nodes) + await EventManager.get_instance().publish_node_event(rse) + self.trainer.on_round_start() + logging.info(f"Expected nodes: {expected_nodes}") + direct_connections = await self.cm.get_addrs_current_connections(only_direct=True) + undirected_connections = await self.cm.get_addrs_current_connections(only_undirected=True) + + logging.info(f"Direct connections: {direct_connections} | Undirected connections: {undirected_connections}") + logging.info(f"[Role {self.rb.get_role_name()}] Starting learning cycle...") + + await self.aggregator.update_federation_nodes(expected_nodes) + async with self._role_behavior_performance_lock: + await self.rb.extended_learning_cycle() + + current_time = time.time() + ree = RoundEndEvent(self.round, current_time) + await EventManager.get_instance().publish_node_event(ree) - print_msg_box( - msg=f"Round {self.round} of {self.total_rounds - 1} finished (max. {self.total_rounds} rounds)", - indent=2, - title="Round information", - ) + await self.get_round_lock().acquire_async() - # await self.aggregator.reset() - self.trainer.on_round_end() - self.round += 1 - self.config.participant["federation_args"]["round"] = ( - self.round - ) # Set current round in config (send to the controller) - await self.get_round_lock().release_async() + print_msg_box( + msg=f"Round {self.round} of {self.total_rounds - 1} finished (max. {self.total_rounds} rounds)", + indent=2, + title="Round information", + ) + # await self.aggregator.reset() + self.trainer.on_round_end() + self.round += 1 + self.config.participant["federation_args"]["round"] = ( + self.round + ) # Set current round in config (send to the controller) + await self.get_round_lock().release_async() # End of the learning cycle self.trainer.on_learning_cycle_end() await self.trainer.test() - + + # Shutdown protocol + await self._shutdown_protocol() + + async def _shutdown_protocol(self): + logging.info("Starting graceful shutdown process...") + + # 1.- Publish Experiment Finish Event to the last update on modules + logging.info("Publishing Experiment Finish Event...") efe = ExperimentFinishEvent() await EventManager.get_instance().publish_node_event(efe) + # 2.- Log finish message print_msg_box( msg=f"FL process has been completed successfully (max. {self.total_rounds} rounds reached)", indent=2, @@ -841,13 +916,6 @@ async def _learning_cycle(self): await self.shutdown() return - async def _extended_learning_cycle(self): - """ - This method is called in each round of the learning cycle. It is used to extend the learning cycle with additional - functionalities. The method is called in the _learning_cycle method. - """ - pass - async def shutdown(self): logging.info("🚦 Engine shutdown initiated") diff --git a/nebula/core/eventmanager.py b/nebula/core/eventmanager.py index 561f58e00..00df4843a 100755 --- a/nebula/core/eventmanager.py +++ b/nebula/core/eventmanager.py @@ -32,7 +32,7 @@ def _initialize(self, verbose=False): self._node_events_lock = Locker("node_events_lock", async_lock=True) self._global_message_subscribers: list[Callable] = [] self._global_message_subscribers_lock = Locker("global_message_subscribers_lock", async_lock=True) - self._verbose = verbose + self._verbose = False self._initialized = True @staticmethod diff --git a/nebula/core/nebulaevents.py b/nebula/core/nebulaevents.py index 8e4173b6f..f2ec08835 100644 --- a/nebula/core/nebulaevents.py +++ b/nebula/core/nebulaevents.py @@ -264,6 +264,35 @@ async def get_event_data(self) -> tuple[str, bool]: async def is_concurrent(self) -> bool: return True + +class ModelPropagationEvent(NodeEvent): + def __init__(self, eligible_neighbors, strategy): + """Event triggered when model propagation is ready. + + Args: + eligible_neighbors (set): The elegible neighbors to propagate model. + strategy (str): Strategy to propagete the model + """ + self.eligible_neighbors = eligible_neighbors + self._strategy = strategy + + def __str__(self): + return f"Model propagation event, strategy: {self._strategy}" + + async def get_event_data(self) -> tuple[set, str]: + """ + Retrieves the event data. + + Returns: + tuple[set, str]: A tuple containing: + - The elegible neighbors to propagate model. + - The propagation strategy. + """ + return (self.eligible_neighbors, self._strategy) + + async def is_concurrent(self) -> bool: + return False + class UpdateReceivedEvent(NodeEvent): diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index 514675f7d..e0b1c17a5 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -495,14 +495,14 @@ def get_addr(self): """ return self.addr - def get_round(self): + async def get_round(self): """ Retrieves the current training round number from the engine. Returns: int: The current round number in the federated learning process. """ - return self.engine.get_round() + return await self.engine.get_round() async def start(self): """ @@ -555,7 +555,7 @@ async def process_connection(reader, writer, priority="medium"): addr = writer.get_extra_info("peername") # Check if learning cycle has finished - reject new connections - if self.engine.learning_cycle_finished(): + if await self.engine.learning_cycle_finished(): logging.info(f"πŸ”— [incoming] Rejecting connection from {addr} because learning cycle has finished") writer.write(b"CONNECTION//CLOSE\n") await writer.drain() @@ -801,7 +801,7 @@ async def deploy_additional_services(self): """ logging.info("🌐 Deploying additional services...") await self._forwarder.start() - self._propagator.start() + await self._propagator.start() async def include_received_message_hash(self, hash_message, source): """ @@ -900,7 +900,7 @@ async def establish_connection(self, addr, direct=True, reconnect=False, priorit bool: True if the connection action (new or upgrade) succeeded, False otherwise. """ # Check if learning cycle has finished - don't establish new connections - if self.engine.learning_cycle_finished(): + if await self.engine.learning_cycle_finished(): logging.info(f"πŸ”— [outgoing] Not establishing connection to {addr} because learning cycle has finished") return False @@ -1129,7 +1129,8 @@ async def disconnect(self, dest_addr, mutual_disconnection=True, forced=False): # Get the connection under lock to prevent race conditions async with self.connections_lock: - if dest_addr not in self.connections: + connection_to_remove = self.connections.get(dest_addr) + if not connection_to_remove: logging.info(f"Connection {dest_addr} not found") return conn = self.connections[dest_addr] @@ -1227,8 +1228,8 @@ async def get_addrs_current_connections(self, only_direct=False, only_undirected def get_ready_connections(self): return {addr for addr, conn in self.connections.items() if conn.get_ready()} - def learning_finished(self): - return self.engine.learning_cycle_finished() + async def learning_finished(self): + return await self.engine.learning_cycle_finished() def __str__(self): return f"Connections: {[str(conn) for conn in self.connections.values()]}" diff --git a/nebula/core/network/connection.py b/nebula/core/network/connection.py index 39c105656..6ba60749b 100755 --- a/nebula/core/network/connection.py +++ b/nebula/core/network/connection.py @@ -98,6 +98,7 @@ def __init__( self._inactivity = False self._last_activity = time.time() self._activity_lock = Locker(name="activity_lock", async_lock=True) + self._activity_task = None self.EOT_CHAR = b"\x00\x00\x00\x04" self.COMPRESSION_CHAR = b"\x00\x00\x00\x01" @@ -293,11 +294,11 @@ async def reconnect(self, max_retries: int = 5, delay: int = 5) -> None: delay (int): Delay in seconds between reconnection attempts. Defaults to 5. """ if self.forced_disconnection or not self.direct: - logging.info("Not going to reconnect because this connection is not direct") + logging.info(f"Not going to reconnect because: (forced: {self.forced_disconnection}, direct: {self.direct})") return # Check if learning cycle has finished - don't reconnect - if self.cm.learning_finished(): + if await self.cm.learning_finished(): logging.info(f"Not attempting reconnection to {self.addr} because learning cycle has finished") return @@ -358,7 +359,7 @@ async def send( return # Check if learning cycle has finished - don't send messages - if self.cm.learning_finished(): + if await self.cm.learning_finished(): logging.info(f"Not sending message to {self.addr} because learning cycle has finished") return @@ -378,9 +379,9 @@ async def send( await self._send_chunks(message_id, data_to_send) except Exception as e: logging.exception(f"Error sending data: {e}") - if self.direct and not self.cm.learning_finished(): + if self.direct and not await self.cm.learning_finished(): await self.reconnect() - elif self.cm.learning_finished(): + elif await self.cm.learning_finished(): logging.info(f"Not attempting reconnection to {self.addr} because learning cycle has finished") def _prepare_data(self, data: Any, pb: bool, encoding_type: str) -> tuple[bytes, bytes]: @@ -504,9 +505,10 @@ async def handle_incoming_message(self) -> None: except Exception as e: logging.exception(f"Error handling incoming message: {e}") finally: - if self.direct or self._prio == ConnectionPriority.HIGH: + if self.direct or self._prio == ConnectionPriority.HIGH: #and not await self.cm.learning_finished(): + logging.info("ERROR: handling incoming message. Trying to reconnect..") await self.reconnect() - elif self.cm.learning_finished(): + elif await self.cm.learning_finished(): logging.info(f"Not attempting reconnection to {self.addr} because learning cycle has finished") async def _read_exactly(self, num_bytes: int, max_retries: int = 3) -> bytes: diff --git a/nebula/core/network/propagator.py b/nebula/core/network/propagator.py index 2f37c227b..717ea5f94 100755 --- a/nebula/core/network/propagator.py +++ b/nebula/core/network/propagator.py @@ -3,6 +3,8 @@ import sys from abc import ABC, abstractmethod from collections import deque +from nebula.core.nebulaevents import ModelPropagationEvent +from nebula.core.eventmanager import EventManager from typing import TYPE_CHECKING, Any from nebula.addons.functions import print_msg_box @@ -23,7 +25,7 @@ class PropagationStrategy(ABC): """ @abstractmethod - def is_node_eligible(self, node: str) -> bool: + async def is_node_eligible(self, node: str) -> bool: """ Determine whether a given node should receive the model payload. @@ -68,16 +70,16 @@ def __init__(self, aggregator: "Aggregator", trainer: "Lightning", engine: "Engi self.trainer = trainer self.engine = engine - def get_round(self): + async def get_round(self): """ Get the current training round number from the engine. Returns: int: The current round index. """ - return self.engine.get_round() + return await self.engine.get_round() - def is_node_eligible(self, node: str) -> bool: + async def is_node_eligible(self, node: str) -> bool: """ Determine if a node has not yet received the initial model. @@ -124,16 +126,16 @@ def __init__(self, aggregator: "Aggregator", trainer: "Lightning", engine: "Engi self.engine = engine self.addr = self.engine.get_addr() - def get_round(self): + async def get_round(self): """ Get the current training round number from the engine. Returns: int: The current round index. """ - return self.engine.get_round() + return await self.engine.get_round() - def is_node_eligible(self, node: str) -> bool: + async def is_node_eligible(self, node: str) -> bool: """ Determine if a node requires a model update based on aggregation state. @@ -145,7 +147,7 @@ def is_node_eligible(self, node: str) -> bool: is less than the current round. """ return (node not in self.aggregator.get_nodes_pending_models_to_aggregate()) or ( - self.engine.cm.connections[node].get_federated_round() < self.get_round() + self.engine.cm.connections[node].get_federated_round() < await self.get_round() ) def prepare_model_payload(self, node: str) -> tuple[Any, float] | None: @@ -195,7 +197,7 @@ def cm(self): else: return self._cm - def start(self): + async def start(self): """ Initialize the Propagator by retrieving core components and configuration, setting up propagation intervals, history buffer, and strategy instances. @@ -203,6 +205,7 @@ def start(self): This method must be called before any propagation cycles to ensure that all dependencies (engine, trainer, aggregator, etc.) are available. """ + await EventManager.get_instance().subscribe_node_event(ModelPropagationEvent, self._propagate) self.engine: Engine = self.cm.engine self.config: Config = self.cm.get_config() self.addr = self.cm.get_addr() @@ -228,14 +231,14 @@ def start(self): ) self._running.set() - def get_round(self): + async def get_round(self): """ Retrieve the current federated learning round number. Returns: int: The current round index from the engine. """ - return self.engine.get_round() + return await self.engine.get_round() def update_and_check_neighbors(self, strategy, eligible_neighbors): """ @@ -285,7 +288,7 @@ def reset_status_history(self): """ self.status_history.clear() - async def propagate(self, strategy_id: str): + async def _propagate(self, mpe: ModelPropagationEvent): """ Execute a single propagation cycle using the specified strategy. @@ -304,21 +307,23 @@ async def propagate(self, strategy_id: str): Returns: bool: True if propagation occurred (payload sent), False if halted early. """ + eligible_neighbors, strategy_id = await mpe.get_event_data() + self.reset_status_history() if strategy_id not in self.strategies: logging.info(f"Strategy {strategy_id} not found.") return False - if self.get_round() is None: + if await self.get_round() is None: logging.info("Propagation halted: round is not set.") return False strategy = self.strategies[strategy_id] logging.info(f"Starting model propagation with strategy: {strategy_id}") - current_connections = await self.cm.get_addrs_current_connections(only_direct=True) - eligible_neighbors = [ - neighbor_addr for neighbor_addr in current_connections if strategy.is_node_eligible(neighbor_addr) - ] + # current_connections = await self.cm.get_addrs_current_connections(only_direct=True) + # eligible_neighbors = [ + # neighbor_addr for neighbor_addr in current_connections if await strategy.is_node_eligible(neighbor_addr) + # ] logging.info(f"Eligible neighbors for model propagation: {eligible_neighbors}") if not eligible_neighbors: logging.info("Propagation complete: No eligible neighbors.") @@ -337,12 +342,13 @@ async def propagate(self, strategy_id: str): else: serialized_model = None - round_number = -1 if strategy_id == "initialization" else self.get_round() + current_round = await self.get_round() + round_number = -1 if strategy_id == "initialization" else current_round parameters = serialized_model message = self.cm.create_message("model", "", round_number, parameters, weight) for neighbor_addr in eligible_neighbors: logging.info( - f"Sending model to {neighbor_addr} with round {self.get_round()}: weight={weight} |Β size={sys.getsizeof(serialized_model) / (1024** 2) if serialized_model is not None else 0} MB" + f"Sending model to {neighbor_addr} with round {await self.get_round()}: weight={weight} |Β size={sys.getsizeof(serialized_model) / (1024** 2) if serialized_model is not None else 0} MB" ) asyncio.create_task(self.cm.send_message(neighbor_addr, message, "model")) # asyncio.create_task(self.cm.send_model(neighbor_addr, round_number, serialized_model, weight)) @@ -371,7 +377,7 @@ async def get_model_information(self, dest_addr, strategy_id: str, init=False): if strategy_id not in self.strategies: logging.info(f"Strategy {strategy_id} not found.") return None - if self.get_round() is None: + if await self.get_round() is None: logging.info("Propagation halted: round is not set.") return None @@ -385,7 +391,7 @@ async def get_model_information(self, dest_addr, strategy_id: str, init=False): serialized_model = ( model_params if isinstance(model_params, bytes) else self.trainer.serialize_model(model_params) ) - return (serialized_model, rounds, self.get_round()) + return (serialized_model, rounds, await self.get_round()) return None diff --git a/nebula/core/node.py b/nebula/core/node.py index 267e1b28b..86a73cc2a 100755 --- a/nebula/core/node.py +++ b/nebula/core/node.py @@ -3,6 +3,8 @@ import random import sys import warnings +import socket +import docker import torch @@ -38,8 +40,7 @@ from nebula.core.models.fashionmnist.mlp import FashionMNISTModelMLP from nebula.core.models.mnist.cnn import MNISTModelCNN from nebula.core.models.mnist.mlp import MNISTModelMLP -from nebula.core.noderole import AggregatorNode, IdleNode, MaliciousNode, ServerNode, TrainerNode -from nebula.core.role import Role +from nebula.core.engine import Engine from nebula.core.training.lightning import Lightning from nebula.core.training.siamese import Siamese @@ -48,7 +49,7 @@ # os.environ["TORCHDYNAMO_VERBOSE"] = "1" -async def main(config): +async def main(config: Config): """ Main function to start the NEBULA node. @@ -175,20 +176,6 @@ async def main(config): else: raise ValueError(f"Trainer {trainer_str} not supported") - if config.participant["device_args"]["malicious"]: - node_cls = MaliciousNode - else: - if config.participant["device_args"]["role"] == Role.AGGREGATOR.value: - node_cls = AggregatorNode - elif config.participant["device_args"]["role"] == Role.TRAINER.value: - node_cls = TrainerNode - elif config.participant["device_args"]["role"] == Role.SERVER.value: - node_cls = ServerNode - elif config.participant["device_args"]["role"] == Role.IDLE.value: - node_cls = IdleNode - else: - raise ValueError(f"Role {config.participant['device_args']['role']} not supported") - VARIABILITY = 0.5 def randomize_value(value, variability): @@ -213,9 +200,10 @@ def randomize_value(value, variability): value = value[key] value[keys[-1]] = randomize_value(value[keys[-1]], VARIABILITY) - logging.info(f"Starting node {idx} with model {model_name}, trainer {trainer.__name__}, and as {node_cls.__name__}") + role = config.participant["device_args"]["role"] + logging.info(f"Starting node {idx} with model {model_name}, trainer {trainer.__name__}, and as {role}") - node = node_cls( + node = Engine( model=model, datamodule=datamodule, config=config, diff --git a/nebula/core/noderole.py b/nebula/core/noderole.py index 12283f9a2..9bd258fef 100644 --- a/nebula/core/noderole.py +++ b/nebula/core/noderole.py @@ -1,298 +1,473 @@ +from __future__ import annotations import logging +import asyncio from nebula.addons.attacks.attacks import create_attack +from nebula.addons.functions import print_msg_box from nebula.config.config import Config -from nebula.core.engine import Engine +from nebula.core.utils.locker import Locker from nebula.core.eventmanager import EventManager -from nebula.core.nebulaevents import UpdateReceivedEvent -from nebula.core.training.lightning import Lightning - +from nebula.core.nebulaevents import UpdateReceivedEvent, ModelPropagationEvent +import random from enum import Enum +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from nebula.core.engine import Engine -class MaliciousNode(Engine): - """ - Specialized Engine subclass representing a malicious participant in the Federated Learning scenario. +#TODO ensure attacks works properly - This node behaves similarly to a standard node but is designed to simulate adversarial or faulty behavior - within the federation. It can be used for testing the robustness of the FL protocol, defense mechanisms, - and detection strategies. +""" ############################## + # ROLE BEHAVIORS # + ############################## +""" - Inherits from: - Engine: The base class that defines the main control flow of the Federated Learning process. +class Role(Enum): + """ + This class defines the participant roles of the platform. + """ + TRAINER = "trainer" + AGGREGATOR = "aggregator" + TRAINER_AGGREGATOR = "trainer_aggregator" + PROXY = "proxy" + IDLE = "idle" + SERVER = "server" + MALICIOUS = "malicious" + +def factory_node_role(role: str) -> Role: + if role == "trainer": + return Role.TRAINER + elif role == "aggregator": + return Role.AGGREGATOR + elif role =="trainer_aggregator": + return Role.TRAINER_AGGREGATOR + elif role == "proxy": + return Role.PROXY + elif role == "idle": + return Role.IDLE + elif role == "server": + return Role.SERVER + elif role == "malicious": + return Role.MALICIOUS + else: + return "" + +class RoleBehavior(ABC): + """ + Abstract base class for defining the role-specific behavior of a node in CFL, DFL, or SDFL systems. - Typical malicious behaviors may include (depending on the scenario configuration): - - Sending incorrect or poisoned model updates. - - Dropping or delaying messages. - - Attempting to manipulate the reputation or aggregation process. - - Participating inconsistently to mimic byzantine or selfish nodes. + Each subclass encapsulates the logic needed for a particular node role (e.g., trainer, aggregator), + providing custom implementations for role-related operations such as training cycles, + update aggregation, and recovery strategies. Attributes: - Inherits all attributes from the base Engine class, but may override key methods related to - training, aggregation, message handling, or reporting. - - Note: - The behavior of this class is driven by scenario configuration parameters and any overridden methods - implementing specific attack strategies. - - Usage: - This class should be instantiated and used in place of the normal Engine to simulate a malicious node. - It integrates seamlessly into the existing federation infrastructure. + _next_role (Role): The role to which the node is expected to transition. + _next_role_locker (Locker): An asynchronous lock to protect access to _next_role. + _source_to_notificate (Optional[Any]): The source node to notify once a role change is applied. """ - def __init__( - self, - model, - datamodule, - config=Config, - trainer=Lightning, - security=False, - ): - super().__init__( - model, - datamodule, - config, - trainer, - security, + def __init__(self): + self._next_role: Role = None + self._next_role_locker = Locker("next_role_locker", async_lock=True) + self._source_to_notificate = None + + @abstractmethod + def get_role(self): + """ + Returns the Role enum value representing the current role of the node. + """ + raise NotImplementedError + + @abstractmethod + def get_role_name(self, effective=False): + """ + Returns a string representation of the current role. + + Args: + effective (bool): Whether to return the name of the current effective role when going as malicious. + + Returns: + str: Name of the role. + """ + raise NotImplementedError + + @abstractmethod + async def extended_learning_cycle(self): + """ + Performs the main learning or aggregation cycle associated with the current role. + + This method encapsulates all the logic tied to the behavior of the node in its current role, + including training, aggregating updates, and coordinating with neighbors. + """ + raise NotImplementedError + + @abstractmethod + async def select_nodes_to_wait(self): + """ + Determines which neighbors the node should wait for during the current cycle. + + This logic varies depending on whether the node is an aggregator, trainer, or other role. + + Returns: + Set[Any]: A set of neighbor node identifiers to wait for. + """ + raise NotImplementedError + + @abstractmethod + async def resolve_missing_updates(self): + """ + Defines the fallback strategy when expected model updates are not received. + + For example, an aggregator might default to a fresh model, while a trainer might proceed + with its own local model. + + Returns: + Any: The resolution outcome depending on the role's specific logic. + """ + raise NotImplementedError + + async def set_next_role(self, role: Role, source_to_notificate = None): + """ + Schedules a role change and optionally stores the source to notify upon completion. + + Args: + role (Role): The new role to transition to. + source_to_notificate (Optional[Any]): Identifier of the node that triggered the change. + """ + async with self._next_role_locker: + self._next_role = role + self._source_to_notificate = source_to_notificate + + async def get_next_role(self) -> Role: + """ + Retrieves and clears the next role value. + + Returns: + Role: The next role to transition into. + """ + async with self._next_role_locker: + next_role = self._next_role + self._next_role = None + return next_role + + async def get_source_to_notificate(self): + """ + Retrieves and clears the stored source to notify after a role change. + + Returns: + Any: The source node identifier, or None if not set. + """ + async with self._next_role_locker: + source_to_notificate = self._source_to_notificate + self._source_to_notificate = None + return source_to_notificate + + async def update_role_needed(self): + """ + Checks whether a role update is scheduled. + + Returns: + bool: True if a role update is pending, False otherwise. + """ + async with self._next_role_locker: + updt_needed = self._next_role != None + return updt_needed + +""" ############################## + # MALICIOUS BEHAVIOR # + ############################## +""" + +class MaliciousRoleBehavior(RoleBehavior): + def __init__(self, engine: Engine, config: Config): + super().__init__() + print_msg_box( + msg=f"Role Behavior Malicious initialization", + indent=2, + title="Role initialization", ) - self.attack = create_attack(self) - self.aggregator_bening = self._aggregator - - async def _extended_learning_cycle(self): + self._engine = engine + self._config = config + logging.info("Creating attack behavior...") + self.attack = create_attack(self._engine) + logging.info("Attack behavior created") + self.aggregator_bening = self._engine._aggregator + benign_role = self._config.participant["adversarial_args"]["fake_behavior"] + self._fake_role_behavior = factory_role_behavior(benign_role, self._engine, self._config) + self._role = factory_node_role("malicious") + + def get_role(self): + return self._role + + def get_role_name(self, effective=False): + if effective: + return self._fake_role_behavior.get_role_name() + return f"{self._role.value} as {self._fake_role_behavior.get_role_name()}" + + async def extended_learning_cycle(self): try: await self.attack.attack() except Exception: - attack_name = self.config.participant["adversarial_args"]["attack_params"]["attacks"] + attack_name = self._config.participant["adversarial_args"]["attacks"] logging.exception(f"Attack {attack_name} failed") - - if self.role.value == "aggregator": - await AggregatorNode._extended_learning_cycle(self) - if self.role.value == "trainer": - await TrainerNode._extended_learning_cycle(self) - if self.role.value == "server": - await ServerNode._extended_learning_cycle(self) - - -class AggregatorNode(Engine): - """ - Node in the Federated Learning system with full training capabilities and additional responsibilities - as an aggregator within the federation. - - This class extends `Engine`, inheriting the full Federated Learning pipeline, including: - - Local model training - - Communication and model sharing with neighboring nodes - - Participation in the aggregation process - - Additional Role: - AggregatorNode is distinguished by its responsibility to **perform model aggregation** from - other participants in its neighborhood or federation scope. This may include: - - Collecting local model updates from neighbors - - Applying aggregation functions (e.g., weighted averaging) - - Updating and distributing the aggregated model - - Managing round synchronization where necessary - - Use Cases: - - Decentralized or partially decentralized federations where aggregation is distributed - - Scenarios with multiple aggregators to increase resilience and scalability - - Hybrid setups with rotating or dynamically elected aggregators - - Attributes: - Inherits all attributes and methods from the `Engine` class. Aggregator-specific behaviors are - typically handled via the `Aggregator` component and configuration parameters. - - Note: - While this node performs aggregation, it also fully participates in trainingβ€”its role is dual: - **trainer and aggregator**, which makes it a powerful actor in the federation topology. - """ - def __init__( - self, - model, - datamodule, - config=Config, - trainer=Lightning, - security=False, - ): - super().__init__( - model, - datamodule, - config, - trainer, - security, - ) - - async def _extended_learning_cycle(self): - # Define the functionality of the aggregator node - await self.trainer.test() - await self.trainning_in_progress_lock.acquire_async() - await self.trainer.train() - await self.trainning_in_progress_lock.release_async() + + await self._fake_role_behavior.extended_learning_cycle() + + async def select_nodes_to_wait(self): + nodes = await self._fake_role_behavior.select_nodes_to_wait() + return nodes + + async def resolve_missing_updates(self): + return await self._fake_role_behavior.resolve_missing_updates() + +""" ############################### + # TRAINER AGGREGATOR BEHAVIOR # + ############################### +""" + +class TrainerAggregatorRoleBehavior(RoleBehavior): + def __init__(self, engine: Engine, config: Config): + super().__init__() + self._engine = engine + self._config = config + self._role = factory_node_role("trainer_aggregator") + + def get_role(self): + return self._role + + def get_role_name(self, effective=False): + return self._role.value + + async def extended_learning_cycle(self): + await self._engine.trainer.test() + await self._engine.trainning_in_progress_lock.acquire_async() + await self._engine.trainer.train() + await self._engine.trainning_in_progress_lock.release_async() self_update_event = UpdateReceivedEvent( - self.trainer.get_model_parameters(), self.trainer.get_model_weight(), self.addr, self.round + self._engine.trainer.get_model_parameters(), self._engine.trainer.get_model_weight(), self._engine.addr, self._engine.round ) await EventManager.get_instance().publish_node_event(self_update_event) - await self.cm.propagator.propagate("stable") - await self._waiting_model_updates() - - -class ServerNode(Engine): - """ - Server node extending the Engine class to manage the federation from a centralized perspective. - - This node does NOT perform local model training. Instead, it: - - Tests the aggregated global model. - - Performs model aggregation from participant updates. - - Propagates the aggregated global model to participant nodes. - - Main functionalities: - - Coordinating the aggregation of models received from participant nodes. - - Evaluating the aggregated global model to monitor performance. - - Disseminating the updated global model back to the federation. - - Managing communication and synchronization signals within the federation. - - Typical use cases: - - Centralized federated learning setups where training happens at participant nodes. - - Server node acts as the aggregator and evaluator of global model. - - Ensures the integrity and progress of the federated learning process by managing rounds and updates. - - Attributes: - Inherits all attributes and methods from `Engine` with specialized logic for aggregation, - evaluation, and propagation of the global model. - - Note: - The ServerNode does not execute training itself but relies on receiving model updates from - participant nodes for aggregation. - """ + mpe = ModelPropagationEvent(await self._engine.cm.get_addrs_current_connections(only_direct=True, myself=False), "stable") + await EventManager.get_instance().publish_node_event(mpe) + + await self._engine._waiting_model_updates() + + async def select_nodes_to_wait(self): + nodes = await self._engine.cm.get_addrs_current_connections(only_direct=True, myself=True) + return nodes + + async def resolve_missing_updates(self): + return {} + +""" ############################## + # AGGREGATOR BEHAVIOR # + ############################## +""" + +class AggregatorRoleBehavior(RoleBehavior): + def __init__(self, engine: Engine, config: Config): + super().__init__() + self._engine = engine + self._config = config + self._role = factory_node_role("aggregator") + self._transfer_send = False + + def get_role(self): + return self._role + + def get_role_name(self, effective=False): + return self._role.value + async def extended_learning_cycle(self): + await self._engine.trainer.test() + + await self._engine._waiting_model_updates() + + mpe = ModelPropagationEvent(await self._engine.cm.get_addrs_current_connections(only_direct=True, myself=False), "stable") + await EventManager.get_instance().publish_node_event(mpe) + + # Transfer leadership + neighbors = await self._engine.cm.get_addrs_current_connections(myself=False) + if len(neighbors) and not self._transfer_send: + random_neighbor = random.choice(list(neighbors)) + lt_message = self._engine.cm.create_message("control", "leadership_transfer") + logging.info(f"Sending transfer leadership to: {random_neighbor}") + asyncio.create_task(self._engine.cm.send_message(random_neighbor, lt_message)) + self._transfer_send = True + + async def select_nodes_to_wait(self): + nodes = await self._engine.cm.get_addrs_current_connections(only_direct=True, myself=False) + return nodes + + async def resolve_missing_updates(self): + return (self._engine.trainer.get_model_parameters(), self._engine.trainer.BYPASS_MODEL_WEIGHT) + +""" ############################## + # SERVER BEHAVIOR # + ############################## +""" + +class ServerRoleBehavior(RoleBehavior): from datetime import datetime - def __init__( - self, - model, - datamodule, - config=Config, - trainer=Lightning, - security=False, - ): - super().__init__( - model, - datamodule, - config, - trainer, - security, - ) - self._start_time = ServerNode.datetime.now().strftime("%d/%m/%Y %H:%M:%S") - - async def _extended_learning_cycle(self): - # Define the functionality of the server node - await self.trainer.test() - - self_update_event = UpdateReceivedEvent( - self.trainer.get_model_parameters(), self.trainer.BYPASS_MODEL_WEIGHT, self.addr, self.round - ) - await EventManager.get_instance().publish_node_event(self_update_event) - - await self._waiting_model_updates() - await self.cm.propagator.propagate("stable") - - -class TrainerNode(Engine): - """ - Trainer node extending the Engine class responsible exclusively for local training and model propagation. - - This node: - - Performs local model training using its own data. - - Propagates the locally trained model updates to aggregator or server nodes. + def __init__(self, engine: Engine, config: Config): + super().__init__() + self._engine = engine + self._config = config + self._start_time = ServerRoleBehavior.datetime.now().strftime("%d/%m/%Y %H:%M:%S") + self._role = factory_node_role("server") + + def get_role(self): + return self._role + + def get_role_name(self, effective=False): + return self._role.value + + async def extended_learning_cycle(self): + await self._engine.trainer.test() + + await self._engine._waiting_model_updates() + + mpe = ModelPropagationEvent(await self._engine.cm.get_addrs_current_connections(only_direct=True, myself=False), "stable") + await EventManager.get_instance().publish_node_event(mpe) + + async def select_nodes_to_wait(self): + nodes = await self._engine.cm.get_addrs_current_connections(only_direct=True, myself=False) + return nodes - It does NOT perform model aggregation. + async def resolve_missing_updates(self): + return (self._engine.trainer.get_model_parameters(), self._engine.trainer.BYPASS_MODEL_WEIGHT) + +""" ############################## + # TRAINER BEHAVIOR # + ############################## +""" + +class TrainerRoleBehavior(RoleBehavior): + def __init__(self, engine: Engine, config: Config): + super().__init__() + self._engine = engine + self._config = config + self._role = factory_node_role("trainer") + + def get_role(self): + return self._role + + def get_role_name(self, effective=False): + return self._role.value + + async def extended_learning_cycle(self): + logging.info("Waiting global update | Assign _waiting_global_update = True") - Main functionalities: - - Training the model locally according to the federated learning protocol. - - Sending updated model parameters to aggregator nodes or server. - - Managing communication related to local training progress and updates. + await self._engine.trainer.test() + await self._engine.trainer.train() + + mpe = ModelPropagationEvent(await self._engine.cm.get_addrs_current_connections(only_direct=True, myself=False), "stable") + await EventManager.get_instance().publish_node_event(mpe) + + await self._engine._waiting_model_updates() + + async def select_nodes_to_wait(self): + nodes = await self._engine.cm.get_addrs_current_connections(only_direct=True, myself=False) + return nodes - Typical use cases: - - Participant nodes in federated learning that contribute local updates. - - Nodes focusing solely on improving their local model and sharing updates. - - Attributes: - Inherits all attributes and methods from `Engine` but change behavior to exclude aggregation steps. - - Note: - Aggregation responsibilities are delegated to other nodes (e.g., ServerNode or AggregatorNode). - """ - def __init__( - self, - model, - datamodule, - config=Config, - trainer=Lightning, - security=False, - ): - super().__init__( - model, - datamodule, - config, - trainer, - security, - ) - - async def _extended_learning_cycle(self): - # Define the functionality of the trainer node + async def resolve_missing_updates(self): + return (self._engine.trainer.get_model_parameters(), self._engine.trainer.get_model_weight()) + +""" ############################## + # IDLE BEHAVIOR # + ############################## +""" + +class IdleRoleBehavior(RoleBehavior): + def __init__(self, engine: Engine, config: Config): + super().__init__() + self._engine = engine + self._config = config + self._role = factory_node_role("idle") + + def get_role(self): + return self._role + + def get_role_name(self, effective=False): + return self._role.value + + async def extended_learning_cycle(self): logging.info("Waiting global update | Assign _waiting_global_update = True") - - await self.trainer.test() - await self.trainer.train() - - self_update_event = UpdateReceivedEvent( - self.trainer.get_model_parameters(), self.trainer.get_model_weight(), self.addr, self.round, local=True - ) - await EventManager.get_instance().publish_node_event(self_update_event) - - await self.cm.propagator.propagate("stable") - await self._waiting_model_updates() - - -class IdleNode(Engine): - """ - Idle node extending the Engine class responsible for passively participating in the federated learning network. - - This node: - - Does not perform any local model training. - - Waits to receive and potentially forward model updates. + await self._engine._waiting_model_updates() + + async def select_nodes_to_wait(self): + nodes = await self._engine.cm.get_addrs_current_connections(only_direct=True, myself=False) + return nodes - It does NOT train models or perform aggregation. - - Main functionalities: - - Passively waiting for model updates from other nodes. - - Handling communication related to received model updates. + async def resolve_missing_updates(self): + raise NotImplementedError + +""" ############################## + # PROXY BEHAVIOR # + ############################## +""" + +class ProxyRoleBehavior(RoleBehavior): + def __init__(self, engine: Engine, config: Config): + super().__init__() + self._engine = engine + self._config = config + self._role = factory_node_role("proxy") + + def get_role(self): + return self._role + + def get_role_name(self, effective=False): + return self._role.value + + async def extended_learning_cycle(self): + logging.info("Waiting global update | Assign _waiting_global_update = True") + await self._engine._waiting_model_updates() + + async def select_nodes_to_wait(self): + nodes = await self._engine.cm.get_addrs_current_connections(only_direct=True, myself=False) + return nodes - Typical use cases: - - Passive participants in federated learning. - - Nodes with no data or limited resources that cannot contribute training but need to stay in sync. - - Observers or relays within the federated network. - - Attributes: - Inherits all attributes and methods from `Engine` but alters behavior to exclude training and aggregation. - - Note: - Training and aggregation responsibilities are delegated to other nodes (e.g., TrainerNode, ServerNode). - """ - def __init__( - self, - model, - datamodule, - config=Config, - trainer=Lightning, - security=False, - ): - super().__init__( - model, - datamodule, - config, - trainer, - security, - ) + async def resolve_missing_updates(self): + raise NotImplementedError + +""" ############################## + # UTILS ROLE BEHAVIORS # + ############################## +""" + +class roleBehaviorException(Exception): + pass + +def factory_role_behavior(role: str, engine: Engine, config: Config) -> RoleBehavior | None: + + role_behaviors = { + "malicious": MaliciousRoleBehavior, + "trainer": TrainerRoleBehavior, + "aggregator": AggregatorRoleBehavior, + "server": ServerRoleBehavior, + "trainer_aggregator": TrainerAggregatorRoleBehavior, + "proxy": ProxyRoleBehavior, + "idle": IdleRoleBehavior, + } + + node_role = role_behaviors.get(role, None) - async def _extended_learning_cycle(self): - # Define the functionality of the idle node - logging.info("Waiting global update | Assign _waiting_global_update = True") - await self._waiting_model_updates() \ No newline at end of file + if node_role: + return node_role(engine, config) + else: + raise roleBehaviorException(f"Node Role Behavior {role} not found") + +def change_role_behavior(old_role: RoleBehavior, new_role: Role, *parameters) -> RoleBehavior: + engine, config = parameters + if not isinstance(old_role, MaliciousRoleBehavior): + return factory_role_behavior(new_role.value, engine, config) + else: + fake_behavior = factory_role_behavior(new_role.value, engine, config) + old_role._fake_role_behavior = fake_behavior + return old_role + + + + diff --git a/nebula/core/situationalawareness/discovery/federationconnector.py b/nebula/core/situationalawareness/discovery/federationconnector.py index 8beb6ba52..85eaba3e5 100644 --- a/nebula/core/situationalawareness/discovery/federationconnector.py +++ b/nebula/core/situationalawareness/discovery/federationconnector.py @@ -216,8 +216,8 @@ async def _waiting_confirmation_from(self, addr): """ async with self.pending_confirmation_from_nodes_lock: found = addr in self.pending_confirmation_from_nodes - logging.info(f"pending confirmations:{self.pending_confirmation_from_nodes}") - logging.info(f"Waiting confirmation from source: {addr}, status: {found}") + # logging.info(f"pending confirmations:{self.pending_confirmation_from_nodes}") + # logging.info(f"Waiting confirmation from source: {addr}, status: {found}") return found async def _confirmation_received(self, addr, confirmation=True, joining=False): @@ -500,14 +500,14 @@ async def _connection_late_connect_callback(self, source, message): ct_actions, df_actions = await self._get_actions() if len(ct_actions): - logging.info(f"{ct_actions}") + # logging.info(f"{ct_actions}") cnt_msg = self.cm.create_message("link", "connect_to", addrs=ct_actions) await self.cm.send_message(source, cnt_msg) if len(df_actions): - logging.info(f"{df_actions}") + # logging.info(f"{df_actions}") for addr in df_actions.split(): - await self.cm.disconnect(addr, mutual_disconnection=True) + await self.cm.disconnect(addr, mutual_disconnection=False) await self._register_late_neighbor(source, joinning_federation=True) @@ -539,7 +539,7 @@ async def _connection_restructure_callback(self, source, message): if len(df_actions): for addr in df_actions.split(): - await self.cm.disconnect(addr, mutual_disconnection=True) + await self.cm.disconnect(addr, mutual_disconnection=False) # df_msg = self.cm.create_message("link", "disconnect_from", addrs=df_actions) # await self.cm.send_message(source, df_msg) @@ -553,7 +553,7 @@ async def _discover_discover_join_callback(self, source, message): await self.engine.trainning_in_progress_lock.acquire_async() model, rounds, round = ( await self.cm.propagator.get_model_information(source, "stable") - if self.engine.get_round() > 0 + if await self.engine.get_round() > 0 else await self.cm.propagator.get_model_information(source, "initialization") ) await self.engine.trainning_in_progress_lock.release_async() @@ -631,12 +631,12 @@ async def _link_connect_to_callback(self, source, message): logging.info(f"πŸ”— handle_link_message | Trigger | Received connect_to message from {source}") addrs = message.addrs for addr in addrs.split(): - await self._meet_node(addr) + asyncio.create_task(self._meet_node(addr)) async def _link_disconnect_from_callback(self, source, message): logging.info(f"πŸ”— handle_link_message | Trigger | Received disconnect_from message from {source}") for addr in message.addrs.split(): - await self.cm.disconnect(addr, mutual_disconnection=True) + asyncio.create_task(self.cm.disconnect(addr, mutual_disconnection=False)) async def stop(self): """ diff --git a/nebula/frontend/static/js/deployment/topology.js b/nebula/frontend/static/js/deployment/topology.js index a08ff3167..28d5ec0ad 100644 --- a/nebula/frontend/static/js/deployment/topology.js +++ b/nebula/frontend/static/js/deployment/topology.js @@ -423,6 +423,7 @@ const TopologyManager = (function() { } else { switch (node.role) { case 'aggregator': + case 'trainer_aggregator': geometry = new THREE.SphereGeometry(5); main_color = "#d95f02"; break; @@ -629,7 +630,7 @@ const TopologyManager = (function() { case "DFL": // All as aggregators for (let i = 0; i < nodes.length; i++) { - nodes[i].role = "aggregator"; + nodes[i].role = "trainer_aggregator"; } break; }