diff --git a/Makefile b/Makefile index 52d0d3569..c1e0cd768 100644 --- a/Makefile +++ b/Makefile @@ -74,6 +74,13 @@ update-dockers: ## Update docker images else \ echo "Skipping nebula-controller docker build."; \ fi + @echo "🐳 Building nebula-database docker image. Do you want to continue (overrides existing image)? (y/n)" + @read ans; if [ "$${ans:-N}" = y ]; then \ + docker build -t nebula-database -f nebula/database/Dockerfile .; \ + docker build -t nebula-pgweb -f nebula/database/pgweb/Dockerfile .; \ + else \ + echo "Skipping nebula-database docker build."; \ + fi @echo "" @echo "🐳 Building nebula-frontend docker image. Do you want to continue (overrides existing image)? (y/n)" @read ans; if [ "$${ans:-N}" = y ]; then \ diff --git a/app/databases/__init__.py b/app/databases/__init__.py deleted file mode 100755 index e69de29bb..000000000 diff --git a/app/deployer.py b/app/deployer.py index a0a5d83d7..95bffda7e 100644 --- a/app/deployer.py +++ b/app/deployer.py @@ -1,7 +1,9 @@ import json import logging import os +import secrets import signal +import string import subprocess import sys import threading @@ -19,6 +21,92 @@ from nebula.controller.scenarios import ScenarioManagement from nebula.utils import DockerUtils, FileUtils, SocketUtils +class CredentialManager: + """ + CredentialManager handles the generation, storage, and validation of environment-based credentials. + + This class is designed to manage credentials required for different system components like the frontend, + Grafana, and the database. It ensures that secure values are generated and persisted in a `.env` file if + they are not already defined in the environment. + + Attributes: + env_path (Path): Absolute path to the environment file where credentials will be stored. + + Typical usage example: + manager = CredentialManager() + manager.check_all_credentials() + """ + + def __init__(self, env_dir="app", env_filename=".env"): + """ + Initializes the CredentialManager and loads existing environment variables from file. + + Args: + env_dir (str): Directory where the .env file is located. Defaults to 'app'. + env_filename (str): Name of the environment file. Defaults to '.env'. + + Behavior: + - Sets up the absolute path to the .env file. + - Loads any existing environment variables from the file using `load_dotenv`. + """ + self.env_path = Path.cwd() / env_dir / env_filename + if os.path.exists(self.env_path): + logging.info(f"Loading environment variables from {self.env_path}") + load_dotenv(self.env_path, override=True) + + def generate_secure_password(self, length=20): + """ + Generates a cryptographically secure and readable password including symbols. + + Args: + length (int): Length of the password. Defaults to 20. + + Returns: + str: A randomly generated secure password, excluding confusing or problematic characters. + """ + alphabet = string.ascii_letters + string.digits + string.punctuation + for char in ['"', "'", "\\", "`", "|", "(", ")", "{", "}", "[", "]", "#"]: + alphabet = alphabet.replace(char, "") + return ''.join(secrets.choice(alphabet) for _ in range(length)) + + def check_credential(self, key, is_password=True): + """ + Checks if a given credential key is present in the environment. If not, generates and saves it. + + Args: + key (str): The environment variable key to check or create. + is_password (bool): If True, generates a secure password. If False, generates a hex token. Defaults to True. + + Behavior: + - If the key is missing, a value is generated and stored both in the environment and the `.env` file. + - If the key exists, no action is taken. + """ + if key not in os.environ: + logging.info(f"Generating value for {key}") + value = self.generate_secure_password(12) if is_password else secrets.token_hex(24) + os.environ[key] = value + logging.info(f"Saving {key} to {self.env_path}") + with self.env_path.open("a") as f: + f.write(f"{key}={value}\n") + else: + logging.info(f"{key} already set") + + def check_all_credentials(self): + """ + Checks and sets all required credentials for the application. + + This method should be called at startup to ensure all necessary keys are initialized. + + Includes: + - Frontend secret key + - Grafana admin password + - (Optional) Database password + """ + self.check_credential("SECRET_KEY", is_password=False) + self.check_credential("GF_SECURITY_ADMIN_PASSWORD") + self.check_credential("POSTGRES_PASSWORD") + self.check_credential("NEBULA_ADMIN_PASSWORD") + class NebulaEventHandler(PatternMatchingEventHandler): """ @@ -514,6 +602,10 @@ def __init__(self, args): logging.exception(warning_msg) sys.exit(1) + self.configure_logger() + self.credentialmanager = CredentialManager() + self.credentialmanager.check_all_credentials() + # --- Tag logic: CLI args > environment > fallback --- arg_production = getattr(args, "production", False) arg_prefix = getattr(args, "prefix", "dev") @@ -565,7 +657,6 @@ def __init__(self, args): self.loki_port = int(args.lokiport) if hasattr(args, "lokiport") else 6010 self.statistics_port = int(args.statsport) if hasattr(args, "statsport") else 8080 - self.configure_logger() def get_container_name(self, role_tag: str) -> str: """ @@ -761,7 +852,8 @@ def start(self): self.run_controller() logging.info("NEBULA Controller is running") - logging.info(f"NEBULA Databases created in {self.databases_dir}") + self.run_database() + logging.info(f"NEBULA Databases docker is running") self.run_frontend() logging.info(f"NEBULA Frontend is running at http://localhost:{self.frontend_port}") if self.production and self.prefix == "production": @@ -853,6 +945,7 @@ def run_frontend(self): client = docker.from_env() environment = { + "SECRET_KEY": os.environ.get("SECRET_KEY"), "NEBULA_PRODUCTION": self.production, "NEBULA_ENV_TAG": self.env_tag, "NEBULA_PREFIX_TAG": self.prefix_tag, @@ -912,6 +1005,77 @@ def run_frontend(self): # Add to metadata Deployer._add_container_to_metadata(frontend_container_name) + def run_database(self): + """ + Runs the Nebula database within a Docker container, ensuring the required Docker environment is available. + """ + if sys.platform == "win32": + if not os.path.exists("//./pipe/docker_Engine"): + raise Exception( + "Docker is not running, please check if Docker is running and Docker Compose is installed." + ) + else: + if not os.path.exists("/var/run/docker.sock"): + raise Exception( + "/var/run/docker.sock not found, please check if Docker is running and Docker Compose is installed." + ) + + network_name = self.get_network_name("net-base") + base = DockerUtils.create_docker_network(network_name) + Deployer._add_network_to_metadata(network_name) + + client = docker.from_env() + + # --- PostgreSQL --- + pg_container_name = self.get_container_name("nebula-database") + pg_environment = { + "POSTGRES_USER": "nebula", + "POSTGRES_PASSWORD": os.environ.get("POSTGRES_PASSWORD"), + "POSTGRES_DB": "nebula", + } + host_sql_path = os.path.join(self.root_path, "nebula/database/init-configs.sql") + db_data_path = os.path.join(self.databases_dir, "postgres-data") + os.makedirs(db_data_path, exist_ok=True) + + pg_host_config = client.api.create_host_config( + binds=[ + f"{host_sql_path}:/docker-entrypoint-initdb.d/init-configs.sql", + f"{db_data_path}:/var/lib/postgresql/data", + ], + port_bindings={5432: 5432}, + ) + pg_networking_config = client.api.create_networking_config( + {f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.125")} + ) + + pg_container = client.api.create_container( + image="nebula-database", + name=pg_container_name, + detach=True, + environment=pg_environment, + host_config=pg_host_config, + networking_config=pg_networking_config, + ) + client.api.start(pg_container) + Deployer._add_container_to_metadata(pg_container_name) + + # --- PGWeb --- + pgweb_container_name = self.get_container_name("nebula-pgweb") + pgweb_host_config = client.api.create_host_config(port_bindings={8081: 8085}) + pgweb_networking_config = client.api.create_networking_config( + {f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.135")} + ) + + pgweb_container = client.api.create_container( + image="nebula-pgweb", + name=pgweb_container_name, + detach=True, + host_config=pgweb_host_config, + networking_config=pgweb_networking_config, + ) + client.api.start(pgweb_container) + Deployer._add_container_to_metadata(pgweb_container_name) + def run_controller(self): if sys.platform == "win32": if not os.path.exists("//./pipe/docker_Engine"): @@ -953,6 +1117,11 @@ def run_controller(self): "NEBULA_CONTROLLER_PORT": self.controller_port, "NEBULA_CONTROLLER_HOST": self.controller_host, "NEBULA_FRONTEND_PORT": self.frontend_port, + "DB_HOST": self.get_container_name("nebula-database"), + "DB_PORT": 5432, + "DB_USER": "nebula", + "DB_PASSWORD": os.environ.get("POSTGRES_PASSWORD"), + "NEBULA_ADMIN_PASSWORD": os.environ.get("NEBULA_ADMIN_PASSWORD") } volumes = ["/nebula", "/var/run/docker.sock"] @@ -1065,7 +1234,7 @@ def run_waf(self): Deployer._add_container_to_metadata(waf_container_name) environment = { - "GF_SECURITY_ADMIN_PASSWORD": "admin", + "GF_SECURITY_ADMIN_PASSWORD": os.environ.get("GF_SECURITY_ADMIN_PASSWORD"), "GF_USERS_ALLOW_SIGN_UP": "false", "GF_SERVER_HTTP_PORT": "3000", "GF_SERVER_PROTOCOL": "http", diff --git a/nebula/addons/reputation/reputation.py b/nebula/addons/reputation/reputation.py index 561199513..bca072cea 100644 --- a/nebula/addons/reputation/reputation.py +++ b/nebula/addons/reputation/reputation.py @@ -61,7 +61,7 @@ class Reputation: The class handles collection of metrics, calculation of static and dynamic reputation, updating history, and communication of reputation scores to neighbors. """ - + REPUTATION_THRESHOLD = 0.6 SIMILARITY_THRESHOLD = 0.6 INITIAL_ROUND_FOR_REPUTATION = 1 @@ -70,12 +70,12 @@ class Reputation: WEIGHTED_HISTORY_ROUNDS = 3 FRACTION_ANOMALY_MULTIPLIER = 1.20 THRESHOLD_ANOMALY_MULTIPLIER = 1.15 - + # Augmentation factors LATENCY_AUGMENT_FACTOR = 1.4 MESSAGE_AUGMENT_FACTOR_EARLY = 2.0 MESSAGE_AUGMENT_FACTOR_NORMAL = 1.1 - + # Penalty and decay factors HISTORICAL_PENALTY_THRESHOLD = 0.9 NEGATIVE_LATENCY_PENALTY = 0.3 @@ -104,7 +104,7 @@ def __init__(self, engine: "Engine", config: "Config"): self._addr = engine.addr self._log_dir = engine.log_dir self._idx = engine.idx - + self._initialize_data_structures() self._configure_constants() self._load_configuration() @@ -116,7 +116,7 @@ def _configure_constants(self): """Configure system constants from config or use defaults.""" reputation_config = self._config.participant.get("defense_args", {}).get("reputation", {}) constants_config = reputation_config.get("constants", {}) - + self.REPUTATION_THRESHOLD = constants_config.get("reputation_threshold", self.REPUTATION_THRESHOLD) self.SIMILARITY_THRESHOLD = constants_config.get("similarity_threshold", self.SIMILARITY_THRESHOLD) self.INITIAL_ROUND_FOR_REPUTATION = constants_config.get("initial_round_for_reputation", self.INITIAL_ROUND_FOR_REPUTATION) @@ -180,15 +180,15 @@ def _load_configuration(self): def _setup_connection_metrics(self): """Initialize metrics for each neighbor.""" - neighbors_str = self._config.participant["network_args"]["neighbors"] - for neighbor in neighbors_str.split(): + neighbors = self._config.participant["network_args"]["neighbors"] + for neighbor in neighbors: self.connection_metrics[neighbor] = Metrics() def _configure_metric_weights(self): """Configure weights for different metrics based on weighting factor.""" default_weight = 0.25 metric_names = ["model_arrival_latency", "model_similarity", "num_messages", "fraction_parameters_changed"] - + if self._weighting_factor == "static": self._weight_model_arrival_latency = float( self._metrics.get("model_arrival_latency", {}).get("weight", default_weight) @@ -209,7 +209,7 @@ def _configure_metric_weights(self): elif not isinstance(self._metrics[metric_name], dict): self._metrics[metric_name] = {"enabled": bool(self._metrics[metric_name])} self._metrics[metric_name]["weight"] = default_weight - + self._weight_model_arrival_latency = default_weight self._weight_model_similarity = default_weight self._weight_num_messages = default_weight @@ -229,24 +229,24 @@ def engine(self): def _is_metric_enabled(self, metric_name: str, metrics_config: dict = None) -> bool: """ Check if a specific metric is enabled based on the provided configuration. - + Args: metric_name (str): The name of the metric to check. - metrics_config (dict, optional): The configuration dictionary for metrics. + metrics_config (dict, optional): The configuration dictionary for metrics. If None, uses the instance's _metrics. - + Returns: bool: True if the metric is enabled, False otherwise. """ config_to_use = metrics_config if metrics_config is not None else getattr(self, '_metrics', None) - + if not isinstance(config_to_use, dict): if metrics_config is not None: logging.warning(f"metrics_config is not a dictionary: {type(metrics_config)}") else: logging.warning("_metrics is not properly initialized") return False - + metric_config = config_to_use.get(metric_name) if metric_config is None: return False @@ -269,7 +269,7 @@ def save_data( ): """ Save data between nodes and aggregated models. - + Args: type_data: Type of data to save ('number_message', 'fraction_of_params_changed', 'model_arrival_latency') nei: Neighbor identifier @@ -290,7 +290,7 @@ def save_data( try: metrics_instance = self.connection_metrics[nei] - + if type_data == "number_message": message_data = {"time": time, "current_round": current_round} if not isinstance(metrics_instance.messages, list): @@ -345,19 +345,19 @@ async def init_reputation( ): """ Initialize the reputation system. - + Args: federation_nodes: List of federation node identifiers - round_num: Current round number + round_num: Current round number last_feedback_round: Last round that received feedback init_reputation: Initial reputation value to assign """ if not self._enabled: return - + if not self._validate_init_parameters(federation_nodes, round_num, init_reputation): return - + neighbors = self._validate_federation_nodes(federation_nodes) if not neighbors: logging.error("init_reputation | No valid neighbors found") @@ -370,13 +370,13 @@ def _validate_init_parameters(self, federation_nodes, round_num, init_reputation if not federation_nodes: logging.error("init_reputation | No federation nodes provided") return False - + if round_num is None: logging.warning("init_reputation | Round number not provided") - + if init_reputation is None: logging.warning("init_reputation | Initial reputation value not provided") - + return True async def _initialize_neighbor_reputations(self, neighbors: list, round_num: int, last_feedback_round: int, init_reputation: float): @@ -392,7 +392,7 @@ def _create_or_update_reputation_entry(self, nei: str, round_num: int, last_feed "round": round_num, "last_feedback_round": last_feedback_round, } - + if nei not in self.reputation: self.reputation[nei] = reputation_data elif self.reputation[nei].get("reputation") is None: @@ -401,21 +401,21 @@ def _create_or_update_reputation_entry(self, nei: str, round_num: int, last_feed def _validate_federation_nodes(self, federation_nodes) -> list: """ Validate and filter federation nodes. - + Args: federation_nodes: List of federation node identifiers - + Returns: list: List of valid node identifiers """ if not federation_nodes: return [] - + valid_nodes = [node for node in federation_nodes if node and str(node).strip()] - + if not valid_nodes: logging.warning("No valid federation nodes found after filtering") - + return valid_nodes async def _calculate_static_reputation( @@ -429,7 +429,7 @@ async def _calculate_static_reputation( Args: addr: The participant's address - nei: The neighbor's address + nei: The neighbor's address metric_values: Dictionary with metric values """ static_weights = { @@ -440,10 +440,10 @@ async def _calculate_static_reputation( } reputation_static = sum( - metric_values.get(metric_name, 0) * static_weights[metric_name] + metric_values.get(metric_name, 0) * static_weights[metric_name] for metric_name in static_weights ) - + logging.info(f"Static reputation for node {nei} at round {await self.engine.get_round()}: {reputation_static}") avg_reputation = await self.save_reputation_history_in_memory(self.engine.addr, nei, reputation_static) @@ -476,48 +476,48 @@ async def _calculate_dynamic_reputation(self, addr, neighbors): async def _calculate_average_weights(self): """Calculate average weights for all enabled metrics.""" average_weights = {} - + for metric_name in self.history_data.keys(): if self._is_metric_enabled(metric_name): average_weights[metric_name] = await self._get_metric_average_weight(metric_name) - + return average_weights - + async def _get_metric_average_weight(self, metric_name): """Get the average weight for a specific metric.""" if metric_name not in self.history_data or not self.history_data[metric_name]: logging.debug(f"No history data available for metric: {metric_name}") return 0 - + valid_entries = [ entry for entry in self.history_data[metric_name] - if (entry.get("round") is not None and - entry["round"] >= await self._engine.get_round() and + if (entry.get("round") is not None and + entry["round"] >= await self._engine.get_round() and entry.get("weight") not in [None, -1]) ] - + if not valid_entries: return 0 - + try: weights = [entry["weight"] for entry in valid_entries if entry.get("weight") is not None] return sum(weights) / len(weights) if weights else 0 except (TypeError, ZeroDivisionError) as e: logging.warning(f"Error calculating average weight for {metric_name}: {e}") return 0 - + async def _process_neighbors_reputation(self, addr, neighbors, average_weights): """Process reputation calculation for all neighbors.""" for nei in neighbors: metric_values = await self._get_neighbor_metric_values(nei) - + if all(metric_name in metric_values for metric_name in average_weights): await self._update_neighbor_reputation(addr, nei, metric_values, average_weights) - + async def _get_neighbor_metric_values(self, nei): """Get metric values for a specific neighbor in the current round.""" metric_values = {} - + for metric_name in self.history_data: if self._is_metric_enabled(metric_name): for entry in self.history_data.get(metric_name, []): @@ -526,16 +526,16 @@ async def _get_neighbor_metric_values(self, nei): entry.get("nei") == nei): metric_values[metric_name] = entry.get("metric_value", 0) break - + return metric_values - + async def _update_neighbor_reputation(self, addr, nei, metric_values, average_weights): """Update reputation for a specific neighbor.""" reputation_with_weights = sum( - metric_values.get(metric_name, 0) * average_weights[metric_name] + metric_values.get(metric_name, 0) * average_weights[metric_name] for metric_name in average_weights ) - + logging.info( f"Dynamic reputation with weights for {nei} at round {await self._engine.get_round()}: {reputation_with_weights}" ) @@ -564,7 +564,7 @@ async def _update_reputation_record(self, nei: str, reputation: float, data: dic data: Additional data to update (currently unused) """ current_round = await self._engine.get_round() - + if nei not in self.reputation: self.reputation[nei] = { "reputation": reputation, @@ -576,7 +576,7 @@ async def _update_reputation_record(self, nei: str, reputation: float, data: dic self.reputation[nei]["round"] = current_round logging.info(f"Reputation of node {nei}: {self.reputation[nei]['reputation']}") - + if self.reputation[nei]["reputation"] < self.REPUTATION_THRESHOLD and current_round > 0: self.rejected_nodes.add(nei) logging.info(f"Rejected node {nei} at round {current_round}") @@ -608,23 +608,23 @@ def calculate_weighted_values( reputation_metrics ) self._add_current_metrics_to_history(active_metrics, history_data, current_round, addr, nei) - + if current_round >= self.INITIAL_ROUND_FOR_REPUTATION and len(active_metrics) > 0: adjusted_weights = self._calculate_dynamic_weights(active_metrics, history_data) else: adjusted_weights = self._calculate_uniform_weights(active_metrics) - + self._update_history_with_weights(active_metrics, history_data, adjusted_weights, current_round, nei) def _ensure_history_data_structure(self, history_data: dict): """Ensure all required keys exist in history data structure.""" required_keys = [ "num_messages", - "model_similarity", + "model_similarity", "fraction_parameters_changed", "model_arrival_latency", ] - + for key in required_keys: if key not in history_data: history_data[key] = [] @@ -644,7 +644,7 @@ def _get_active_metrics( "fraction_parameters_changed": fraction_score_asign, "model_arrival_latency": avg_model_arrival_latency, } - + return {k: v for k, v in all_metrics.items() if self._is_metric_enabled(k, reputation_metrics)} def _add_current_metrics_to_history(self, active_metrics: dict, history_data: dict, current_round: int, addr: str, nei: str): @@ -662,7 +662,7 @@ def _add_current_metrics_to_history(self, active_metrics: dict, history_data: di def _calculate_dynamic_weights(self, active_metrics: dict, history_data: dict) -> dict: """Calculate dynamic weights based on metric deviations.""" deviations = self._calculate_metric_deviations(active_metrics, history_data) - + if all(deviation == 0.0 for deviation in deviations.values()): return self._generate_random_weights(active_metrics) else: @@ -672,7 +672,7 @@ def _calculate_dynamic_weights(self, active_metrics: dict, history_data: dict) - def _calculate_metric_deviations(self, active_metrics: dict, history_data: dict) -> dict: """Calculate deviations of current metrics from historical means.""" deviations = {} - + for metric_name, current_value in active_metrics.items(): historical_values = history_data[metric_name] metric_values = [ @@ -680,11 +680,11 @@ def _calculate_metric_deviations(self, active_metrics: dict, history_data: dict) for entry in historical_values if "metric_value" in entry and entry["metric_value"] != 0 ] - + mean_value = np.mean(metric_values) if metric_values else 0 deviation = abs(current_value - mean_value) deviations[metric_name] = deviation - + return deviations def _generate_random_weights(self, active_metrics: dict) -> dict: @@ -692,7 +692,7 @@ def _generate_random_weights(self, active_metrics: dict) -> dict: num_metrics = len(active_metrics) random_weights = [random.random() for _ in range(num_metrics)] total_random_weight = sum(random_weights) - + return { metric_name: weight / total_random_weight for metric_name, weight in zip(active_metrics, random_weights, strict=False) @@ -702,14 +702,14 @@ def _normalize_deviation_weights(self, deviations: dict) -> dict: """Normalize weights based on deviations.""" max_deviation = max(deviations.values()) if deviations else 1 normalized_weights = { - metric_name: (deviation / max_deviation) + metric_name: (deviation / max_deviation) for metric_name, deviation in deviations.items() } - + total_weight = sum(normalized_weights.values()) if total_weight > 0: return { - metric_name: weight / total_weight + metric_name: weight / total_weight for metric_name, weight in normalized_weights.items() } else: @@ -720,20 +720,20 @@ def _adjust_weights_with_minimum(self, normalized_weights: dict, deviations: dic """Apply minimum weight constraints and renormalize.""" mean_deviation = np.mean(list(deviations.values())) dynamic_min_weight = max(self.DYNAMIC_MIN_WEIGHT_THRESHOLD, mean_deviation / (mean_deviation + 1)) - + adjusted_weights = {} total_adjusted_weight = 0 - + for metric_name, weight in normalized_weights.items(): adjusted_weight = max(weight, dynamic_min_weight) adjusted_weights[metric_name] = adjusted_weight total_adjusted_weight += adjusted_weight - + # Renormalize if total weight exceeds 1 if total_adjusted_weight > 1: for metric_name in adjusted_weights: adjusted_weights[metric_name] /= total_adjusted_weight - + return adjusted_weights def _calculate_uniform_weights(self, active_metrics: dict) -> dict: @@ -748,8 +748,8 @@ def _update_history_with_weights(self, active_metrics: dict, history_data: dict, for metric_name in active_metrics: weight = weights.get(metric_name, -1) for entry in history_data[metric_name]: - if (entry["metric_name"] == metric_name and - entry["round"] == current_round and + if (entry["metric_name"] == metric_name and + entry["round"] == current_round and entry["nei"] == nei): entry["weight"] = weight @@ -765,7 +765,7 @@ async def calculate_value_metrics(self, addr, nei, metrics_active=None): try: current_round = await self._engine.get_round() metrics_instance = self.connection_metrics.get(nei) - + if not metrics_instance: logging.warning(f"No metrics found for neighbor {nei}") return self._get_default_metric_values() @@ -778,7 +778,7 @@ async def calculate_value_metrics(self, addr, nei, metrics_active=None): } self._log_metrics_graphics(metric_results, addr, nei, current_round) - + return ( metric_results["messages"]["avg"], metric_results["similarity"], @@ -802,7 +802,7 @@ def _process_num_messages_metric(self, metrics_instance, addr: str, nei: str, cu filtered_messages = [ msg for msg in metrics_instance.messages if msg.get("current_round") == current_round ] - + for msg in filtered_messages: self.messages_number_message.append({ "number_message": msg.get("time"), @@ -813,9 +813,9 @@ def _process_num_messages_metric(self, metrics_instance, addr: str, nei: str, cu normalized, count = self.manage_metric_number_message( self.messages_number_message, addr, nei, current_round, True ) - + avg = self.save_number_message_history(addr, nei, normalized, current_round) - + if avg is None and current_round > self.HISTORY_ROUNDS_LOOKBACK: avg = self.number_message_history[(addr, nei)][current_round - 1]["avg_number_message"] @@ -901,7 +901,7 @@ def _process_model_arrival_latency_metric(self, metrics_instance, addr: str, nei if avg_latency is None and current_round > 1: avg_latency = self.model_arrival_latency_history[(addr, nei)][current_round - 1]["score"] return avg_latency or 0 - + return 0 def _process_model_similarity_metric(self, nei: str, current_round: int, metrics_active) -> float: @@ -938,7 +938,7 @@ def create_graphics_to_metrics( ): """ Create and log graphics for reputation metrics. - + Args: number_message_count: Count of messages for logging number_message_norm: Normalized message metric @@ -952,25 +952,25 @@ def create_graphics_to_metrics( """ if current_round is None or current_round >= total_rounds: return - + self.engine.trainer._logger.log_data( - {f"R-Model_arrival_latency_reputation/{addr}": {nei: model_arrival_latency}}, + {f"R-Model_arrival_latency_reputation/{addr}": {nei: model_arrival_latency}}, step=current_round ) self.engine.trainer._logger.log_data( - {f"R-Count_messages_number_message_reputation/{addr}": {nei: number_message_count}}, + {f"R-Count_messages_number_message_reputation/{addr}": {nei: number_message_count}}, step=current_round ) self.engine.trainer._logger.log_data( - {f"R-number_message_reputation/{addr}": {nei: number_message_norm}}, + {f"R-number_message_reputation/{addr}": {nei: number_message_norm}}, step=current_round ) self.engine.trainer._logger.log_data( - {f"R-Similarity_reputation/{addr}": {nei: similarity}}, + {f"R-Similarity_reputation/{addr}": {nei: similarity}}, step=current_round ) self.engine.trainer._logger.log_data( - {f"R-Fraction_reputation/{addr}": {nei: fraction}}, + {f"R-Fraction_reputation/{addr}": {nei: fraction}}, step=current_round ) @@ -991,7 +991,7 @@ def analyze_anomalies( try: key = (addr, nei, current_round) self._initialize_fraction_history_entry(key, fraction_changed, threshold) - + if current_round == 0: return self._handle_initial_round_anomalies(key, fraction_changed, threshold) else: @@ -1032,16 +1032,16 @@ def _handle_subsequent_round_anomalies( ) -> float: """Handle anomaly analysis for subsequent rounds.""" prev_stats = self._find_previous_valid_stats(addr, nei, current_round) - + if prev_stats is None: logging.warning(f"No valid previous stats found for {addr}, {nei}, round {current_round}") return 1.0 - + anomalies = self._detect_anomalies(fraction_changed, threshold, prev_stats) values = self._calculate_anomaly_values(fraction_changed, threshold, prev_stats, anomalies) fraction_score = self._calculate_combined_score(values) self._update_fraction_statistics(key, fraction_changed, threshold, prev_stats, anomalies, fraction_score) - + return max(fraction_score, 0) def _find_previous_valid_stats(self, addr: str, nei: str, current_round: int) -> dict: @@ -1049,18 +1049,18 @@ def _find_previous_valid_stats(self, addr: str, nei: str, current_round: int) -> for i in range(1, current_round + 1): candidate_key = (addr, nei, current_round - i) candidate_data = self.fraction_changed_history.get(candidate_key, {}) - + required_keys = ["mean_fraction", "std_dev_fraction", "mean_threshold", "std_dev_threshold"] if all(candidate_data.get(k) is not None for k in required_keys): return candidate_data - + return None def _detect_anomalies(self, current_fraction: float, current_threshold: float, prev_stats: dict) -> dict: """Detect if current values are anomalous compared to previous statistics.""" upper_mean_fraction = (prev_stats["mean_fraction"] + prev_stats["std_dev_fraction"]) * self.FRACTION_ANOMALY_MULTIPLIER upper_mean_threshold = (prev_stats["mean_threshold"] + prev_stats["std_dev_threshold"]) * self.THRESHOLD_ANOMALY_MULTIPLIER - + return { "fraction_anomaly": current_fraction > upper_mean_fraction, "threshold_anomaly": current_threshold > upper_mean_threshold, @@ -1074,19 +1074,19 @@ def _calculate_anomaly_values( """Calculate penalty values for fraction and threshold anomalies.""" fraction_value = 1.0 threshold_value = 1.0 - + if anomalies["fraction_anomaly"]: mean_fraction_prev = prev_stats["mean_fraction"] if mean_fraction_prev > 0: penalization_factor = abs(current_fraction - mean_fraction_prev) / mean_fraction_prev fraction_value = 1 - (1 / (1 + np.exp(-penalization_factor))) - + if anomalies["threshold_anomaly"]: mean_threshold_prev = prev_stats["mean_threshold"] if mean_threshold_prev > 0: penalization_factor = abs(current_threshold - mean_threshold_prev) / mean_threshold_prev threshold_value = 1 - (1 / (1 + np.exp(-penalization_factor))) - + return { "fraction_value": fraction_value, "threshold_value": threshold_value, @@ -1099,19 +1099,19 @@ def _calculate_combined_score(self, values: dict) -> float: return fraction_weight * values["fraction_value"] + threshold_weight * values["threshold_value"] def _update_fraction_statistics( - self, key: tuple, current_fraction: float, current_threshold: float, + self, key: tuple, current_fraction: float, current_threshold: float, prev_stats: dict, anomalies: dict, fraction_score: float ): """Update the fraction statistics for the current round.""" self.fraction_changed_history[key]["fraction_anomaly"] = anomalies["fraction_anomaly"] self.fraction_changed_history[key]["threshold_anomaly"] = anomalies["threshold_anomaly"] - + self.fraction_changed_history[key]["mean_fraction"] = (current_fraction + prev_stats["mean_fraction"]) / 2 self.fraction_changed_history[key]["mean_threshold"] = (current_threshold + prev_stats["mean_threshold"]) / 2 - + fraction_variance = ((current_fraction - prev_stats["mean_fraction"]) ** 2 + prev_stats["std_dev_fraction"] ** 2) / 2 threshold_variance = ((self.THRESHOLD_VARIANCE_MULTIPLIER * (current_threshold - prev_stats["mean_threshold"]) ** 2) + prev_stats["std_dev_threshold"] ** 2) / 2 - + self.fraction_changed_history[key]["std_dev_fraction"] = np.sqrt(fraction_variance) self.fraction_changed_history[key]["std_dev_threshold"] = np.sqrt(threshold_variance) self.fraction_changed_history[key]["fraction_score"] = fraction_score @@ -1132,9 +1132,9 @@ def manage_model_arrival_latency(self, addr, nei, latency, current_round, round_ """ try: current_key = nei - + self._initialize_latency_round_entry(current_round, current_key, latency) - + if current_round >= 1: score = self._calculate_latency_score(current_round, current_key, latency) self._update_latency_entry_with_score(current_round, current_key, score) @@ -1161,17 +1161,17 @@ def _calculate_latency_score(self, current_round: int, current_key: str, latency """Calculate the latency score based on historical data.""" target_round = self._get_target_round_for_latency(current_round) all_latencies = self._get_all_latencies_for_round(target_round) - + if not all_latencies: return 0.0 - + mean_latency = np.mean(all_latencies) augment_mean = mean_latency * self.LATENCY_AUGMENT_FACTOR - + if latency is None: logging.info(f"latency is None in round {current_round} for nei {current_key}") return -0.5 - + if latency <= augment_mean: return 1.0 else: @@ -1195,7 +1195,7 @@ def _update_latency_entry_with_score(self, current_round: int, current_key: str, target_round = self._get_target_round_for_latency(current_round) all_latencies = self._get_all_latencies_for_round(target_round) mean_latency = np.mean(all_latencies) if all_latencies else 0 - + self.model_arrival_latency_history[current_round][current_key].update({ "mean_latency": mean_latency, "score": score, @@ -1215,9 +1215,9 @@ def save_model_arrival_latency_history(self, nei, model_arrival_latency, round_n """ try: current_key = nei - + self._initialize_latency_history_entry(round_num, current_key, model_arrival_latency) - + if model_arrival_latency > 0 and round_num >= 1: avg_model_arrival_latency = self._calculate_latency_weighted_average_positive( round_num, current_key, model_arrival_latency @@ -1236,7 +1236,7 @@ def save_model_arrival_latency_history(self, nei, model_arrival_latency, round_n ) return avg_model_arrival_latency - + except Exception: logging.exception("Error saving model_arrival_latency history") @@ -1284,14 +1284,14 @@ def manage_metric_number_message( ) -> tuple[float, int]: """ Manage the number of messages metric for a specific neighbor. - + Args: messages_number_message: List of message data addr: Source address nei: Neighbor address current_round: Current round number metric_active: Whether the metric is active - + Returns: Tuple of (normalized_messages, messages_count) """ @@ -1301,13 +1301,13 @@ def manage_metric_number_message( messages_count = self._count_relevant_messages(messages_number_message, addr, nei, current_round) neighbor_stats = self._calculate_neighbor_statistics(messages_number_message, current_round) - + normalized_messages = self._calculate_normalized_messages(messages_count, neighbor_stats) - + normalized_messages = self._apply_historical_penalty( normalized_messages, addr, nei, current_round ) - + self._store_message_history(addr, nei, current_round, normalized_messages) normalized_messages = max(0.001, normalized_messages) @@ -1339,7 +1339,7 @@ def _calculate_neighbor_statistics(self, messages: list, current_round: int) -> neighbor_counts[key] = neighbor_counts.get(key, 0) + 1 counts_all_neighbors = list(neighbor_counts.values()) - + if not counts_all_neighbors: return { "percentile_reference": 0, @@ -1349,7 +1349,7 @@ def _calculate_neighbor_statistics(self, messages: list, current_round: int) -> } mean_messages = np.mean(counts_all_neighbors) - + return { "percentile_reference": np.percentile(counts_all_neighbors, 25), "std_dev": np.std(counts_all_neighbors), @@ -1361,10 +1361,10 @@ def _calculate_normalized_messages(self, messages_count: int, neighbor_stats: di """Calculate normalized message score with relative and extra penalties.""" normalized_messages = 1.0 penalties_applied = [] - + relative_increase = self._calculate_relative_increase(messages_count, neighbor_stats["percentile_reference"]) dynamic_margin = self._calculate_dynamic_margin(neighbor_stats) - + if relative_increase > dynamic_margin: penalty_ratio = self._calculate_penalty_ratio(relative_increase, dynamic_margin) normalized_messages *= np.exp(-(penalty_ratio**2)) @@ -1400,7 +1400,7 @@ def _calculate_penalty_ratio(self, relative_increase: float, dynamic_margin: flo def _should_apply_extra_penalty(self, messages_count: int, neighbor_stats: dict) -> bool: """Determine if extra penalty should be applied.""" - return (neighbor_stats["mean_messages"] > 0 and + return (neighbor_stats["mean_messages"] > 0 and messages_count > neighbor_stats["augment_mean"]) def _calculate_extra_penalty_factor(self, messages_count: int, neighbor_stats: dict) -> float: @@ -1408,7 +1408,7 @@ def _calculate_extra_penalty_factor(self, messages_count: int, neighbor_stats: d epsilon = 1e-6 mean_messages = neighbor_stats["mean_messages"] augment_mean = neighbor_stats["augment_mean"] - + extra_penalty = (messages_count - mean_messages) / (mean_messages + epsilon) amplification = 1 + (augment_mean / (mean_messages + epsilon)) return extra_penalty * amplification @@ -1417,27 +1417,27 @@ def _apply_historical_penalty(self, normalized_messages: float, addr: str, nei: """Apply historical penalty based on previous round's score.""" if current_round <= 1: return normalized_messages - + prev_data = ( self.number_message_history.get((addr, nei), {}) .get(current_round - 1, {}) ) - + prev_score = prev_data.get("normalized_messages") was_previously_penalized = prev_data.get("was_penalized", False) - + if prev_score is not None and prev_score < self.HISTORICAL_PENALTY_THRESHOLD: original_score = normalized_messages - + if was_previously_penalized: penalty_factor = self.HISTORICAL_PENALTY_THRESHOLD * 0.8 logging.debug(f"Repeated penalty applied to {nei}: stricter historical penalty") else: penalty_factor = self.HISTORICAL_PENALTY_THRESHOLD - + normalized_messages *= penalty_factor logging.debug(f"Historical penalty applied to {nei}: {original_score:.4f} -> {normalized_messages:.4f} (prev_score: {prev_score:.4f}, was_penalized: {was_previously_penalized})") - + return normalized_messages def _store_message_history(self, addr: str, nei: str, current_round: int, normalized_messages: float): @@ -1445,9 +1445,9 @@ def _store_message_history(self, addr: str, nei: str, current_round: int, normal key = (addr, nei) if key not in self.number_message_history: self.number_message_history[key] = {} - + was_penalized = normalized_messages < 1.0 - + self.number_message_history[key][current_round] = { "normalized_messages": normalized_messages, "was_penalized": was_penalized, @@ -1464,9 +1464,9 @@ def save_number_message_history(self, addr, nei, messages_number_message_normali """ try: key = (addr, nei) - + self._initialize_message_history_entry(key, current_round, messages_number_message_normalized) - + if messages_number_message_normalized > 0 and current_round >= 1: avg_number_message = self._calculate_weighted_average_positive(key, current_round, messages_number_message_normalized) elif messages_number_message_normalized == 0 and current_round >= 1: @@ -1478,7 +1478,7 @@ def save_number_message_history(self, addr, nei, messages_number_message_normali self.number_message_history[key][current_round]["avg_number_message"] = avg_number_message return avg_number_message - + except Exception: logging.exception("Error saving number_message history") return -1 @@ -1524,7 +1524,7 @@ async def save_reputation_history_in_memory(self, addr: str, nei: str, reputatio Args: addr: The node's identifier - nei: The neighboring node identifier + nei: The neighboring node identifier reputation: The reputation value to save Returns: @@ -1533,27 +1533,27 @@ async def save_reputation_history_in_memory(self, addr: str, nei: str, reputatio try: key = (addr, nei) current_round = await self._engine.get_round() - + if key not in self.reputation_history: self.reputation_history[key] = {} self.reputation_history[key][current_round] = reputation rounds = sorted(self.reputation_history[key].keys(), reverse=True)[:2] - + if len(rounds) >= 2: current_rep = self.reputation_history[key][rounds[0]] previous_rep = self.reputation_history[key][rounds[1]] - + current_weight = self.REPUTATION_CURRENT_WEIGHT previous_weight = self.REPUTATION_FEEDBACK_WEIGHT avg_reputation = (current_rep * current_weight) + (previous_rep * previous_weight) - + logging.info(f"Current reputation: {current_rep}, Previous reputation: {previous_rep}") logging.info(f"Reputation ponderated: {avg_reputation}") else: avg_reputation = reputation - + return avg_reputation except Exception: @@ -1577,23 +1577,23 @@ def calculate_similarity_from_metrics(self, nei: str, current_round: int) -> flo return 0.0 relevant_metrics = [ - metric for metric in metrics_instance.similarity + metric for metric in metrics_instance.similarity if metric.get("nei") == nei and metric.get("current_round") == current_round ] - + if not relevant_metrics: relevant_metrics = [ - metric for metric in metrics_instance.similarity + metric for metric in metrics_instance.similarity if metric.get("nei") == nei ] - + if not relevant_metrics: return 0.0 neighbor_metric = relevant_metrics[-1] similarity_weights = { "cosine": 0.25, - "euclidean": 0.25, + "euclidean": 0.25, "manhattan": 0.25, "pearson_correlation": 0.25, } @@ -1604,7 +1604,7 @@ def calculate_similarity_from_metrics(self, nei: str, current_round: int) -> flo ) return max(0.0, min(1.0, similarity_value)) - + except Exception: return 0.0 @@ -1620,9 +1620,9 @@ async def calculate_reputation(self, ae: AggregationEvent): (updates, _, _) = await ae.get_event_data() await self._log_reputation_calculation_start() - + neighbors = set(await self._engine._cm.get_addrs_current_connections(only_direct=True)) - + await self._process_neighbor_metrics(neighbors) await self._calculate_reputation_by_factor(neighbors) await self._handle_initial_reputation() @@ -1644,7 +1644,7 @@ async def _process_neighbor_metrics(self, neighbors): metrics = await self.calculate_value_metrics( self._addr, nei, metrics_active=self._metrics ) - + if self._weighting_factor == "dynamic": await self._process_dynamic_metrics(nei, metrics) elif self._weighting_factor == "static" and await self._engine.get_round() >= 1: @@ -1653,7 +1653,7 @@ async def _process_neighbor_metrics(self, neighbors): async def _process_dynamic_metrics(self, nei, metrics): """Process metrics for dynamic weighting factor.""" (metric_messages_number, metric_similarity, metric_fraction, metric_model_arrival_latency) = metrics - + self.calculate_weighted_values( metric_messages_number, metric_similarity, @@ -1669,7 +1669,7 @@ async def _process_dynamic_metrics(self, nei, metrics): async def _process_static_metrics(self, nei, metrics): """Process metrics for static weighting factor.""" (metric_messages_number, metric_similarity, metric_fraction, metric_model_arrival_latency) = metrics - + metric_values_dict = { "num_messages": metric_messages_number, "model_similarity": metric_similarity, @@ -1686,7 +1686,7 @@ async def _calculate_reputation_by_factor(self, neighbors): async def _handle_initial_reputation(self): """Handle reputation initialization for the first round.""" if await self._engine.get_round() < 1 and self._enabled: - federation = self._engine.config.participant["network_args"]["neighbors"].split() + federation = self._engine.config.participant["network_args"]["neighbors"] await self.init_reputation( federation_nodes=federation, round_num=await self._engine.get_round(), @@ -1698,7 +1698,7 @@ async def _process_feedback(self): """Process and include feedback in reputation.""" status = await self.include_feedback_in_reputation() current_round = await self._engine.get_round() - + if status: logging.info(f"Feedback included in reputation at round {current_round}") else: @@ -1735,7 +1735,7 @@ async def send_reputation_to_neighbors(self, neighbors): def create_graphic_reputation(self, addr: str, round_num: int): """ Log reputation data for visualization. - + Args: addr: The node address round_num: The round number for logging step @@ -1746,7 +1746,7 @@ def create_graphic_reputation(self, addr: str, round_num: int): for node_id, data in self.reputation.items() if data.get("reputation") is not None } - + if valid_reputations: reputation_data = {f"Reputation/{addr}": valid_reputations} self._engine.trainer._logger.log_data(reputation_data, step=round_num) @@ -1954,26 +1954,26 @@ def _recalculate_pending_latencies(self, current_round): async def recollect_similarity(self, ure: UpdateReceivedEvent): """ Collect and analyze model similarity metrics. - + Args: ure: UpdateReceivedEvent containing model and metadata """ (decoded_model, weight, nei, round_num, local) = await ure.get_event_data() - + if not (self._enabled and self._is_metric_enabled("model_similarity")): return - + if not self._engine.config.participant["adaptive_args"]["model_similarity"]: return - + if nei == self._addr: return - + logging.info("🤖 handle_model_message | Checking model similarity") - + local_model = self._engine.trainer.get_model_parameters() similarity_values = self._calculate_all_similarity_metrics(local_model, decoded_model) - + similarity_metrics = { "timestamp": datetime.now(), "nei": nei, @@ -1996,7 +1996,7 @@ def _calculate_all_similarity_metrics(self, local_model: dict, received_model: d "jaccard": 0.0, "minkowski": 0.0, } - + similarity_functions = [ ("cosine", cosine_metric), ("euclidean", euclidean_metric), @@ -2004,29 +2004,29 @@ def _calculate_all_similarity_metrics(self, local_model: dict, received_model: d ("pearson_correlation", pearson_correlation_metric), ("jaccard", jaccard_metric), ] - + similarity_values = {} - + for name, metric_func in similarity_functions: try: similarity_values[name] = metric_func(local_model, received_model, similarity=True) except Exception: similarity_values[name] = 0.0 - + try: similarity_values["minkowski"] = minkowski_metric( local_model, received_model, p=2, similarity=True ) except Exception: similarity_values["minkowski"] = 0.0 - + return similarity_values def _store_similarity_metrics(self, nei: str, similarity_metrics: dict): """Store similarity metrics for the given neighbor.""" if nei not in self.connection_metrics: self.connection_metrics[nei] = Metrics() - + self.connection_metrics[nei].similarity.append(similarity_metrics) def _check_similarity_threshold(self, nei: str, cosine_value: float): @@ -2064,25 +2064,25 @@ async def _record_message_data(self, source: str): async def recollect_fraction_of_parameters_changed(self, ure: UpdateReceivedEvent): """ Collect and analyze the fraction of parameters that changed between models. - + Args: ure: UpdateReceivedEvent containing model and metadata """ (decoded_model, weight, source, round_num, local) = await ure.get_event_data() - + current_round = await self._engine.get_round() parameters_local = self._engine.trainer.get_model_parameters() - + prev_threshold = self._get_previous_threshold(source, current_round) differences = self._calculate_parameter_differences(parameters_local, decoded_model) current_threshold = self._calculate_threshold(differences, prev_threshold) - + changed_params, total_params, changes_record = self._count_changed_parameters( parameters_local, decoded_model, current_threshold ) - + fraction_changed = changed_params / total_params if total_params > 0 else 0.0 - + self._store_fraction_data(source, current_round, { "fraction_changed": fraction_changed, "total_params": total_params, @@ -2102,7 +2102,7 @@ async def recollect_fraction_of_parameters_changed(self, ure: UpdateReceivedEven def _get_previous_threshold(self, source: str, current_round: int) -> float: """Get the threshold from the previous round for the given source.""" - if (source in self.fraction_of_params_changed and + if (source in self.fraction_of_params_changed and current_round - 1 in self.fraction_of_params_changed[source]): return self.fraction_of_params_changed[source][current_round - 1][-1]["threshold"] return None @@ -2122,7 +2122,7 @@ def _calculate_threshold(self, differences: list, prev_threshold: float) -> floa """Calculate the threshold for determining parameter changes.""" if not differences: return 0 - + mean_threshold = torch.mean(torch.tensor(differences)).item() if prev_threshold is not None: return (prev_threshold + mean_threshold) / 2 @@ -2133,20 +2133,20 @@ def _count_changed_parameters(self, local_params: dict, received_params: dict, t total_params = 0 changed_params = 0 changes_record = {} - + for key in local_params.keys(): if key in received_params: local_tensor = local_params[key].cpu() received_tensor = received_params[key].cpu() diff = torch.abs(local_tensor - received_tensor) total_params += diff.numel() - + num_changed = torch.sum(diff > threshold).item() changed_params += num_changed - + if num_changed > 0: changes_record[key] = num_changed - + return changed_params, total_params, changes_record def _store_fraction_data(self, source: str, current_round: int, data: dict): @@ -2155,5 +2155,5 @@ def _store_fraction_data(self, source: str, current_round: int, data: dict): self.fraction_of_params_changed[source] = {} if current_round not in self.fraction_of_params_changed[source]: self.fraction_of_params_changed[source][current_round] = [] - - self.fraction_of_params_changed[source][current_round].append(data) \ No newline at end of file + + self.fraction_of_params_changed[source][current_round].append(data) diff --git a/nebula/addons/topologymanager.py b/nebula/addons/topologymanager.py index c29937372..eda6429cd 100755 --- a/nebula/addons/topologymanager.py +++ b/nebula/addons/topologymanager.py @@ -322,15 +322,16 @@ def update_nodes(self, config_participants): def get_neighbors_string(self, node_idx): """ - Retrieves the neighbors of a given node as a string representation. + Retrieves the neighbors of a given node as a list of string representations. - This method checks the `topology` attribute to find the neighbors of the node at the specified index (`node_idx`). It then returns a string that lists the coordinates of each neighbor. + This method checks the `topology` attribute to find the neighbors of the node at the specified index (`node_idx`). + It then returns a list that contains the coordinates of each neighbor in string format. Parameters: node_idx (int): The index of the node for which neighbors are to be retrieved. Returns: - str: A space-separated string of neighbors' coordinates in the format "latitude:longitude". + list[str]: A list of neighbors' coordinates in the format "latitude:longitude". """ logging.info(f"Getting neighbors for node {node_idx}") logging.info(f"Topology shape: {self.topology.shape}") @@ -342,9 +343,8 @@ def get_neighbors_string(self, node_idx): logging.info(f"Found neighbor at index {i}: {self.nodes[i]}") neighbors_data_strings = [f"{i[0]}:{i[1]}" for i in neighbors_data] - neighbors_data_string = " ".join(neighbors_data_strings) - logging.info(f"Neighbors of node participant_{node_idx}: {neighbors_data_string}") - return neighbors_data_string + logging.info(f"Neighbors of node participant_{node_idx}: {neighbors_data_strings}") + return neighbors_data_strings def __ring_topology(self, increase_convergence=False): """ diff --git a/nebula/config/config.py b/nebula/config/config.py index 5ef336e3a..d5f3a813f 100755 --- a/nebula/config/config.py +++ b/nebula/config/config.py @@ -55,7 +55,7 @@ def reset_logging_configuration(self): self.__set_default_logging(mode="a") self.__set_training_logging(mode="a") - + def shutdown_logging(self): """ Properly shuts down all loggers and their handlers in the system. @@ -204,40 +204,46 @@ def add_participants_config(self, participants_config): def add_neighbor_from_config(self, addr): if self.participant != {}: - if self.participant["network_args"]["neighbors"] == "": - self.participant["network_args"]["neighbors"] = addr + neighbors = self.participant["network_args"]["neighbors"] + + if not neighbors: + self.participant["network_args"]["neighbors"] = [addr] self.participant["mobility_args"]["neighbors_distance"][addr] = None else: - if addr not in self.participant["network_args"]["neighbors"]: - self.participant["network_args"]["neighbors"] += " " + addr + if addr not in neighbors: + self.participant["network_args"]["neighbors"].append(addr) self.participant["mobility_args"]["neighbors_distance"][addr] = None def update_nodes_distance(self, distances: dict): self.participant["mobility_args"]["neighbors_distance"] = {node: dist for node, (dist, _) in distances.items()} def update_neighbors_from_config(self, current_connections, dest_addr): - final_neighbors = [] - for n in current_connections: - if n != dest_addr: - final_neighbors.append(n) - - final_neighbors_string = " ".join(final_neighbors) - # Update neighbors - self.participant["network_args"]["neighbors"] = final_neighbors_string + final_neighbors = [n for n in current_connections if n != dest_addr] + + # Update neighbors como lista (no string) + self.participant["network_args"]["neighbors"] = final_neighbors + # Update neighbors location self.participant["mobility_args"]["neighbors_distance"] = { n: self.participant["mobility_args"]["neighbors_distance"][n] for n in final_neighbors if n in self.participant["mobility_args"]["neighbors_distance"] } - logging.info(f"Final neighbors: {final_neighbors_string} (config updated))") + + logging.info(f"Final neighbors: {final_neighbors} (config updated)") + def remove_neighbor_from_config(self, addr): - if self.participant != {}: - if self.participant["network_args"]["neighbors"] != "": - self.participant["network_args"]["neighbors"] = ( - self.participant["network_args"]["neighbors"].replace(addr, "").replace(" ", " ").strip() - ) + if self.participant: + neighbors = self.participant["network_args"]["neighbors"] + + if addr in neighbors: + neighbors.remove(addr) + self.participant["network_args"]["neighbors"] = neighbors + + if addr in self.participant["mobility_args"]["neighbors_distance"]: + del self.participant["mobility_args"]["neighbors_distance"][addr] + def reload_config_file(self): config_dir = self.participant["tracking_args"]["config_dir"] diff --git a/nebula/controller/controller.py b/nebula/controller/controller.py index acf03c19f..7e59ae7ec 100755 --- a/nebula/controller/controller.py +++ b/nebula/controller/controller.py @@ -7,6 +7,7 @@ import logging import os import re +import copy from typing import Annotated import aiohttp @@ -15,7 +16,13 @@ from fastapi import Body, FastAPI, Request, status, HTTPException, Path, File, UploadFile from fastapi.concurrency import asynccontextmanager -from nebula.controller.database import scenario_set_all_status_to_finished, scenario_set_status_to_finished +from nebula.controller.database import ( + init_db_pool, + close_db_pool, + insert_default_admin, + scenario_set_all_status_to_finished, + scenario_set_status_to_finished, +) from nebula.controller.http_helpers import remote_get, remote_post_form from nebula.utils import DockerUtils @@ -66,9 +73,6 @@ def format(self, record): return super().format(record) -os.environ["NEBULA_CONTROLLER_NAME"] = os.environ.get("USER") - - def configure_logger(controller_log): """ Configures the logging system for the controller. @@ -105,17 +109,25 @@ def configure_logger(controller_log): @asynccontextmanager async def lifespan(app: FastAPI): - databases_dir: str = os.environ.get("NEBULA_DATABASES_DIR") + """ + Application lifespan context manager. + - Initializes the database connection pool on startup. + - Configures logging. + - Cleans up resources like the database pool on shutdown. + """ + # Code to run on startup controller_log: str = os.environ.get("NEBULA_CONTROLLER_LOG") - - from nebula.controller.database import initialize_databases - - await initialize_databases(databases_dir) - configure_logger(controller_log) + # Initialize the database connection pool + await init_db_pool() + await insert_default_admin() + yield + # Code to run on shutdown + await close_db_pool() + # Initialize FastAPI app outside the Controller class app = FastAPI(lifespan=lifespan) @@ -309,6 +321,8 @@ async def run_scenario( validate_physical_fields(scenario_data) + db_scenario = copy.deepcopy(scenario_data) + # Manager for the actual scenario scenarioManagement = ScenarioManagement(scenario_data, user) @@ -342,7 +356,6 @@ async def run_scenario( @app.post("/scenarios/stop") async def stop_scenario( scenario_name: str = Body(..., embed=True), - username: str = Body(..., embed=True), all: bool = Body(False, embed=True), ): """ @@ -370,9 +383,9 @@ async def stop_scenario( ScenarioManagement.cleanup_scenario_containers() try: if all: - scenario_set_all_status_to_finished() + await scenario_set_all_status_to_finished() else: - scenario_set_status_to_finished(scenario_name) + await scenario_set_status_to_finished(scenario_name) except Exception as e: logging.exception(f"Error setting scenario {scenario_name} to finished: {e}") raise HTTPException(status_code=500, detail="Internal server error") @@ -391,12 +404,13 @@ async def remove_scenario( Returns: dict: A message indicating successful removal. """ - from nebula.controller.database import remove_scenario_by_name + from nebula.controller.database import remove_scenario_by_name, get_user_by_scenario_name from nebula.controller.scenarios import ScenarioManagement try: - remove_scenario_by_name(scenario_name) + await remove_scenario_by_name(scenario_name) ScenarioManagement.remove_files_by_scenario(scenario_name) + except Exception as e: logging.exception(f"Error removing scenario {scenario_name}: {e}") raise HTTPException(status_code=500, detail="Internal server error") @@ -422,11 +436,12 @@ async def get_scenarios( from nebula.controller.database import get_all_scenarios_and_check_completed, get_running_scenario try: - scenarios = get_all_scenarios_and_check_completed(username=user, role=role) + scenarios = await get_all_scenarios_and_check_completed(username=user, role=role) + if role == "admin": - scenario_running = get_running_scenario() + scenario_running = await get_running_scenario() else: - scenario_running = get_running_scenario(username=user) + scenario_running = await get_running_scenario(username=user) except Exception as e: logging.exception(f"Error obtaining scenarios: {e}") raise HTTPException(status_code=500, detail="Internal server error") @@ -460,11 +475,9 @@ async def update_scenario( dict: A message confirming the update. """ from nebula.controller.database import scenario_update_record - from nebula.controller.scenarios import Scenario try: - scenario = Scenario.from_dict(scenario) - scenario_update_record(scenario_name, start_time, end_time, scenario, status, role, username) + await scenario_update_record(scenario_name, start_time, end_time, scenario, status, username) except Exception as e: logging.exception(f"Error updating scenario {scenario_name}: {e}") raise HTTPException(status_code=500, detail="Internal server error") @@ -490,9 +503,9 @@ async def set_scenario_status_to_finished( try: if all: - scenario_set_all_status_to_finished() + await scenario_set_all_status_to_finished() else: - scenario_set_status_to_finished(scenario_name) + await scenario_set_status_to_finished(scenario_name) except Exception as e: logging.exception(f"Error setting scenario {scenario_name} to finished: {e}") raise HTTPException(status_code=500, detail="Internal server error") @@ -514,14 +527,15 @@ async def get_running_scenario(get_all: bool = False): from nebula.controller.database import get_running_scenario try: - return get_running_scenario(get_all=get_all) + return await get_running_scenario(get_all=get_all) except Exception as e: logging.exception(f"Error obtaining running scenario: {e}") raise HTTPException(status_code=500, detail="Internal server error") -@app.get("/scenarios/check/{role}/{scenario_name}") +@app.get("/scenarios/check/{user}/{role}/{scenario_name}") async def check_scenario( + user: Annotated[str, Path(regex="^[a-zA-Z0-9_-]+$", min_length=1, max_length=50, description="Valid username")], role: Annotated[str, Path(regex="^[a-zA-Z0-9_-]+$", min_length=1, max_length=50, description="Valid role")], scenario_name: Annotated[ str, Path(regex="^[a-zA-Z0-9_-]+$", min_length=1, max_length=50, description="Valid scenario name") @@ -540,7 +554,7 @@ async def check_scenario( from nebula.controller.database import check_scenario_with_role try: - allowed = check_scenario_with_role(role, scenario_name) + allowed = await check_scenario_with_role(role, scenario_name, user) return {"allowed": allowed} except Exception as e: logging.exception(f"Error checking scenario with role: {e}") @@ -565,7 +579,7 @@ async def get_scenario_by_name( from nebula.controller.database import get_scenario_by_name try: - scenario = get_scenario_by_name(scenario_name) + scenario = await get_scenario_by_name(scenario_name) except Exception as e: logging.exception(f"Error obtaining scenario {scenario_name}: {e}") raise HTTPException(status_code=500, detail="Internal server error") @@ -591,7 +605,7 @@ async def list_nodes_by_scenario_name( from nebula.controller.database import list_nodes_by_scenario_name try: - nodes = list_nodes_by_scenario_name(scenario_name) + nodes = await list_nodes_by_scenario_name(scenario_name) except Exception as e: logging.exception(f"Error obtaining nodes: {e}") raise HTTPException(status_code=500, detail="Internal server error") @@ -629,7 +643,7 @@ async def update_nodes( str(config["network_args"]["ip"]), str(config["network_args"]["port"]), str(config["device_args"]["role"]), - str(config["network_args"]["neighbors"]), + config["network_args"]["neighbors"], str(config["mobility_args"]["latitude"]), str(config["mobility_args"]["longitude"]), str(timestamp), @@ -644,7 +658,7 @@ async def update_nodes( raise HTTPException(status_code=500, detail="Internal server error") url = ( - f"http://{os.environ['NEBULA_CONTROLLER_NAME']}_nebula-frontend/platform/dashboard/{scenario_name}/node/update" + f"http://{os.environ['NEBULA_ENV_TAG']}_{os.environ['NEBULA_PREFIX_TAG']}_{os.environ['NEBULA_USER_TAG']}_nebula-frontend/platform/dashboard/{scenario_name}/node/update" ) config["timestamp"] = str(timestamp) @@ -679,7 +693,7 @@ async def node_done( Returns the response from the frontend or raises an HTTPException if it fails. """ - url = f"http://{os.environ['NEBULA_CONTROLLER_NAME']}_nebula-frontend/platform/dashboard/{scenario_name}/node/done" + url = f"http://{os.environ['NEBULA_ENV_TAG']}_{os.environ['NEBULA_PREFIX_TAG']}_{os.environ['NEBULA_USER_TAG']}_nebula-frontend/platform/dashboard/{scenario_name}/node/done" data = await request.json() @@ -706,7 +720,7 @@ async def remove_nodes_by_scenario_name(scenario_name: str = Body(..., embed=Tru from nebula.controller.database import remove_nodes_by_scenario_name try: - remove_nodes_by_scenario_name(scenario_name) + await remove_nodes_by_scenario_name(scenario_name) except Exception as e: logging.exception(f"Error removing nodes: {e}") raise HTTPException(status_code=500, detail="Internal server error") @@ -722,22 +736,21 @@ async def get_notes_by_scenario_name( ): """ Endpoint to retrieve notes associated with a scenario. - - Path Parameters: - - scenario_name: Name of the scenario. - - Returns the notes or raises an HTTPException on error. """ from nebula.controller.database import get_notes try: - notes = get_notes(scenario_name) + notes_record = await get_notes(scenario_name) + + if notes_record is not None: + notes_record = dict(notes_record.items()) + + return notes_record + except Exception as e: - logging.exception(f"Error obtaining notes {notes}: {e}") + logging.exception(f"Error obtaining notes for scenario {scenario_name}: {e}") raise HTTPException(status_code=500, detail="Internal server error") - return notes - @app.post("/notes/update") async def update_notes_by_scenario_name(scenario_name: str = Body(..., embed=True), notes: str = Body(..., embed=True)): @@ -753,7 +766,7 @@ async def update_notes_by_scenario_name(scenario_name: str = Body(..., embed=Tru from nebula.controller.database import save_notes try: - save_notes(scenario_name, notes) + await save_notes(scenario_name, notes) except Exception as e: logging.exception(f"Error updating notes: {e}") raise HTTPException(status_code=500, detail="Internal server error") @@ -774,7 +787,7 @@ async def remove_notes_by_scenario_name(scenario_name: str = Body(..., embed=Tru from nebula.controller.database import remove_note try: - remove_note(scenario_name) + await remove_note(scenario_name) except Exception as e: logging.exception(f"Error removing notes: {e}") raise HTTPException(status_code=500, detail="Internal server error") @@ -795,9 +808,9 @@ async def list_users_controller(all_info: bool = False): from nebula.controller.database import list_users try: - user_list = list_users(all_info) + user_list = await list_users(all_info) if all_info: - # Convert each sqlite3.Row to a dictionary so that it is JSON serializable. + # Convert each asyncpg.Record to a dictionary so that it is JSON serializable. user_list = [dict(user) for user in user_list] return {"users": user_list} except Exception as e: @@ -821,7 +834,7 @@ async def get_user_by_scenario_name( from nebula.controller.database import get_user_by_scenario_name try: - user = get_user_by_scenario_name(scenario_name) + user = await get_user_by_scenario_name(scenario_name) except Exception as e: logging.exception(f"Error obtaining user {user}: {e}") raise HTTPException(status_code=500, detail="Internal server error") @@ -1026,7 +1039,7 @@ async def add_user_controller(user: str = Body(...), password: str = Body(...), from nebula.controller.database import add_user try: - add_user(user, password, role) + await add_user(user, password, role) return {"detail": "User added successfully"} except Exception as e: logging.exception(f"Error adding user: {e}") @@ -1046,7 +1059,7 @@ async def remove_user_controller(user: str = Body(..., embed=True)): from nebula.controller.database import delete_user_from_db try: - delete_user_from_db(user) + await delete_user_from_db(user) return {"detail": "User deleted successfully"} except Exception as e: logging.exception(f"Error deleting user: {e}") @@ -1068,7 +1081,7 @@ async def update_user_controller(user: str = Body(...), password: str = Body(... from nebula.controller.database import update_user try: - update_user(user, password, role) + await update_user(user, password, role) return {"detail": "User updated successfully"} except Exception as e: logging.exception(f"Error updating user: {e}") @@ -1090,8 +1103,8 @@ async def verify_user_controller(user: str = Body(...), password: str = Body(... try: user_submitted = user.upper() - if (user_submitted in list_users()) and verify(user_submitted, password): - user_info = get_user_info(user_submitted) + if (await list_users() and await verify(user_submitted, password)): + user_info = await get_user_info(user_submitted) return {"user": user_submitted, "role": user_info[2]} else: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) diff --git a/nebula/controller/database.py b/nebula/controller/database.py index 7a012fd8a..407ce3908 100755 --- a/nebula/controller/database.py +++ b/nebula/controller/database.py @@ -1,297 +1,85 @@ -import asyncio -import datetime -import json import logging import os -import sqlite3 +import datetime +import json +import asyncpg +import asyncio -import aiosqlite -from argon2 import PasswordHasher +from passlib.context import CryptContext -user_db_file_location = None -node_db_file_location = None -scenario_db_file_location = None -notes_db_file_location = None +# --- Configuration --- +# Use environment variables for database credentials from the Docker Compose file +DATABASE_URL = f"postgresql://{os.environ.get('DB_USER')}:{os.environ.get('DB_PASSWORD')}@{os.environ.get('DB_HOST')}:{os.environ.get('DB_PORT')}/nebula" -_node_lock = asyncio.Lock() +# Password hashing context (using Argon2) +pwd_context = CryptContext(schemes=["argon2"], deprecated="auto") -PRAGMA_SETTINGS = [ - "PRAGMA journal_mode=WAL;", - "PRAGMA synchronous=NORMAL;", - "PRAGMA journal_size_limit=1048576;", - "PRAGMA cache_size=10000;", - "PRAGMA temp_store=MEMORY;", - "PRAGMA cache_spill=0;", -] +# Asynchronous lock for node updates +_node_lock = asyncio.Lock() +# --- Connection Pool Management --- +# Global pool variable, should be initialized at application startup +POOL = None -async def setup_database(db_file_location): +async def init_db_pool(): """ - Initializes the SQLite database with the required PRAGMA settings. - - This function: - - Connects asynchronously to the specified SQLite database file. - - Applies a predefined list of PRAGMA settings to configure the database. - - Commits the changes after applying the settings. - - Args: - db_file_location (str): Path to the SQLite database file. - - Exceptions: - PermissionError: Logged if the application lacks permission to create or modify the database file. - Exception: Logs any other unexpected error that occurs during setup. + Initializes the asynchronous PostgreSQL connection pool. + This should be called once when the application starts. """ - try: - async with aiosqlite.connect(db_file_location) as db: - for pragma in PRAGMA_SETTINGS: - await db.execute(pragma) - await db.commit() - except PermissionError: - logging.info("No permission to create the databases. Change the default databases directory") - except Exception as e: - logging.exception(f"An error has ocurred during setup_database: {e}") - + global POOL + if POOL is None: + try: + POOL = await asyncpg.create_pool( + dsn=DATABASE_URL, + min_size=5, # Minimum number of connections in the pool + max_size=20, # Maximum number of connections in the pool + ) + logging.info("Database connection pool successfully created.") + except Exception as e: + logging.critical(f"Failed to create database connection pool: {e}", exc_info=True) + # Exit or handle the failure appropriately + raise -async def ensure_columns(conn, table_name, desired_columns): +async def close_db_pool(): + """ + Closes the asynchronous PostgreSQL connection pool. + This should be called once when the application shuts down gracefully. """ - Ensures that a table contains all the desired columns, adding any that are missing. + global POOL + if POOL: + await POOL.close() + logging.info("Database connection pool closed.") - This function: - - Retrieves the current columns of the specified table. - - Compares them with the desired columns. - - Adds any missing columns to the table using ALTER TABLE statements. - Args: - conn (aiosqlite.Connection): Active connection to the SQLite database. - table_name (str): Name of the table to check and modify. - desired_columns (dict): Dictionary mapping column names to their SQL definitions. +# --- User Management Functions --- - Note: - This operation is committed immediately after all changes are applied. +async def insert_default_admin(): """ - _c = await conn.execute(f"PRAGMA table_info({table_name});") - existing_columns = [row[1] for row in await _c.fetchall()] - for column_name, column_definition in desired_columns.items(): - if column_name not in existing_columns: - await conn.execute(f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_definition};") - await conn.commit() - - -async def initialize_databases(databases_dir): + Inserts a default 'ADMIN' user into the database with a hashed password. + The password must be provided via the ADMIN_PASSWORD environment variable. """ - Initializes all required SQLite databases and their corresponding tables for the system. - - This function: - - Defines paths for user, node, scenario, and notes databases based on the provided directory. - - Sets up each database with appropriate PRAGMA settings. - - Creates necessary tables if they do not exist. - - Ensures all expected columns are present in each table, adding any missing ones. - - Creates a default admin user if no users are present. + admin_password = os.environ.get("NEBULA_ADMIN_PASSWORD") - Args: - databases_dir (str): Path to the directory where the database files will be created or accessed. + hashed_password = pwd_context.hash(admin_password) - Note: - Default credentials (username and password) are taken from environment variables: - - NEBULA_DEFAULT_USER - - NEBULA_DEFAULT_PASSWORD + query = """ + INSERT INTO users ("user", password, role) + VALUES ($1, $2, $3) + ON CONFLICT ("user") DO NOTHING; """ - global user_db_file_location, node_db_file_location, scenario_db_file_location, notes_db_file_location - - user_db_file_location = os.path.join(databases_dir, "users.db") - node_db_file_location = os.path.join(databases_dir, "nodes.db") - scenario_db_file_location = os.path.join(databases_dir, "scenarios.db") - notes_db_file_location = os.path.join(databases_dir, "notes.db") - - await setup_database(user_db_file_location) - await setup_database(node_db_file_location) - await setup_database(scenario_db_file_location) - await setup_database(notes_db_file_location) - - async with aiosqlite.connect(user_db_file_location) as conn: - await conn.execute( - """ - CREATE TABLE IF NOT EXISTS users ( - user TEXT PRIMARY KEY, - password TEXT, - role TEXT - ); - """ - ) - desired_columns = {"user": "TEXT PRIMARY KEY", "password": "TEXT", "role": "TEXT"} - await ensure_columns(conn, "users", desired_columns) - - async with aiosqlite.connect(node_db_file_location) as conn: - await conn.execute( - """ - CREATE TABLE IF NOT EXISTS nodes ( - uid TEXT PRIMARY KEY, - idx TEXT, - ip TEXT, - port TEXT, - role TEXT, - neighbors TEXT, - latitude TEXT, - longitude TEXT, - timestamp TEXT, - federation TEXT, - round TEXT, - scenario TEXT, - hash TEXT, - malicious TEXT - ); - """ - ) - desired_columns = { - "uid": "TEXT PRIMARY KEY", - "idx": "TEXT", - "ip": "TEXT", - "port": "TEXT", - "role": "TEXT", - "neighbors": "TEXT", - "latitude": "TEXT", - "longitude": "TEXT", - "timestamp": "TEXT", - "federation": "TEXT", - "round": "TEXT", - "scenario": "TEXT", - "hash": "TEXT", - "malicious": "TEXT", - } - await ensure_columns(conn, "nodes", desired_columns) - - async with aiosqlite.connect(scenario_db_file_location) as conn: - await conn.execute( - """ - CREATE TABLE IF NOT EXISTS scenarios ( - name TEXT PRIMARY KEY, - start_time TEXT, - end_time TEXT, - title TEXT, - description TEXT, - deployment TEXT, - federation TEXT, - topology TEXT, - nodes TEXT, - nodes_graph TEXT, - n_nodes TEXT, - matrix TEXT, - random_topology_probability TEXT, - dataset TEXT, - iid TEXT, - partition_selection TEXT, - partition_parameter TEXT, - model TEXT, - agg_algorithm TEXT, - rounds TEXT, - logginglevel TEXT, - report_status_data_queue TEXT, - accelerator TEXT, - network_subnet TEXT, - network_gateway TEXT, - epochs TEXT, - attack_params TEXT, - reputation TEXT, - random_geo TEXT, - latitude TEXT, - longitude TEXT, - mobility TEXT, - mobility_type TEXT, - radius_federation TEXT, - scheme_mobility TEXT, - round_frequency TEXT, - mobile_participants_percent TEXT, - additional_participants TEXT, - schema_additional_participants TEXT, - status TEXT, - role TEXT, - username TEXT, - gpu_id TEXT - ); - """ - ) - desired_columns = { - "name": "TEXT PRIMARY KEY", - "start_time": "TEXT", - "end_time": "TEXT", - "title": "TEXT", - "description": "TEXT", - "deployment": "TEXT", - "federation": "TEXT", - "topology": "TEXT", - "nodes": "TEXT", - "nodes_graph": "TEXT", - "n_nodes": "TEXT", - "matrix": "TEXT", - "random_topology_probability": "TEXT", - "dataset": "TEXT", - "iid": "TEXT", - "partition_selection": "TEXT", - "partition_parameter": "TEXT", - "model": "TEXT", - "agg_algorithm": "TEXT", - "rounds": "TEXT", - "logginglevel": "TEXT", - "report_status_data_queue": "TEXT", - "accelerator": "TEXT", - "gpu_id": "TEXT", - "network_subnet": "TEXT", - "network_gateway": "TEXT", - "epochs": "TEXT", - "attack_params": "TEXT", - "reputation": "TEXT", - "random_geo": "TEXT", - "latitude": "TEXT", - "longitude": "TEXT", - "mobility": "TEXT", - "mobility_type": "TEXT", - "radius_federation": "TEXT", - "scheme_mobility": "TEXT", - "round_frequency": "TEXT", - "mobile_participants_percent": "TEXT", - "additional_participants": "TEXT", - "schema_additional_participants": "TEXT", - "status": "TEXT", - "role": "TEXT", - "username": "TEXT", - } - await ensure_columns(conn, "scenarios", desired_columns) - - async with aiosqlite.connect(notes_db_file_location) as conn: - await conn.execute( - """ - CREATE TABLE IF NOT EXISTS notes ( - scenario TEXT PRIMARY KEY, - scenario_notes TEXT - ); - """ - ) - desired_columns = {"scenario": "TEXT PRIMARY KEY", "scenario_notes": "TEXT"} - await ensure_columns(conn, "notes", desired_columns) - - username = os.environ.get("NEBULA_DEFAULT_USER", "admin") - password = os.environ.get("NEBULA_DEFAULT_PASSWORD", "admin") - if not list_users(): - add_user(username, password, "admin") - if not verify_hash_algorithm(username): - update_user(username, password, "admin") - + try: + async with POOL.acquire() as conn: + await conn.execute(query, "ADMIN", hashed_password, "admin") + logging.info("Default admin user inserted (or already exists).") + except Exception as e: + logging.error(f"Failed to insert default admin user: {e}", exc_info=True) -def list_users(all_info=False): +async def list_users(all_info=False): """ Retrieves a list of users from the users database. - - Args: - all_info (bool): If True, returns full user records; otherwise, returns only usernames. Default is False. - - Returns: - list: A list of usernames or full user records depending on the all_info flag. """ - with sqlite3.connect(user_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - c.execute("SELECT * FROM users") - result = c.fetchall() + async with POOL.acquire() as conn: + result = await conn.fetch("SELECT * FROM users") if not all_info: result = [user["user"] for user in result] @@ -299,985 +87,549 @@ def list_users(all_info=False): return result -def get_user_info(user): +async def get_user_info(user): """ Fetches detailed information for a specific user from the users database. - - Args: - user (str): The username to retrieve information for. - - Returns: - sqlite3.Row or None: A row containing the user's information if found, otherwise None. """ - with sqlite3.connect(user_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - - command = "SELECT * FROM users WHERE user = ?" - c.execute(command, (user,)) - result = c.fetchone() - - return result + async with POOL.acquire() as conn: + return await conn.fetchrow('SELECT * FROM users WHERE "user" = $1', user) -def verify(user, password): +async def verify(user, password): """ Verifies whether the provided password matches the stored hashed password for a user. - - Args: - user (str): The username to verify. - password (str): The plain text password to check against the stored hash. - - Returns: - bool: True if the password is correct, False otherwise. """ - ph = PasswordHasher() - with sqlite3.connect(user_db_file_location) as conn: - c = conn.cursor() - - c.execute("SELECT password FROM users WHERE user = ?", (user,)) - result = c.fetchone() - if result: - try: - return ph.verify(result[0], password) - except: - return False + async with POOL.acquire() as conn: + result = await conn.fetchrow('SELECT password FROM users WHERE "user" = $1', user) + if result: + try: + return pwd_context.verify(password, result[0]) + except Exception: + # Catch more general exceptions during verification to be safe + logging.error(f"Error during password verification for user {user}", exc_info=True) + return False return False -def verify_hash_algorithm(user): +async def verify_hash_algorithm(user): """ Checks if the stored password hash for a user uses a supported Argon2 algorithm. - - Args: - user (str): The username to check (case-insensitive, converted to uppercase). - - Returns: - bool: True if the password hash starts with a valid Argon2 prefix, False otherwise. """ user = user.upper() argon2_prefixes = ("$argon2i$", "$argon2id$") - - with sqlite3.connect(user_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - - c.execute("SELECT password FROM users WHERE user = ?", (user,)) - result = c.fetchone() - if result: - password_hash = result["password"] - return password_hash.startswith(argon2_prefixes) - + async with POOL.acquire() as conn: + result = await conn.fetchrow('SELECT password FROM users WHERE "user" = $1', user) + if result: + password_hash = result["password"] + return password_hash.startswith(argon2_prefixes) return False -def delete_user_from_db(user): +async def delete_user_from_db(user): """ Deletes a user record from the users database. - - Args: - user (str): The username of the user to be deleted. """ - with sqlite3.connect(user_db_file_location) as conn: - c = conn.cursor() - c.execute("DELETE FROM users WHERE user = ?", (user,)) + async with POOL.acquire() as conn: + await conn.execute('DELETE FROM users WHERE "user" = $1', user) -def add_user(user, password, role): +async def add_user(user, password, role): """ Adds a new user to the users database with a hashed password. - - Args: - user (str): The username to add (stored in uppercase). - password (str): The plain text password to hash and store. - role (str): The role assigned to the user. - """ - ph = PasswordHasher() - with sqlite3.connect(user_db_file_location) as conn: - c = conn.cursor() - c.execute( - "INSERT INTO users VALUES (?, ?, ?)", - (user.upper(), ph.hash(password), role), + """ + hashed_password = pwd_context.hash(password) + async with POOL.acquire() as conn: + await conn.execute( + 'INSERT INTO users ("user", password, role) VALUES ($1, $2, $3)', + user.upper(), hashed_password, role, ) -def update_user(user, password, role): +async def update_user(user, password, role): """ Updates the password and role of an existing user in the users database. - - Args: - user (str): The username to update (case-insensitive, stored as uppercase). - password (str): The new plain text password to hash and store. - role (str): The new role to assign to the user. - """ - ph = PasswordHasher() - with sqlite3.connect(user_db_file_location) as conn: - c = conn.cursor() - c.execute( - "UPDATE users SET password = ?, role = ? WHERE user = ?", - (ph.hash(password), role, user.upper()), + """ + hashed_password = pwd_context.hash(password) + async with POOL.acquire() as conn: + await conn.execute( + 'UPDATE users SET password = $1, role = $2 WHERE "user" = $3', + hashed_password, role, user.upper(), ) +# --- Node Management Functions --- -def list_nodes(scenario_name=None, sort_by="idx"): +async def list_nodes(scenario_name=None, sort_by="idx"): """ Retrieves a list of nodes from the nodes database, optionally filtered by scenario and sorted. - - Args: - scenario_name (str, optional): Name of the scenario to filter nodes by. If None, returns all nodes. - sort_by (str): Column name to sort the results by. Defaults to "idx". - - Returns: - list or None: A list of sqlite3.Row objects representing nodes, or None if an error occurs. """ - try: - with sqlite3.connect(node_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() + # Validate sort_by to prevent SQL injection + allowed_sort_fields = ["uid", "idx", "ip", "port", "role", "timestamp", "federation", "round"] + if sort_by not in allowed_sort_fields: + sort_by = "idx" # Default to a safe field + try: + async with POOL.acquire() as conn: if scenario_name: - command = "SELECT * FROM nodes WHERE scenario = ? ORDER BY " + sort_by + ";" - c.execute(command, (scenario_name,)) + # Using f-string for column names is generally safe if validated as above + command = f"SELECT * FROM nodes WHERE scenario = $1 ORDER BY {sort_by};" + result = await conn.fetch(command, scenario_name) else: - command = "SELECT * FROM nodes ORDER BY " + sort_by + ";" - c.execute(command) - - result = c.fetchall() - + command = f"SELECT * FROM nodes ORDER BY {sort_by};" + result = await conn.fetch(command) return result - except sqlite3.Error as e: - print(f"Error occurred while listing nodes: {e}") + except asyncpg.PostgresError as e: + logging.error(f"Error occurred while listing nodes: {e}") return None -def list_nodes_by_scenario_name(scenario_name): +async def list_nodes_by_scenario_name(scenario_name): """ Fetches all nodes associated with a specific scenario, ordered by their index as integers. - - Args: - scenario_name (str): The name of the scenario to filter nodes by. - - Returns: - list or None: A list of sqlite3.Row objects for nodes in the scenario, or None if an error occurs. """ try: - with sqlite3.connect(node_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - - command = "SELECT * FROM nodes WHERE scenario = ? ORDER BY CAST(idx AS INTEGER) ASC;" - c.execute(command, (scenario_name,)) - result = c.fetchall() - - return result - except sqlite3.Error as e: - print(f"Error occurred while listing nodes by scenario name: {e}") + async with POOL.acquire() as conn: + command = "SELECT * FROM nodes WHERE scenario = $1 ORDER BY CAST(idx AS INTEGER) ASC;" + result = await conn.fetch(command, scenario_name) + return [dict(record) for record in result] + except Exception as e: + logging.error(f"Error occurred while listing nodes by scenario name: {e}") return None async def update_node_record( - node_uid, - idx, - ip, - port, - role, - neighbors, - latitude, - longitude, - timestamp, - federation, - federation_round, - scenario, - run_hash, - malicious, + node_uid, idx, ip, port, role, neighbors, latitude, longitude, + timestamp, federation, federation_round, scenario, run_hash, malicious, ): """ Inserts or updates a node record in the database for a given scenario, ensuring thread-safe access. - - Args: - node_uid (str): Unique identifier of the node. - idx (str): Index or identifier within the scenario. - ip (str): IP address of the node. - port (str): Port used by the node. - role (str): Role of the node in the federation. - neighbors (str): Neighbors of the node (serialized). - latitude (str): Geographic latitude of the node. - longitude (str): Geographic longitude of the node. - timestamp (str): Timestamp of the last update. - federation (str): Federation identifier the node belongs to. - federation_round (str): Current federation round. - scenario (str): Scenario name the node is part of. - run_hash (str): Hash of the current run/state. - malicious (str): Indicator if the node is malicious. - - Returns: - dict or None: The updated or inserted node record as a dictionary, or None if insertion/update failed. - """ - global _node_lock + """ async with _node_lock: - async with aiosqlite.connect(node_db_file_location) as conn: - conn.row_factory = aiosqlite.Row - _c = await conn.cursor() - - # Check if the node already exists - await _c.execute("SELECT * FROM nodes WHERE uid = ? AND scenario = ?;", (node_uid, scenario)) - result = await _c.fetchone() - - if result is None: - # Insert new node - await _c.execute( - "INSERT INTO nodes (uid, idx, ip, port, role, neighbors, latitude, longitude, timestamp, federation, round, scenario, hash, malicious) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);", - ( - node_uid, - idx, - ip, - port, - role, - neighbors, - latitude, - longitude, - timestamp, - federation, - federation_round, - scenario, - run_hash, - malicious, - ), - ) - else: - # Update existing node - await _c.execute( - "UPDATE nodes SET idx = ?, ip = ?, port = ?, role = ?, neighbors = ?, latitude = ?, longitude = ?, timestamp = ?, federation = ?, round = ?, hash = ?, malicious = ? WHERE uid = ? AND scenario = ?;", - ( - idx, - ip, - port, - role, - neighbors, - latitude, - longitude, - timestamp, - federation, - federation_round, - run_hash, - malicious, - node_uid, - scenario, - ), - ) - - await conn.commit() - - # Fetch the updated or newly inserted row - await _c.execute("SELECT * FROM nodes WHERE uid = ? AND scenario = ?;", (node_uid, scenario)) - updated_row = await _c.fetchone() - return dict(updated_row) if updated_row else None - - -def remove_all_nodes(): + async with POOL.acquire() as conn: + try: + async with conn.transaction(): + result = await conn.fetchrow( + "SELECT * FROM nodes WHERE uid = $1 AND scenario = $2 FOR UPDATE;", + node_uid, scenario + ) + + if result is None: + # Insert new node + await conn.execute( + """ + INSERT INTO nodes (uid, idx, ip, port, role, neighbors, latitude, longitude, + timestamp, federation, round, scenario, hash, malicious) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14); + """, + node_uid, idx, ip, port, role, neighbors, latitude, longitude, + timestamp, federation, federation_round, scenario, run_hash, malicious, + ) + else: + # Update existing node + await conn.execute( + """ + UPDATE nodes SET idx = $1, ip = $2, port = $3, role = $4, neighbors = $5, + latitude = $6, longitude = $7, timestamp = $8, federation = $9, round = $10, + hash = $11, malicious = $12 + WHERE uid = $13 AND scenario = $14; + """, + idx, ip, port, role, neighbors, latitude, longitude, + timestamp, federation, federation_round, run_hash, malicious, + node_uid, scenario, + ) + + updated_row = await conn.fetchrow("SELECT * from nodes WHERE uid = $1 AND scenario = $2;", node_uid, scenario) + return dict(updated_row) if updated_row else None + except asyncpg.PostgresError as e: + logging.error(f"Database error during node record update: {e}", exc_info=True) + return None + + +async def remove_all_nodes(): """ Deletes all node records from the nodes database. - - This operation removes every entry in the nodes table. - - Returns: - None """ - with sqlite3.connect(node_db_file_location) as conn: - c = conn.cursor() - command = "DELETE FROM nodes;" - c.execute(command) + async with POOL.acquire() as conn: + await conn.execute("TRUNCATE nodes CASCADE;") # Use CASCADE if there are foreign key dependencies -def remove_nodes_by_scenario_name(scenario_name): +async def remove_nodes_by_scenario_name(scenario_name): """ Deletes all nodes associated with a specific scenario from the database. - - Args: - scenario_name (str): The name of the scenario whose nodes should be removed. - - Returns: - None - """ - with sqlite3.connect(node_db_file_location) as conn: - c = conn.cursor() - command = "DELETE FROM nodes WHERE scenario = ?;" - c.execute(command, (scenario_name,)) - - -def get_all_scenarios(username, role, sort_by="start_time"): """ - Retrieve all scenarios from the database filtered by user role and sorted by a specified field. - - Parameters: - username (str): The username of the requesting user. - role (str): The role of the user, e.g., "admin" or regular user. - sort_by (str, optional): The field name to sort the results by. Defaults to "start_time". - - Returns: - list[sqlite3.Row]: A list of scenario records as SQLite Row objects. - - Behavior: - - Admin users retrieve all scenarios. - - Non-admin users retrieve only scenarios associated with their username. - - Sorting by "start_time" applies custom datetime ordering. - - Other sort fields are applied directly in the ORDER BY clause. + async with POOL.acquire() as conn: + await conn.execute("DELETE FROM nodes WHERE scenario = $1;", scenario_name) + +# --- Scenario Management Functions --- + +async def get_all_scenarios(username, role, sort_by="start_time"): + """ + Retrieves all scenarios from the database, accessing fields from the 'config' (JSONB) column + and direct columns. Filters by user role and sorts by the specified field. + """ + allowed_sort_fields = ["start_time", "title", "username", "status", "name"] + if sort_by not in allowed_sort_fields: + sort_by = "start_time" + + # Determine the ORDER BY clause based on sort_by + if sort_by == "start_time": + order_by_clause = """ + ORDER BY + CASE + WHEN start_time IS NULL OR start_time = '' THEN 1 + ELSE 0 + END, + to_timestamp(start_time, 'DD/MM/YYYY HH24:MI:SS') DESC + """ + elif sort_by in ["title", "model", "dataset", "rounds"]: # These are inside config JSONB + order_by_clause = f"ORDER BY config->>'{sort_by}'" + else: # For direct table columns like name, username, status + order_by_clause = f"ORDER BY {sort_by}" + + async with POOL.acquire() as conn: + # Select direct columns and relevant fields from config JSONB + command = """ + SELECT + name, + username, + status, + start_time, + end_time, + config->>'title' AS title, + config->>'model' AS model, + config->>'dataset' AS dataset, + config->>'rounds' AS rounds, + config -- return the full config JSONB + FROM scenarios + """ + params = [] + + if role != "admin": + command += " WHERE username = $1" # username is a direct column now + params.append(username) + + full_command = f"{command} {order_by_clause};" + return await conn.fetch(full_command, *params) + + +async def get_all_scenarios_and_check_completed(username, role, sort_by="start_time"): + """ + Retrieves all scenarios, sorts them, and updates the status if necessary. + Returns a list of dictionaries, where each dictionary represents a scenario. + """ + # Safe list of allowed sorting fields to prevent SQL injection. + allowed_sort_fields = ["start_time", "title", "username", "status", "name"] + if sort_by not in allowed_sort_fields: + sort_by = "start_time" # Safe default value + + # Building the ORDER BY clause + if sort_by == "start_time": + order_by_clause = """ + ORDER BY + CASE + WHEN start_time IS NULL OR start_time = '' THEN 1 + ELSE 0 + END, + to_timestamp(start_time, 'DD/MM/YYYY HH24:MI:SS') DESC + """ + elif sort_by in ["title", "model", "dataset", "rounds"]: # These are inside config JSONB + order_by_clause = f"ORDER BY config->>'{sort_by}'" + else: # For direct table columns like name, username, status + order_by_clause = f"ORDER BY {sort_by}" + + async with POOL.acquire() as conn: + # Base query that extracts fields from the JSONB using the ->> operator + command = f""" + SELECT + name, + username, + status, + start_time, + end_time, + config->>'title' AS title, + config->>'model' AS model, + config->>'dataset' AS dataset, + config->>'rounds' AS rounds, + config -- Return the full config object + FROM scenarios + """ + params = [] + if role != "admin": + command += " WHERE username = $1" # username is a direct column + params.append(username) + + command += f" {order_by_clause};" + + result_dicts = await conn.fetch(command, *params) + + scenarios_to_return = [dict(s) for s in result_dicts] + + re_fetch_required = False + for scenario in scenarios_to_return: + if scenario["status"] == "running": + if await check_scenario_federation_completed(scenario["name"]): + await scenario_set_status_to_completed(scenario["name"]) + re_fetch_required = True + break + + if re_fetch_required: + # Recursively call to get fresh data after status update + return await get_all_scenarios_and_check_completed(username, role, sort_by) + + return scenarios_to_return + + +async def scenario_update_record(name, start_time, end_time, scenario_config, status, username): + """ + Inserts or updates a scenario record using the PostgreSQL "UPSERT" pattern. + All configuration is saved in the 'config' column of type JSONB. + Direct columns (name, start_time, end_time, username, status) are also handled. + """ + # Ensure scenario_config is a dictionary before dumping to JSON + if not isinstance(scenario_config, dict): + try: + scenario_config = json.loads(scenario_config) + except (json.JSONDecodeError, TypeError): + logging.error("scenario_config is not a valid JSON string or dict.") + return + + command = """ + INSERT INTO scenarios (name, start_time, end_time, username, status, config) + VALUES ($1, $2, $3, $4, $5, $6::jsonb) + ON CONFLICT (name) DO UPDATE SET + start_time = EXCLUDED.start_time, + end_time = EXCLUDED.end_time, + username = EXCLUDED.username, + status = EXCLUDED.status, + config = scenarios.config || EXCLUDED.config; -- Merge JSONB + """ + async with POOL.acquire() as conn: + await conn.execute(command, name, start_time, end_time, username, status, json.dumps(scenario_config)) + + +async def scenario_set_all_status_to_finished(): + """ + Sets the status of all 'running' scenarios to 'finished' + and updates their 'end_time' (both in the direct column and within JSONB). + """ + current_time = datetime.datetime.now().strftime('%Y/%m/%d %H:%M:%S') # Consistent format + command = """ + UPDATE scenarios + SET + status = 'finished', + end_time = $1, + config = jsonb_set(config, '{status}', '"finished"') || + jsonb_set(config, '{end_time}', $2::jsonb) + WHERE status = 'running'; + """ + async with POOL.acquire() as conn: + await conn.execute(command, current_time, json.dumps(current_time)) + + +async def scenario_set_status_to_finished(scenario_name): + """ + Sets the status of a specific scenario to 'finished' and updates its 'end_time'. + Updates both the direct columns and the JSONB 'config'. + """ + current_time = datetime.datetime.now().strftime('%Y/%m/%d %H:%M:%S') # Consistent format + command = """ + UPDATE scenarios + SET + status = 'finished', + end_time = $1, + config = jsonb_set( + jsonb_set(config, '{status}', '"finished"'), + '{end_time}', $2::jsonb + ) + WHERE name = $3; """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - if role == "admin": - if sort_by == "start_time": - command = """ - SELECT * FROM scenarios - ORDER BY strftime('%Y-%m-%d %H:%M:%S', substr(start_time, 7, 4) || '-' || substr(start_time, 4, 2) || '-' || substr(start_time, 1, 2) || ' ' || substr(start_time, 12, 8)); - """ - c.execute(command) - else: - command = "SELECT * FROM scenarios ORDER BY ?;" - c.execute(command, (sort_by,)) - else: - if sort_by == "start_time": - command = """ - SELECT * FROM scenarios - WHERE username = ? - ORDER BY strftime('%Y-%m-%d %H:%M:%S', substr(start_time, 7, 4) || '-' || substr(start_time, 4, 2) || '-' || substr(start_time, 1, 2) || ' ' || substr(start_time, 12, 8)); - """ - c.execute(command, (username,)) - else: - command = "SELECT * FROM scenarios WHERE username = ? ORDER BY ?;" - c.execute( - command, - ( - username, - sort_by, - ), - ) - result = c.fetchall() + async with POOL.acquire() as conn: + await conn.execute(command, current_time, json.dumps(current_time), scenario_name) - return result - -def get_all_scenarios_and_check_completed(username, role, sort_by="start_time"): - """ - Retrieve all scenarios with detailed fields and update the status of running scenarios if their federation is completed. - - Parameters: - username (str): The username of the requesting user. - role (str): The role of the user, e.g., "admin" or regular user. - sort_by (str, optional): The field name to sort the results by. Defaults to "start_time". - - Returns: - list[sqlite3.Row]: A list of scenario records including name, username, title, start_time, model, dataset, rounds, and status. - - Behavior: - - Admin users retrieve all scenarios. - - Non-admin users retrieve only scenarios associated with their username. - - Scenarios are sorted by start_time with special handling for null or empty values. - - For scenarios with status "running", checks if federation is completed: - - If completed, updates the scenario status to "completed". - - Refreshes the returned scenario list after updates. - """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - - if role == "admin": - if sort_by == "start_time": - command = """ - SELECT name, username, title, start_time, model, dataset, rounds, status FROM scenarios - ORDER BY - CASE - WHEN start_time IS NULL OR start_time = '' THEN 1 - ELSE 0 - END, - strftime( - '%Y-%m-%d %H:%M:%S', - substr(start_time, 7, 4) || '-' || substr(start_time, 4, 2) || '-' || substr(start_time, 1, 2) || ' ' || substr(start_time, 12, 8) - ); - """ - c.execute(command) - else: - command = "SELECT name, username, title, start_time, model, dataset, rounds, status FROM scenarios ORDER BY ?;" - c.execute(command, (sort_by,)) - result = c.fetchall() - else: - if sort_by == "start_time": - command = """ - SELECT name, username, title, start_time, model, dataset, rounds, status FROM scenarios - WHERE username = ? - ORDER BY - CASE - WHEN start_time IS NULL OR start_time = '' THEN 1 - ELSE 0 - END, - strftime( - '%Y-%m-%d %H:%M:%S', - substr(start_time, 7, 4) || '-' || substr(start_time, 4, 2) || '-' || substr(start_time, 1, 2) || ' ' || substr(start_time, 12, 8) - ); - """ - c.execute(command, (username,)) - else: - command = "SELECT name, username, title, start_time, model, dataset, rounds, status FROM scenarios WHERE username = ? ORDER BY ?;" - c.execute( - command, - ( - username, - sort_by, - ), - ) - result = c.fetchall() - - for scenario in result: - if scenario["status"] == "running": - if check_scenario_federation_completed(scenario["name"]): - scenario_set_status_to_completed(scenario["name"]) - result = get_all_scenarios(username, role) - - return result - - -def scenario_update_record(name, start_time, end_time, scenario, status, role, username): - """ - Insert a new scenario record or update an existing one in the database based on the scenario name. - - Parameters: - name (str): The unique name identifier of the scenario. - start_time (str): The start time of the scenario. - end_time (str): The end time of the scenario. - scenario (object): An object containing detailed scenario attributes. - status (str): The current status of the scenario. - role (str): The role of the user performing the operation. - username (str): The username of the user performing the operation. - - Behavior: - - Checks if a scenario with the given name exists. - - If not, inserts a new record with all scenario details. - - If exists, updates the existing record with the provided data. - - Commits the transaction to persist changes. - """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - - select_command = "SELECT * FROM scenarios WHERE name = ?;" - c.execute(select_command, (name,)) - result = c.fetchone() - - if result is None: - insert_command = """ - INSERT INTO scenarios ( - name, - start_time, - end_time, - title, - description, - deployment, - federation, - topology, - nodes, - nodes_graph, - n_nodes, - matrix, - random_topology_probability, - dataset, - iid, - partition_selection, - partition_parameter, - model, - agg_algorithm, - rounds, - logginglevel, - report_status_data_queue, - accelerator, - gpu_id, - network_subnet, - network_gateway, - epochs, - attack_params, - reputation, - random_geo, - latitude, - longitude, - mobility, - mobility_type, - radius_federation, - scheme_mobility, - round_frequency, - mobile_participants_percent, - additional_participants, - schema_additional_participants, - status, - role, - username - ) VALUES ( - ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? - ); - """ - c.execute( - insert_command, - ( - name, - start_time, - end_time, - scenario.scenario_title, - scenario.scenario_description, - scenario.deployment, - scenario.federation, - scenario.topology, - json.dumps(scenario.nodes), - json.dumps(scenario.nodes_graph), - scenario.n_nodes, - json.dumps(scenario.matrix), - scenario.random_topology_probability, - scenario.dataset, - scenario.iid, - scenario.partition_selection, - scenario.partition_parameter, - scenario.model, - scenario.agg_algorithm, - scenario.rounds, - scenario.logginglevel, - scenario.report_status_data_queue, - scenario.accelerator, - json.dumps(scenario.gpu_id), - scenario.network_subnet, - scenario.network_gateway, - scenario.epochs, - json.dumps(scenario.attack_params), - json.dumps(scenario.reputation), - scenario.random_geo, - scenario.latitude, - scenario.longitude, - scenario.mobility, - scenario.mobility_type, - scenario.radius_federation, - scenario.scheme_mobility, - scenario.round_frequency, - scenario.mobile_participants_percent, - json.dumps(scenario.additional_participants), - scenario.schema_additional_participants, - status, - role, - username, - ), - ) - else: - update_command = """ - UPDATE scenarios SET - start_time = ?, - end_time = ?, - title = ?, - description = ?, - deployment = ?, - federation = ?, - topology = ?, - nodes = ?, - nodes_graph = ?, - n_nodes = ?, - matrix = ?, - random_topology_probability = ?, - dataset = ?, - iid = ?, - partition_selection = ?, - partition_parameter = ?, - model = ?, - agg_algorithm = ?, - rounds = ?, - logginglevel = ?, - report_status_data_queue = ?, - accelerator = ?, - gpu_id = ?, - network_subnet = ?, - network_gateway = ?, - epochs = ?, - attack_params = ?, - reputation = ?, - random_geo = ?, - latitude = ?, - longitude = ?, - mobility = ?, - mobility_type = ?, - radius_federation = ?, - scheme_mobility = ?, - round_frequency = ?, - mobile_participants_percent = ?, - additional_participants = ?, - schema_additional_participants = ?, - status = ?, - role = ?, - username = ? - WHERE name = ?; - """ - c.execute( - update_command, - ( - start_time, - end_time, - scenario.scenario_title, - scenario.scenario_description, - scenario.deployment, - scenario.federation, - scenario.topology, - json.dumps(scenario.nodes), - json.dumps(scenario.nodes_graph), - scenario.n_nodes, - json.dumps(scenario.matrix), - scenario.random_topology_probability, - scenario.dataset, - scenario.iid, - scenario.partition_selection, - scenario.partition_parameter, - scenario.model, - scenario.agg_algorithm, - scenario.rounds, - scenario.logginglevel, - scenario.report_status_data_queue, - scenario.accelerator, - json.dumps(scenario.gpu_id), - scenario.network_subnet, - scenario.network_gateway, - scenario.epochs, - json.dumps(scenario.attack_params), - json.dumps(scenario.reputation), - scenario.random_geo, - scenario.latitude, - scenario.longitude, - scenario.mobility, - scenario.mobility_type, - scenario.radius_federation, - scenario.scheme_mobility, - scenario.round_frequency, - scenario.mobile_participants_percent, - json.dumps(scenario.additional_participants), - scenario.schema_additional_participants, - status, - role, - username, - name, - ), - ) - - conn.commit() - - -def scenario_set_all_status_to_finished(): - """ - Set the status of all currently running scenarios to "finished" and update their end time to the current datetime. - - Behavior: - - Finds all scenarios with status "running". - - Updates their status to "finished". - - Sets the end_time to the current timestamp. - - Commits the changes to the database. - """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - current_time = str(datetime.datetime.now()) - c.execute("UPDATE scenarios SET status = 'finished', end_time = ? WHERE status = 'running';", (current_time,)) - conn.commit() - - -def scenario_set_status_to_finished(scenario_name): - """ - Set the status of a specific scenario to "finished" and update its end time to the current datetime. - - Parameters: - scenario_name (str): The unique name identifier of the scenario to update. - - Behavior: - - Updates the scenario's status to "finished". - - Sets the end_time to the current timestamp. - - Commits the update to the database. +async def scenario_set_status_to_completed(scenario_name): """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - current_time = str(datetime.datetime.now()) - c.execute( - "UPDATE scenarios SET status = 'finished', end_time = ? WHERE name = ?;", (current_time, scenario_name) - ) - conn.commit() - - -def scenario_set_status_to_completed(scenario_name): + Sets the status of a specific scenario to 'completed'. + Updates both the direct column and the JSONB 'config'. """ - Set the status of a specific scenario to "completed". - - Parameters: - scenario_name (str): The unique name identifier of the scenario to update. - - Behavior: - - Updates the scenario's status to "completed". - - Commits the change to the database. + command = """ + UPDATE scenarios + SET + status = 'completed', + config = jsonb_set(config, '{status}', '"completed"') + WHERE name = $1; """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - c.execute("UPDATE scenarios SET status = 'completed' WHERE name = ?;", (scenario_name,)) - conn.commit() + async with POOL.acquire() as conn: + await conn.execute(command, scenario_name) -def get_running_scenario(username=None, get_all=False): +async def get_running_scenario(username=None, get_all=False): """ - Retrieve running or completed scenarios from the database, optionally filtered by username. - - Parameters: - username (str, optional): The username to filter scenarios by. If None, no user filter is applied. - get_all (bool, optional): If True, returns all matching scenarios; otherwise returns only one. Defaults to False. - - Returns: - sqlite3.Row or list[sqlite3.Row]: A single scenario record or a list of scenario records matching the criteria. - - Behavior: - - Filters scenarios with status "running". - - Applies username filter if provided. - - Returns either one or all matching records depending on get_all. + Retrieves scenarios with a 'running' status, optionally filtered by user. + Returns full scenario record (including direct columns and config JSONB). """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() + async with POOL.acquire() as conn: + params = ["running"] + # Select all columns to get both direct and config data + command = "SELECT name, username, status, start_time, end_time, config FROM scenarios WHERE status = $1" if username: - command = """ - SELECT * FROM scenarios - WHERE (status = ?) AND username = ?; - """ - c.execute(command, ("running", username)) + command += " AND username = $2" + params.append(username) - result = c.fetchone() + if get_all: + result = [dict(row) for row in await conn.fetch(command, *params)] # Convert records to dicts else: - command = "SELECT * FROM scenarios WHERE status = ?;" - c.execute(command, ("running",)) - if get_all: - result = c.fetchall() - else: - result = c.fetchone() - + result_row = await conn.fetchrow(command, *params) + result = dict(result_row) if result_row else None return result -def get_completed_scenario(): +async def get_completed_scenario(): """ - Retrieve a single scenario with status "completed" from the database. - - Returns: - sqlite3.Row: A scenario record with status "completed", or None if no such scenario exists. - - Behavior: - - Fetches the first scenario found with status "completed". + Retrieves a single scenario with a 'completed' status. + Returns full scenario record (including direct columns and config JSONB). """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - command = "SELECT * FROM scenarios WHERE status = ?;" - c.execute(command, ("completed",)) - result = c.fetchone() - - return result + async with POOL.acquire() as conn: + command = "SELECT name, username, status, start_time, end_time, config FROM scenarios WHERE status = $1;" + result_row = await conn.fetchrow(command, "completed") + return dict(result_row) if result_row else None -def get_scenario_by_name(scenario_name): +async def get_scenario_by_name(scenario_name): """ - Retrieve a scenario record by its unique name. + Retrieves the complete record of a scenario by its name. + """ + async with POOL.acquire() as conn: + result_row = await conn.fetchrow("SELECT name, start_time, end_time, username, status, config FROM scenarios WHERE name = $1;", scenario_name) - Parameters: - scenario_name (str): The unique name identifier of the scenario. + result = dict(result_row) if result_row else None - Returns: - sqlite3.Row: The scenario record matching the given name, or None if not found. - """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - c.execute("SELECT * FROM scenarios WHERE name = ?;", (scenario_name,)) - result = c.fetchone() + if result and result.get('config'): + # Assuming 'config' is a JSON string from the DB, so we parse it + # It might already be a dict if asyncpg handles JSONB conversion automatically + config_data = result['config'] + if isinstance(config_data, str): + try: + config_data = json.loads(config_data) + except json.JSONDecodeError: + config_data = {} + + # Extract the 'scenario_title' and add it as a top-level key + result['title'] = config_data.get('scenario_title') + result['description'] = config_data.get('description') return result -def get_user_by_scenario_name(scenario_name): +async def get_user_by_scenario_name(scenario_name): """ - Retrieve the username associated with a given scenario name. - - Parameters: - scenario_name (str): The unique name identifier of the scenario. - - Returns: - str: The username linked to the specified scenario, or None if not found. + Retrieves the username associated with a scenario (from the direct 'username' column). """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - c.execute("SELECT username FROM scenarios WHERE name = ?;", (scenario_name,)) - result = c.fetchone() - - return result["username"] + async with POOL.acquire() as conn: + return await conn.fetchval("SELECT username FROM scenarios WHERE name = $1;", scenario_name) -def remove_scenario_by_name(scenario_name): +async def remove_scenario_by_name(scenario_name): """ Delete a scenario from the database by its unique name. - - Parameters: - scenario_name (str): The unique name identifier of the scenario to be removed. - - Behavior: - - Removes the scenario record matching the given name. - - Commits the deletion to the database. """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - c.execute("DELETE FROM scenarios WHERE name = ?;", (scenario_name,)) - conn.commit() + try: + async with POOL.acquire() as conn: + await conn.execute("DELETE FROM scenarios WHERE name = $1;", scenario_name) + logging.info(f"Scenario '{scenario_name}' successfully removed.") + except asyncpg.PostgresError as e: + logging.error(f"Error occurred while deleting scenario '{scenario_name}': {e}") -def check_scenario_federation_completed(scenario_name): +async def check_scenario_federation_completed(scenario_name): """ Check if all nodes in a given scenario have completed the required federation rounds. + """ + try: + async with POOL.acquire() as conn: + # Retrieve the total rounds for the scenario from the 'config' JSONB column + scenario_rounds_str = await conn.fetchval("SELECT config->>'rounds' AS rounds FROM scenarios WHERE name = $1;", scenario_name) - Parameters: - scenario_name (str): The unique name identifier of the scenario to check. + if not scenario_rounds_str: + logging.warning(f"Scenario '{scenario_name}' not found or 'rounds' not defined.") + return False - Returns: - bool: True if all nodes have completed the total rounds specified for the scenario, False otherwise or if an error occurs. + # Ensure total_rounds is an integer for comparison + try: + total_rounds = int(scenario_rounds_str) + except (ValueError, TypeError): + logging.error(f"Invalid 'rounds' value for scenario '{scenario_name}': {scenario_rounds_str}") + return False - Behavior: - - Retrieves the total number of rounds defined for the scenario. - - Fetches the current round progress of all nodes in that scenario. - - Returns True only if every node has reached the total rounds. - - Handles database errors and missing scenario cases gracefully. - """ - try: - # Connect to the scenario database to get the total rounds for the scenario - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - c.execute("SELECT rounds FROM scenarios WHERE name = ?;", (scenario_name,)) - scenario = c.fetchone() - - if not scenario: - raise ValueError(f"Scenario '{scenario_name}' not found.") - - total_rounds = scenario["rounds"] - - # Connect to the node database to check the rounds for each node - with sqlite3.connect(node_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - c.execute("SELECT round FROM nodes WHERE scenario = ?;", (scenario_name,)) - nodes = c.fetchall() - - if len(nodes) == 0: + # Fetch the current round progress of all nodes in that scenario + nodes = await conn.fetch("SELECT round FROM nodes WHERE scenario = $1;", scenario_name) + + if not nodes: + logging.info(f"No nodes found for scenario '{scenario_name}'. Federation not considered completed.") return False # Check if all nodes have completed the total rounds - total_rounds_str = str(total_rounds) - return all(str(node["round"]) == total_rounds_str for node in nodes) + return all(int(node["round"]) >= total_rounds for node in nodes) - except sqlite3.Error as e: - print(f"Database error: {e}") + except asyncpg.PostgresError as e: + logging.error(f"PostgreSQL error during check_scenario_federation_completed for '{scenario_name}': {e}") return False - except Exception as e: - print(f"An error occurred: {e}") + except ValueError as e: + logging.error(f"Data error during check_scenario_federation_completed for '{scenario_name}': {e}") return False -def check_scenario_with_role(role, scenario_name): +async def check_scenario_with_role(role, scenario_name, current_username=None): + """ + Verify if a scenario exists that the user with the given role and username can access. """ - Verify if a scenario exists with a specific role and name. + scenario_info = await get_scenario_by_name(scenario_name) - Parameters: - role (str): The role associated with the scenario (e.g., "admin", "user"). - scenario_name (str): The unique name identifier of the scenario. + if not scenario_info: + return False # Scenario does not exist - Returns: - bool: True if a scenario with the given role and name exists, False otherwise. - """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - c.execute( - "SELECT * FROM scenarios WHERE role = ? AND name = ?;", - ( - role, - scenario_name, - ), + if role == "admin": + return True # Admins can access any existing scenario + + if current_username is None: + logging.info(f"[FER] db username {scenario_info.get('username')} current_username {current_username}") + logging.warning( + "check_scenario_with_role called for non-admin role without current_username." ) - result = c.fetchone() + return False - return result is not None + return scenario_info.get("username") == current_username +# --- Notes Management Functions --- -def save_notes(scenario, notes): +async def save_notes(scenario, notes): """ Save or update notes associated with a specific scenario. - - Parameters: - scenario (str): The unique identifier of the scenario. - notes (str): The textual notes to be saved for the scenario. - - Behavior: - - Inserts new notes if the scenario does not exist in the database. - - Updates existing notes if the scenario already has notes saved. - - Handles SQLite integrity and general database errors gracefully. """ try: - with sqlite3.connect(notes_db_file_location) as conn: - c = conn.cursor() - c.execute( + async with POOL.acquire() as conn: + await conn.execute( """ - INSERT INTO notes (scenario, scenario_notes) VALUES (?, ?) - ON CONFLICT(scenario) DO UPDATE SET scenario_notes = excluded.scenario_notes; + INSERT INTO notes (scenario, scenario_notes) VALUES ($1, $2) + ON CONFLICT(scenario) DO UPDATE SET scenario_notes = EXCLUDED.scenario_notes; """, - (scenario, notes), + scenario, notes, ) - conn.commit() - except sqlite3.IntegrityError as e: - print(f"SQLite integrity error: {e}") - except sqlite3.Error as e: - print(f"SQLite error: {e}") + except asyncpg.PostgresError as e: + logging.error(f"PostgreSQL error during save_notes: {e}") -def get_notes(scenario): +async def get_notes(scenario): """ Retrieve notes associated with a specific scenario. - - Parameters: - scenario (str): The unique identifier of the scenario. - - Returns: - sqlite3.Row or None: The notes record for the given scenario, or None if no notes exist. """ - with sqlite3.connect(notes_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - c.execute("SELECT * FROM notes WHERE scenario = ?;", (scenario,)) - result = c.fetchone() + async with POOL.acquire() as conn: + return await conn.fetchrow("SELECT * FROM notes WHERE scenario = $1;", scenario) - return result - -def remove_note(scenario): +async def remove_note(scenario): """ Delete the note associated with a specific scenario. - - Parameters: - scenario (str): The unique identifier of the scenario whose note should be removed. - """ - with sqlite3.connect(notes_db_file_location) as conn: - c = conn.cursor() - c.execute("DELETE FROM notes WHERE scenario = ?;", (scenario,)) - conn.commit() - - -if __name__ == "__main__": - """ - Entry point for the script to print the list of users. - - When executed directly, this block calls the `list_users()` function - and prints its returned list of users. """ - print(list_users()) + async with POOL.acquire() as conn: + await conn.execute("DELETE FROM notes WHERE scenario = $1;", scenario) diff --git a/nebula/controller/scenarios.py b/nebula/controller/scenarios.py index 7276f3893..0d8f2286e 100644 --- a/nebula/controller/scenarios.py +++ b/nebula/controller/scenarios.py @@ -549,6 +549,26 @@ def from_dict(cls, data): return scenario + @staticmethod + def to_json(scenario_obj): + """ + Converts a Scenario object to a JSON string. + + Args: + scenario_obj (Scenario): An instance of the Scenario class. + + Returns: + str: A JSON string representation of the Scenario object. + """ + if not isinstance(scenario_obj, Scenario): + raise TypeError("Input must be an instance of the Scenario class.") + + # Get all attributes of the Scenario object + scenario_dict = scenario_obj.__dict__ + + # Convert the dictionary to a JSON string + return json.dumps(scenario_dict, indent=2) # Using indent for pretty-printing + # Class to manage the current scenario class ScenarioManagement: diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 151ae6b22..fc2fef4f6 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -94,7 +94,7 @@ def __init__( self.ip = config.participant["network_args"]["ip"] self.port = config.participant["network_args"]["port"] self.addr = config.participant["network_args"]["addr"] - + self.name = config.participant["device_args"]["name"] self.client = docker.from_env() @@ -187,7 +187,7 @@ def aggregator(self): def trainer(self): """Trainer""" return self._trainer - + @property def rb(self): """Role Behavior""" @@ -317,7 +317,7 @@ async def _control_alive_callback(self, source, message): async def _control_leadership_transfer_callback(self, source, message): logging.info(f"🔧 handle_control_message | Trigger | Received leadership transfer message from {source}") - + if await self._round_in_process_lock.locked_async(): logging.info("Learning cycle is executing, role behavior will be modified next round") await self.rb.set_next_role(Role.AGGREGATOR, source_to_notificate=source) @@ -354,7 +354,7 @@ async def _control_leadership_transfer_ack_callback(self, source, message): except TimeoutError: logging.info("Learning cycle is locked, role behavior will be modified next round") await self.rb.set_next_role(Role.TRAINER) - + async def _connection_connect_callback(self, source, message): logging.info(f"🔗 handle_connection_message | Trigger | Received connection message from {source}") @@ -600,7 +600,7 @@ async def start_communications(self): before other services or training processes begin. """ await self.register_events_callbacks() - initial_neighbors = self.config.participant["network_args"]["neighbors"].split() + initial_neighbors = self.config.participant["network_args"]["neighbors"] await self.cm.start_communications(initial_neighbors) await asyncio.sleep(self.config.participant["misc_args"]["grace_time_connection"] // 2) @@ -710,10 +710,10 @@ async def _start_learning(self): await self.get_federation_ready_lock().acquire_async() if self.config.participant["device_args"]["start"]: logging.info("Propagate initial model updates.") - + mpe = ModelPropagationEvent(await self.cm.get_addrs_current_connections(only_direct=True, myself=False), "initialization") await EventManager.get_instance().publish_node_event(mpe) - + await self.get_federation_ready_lock().release_async() self.trainer.set_epochs(epochs) @@ -764,7 +764,7 @@ async def learning_cycle_finished(self): return False else: return current_round >= self.total_rounds - + async def resolve_missing_updates(self): """ Delegates the resolution strategy for missing updates to the current role behavior. @@ -778,7 +778,7 @@ async def resolve_missing_updates(self): """ logging.info(f"Using Role behavior: {self.rb.get_role_name()} conflict resolve strategy") return await self.rb.resolve_missing_updates() - + async def update_self_role(self): """ Checks whether a role update is required and performs the transition if necessary. @@ -806,7 +806,7 @@ async def update_self_role(self): logging.info(f"Sending role modification ACK to transferer: {source_to_notificate}") message = self.cm.create_message("control", "leadership_transfer_ack") asyncio.create_task(self.cm.send_message(source_to_notificate, message)) - + async def _learning_cycle(self): """ Main asynchronous loop for executing the Federated Learning process across multiple rounds. @@ -837,9 +837,9 @@ async def _learning_cycle(self): indent=2, title="Round information", ) - + await self.update_self_role() - + logging.info(f"Federation nodes: {self.federation_nodes}") await self.update_federation_nodes( await self.cm.get_addrs_current_connections(only_direct=True, myself=True) @@ -851,10 +851,10 @@ async def _learning_cycle(self): logging.info(f"Expected nodes: {expected_nodes}") direct_connections = await self.cm.get_addrs_current_connections(only_direct=True) undirected_connections = await self.cm.get_addrs_current_connections(only_undirected=True) - + logging.info(f"Direct connections: {direct_connections} | Undirected connections: {undirected_connections}") logging.info(f"[Role {self.rb.get_role_name()}] Starting learning cycle...") - + await self.aggregator.update_federation_nodes(expected_nodes) async with self._role_behavior_performance_lock: await self.rb.extended_learning_cycle() @@ -882,13 +882,13 @@ async def _learning_cycle(self): self.trainer.on_learning_cycle_end() await self.trainer.test() - + # Shutdown protocol await self._shutdown_protocol() - + async def _shutdown_protocol(self): logging.info("Starting graceful shutdown process...") - + # 1.- Publish Experiment Finish Event to the last update on modules logging.info("Publishing Experiment Finish Event...") efe = ExperimentFinishEvent() diff --git a/nebula/database/Dockerfile b/nebula/database/Dockerfile new file mode 100644 index 000000000..04ed953e9 --- /dev/null +++ b/nebula/database/Dockerfile @@ -0,0 +1,52 @@ +FROM postgres:17.5-alpine3.22 + +# Rename the official entrypoint so we can wrap it +RUN mv /usr/local/bin/docker-entrypoint.sh /usr/local/bin/docker-entrypoint.sh.orig + +# Copy SQL init file and custom entrypoint script +COPY /nebula/database/docker-entrypoint.sh /usr/local/bin/docker-entrypoint.sh +RUN chmod +x /usr/local/bin/docker-entrypoint.sh + +# Install Python 3.11.7 from source +RUN apk add --no-cache \ + gcc \ + g++ \ + musl-dev \ + make \ + openssl-dev \ + bzip2-dev \ + zlib-dev \ + xz-dev \ + readline-dev \ + sqlite-dev \ + libffi-dev \ + curl \ + tar \ + bash + +ENV PYTHON_VERSION=3.11.7 + +RUN curl -O https://www.python.org/ftp/python/${PYTHON_VERSION}/Python-${PYTHON_VERSION}.tgz && \ + tar -xzf Python-${PYTHON_VERSION}.tgz && \ + cd Python-${PYTHON_VERSION} && \ + ./configure --prefix=/usr/local --enable-optimizations && \ + make -j$(nproc) && \ + make install && \ + cd .. && rm -rf Python-${PYTHON_VERSION}* + +RUN python3.11 --version + +# Install uv (alternative to pip, very fast) +ADD https://astral.sh/uv/install.sh /uv-installer.sh +RUN sh /uv-installer.sh && rm /uv-installer.sh +ENV PATH="/root/.local/bin/:$PATH" + +# Install Python dependencies using uv +COPY pyproject.toml . +RUN uv python pin 3.11.7 +RUN uv sync --group database + +ENV PATH=".venv/bin:$PATH" + +# ENTRYPOINT ["/usr/local/bin/docker-entrypoint.sh"] +# CMD ["postgres"] diff --git a/nebula/database/docker-entrypoint.sh b/nebula/database/docker-entrypoint.sh new file mode 100644 index 000000000..a298ca23b --- /dev/null +++ b/nebula/database/docker-entrypoint.sh @@ -0,0 +1,18 @@ +#!/bin/sh +set -e + +# 1) Run the original entrypoint and wait for it to finish initialization +/usr/local/bin/docker-entrypoint.sh.orig "$@" + +# 2) Wait until PostgreSQL accepts connections to the configured database +echo "⏳ Waiting for PostgreSQL to be ready..." +until pg_isready -U "$POSTGRES_USER" -d "$POSTGRES_DB" >/dev/null 2>&1; do + sleep 1 +done + +# 3) Apply the SQL initialization script +echo "🚀 Applying init-configs.sql..." +psql -v ON_ERROR_STOP=1 \ + -U "$POSTGRES_USER" \ + -d "$POSTGRES_DB" \ + -f /docker-entrypoint-initdb.d/init-configs.sql diff --git a/nebula/database/init-configs.sql b/nebula/database/init-configs.sql new file mode 100644 index 000000000..a34b17841 --- /dev/null +++ b/nebula/database/init-configs.sql @@ -0,0 +1,63 @@ +-- -------------------------------------------------- +-- init_postgres.sql +-- -------------------------------------------------- + +-- 1) (Optional) If you need to create the database, uncomment: +-- CREATE DATABASE nebula; +-- \c nebula + +-- 2) Users table +CREATE TABLE IF NOT EXISTS users ( + "user" TEXT PRIMARY KEY, + password TEXT, + role TEXT +); + +-- 2) Nodes como JSONB +CREATE TABLE IF NOT EXISTS nodes ( + uid TEXT PRIMARY KEY, + idx TEXT, + ip TEXT, + port TEXT, + role TEXT, + neighbors TEXT[], + latitude TEXT, + longitude TEXT, + timestamp TEXT, + federation TEXT, + round TEXT, + scenario TEXT, + hash TEXT, + malicious TEXT +); + +-- 3) Configs como JSONB +DROP INDEX IF EXISTS idx_configs_config_gin; +DROP TABLE IF EXISTS configs; +CREATE TABLE configs ( + id SERIAL PRIMARY KEY, + config JSONB NOT NULL, + created_at TIMESTAMPTZ DEFAULT NOW() +); +CREATE INDEX idx_configs_config_gin ON configs USING GIN (config); + +-- 4) Scenarios table as JSONB +CREATE TABLE IF NOT EXISTS scenarios ( + name TEXT PRIMARY KEY, + username TEXT NOT NULL, + status TEXT, + start_time TEXT, + end_time TEXT, + config JSONB NOT NULL, + created_at TIMESTAMPTZ DEFAULT NOW() +); + +-- Index for fast JSONB queries on scenarios.config +CREATE INDEX IF NOT EXISTS idx_scenarios_config_gin + ON scenarios USING GIN (config); + +-- 5) Notes table +CREATE TABLE IF NOT EXISTS notes ( + scenario TEXT PRIMARY KEY, + scenario_notes TEXT +); diff --git a/nebula/database/pgweb/Dockerfile b/nebula/database/pgweb/Dockerfile new file mode 100644 index 000000000..c9bb132e0 --- /dev/null +++ b/nebula/database/pgweb/Dockerfile @@ -0,0 +1 @@ +FROM sosedoff/pgweb diff --git a/nebula/frontend/app.py b/nebula/frontend/app.py index 830a87c85..42f5b3a68 100755 --- a/nebula/frontend/app.py +++ b/nebula/frontend/app.py @@ -121,15 +121,6 @@ class Settings: logging.info(f"NEBULA_PRODUCTION: {settings.env_tag == 'prod'}") logging.info(f"NEBULA_DEPLOYMENT_PREFIX: {settings.prefix_tag}") -if "SECRET_KEY" not in os.environ: - logging.info("Generating SECRET_KEY") - os.environ["SECRET_KEY"] = os.urandom(24).hex() - logging.info(f"Saving SECRET_KEY to {settings.env_file}") - with open(settings.env_file, "a") as f: - f.write(f"SECRET_KEY={os.environ['SECRET_KEY']}\n") -else: - logging.info("SECRET_KEY already set") - app = FastAPI() app.add_middleware( SessionMiddleware, @@ -696,7 +687,7 @@ async def remove_scenario_by_name(scenario_name): await controller_post(url, data) -async def check_scenario_with_role(role, scenario_name): +async def check_scenario_with_role(role, scenario_name, user): """ Check if a specific scenario is allowed for the session's role. @@ -710,7 +701,7 @@ async def check_scenario_with_role(role, scenario_name): Raises: HTTPException: If the underlying HTTP GET request fails. """ - url = f"http://{settings.controller_host}:{settings.controller_port}/scenarios/check/{role}/{scenario_name}" + url = f"http://{settings.controller_host}:{settings.controller_port}/scenarios/check/{user}/{role}/{scenario_name}" check_data = await controller_get(url) return check_data.get("allowed", False) @@ -1611,7 +1602,7 @@ async def nebula_dashboard_monitor(scenario_name: str, request: Request, session "ip": node["ip"], "port": node["port"], "role": node["role"], - "neighbors": node["neighbors"], + "neighbors": " ".join(node["neighbors"]), "latitude": node["latitude"], "longitude": node["longitude"], "timestamp": node["timestamp"], @@ -1715,8 +1706,7 @@ async def nebula_update_node(scenario_name: str, request: Request): "ip": config["network_args"]["ip"], "port": str(config["network_args"]["port"]), "role": config["device_args"]["role"], - "malicious": config["device_args"]["malicious"], - "neighbors": config["network_args"]["neighbors"], + "neighbors": " ".join(config["network_args"]["neighbors"]), "latitude": config["mobility_args"]["latitude"], "longitude": config["mobility_args"]["longitude"], "timestamp": config["timestamp"], @@ -1893,7 +1883,7 @@ async def nebula_relaunch_scenario( if session["role"] == "demo": raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) elif session["role"] == "user": - if not await check_scenario_with_role(session["role"], scenario_name): + if not await check_scenario_with_role(session["role"], scenario_name, session["user"]): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) scenario_path = FileUtils.check_path(settings.config_dir, os.path.join(scenario_name, "scenario.json")) @@ -1934,7 +1924,7 @@ async def nebula_remove_scenario(scenario_name: str, session: dict = Depends(get if session["role"] == "demo": raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) elif session["role"] == "user": - if not await check_scenario_with_role(session["role"], scenario_name): + if not await check_scenario_with_role(session["role"], scenario_name, session["user"]): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) await remove_scenario(scenario_name, session["user"]) return RedirectResponse(url="/platform/dashboard") @@ -2163,7 +2153,8 @@ async def assign_available_gpu(scenario_data, role): running_gpus = [] # Obtain associated gpus of the running scenarios for scenario in running_scenarios: - scenario_gpus = json.loads(scenario["gpu_id"]) + config = json.loads(scenario["config"]) + scenario_gpus = config.get("gpu_id", []) # Obtain the list of gpus in use without duplicates for gpu in scenario_gpus: if gpu not in running_gpus: diff --git a/nebula/physical/api.py b/nebula/physical/api.py index 6d44428e7..3e7da1bec 100644 --- a/nebula/physical/api.py +++ b/nebula/physical/api.py @@ -350,7 +350,7 @@ def run(): if TRAINING_PROC and TRAINING_PROC.poll() is None: _json_abort(409, "Training already running") - cmd = ["python", "/home/dietpi/prueba/nebula/nebula/node.py", json_files[0]] + cmd = ["python", "/home/dietpi/test/nebula/nebula/node.py", json_files[0]] TRAINING_PROC = subprocess.Popen(cmd) return jsonify(pid=TRAINING_PROC.pid, state="running") @@ -379,8 +379,8 @@ def setup_new_run(): Expected multipart-form fields ------------------------------- - * **config** – JSON with scenario, network and security arguments - * **global_test** – shared evaluation dataset (`*.h5`) + * **config** – JSON with scenario, network and security arguments + * **global_test** – shared evaluation dataset (`*.h5`) * **train_set** – participant-specific training dataset (`*.h5`) The function rewrites paths inside *config*, validates neighbour IPs @@ -489,4 +489,4 @@ def setup_new_run(): # ----------------------------------------------------------------------------- if __name__ == "__main__": # Local testing: python main.py - app.run(host="0.0.0.0", port=8000, debug=False) \ No newline at end of file + app.run(host="0.0.0.0", port=8000, debug=False) diff --git a/pyproject.toml b/pyproject.toml index d54ff3dcd..7c7e666b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,9 @@ docs = [ "mkdocstrings[python]<1.0.0,>=0.26.2", ] controller = [ + "psycopg2-binary==2.9.10", + "asyncpg==0.30.0", + "passlib==1.7.4", "aiohttp==3.10.5", "aiosqlite==0.20.0", "argon2-cffi==23.1.0", @@ -80,6 +83,10 @@ controller = [ "scikit-image==0.24.0", "scikit-learn==1.5.1", ] +database = [ + "asyncpg==0.30.0", + "psycopg2-binary==2.9.10" +] core = [ "aiohttp==3.10.5", "async-timeout==4.0.3",