From e039f2a4bd257264984470d4bcba1e913dc4b027 Mon Sep 17 00:00:00 2001 From: FerTV Date: Wed, 2 Jul 2025 13:09:36 +0200 Subject: [PATCH 01/20] postgres docker created --- Makefile | 6 + app/deployer.py | 113 +- nebula/controller/controller.py | 20 +- nebula/controller/database.py | 1451 +++++++------------------- nebula/controller/scenarios.py | 21 + nebula/database/Dockerfile | 55 + nebula/database/docker-entrypoint.sh | 23 + nebula/database/init-configs.sql | 69 ++ nebula/frontend/app.py | 2 + pyproject.toml | 7 + 10 files changed, 694 insertions(+), 1073 deletions(-) create mode 100644 nebula/database/Dockerfile create mode 100644 nebula/database/docker-entrypoint.sh create mode 100644 nebula/database/init-configs.sql diff --git a/Makefile b/Makefile index 52d0d3569..01eb836b0 100644 --- a/Makefile +++ b/Makefile @@ -74,6 +74,12 @@ update-dockers: ## Update docker images else \ echo "Skipping nebula-controller docker build."; \ fi + @echo "🐳 Building nebula-database docker image. Do you want to continue (overrides existing image)? (y/n)" + @read ans; if [ "$${ans:-N}" = y ]; then \ + docker build -t nebula-database -f nebula/database/Dockerfile .; \ + else \ + echo "Skipping nebula-database docker build."; \ + fi @echo "" @echo "🐳 Building nebula-frontend docker image. Do you want to continue (overrides existing image)? (y/n)" @read ans; if [ "$${ans:-N}" = y ]; then \ diff --git a/app/deployer.py b/app/deployer.py index a0a5d83d7..c3696f413 100644 --- a/app/deployer.py +++ b/app/deployer.py @@ -761,7 +761,8 @@ def start(self): self.run_controller() logging.info("NEBULA Controller is running") - logging.info(f"NEBULA Databases created in {self.databases_dir}") + self.run_database() + logging.info(f"NEBULA Databases docker is running") self.run_frontend() logging.info(f"NEBULA Frontend is running at http://localhost:{self.frontend_port}") if self.production and self.prefix == "production": @@ -912,6 +913,112 @@ def run_frontend(self): # Add to metadata Deployer._add_container_to_metadata(frontend_container_name) + def run_database(self): + """ + Runs the Nebula database within a Docker container, ensuring the required Docker environment is available. + """ + if sys.platform == "win32": + if not os.path.exists("//./pipe/docker_Engine"): + raise Exception( + "Docker is not running, please check if Docker is running and Docker Compose is installed." + ) + else: + if not os.path.exists("/var/run/docker.sock"): + raise Exception( + "/var/run/docker.sock not found, please check if Docker is running and Docker Compose is installed." + ) + + network_name = f"{os.environ['USER']}_nebula-net-base" + + ############### + # POSTGRES DB # + ############### + + host_port = 54312 + + # Create the Docker network + base = DockerUtils.create_docker_network(network_name) + + client = docker.from_env() + + environment = { + "POSTGRES_USER": "nebula", + "POSTGRES_PASSWORD": "nebula", + "POSTGRES_DB": "nebula", + "NEBULA_DATABASES_PORT": self.controller_host, + } + + host_sql_path = os.path.join(self.root_path, "nebula/database/init-configs.sql") + container_sql_path = "/docker-entrypoint-initdb.d/init-configs.sql" + + host_config = client.api.create_host_config( + binds=[ + f"{host_sql_path}:{container_sql_path}", + ], + extra_hosts={"host.docker.internal": "host-gateway"}, + port_bindings={5432: host_port}, + ) + + networking_config = client.api.create_networking_config({ + f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.125") + }) + + container_id = client.api.create_container( + image="nebula-database", + name=f"{os.environ['USER']}_nebula-database", + detach=True, + environment=environment, + host_config=host_config, + networking_config=networking_config, + ) + + client.api.start(container_id) + + ################ + # POSTGRES WEB # + ################ + + pgweb_host_port = 8085 + pgweb_container_port = 8081 + + pgweb_host_config = client.api.create_host_config( + port_bindings={pgweb_container_port: pgweb_host_port}, + device_requests=[{ + "Driver": "nvidia", + "Count": -1, + "Capabilities": [["gpu"]], + }] if self.gpu_available else None, + ) + + pgweb_networking_config = client.api.create_networking_config({ + f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.135") + }) + + pgweb_container_name = f"{os.environ.get('USER')}_nebula-pgweb" + + pgweb_container_id = client.api.create_container( + image="nebula-pgweb", + name=pgweb_container_name, + detach=True, + host_config=pgweb_host_config, + networking_config=pgweb_networking_config, + ) + + client.api.start(pgweb_container_id) + + def stop_database(self): + """ + Stops and removes all NEBULA database Docker containers associated with the current user. + + Responsibilities: + - Detects running Docker containers with names starting with '_nebula-database'. + - Gracefully stops and removes these database containers. + + Typical use cases: + - Cleaning up database containers during shutdown or redeployment processes. + """ + DockerUtils.remove_containers_by_prefix(f"{os.environ['USER']}_nebula-database") + def run_controller(self): if sys.platform == "win32": if not os.path.exists("//./pipe/docker_Engine"): @@ -953,6 +1060,10 @@ def run_controller(self): "NEBULA_CONTROLLER_PORT": self.controller_port, "NEBULA_CONTROLLER_HOST": self.controller_host, "NEBULA_FRONTEND_PORT": self.frontend_port, + "DB_HOST": f"{os.environ['USER']}_nebula-database", + "DB_PORT": 5432, + "DB_USER": "nebula", + "DB_PASSWORD": "nebula", } volumes = ["/nebula", "/var/run/docker.sock"] diff --git a/nebula/controller/controller.py b/nebula/controller/controller.py index acf03c19f..4c7605036 100755 --- a/nebula/controller/controller.py +++ b/nebula/controller/controller.py @@ -7,6 +7,7 @@ import logging import os import re +import copy from typing import Annotated import aiohttp @@ -15,7 +16,7 @@ from fastapi import Body, FastAPI, Request, status, HTTPException, Path, File, UploadFile from fastapi.concurrency import asynccontextmanager -from nebula.controller.database import scenario_set_all_status_to_finished, scenario_set_status_to_finished +# from nebula.controller.database import scenario_set_all_status_to_finished, scenario_set_status_to_finished from nebula.controller.http_helpers import remote_get, remote_post_form from nebula.utils import DockerUtils @@ -105,12 +106,12 @@ def configure_logger(controller_log): @asynccontextmanager async def lifespan(app: FastAPI): - databases_dir: str = os.environ.get("NEBULA_DATABASES_DIR") + # databases_dir: str = os.environ.get("NEBULA_DATABASES_DIR") controller_log: str = os.environ.get("NEBULA_CONTROLLER_LOG") - from nebula.controller.database import initialize_databases + # from nebula.controller.database import initialize_databases - await initialize_databases(databases_dir) + # await initialize_databases(databases_dir) configure_logger(controller_log) @@ -309,13 +310,18 @@ async def run_scenario( validate_physical_fields(scenario_data) + db_scenario = copy.deepcopy(scenario_data) + # Manager for the actual scenario scenarioManagement = ScenarioManagement(scenario_data, user) + logging.info(f"[FER] scenario run_scenario {scenario_data}") + await update_scenario( scenario_name=scenarioManagement.scenario_name, start_time=scenarioManagement.start_date_scenario, end_time="", + scenario=db_scenario, scenario=scenario_data, status="running", role=role, @@ -460,11 +466,11 @@ async def update_scenario( dict: A message confirming the update. """ from nebula.controller.database import scenario_update_record - from nebula.controller.scenarios import Scenario + # from nebula.controller.scenarios import Scenario try: - scenario = Scenario.from_dict(scenario) - scenario_update_record(scenario_name, start_time, end_time, scenario, status, role, username) + logging.info(f"[FER] scenario controller.py {scenario}") + scenario_update_record(scenario_name, start_time, end_time, scenario, status, username) except Exception as e: logging.exception(f"Error updating scenario {scenario_name}: {e}") raise HTTPException(status_code=500, detail="Internal server error") diff --git a/nebula/controller/database.py b/nebula/controller/database.py index 7a012fd8a..017f84d3d 100755 --- a/nebula/controller/database.py +++ b/nebula/controller/database.py @@ -1,299 +1,51 @@ -import asyncio -import datetime -import json import logging import os -import sqlite3 - -import aiosqlite -from argon2 import PasswordHasher - -user_db_file_location = None -node_db_file_location = None -scenario_db_file_location = None -notes_db_file_location = None - -_node_lock = asyncio.Lock() - -PRAGMA_SETTINGS = [ - "PRAGMA journal_mode=WAL;", - "PRAGMA synchronous=NORMAL;", - "PRAGMA journal_size_limit=1048576;", - "PRAGMA cache_size=10000;", - "PRAGMA temp_store=MEMORY;", - "PRAGMA cache_spill=0;", -] - - -async def setup_database(db_file_location): - """ - Initializes the SQLite database with the required PRAGMA settings. - - This function: - - Connects asynchronously to the specified SQLite database file. - - Applies a predefined list of PRAGMA settings to configure the database. - - Commits the changes after applying the settings. - - Args: - db_file_location (str): Path to the SQLite database file. - - Exceptions: - PermissionError: Logged if the application lacks permission to create or modify the database file. - Exception: Logs any other unexpected error that occurs during setup. - """ - try: - async with aiosqlite.connect(db_file_location) as db: - for pragma in PRAGMA_SETTINGS: - await db.execute(pragma) - await db.commit() - except PermissionError: - logging.info("No permission to create the databases. Change the default databases directory") - except Exception as e: - logging.exception(f"An error has ocurred during setup_database: {e}") +import psycopg2 +import psycopg2.extras +from passlib.context import CryptContext +import datetime +import json +import asyncpg +import asyncio +from nebula.controller.scenarios import Scenario -async def ensure_columns(conn, table_name, desired_columns): - """ - Ensures that a table contains all the desired columns, adding any that are missing. +# --- Configuration --- +# Use environment variables for database credentials from the Docker Compose file +DATABASE_URL = f"postgresql://{os.environ.get('DB_USER')}:{os.environ.get('DB_PASSWORD')}@{os.environ.get('DB_HOST')}:{os.environ.get('DB_PORT')}/nebula" - This function: - - Retrieves the current columns of the specified table. - - Compares them with the desired columns. - - Adds any missing columns to the table using ALTER TABLE statements. +# Password hashing context (using Argon2) +pwd_context = CryptContext(schemes=["argon2"], deprecated="auto") - Args: - conn (aiosqlite.Connection): Active connection to the SQLite database. - table_name (str): Name of the table to check and modify. - desired_columns (dict): Dictionary mapping column names to their SQL definitions. +# Asynchronous lock for node updates +_node_lock = asyncio.Lock() - Note: - This operation is committed immediately after all changes are applied. - """ - _c = await conn.execute(f"PRAGMA table_info({table_name});") - existing_columns = [row[1] for row in await _c.fetchall()] - for column_name, column_definition in desired_columns.items(): - if column_name not in existing_columns: - await conn.execute(f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_definition};") - await conn.commit() +# --- Connection Management Helper Functions --- -async def initialize_databases(databases_dir): - """ - Initializes all required SQLite databases and their corresponding tables for the system. - - This function: - - Defines paths for user, node, scenario, and notes databases based on the provided directory. - - Sets up each database with appropriate PRAGMA settings. - - Creates necessary tables if they do not exist. - - Ensures all expected columns are present in each table, adding any missing ones. - - Creates a default admin user if no users are present. - - Args: - databases_dir (str): Path to the directory where the database files will be created or accessed. - - Note: - Default credentials (username and password) are taken from environment variables: - - NEBULA_DEFAULT_USER - - NEBULA_DEFAULT_PASSWORD - """ - global user_db_file_location, node_db_file_location, scenario_db_file_location, notes_db_file_location +def get_sync_conn(): + """Establishes a synchronous PostgreSQL connection.""" + return psycopg2.connect(DATABASE_URL) - user_db_file_location = os.path.join(databases_dir, "users.db") - node_db_file_location = os.path.join(databases_dir, "nodes.db") - scenario_db_file_location = os.path.join(databases_dir, "scenarios.db") - notes_db_file_location = os.path.join(databases_dir, "notes.db") - await setup_database(user_db_file_location) - await setup_database(node_db_file_location) - await setup_database(scenario_db_file_location) - await setup_database(notes_db_file_location) +async def get_async_conn(): + """Establishes an asynchronous PostgreSQL connection.""" + return await asyncpg.connect(DATABASE_URL) - async with aiosqlite.connect(user_db_file_location) as conn: - await conn.execute( - """ - CREATE TABLE IF NOT EXISTS users ( - user TEXT PRIMARY KEY, - password TEXT, - role TEXT - ); - """ - ) - desired_columns = {"user": "TEXT PRIMARY KEY", "password": "TEXT", "role": "TEXT"} - await ensure_columns(conn, "users", desired_columns) - - async with aiosqlite.connect(node_db_file_location) as conn: - await conn.execute( - """ - CREATE TABLE IF NOT EXISTS nodes ( - uid TEXT PRIMARY KEY, - idx TEXT, - ip TEXT, - port TEXT, - role TEXT, - neighbors TEXT, - latitude TEXT, - longitude TEXT, - timestamp TEXT, - federation TEXT, - round TEXT, - scenario TEXT, - hash TEXT, - malicious TEXT - ); - """ - ) - desired_columns = { - "uid": "TEXT PRIMARY KEY", - "idx": "TEXT", - "ip": "TEXT", - "port": "TEXT", - "role": "TEXT", - "neighbors": "TEXT", - "latitude": "TEXT", - "longitude": "TEXT", - "timestamp": "TEXT", - "federation": "TEXT", - "round": "TEXT", - "scenario": "TEXT", - "hash": "TEXT", - "malicious": "TEXT", - } - await ensure_columns(conn, "nodes", desired_columns) - - async with aiosqlite.connect(scenario_db_file_location) as conn: - await conn.execute( - """ - CREATE TABLE IF NOT EXISTS scenarios ( - name TEXT PRIMARY KEY, - start_time TEXT, - end_time TEXT, - title TEXT, - description TEXT, - deployment TEXT, - federation TEXT, - topology TEXT, - nodes TEXT, - nodes_graph TEXT, - n_nodes TEXT, - matrix TEXT, - random_topology_probability TEXT, - dataset TEXT, - iid TEXT, - partition_selection TEXT, - partition_parameter TEXT, - model TEXT, - agg_algorithm TEXT, - rounds TEXT, - logginglevel TEXT, - report_status_data_queue TEXT, - accelerator TEXT, - network_subnet TEXT, - network_gateway TEXT, - epochs TEXT, - attack_params TEXT, - reputation TEXT, - random_geo TEXT, - latitude TEXT, - longitude TEXT, - mobility TEXT, - mobility_type TEXT, - radius_federation TEXT, - scheme_mobility TEXT, - round_frequency TEXT, - mobile_participants_percent TEXT, - additional_participants TEXT, - schema_additional_participants TEXT, - status TEXT, - role TEXT, - username TEXT, - gpu_id TEXT - ); - """ - ) - desired_columns = { - "name": "TEXT PRIMARY KEY", - "start_time": "TEXT", - "end_time": "TEXT", - "title": "TEXT", - "description": "TEXT", - "deployment": "TEXT", - "federation": "TEXT", - "topology": "TEXT", - "nodes": "TEXT", - "nodes_graph": "TEXT", - "n_nodes": "TEXT", - "matrix": "TEXT", - "random_topology_probability": "TEXT", - "dataset": "TEXT", - "iid": "TEXT", - "partition_selection": "TEXT", - "partition_parameter": "TEXT", - "model": "TEXT", - "agg_algorithm": "TEXT", - "rounds": "TEXT", - "logginglevel": "TEXT", - "report_status_data_queue": "TEXT", - "accelerator": "TEXT", - "gpu_id": "TEXT", - "network_subnet": "TEXT", - "network_gateway": "TEXT", - "epochs": "TEXT", - "attack_params": "TEXT", - "reputation": "TEXT", - "random_geo": "TEXT", - "latitude": "TEXT", - "longitude": "TEXT", - "mobility": "TEXT", - "mobility_type": "TEXT", - "radius_federation": "TEXT", - "scheme_mobility": "TEXT", - "round_frequency": "TEXT", - "mobile_participants_percent": "TEXT", - "additional_participants": "TEXT", - "schema_additional_participants": "TEXT", - "status": "TEXT", - "role": "TEXT", - "username": "TEXT", - } - await ensure_columns(conn, "scenarios", desired_columns) - - async with aiosqlite.connect(notes_db_file_location) as conn: - await conn.execute( - """ - CREATE TABLE IF NOT EXISTS notes ( - scenario TEXT PRIMARY KEY, - scenario_notes TEXT - ); - """ - ) - desired_columns = {"scenario": "TEXT PRIMARY KEY", "scenario_notes": "TEXT"} - await ensure_columns(conn, "notes", desired_columns) - - username = os.environ.get("NEBULA_DEFAULT_USER", "admin") - password = os.environ.get("NEBULA_DEFAULT_PASSWORD", "admin") - if not list_users(): - add_user(username, password, "admin") - if not verify_hash_algorithm(username): - update_user(username, password, "admin") +# --- User Management Functions --- def list_users(all_info=False): """ Retrieves a list of users from the users database. - - Args: - all_info (bool): If True, returns full user records; otherwise, returns only usernames. Default is False. - - Returns: - list: A list of usernames or full user records depending on the all_info flag. """ - with sqlite3.connect(user_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - c.execute("SELECT * FROM users") - result = c.fetchall() + with get_sync_conn() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: + c.execute("SELECT * FROM users") + result = c.fetchall() if not all_info: + # In PostgreSQL, you can access columns by key from DictCursor result = [user["user"] for user in result] return result @@ -302,150 +54,102 @@ def list_users(all_info=False): def get_user_info(user): """ Fetches detailed information for a specific user from the users database. - - Args: - user (str): The username to retrieve information for. - - Returns: - sqlite3.Row or None: A row containing the user's information if found, otherwise None. """ - with sqlite3.connect(user_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - - command = "SELECT * FROM users WHERE user = ?" - c.execute(command, (user,)) - result = c.fetchone() - + with get_sync_conn() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: + c.execute("SELECT * FROM users WHERE \"user\" = %s", (user,)) + result = c.fetchone() return result def verify(user, password): """ Verifies whether the provided password matches the stored hashed password for a user. - - Args: - user (str): The username to verify. - password (str): The plain text password to check against the stored hash. - - Returns: - bool: True if the password is correct, False otherwise. """ - ph = PasswordHasher() - with sqlite3.connect(user_db_file_location) as conn: - c = conn.cursor() - - c.execute("SELECT password FROM users WHERE user = ?", (user,)) - result = c.fetchone() - if result: - try: - return ph.verify(result[0], password) - except: - return False + with get_sync_conn() as conn: + with conn.cursor() as c: + c.execute("SELECT password FROM users WHERE \"user\" = %s", (user,)) + result = c.fetchone() + if result: + try: + return pwd_context.verify(password, result[0]) + except: + return False return False def verify_hash_algorithm(user): """ Checks if the stored password hash for a user uses a supported Argon2 algorithm. - - Args: - user (str): The username to check (case-insensitive, converted to uppercase). - - Returns: - bool: True if the password hash starts with a valid Argon2 prefix, False otherwise. """ user = user.upper() argon2_prefixes = ("$argon2i$", "$argon2id$") - - with sqlite3.connect(user_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - - c.execute("SELECT password FROM users WHERE user = ?", (user,)) - result = c.fetchone() - if result: - password_hash = result["password"] - return password_hash.startswith(argon2_prefixes) - + with get_sync_conn() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: + c.execute("SELECT password FROM users WHERE \"user\" = %s", (user,)) + result = c.fetchone() + if result: + password_hash = result["password"] + return password_hash.startswith(argon2_prefixes) return False def delete_user_from_db(user): """ Deletes a user record from the users database. - - Args: - user (str): The username of the user to be deleted. """ - with sqlite3.connect(user_db_file_location) as conn: - c = conn.cursor() - c.execute("DELETE FROM users WHERE user = ?", (user,)) + with get_sync_conn() as conn: + with conn.cursor() as c: + c.execute("DELETE FROM users WHERE \"user\" = %s", (user,)) + conn.commit() def add_user(user, password, role): """ Adds a new user to the users database with a hashed password. - - Args: - user (str): The username to add (stored in uppercase). - password (str): The plain text password to hash and store. - role (str): The role assigned to the user. """ - ph = PasswordHasher() - with sqlite3.connect(user_db_file_location) as conn: - c = conn.cursor() - c.execute( - "INSERT INTO users VALUES (?, ?, ?)", - (user.upper(), ph.hash(password), role), - ) + with get_sync_conn() as conn: + with conn.cursor() as c: + hashed_password = pwd_context.hash(password) + c.execute( + "INSERT INTO users (\"user\", password, role) VALUES (%s, %s, %s)", + (user.upper(), hashed_password, role), + ) + conn.commit() def update_user(user, password, role): """ Updates the password and role of an existing user in the users database. - - Args: - user (str): The username to update (case-insensitive, stored as uppercase). - password (str): The new plain text password to hash and store. - role (str): The new role to assign to the user. """ - ph = PasswordHasher() - with sqlite3.connect(user_db_file_location) as conn: - c = conn.cursor() - c.execute( - "UPDATE users SET password = ?, role = ? WHERE user = ?", - (ph.hash(password), role, user.upper()), - ) + with get_sync_conn() as conn: + with conn.cursor() as c: + hashed_password = pwd_context.hash(password) + c.execute( + "UPDATE users SET password = %s, role = %s WHERE \"user\" = %s", + (hashed_password, role, user.upper()), + ) + conn.commit() +# --- Node Management Functions --- def list_nodes(scenario_name=None, sort_by="idx"): """ Retrieves a list of nodes from the nodes database, optionally filtered by scenario and sorted. - - Args: - scenario_name (str, optional): Name of the scenario to filter nodes by. If None, returns all nodes. - sort_by (str): Column name to sort the results by. Defaults to "idx". - - Returns: - list or None: A list of sqlite3.Row objects representing nodes, or None if an error occurs. """ try: - with sqlite3.connect(node_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - - if scenario_name: - command = "SELECT * FROM nodes WHERE scenario = ? ORDER BY " + sort_by + ";" - c.execute(command, (scenario_name,)) - else: - command = "SELECT * FROM nodes ORDER BY " + sort_by + ";" - c.execute(command) - - result = c.fetchall() + with get_sync_conn() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: + if scenario_name: + command = f"SELECT * FROM nodes WHERE scenario = %s ORDER BY {psycopg2.extensions.AsIs(sort_by)};" + c.execute(command, (scenario_name,)) + else: + command = f"SELECT * FROM nodes ORDER BY {psycopg2.extensions.AsIs(sort_by)};" + c.execute(command) + result = c.fetchall() return result - except sqlite3.Error as e: + except psycopg2.Error as e: print(f"Error occurred while listing nodes: {e}") return None @@ -453,737 +157,365 @@ def list_nodes(scenario_name=None, sort_by="idx"): def list_nodes_by_scenario_name(scenario_name): """ Fetches all nodes associated with a specific scenario, ordered by their index as integers. - - Args: - scenario_name (str): The name of the scenario to filter nodes by. - - Returns: - list or None: A list of sqlite3.Row objects for nodes in the scenario, or None if an error occurs. """ try: - with sqlite3.connect(node_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - - command = "SELECT * FROM nodes WHERE scenario = ? ORDER BY CAST(idx AS INTEGER) ASC;" - c.execute(command, (scenario_name,)) - result = c.fetchall() - + with get_sync_conn() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: + # Use a specific cast in PostgreSQL to order by integer value of idx + command = "SELECT * FROM nodes WHERE scenario = %s ORDER BY CAST(idx AS INTEGER) ASC;" + c.execute(command, (scenario_name,)) + result = c.fetchall() return result - except sqlite3.Error as e: + except psycopg2.Error as e: print(f"Error occurred while listing nodes by scenario name: {e}") return None async def update_node_record( - node_uid, - idx, - ip, - port, - role, - neighbors, - latitude, - longitude, - timestamp, - federation, - federation_round, - scenario, - run_hash, - malicious, + node_uid, idx, ip, port, role, neighbors, latitude, longitude, + timestamp, federation, federation_round, scenario, run_hash, malicious, ): """ Inserts or updates a node record in the database for a given scenario, ensuring thread-safe access. - - Args: - node_uid (str): Unique identifier of the node. - idx (str): Index or identifier within the scenario. - ip (str): IP address of the node. - port (str): Port used by the node. - role (str): Role of the node in the federation. - neighbors (str): Neighbors of the node (serialized). - latitude (str): Geographic latitude of the node. - longitude (str): Geographic longitude of the node. - timestamp (str): Timestamp of the last update. - federation (str): Federation identifier the node belongs to. - federation_round (str): Current federation round. - scenario (str): Scenario name the node is part of. - run_hash (str): Hash of the current run/state. - malicious (str): Indicator if the node is malicious. - - Returns: - dict or None: The updated or inserted node record as a dictionary, or None if insertion/update failed. """ - global _node_lock async with _node_lock: - async with aiosqlite.connect(node_db_file_location) as conn: - conn.row_factory = aiosqlite.Row - _c = await conn.cursor() - - # Check if the node already exists - await _c.execute("SELECT * FROM nodes WHERE uid = ? AND scenario = ?;", (node_uid, scenario)) - result = await _c.fetchone() - - if result is None: - # Insert new node - await _c.execute( - "INSERT INTO nodes (uid, idx, ip, port, role, neighbors, latitude, longitude, timestamp, federation, round, scenario, hash, malicious) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);", - ( - node_uid, - idx, - ip, - port, - role, - neighbors, - latitude, - longitude, - timestamp, - federation, - federation_round, - scenario, - run_hash, - malicious, - ), - ) - else: - # Update existing node - await _c.execute( - "UPDATE nodes SET idx = ?, ip = ?, port = ?, role = ?, neighbors = ?, latitude = ?, longitude = ?, timestamp = ?, federation = ?, round = ?, hash = ?, malicious = ? WHERE uid = ? AND scenario = ?;", - ( - idx, - ip, - port, - role, - neighbors, - latitude, - longitude, - timestamp, - federation, - federation_round, - run_hash, - malicious, - node_uid, - scenario, - ), + # CORRECTED: Use get_async_conn() directly as the async context manager. + # This assumes get_async_conn() returns a connection object that supports the + # asynchronous context manager protocol (i.e., has __aenter__ and __aexit__). + # A typical implementation of get_async_conn() using asyncpg would be 'return pool.acquire()'. + async with get_async_conn() as conn: + # Use a connection-bound cursor + async with conn.transaction(): + # Use a SELECT ... FOR UPDATE to lock the row and avoid race conditions + result = await conn.fetchrow( + "SELECT * FROM nodes WHERE uid = $1 AND scenario = $2 FOR UPDATE;", + node_uid, scenario ) - await conn.commit() - - # Fetch the updated or newly inserted row - await _c.execute("SELECT * FROM nodes WHERE uid = ? AND scenario = ?;", (node_uid, scenario)) - updated_row = await _c.fetchone() - return dict(updated_row) if updated_row else None + if result is None: + # Insert new node + await conn.execute( + """ + INSERT INTO nodes (uid, idx, ip, port, role, neighbors, latitude, longitude, + timestamp, federation, round, scenario, hash, malicious) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14); + """, + node_uid, idx, ip, port, role, neighbors, latitude, longitude, + timestamp, federation, federation_round, scenario, run_hash, malicious, + ) + else: + # Update existing node + await conn.execute( + """ + UPDATE nodes SET idx = $1, ip = $2, port = $3, role = $4, neighbors = $5, + latitude = $6, longitude = $7, timestamp = $8, federation = $9, round = $10, + hash = $11, malicious = $12 + WHERE uid = $13 AND scenario = $14; + """, + idx, ip, port, role, neighbors, latitude, longitude, + timestamp, federation, federation_round, run_hash, malicious, + node_uid, scenario, + ) + + # Fetch the updated or newly inserted row + updated_row = await conn.fetchrow("SELECT * from nodes WHERE uid = $1 AND scenario = $2;", node_uid, scenario) + return dict(updated_row) if updated_row else None def remove_all_nodes(): """ Deletes all node records from the nodes database. - - This operation removes every entry in the nodes table. - - Returns: - None """ - with sqlite3.connect(node_db_file_location) as conn: - c = conn.cursor() - command = "DELETE FROM nodes;" - c.execute(command) + with get_sync_conn() as conn: + with conn.cursor() as c: + c.execute("TRUNCATE nodes;") # TRUNCATE is faster than DELETE FROM for clearing tables + conn.commit() def remove_nodes_by_scenario_name(scenario_name): """ Deletes all nodes associated with a specific scenario from the database. - - Args: - scenario_name (str): The name of the scenario whose nodes should be removed. - - Returns: - None """ - with sqlite3.connect(node_db_file_location) as conn: - c = conn.cursor() - command = "DELETE FROM nodes WHERE scenario = ?;" - c.execute(command, (scenario_name,)) + with get_sync_conn() as conn: + with conn.cursor() as c: + c.execute("DELETE FROM nodes WHERE scenario = %s;", (scenario_name,)) + conn.commit() +# --- Scenario Management Functions --- def get_all_scenarios(username, role, sort_by="start_time"): """ - Retrieve all scenarios from the database filtered by user role and sorted by a specified field. - - Parameters: - username (str): The username of the requesting user. - role (str): The role of the user, e.g., "admin" or regular user. - sort_by (str, optional): The field name to sort the results by. Defaults to "start_time". - - Returns: - list[sqlite3.Row]: A list of scenario records as SQLite Row objects. - - Behavior: - - Admin users retrieve all scenarios. - - Non-admin users retrieve only scenarios associated with their username. - - Sorting by "start_time" applies custom datetime ordering. - - Other sort fields are applied directly in the ORDER BY clause. - """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - if role == "admin": - if sort_by == "start_time": - command = """ - SELECT * FROM scenarios - ORDER BY strftime('%Y-%m-%d %H:%M:%S', substr(start_time, 7, 4) || '-' || substr(start_time, 4, 2) || '-' || substr(start_time, 1, 2) || ' ' || substr(start_time, 12, 8)); - """ - c.execute(command) - else: - command = "SELECT * FROM scenarios ORDER BY ?;" - c.execute(command, (sort_by,)) - else: - if sort_by == "start_time": - command = """ - SELECT * FROM scenarios - WHERE username = ? - ORDER BY strftime('%Y-%m-%d %H:%M:%S', substr(start_time, 7, 4) || '-' || substr(start_time, 4, 2) || '-' || substr(start_time, 1, 2) || ' ' || substr(start_time, 12, 8)); - """ - c.execute(command, (username,)) - else: - command = "SELECT * FROM scenarios WHERE username = ? ORDER BY ?;" - c.execute( - command, - ( - username, - sort_by, - ), - ) - result = c.fetchall() - - return result - - -def get_all_scenarios_and_check_completed(username, role, sort_by="start_time"): - """ - Retrieve all scenarios with detailed fields and update the status of running scenarios if their federation is completed. - - Parameters: - username (str): The username of the requesting user. - role (str): The role of the user, e.g., "admin" or regular user. - sort_by (str, optional): The field name to sort the results by. Defaults to "start_time". - - Returns: - list[sqlite3.Row]: A list of scenario records including name, username, title, start_time, model, dataset, rounds, and status. - - Behavior: - - Admin users retrieve all scenarios. - - Non-admin users retrieve only scenarios associated with their username. - - Scenarios are sorted by start_time with special handling for null or empty values. - - For scenarios with status "running", checks if federation is completed: - - If completed, updates the scenario status to "completed". - - Refreshes the returned scenario list after updates. - """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - - if role == "admin": - if sort_by == "start_time": - command = """ - SELECT name, username, title, start_time, model, dataset, rounds, status FROM scenarios - ORDER BY - CASE - WHEN start_time IS NULL OR start_time = '' THEN 1 - ELSE 0 - END, - strftime( - '%Y-%m-%d %H:%M:%S', - substr(start_time, 7, 4) || '-' || substr(start_time, 4, 2) || '-' || substr(start_time, 1, 2) || ' ' || substr(start_time, 12, 8) - ); - """ - c.execute(command) - else: - command = "SELECT name, username, title, start_time, model, dataset, rounds, status FROM scenarios ORDER BY ?;" - c.execute(command, (sort_by,)) - result = c.fetchall() - else: - if sort_by == "start_time": - command = """ - SELECT name, username, title, start_time, model, dataset, rounds, status FROM scenarios - WHERE username = ? - ORDER BY - CASE - WHEN start_time IS NULL OR start_time = '' THEN 1 - ELSE 0 - END, - strftime( - '%Y-%m-%d %H:%M:%S', - substr(start_time, 7, 4) || '-' || substr(start_time, 4, 2) || '-' || substr(start_time, 1, 2) || ' ' || substr(start_time, 12, 8) - ); - """ - c.execute(command, (username,)) - else: - command = "SELECT name, username, title, start_time, model, dataset, rounds, status FROM scenarios WHERE username = ? ORDER BY ?;" - c.execute( - command, - ( - username, - sort_by, - ), - ) + Retrieves all scenarios from the database, accessing the fields + inside the 'config' (JSONB) column. + Filters by user role and sorts by the specified field. + """ + # Safe list of allowed sorting fields to prevent SQL injection + allowed_sort_fields = ["start_time", "title", "username", "status", "name"] + if sort_by not in allowed_sort_fields: + sort_by = "start_time" # Use a safe default value + + # Building the ORDER BY clause + if sort_by == "start_time": + # Special sorting for dates saved as text, handling nulls/empty strings + order_by_clause = """ + ORDER BY + CASE + WHEN start_time IS NULL OR start_time = '' THEN 1 + ELSE 0 + END, + to_timestamp(start_time, 'YYYY/MM/DD HH24:MI:SS') DESC + """ + else: + # Sorting is built dynamically but safely, + # since 'sort_by' has been validated against the allowed list. + order_by_clause = f"ORDER BY config->>'{sort_by}'" + + with get_sync_conn() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: + # Base query. It's flexible to return 'name' and the full 'config' object. + # The code that calls this function can access any field from the config object. + command = "SELECT name, username, start_time, end_time, status, config FROM scenarios" + params = [] + + # Conditionally add the WHERE filter if the role is not admin + if role != "admin": + command += " WHERE config->>'username' = %s" + params.append(username) + + # Combine the query with the sorting clause + full_command = f"{command} {order_by_clause};" + + c.execute(full_command, tuple(params)) result = c.fetchall() - - for scenario in result: - if scenario["status"] == "running": - if check_scenario_federation_completed(scenario["name"]): - scenario_set_status_to_completed(scenario["name"]) - result = get_all_scenarios(username, role) - + return result -def scenario_update_record(name, start_time, end_time, scenario, status, role, username): - """ - Insert a new scenario record or update an existing one in the database based on the scenario name. - - Parameters: - name (str): The unique name identifier of the scenario. - start_time (str): The start time of the scenario. - end_time (str): The end time of the scenario. - scenario (object): An object containing detailed scenario attributes. - status (str): The current status of the scenario. - role (str): The role of the user performing the operation. - username (str): The username of the user performing the operation. - - Behavior: - - Checks if a scenario with the given name exists. - - If not, inserts a new record with all scenario details. - - If exists, updates the existing record with the provided data. - - Commits the transaction to persist changes. +def get_all_scenarios_and_check_completed(username, role, sort_by="start_time"): """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - - select_command = "SELECT * FROM scenarios WHERE name = ?;" - c.execute(select_command, (name,)) - result = c.fetchone() - - if result is None: - insert_command = """ - INSERT INTO scenarios ( + Retrieves all scenarios from the JSONB field, sorts them, and updates the status if necessary. + Returns a list of dictionaries, where each dictionary represents a scenario. + """ + # Safe list of allowed sorting fields to prevent SQL injection. + allowed_sort_fields = ["start_time", "title", "username", "status", "name"] + if sort_by not in allowed_sort_fields: + sort_by = "start_time" # Safe default value + + # Building the ORDER BY clause + if sort_by == "start_time": + order_by_clause = """ + ORDER BY + CASE + WHEN start_time IS NULL OR start_time = '' THEN 1 + ELSE 0 + END, + to_timestamp(start_time, 'YYYY/MM/DD HH24:MI:SS') DESC + """ + else: + # We use a safe f-string because the value of sort_by has been validated + order_by_clause = f"ORDER BY config->>'{sort_by}'" + + with get_sync_conn() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: + # Base query that extracts fields from the JSONB using the ->> operator + command = f""" + SELECT name, - start_time, - end_time, - title, - description, - deployment, - federation, - topology, - nodes, - nodes_graph, - n_nodes, - matrix, - random_topology_probability, - dataset, - iid, - partition_selection, - partition_parameter, - model, - agg_algorithm, - rounds, - logginglevel, - report_status_data_queue, - accelerator, - gpu_id, - network_subnet, - network_gateway, - epochs, - attack_params, - reputation, - random_geo, - latitude, - longitude, - mobility, - mobility_type, - radius_federation, - scheme_mobility, - round_frequency, - mobile_participants_percent, - additional_participants, - schema_additional_participants, + username, status, - role, - username - ) VALUES ( - ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? - ); - """ - c.execute( - insert_command, - ( - name, start_time, end_time, - scenario.scenario_title, - scenario.scenario_description, - scenario.deployment, - scenario.federation, - scenario.topology, - json.dumps(scenario.nodes), - json.dumps(scenario.nodes_graph), - scenario.n_nodes, - json.dumps(scenario.matrix), - scenario.random_topology_probability, - scenario.dataset, - scenario.iid, - scenario.partition_selection, - scenario.partition_parameter, - scenario.model, - scenario.agg_algorithm, - scenario.rounds, - scenario.logginglevel, - scenario.report_status_data_queue, - scenario.accelerator, - json.dumps(scenario.gpu_id), - scenario.network_subnet, - scenario.network_gateway, - scenario.epochs, - json.dumps(scenario.attack_params), - json.dumps(scenario.reputation), - scenario.random_geo, - scenario.latitude, - scenario.longitude, - scenario.mobility, - scenario.mobility_type, - scenario.radius_federation, - scenario.scheme_mobility, - scenario.round_frequency, - scenario.mobile_participants_percent, - json.dumps(scenario.additional_participants), - scenario.schema_additional_participants, - status, - role, - username, - ), - ) - else: - update_command = """ - UPDATE scenarios SET - start_time = ?, - end_time = ?, - title = ?, - description = ?, - deployment = ?, - federation = ?, - topology = ?, - nodes = ?, - nodes_graph = ?, - n_nodes = ?, - matrix = ?, - random_topology_probability = ?, - dataset = ?, - iid = ?, - partition_selection = ?, - partition_parameter = ?, - model = ?, - agg_algorithm = ?, - rounds = ?, - logginglevel = ?, - report_status_data_queue = ?, - accelerator = ?, - gpu_id = ?, - network_subnet = ?, - network_gateway = ?, - epochs = ?, - attack_params = ?, - reputation = ?, - random_geo = ?, - latitude = ?, - longitude = ?, - mobility = ?, - mobility_type = ?, - radius_federation = ?, - scheme_mobility = ?, - round_frequency = ?, - mobile_participants_percent = ?, - additional_participants = ?, - schema_additional_participants = ?, - status = ?, - role = ?, - username = ? - WHERE name = ?; + config->>'title' AS title, + config->>'model' AS model, + config->>'dataset' AS dataset, + config->>'rounds' AS rounds, + config -- Return the full config object + FROM scenarios """ - c.execute( - update_command, - ( - start_time, - end_time, - scenario.scenario_title, - scenario.scenario_description, - scenario.deployment, - scenario.federation, - scenario.topology, - json.dumps(scenario.nodes), - json.dumps(scenario.nodes_graph), - scenario.n_nodes, - json.dumps(scenario.matrix), - scenario.random_topology_probability, - scenario.dataset, - scenario.iid, - scenario.partition_selection, - scenario.partition_parameter, - scenario.model, - scenario.agg_algorithm, - scenario.rounds, - scenario.logginglevel, - scenario.report_status_data_queue, - scenario.accelerator, - json.dumps(scenario.gpu_id), - scenario.network_subnet, - scenario.network_gateway, - scenario.epochs, - json.dumps(scenario.attack_params), - json.dumps(scenario.reputation), - scenario.random_geo, - scenario.latitude, - scenario.longitude, - scenario.mobility, - scenario.mobility_type, - scenario.radius_federation, - scenario.scheme_mobility, - scenario.round_frequency, - scenario.mobile_participants_percent, - json.dumps(scenario.additional_participants), - scenario.schema_additional_participants, - status, - role, - username, - name, - ), - ) - - conn.commit() + params = [] + if role != "admin": + command += " WHERE config->>'username' = %s" + params.append(username) + + command += f" {order_by_clause};" + + c.execute(command, tuple(params)) + result_dicts = c.fetchall() # This already returns a list of DictRow objects (which act like dicts) + + # Logic to check for completed scenarios and update status. + # It's important to modify the `result_dicts` directly or handle the recursion carefully. + + # Create a mutable list from DictRow objects for potential status updates + scenarios_to_return = [dict(s) for s in result_dicts] + + re_fetch_required = False + for scenario in scenarios_to_return: + if scenario["status"] == "running": + if check_scenario_federation_completed(scenario["name"]): + scenario_set_status_to_completed(scenario["name"]) + # If a scenario's status changes, it's best to re-query the database + # to ensure the most up-to-date information is returned for ALL scenarios. + # This avoids inconsistencies if multiple scenarios complete in a single call. + re_fetch_required = True + break # Break after finding one completed scenario to trigger re-fetch + + if re_fetch_required: + # If any status was updated, recursively call the function again to get the fresh data. + # This ensures the returned list reflects the updated status from the DB. + # Make sure `get_all_scenarios` is indeed this function if you rename it. + return get_all_scenarios(username, role, sort_by) + + return scenarios_to_return + + +def scenario_update_record(name, start_time, end_time, scenario, status, username): + """ + Inserts or updates a scenario record using the PostgreSQL "UPSERT" pattern. + All configuration is saved in the 'config' column of type JSONB. + """ + with get_sync_conn() as conn: + with conn.cursor() as c: + command = """ + INSERT INTO scenarios (name, start_time, end_time, username, status, config) + VALUES (%s, %s, %s, %s, %s, %s) + ON CONFLICT (name) DO UPDATE SET + config = scenarios.config || excluded.config; + """ + logging.info(f"[FER] scenario database.py {json.dumps(scenario, indent=2)}") + c.execute(command, (name, start_time, end_time, username, status, json.dumps(scenario, indent=2))) + conn.commit() def scenario_set_all_status_to_finished(): """ - Set the status of all currently running scenarios to "finished" and update their end time to the current datetime. - - Behavior: - - Finds all scenarios with status "running". - - Updates their status to "finished". - - Sets the end_time to the current timestamp. - - Commits the changes to the database. + Sets the status of all 'running' scenarios to 'finished' + and updates their 'end_time' within the JSONB. """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - current_time = str(datetime.datetime.now()) - c.execute("UPDATE scenarios SET status = 'finished', end_time = ? WHERE status = 'running';", (current_time,)) - conn.commit() + with get_sync_conn() as conn: + with conn.cursor() as c: + current_time = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S') + # We use jsonb_set to update specific fields within the JSONB. + # We nest the calls to update multiple fields. + command = """ + UPDATE scenarios + SET status = 'finished', end_time = %s + WHERE status = 'running'; + """ + c.execute(command, (json.dumps(current_time),)) + conn.commit() def scenario_set_status_to_finished(scenario_name): """ - Set the status of a specific scenario to "finished" and update its end time to the current datetime. - - Parameters: - scenario_name (str): The unique name identifier of the scenario to update. - - Behavior: - - Updates the scenario's status to "finished". - - Sets the end_time to the current timestamp. - - Commits the update to the database. + Sets the status of a specific scenario to 'finished' and updates its 'end_time'. """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - current_time = str(datetime.datetime.now()) - c.execute( - "UPDATE scenarios SET status = 'finished', end_time = ? WHERE name = ?;", (current_time, scenario_name) - ) - conn.commit() + with get_sync_conn() as conn: + with conn.cursor() as c: + current_time = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S') + command = """ + UPDATE scenarios + SET config = jsonb_set( + jsonb_set(config, '{status}', '"finished"'), + '{end_time}', %s::jsonb + ) + WHERE name = %s; + """ + c.execute(command, (json.dumps(current_time), scenario_name)) + conn.commit() def scenario_set_status_to_completed(scenario_name): """ - Set the status of a specific scenario to "completed". - - Parameters: - scenario_name (str): The unique name identifier of the scenario to update. - - Behavior: - - Updates the scenario's status to "completed". - - Commits the change to the database. + Sets the status of a specific scenario to 'completed'. """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - c.execute("UPDATE scenarios SET status = 'completed' WHERE name = ?;", (scenario_name,)) - conn.commit() + with get_sync_conn() as conn: + with conn.cursor() as c: + command = """ + UPDATE scenarios + SET status = "completed" + WHERE name = %s; + """ + c.execute(command, (scenario_name,)) + conn.commit() def get_running_scenario(username=None, get_all=False): """ - Retrieve running or completed scenarios from the database, optionally filtered by username. - - Parameters: - username (str, optional): The username to filter scenarios by. If None, no user filter is applied. - get_all (bool, optional): If True, returns all matching scenarios; otherwise returns only one. Defaults to False. - - Returns: - sqlite3.Row or list[sqlite3.Row]: A single scenario record or a list of scenario records matching the criteria. - - Behavior: - - Filters scenarios with status "running". - - Applies username filter if provided. - - Returns either one or all matching records depending on get_all. - """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - - if username: - command = """ - SELECT * FROM scenarios - WHERE (status = ?) AND username = ?; - """ - c.execute(command, ("running", username)) - - result = c.fetchone() - else: - command = "SELECT * FROM scenarios WHERE status = ?;" - c.execute(command, ("running",)) - if get_all: - result = c.fetchall() + Retrieves scenarios with a 'running' status, optionally filtered by user. + """ + with get_sync_conn() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: + params = ["running"] + command = "SELECT name, config FROM scenarios WHERE config->>'status' = %s" + + if username: + command += " AND config->>'username' = %s" + params.append(username) + + c.execute(command, tuple(params)) + + if get_all: + raw_results = c.fetchall() + if raw_results: + processed_results = [] + for row in raw_results: + processed_results.append({ + 'name': row['name'], + 'config': row['config'] + }) + result = processed_results else: result = c.fetchone() - + if result: + result = result['config'] return result def get_completed_scenario(): """ - Retrieve a single scenario with status "completed" from the database. - - Returns: - sqlite3.Row: A scenario record with status "completed", or None if no such scenario exists. - - Behavior: - - Fetches the first scenario found with status "completed". + Retrieves a single scenario with a 'completed' status. """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - command = "SELECT * FROM scenarios WHERE status = ?;" - c.execute(command, ("completed",)) - result = c.fetchone() - + with get_sync_conn() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: + command = "SELECT name, config FROM scenarios WHERE config->>'status' = %s;" + c.execute(command, ("completed",)) + result = c.fetchone() return result def get_scenario_by_name(scenario_name): """ - Retrieve a scenario record by its unique name. - - Parameters: - scenario_name (str): The unique name identifier of the scenario. - - Returns: - sqlite3.Row: The scenario record matching the given name, or None if not found. + Retrieves the complete record of a scenario by its name. """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - c.execute("SELECT * FROM scenarios WHERE name = ?;", (scenario_name,)) - result = c.fetchone() - + with get_sync_conn() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: + c.execute("SELECT name, start_time, end_time, username, status, config FROM scenarios WHERE name = %s;", (scenario_name,)) + result = c.fetchone() return result def get_user_by_scenario_name(scenario_name): """ - Retrieve the username associated with a given scenario name. - - Parameters: - scenario_name (str): The unique name identifier of the scenario. - - Returns: - str: The username linked to the specified scenario, or None if not found. - """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - c.execute("SELECT username FROM scenarios WHERE name = ?;", (scenario_name,)) - result = c.fetchone() - - return result["username"] - - -def remove_scenario_by_name(scenario_name): - """ - Delete a scenario from the database by its unique name. - - Parameters: - scenario_name (str): The unique name identifier of the scenario to be removed. - - Behavior: - - Removes the scenario record matching the given name. - - Commits the deletion to the database. + Retrieves the username associated with a scenario from the JSONB field. """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - c.execute("DELETE FROM scenarios WHERE name = ?;", (scenario_name,)) - conn.commit() - + with get_sync_conn() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: + c.execute("SELECT username FROM scenarios WHERE name = %s;", (scenario_name,)) + result = c.fetchone() + return result["username"] if result else None +# Placeholder for `check_scenario_federation_completed`. +# You need to implement this based on your application logic. def check_scenario_federation_completed(scenario_name): """ - Check if all nodes in a given scenario have completed the required federation rounds. - - Parameters: - scenario_name (str): The unique name identifier of the scenario to check. - - Returns: - bool: True if all nodes have completed the total rounds specified for the scenario, False otherwise or if an error occurs. - - Behavior: - - Retrieves the total number of rounds defined for the scenario. - - Fetches the current round progress of all nodes in that scenario. - - Returns True only if every node has reached the total rounds. - - Handles database errors and missing scenario cases gracefully. + Placeholder function to check if a scenario's federation is completed. + This should be implemented based on your specific application logic. + For example, it could check if the last round has been reached. """ - try: - # Connect to the scenario database to get the total rounds for the scenario - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - c.execute("SELECT rounds FROM scenarios WHERE name = ?;", (scenario_name,)) - scenario = c.fetchone() - - if not scenario: - raise ValueError(f"Scenario '{scenario_name}' not found.") - - total_rounds = scenario["rounds"] - - # Connect to the node database to check the rounds for each node - with sqlite3.connect(node_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - c.execute("SELECT round FROM nodes WHERE scenario = ?;", (scenario_name,)) - nodes = c.fetchall() - - if len(nodes) == 0: - return False - - # Check if all nodes have completed the total rounds - total_rounds_str = str(total_rounds) - return all(str(node["round"]) == total_rounds_str for node in nodes) - - except sqlite3.Error as e: - print(f"Database error: {e}") - return False - except Exception as e: - print(f"An error occurred: {e}") - return False - + print(f"Checking if scenario '{scenario_name}' is completed...") + # Example logic: + # return get_current_round(scenario_name) >= get_total_rounds(scenario_name) + return False # Placeholder value def check_scenario_with_role(role, scenario_name): """ @@ -1196,21 +528,17 @@ def check_scenario_with_role(role, scenario_name): Returns: bool: True if a scenario with the given role and name exists, False otherwise. """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - c.execute( - "SELECT * FROM scenarios WHERE role = ? AND name = ?;", - ( - role, - scenario_name, - ), - ) - result = c.fetchone() + with get_sync_conn() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: + # Use %s placeholders for query parameters + c.execute( + "SELECT 1 FROM scenarios WHERE role = %s AND name = %s;", + (role, scenario_name), + ) + result = c.fetchone() return result is not None - def save_notes(scenario, notes): """ Save or update notes associated with a specific scenario. @@ -1222,24 +550,24 @@ def save_notes(scenario, notes): Behavior: - Inserts new notes if the scenario does not exist in the database. - Updates existing notes if the scenario already has notes saved. - - Handles SQLite integrity and general database errors gracefully. + - Handles database errors gracefully. """ try: - with sqlite3.connect(notes_db_file_location) as conn: - c = conn.cursor() - c.execute( - """ - INSERT INTO notes (scenario, scenario_notes) VALUES (?, ?) - ON CONFLICT(scenario) DO UPDATE SET scenario_notes = excluded.scenario_notes; - """, - (scenario, notes), - ) + with get_sync_conn() as conn: + with conn.cursor() as c: + # Use INSERT ... ON CONFLICT (UPSERT) + c.execute( + """ + INSERT INTO notes (scenario, scenario_notes) VALUES (%s, %s) + ON CONFLICT(scenario) DO UPDATE SET scenario_notes = EXCLUDED.scenario_notes; + """, + (scenario, notes), + ) conn.commit() - except sqlite3.IntegrityError as e: - print(f"SQLite integrity error: {e}") - except sqlite3.Error as e: - print(f"SQLite error: {e}") - + except psycopg2.IntegrityError as e: + print(f"PostgreSQL integrity error: {e}") + except psycopg2.Error as e: + print(f"PostgreSQL error: {e}") def get_notes(scenario): """ @@ -1249,17 +577,14 @@ def get_notes(scenario): scenario (str): The unique identifier of the scenario. Returns: - sqlite3.Row or None: The notes record for the given scenario, or None if no notes exist. + psycopg2.extras.DictRow or None: The notes record for the given scenario, or None if no notes exist. """ - with sqlite3.connect(notes_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - c.execute("SELECT * FROM notes WHERE scenario = ?;", (scenario,)) - result = c.fetchone() - + with get_sync_conn() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: + c.execute("SELECT * FROM notes WHERE scenario = %s;", (scenario,)) + result = c.fetchone() return result - def remove_note(scenario): """ Delete the note associated with a specific scenario. @@ -1267,17 +592,13 @@ def remove_note(scenario): Parameters: scenario (str): The unique identifier of the scenario whose note should be removed. """ - with sqlite3.connect(notes_db_file_location) as conn: - c = conn.cursor() - c.execute("DELETE FROM notes WHERE scenario = ?;", (scenario,)) + with get_sync_conn() as conn: + with conn.cursor() as c: + c.execute("DELETE FROM notes WHERE scenario = %s;", (scenario,)) conn.commit() - if __name__ == "__main__": """ Entry point for the script to print the list of users. - - When executed directly, this block calls the `list_users()` function - and prints its returned list of users. """ - print(list_users()) + print(list_users()) \ No newline at end of file diff --git a/nebula/controller/scenarios.py b/nebula/controller/scenarios.py index 7276f3893..919f4834c 100644 --- a/nebula/controller/scenarios.py +++ b/nebula/controller/scenarios.py @@ -548,6 +548,26 @@ def from_dict(cls, data): scenario = cls(**scenario_data) return scenario + + @staticmethod + def to_json(scenario_obj): + """ + Converts a Scenario object to a JSON string. + + Args: + scenario_obj (Scenario): An instance of the Scenario class. + + Returns: + str: A JSON string representation of the Scenario object. + """ + if not isinstance(scenario_obj, Scenario): + raise TypeError("Input must be an instance of the Scenario class.") + + # Get all attributes of the Scenario object + scenario_dict = scenario_obj.__dict__ + + # Convert the dictionary to a JSON string + return json.dumps(scenario_dict, indent=2) # Using indent for pretty-printing # Class to manage the current scenario @@ -572,6 +592,7 @@ class ScenarioManagement: def __init__(self, scenario, user=None): # Current scenario self.scenario = Scenario.from_dict(scenario) + logging.info(f"[FER] scenario from scenarios.py {Scenario.to_json(self.scenario)}") # Uid of the user self.user = user # Scenario management settings diff --git a/nebula/database/Dockerfile b/nebula/database/Dockerfile new file mode 100644 index 000000000..ba3ce607d --- /dev/null +++ b/nebula/database/Dockerfile @@ -0,0 +1,55 @@ +FROM postgres:17.5-alpine3.22 + +# Rename the official entrypoint so we can wrap it +RUN mv /usr/local/bin/docker-entrypoint.sh /usr/local/bin/docker-entrypoint.sh.orig + +ENV POSTGRES_DB=configdb +ENV POSTGRES_USER=appuser + +# Copy SQL init file and custom entrypoint script +COPY /nebula/database/docker-entrypoint.sh /usr/local/bin/docker-entrypoint.sh +RUN chmod +x /usr/local/bin/docker-entrypoint.sh + +# Install Python 3.11.7 from source +RUN apk add --no-cache \ + gcc \ + g++ \ + musl-dev \ + make \ + openssl-dev \ + bzip2-dev \ + zlib-dev \ + xz-dev \ + readline-dev \ + sqlite-dev \ + libffi-dev \ + curl \ + tar \ + bash + +ENV PYTHON_VERSION=3.11.7 + +RUN curl -O https://www.python.org/ftp/python/${PYTHON_VERSION}/Python-${PYTHON_VERSION}.tgz && \ + tar -xzf Python-${PYTHON_VERSION}.tgz && \ + cd Python-${PYTHON_VERSION} && \ + ./configure --prefix=/usr/local --enable-optimizations && \ + make -j$(nproc) && \ + make install && \ + cd .. && rm -rf Python-${PYTHON_VERSION}* + +RUN python3.11 --version + +# Install uv (alternative to pip, very fast) +ADD https://astral.sh/uv/install.sh /uv-installer.sh +RUN sh /uv-installer.sh && rm /uv-installer.sh +ENV PATH="/root/.local/bin/:$PATH" + +# Install Python dependencies using uv +COPY pyproject.toml . +RUN uv python pin 3.11.7 +RUN uv sync --group database + +ENV PATH=".venv/bin:$PATH" + +ENTRYPOINT ["/usr/local/bin/docker-entrypoint.sh"] +CMD ["postgres"] diff --git a/nebula/database/docker-entrypoint.sh b/nebula/database/docker-entrypoint.sh new file mode 100644 index 000000000..e6bbde73c --- /dev/null +++ b/nebula/database/docker-entrypoint.sh @@ -0,0 +1,23 @@ +#!/bin/sh +set -e + +# 1) Launch the original entrypoint in the background +exec /usr/local/bin/docker-entrypoint.sh.orig "$@" & + +pid="$!" + +# 2) Wait until PostgreSQL is ready to accept connections +echo "⏳ Waiting for PostgreSQL to be ready..." +until pg_isready -U "$POSTGRES_USER" -d "$POSTGRES_DB" >/dev/null 2>&1; do + sleep 1 +done + +# 3) Always apply our init SQL +echo "🚀 Applying init-configs.sql..." +psql -v ON_ERROR_STOP=1 \ + -U "$POSTGRES_USER" \ + -d "$POSTGRES_DB" \ + -f /docker-entrypoint-initdb.d/init-configs.sql + +# 4) Wait on the main Postgres process +wait "$pid" \ No newline at end of file diff --git a/nebula/database/init-configs.sql b/nebula/database/init-configs.sql new file mode 100644 index 000000000..f125342c5 --- /dev/null +++ b/nebula/database/init-configs.sql @@ -0,0 +1,69 @@ +-- -------------------------------------------------- +-- init_postgres.sql +-- -------------------------------------------------- + +-- 1) (Optional) If you need to create the database, uncomment: +-- CREATE DATABASE nebula; +-- \c nebula + +-- 2) Users table +CREATE TABLE IF NOT EXISTS users ( + "user" TEXT PRIMARY KEY, + password TEXT, + role TEXT +); + +-- 2) Nodes como JSONB +CREATE TABLE IF NOT EXISTS nodes ( + uid TEXT PRIMARY KEY, + idx TEXT, + ip TEXT, + port TEXT, + role TEXT, + neighbors TEXT, + latitude TEXT, + longitude TEXT, + timestamp TEXT, + federation TEXT, + round TEXT, + scenario TEXT, + hash TEXT, + malicious TEXT +); + +-- 3) Configs como JSONB +DROP INDEX IF EXISTS idx_configs_config_gin; +DROP TABLE IF EXISTS configs; +CREATE TABLE configs ( + id SERIAL PRIMARY KEY, + config JSONB NOT NULL, + created_at TIMESTAMPTZ DEFAULT NOW() +); +CREATE INDEX idx_configs_config_gin ON configs USING GIN (config); + +-- 4) Scenarios table as JSONB +CREATE TABLE IF NOT EXISTS scenarios ( + name TEXT PRIMARY KEY, + username TEXT NOT NULL, + status TEXT, + start_time TEXT, + end_time TEXT, + config JSONB NOT NULL, + created_at TIMESTAMPTZ DEFAULT NOW() +); + +-- Index for fast JSONB queries on scenarios.config +CREATE INDEX IF NOT EXISTS idx_scenarios_config_gin + ON scenarios USING GIN (config); + +-- 5) Notes table +CREATE TABLE IF NOT EXISTS notes ( + scenario TEXT PRIMARY KEY, + scenario_notes TEXT +); + +-- 6) Insert the default 'admin' user with a hashed password +-- The hash must be generated by a Python script using passlib. +-- Replace the placeholder with your generated hash. +INSERT INTO users ("user", password, role) VALUES ('ADMIN', '$argon2id$v=19$m=65536,t=3,p=4$OobPh8BkZeT6D5s+Rt11mQ$JjI2M3U5+4lupdr87/GrIn46ImzoQujNEyVd7IGYiXY', 'admin') +ON CONFLICT ("user") DO NOTHING; \ No newline at end of file diff --git a/nebula/frontend/app.py b/nebula/frontend/app.py index 830a87c85..0693e367c 100755 --- a/nebula/frontend/app.py +++ b/nebula/frontend/app.py @@ -589,6 +589,7 @@ async def deploy_scenario(scenario_data, role, user): HTTPException: If the underlying HTTP POST request fails. """ url = f"http://{settings.controller_host}:{settings.controller_port}/scenarios/run" + logging.info(f"[FER] scenario {scenario_data}") data = {"scenario_data": scenario_data, "role": role, "user": user} return await controller_post(url, data) @@ -1539,6 +1540,7 @@ async def nebula_dashboard(request: Request, session: dict = Depends(get_session scenario_running = None bool_completed = False + logging.info(f"[FER] scenarios {scenarios} scenario_running {scenario_running}") if scenario_running: bool_completed = scenario_running["status"] == "completed" if scenarios: diff --git a/pyproject.toml b/pyproject.toml index d54ff3dcd..7c7e666b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,9 @@ docs = [ "mkdocstrings[python]<1.0.0,>=0.26.2", ] controller = [ + "psycopg2-binary==2.9.10", + "asyncpg==0.30.0", + "passlib==1.7.4", "aiohttp==3.10.5", "aiosqlite==0.20.0", "argon2-cffi==23.1.0", @@ -80,6 +83,10 @@ controller = [ "scikit-image==0.24.0", "scikit-learn==1.5.1", ] +database = [ + "asyncpg==0.30.0", + "psycopg2-binary==2.9.10" +] core = [ "aiohttp==3.10.5", "async-timeout==4.0.3", From 3db0a5b55e55c84ba4da6341ac87f0d816c624fa Mon Sep 17 00:00:00 2001 From: FerTV Date: Mon, 7 Jul 2025 17:12:07 +0200 Subject: [PATCH 02/20] fix postgres db endpoints --- nebula/controller/controller.py | 24 +- nebula/controller/database.py | 423 +++++++++++++++++++++----------- nebula/controller/scenarios.py | 1 - nebula/frontend/app.py | 2 - 4 files changed, 295 insertions(+), 155 deletions(-) diff --git a/nebula/controller/controller.py b/nebula/controller/controller.py index 4c7605036..a58b74e41 100755 --- a/nebula/controller/controller.py +++ b/nebula/controller/controller.py @@ -17,6 +17,7 @@ from fastapi.concurrency import asynccontextmanager # from nebula.controller.database import scenario_set_all_status_to_finished, scenario_set_status_to_finished +from nebula.controller.database import scenario_set_all_status_to_finished, scenario_set_status_to_finished from nebula.controller.http_helpers import remote_get, remote_post_form from nebula.utils import DockerUtils @@ -315,13 +316,10 @@ async def run_scenario( # Manager for the actual scenario scenarioManagement = ScenarioManagement(scenario_data, user) - logging.info(f"[FER] scenario run_scenario {scenario_data}") - await update_scenario( scenario_name=scenarioManagement.scenario_name, start_time=scenarioManagement.start_date_scenario, end_time="", - scenario=db_scenario, scenario=scenario_data, status="running", role=role, @@ -469,7 +467,6 @@ async def update_scenario( # from nebula.controller.scenarios import Scenario try: - logging.info(f"[FER] scenario controller.py {scenario}") scenario_update_record(scenario_name, start_time, end_time, scenario, status, username) except Exception as e: logging.exception(f"Error updating scenario {scenario_name}: {e}") @@ -597,7 +594,7 @@ async def list_nodes_by_scenario_name( from nebula.controller.database import list_nodes_by_scenario_name try: - nodes = list_nodes_by_scenario_name(scenario_name) + nodes = await list_nodes_by_scenario_name(scenario_name) except Exception as e: logging.exception(f"Error obtaining nodes: {e}") raise HTTPException(status_code=500, detail="Internal server error") @@ -728,22 +725,21 @@ async def get_notes_by_scenario_name( ): """ Endpoint to retrieve notes associated with a scenario. - - Path Parameters: - - scenario_name: Name of the scenario. - - Returns the notes or raises an HTTPException on error. """ from nebula.controller.database import get_notes try: - notes = get_notes(scenario_name) + notes_record = get_notes(scenario_name) + + if notes_record is not None: + notes_record = dict(notes_record.items()) + + return notes_record + except Exception as e: - logging.exception(f"Error obtaining notes {notes}: {e}") + logging.exception(f"Error obtaining notes for scenario {scenario_name}: {e}") raise HTTPException(status_code=500, detail="Internal server error") - return notes - @app.post("/notes/update") async def update_notes_by_scenario_name(scenario_name: str = Body(..., embed=True), notes: str = Body(..., embed=True)): diff --git a/nebula/controller/database.py b/nebula/controller/database.py index 017f84d3d..b3ffd9345 100755 --- a/nebula/controller/database.py +++ b/nebula/controller/database.py @@ -45,7 +45,6 @@ def list_users(all_info=False): result = c.fetchall() if not all_info: - # In PostgreSQL, you can access columns by key from DictCursor result = [user["user"] for user in result] return result @@ -73,7 +72,9 @@ def verify(user, password): if result: try: return pwd_context.verify(password, result[0]) - except: + except Exception: + # Catch more general exceptions during verification to be safe + logging.error(f"Error during password verification for user {user}", exc_info=True) return False return False @@ -140,7 +141,13 @@ def list_nodes(scenario_name=None, sort_by="idx"): try: with get_sync_conn() as conn: with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: + # Validate sort_by to prevent SQL injection + allowed_sort_fields = ["uid", "idx", "ip", "port", "role", "timestamp", "federation", "round"] + if sort_by not in allowed_sort_fields: + sort_by = "idx" # Default to a safe field + if scenario_name: + # Using psycopg2.extensions.AsIs for safe insertion of column names command = f"SELECT * FROM nodes WHERE scenario = %s ORDER BY {psycopg2.extensions.AsIs(sort_by)};" c.execute(command, (scenario_name,)) else: @@ -150,25 +157,26 @@ def list_nodes(scenario_name=None, sort_by="idx"): result = c.fetchall() return result except psycopg2.Error as e: - print(f"Error occurred while listing nodes: {e}") + logging.error(f"Error occurred while listing nodes: {e}") return None -def list_nodes_by_scenario_name(scenario_name): +async def list_nodes_by_scenario_name(scenario_name): """ Fetches all nodes associated with a specific scenario, ordered by their index as integers. """ + conn = None try: - with get_sync_conn() as conn: - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: - # Use a specific cast in PostgreSQL to order by integer value of idx - command = "SELECT * FROM nodes WHERE scenario = %s ORDER BY CAST(idx AS INTEGER) ASC;" - c.execute(command, (scenario_name,)) - result = c.fetchall() - return result - except psycopg2.Error as e: - print(f"Error occurred while listing nodes by scenario name: {e}") + conn = await get_async_conn() + command = "SELECT * FROM nodes WHERE scenario = $1 ORDER BY CAST(idx AS INTEGER) ASC;" + result = await conn.fetch(command, scenario_name) + return [dict(record) for record in result] + except Exception as e: + logging.error(f"Error occurred while listing nodes by scenario name: {e}") return None + finally: + if conn: + await conn.close() async def update_node_record( @@ -179,14 +187,10 @@ async def update_node_record( Inserts or updates a node record in the database for a given scenario, ensuring thread-safe access. """ async with _node_lock: - # CORRECTED: Use get_async_conn() directly as the async context manager. - # This assumes get_async_conn() returns a connection object that supports the - # asynchronous context manager protocol (i.e., has __aenter__ and __aexit__). - # A typical implementation of get_async_conn() using asyncpg would be 'return pool.acquire()'. - async with get_async_conn() as conn: - # Use a connection-bound cursor + # Await the get_async_conn() call to get the actual connection object + conn = await get_async_conn() + try: async with conn.transaction(): - # Use a SELECT ... FOR UPDATE to lock the row and avoid race conditions result = await conn.fetchrow( "SELECT * FROM nodes WHERE uid = $1 AND scenario = $2 FOR UPDATE;", node_uid, scenario @@ -200,7 +204,7 @@ async def update_node_record( timestamp, federation, round, scenario, hash, malicious) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14); """, - node_uid, idx, ip, port, role, neighbors, latitude, longitude, + node_uid, idx, ip, port, role, json.dumps(neighbors), latitude, longitude, timestamp, federation, federation_round, scenario, run_hash, malicious, ) else: @@ -212,14 +216,16 @@ async def update_node_record( hash = $11, malicious = $12 WHERE uid = $13 AND scenario = $14; """, - idx, ip, port, role, neighbors, latitude, longitude, + idx, ip, port, role, json.dumps(neighbors), latitude, longitude, timestamp, federation, federation_round, run_hash, malicious, node_uid, scenario, ) - - # Fetch the updated or newly inserted row + updated_row = await conn.fetchrow("SELECT * from nodes WHERE uid = $1 AND scenario = $2;", node_uid, scenario) return dict(updated_row) if updated_row else None + finally: + # Ensure the connection is closed after use + await conn.close() def remove_all_nodes(): @@ -228,7 +234,7 @@ def remove_all_nodes(): """ with get_sync_conn() as conn: with conn.cursor() as c: - c.execute("TRUNCATE nodes;") # TRUNCATE is faster than DELETE FROM for clearing tables + c.execute("TRUNCATE nodes CASCADE;") # Use CASCADE if there are foreign key dependencies conn.commit() @@ -245,18 +251,15 @@ def remove_nodes_by_scenario_name(scenario_name): def get_all_scenarios(username, role, sort_by="start_time"): """ - Retrieves all scenarios from the database, accessing the fields - inside the 'config' (JSONB) column. - Filters by user role and sorts by the specified field. + Retrieves all scenarios from the database, accessing fields from the 'config' (JSONB) column + and direct columns. Filters by user role and sorts by the specified field. """ - # Safe list of allowed sorting fields to prevent SQL injection allowed_sort_fields = ["start_time", "title", "username", "status", "name"] if sort_by not in allowed_sort_fields: - sort_by = "start_time" # Use a safe default value + sort_by = "start_time" - # Building the ORDER BY clause + # Determine the ORDER BY clause based on sort_by if sort_by == "start_time": - # Special sorting for dates saved as text, handling nulls/empty strings order_by_clause = """ ORDER BY CASE @@ -265,26 +268,36 @@ def get_all_scenarios(username, role, sort_by="start_time"): END, to_timestamp(start_time, 'YYYY/MM/DD HH24:MI:SS') DESC """ - else: - # Sorting is built dynamically but safely, - # since 'sort_by' has been validated against the allowed list. + elif sort_by in ["title", "model", "dataset", "rounds"]: # These are inside config JSONB order_by_clause = f"ORDER BY config->>'{sort_by}'" + else: # For direct table columns like name, username, status + order_by_clause = f"ORDER BY {sort_by}" + with get_sync_conn() as conn: with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: - # Base query. It's flexible to return 'name' and the full 'config' object. - # The code that calls this function can access any field from the config object. - command = "SELECT name, username, start_time, end_time, status, config FROM scenarios" + # Select direct columns and relevant fields from config JSONB + command = """ + SELECT + name, + username, + status, + start_time, + end_time, + config->>'title' AS title, + config->>'model' AS model, + config->>'dataset' AS dataset, + config->>'rounds' AS rounds, + config -- return the full config JSONB + FROM scenarios + """ params = [] - # Conditionally add the WHERE filter if the role is not admin if role != "admin": - command += " WHERE config->>'username' = %s" + command += " WHERE username = %s" # username is a direct column now params.append(username) - # Combine the query with the sorting clause full_command = f"{command} {order_by_clause};" - c.execute(full_command, tuple(params)) result = c.fetchall() @@ -293,7 +306,7 @@ def get_all_scenarios(username, role, sort_by="start_time"): def get_all_scenarios_and_check_completed(username, role, sort_by="start_time"): """ - Retrieves all scenarios from the JSONB field, sorts them, and updates the status if necessary. + Retrieves all scenarios, sorts them, and updates the status if necessary. Returns a list of dictionaries, where each dictionary represents a scenario. """ # Safe list of allowed sorting fields to prevent SQL injection. @@ -301,7 +314,7 @@ def get_all_scenarios_and_check_completed(username, role, sort_by="start_time"): if sort_by not in allowed_sort_fields: sort_by = "start_time" # Safe default value - # Building the ORDER BY clause + # Building the ORDER BY clause (same as get_all_scenarios) if sort_by == "start_time": order_by_clause = """ ORDER BY @@ -309,11 +322,13 @@ def get_all_scenarios_and_check_completed(username, role, sort_by="start_time"): WHEN start_time IS NULL OR start_time = '' THEN 1 ELSE 0 END, - to_timestamp(start_time, 'YYYY/MM/DD HH24:MI:SS') DESC + -- CORRECTED: Changed 'DD/MM/YYYY' to 'YYYY/MM/DD' to match the storage format + to_timestamp(start_time, 'DD/MM/YYYY HH24:MI:SS') DESC """ - else: - # We use a safe f-string because the value of sort_by has been validated + elif sort_by in ["title", "model", "dataset", "rounds"]: # These are inside config JSONB order_by_clause = f"ORDER BY config->>'{sort_by}'" + else: # For direct table columns like name, username, status + order_by_clause = f"ORDER BY {sort_by}" with get_sync_conn() as conn: with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: @@ -334,18 +349,14 @@ def get_all_scenarios_and_check_completed(username, role, sort_by="start_time"): """ params = [] if role != "admin": - command += " WHERE config->>'username' = %s" + command += " WHERE username = %s" # username is a direct column params.append(username) command += f" {order_by_clause};" c.execute(command, tuple(params)) - result_dicts = c.fetchall() # This already returns a list of DictRow objects (which act like dicts) - - # Logic to check for completed scenarios and update status. - # It's important to modify the `result_dicts` directly or handle the recursion carefully. + result_dicts = c.fetchall() - # Create a mutable list from DictRow objects for potential status updates scenarios_to_return = [dict(s) for s in result_dicts] re_fetch_required = False @@ -353,86 +364,103 @@ def get_all_scenarios_and_check_completed(username, role, sort_by="start_time"): if scenario["status"] == "running": if check_scenario_federation_completed(scenario["name"]): scenario_set_status_to_completed(scenario["name"]) - # If a scenario's status changes, it's best to re-query the database - # to ensure the most up-to-date information is returned for ALL scenarios. - # This avoids inconsistencies if multiple scenarios complete in a single call. re_fetch_required = True - break # Break after finding one completed scenario to trigger re-fetch + break if re_fetch_required: - # If any status was updated, recursively call the function again to get the fresh data. - # This ensures the returned list reflects the updated status from the DB. - # Make sure `get_all_scenarios` is indeed this function if you rename it. - return get_all_scenarios(username, role, sort_by) + # Recursively call get_all_scenarios_and_check_completed to get fresh data + return get_all_scenarios_and_check_completed(username, role, sort_by) return scenarios_to_return -def scenario_update_record(name, start_time, end_time, scenario, status, username): +def scenario_update_record(name, start_time, end_time, scenario_config, status, username): """ Inserts or updates a scenario record using the PostgreSQL "UPSERT" pattern. All configuration is saved in the 'config' column of type JSONB. + Direct columns (name, start_time, end_time, username, status) are also handled. """ with get_sync_conn() as conn: with conn.cursor() as c: + # Ensure scenario_config is a dictionary before dumping to JSON + if not isinstance(scenario_config, dict): + try: + scenario_config = json.loads(scenario_config) + except json.JSONDecodeError: + logging.error("scenario_config is not a valid JSON string or dict.") + return + command = """ INSERT INTO scenarios (name, start_time, end_time, username, status, config) - VALUES (%s, %s, %s, %s, %s, %s) + VALUES (%s, %s, %s, %s, %s, %s::jsonb) ON CONFLICT (name) DO UPDATE SET - config = scenarios.config || excluded.config; + start_time = EXCLUDED.start_time, + end_time = EXCLUDED.end_time, + username = EXCLUDED.username, + status = EXCLUDED.status, + config = scenarios.config || EXCLUDED.config; -- Merge JSONB """ - logging.info(f"[FER] scenario database.py {json.dumps(scenario, indent=2)}") - c.execute(command, (name, start_time, end_time, username, status, json.dumps(scenario, indent=2))) + c.execute(command, (name, start_time, end_time, username, status, json.dumps(scenario_config))) conn.commit() def scenario_set_all_status_to_finished(): """ Sets the status of all 'running' scenarios to 'finished' - and updates their 'end_time' within the JSONB. + and updates their 'end_time' (both in the direct column and within JSONB). """ with get_sync_conn() as conn: with conn.cursor() as c: - current_time = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S') - # We use jsonb_set to update specific fields within the JSONB. - # We nest the calls to update multiple fields. + current_time = datetime.datetime.now().strftime('%Y/%m/%d %H:%M:%S') # Consistent format + # Update direct columns first, then update JSONB within config command = """ UPDATE scenarios - SET status = 'finished', end_time = %s + SET + status = 'finished', + end_time = %s, + config = jsonb_set(config, '{status}', '"finished"') || + jsonb_set(config, '{end_time}', %s::jsonb) WHERE status = 'running'; """ - c.execute(command, (json.dumps(current_time),)) + c.execute(command, (current_time, json.dumps(current_time))) conn.commit() def scenario_set_status_to_finished(scenario_name): """ Sets the status of a specific scenario to 'finished' and updates its 'end_time'. + Updates both the direct columns and the JSONB 'config'. """ with get_sync_conn() as conn: with conn.cursor() as c: - current_time = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S') + current_time = datetime.datetime.now().strftime('%Y/%m/%d %H:%M:%S') # Consistent format command = """ UPDATE scenarios - SET config = jsonb_set( + SET + status = 'finished', + end_time = %s, + config = jsonb_set( jsonb_set(config, '{status}', '"finished"'), '{end_time}', %s::jsonb ) WHERE name = %s; """ - c.execute(command, (json.dumps(current_time), scenario_name)) + c.execute(command, (current_time, json.dumps(current_time), scenario_name)) conn.commit() def scenario_set_status_to_completed(scenario_name): """ Sets the status of a specific scenario to 'completed'. + Updates both the direct column and the JSONB 'config'. """ with get_sync_conn() as conn: with conn.cursor() as c: command = """ UPDATE scenarios - SET status = "completed" + SET + status = 'completed', + config = jsonb_set(config, '{status}', '"completed"') WHERE name = %s; """ c.execute(command, (scenario_name,)) @@ -442,44 +470,40 @@ def scenario_set_status_to_completed(scenario_name): def get_running_scenario(username=None, get_all=False): """ Retrieves scenarios with a 'running' status, optionally filtered by user. + Returns full scenario record (including direct columns and config JSONB). """ with get_sync_conn() as conn: with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: params = ["running"] - command = "SELECT name, config FROM scenarios WHERE config->>'status' = %s" + # Select all columns to get both direct and config data + command = "SELECT name, username, status, start_time, end_time, config FROM scenarios WHERE status = %s" if username: - command += " AND config->>'username' = %s" + command += " AND username = %s" params.append(username) c.execute(command, tuple(params)) if get_all: - raw_results = c.fetchall() - if raw_results: - processed_results = [] - for row in raw_results: - processed_results.append({ - 'name': row['name'], - 'config': row['config'] - }) - result = processed_results + result = [dict(row) for row in c.fetchall()] # Convert DictRows to dicts else: - result = c.fetchone() - if result: - result = result['config'] + result_row = c.fetchone() + result = dict(result_row) if result_row else None return result def get_completed_scenario(): """ Retrieves a single scenario with a 'completed' status. + Returns full scenario record (including direct columns and config JSONB). """ with get_sync_conn() as conn: with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: - command = "SELECT name, config FROM scenarios WHERE config->>'status' = %s;" + # The status is now a direct column, not just in config->>'status' + command = "SELECT name, username, status, start_time, end_time, config FROM scenarios WHERE status = %s;" c.execute(command, ("completed",)) - result = c.fetchone() + result_row = c.fetchone() + result = dict(result_row) if result_row else None return result @@ -490,13 +514,26 @@ def get_scenario_by_name(scenario_name): with get_sync_conn() as conn: with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: c.execute("SELECT name, start_time, end_time, username, status, config FROM scenarios WHERE name = %s;", (scenario_name,)) - result = c.fetchone() + result_row = c.fetchone() + result = dict(result_row) if result_row else None + + if result and result.get('config'): + # Assuming 'config' is already parsed into a Python dictionary by DictCursor + # If it's still a string, you might need: config_data = json.loads(result['config']) + config_data = result['config'] + + # Extract the 'scenario_title' and add it as a top-level key + # Use .get() for safety in case 'scenario_title' is also missing within config + result['title'] = config_data.get('scenario_title') + + # Also, if 'description' is inside config, you'll need to extract it similarly + result['description'] = config_data.get('description') # Assuming 'description' is also in config return result def get_user_by_scenario_name(scenario_name): """ - Retrieves the username associated with a scenario from the JSONB field. + Retrieves the username associated with a scenario (from the direct 'username' column). """ with get_sync_conn() as conn: with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: @@ -504,58 +541,129 @@ def get_user_by_scenario_name(scenario_name): result = c.fetchone() return result["username"] if result else None -# Placeholder for `check_scenario_federation_completed`. -# You need to implement this based on your application logic. -def check_scenario_federation_completed(scenario_name): + +def remove_scenario_by_name(scenario_name): """ - Placeholder function to check if a scenario's federation is completed. - This should be implemented based on your specific application logic. - For example, it could check if the last round has been reached. + Delete a scenario from the database by its unique name. + + Parameters: + scenario_name (str): The unique name identifier of the scenario to be removed. + + Behavior: + - Removes the scenario record matching the given name. + - Commits the deletion to the database. """ - print(f"Checking if scenario '{scenario_name}' is completed...") - # Example logic: - # return get_current_round(scenario_name) >= get_total_rounds(scenario_name) - return False # Placeholder value + try: + with get_sync_conn() as conn: + with conn.cursor() as c: + c.execute("DELETE FROM scenarios WHERE name = %s;", (scenario_name,)) + conn.commit() + logging.info(f"Scenario '{scenario_name}' successfully removed.") + except psycopg2.Error as e: + logging.error(f"Error occurred while deleting scenario '{scenario_name}': {e}") -def check_scenario_with_role(role, scenario_name): + +def check_scenario_federation_completed(scenario_name): """ - Verify if a scenario exists with a specific role and name. + Check if all nodes in a given scenario have completed the required federation rounds. Parameters: - role (str): The role associated with the scenario (e.g., "admin", "user"). - scenario_name (str): The unique name identifier of the scenario. + scenario_name (str): The unique name identifier of the scenario to check. Returns: - bool: True if a scenario with the given role and name exists, False otherwise. + bool: True if all nodes have completed the total rounds specified for the scenario, False otherwise or if an error occurs. + + Behavior: + - Retrieves the total number of rounds defined for the scenario. + - Fetches the current round progress of all nodes in that scenario. + - Returns True only if every node has reached the total rounds. + - Handles database errors and missing scenario cases gracefully. """ - with get_sync_conn() as conn: - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: - # Use %s placeholders for query parameters - c.execute( - "SELECT 1 FROM scenarios WHERE role = %s AND name = %s;", - (role, scenario_name), - ) - result = c.fetchone() + try: + with get_sync_conn() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: + # Retrieve the total rounds for the scenario from the 'config' JSONB column + c.execute("SELECT config->>'rounds' AS rounds FROM scenarios WHERE name = %s;", (scenario_name,)) + scenario = c.fetchone() - return result is not None + if not scenario or scenario["rounds"] is None: + logging.warning(f"Scenario '{scenario_name}' not found or 'rounds' not defined.") + return False -def save_notes(scenario, notes): + # Ensure total_rounds is an integer for comparison + try: + total_rounds = int(scenario["rounds"]) + except (ValueError, TypeError): + logging.error(f"Invalid 'rounds' value for scenario '{scenario_name}': {scenario['rounds']}") + return False + + # Fetch the current round progress of all nodes in that scenario + # The 'round' column in 'nodes' is a direct column + c.execute("SELECT round FROM nodes WHERE scenario = %s;", (scenario_name,)) + nodes = c.fetchall() + + if not nodes: + logging.info(f"No nodes found for scenario '{scenario_name}'. Federation not considered completed.") + return False + + # Check if all nodes have completed the total rounds + # The 'round' column in nodes is likely stored as a string or a numeric type. + # Assuming 'round' in 'nodes' is a numeric type, we convert it to int for comparison. + return all(int(node["round"]) >= total_rounds for node in nodes) + + except psycopg2.Error as e: + logging.error(f"PostgreSQL error during check_scenario_federation_completed for scenario '{scenario_name}': {e}") + return False + except ValueError as e: + logging.error(f"Data error during check_scenario_federation_completed for scenario '{scenario_name}': {e}") + return False + + +def check_scenario_with_role(role, scenario_name, current_username=None): """ - Save or update notes associated with a specific scenario. + Verify if a scenario exists that the user with the given role and username can access. Parameters: - scenario (str): The unique identifier of the scenario. - notes (str): The textual notes to be saved for the scenario. + role (str): The role of the current user (e.g., "admin", "user"). + scenario_name (str): The unique name identifier of the scenario to check. + current_username (str, optional): The username of the currently authenticated user. + Required for non-admin roles. + + Returns: + bool: True if the scenario exists and the user has access, False otherwise. Behavior: - - Inserts new notes if the scenario does not exist in the database. - - Updates existing notes if the scenario already has notes saved. - - Handles database errors gracefully. + - If the user's role is "admin", they can access any existing scenario. + - If the user's role is not "admin", they can only access scenarios where the + scenario's 'username' matches their `current_username`. + """ + scenario_info = get_scenario_by_name(scenario_name) + + if not scenario_info: + return False # Scenario does not exist + + if role == "admin": + return True # Admins can access any existing scenario + else: + # For non-admin roles, check if the scenario's username matches the current user's username + if current_username is None: + logging.warning( + "`check_scenario_with_role` called for non-admin role without `current_username`. " + "Cannot verify user-specific scenario access." + ) + return False # Cannot verify access without the current user's username + + return scenario_info.get("username") == current_username + +# --- Notes Management Functions --- + +def save_notes(scenario, notes): + """ + Save or update notes associated with a specific scenario. """ try: with get_sync_conn() as conn: with conn.cursor() as c: - # Use INSERT ... ON CONFLICT (UPSERT) c.execute( """ INSERT INTO notes (scenario, scenario_notes) VALUES (%s, %s) @@ -565,19 +673,14 @@ def save_notes(scenario, notes): ) conn.commit() except psycopg2.IntegrityError as e: - print(f"PostgreSQL integrity error: {e}") + logging.error(f"PostgreSQL integrity error during save_notes: {e}") except psycopg2.Error as e: - print(f"PostgreSQL error: {e}") + logging.error(f"PostgreSQL error during save_notes: {e}") + def get_notes(scenario): """ Retrieve notes associated with a specific scenario. - - Parameters: - scenario (str): The unique identifier of the scenario. - - Returns: - psycopg2.extras.DictRow or None: The notes record for the given scenario, or None if no notes exist. """ with get_sync_conn() as conn: with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: @@ -585,20 +688,64 @@ def get_notes(scenario): result = c.fetchone() return result + def remove_note(scenario): """ Delete the note associated with a specific scenario. - - Parameters: - scenario (str): The unique identifier of the scenario whose note should be removed. """ with get_sync_conn() as conn: with conn.cursor() as c: c.execute("DELETE FROM notes WHERE scenario = %s;", (scenario,)) conn.commit() + if __name__ == "__main__": """ Entry point for the script to print the list of users. """ - print(list_users()) \ No newline at end of file + # Example usage (assuming DB_USER, DB_PASSWORD, DB_HOST, DB_PORT are set in env) + # os.environ['DB_USER'] = 'your_db_user' + # os.environ['DB_PASSWORD'] = 'your_db_password' + # os.environ['DB_HOST'] = 'localhost' + # os.environ['DB_PORT'] = '5432' + + logging.basicConfig(level=logging.INFO) + + print("Listing users:") + users = list_users(all_info=True) + for user in users: + print(f"- User: {user['user']}, Role: {user['role']}") + + # Example of adding/updating a user + # print("\nAdding/Updating test user:") + # add_user("TESTUSER", "testpassword123", "user") + # print(get_user_info("TESTUSER")) + + # Example of scenario operations + # print("\nScenario operations:") + # current_time_str = datetime.datetime.now().strftime('%Y/%m/%d %H:%M:%S') + # scenario_data = { + # "title": "My Test Scenario", + # "model": "NN", + # "dataset": "MNIST", + # "rounds": "10", + # "description": "A test scenario for demonstration." + # } + # scenario_update_record("test_scenario_1", current_time_str, "", scenario_data, "running", "ADMIN") + # print("Running scenarios:") + # print(get_running_scenario(username="ADMIN", get_all=True)) + + # print("\nAll scenarios:") + # all_scenarios = get_all_scenarios_and_check_completed("ADMIN", "admin", sort_by="start_time") + # for s in all_scenarios: + # print(f"Scenario: {s['name']}, Status: {s['status']}, Title: {s.get('title')}") + + # print("\nSetting a scenario to finished:") + # scenario_set_status_to_finished("test_scenario_1") + # print(get_scenario_by_name("test_scenario_1")) + + # print("\nTesting notes:") + # save_notes("test_scenario_1", "These are some notes for test scenario 1.") + # print(get_notes("test_scenario_1")) + # remove_note("test_scenario_1") + # print(get_notes("test_scenario_1")) \ No newline at end of file diff --git a/nebula/controller/scenarios.py b/nebula/controller/scenarios.py index 919f4834c..13fc35e7f 100644 --- a/nebula/controller/scenarios.py +++ b/nebula/controller/scenarios.py @@ -592,7 +592,6 @@ class ScenarioManagement: def __init__(self, scenario, user=None): # Current scenario self.scenario = Scenario.from_dict(scenario) - logging.info(f"[FER] scenario from scenarios.py {Scenario.to_json(self.scenario)}") # Uid of the user self.user = user # Scenario management settings diff --git a/nebula/frontend/app.py b/nebula/frontend/app.py index 0693e367c..830a87c85 100755 --- a/nebula/frontend/app.py +++ b/nebula/frontend/app.py @@ -589,7 +589,6 @@ async def deploy_scenario(scenario_data, role, user): HTTPException: If the underlying HTTP POST request fails. """ url = f"http://{settings.controller_host}:{settings.controller_port}/scenarios/run" - logging.info(f"[FER] scenario {scenario_data}") data = {"scenario_data": scenario_data, "role": role, "user": user} return await controller_post(url, data) @@ -1540,7 +1539,6 @@ async def nebula_dashboard(request: Request, session: dict = Depends(get_session scenario_running = None bool_completed = False - logging.info(f"[FER] scenarios {scenarios} scenario_running {scenario_running}") if scenario_running: bool_completed = scenario_running["status"] == "completed" if scenarios: From 7da642ae889a2be041e23bcabffaa8f0a93fe8a1 Mon Sep 17 00:00:00 2001 From: FerTV Date: Mon, 7 Jul 2025 17:12:59 +0200 Subject: [PATCH 03/20] redis docker added --- app/deployer.py | 67 +++++++++++++++++++++++ nebula/controller/database.py | 34 ++++++------ nebula/controller/scenarios.py | 2 +- nebula/database/Dockerfile | 7 +-- nebula/database/docker-entrypoint.sh | 13 ++--- nebula/database/redis/Dockerfile | 1 + nebula/database/rediscommander/Dockerfile | 1 + 7 files changed, 93 insertions(+), 32 deletions(-) create mode 100644 nebula/database/redis/Dockerfile create mode 100644 nebula/database/rediscommander/Dockerfile diff --git a/app/deployer.py b/app/deployer.py index c3696f413..ef2ce4fe0 100644 --- a/app/deployer.py +++ b/app/deployer.py @@ -1006,6 +1006,70 @@ def run_database(self): client.api.start(pgweb_container_id) + ######### + # REDIS # + ######### + + # network_name = "redis-network" + # base = DockerUtils.create_docker_network(network_name) + + # --- REDIS --- + + host_port_redis = 6379 # You can change this if you want a different external port + + host_config_redis = client.api.create_host_config( + binds=[ + f"redis:/var/lib/redis", + ], + extra_hosts={"host.docker.internal": "host-gateway"}, + port_bindings={6379: host_port_redis}, + ) + + redis_networking_config = client.api.create_networking_config({ + f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.126") + }) + + redis_container = client.api.create_container( + image="nebula-redis", + name=f"{os.environ['USER']}_redis", + detach=True, + command="redis-server", + host_config=host_config_redis, + networking_config=redis_networking_config, + ) + + client.api.start(redis_container) + + # --- REDIS COMMANDER --- + + host_port_commander = 8081 + + environment_commander = { + "REDIS_HOSTS": "local:redis:6379", + "HTTP_USER": "root", + "HTTP_PASSWORD": "root", + } + + host_config_commander = client.api.create_host_config( + extra_hosts={"host.docker.internal": "host-gateway"}, + port_bindings={8081: host_port_commander}, + ) + + commander_networking_config = client.api.create_networking_config({ + f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.127") + }) + + commander_container = client.api.create_container( + image="nebula-rediscommander", + name=f"{os.environ['USER']}_redis_commander", + detach=True, + environment=environment_commander, + host_config=host_config_commander, + networking_config=commander_networking_config, + ) + + client.api.start(commander_container) + def stop_database(self): """ Stops and removes all NEBULA database Docker containers associated with the current user. @@ -1018,6 +1082,9 @@ def stop_database(self): - Cleaning up database containers during shutdown or redeployment processes. """ DockerUtils.remove_containers_by_prefix(f"{os.environ['USER']}_nebula-database") + DockerUtils.remove_containers_by_prefix(f"{os.environ['USER']}_nebula-pgweb") + DockerUtils.remove_containers_by_prefix(f"{os.environ['USER']}_nebula-redis") + DockerUtils.remove_containers_by_prefix(f"{os.environ['USER']}_nebula-rediscommander") def run_controller(self): if sys.platform == "win32": diff --git a/nebula/controller/database.py b/nebula/controller/database.py index b3ffd9345..54279800f 100755 --- a/nebula/controller/database.py +++ b/nebula/controller/database.py @@ -188,7 +188,7 @@ async def update_node_record( """ async with _node_lock: # Await the get_async_conn() call to get the actual connection object - conn = await get_async_conn() + conn = await get_async_conn() try: async with conn.transaction(): result = await conn.fetchrow( @@ -300,7 +300,7 @@ def get_all_scenarios(username, role, sort_by="start_time"): full_command = f"{command} {order_by_clause};" c.execute(full_command, tuple(params)) result = c.fetchall() - + return result @@ -353,10 +353,10 @@ def get_all_scenarios_and_check_completed(username, role, sort_by="start_time"): params.append(username) command += f" {order_by_clause};" - + c.execute(command, tuple(params)) - result_dicts = c.fetchall() - + result_dicts = c.fetchall() + scenarios_to_return = [dict(s) for s in result_dicts] re_fetch_required = False @@ -369,8 +369,8 @@ def get_all_scenarios_and_check_completed(username, role, sort_by="start_time"): if re_fetch_required: # Recursively call get_all_scenarios_and_check_completed to get fresh data - return get_all_scenarios_and_check_completed(username, role, sort_by) - + return get_all_scenarios_and_check_completed(username, role, sort_by) + return scenarios_to_return @@ -477,14 +477,14 @@ def get_running_scenario(username=None, get_all=False): params = ["running"] # Select all columns to get both direct and config data command = "SELECT name, username, status, start_time, end_time, config FROM scenarios WHERE status = %s" - + if username: command += " AND username = %s" params.append(username) - + c.execute(command, tuple(params)) - - if get_all: + + if get_all: result = [dict(row) for row in c.fetchall()] # Convert DictRows to dicts else: result_row = c.fetchone() @@ -516,16 +516,16 @@ def get_scenario_by_name(scenario_name): c.execute("SELECT name, start_time, end_time, username, status, config FROM scenarios WHERE name = %s;", (scenario_name,)) result_row = c.fetchone() result = dict(result_row) if result_row else None - + if result and result.get('config'): # Assuming 'config' is already parsed into a Python dictionary by DictCursor # If it's still a string, you might need: config_data = json.loads(result['config']) config_data = result['config'] - + # Extract the 'scenario_title' and add it as a top-level key # Use .get() for safety in case 'scenario_title' is also missing within config - result['title'] = config_data.get('scenario_title') - + result['title'] = config_data.get('scenario_title') + # Also, if 'description' is inside config, you'll need to extract it similarly result['description'] = config_data.get('description') # Assuming 'description' is also in config return result @@ -708,7 +708,7 @@ def remove_note(scenario): # os.environ['DB_PASSWORD'] = 'your_db_password' # os.environ['DB_HOST'] = 'localhost' # os.environ['DB_PORT'] = '5432' - + logging.basicConfig(level=logging.INFO) print("Listing users:") @@ -748,4 +748,4 @@ def remove_note(scenario): # save_notes("test_scenario_1", "These are some notes for test scenario 1.") # print(get_notes("test_scenario_1")) # remove_note("test_scenario_1") - # print(get_notes("test_scenario_1")) \ No newline at end of file + # print(get_notes("test_scenario_1")) diff --git a/nebula/controller/scenarios.py b/nebula/controller/scenarios.py index 13fc35e7f..0d8f2286e 100644 --- a/nebula/controller/scenarios.py +++ b/nebula/controller/scenarios.py @@ -548,7 +548,7 @@ def from_dict(cls, data): scenario = cls(**scenario_data) return scenario - + @staticmethod def to_json(scenario_obj): """ diff --git a/nebula/database/Dockerfile b/nebula/database/Dockerfile index ba3ce607d..04ed953e9 100644 --- a/nebula/database/Dockerfile +++ b/nebula/database/Dockerfile @@ -3,9 +3,6 @@ FROM postgres:17.5-alpine3.22 # Rename the official entrypoint so we can wrap it RUN mv /usr/local/bin/docker-entrypoint.sh /usr/local/bin/docker-entrypoint.sh.orig -ENV POSTGRES_DB=configdb -ENV POSTGRES_USER=appuser - # Copy SQL init file and custom entrypoint script COPY /nebula/database/docker-entrypoint.sh /usr/local/bin/docker-entrypoint.sh RUN chmod +x /usr/local/bin/docker-entrypoint.sh @@ -51,5 +48,5 @@ RUN uv sync --group database ENV PATH=".venv/bin:$PATH" -ENTRYPOINT ["/usr/local/bin/docker-entrypoint.sh"] -CMD ["postgres"] +# ENTRYPOINT ["/usr/local/bin/docker-entrypoint.sh"] +# CMD ["postgres"] diff --git a/nebula/database/docker-entrypoint.sh b/nebula/database/docker-entrypoint.sh index e6bbde73c..a298ca23b 100644 --- a/nebula/database/docker-entrypoint.sh +++ b/nebula/database/docker-entrypoint.sh @@ -1,23 +1,18 @@ #!/bin/sh set -e -# 1) Launch the original entrypoint in the background -exec /usr/local/bin/docker-entrypoint.sh.orig "$@" & +# 1) Run the original entrypoint and wait for it to finish initialization +/usr/local/bin/docker-entrypoint.sh.orig "$@" -pid="$!" - -# 2) Wait until PostgreSQL is ready to accept connections +# 2) Wait until PostgreSQL accepts connections to the configured database echo "⏳ Waiting for PostgreSQL to be ready..." until pg_isready -U "$POSTGRES_USER" -d "$POSTGRES_DB" >/dev/null 2>&1; do sleep 1 done -# 3) Always apply our init SQL +# 3) Apply the SQL initialization script echo "🚀 Applying init-configs.sql..." psql -v ON_ERROR_STOP=1 \ -U "$POSTGRES_USER" \ -d "$POSTGRES_DB" \ -f /docker-entrypoint-initdb.d/init-configs.sql - -# 4) Wait on the main Postgres process -wait "$pid" \ No newline at end of file diff --git a/nebula/database/redis/Dockerfile b/nebula/database/redis/Dockerfile new file mode 100644 index 000000000..31ff6af02 --- /dev/null +++ b/nebula/database/redis/Dockerfile @@ -0,0 +1 @@ +FROM redis:latest \ No newline at end of file diff --git a/nebula/database/rediscommander/Dockerfile b/nebula/database/rediscommander/Dockerfile new file mode 100644 index 000000000..0eb1e1ead --- /dev/null +++ b/nebula/database/rediscommander/Dockerfile @@ -0,0 +1 @@ +FROM rediscommander/redis-commander:latest \ No newline at end of file From 13ef3d62bb873d54dc8f56c9ffb67a30d690e43a Mon Sep 17 00:00:00 2001 From: FerTV Date: Wed, 9 Jul 2025 12:43:42 +0200 Subject: [PATCH 04/20] fix(frontend): connections between nodes in monitor fix(backend): saving connections between nodes in the database --- nebula/addons/reputation/reputation.py | 352 ++++++++++++------------- nebula/addons/topologymanager.py | 12 +- nebula/config/config.py | 44 ++-- nebula/controller/controller.py | 2 +- nebula/controller/database.py | 43 +-- nebula/core/engine.py | 34 +-- nebula/database/init-configs.sql | 6 +- nebula/frontend/app.py | 5 +- 8 files changed, 232 insertions(+), 266 deletions(-) diff --git a/nebula/addons/reputation/reputation.py b/nebula/addons/reputation/reputation.py index 561199513..bca072cea 100644 --- a/nebula/addons/reputation/reputation.py +++ b/nebula/addons/reputation/reputation.py @@ -61,7 +61,7 @@ class Reputation: The class handles collection of metrics, calculation of static and dynamic reputation, updating history, and communication of reputation scores to neighbors. """ - + REPUTATION_THRESHOLD = 0.6 SIMILARITY_THRESHOLD = 0.6 INITIAL_ROUND_FOR_REPUTATION = 1 @@ -70,12 +70,12 @@ class Reputation: WEIGHTED_HISTORY_ROUNDS = 3 FRACTION_ANOMALY_MULTIPLIER = 1.20 THRESHOLD_ANOMALY_MULTIPLIER = 1.15 - + # Augmentation factors LATENCY_AUGMENT_FACTOR = 1.4 MESSAGE_AUGMENT_FACTOR_EARLY = 2.0 MESSAGE_AUGMENT_FACTOR_NORMAL = 1.1 - + # Penalty and decay factors HISTORICAL_PENALTY_THRESHOLD = 0.9 NEGATIVE_LATENCY_PENALTY = 0.3 @@ -104,7 +104,7 @@ def __init__(self, engine: "Engine", config: "Config"): self._addr = engine.addr self._log_dir = engine.log_dir self._idx = engine.idx - + self._initialize_data_structures() self._configure_constants() self._load_configuration() @@ -116,7 +116,7 @@ def _configure_constants(self): """Configure system constants from config or use defaults.""" reputation_config = self._config.participant.get("defense_args", {}).get("reputation", {}) constants_config = reputation_config.get("constants", {}) - + self.REPUTATION_THRESHOLD = constants_config.get("reputation_threshold", self.REPUTATION_THRESHOLD) self.SIMILARITY_THRESHOLD = constants_config.get("similarity_threshold", self.SIMILARITY_THRESHOLD) self.INITIAL_ROUND_FOR_REPUTATION = constants_config.get("initial_round_for_reputation", self.INITIAL_ROUND_FOR_REPUTATION) @@ -180,15 +180,15 @@ def _load_configuration(self): def _setup_connection_metrics(self): """Initialize metrics for each neighbor.""" - neighbors_str = self._config.participant["network_args"]["neighbors"] - for neighbor in neighbors_str.split(): + neighbors = self._config.participant["network_args"]["neighbors"] + for neighbor in neighbors: self.connection_metrics[neighbor] = Metrics() def _configure_metric_weights(self): """Configure weights for different metrics based on weighting factor.""" default_weight = 0.25 metric_names = ["model_arrival_latency", "model_similarity", "num_messages", "fraction_parameters_changed"] - + if self._weighting_factor == "static": self._weight_model_arrival_latency = float( self._metrics.get("model_arrival_latency", {}).get("weight", default_weight) @@ -209,7 +209,7 @@ def _configure_metric_weights(self): elif not isinstance(self._metrics[metric_name], dict): self._metrics[metric_name] = {"enabled": bool(self._metrics[metric_name])} self._metrics[metric_name]["weight"] = default_weight - + self._weight_model_arrival_latency = default_weight self._weight_model_similarity = default_weight self._weight_num_messages = default_weight @@ -229,24 +229,24 @@ def engine(self): def _is_metric_enabled(self, metric_name: str, metrics_config: dict = None) -> bool: """ Check if a specific metric is enabled based on the provided configuration. - + Args: metric_name (str): The name of the metric to check. - metrics_config (dict, optional): The configuration dictionary for metrics. + metrics_config (dict, optional): The configuration dictionary for metrics. If None, uses the instance's _metrics. - + Returns: bool: True if the metric is enabled, False otherwise. """ config_to_use = metrics_config if metrics_config is not None else getattr(self, '_metrics', None) - + if not isinstance(config_to_use, dict): if metrics_config is not None: logging.warning(f"metrics_config is not a dictionary: {type(metrics_config)}") else: logging.warning("_metrics is not properly initialized") return False - + metric_config = config_to_use.get(metric_name) if metric_config is None: return False @@ -269,7 +269,7 @@ def save_data( ): """ Save data between nodes and aggregated models. - + Args: type_data: Type of data to save ('number_message', 'fraction_of_params_changed', 'model_arrival_latency') nei: Neighbor identifier @@ -290,7 +290,7 @@ def save_data( try: metrics_instance = self.connection_metrics[nei] - + if type_data == "number_message": message_data = {"time": time, "current_round": current_round} if not isinstance(metrics_instance.messages, list): @@ -345,19 +345,19 @@ async def init_reputation( ): """ Initialize the reputation system. - + Args: federation_nodes: List of federation node identifiers - round_num: Current round number + round_num: Current round number last_feedback_round: Last round that received feedback init_reputation: Initial reputation value to assign """ if not self._enabled: return - + if not self._validate_init_parameters(federation_nodes, round_num, init_reputation): return - + neighbors = self._validate_federation_nodes(federation_nodes) if not neighbors: logging.error("init_reputation | No valid neighbors found") @@ -370,13 +370,13 @@ def _validate_init_parameters(self, federation_nodes, round_num, init_reputation if not federation_nodes: logging.error("init_reputation | No federation nodes provided") return False - + if round_num is None: logging.warning("init_reputation | Round number not provided") - + if init_reputation is None: logging.warning("init_reputation | Initial reputation value not provided") - + return True async def _initialize_neighbor_reputations(self, neighbors: list, round_num: int, last_feedback_round: int, init_reputation: float): @@ -392,7 +392,7 @@ def _create_or_update_reputation_entry(self, nei: str, round_num: int, last_feed "round": round_num, "last_feedback_round": last_feedback_round, } - + if nei not in self.reputation: self.reputation[nei] = reputation_data elif self.reputation[nei].get("reputation") is None: @@ -401,21 +401,21 @@ def _create_or_update_reputation_entry(self, nei: str, round_num: int, last_feed def _validate_federation_nodes(self, federation_nodes) -> list: """ Validate and filter federation nodes. - + Args: federation_nodes: List of federation node identifiers - + Returns: list: List of valid node identifiers """ if not federation_nodes: return [] - + valid_nodes = [node for node in federation_nodes if node and str(node).strip()] - + if not valid_nodes: logging.warning("No valid federation nodes found after filtering") - + return valid_nodes async def _calculate_static_reputation( @@ -429,7 +429,7 @@ async def _calculate_static_reputation( Args: addr: The participant's address - nei: The neighbor's address + nei: The neighbor's address metric_values: Dictionary with metric values """ static_weights = { @@ -440,10 +440,10 @@ async def _calculate_static_reputation( } reputation_static = sum( - metric_values.get(metric_name, 0) * static_weights[metric_name] + metric_values.get(metric_name, 0) * static_weights[metric_name] for metric_name in static_weights ) - + logging.info(f"Static reputation for node {nei} at round {await self.engine.get_round()}: {reputation_static}") avg_reputation = await self.save_reputation_history_in_memory(self.engine.addr, nei, reputation_static) @@ -476,48 +476,48 @@ async def _calculate_dynamic_reputation(self, addr, neighbors): async def _calculate_average_weights(self): """Calculate average weights for all enabled metrics.""" average_weights = {} - + for metric_name in self.history_data.keys(): if self._is_metric_enabled(metric_name): average_weights[metric_name] = await self._get_metric_average_weight(metric_name) - + return average_weights - + async def _get_metric_average_weight(self, metric_name): """Get the average weight for a specific metric.""" if metric_name not in self.history_data or not self.history_data[metric_name]: logging.debug(f"No history data available for metric: {metric_name}") return 0 - + valid_entries = [ entry for entry in self.history_data[metric_name] - if (entry.get("round") is not None and - entry["round"] >= await self._engine.get_round() and + if (entry.get("round") is not None and + entry["round"] >= await self._engine.get_round() and entry.get("weight") not in [None, -1]) ] - + if not valid_entries: return 0 - + try: weights = [entry["weight"] for entry in valid_entries if entry.get("weight") is not None] return sum(weights) / len(weights) if weights else 0 except (TypeError, ZeroDivisionError) as e: logging.warning(f"Error calculating average weight for {metric_name}: {e}") return 0 - + async def _process_neighbors_reputation(self, addr, neighbors, average_weights): """Process reputation calculation for all neighbors.""" for nei in neighbors: metric_values = await self._get_neighbor_metric_values(nei) - + if all(metric_name in metric_values for metric_name in average_weights): await self._update_neighbor_reputation(addr, nei, metric_values, average_weights) - + async def _get_neighbor_metric_values(self, nei): """Get metric values for a specific neighbor in the current round.""" metric_values = {} - + for metric_name in self.history_data: if self._is_metric_enabled(metric_name): for entry in self.history_data.get(metric_name, []): @@ -526,16 +526,16 @@ async def _get_neighbor_metric_values(self, nei): entry.get("nei") == nei): metric_values[metric_name] = entry.get("metric_value", 0) break - + return metric_values - + async def _update_neighbor_reputation(self, addr, nei, metric_values, average_weights): """Update reputation for a specific neighbor.""" reputation_with_weights = sum( - metric_values.get(metric_name, 0) * average_weights[metric_name] + metric_values.get(metric_name, 0) * average_weights[metric_name] for metric_name in average_weights ) - + logging.info( f"Dynamic reputation with weights for {nei} at round {await self._engine.get_round()}: {reputation_with_weights}" ) @@ -564,7 +564,7 @@ async def _update_reputation_record(self, nei: str, reputation: float, data: dic data: Additional data to update (currently unused) """ current_round = await self._engine.get_round() - + if nei not in self.reputation: self.reputation[nei] = { "reputation": reputation, @@ -576,7 +576,7 @@ async def _update_reputation_record(self, nei: str, reputation: float, data: dic self.reputation[nei]["round"] = current_round logging.info(f"Reputation of node {nei}: {self.reputation[nei]['reputation']}") - + if self.reputation[nei]["reputation"] < self.REPUTATION_THRESHOLD and current_round > 0: self.rejected_nodes.add(nei) logging.info(f"Rejected node {nei} at round {current_round}") @@ -608,23 +608,23 @@ def calculate_weighted_values( reputation_metrics ) self._add_current_metrics_to_history(active_metrics, history_data, current_round, addr, nei) - + if current_round >= self.INITIAL_ROUND_FOR_REPUTATION and len(active_metrics) > 0: adjusted_weights = self._calculate_dynamic_weights(active_metrics, history_data) else: adjusted_weights = self._calculate_uniform_weights(active_metrics) - + self._update_history_with_weights(active_metrics, history_data, adjusted_weights, current_round, nei) def _ensure_history_data_structure(self, history_data: dict): """Ensure all required keys exist in history data structure.""" required_keys = [ "num_messages", - "model_similarity", + "model_similarity", "fraction_parameters_changed", "model_arrival_latency", ] - + for key in required_keys: if key not in history_data: history_data[key] = [] @@ -644,7 +644,7 @@ def _get_active_metrics( "fraction_parameters_changed": fraction_score_asign, "model_arrival_latency": avg_model_arrival_latency, } - + return {k: v for k, v in all_metrics.items() if self._is_metric_enabled(k, reputation_metrics)} def _add_current_metrics_to_history(self, active_metrics: dict, history_data: dict, current_round: int, addr: str, nei: str): @@ -662,7 +662,7 @@ def _add_current_metrics_to_history(self, active_metrics: dict, history_data: di def _calculate_dynamic_weights(self, active_metrics: dict, history_data: dict) -> dict: """Calculate dynamic weights based on metric deviations.""" deviations = self._calculate_metric_deviations(active_metrics, history_data) - + if all(deviation == 0.0 for deviation in deviations.values()): return self._generate_random_weights(active_metrics) else: @@ -672,7 +672,7 @@ def _calculate_dynamic_weights(self, active_metrics: dict, history_data: dict) - def _calculate_metric_deviations(self, active_metrics: dict, history_data: dict) -> dict: """Calculate deviations of current metrics from historical means.""" deviations = {} - + for metric_name, current_value in active_metrics.items(): historical_values = history_data[metric_name] metric_values = [ @@ -680,11 +680,11 @@ def _calculate_metric_deviations(self, active_metrics: dict, history_data: dict) for entry in historical_values if "metric_value" in entry and entry["metric_value"] != 0 ] - + mean_value = np.mean(metric_values) if metric_values else 0 deviation = abs(current_value - mean_value) deviations[metric_name] = deviation - + return deviations def _generate_random_weights(self, active_metrics: dict) -> dict: @@ -692,7 +692,7 @@ def _generate_random_weights(self, active_metrics: dict) -> dict: num_metrics = len(active_metrics) random_weights = [random.random() for _ in range(num_metrics)] total_random_weight = sum(random_weights) - + return { metric_name: weight / total_random_weight for metric_name, weight in zip(active_metrics, random_weights, strict=False) @@ -702,14 +702,14 @@ def _normalize_deviation_weights(self, deviations: dict) -> dict: """Normalize weights based on deviations.""" max_deviation = max(deviations.values()) if deviations else 1 normalized_weights = { - metric_name: (deviation / max_deviation) + metric_name: (deviation / max_deviation) for metric_name, deviation in deviations.items() } - + total_weight = sum(normalized_weights.values()) if total_weight > 0: return { - metric_name: weight / total_weight + metric_name: weight / total_weight for metric_name, weight in normalized_weights.items() } else: @@ -720,20 +720,20 @@ def _adjust_weights_with_minimum(self, normalized_weights: dict, deviations: dic """Apply minimum weight constraints and renormalize.""" mean_deviation = np.mean(list(deviations.values())) dynamic_min_weight = max(self.DYNAMIC_MIN_WEIGHT_THRESHOLD, mean_deviation / (mean_deviation + 1)) - + adjusted_weights = {} total_adjusted_weight = 0 - + for metric_name, weight in normalized_weights.items(): adjusted_weight = max(weight, dynamic_min_weight) adjusted_weights[metric_name] = adjusted_weight total_adjusted_weight += adjusted_weight - + # Renormalize if total weight exceeds 1 if total_adjusted_weight > 1: for metric_name in adjusted_weights: adjusted_weights[metric_name] /= total_adjusted_weight - + return adjusted_weights def _calculate_uniform_weights(self, active_metrics: dict) -> dict: @@ -748,8 +748,8 @@ def _update_history_with_weights(self, active_metrics: dict, history_data: dict, for metric_name in active_metrics: weight = weights.get(metric_name, -1) for entry in history_data[metric_name]: - if (entry["metric_name"] == metric_name and - entry["round"] == current_round and + if (entry["metric_name"] == metric_name and + entry["round"] == current_round and entry["nei"] == nei): entry["weight"] = weight @@ -765,7 +765,7 @@ async def calculate_value_metrics(self, addr, nei, metrics_active=None): try: current_round = await self._engine.get_round() metrics_instance = self.connection_metrics.get(nei) - + if not metrics_instance: logging.warning(f"No metrics found for neighbor {nei}") return self._get_default_metric_values() @@ -778,7 +778,7 @@ async def calculate_value_metrics(self, addr, nei, metrics_active=None): } self._log_metrics_graphics(metric_results, addr, nei, current_round) - + return ( metric_results["messages"]["avg"], metric_results["similarity"], @@ -802,7 +802,7 @@ def _process_num_messages_metric(self, metrics_instance, addr: str, nei: str, cu filtered_messages = [ msg for msg in metrics_instance.messages if msg.get("current_round") == current_round ] - + for msg in filtered_messages: self.messages_number_message.append({ "number_message": msg.get("time"), @@ -813,9 +813,9 @@ def _process_num_messages_metric(self, metrics_instance, addr: str, nei: str, cu normalized, count = self.manage_metric_number_message( self.messages_number_message, addr, nei, current_round, True ) - + avg = self.save_number_message_history(addr, nei, normalized, current_round) - + if avg is None and current_round > self.HISTORY_ROUNDS_LOOKBACK: avg = self.number_message_history[(addr, nei)][current_round - 1]["avg_number_message"] @@ -901,7 +901,7 @@ def _process_model_arrival_latency_metric(self, metrics_instance, addr: str, nei if avg_latency is None and current_round > 1: avg_latency = self.model_arrival_latency_history[(addr, nei)][current_round - 1]["score"] return avg_latency or 0 - + return 0 def _process_model_similarity_metric(self, nei: str, current_round: int, metrics_active) -> float: @@ -938,7 +938,7 @@ def create_graphics_to_metrics( ): """ Create and log graphics for reputation metrics. - + Args: number_message_count: Count of messages for logging number_message_norm: Normalized message metric @@ -952,25 +952,25 @@ def create_graphics_to_metrics( """ if current_round is None or current_round >= total_rounds: return - + self.engine.trainer._logger.log_data( - {f"R-Model_arrival_latency_reputation/{addr}": {nei: model_arrival_latency}}, + {f"R-Model_arrival_latency_reputation/{addr}": {nei: model_arrival_latency}}, step=current_round ) self.engine.trainer._logger.log_data( - {f"R-Count_messages_number_message_reputation/{addr}": {nei: number_message_count}}, + {f"R-Count_messages_number_message_reputation/{addr}": {nei: number_message_count}}, step=current_round ) self.engine.trainer._logger.log_data( - {f"R-number_message_reputation/{addr}": {nei: number_message_norm}}, + {f"R-number_message_reputation/{addr}": {nei: number_message_norm}}, step=current_round ) self.engine.trainer._logger.log_data( - {f"R-Similarity_reputation/{addr}": {nei: similarity}}, + {f"R-Similarity_reputation/{addr}": {nei: similarity}}, step=current_round ) self.engine.trainer._logger.log_data( - {f"R-Fraction_reputation/{addr}": {nei: fraction}}, + {f"R-Fraction_reputation/{addr}": {nei: fraction}}, step=current_round ) @@ -991,7 +991,7 @@ def analyze_anomalies( try: key = (addr, nei, current_round) self._initialize_fraction_history_entry(key, fraction_changed, threshold) - + if current_round == 0: return self._handle_initial_round_anomalies(key, fraction_changed, threshold) else: @@ -1032,16 +1032,16 @@ def _handle_subsequent_round_anomalies( ) -> float: """Handle anomaly analysis for subsequent rounds.""" prev_stats = self._find_previous_valid_stats(addr, nei, current_round) - + if prev_stats is None: logging.warning(f"No valid previous stats found for {addr}, {nei}, round {current_round}") return 1.0 - + anomalies = self._detect_anomalies(fraction_changed, threshold, prev_stats) values = self._calculate_anomaly_values(fraction_changed, threshold, prev_stats, anomalies) fraction_score = self._calculate_combined_score(values) self._update_fraction_statistics(key, fraction_changed, threshold, prev_stats, anomalies, fraction_score) - + return max(fraction_score, 0) def _find_previous_valid_stats(self, addr: str, nei: str, current_round: int) -> dict: @@ -1049,18 +1049,18 @@ def _find_previous_valid_stats(self, addr: str, nei: str, current_round: int) -> for i in range(1, current_round + 1): candidate_key = (addr, nei, current_round - i) candidate_data = self.fraction_changed_history.get(candidate_key, {}) - + required_keys = ["mean_fraction", "std_dev_fraction", "mean_threshold", "std_dev_threshold"] if all(candidate_data.get(k) is not None for k in required_keys): return candidate_data - + return None def _detect_anomalies(self, current_fraction: float, current_threshold: float, prev_stats: dict) -> dict: """Detect if current values are anomalous compared to previous statistics.""" upper_mean_fraction = (prev_stats["mean_fraction"] + prev_stats["std_dev_fraction"]) * self.FRACTION_ANOMALY_MULTIPLIER upper_mean_threshold = (prev_stats["mean_threshold"] + prev_stats["std_dev_threshold"]) * self.THRESHOLD_ANOMALY_MULTIPLIER - + return { "fraction_anomaly": current_fraction > upper_mean_fraction, "threshold_anomaly": current_threshold > upper_mean_threshold, @@ -1074,19 +1074,19 @@ def _calculate_anomaly_values( """Calculate penalty values for fraction and threshold anomalies.""" fraction_value = 1.0 threshold_value = 1.0 - + if anomalies["fraction_anomaly"]: mean_fraction_prev = prev_stats["mean_fraction"] if mean_fraction_prev > 0: penalization_factor = abs(current_fraction - mean_fraction_prev) / mean_fraction_prev fraction_value = 1 - (1 / (1 + np.exp(-penalization_factor))) - + if anomalies["threshold_anomaly"]: mean_threshold_prev = prev_stats["mean_threshold"] if mean_threshold_prev > 0: penalization_factor = abs(current_threshold - mean_threshold_prev) / mean_threshold_prev threshold_value = 1 - (1 / (1 + np.exp(-penalization_factor))) - + return { "fraction_value": fraction_value, "threshold_value": threshold_value, @@ -1099,19 +1099,19 @@ def _calculate_combined_score(self, values: dict) -> float: return fraction_weight * values["fraction_value"] + threshold_weight * values["threshold_value"] def _update_fraction_statistics( - self, key: tuple, current_fraction: float, current_threshold: float, + self, key: tuple, current_fraction: float, current_threshold: float, prev_stats: dict, anomalies: dict, fraction_score: float ): """Update the fraction statistics for the current round.""" self.fraction_changed_history[key]["fraction_anomaly"] = anomalies["fraction_anomaly"] self.fraction_changed_history[key]["threshold_anomaly"] = anomalies["threshold_anomaly"] - + self.fraction_changed_history[key]["mean_fraction"] = (current_fraction + prev_stats["mean_fraction"]) / 2 self.fraction_changed_history[key]["mean_threshold"] = (current_threshold + prev_stats["mean_threshold"]) / 2 - + fraction_variance = ((current_fraction - prev_stats["mean_fraction"]) ** 2 + prev_stats["std_dev_fraction"] ** 2) / 2 threshold_variance = ((self.THRESHOLD_VARIANCE_MULTIPLIER * (current_threshold - prev_stats["mean_threshold"]) ** 2) + prev_stats["std_dev_threshold"] ** 2) / 2 - + self.fraction_changed_history[key]["std_dev_fraction"] = np.sqrt(fraction_variance) self.fraction_changed_history[key]["std_dev_threshold"] = np.sqrt(threshold_variance) self.fraction_changed_history[key]["fraction_score"] = fraction_score @@ -1132,9 +1132,9 @@ def manage_model_arrival_latency(self, addr, nei, latency, current_round, round_ """ try: current_key = nei - + self._initialize_latency_round_entry(current_round, current_key, latency) - + if current_round >= 1: score = self._calculate_latency_score(current_round, current_key, latency) self._update_latency_entry_with_score(current_round, current_key, score) @@ -1161,17 +1161,17 @@ def _calculate_latency_score(self, current_round: int, current_key: str, latency """Calculate the latency score based on historical data.""" target_round = self._get_target_round_for_latency(current_round) all_latencies = self._get_all_latencies_for_round(target_round) - + if not all_latencies: return 0.0 - + mean_latency = np.mean(all_latencies) augment_mean = mean_latency * self.LATENCY_AUGMENT_FACTOR - + if latency is None: logging.info(f"latency is None in round {current_round} for nei {current_key}") return -0.5 - + if latency <= augment_mean: return 1.0 else: @@ -1195,7 +1195,7 @@ def _update_latency_entry_with_score(self, current_round: int, current_key: str, target_round = self._get_target_round_for_latency(current_round) all_latencies = self._get_all_latencies_for_round(target_round) mean_latency = np.mean(all_latencies) if all_latencies else 0 - + self.model_arrival_latency_history[current_round][current_key].update({ "mean_latency": mean_latency, "score": score, @@ -1215,9 +1215,9 @@ def save_model_arrival_latency_history(self, nei, model_arrival_latency, round_n """ try: current_key = nei - + self._initialize_latency_history_entry(round_num, current_key, model_arrival_latency) - + if model_arrival_latency > 0 and round_num >= 1: avg_model_arrival_latency = self._calculate_latency_weighted_average_positive( round_num, current_key, model_arrival_latency @@ -1236,7 +1236,7 @@ def save_model_arrival_latency_history(self, nei, model_arrival_latency, round_n ) return avg_model_arrival_latency - + except Exception: logging.exception("Error saving model_arrival_latency history") @@ -1284,14 +1284,14 @@ def manage_metric_number_message( ) -> tuple[float, int]: """ Manage the number of messages metric for a specific neighbor. - + Args: messages_number_message: List of message data addr: Source address nei: Neighbor address current_round: Current round number metric_active: Whether the metric is active - + Returns: Tuple of (normalized_messages, messages_count) """ @@ -1301,13 +1301,13 @@ def manage_metric_number_message( messages_count = self._count_relevant_messages(messages_number_message, addr, nei, current_round) neighbor_stats = self._calculate_neighbor_statistics(messages_number_message, current_round) - + normalized_messages = self._calculate_normalized_messages(messages_count, neighbor_stats) - + normalized_messages = self._apply_historical_penalty( normalized_messages, addr, nei, current_round ) - + self._store_message_history(addr, nei, current_round, normalized_messages) normalized_messages = max(0.001, normalized_messages) @@ -1339,7 +1339,7 @@ def _calculate_neighbor_statistics(self, messages: list, current_round: int) -> neighbor_counts[key] = neighbor_counts.get(key, 0) + 1 counts_all_neighbors = list(neighbor_counts.values()) - + if not counts_all_neighbors: return { "percentile_reference": 0, @@ -1349,7 +1349,7 @@ def _calculate_neighbor_statistics(self, messages: list, current_round: int) -> } mean_messages = np.mean(counts_all_neighbors) - + return { "percentile_reference": np.percentile(counts_all_neighbors, 25), "std_dev": np.std(counts_all_neighbors), @@ -1361,10 +1361,10 @@ def _calculate_normalized_messages(self, messages_count: int, neighbor_stats: di """Calculate normalized message score with relative and extra penalties.""" normalized_messages = 1.0 penalties_applied = [] - + relative_increase = self._calculate_relative_increase(messages_count, neighbor_stats["percentile_reference"]) dynamic_margin = self._calculate_dynamic_margin(neighbor_stats) - + if relative_increase > dynamic_margin: penalty_ratio = self._calculate_penalty_ratio(relative_increase, dynamic_margin) normalized_messages *= np.exp(-(penalty_ratio**2)) @@ -1400,7 +1400,7 @@ def _calculate_penalty_ratio(self, relative_increase: float, dynamic_margin: flo def _should_apply_extra_penalty(self, messages_count: int, neighbor_stats: dict) -> bool: """Determine if extra penalty should be applied.""" - return (neighbor_stats["mean_messages"] > 0 and + return (neighbor_stats["mean_messages"] > 0 and messages_count > neighbor_stats["augment_mean"]) def _calculate_extra_penalty_factor(self, messages_count: int, neighbor_stats: dict) -> float: @@ -1408,7 +1408,7 @@ def _calculate_extra_penalty_factor(self, messages_count: int, neighbor_stats: d epsilon = 1e-6 mean_messages = neighbor_stats["mean_messages"] augment_mean = neighbor_stats["augment_mean"] - + extra_penalty = (messages_count - mean_messages) / (mean_messages + epsilon) amplification = 1 + (augment_mean / (mean_messages + epsilon)) return extra_penalty * amplification @@ -1417,27 +1417,27 @@ def _apply_historical_penalty(self, normalized_messages: float, addr: str, nei: """Apply historical penalty based on previous round's score.""" if current_round <= 1: return normalized_messages - + prev_data = ( self.number_message_history.get((addr, nei), {}) .get(current_round - 1, {}) ) - + prev_score = prev_data.get("normalized_messages") was_previously_penalized = prev_data.get("was_penalized", False) - + if prev_score is not None and prev_score < self.HISTORICAL_PENALTY_THRESHOLD: original_score = normalized_messages - + if was_previously_penalized: penalty_factor = self.HISTORICAL_PENALTY_THRESHOLD * 0.8 logging.debug(f"Repeated penalty applied to {nei}: stricter historical penalty") else: penalty_factor = self.HISTORICAL_PENALTY_THRESHOLD - + normalized_messages *= penalty_factor logging.debug(f"Historical penalty applied to {nei}: {original_score:.4f} -> {normalized_messages:.4f} (prev_score: {prev_score:.4f}, was_penalized: {was_previously_penalized})") - + return normalized_messages def _store_message_history(self, addr: str, nei: str, current_round: int, normalized_messages: float): @@ -1445,9 +1445,9 @@ def _store_message_history(self, addr: str, nei: str, current_round: int, normal key = (addr, nei) if key not in self.number_message_history: self.number_message_history[key] = {} - + was_penalized = normalized_messages < 1.0 - + self.number_message_history[key][current_round] = { "normalized_messages": normalized_messages, "was_penalized": was_penalized, @@ -1464,9 +1464,9 @@ def save_number_message_history(self, addr, nei, messages_number_message_normali """ try: key = (addr, nei) - + self._initialize_message_history_entry(key, current_round, messages_number_message_normalized) - + if messages_number_message_normalized > 0 and current_round >= 1: avg_number_message = self._calculate_weighted_average_positive(key, current_round, messages_number_message_normalized) elif messages_number_message_normalized == 0 and current_round >= 1: @@ -1478,7 +1478,7 @@ def save_number_message_history(self, addr, nei, messages_number_message_normali self.number_message_history[key][current_round]["avg_number_message"] = avg_number_message return avg_number_message - + except Exception: logging.exception("Error saving number_message history") return -1 @@ -1524,7 +1524,7 @@ async def save_reputation_history_in_memory(self, addr: str, nei: str, reputatio Args: addr: The node's identifier - nei: The neighboring node identifier + nei: The neighboring node identifier reputation: The reputation value to save Returns: @@ -1533,27 +1533,27 @@ async def save_reputation_history_in_memory(self, addr: str, nei: str, reputatio try: key = (addr, nei) current_round = await self._engine.get_round() - + if key not in self.reputation_history: self.reputation_history[key] = {} self.reputation_history[key][current_round] = reputation rounds = sorted(self.reputation_history[key].keys(), reverse=True)[:2] - + if len(rounds) >= 2: current_rep = self.reputation_history[key][rounds[0]] previous_rep = self.reputation_history[key][rounds[1]] - + current_weight = self.REPUTATION_CURRENT_WEIGHT previous_weight = self.REPUTATION_FEEDBACK_WEIGHT avg_reputation = (current_rep * current_weight) + (previous_rep * previous_weight) - + logging.info(f"Current reputation: {current_rep}, Previous reputation: {previous_rep}") logging.info(f"Reputation ponderated: {avg_reputation}") else: avg_reputation = reputation - + return avg_reputation except Exception: @@ -1577,23 +1577,23 @@ def calculate_similarity_from_metrics(self, nei: str, current_round: int) -> flo return 0.0 relevant_metrics = [ - metric for metric in metrics_instance.similarity + metric for metric in metrics_instance.similarity if metric.get("nei") == nei and metric.get("current_round") == current_round ] - + if not relevant_metrics: relevant_metrics = [ - metric for metric in metrics_instance.similarity + metric for metric in metrics_instance.similarity if metric.get("nei") == nei ] - + if not relevant_metrics: return 0.0 neighbor_metric = relevant_metrics[-1] similarity_weights = { "cosine": 0.25, - "euclidean": 0.25, + "euclidean": 0.25, "manhattan": 0.25, "pearson_correlation": 0.25, } @@ -1604,7 +1604,7 @@ def calculate_similarity_from_metrics(self, nei: str, current_round: int) -> flo ) return max(0.0, min(1.0, similarity_value)) - + except Exception: return 0.0 @@ -1620,9 +1620,9 @@ async def calculate_reputation(self, ae: AggregationEvent): (updates, _, _) = await ae.get_event_data() await self._log_reputation_calculation_start() - + neighbors = set(await self._engine._cm.get_addrs_current_connections(only_direct=True)) - + await self._process_neighbor_metrics(neighbors) await self._calculate_reputation_by_factor(neighbors) await self._handle_initial_reputation() @@ -1644,7 +1644,7 @@ async def _process_neighbor_metrics(self, neighbors): metrics = await self.calculate_value_metrics( self._addr, nei, metrics_active=self._metrics ) - + if self._weighting_factor == "dynamic": await self._process_dynamic_metrics(nei, metrics) elif self._weighting_factor == "static" and await self._engine.get_round() >= 1: @@ -1653,7 +1653,7 @@ async def _process_neighbor_metrics(self, neighbors): async def _process_dynamic_metrics(self, nei, metrics): """Process metrics for dynamic weighting factor.""" (metric_messages_number, metric_similarity, metric_fraction, metric_model_arrival_latency) = metrics - + self.calculate_weighted_values( metric_messages_number, metric_similarity, @@ -1669,7 +1669,7 @@ async def _process_dynamic_metrics(self, nei, metrics): async def _process_static_metrics(self, nei, metrics): """Process metrics for static weighting factor.""" (metric_messages_number, metric_similarity, metric_fraction, metric_model_arrival_latency) = metrics - + metric_values_dict = { "num_messages": metric_messages_number, "model_similarity": metric_similarity, @@ -1686,7 +1686,7 @@ async def _calculate_reputation_by_factor(self, neighbors): async def _handle_initial_reputation(self): """Handle reputation initialization for the first round.""" if await self._engine.get_round() < 1 and self._enabled: - federation = self._engine.config.participant["network_args"]["neighbors"].split() + federation = self._engine.config.participant["network_args"]["neighbors"] await self.init_reputation( federation_nodes=federation, round_num=await self._engine.get_round(), @@ -1698,7 +1698,7 @@ async def _process_feedback(self): """Process and include feedback in reputation.""" status = await self.include_feedback_in_reputation() current_round = await self._engine.get_round() - + if status: logging.info(f"Feedback included in reputation at round {current_round}") else: @@ -1735,7 +1735,7 @@ async def send_reputation_to_neighbors(self, neighbors): def create_graphic_reputation(self, addr: str, round_num: int): """ Log reputation data for visualization. - + Args: addr: The node address round_num: The round number for logging step @@ -1746,7 +1746,7 @@ def create_graphic_reputation(self, addr: str, round_num: int): for node_id, data in self.reputation.items() if data.get("reputation") is not None } - + if valid_reputations: reputation_data = {f"Reputation/{addr}": valid_reputations} self._engine.trainer._logger.log_data(reputation_data, step=round_num) @@ -1954,26 +1954,26 @@ def _recalculate_pending_latencies(self, current_round): async def recollect_similarity(self, ure: UpdateReceivedEvent): """ Collect and analyze model similarity metrics. - + Args: ure: UpdateReceivedEvent containing model and metadata """ (decoded_model, weight, nei, round_num, local) = await ure.get_event_data() - + if not (self._enabled and self._is_metric_enabled("model_similarity")): return - + if not self._engine.config.participant["adaptive_args"]["model_similarity"]: return - + if nei == self._addr: return - + logging.info("🤖 handle_model_message | Checking model similarity") - + local_model = self._engine.trainer.get_model_parameters() similarity_values = self._calculate_all_similarity_metrics(local_model, decoded_model) - + similarity_metrics = { "timestamp": datetime.now(), "nei": nei, @@ -1996,7 +1996,7 @@ def _calculate_all_similarity_metrics(self, local_model: dict, received_model: d "jaccard": 0.0, "minkowski": 0.0, } - + similarity_functions = [ ("cosine", cosine_metric), ("euclidean", euclidean_metric), @@ -2004,29 +2004,29 @@ def _calculate_all_similarity_metrics(self, local_model: dict, received_model: d ("pearson_correlation", pearson_correlation_metric), ("jaccard", jaccard_metric), ] - + similarity_values = {} - + for name, metric_func in similarity_functions: try: similarity_values[name] = metric_func(local_model, received_model, similarity=True) except Exception: similarity_values[name] = 0.0 - + try: similarity_values["minkowski"] = minkowski_metric( local_model, received_model, p=2, similarity=True ) except Exception: similarity_values["minkowski"] = 0.0 - + return similarity_values def _store_similarity_metrics(self, nei: str, similarity_metrics: dict): """Store similarity metrics for the given neighbor.""" if nei not in self.connection_metrics: self.connection_metrics[nei] = Metrics() - + self.connection_metrics[nei].similarity.append(similarity_metrics) def _check_similarity_threshold(self, nei: str, cosine_value: float): @@ -2064,25 +2064,25 @@ async def _record_message_data(self, source: str): async def recollect_fraction_of_parameters_changed(self, ure: UpdateReceivedEvent): """ Collect and analyze the fraction of parameters that changed between models. - + Args: ure: UpdateReceivedEvent containing model and metadata """ (decoded_model, weight, source, round_num, local) = await ure.get_event_data() - + current_round = await self._engine.get_round() parameters_local = self._engine.trainer.get_model_parameters() - + prev_threshold = self._get_previous_threshold(source, current_round) differences = self._calculate_parameter_differences(parameters_local, decoded_model) current_threshold = self._calculate_threshold(differences, prev_threshold) - + changed_params, total_params, changes_record = self._count_changed_parameters( parameters_local, decoded_model, current_threshold ) - + fraction_changed = changed_params / total_params if total_params > 0 else 0.0 - + self._store_fraction_data(source, current_round, { "fraction_changed": fraction_changed, "total_params": total_params, @@ -2102,7 +2102,7 @@ async def recollect_fraction_of_parameters_changed(self, ure: UpdateReceivedEven def _get_previous_threshold(self, source: str, current_round: int) -> float: """Get the threshold from the previous round for the given source.""" - if (source in self.fraction_of_params_changed and + if (source in self.fraction_of_params_changed and current_round - 1 in self.fraction_of_params_changed[source]): return self.fraction_of_params_changed[source][current_round - 1][-1]["threshold"] return None @@ -2122,7 +2122,7 @@ def _calculate_threshold(self, differences: list, prev_threshold: float) -> floa """Calculate the threshold for determining parameter changes.""" if not differences: return 0 - + mean_threshold = torch.mean(torch.tensor(differences)).item() if prev_threshold is not None: return (prev_threshold + mean_threshold) / 2 @@ -2133,20 +2133,20 @@ def _count_changed_parameters(self, local_params: dict, received_params: dict, t total_params = 0 changed_params = 0 changes_record = {} - + for key in local_params.keys(): if key in received_params: local_tensor = local_params[key].cpu() received_tensor = received_params[key].cpu() diff = torch.abs(local_tensor - received_tensor) total_params += diff.numel() - + num_changed = torch.sum(diff > threshold).item() changed_params += num_changed - + if num_changed > 0: changes_record[key] = num_changed - + return changed_params, total_params, changes_record def _store_fraction_data(self, source: str, current_round: int, data: dict): @@ -2155,5 +2155,5 @@ def _store_fraction_data(self, source: str, current_round: int, data: dict): self.fraction_of_params_changed[source] = {} if current_round not in self.fraction_of_params_changed[source]: self.fraction_of_params_changed[source][current_round] = [] - - self.fraction_of_params_changed[source][current_round].append(data) \ No newline at end of file + + self.fraction_of_params_changed[source][current_round].append(data) diff --git a/nebula/addons/topologymanager.py b/nebula/addons/topologymanager.py index c29937372..eda6429cd 100755 --- a/nebula/addons/topologymanager.py +++ b/nebula/addons/topologymanager.py @@ -322,15 +322,16 @@ def update_nodes(self, config_participants): def get_neighbors_string(self, node_idx): """ - Retrieves the neighbors of a given node as a string representation. + Retrieves the neighbors of a given node as a list of string representations. - This method checks the `topology` attribute to find the neighbors of the node at the specified index (`node_idx`). It then returns a string that lists the coordinates of each neighbor. + This method checks the `topology` attribute to find the neighbors of the node at the specified index (`node_idx`). + It then returns a list that contains the coordinates of each neighbor in string format. Parameters: node_idx (int): The index of the node for which neighbors are to be retrieved. Returns: - str: A space-separated string of neighbors' coordinates in the format "latitude:longitude". + list[str]: A list of neighbors' coordinates in the format "latitude:longitude". """ logging.info(f"Getting neighbors for node {node_idx}") logging.info(f"Topology shape: {self.topology.shape}") @@ -342,9 +343,8 @@ def get_neighbors_string(self, node_idx): logging.info(f"Found neighbor at index {i}: {self.nodes[i]}") neighbors_data_strings = [f"{i[0]}:{i[1]}" for i in neighbors_data] - neighbors_data_string = " ".join(neighbors_data_strings) - logging.info(f"Neighbors of node participant_{node_idx}: {neighbors_data_string}") - return neighbors_data_string + logging.info(f"Neighbors of node participant_{node_idx}: {neighbors_data_strings}") + return neighbors_data_strings def __ring_topology(self, increase_convergence=False): """ diff --git a/nebula/config/config.py b/nebula/config/config.py index 5ef336e3a..d5f3a813f 100755 --- a/nebula/config/config.py +++ b/nebula/config/config.py @@ -55,7 +55,7 @@ def reset_logging_configuration(self): self.__set_default_logging(mode="a") self.__set_training_logging(mode="a") - + def shutdown_logging(self): """ Properly shuts down all loggers and their handlers in the system. @@ -204,40 +204,46 @@ def add_participants_config(self, participants_config): def add_neighbor_from_config(self, addr): if self.participant != {}: - if self.participant["network_args"]["neighbors"] == "": - self.participant["network_args"]["neighbors"] = addr + neighbors = self.participant["network_args"]["neighbors"] + + if not neighbors: + self.participant["network_args"]["neighbors"] = [addr] self.participant["mobility_args"]["neighbors_distance"][addr] = None else: - if addr not in self.participant["network_args"]["neighbors"]: - self.participant["network_args"]["neighbors"] += " " + addr + if addr not in neighbors: + self.participant["network_args"]["neighbors"].append(addr) self.participant["mobility_args"]["neighbors_distance"][addr] = None def update_nodes_distance(self, distances: dict): self.participant["mobility_args"]["neighbors_distance"] = {node: dist for node, (dist, _) in distances.items()} def update_neighbors_from_config(self, current_connections, dest_addr): - final_neighbors = [] - for n in current_connections: - if n != dest_addr: - final_neighbors.append(n) - - final_neighbors_string = " ".join(final_neighbors) - # Update neighbors - self.participant["network_args"]["neighbors"] = final_neighbors_string + final_neighbors = [n for n in current_connections if n != dest_addr] + + # Update neighbors como lista (no string) + self.participant["network_args"]["neighbors"] = final_neighbors + # Update neighbors location self.participant["mobility_args"]["neighbors_distance"] = { n: self.participant["mobility_args"]["neighbors_distance"][n] for n in final_neighbors if n in self.participant["mobility_args"]["neighbors_distance"] } - logging.info(f"Final neighbors: {final_neighbors_string} (config updated))") + + logging.info(f"Final neighbors: {final_neighbors} (config updated)") + def remove_neighbor_from_config(self, addr): - if self.participant != {}: - if self.participant["network_args"]["neighbors"] != "": - self.participant["network_args"]["neighbors"] = ( - self.participant["network_args"]["neighbors"].replace(addr, "").replace(" ", " ").strip() - ) + if self.participant: + neighbors = self.participant["network_args"]["neighbors"] + + if addr in neighbors: + neighbors.remove(addr) + self.participant["network_args"]["neighbors"] = neighbors + + if addr in self.participant["mobility_args"]["neighbors_distance"]: + del self.participant["mobility_args"]["neighbors_distance"][addr] + def reload_config_file(self): config_dir = self.participant["tracking_args"]["config_dir"] diff --git a/nebula/controller/controller.py b/nebula/controller/controller.py index a58b74e41..266e8b03b 100755 --- a/nebula/controller/controller.py +++ b/nebula/controller/controller.py @@ -632,7 +632,7 @@ async def update_nodes( str(config["network_args"]["ip"]), str(config["network_args"]["port"]), str(config["device_args"]["role"]), - str(config["network_args"]["neighbors"]), + config["network_args"]["neighbors"], str(config["mobility_args"]["latitude"]), str(config["mobility_args"]["longitude"]), str(timestamp), diff --git a/nebula/controller/database.py b/nebula/controller/database.py index 54279800f..f0548f10c 100755 --- a/nebula/controller/database.py +++ b/nebula/controller/database.py @@ -204,7 +204,7 @@ async def update_node_record( timestamp, federation, round, scenario, hash, malicious) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14); """, - node_uid, idx, ip, port, role, json.dumps(neighbors), latitude, longitude, + node_uid, idx, ip, port, role, neighbors, latitude, longitude, timestamp, federation, federation_round, scenario, run_hash, malicious, ) else: @@ -216,7 +216,7 @@ async def update_node_record( hash = $11, malicious = $12 WHERE uid = $13 AND scenario = $14; """, - idx, ip, port, role, json.dumps(neighbors), latitude, longitude, + idx, ip, port, role, neighbors, latitude, longitude, timestamp, federation, federation_round, run_hash, malicious, node_uid, scenario, ) @@ -710,42 +710,3 @@ def remove_note(scenario): # os.environ['DB_PORT'] = '5432' logging.basicConfig(level=logging.INFO) - - print("Listing users:") - users = list_users(all_info=True) - for user in users: - print(f"- User: {user['user']}, Role: {user['role']}") - - # Example of adding/updating a user - # print("\nAdding/Updating test user:") - # add_user("TESTUSER", "testpassword123", "user") - # print(get_user_info("TESTUSER")) - - # Example of scenario operations - # print("\nScenario operations:") - # current_time_str = datetime.datetime.now().strftime('%Y/%m/%d %H:%M:%S') - # scenario_data = { - # "title": "My Test Scenario", - # "model": "NN", - # "dataset": "MNIST", - # "rounds": "10", - # "description": "A test scenario for demonstration." - # } - # scenario_update_record("test_scenario_1", current_time_str, "", scenario_data, "running", "ADMIN") - # print("Running scenarios:") - # print(get_running_scenario(username="ADMIN", get_all=True)) - - # print("\nAll scenarios:") - # all_scenarios = get_all_scenarios_and_check_completed("ADMIN", "admin", sort_by="start_time") - # for s in all_scenarios: - # print(f"Scenario: {s['name']}, Status: {s['status']}, Title: {s.get('title')}") - - # print("\nSetting a scenario to finished:") - # scenario_set_status_to_finished("test_scenario_1") - # print(get_scenario_by_name("test_scenario_1")) - - # print("\nTesting notes:") - # save_notes("test_scenario_1", "These are some notes for test scenario 1.") - # print(get_notes("test_scenario_1")) - # remove_note("test_scenario_1") - # print(get_notes("test_scenario_1")) diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 151ae6b22..fc2fef4f6 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -94,7 +94,7 @@ def __init__( self.ip = config.participant["network_args"]["ip"] self.port = config.participant["network_args"]["port"] self.addr = config.participant["network_args"]["addr"] - + self.name = config.participant["device_args"]["name"] self.client = docker.from_env() @@ -187,7 +187,7 @@ def aggregator(self): def trainer(self): """Trainer""" return self._trainer - + @property def rb(self): """Role Behavior""" @@ -317,7 +317,7 @@ async def _control_alive_callback(self, source, message): async def _control_leadership_transfer_callback(self, source, message): logging.info(f"🔧 handle_control_message | Trigger | Received leadership transfer message from {source}") - + if await self._round_in_process_lock.locked_async(): logging.info("Learning cycle is executing, role behavior will be modified next round") await self.rb.set_next_role(Role.AGGREGATOR, source_to_notificate=source) @@ -354,7 +354,7 @@ async def _control_leadership_transfer_ack_callback(self, source, message): except TimeoutError: logging.info("Learning cycle is locked, role behavior will be modified next round") await self.rb.set_next_role(Role.TRAINER) - + async def _connection_connect_callback(self, source, message): logging.info(f"🔗 handle_connection_message | Trigger | Received connection message from {source}") @@ -600,7 +600,7 @@ async def start_communications(self): before other services or training processes begin. """ await self.register_events_callbacks() - initial_neighbors = self.config.participant["network_args"]["neighbors"].split() + initial_neighbors = self.config.participant["network_args"]["neighbors"] await self.cm.start_communications(initial_neighbors) await asyncio.sleep(self.config.participant["misc_args"]["grace_time_connection"] // 2) @@ -710,10 +710,10 @@ async def _start_learning(self): await self.get_federation_ready_lock().acquire_async() if self.config.participant["device_args"]["start"]: logging.info("Propagate initial model updates.") - + mpe = ModelPropagationEvent(await self.cm.get_addrs_current_connections(only_direct=True, myself=False), "initialization") await EventManager.get_instance().publish_node_event(mpe) - + await self.get_federation_ready_lock().release_async() self.trainer.set_epochs(epochs) @@ -764,7 +764,7 @@ async def learning_cycle_finished(self): return False else: return current_round >= self.total_rounds - + async def resolve_missing_updates(self): """ Delegates the resolution strategy for missing updates to the current role behavior. @@ -778,7 +778,7 @@ async def resolve_missing_updates(self): """ logging.info(f"Using Role behavior: {self.rb.get_role_name()} conflict resolve strategy") return await self.rb.resolve_missing_updates() - + async def update_self_role(self): """ Checks whether a role update is required and performs the transition if necessary. @@ -806,7 +806,7 @@ async def update_self_role(self): logging.info(f"Sending role modification ACK to transferer: {source_to_notificate}") message = self.cm.create_message("control", "leadership_transfer_ack") asyncio.create_task(self.cm.send_message(source_to_notificate, message)) - + async def _learning_cycle(self): """ Main asynchronous loop for executing the Federated Learning process across multiple rounds. @@ -837,9 +837,9 @@ async def _learning_cycle(self): indent=2, title="Round information", ) - + await self.update_self_role() - + logging.info(f"Federation nodes: {self.federation_nodes}") await self.update_federation_nodes( await self.cm.get_addrs_current_connections(only_direct=True, myself=True) @@ -851,10 +851,10 @@ async def _learning_cycle(self): logging.info(f"Expected nodes: {expected_nodes}") direct_connections = await self.cm.get_addrs_current_connections(only_direct=True) undirected_connections = await self.cm.get_addrs_current_connections(only_undirected=True) - + logging.info(f"Direct connections: {direct_connections} | Undirected connections: {undirected_connections}") logging.info(f"[Role {self.rb.get_role_name()}] Starting learning cycle...") - + await self.aggregator.update_federation_nodes(expected_nodes) async with self._role_behavior_performance_lock: await self.rb.extended_learning_cycle() @@ -882,13 +882,13 @@ async def _learning_cycle(self): self.trainer.on_learning_cycle_end() await self.trainer.test() - + # Shutdown protocol await self._shutdown_protocol() - + async def _shutdown_protocol(self): logging.info("Starting graceful shutdown process...") - + # 1.- Publish Experiment Finish Event to the last update on modules logging.info("Publishing Experiment Finish Event...") efe = ExperimentFinishEvent() diff --git a/nebula/database/init-configs.sql b/nebula/database/init-configs.sql index f125342c5..0d6d3631d 100644 --- a/nebula/database/init-configs.sql +++ b/nebula/database/init-configs.sql @@ -20,7 +20,7 @@ CREATE TABLE IF NOT EXISTS nodes ( ip TEXT, port TEXT, role TEXT, - neighbors TEXT, + neighbors TEXT[], latitude TEXT, longitude TEXT, timestamp TEXT, @@ -40,7 +40,7 @@ CREATE TABLE configs ( created_at TIMESTAMPTZ DEFAULT NOW() ); CREATE INDEX idx_configs_config_gin ON configs USING GIN (config); - + -- 4) Scenarios table as JSONB CREATE TABLE IF NOT EXISTS scenarios ( name TEXT PRIMARY KEY, @@ -66,4 +66,4 @@ CREATE TABLE IF NOT EXISTS notes ( -- The hash must be generated by a Python script using passlib. -- Replace the placeholder with your generated hash. INSERT INTO users ("user", password, role) VALUES ('ADMIN', '$argon2id$v=19$m=65536,t=3,p=4$OobPh8BkZeT6D5s+Rt11mQ$JjI2M3U5+4lupdr87/GrIn46ImzoQujNEyVd7IGYiXY', 'admin') -ON CONFLICT ("user") DO NOTHING; \ No newline at end of file +ON CONFLICT ("user") DO NOTHING; diff --git a/nebula/frontend/app.py b/nebula/frontend/app.py index 830a87c85..76ba8dabe 100755 --- a/nebula/frontend/app.py +++ b/nebula/frontend/app.py @@ -1611,7 +1611,7 @@ async def nebula_dashboard_monitor(scenario_name: str, request: Request, session "ip": node["ip"], "port": node["port"], "role": node["role"], - "neighbors": node["neighbors"], + "neighbors": " ".join(node["neighbors"]), "latitude": node["latitude"], "longitude": node["longitude"], "timestamp": node["timestamp"], @@ -1715,8 +1715,7 @@ async def nebula_update_node(scenario_name: str, request: Request): "ip": config["network_args"]["ip"], "port": str(config["network_args"]["port"]), "role": config["device_args"]["role"], - "malicious": config["device_args"]["malicious"], - "neighbors": config["network_args"]["neighbors"], + "neighbors": " ".join(config["network_args"]["neighbors"]), "latitude": config["mobility_args"]["latitude"], "longitude": config["mobility_args"]["longitude"], "timestamp": config["timestamp"], From 5126bc1a4066e9411d045e117a497841207af281 Mon Sep 17 00:00:00 2001 From: FerTV Date: Wed, 9 Jul 2025 12:44:39 +0200 Subject: [PATCH 05/20] chore: renamed log in physical api --- nebula/physical/api.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nebula/physical/api.py b/nebula/physical/api.py index 6d44428e7..3e7da1bec 100644 --- a/nebula/physical/api.py +++ b/nebula/physical/api.py @@ -350,7 +350,7 @@ def run(): if TRAINING_PROC and TRAINING_PROC.poll() is None: _json_abort(409, "Training already running") - cmd = ["python", "/home/dietpi/prueba/nebula/nebula/node.py", json_files[0]] + cmd = ["python", "/home/dietpi/test/nebula/nebula/node.py", json_files[0]] TRAINING_PROC = subprocess.Popen(cmd) return jsonify(pid=TRAINING_PROC.pid, state="running") @@ -379,8 +379,8 @@ def setup_new_run(): Expected multipart-form fields ------------------------------- - * **config** – JSON with scenario, network and security arguments - * **global_test** – shared evaluation dataset (`*.h5`) + * **config** – JSON with scenario, network and security arguments + * **global_test** – shared evaluation dataset (`*.h5`) * **train_set** – participant-specific training dataset (`*.h5`) The function rewrites paths inside *config*, validates neighbour IPs @@ -489,4 +489,4 @@ def setup_new_run(): # ----------------------------------------------------------------------------- if __name__ == "__main__": # Local testing: python main.py - app.run(host="0.0.0.0", port=8000, debug=False) \ No newline at end of file + app.run(host="0.0.0.0", port=8000, debug=False) From 1fe4ab3d9c5e00c2699d9ba486e421d899aa6c46 Mon Sep 17 00:00:00 2001 From: FerTV Date: Thu, 10 Jul 2025 10:45:01 +0200 Subject: [PATCH 06/20] feature: credentialmanager created --- app/deployer.py | 121 ++++++++++++++++++++++++++++++++++++----- nebula/frontend/app.py | 9 --- 2 files changed, 107 insertions(+), 23 deletions(-) diff --git a/app/deployer.py b/app/deployer.py index ef2ce4fe0..27ef8e8ce 100644 --- a/app/deployer.py +++ b/app/deployer.py @@ -1,7 +1,9 @@ import json import logging import os +import secrets import signal +import string import subprocess import sys import threading @@ -19,6 +21,92 @@ from nebula.controller.scenarios import ScenarioManagement from nebula.utils import DockerUtils, FileUtils, SocketUtils +class CredentialManager: + """ + CredentialManager handles the generation, storage, and validation of environment-based credentials. + + This class is designed to manage credentials required for different system components like the frontend, + Grafana, and the database. It ensures that secure values are generated and persisted in a `.env` file if + they are not already defined in the environment. + + Attributes: + env_path (Path): Absolute path to the environment file where credentials will be stored. + + Typical usage example: + manager = CredentialManager() + manager.check_all_credentials() + """ + + def __init__(self, env_dir="app", env_filename=".env"): + """ + Initializes the CredentialManager and loads existing environment variables from file. + + Args: + env_dir (str): Directory where the .env file is located. Defaults to 'app'. + env_filename (str): Name of the environment file. Defaults to '.env'. + + Behavior: + - Sets up the absolute path to the .env file. + - Loads any existing environment variables from the file using `load_dotenv`. + """ + self.env_path = Path.cwd() / env_dir / env_filename + if os.path.exists(self.env_path): + logging.info(f"Loading environment variables from {self.env_path}") + load_dotenv(self.env_path, override=True) + + def generate_secure_password(self, length=20): + """ + Generates a cryptographically secure and readable password including symbols. + + Args: + length (int): Length of the password. Defaults to 20. + + Returns: + str: A randomly generated secure password, excluding confusing or problematic characters. + """ + alphabet = string.ascii_letters + string.digits + string.punctuation + for char in ['"', "'", "\\", "`", "|", "(", ")", "{", "}", "[", "]", "#"]: + alphabet = alphabet.replace(char, "") + return ''.join(secrets.choice(alphabet) for _ in range(length)) + + def check_credential(self, key, is_password=True): + """ + Checks if a given credential key is present in the environment. If not, generates and saves it. + + Args: + key (str): The environment variable key to check or create. + is_password (bool): If True, generates a secure password. If False, generates a hex token. Defaults to True. + + Behavior: + - If the key is missing, a value is generated and stored both in the environment and the `.env` file. + - If the key exists, no action is taken. + """ + if key not in os.environ: + logging.info(f"Generating value for {key}") + value = self.generate_secure_password(12) if is_password else secrets.token_hex(24) + os.environ[key] = value + logging.info(f"Saving {key} to {self.env_path}") + with self.env_path.open("a") as f: + f.write(f"{key}={value}\n") + else: + logging.info(f"{key} already set") + + def check_all_credentials(self): + """ + Checks and sets all required credentials for the application. + + This method should be called at startup to ensure all necessary keys are initialized. + + Includes: + - Frontend secret key + - Grafana admin password + - (Optional) Database password + """ + self.check_credential("SECRET_KEY", is_password=False) + self.check_credential("GF_SECURITY_ADMIN_PASSWORD") + self.check_credential("POSTGRES_PASSWORD") + self.check_credential("HTTP_PASSWORD") + class NebulaEventHandler(PatternMatchingEventHandler): """ @@ -513,6 +601,10 @@ def __init__(self, args): ) logging.exception(warning_msg) sys.exit(1) + + self.configure_logger() + self.credentialmanager = CredentialManager() + self.credentialmanager.check_all_credentials() # --- Tag logic: CLI args > environment > fallback --- arg_production = getattr(args, "production", False) @@ -565,7 +657,6 @@ def __init__(self, args): self.loki_port = int(args.lokiport) if hasattr(args, "lokiport") else 6010 self.statistics_port = int(args.statsport) if hasattr(args, "statsport") else 8080 - self.configure_logger() def get_container_name(self, role_tag: str) -> str: """ @@ -854,6 +945,8 @@ def run_frontend(self): client = docker.from_env() environment = { + "NEBULA_CONTROLLER_NAME": os.environ["USER"], + "SECRET_KEY": os.environ.get("SECRET_KEY"), "NEBULA_PRODUCTION": self.production, "NEBULA_ENV_TAG": self.env_tag, "NEBULA_PREFIX_TAG": self.prefix_tag, @@ -929,11 +1022,11 @@ def run_database(self): ) network_name = f"{os.environ['USER']}_nebula-net-base" - + ############### # POSTGRES DB # ############### - + host_port = 54312 # Create the Docker network @@ -943,7 +1036,7 @@ def run_database(self): environment = { "POSTGRES_USER": "nebula", - "POSTGRES_PASSWORD": "nebula", + "POSTGRES_PASSWORD": os.environ.get("POSTGRES_PASSWORD"), "POSTGRES_DB": "nebula", "NEBULA_DATABASES_PORT": self.controller_host, } @@ -973,14 +1066,14 @@ def run_database(self): ) client.api.start(container_id) - + ################ # POSTGRES WEB # ################ - + pgweb_host_port = 8085 - pgweb_container_port = 8081 - + pgweb_container_port = 8081 + pgweb_host_config = client.api.create_host_config( port_bindings={pgweb_container_port: pgweb_host_port}, device_requests=[{ @@ -989,13 +1082,13 @@ def run_database(self): "Capabilities": [["gpu"]], }] if self.gpu_available else None, ) - + pgweb_networking_config = client.api.create_networking_config({ f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.135") }) - + pgweb_container_name = f"{os.environ.get('USER')}_nebula-pgweb" - + pgweb_container_id = client.api.create_container( image="nebula-pgweb", name=pgweb_container_name, @@ -1003,7 +1096,7 @@ def run_database(self): host_config=pgweb_host_config, networking_config=pgweb_networking_config, ) - + client.api.start(pgweb_container_id) ######### @@ -1047,7 +1140,7 @@ def run_database(self): environment_commander = { "REDIS_HOSTS": "local:redis:6379", "HTTP_USER": "root", - "HTTP_PASSWORD": "root", + "HTTP_PASSWORD": os.environ.get("HTTP_PASSWORD"), } host_config_commander = client.api.create_host_config( @@ -1243,7 +1336,7 @@ def run_waf(self): Deployer._add_container_to_metadata(waf_container_name) environment = { - "GF_SECURITY_ADMIN_PASSWORD": "admin", + "GF_SECURITY_ADMIN_PASSWORD": os.environ.get("GF_SECURITY_ADMIN_PASSWORD"), "GF_USERS_ALLOW_SIGN_UP": "false", "GF_SERVER_HTTP_PORT": "3000", "GF_SERVER_PROTOCOL": "http", diff --git a/nebula/frontend/app.py b/nebula/frontend/app.py index 76ba8dabe..3510aa0f0 100755 --- a/nebula/frontend/app.py +++ b/nebula/frontend/app.py @@ -121,15 +121,6 @@ class Settings: logging.info(f"NEBULA_PRODUCTION: {settings.env_tag == 'prod'}") logging.info(f"NEBULA_DEPLOYMENT_PREFIX: {settings.prefix_tag}") -if "SECRET_KEY" not in os.environ: - logging.info("Generating SECRET_KEY") - os.environ["SECRET_KEY"] = os.urandom(24).hex() - logging.info(f"Saving SECRET_KEY to {settings.env_file}") - with open(settings.env_file, "a") as f: - f.write(f"SECRET_KEY={os.environ['SECRET_KEY']}\n") -else: - logging.info("SECRET_KEY already set") - app = FastAPI() app.add_middleware( SessionMiddleware, From 7bec1307dc011248bcb5c9e378d72271eaa3f6d0 Mon Sep 17 00:00:00 2001 From: FerTV Date: Thu, 10 Jul 2025 11:33:44 +0200 Subject: [PATCH 07/20] chore: all querys changed to async --- nebula/controller/controller.py | 48 +-- nebula/controller/database.py | 668 ++++++++++++++++---------------- 2 files changed, 364 insertions(+), 352 deletions(-) diff --git a/nebula/controller/controller.py b/nebula/controller/controller.py index 266e8b03b..02a5d6da3 100755 --- a/nebula/controller/controller.py +++ b/nebula/controller/controller.py @@ -374,9 +374,9 @@ async def stop_scenario( ScenarioManagement.cleanup_scenario_containers() try: if all: - scenario_set_all_status_to_finished() + await scenario_set_all_status_to_finished() else: - scenario_set_status_to_finished(scenario_name) + await scenario_set_status_to_finished(scenario_name) except Exception as e: logging.exception(f"Error setting scenario {scenario_name} to finished: {e}") raise HTTPException(status_code=500, detail="Internal server error") @@ -399,7 +399,7 @@ async def remove_scenario( from nebula.controller.scenarios import ScenarioManagement try: - remove_scenario_by_name(scenario_name) + await remove_scenario_by_name(scenario_name) ScenarioManagement.remove_files_by_scenario(scenario_name) except Exception as e: logging.exception(f"Error removing scenario {scenario_name}: {e}") @@ -426,11 +426,11 @@ async def get_scenarios( from nebula.controller.database import get_all_scenarios_and_check_completed, get_running_scenario try: - scenarios = get_all_scenarios_and_check_completed(username=user, role=role) + scenarios = await get_all_scenarios_and_check_completed(username=user, role=role) if role == "admin": - scenario_running = get_running_scenario() + scenario_running = await get_running_scenario() else: - scenario_running = get_running_scenario(username=user) + scenario_running = await get_running_scenario(username=user) except Exception as e: logging.exception(f"Error obtaining scenarios: {e}") raise HTTPException(status_code=500, detail="Internal server error") @@ -467,7 +467,7 @@ async def update_scenario( # from nebula.controller.scenarios import Scenario try: - scenario_update_record(scenario_name, start_time, end_time, scenario, status, username) + await scenario_update_record(scenario_name, start_time, end_time, scenario, status, username) except Exception as e: logging.exception(f"Error updating scenario {scenario_name}: {e}") raise HTTPException(status_code=500, detail="Internal server error") @@ -493,9 +493,9 @@ async def set_scenario_status_to_finished( try: if all: - scenario_set_all_status_to_finished() + await scenario_set_all_status_to_finished() else: - scenario_set_status_to_finished(scenario_name) + await scenario_set_status_to_finished(scenario_name) except Exception as e: logging.exception(f"Error setting scenario {scenario_name} to finished: {e}") raise HTTPException(status_code=500, detail="Internal server error") @@ -517,7 +517,7 @@ async def get_running_scenario(get_all: bool = False): from nebula.controller.database import get_running_scenario try: - return get_running_scenario(get_all=get_all) + return await get_running_scenario(get_all=get_all) except Exception as e: logging.exception(f"Error obtaining running scenario: {e}") raise HTTPException(status_code=500, detail="Internal server error") @@ -543,7 +543,7 @@ async def check_scenario( from nebula.controller.database import check_scenario_with_role try: - allowed = check_scenario_with_role(role, scenario_name) + allowed = await check_scenario_with_role(role, scenario_name) return {"allowed": allowed} except Exception as e: logging.exception(f"Error checking scenario with role: {e}") @@ -568,7 +568,7 @@ async def get_scenario_by_name( from nebula.controller.database import get_scenario_by_name try: - scenario = get_scenario_by_name(scenario_name) + scenario = await get_scenario_by_name(scenario_name) except Exception as e: logging.exception(f"Error obtaining scenario {scenario_name}: {e}") raise HTTPException(status_code=500, detail="Internal server error") @@ -709,7 +709,7 @@ async def remove_nodes_by_scenario_name(scenario_name: str = Body(..., embed=Tru from nebula.controller.database import remove_nodes_by_scenario_name try: - remove_nodes_by_scenario_name(scenario_name) + await remove_nodes_by_scenario_name(scenario_name) except Exception as e: logging.exception(f"Error removing nodes: {e}") raise HTTPException(status_code=500, detail="Internal server error") @@ -729,7 +729,7 @@ async def get_notes_by_scenario_name( from nebula.controller.database import get_notes try: - notes_record = get_notes(scenario_name) + notes_record = await get_notes(scenario_name) if notes_record is not None: notes_record = dict(notes_record.items()) @@ -755,7 +755,7 @@ async def update_notes_by_scenario_name(scenario_name: str = Body(..., embed=Tru from nebula.controller.database import save_notes try: - save_notes(scenario_name, notes) + await save_notes(scenario_name, notes) except Exception as e: logging.exception(f"Error updating notes: {e}") raise HTTPException(status_code=500, detail="Internal server error") @@ -776,7 +776,7 @@ async def remove_notes_by_scenario_name(scenario_name: str = Body(..., embed=Tru from nebula.controller.database import remove_note try: - remove_note(scenario_name) + await remove_note(scenario_name) except Exception as e: logging.exception(f"Error removing notes: {e}") raise HTTPException(status_code=500, detail="Internal server error") @@ -797,9 +797,9 @@ async def list_users_controller(all_info: bool = False): from nebula.controller.database import list_users try: - user_list = list_users(all_info) + user_list = await list_users(all_info) if all_info: - # Convert each sqlite3.Row to a dictionary so that it is JSON serializable. + # Convert each asyncpg.Record to a dictionary so that it is JSON serializable. user_list = [dict(user) for user in user_list] return {"users": user_list} except Exception as e: @@ -823,7 +823,7 @@ async def get_user_by_scenario_name( from nebula.controller.database import get_user_by_scenario_name try: - user = get_user_by_scenario_name(scenario_name) + user = await get_user_by_scenario_name(scenario_name) except Exception as e: logging.exception(f"Error obtaining user {user}: {e}") raise HTTPException(status_code=500, detail="Internal server error") @@ -1028,7 +1028,7 @@ async def add_user_controller(user: str = Body(...), password: str = Body(...), from nebula.controller.database import add_user try: - add_user(user, password, role) + await add_user(user, password, role) return {"detail": "User added successfully"} except Exception as e: logging.exception(f"Error adding user: {e}") @@ -1048,7 +1048,7 @@ async def remove_user_controller(user: str = Body(..., embed=True)): from nebula.controller.database import delete_user_from_db try: - delete_user_from_db(user) + await delete_user_from_db(user) return {"detail": "User deleted successfully"} except Exception as e: logging.exception(f"Error deleting user: {e}") @@ -1070,7 +1070,7 @@ async def update_user_controller(user: str = Body(...), password: str = Body(... from nebula.controller.database import update_user try: - update_user(user, password, role) + await update_user(user, password, role) return {"detail": "User updated successfully"} except Exception as e: logging.exception(f"Error updating user: {e}") @@ -1092,8 +1092,8 @@ async def verify_user_controller(user: str = Body(...), password: str = Body(... try: user_submitted = user.upper() - if (user_submitted in list_users()) and verify(user_submitted, password): - user_info = get_user_info(user_submitted) + if (await list_users() and await verify(user_submitted, password)): + user_info = await get_user_info(user_submitted) return {"user": user_submitted, "role": user_info[2]} else: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) diff --git a/nebula/controller/database.py b/nebula/controller/database.py index f0548f10c..50d61db0d 100755 --- a/nebula/controller/database.py +++ b/nebula/controller/database.py @@ -23,11 +23,6 @@ # --- Connection Management Helper Functions --- -def get_sync_conn(): - """Establishes a synchronous PostgreSQL connection.""" - return psycopg2.connect(DATABASE_URL) - - async def get_async_conn(): """Establishes an asynchronous PostgreSQL connection.""" return await asyncpg.connect(DATABASE_URL) @@ -35,14 +30,15 @@ async def get_async_conn(): # --- User Management Functions --- -def list_users(all_info=False): +async def list_users(all_info=False): """ Retrieves a list of users from the users database. """ - with get_sync_conn() as conn: - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: - c.execute("SELECT * FROM users") - result = c.fetchall() + conn = await get_async_conn() + try: + result = await conn.fetch("SELECT * FROM users") + finally: + await conn.close() if not all_info: result = [user["user"] for user in result] @@ -50,115 +46,121 @@ def list_users(all_info=False): return result -def get_user_info(user): +async def get_user_info(user): """ Fetches detailed information for a specific user from the users database. """ - with get_sync_conn() as conn: - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: - c.execute("SELECT * FROM users WHERE \"user\" = %s", (user,)) - result = c.fetchone() + conn = await get_async_conn() + try: + result = await conn.fetchrow('SELECT * FROM users WHERE "user" = $1', user) + finally: + await conn.close() return result -def verify(user, password): +async def verify(user, password): """ Verifies whether the provided password matches the stored hashed password for a user. """ - with get_sync_conn() as conn: - with conn.cursor() as c: - c.execute("SELECT password FROM users WHERE \"user\" = %s", (user,)) - result = c.fetchone() - if result: - try: - return pwd_context.verify(password, result[0]) - except Exception: - # Catch more general exceptions during verification to be safe - logging.error(f"Error during password verification for user {user}", exc_info=True) - return False + conn = await get_async_conn() + try: + result = await conn.fetchrow('SELECT password FROM users WHERE "user" = $1', user) + if result: + try: + return pwd_context.verify(password, result[0]) + except Exception: + # Catch more general exceptions during verification to be safe + logging.error(f"Error during password verification for user {user}", exc_info=True) + return False + finally: + await conn.close() return False -def verify_hash_algorithm(user): +async def verify_hash_algorithm(user): """ Checks if the stored password hash for a user uses a supported Argon2 algorithm. """ user = user.upper() argon2_prefixes = ("$argon2i$", "$argon2id$") - with get_sync_conn() as conn: - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: - c.execute("SELECT password FROM users WHERE \"user\" = %s", (user,)) - result = c.fetchone() - if result: - password_hash = result["password"] - return password_hash.startswith(argon2_prefixes) + conn = await get_async_conn() + try: + result = await conn.fetchrow('SELECT password FROM users WHERE "user" = $1', user) + if result: + password_hash = result["password"] + return password_hash.startswith(argon2_prefixes) + finally: + await conn.close() return False -def delete_user_from_db(user): +async def delete_user_from_db(user): """ Deletes a user record from the users database. """ - with get_sync_conn() as conn: - with conn.cursor() as c: - c.execute("DELETE FROM users WHERE \"user\" = %s", (user,)) - conn.commit() + conn = await get_async_conn() + try: + await conn.execute('DELETE FROM users WHERE "user" = $1', user) + finally: + await conn.close() -def add_user(user, password, role): +async def add_user(user, password, role): """ Adds a new user to the users database with a hashed password. """ - with get_sync_conn() as conn: - with conn.cursor() as c: - hashed_password = pwd_context.hash(password) - c.execute( - "INSERT INTO users (\"user\", password, role) VALUES (%s, %s, %s)", - (user.upper(), hashed_password, role), - ) - conn.commit() + conn = await get_async_conn() + try: + hashed_password = pwd_context.hash(password) + await conn.execute( + 'INSERT INTO users ("user", password, role) VALUES ($1, $2, $3)', + user.upper(), hashed_password, role, + ) + finally: + await conn.close() -def update_user(user, password, role): +async def update_user(user, password, role): """ Updates the password and role of an existing user in the users database. """ - with get_sync_conn() as conn: - with conn.cursor() as c: - hashed_password = pwd_context.hash(password) - c.execute( - "UPDATE users SET password = %s, role = %s WHERE \"user\" = %s", - (hashed_password, role, user.upper()), - ) - conn.commit() + conn = await get_async_conn() + try: + hashed_password = pwd_context.hash(password) + await conn.execute( + 'UPDATE users SET password = $1, role = $2 WHERE "user" = $3', + hashed_password, role, user.upper(), + ) + finally: + await conn.close() # --- Node Management Functions --- -def list_nodes(scenario_name=None, sort_by="idx"): +async def list_nodes(scenario_name=None, sort_by="idx"): """ Retrieves a list of nodes from the nodes database, optionally filtered by scenario and sorted. """ + conn = await get_async_conn() try: - with get_sync_conn() as conn: - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: - # Validate sort_by to prevent SQL injection - allowed_sort_fields = ["uid", "idx", "ip", "port", "role", "timestamp", "federation", "round"] - if sort_by not in allowed_sort_fields: - sort_by = "idx" # Default to a safe field - - if scenario_name: - # Using psycopg2.extensions.AsIs for safe insertion of column names - command = f"SELECT * FROM nodes WHERE scenario = %s ORDER BY {psycopg2.extensions.AsIs(sort_by)};" - c.execute(command, (scenario_name,)) - else: - command = f"SELECT * FROM nodes ORDER BY {psycopg2.extensions.AsIs(sort_by)};" - c.execute(command) - - result = c.fetchall() - return result - except psycopg2.Error as e: + # Validate sort_by to prevent SQL injection + allowed_sort_fields = ["uid", "idx", "ip", "port", "role", "timestamp", "federation", "round"] + if sort_by not in allowed_sort_fields: + sort_by = "idx" # Default to a safe field + + if scenario_name: + # Using f-string for column names is generally safe if validated as above + command = f"SELECT * FROM nodes WHERE scenario = $1 ORDER BY {sort_by};" + result = await conn.fetch(command, scenario_name) + else: + command = f"SELECT * FROM nodes ORDER BY {sort_by};" + result = await conn.fetch(command) + + return result + except asyncpg.PostgresError as e: logging.error(f"Error occurred while listing nodes: {e}") return None + finally: + await conn.close() async def list_nodes_by_scenario_name(scenario_name): @@ -228,28 +230,30 @@ async def update_node_record( await conn.close() -def remove_all_nodes(): +async def remove_all_nodes(): """ Deletes all node records from the nodes database. """ - with get_sync_conn() as conn: - with conn.cursor() as c: - c.execute("TRUNCATE nodes CASCADE;") # Use CASCADE if there are foreign key dependencies - conn.commit() + conn = await get_async_conn() + try: + await conn.execute("TRUNCATE nodes CASCADE;") # Use CASCADE if there are foreign key dependencies + finally: + await conn.close() -def remove_nodes_by_scenario_name(scenario_name): +async def remove_nodes_by_scenario_name(scenario_name): """ Deletes all nodes associated with a specific scenario from the database. """ - with get_sync_conn() as conn: - with conn.cursor() as c: - c.execute("DELETE FROM nodes WHERE scenario = %s;", (scenario_name,)) - conn.commit() + conn = await get_async_conn() + try: + await conn.execute("DELETE FROM nodes WHERE scenario = $1;", scenario_name) + finally: + await conn.close() # --- Scenario Management Functions --- -def get_all_scenarios(username, role, sort_by="start_time"): +async def get_all_scenarios(username, role, sort_by="start_time"): """ Retrieves all scenarios from the database, accessing fields from the 'config' (JSONB) column and direct columns. Filters by user role and sorts by the specified field. @@ -273,38 +277,38 @@ def get_all_scenarios(username, role, sort_by="start_time"): else: # For direct table columns like name, username, status order_by_clause = f"ORDER BY {sort_by}" + conn = await get_async_conn() + try: + # Select direct columns and relevant fields from config JSONB + command = """ + SELECT + name, + username, + status, + start_time, + end_time, + config->>'title' AS title, + config->>'model' AS model, + config->>'dataset' AS dataset, + config->>'rounds' AS rounds, + config -- return the full config JSONB + FROM scenarios + """ + params = [] - with get_sync_conn() as conn: - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: - # Select direct columns and relevant fields from config JSONB - command = """ - SELECT - name, - username, - status, - start_time, - end_time, - config->>'title' AS title, - config->>'model' AS model, - config->>'dataset' AS dataset, - config->>'rounds' AS rounds, - config -- return the full config JSONB - FROM scenarios - """ - params = [] - - if role != "admin": - command += " WHERE username = %s" # username is a direct column now - params.append(username) + if role != "admin": + command += " WHERE username = $1" # username is a direct column now + params.append(username) - full_command = f"{command} {order_by_clause};" - c.execute(full_command, tuple(params)) - result = c.fetchall() + full_command = f"{command} {order_by_clause};" + result = await conn.fetch(full_command, *params) + finally: + await conn.close() return result -def get_all_scenarios_and_check_completed(username, role, sort_by="start_time"): +async def get_all_scenarios_and_check_completed(username, role, sort_by="start_time"): """ Retrieves all scenarios, sorts them, and updates the status if necessary. Returns a list of dictionaries, where each dictionary represents a scenario. @@ -330,219 +334,226 @@ def get_all_scenarios_and_check_completed(username, role, sort_by="start_time"): else: # For direct table columns like name, username, status order_by_clause = f"ORDER BY {sort_by}" - with get_sync_conn() as conn: - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: - # Base query that extracts fields from the JSONB using the ->> operator - command = f""" - SELECT - name, - username, - status, - start_time, - end_time, - config->>'title' AS title, - config->>'model' AS model, - config->>'dataset' AS dataset, - config->>'rounds' AS rounds, - config -- Return the full config object - FROM scenarios - """ - params = [] - if role != "admin": - command += " WHERE username = %s" # username is a direct column - params.append(username) + conn = await get_async_conn() + try: + # Base query that extracts fields from the JSONB using the ->> operator + command = f""" + SELECT + name, + username, + status, + start_time, + end_time, + config->>'title' AS title, + config->>'model' AS model, + config->>'dataset' AS dataset, + config->>'rounds' AS rounds, + config -- Return the full config object + FROM scenarios + """ + params = [] + if role != "admin": + command += " WHERE username = $1" # username is a direct column + params.append(username) - command += f" {order_by_clause};" + command += f" {order_by_clause};" - c.execute(command, tuple(params)) - result_dicts = c.fetchall() + result_dicts = await conn.fetch(command, *params) - scenarios_to_return = [dict(s) for s in result_dicts] + scenarios_to_return = [dict(s) for s in result_dicts] - re_fetch_required = False - for scenario in scenarios_to_return: - if scenario["status"] == "running": - if check_scenario_federation_completed(scenario["name"]): - scenario_set_status_to_completed(scenario["name"]) - re_fetch_required = True - break + re_fetch_required = False + for scenario in scenarios_to_return: + if scenario["status"] == "running": + if await check_scenario_federation_completed(scenario["name"]): + await scenario_set_status_to_completed(scenario["name"]) + re_fetch_required = True + break - if re_fetch_required: - # Recursively call get_all_scenarios_and_check_completed to get fresh data - return get_all_scenarios_and_check_completed(username, role, sort_by) + if re_fetch_required: + # Recursively call get_all_scenarios_and_check_completed to get fresh data + return await get_all_scenarios_and_check_completed(username, role, sort_by) + finally: + await conn.close() return scenarios_to_return -def scenario_update_record(name, start_time, end_time, scenario_config, status, username): +async def scenario_update_record(name, start_time, end_time, scenario_config, status, username): """ Inserts or updates a scenario record using the PostgreSQL "UPSERT" pattern. All configuration is saved in the 'config' column of type JSONB. Direct columns (name, start_time, end_time, username, status) are also handled. """ - with get_sync_conn() as conn: - with conn.cursor() as c: - # Ensure scenario_config is a dictionary before dumping to JSON - if not isinstance(scenario_config, dict): - try: - scenario_config = json.loads(scenario_config) - except json.JSONDecodeError: - logging.error("scenario_config is not a valid JSON string or dict.") - return - - command = """ - INSERT INTO scenarios (name, start_time, end_time, username, status, config) - VALUES (%s, %s, %s, %s, %s, %s::jsonb) - ON CONFLICT (name) DO UPDATE SET - start_time = EXCLUDED.start_time, - end_time = EXCLUDED.end_time, - username = EXCLUDED.username, - status = EXCLUDED.status, - config = scenarios.config || EXCLUDED.config; -- Merge JSONB - """ - c.execute(command, (name, start_time, end_time, username, status, json.dumps(scenario_config))) - conn.commit() + conn = await get_async_conn() + try: + # Ensure scenario_config is a dictionary before dumping to JSON + if not isinstance(scenario_config, dict): + try: + scenario_config = json.loads(scenario_config) + except json.JSONDecodeError: + logging.error("scenario_config is not a valid JSON string or dict.") + return + + command = """ + INSERT INTO scenarios (name, start_time, end_time, username, status, config) + VALUES ($1, $2, $3, $4, $5, $6::jsonb) + ON CONFLICT (name) DO UPDATE SET + start_time = EXCLUDED.start_time, + end_time = EXCLUDED.end_time, + username = EXCLUDED.username, + status = EXCLUDED.status, + config = scenarios.config || EXCLUDED.config; -- Merge JSONB + """ + await conn.execute(command, name, start_time, end_time, username, status, json.dumps(scenario_config)) + finally: + await conn.close() -def scenario_set_all_status_to_finished(): +async def scenario_set_all_status_to_finished(): """ Sets the status of all 'running' scenarios to 'finished' and updates their 'end_time' (both in the direct column and within JSONB). """ - with get_sync_conn() as conn: - with conn.cursor() as c: - current_time = datetime.datetime.now().strftime('%Y/%m/%d %H:%M:%S') # Consistent format - # Update direct columns first, then update JSONB within config - command = """ - UPDATE scenarios - SET - status = 'finished', - end_time = %s, - config = jsonb_set(config, '{status}', '"finished"') || - jsonb_set(config, '{end_time}', %s::jsonb) - WHERE status = 'running'; - """ - c.execute(command, (current_time, json.dumps(current_time))) - conn.commit() + conn = await get_async_conn() + try: + current_time = datetime.datetime.now().strftime('%Y/%m/%d %H:%M:%S') # Consistent format + # Update direct columns first, then update JSONB within config + command = """ + UPDATE scenarios + SET + status = 'finished', + end_time = $1, + config = jsonb_set(config, '{status}', '"finished"') || + jsonb_set(config, '{end_time}', $2::jsonb) + WHERE status = 'running'; + """ + await conn.execute(command, current_time, json.dumps(current_time)) + finally: + await conn.close() -def scenario_set_status_to_finished(scenario_name): +async def scenario_set_status_to_finished(scenario_name): """ Sets the status of a specific scenario to 'finished' and updates its 'end_time'. Updates both the direct columns and the JSONB 'config'. """ - with get_sync_conn() as conn: - with conn.cursor() as c: - current_time = datetime.datetime.now().strftime('%Y/%m/%d %H:%M:%S') # Consistent format - command = """ - UPDATE scenarios - SET - status = 'finished', - end_time = %s, - config = jsonb_set( - jsonb_set(config, '{status}', '"finished"'), - '{end_time}', %s::jsonb - ) - WHERE name = %s; - """ - c.execute(command, (current_time, json.dumps(current_time), scenario_name)) - conn.commit() + conn = await get_async_conn() + try: + current_time = datetime.datetime.now().strftime('%Y/%m/%d %H:%M:%S') # Consistent format + command = """ + UPDATE scenarios + SET + status = 'finished', + end_time = $1, + config = jsonb_set( + jsonb_set(config, '{status}', '"finished"'), + '{end_time}', $2::jsonb + ) + WHERE name = $3; + """ + await conn.execute(command, current_time, json.dumps(current_time), scenario_name) + finally: + await conn.close() -def scenario_set_status_to_completed(scenario_name): +async def scenario_set_status_to_completed(scenario_name): """ Sets the status of a specific scenario to 'completed'. Updates both the direct column and the JSONB 'config'. """ - with get_sync_conn() as conn: - with conn.cursor() as c: - command = """ - UPDATE scenarios - SET - status = 'completed', - config = jsonb_set(config, '{status}', '"completed"') - WHERE name = %s; - """ - c.execute(command, (scenario_name,)) - conn.commit() + conn = await get_async_conn() + try: + command = """ + UPDATE scenarios + SET + status = 'completed', + config = jsonb_set(config, '{status}', '"completed"') + WHERE name = $1; + """ + await conn.execute(command, scenario_name) + finally: + await conn.close() -def get_running_scenario(username=None, get_all=False): +async def get_running_scenario(username=None, get_all=False): """ Retrieves scenarios with a 'running' status, optionally filtered by user. Returns full scenario record (including direct columns and config JSONB). """ - with get_sync_conn() as conn: - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: - params = ["running"] - # Select all columns to get both direct and config data - command = "SELECT name, username, status, start_time, end_time, config FROM scenarios WHERE status = %s" - - if username: - command += " AND username = %s" - params.append(username) - - c.execute(command, tuple(params)) - - if get_all: - result = [dict(row) for row in c.fetchall()] # Convert DictRows to dicts - else: - result_row = c.fetchone() - result = dict(result_row) if result_row else None + conn = await get_async_conn() + try: + params = ["running"] + # Select all columns to get both direct and config data + command = "SELECT name, username, status, start_time, end_time, config FROM scenarios WHERE status = $1" + + if username: + command += " AND username = $2" + params.append(username) + + if get_all: + result = [dict(row) for row in await conn.fetch(command, *params)] # Convert records to dicts + else: + result_row = await conn.fetchrow(command, *params) + result = dict(result_row) if result_row else None + finally: + await conn.close() return result -def get_completed_scenario(): +async def get_completed_scenario(): """ Retrieves a single scenario with a 'completed' status. Returns full scenario record (including direct columns and config JSONB). """ - with get_sync_conn() as conn: - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: - # The status is now a direct column, not just in config->>'status' - command = "SELECT name, username, status, start_time, end_time, config FROM scenarios WHERE status = %s;" - c.execute(command, ("completed",)) - result_row = c.fetchone() - result = dict(result_row) if result_row else None + conn = await get_async_conn() + try: + # The status is now a direct column, not just in config->>'status' + command = "SELECT name, username, status, start_time, end_time, config FROM scenarios WHERE status = $1;" + result_row = await conn.fetchrow(command, "completed") + result = dict(result_row) if result_row else None + finally: + await conn.close() return result -def get_scenario_by_name(scenario_name): +async def get_scenario_by_name(scenario_name): """ Retrieves the complete record of a scenario by its name. """ - with get_sync_conn() as conn: - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: - c.execute("SELECT name, start_time, end_time, username, status, config FROM scenarios WHERE name = %s;", (scenario_name,)) - result_row = c.fetchone() - result = dict(result_row) if result_row else None + conn = await get_async_conn() + try: + result_row = await conn.fetchrow("SELECT name, start_time, end_time, username, status, config FROM scenarios WHERE name = $1;", scenario_name) + result = dict(result_row) if result_row else None - if result and result.get('config'): - # Assuming 'config' is already parsed into a Python dictionary by DictCursor - # If it's still a string, you might need: config_data = json.loads(result['config']) - config_data = result['config'] + if result and result.get('config'): + # Assuming 'config' is a JSON string, so we parse it + config_data = json.loads(result['config']) - # Extract the 'scenario_title' and add it as a top-level key - # Use .get() for safety in case 'scenario_title' is also missing within config - result['title'] = config_data.get('scenario_title') + # Extract the 'scenario_title' and add it as a top-level key + # Use .get() for safety in case 'scenario_title' is also missing within config + result['title'] = config_data.get('scenario_title') - # Also, if 'description' is inside config, you'll need to extract it similarly - result['description'] = config_data.get('description') # Assuming 'description' is also in config + # Also, if 'description' is inside config, you'll need to extract it similarly + result['description'] = config_data.get('description') # Assuming 'description' is also in config + finally: + await conn.close() return result -def get_user_by_scenario_name(scenario_name): +async def get_user_by_scenario_name(scenario_name): """ Retrieves the username associated with a scenario (from the direct 'username' column). """ - with get_sync_conn() as conn: - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: - c.execute("SELECT username FROM scenarios WHERE name = %s;", (scenario_name,)) - result = c.fetchone() - return result["username"] if result else None + conn = await get_async_conn() + try: + result = await conn.fetchval("SELECT username FROM scenarios WHERE name = $1;", scenario_name) + finally: + await conn.close() + return result -def remove_scenario_by_name(scenario_name): +async def remove_scenario_by_name(scenario_name): """ Delete a scenario from the database by its unique name. @@ -553,17 +564,17 @@ def remove_scenario_by_name(scenario_name): - Removes the scenario record matching the given name. - Commits the deletion to the database. """ + conn = await get_async_conn() try: - with get_sync_conn() as conn: - with conn.cursor() as c: - c.execute("DELETE FROM scenarios WHERE name = %s;", (scenario_name,)) - conn.commit() + await conn.execute("DELETE FROM scenarios WHERE name = $1;", scenario_name) logging.info(f"Scenario '{scenario_name}' successfully removed.") - except psycopg2.Error as e: + except asyncpg.PostgresError as e: logging.error(f"Error occurred while deleting scenario '{scenario_name}': {e}") + finally: + await conn.close() -def check_scenario_federation_completed(scenario_name): +async def check_scenario_federation_completed(scenario_name): """ Check if all nodes in a given scenario have completed the required federation rounds. @@ -579,47 +590,46 @@ def check_scenario_federation_completed(scenario_name): - Returns True only if every node has reached the total rounds. - Handles database errors and missing scenario cases gracefully. """ + conn = await get_async_conn() try: - with get_sync_conn() as conn: - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: - # Retrieve the total rounds for the scenario from the 'config' JSONB column - c.execute("SELECT config->>'rounds' AS rounds FROM scenarios WHERE name = %s;", (scenario_name,)) - scenario = c.fetchone() - - if not scenario or scenario["rounds"] is None: - logging.warning(f"Scenario '{scenario_name}' not found or 'rounds' not defined.") - return False - - # Ensure total_rounds is an integer for comparison - try: - total_rounds = int(scenario["rounds"]) - except (ValueError, TypeError): - logging.error(f"Invalid 'rounds' value for scenario '{scenario_name}': {scenario['rounds']}") - return False - - # Fetch the current round progress of all nodes in that scenario - # The 'round' column in 'nodes' is a direct column - c.execute("SELECT round FROM nodes WHERE scenario = %s;", (scenario_name,)) - nodes = c.fetchall() - - if not nodes: - logging.info(f"No nodes found for scenario '{scenario_name}'. Federation not considered completed.") - return False - - # Check if all nodes have completed the total rounds - # The 'round' column in nodes is likely stored as a string or a numeric type. - # Assuming 'round' in 'nodes' is a numeric type, we convert it to int for comparison. - return all(int(node["round"]) >= total_rounds for node in nodes) - - except psycopg2.Error as e: + # Retrieve the total rounds for the scenario from the 'config' JSONB column + scenario_rounds_str = await conn.fetchval("SELECT config->>'rounds' AS rounds FROM scenarios WHERE name = $1;", scenario_name) + + if not scenario_rounds_str: + logging.warning(f"Scenario '{scenario_name}' not found or 'rounds' not defined.") + return False + + # Ensure total_rounds is an integer for comparison + try: + total_rounds = int(scenario_rounds_str) + except (ValueError, TypeError): + logging.error(f"Invalid 'rounds' value for scenario '{scenario_name}': {scenario_rounds_str}") + return False + + # Fetch the current round progress of all nodes in that scenario + # The 'round' column in 'nodes' is a direct column + nodes = await conn.fetch("SELECT round FROM nodes WHERE scenario = $1;", scenario_name) + + if not nodes: + logging.info(f"No nodes found for scenario '{scenario_name}'. Federation not considered completed.") + return False + + # Check if all nodes have completed the total rounds + # The 'round' column in nodes is likely stored as a string or a numeric type. + # Assuming 'round' in 'nodes' is a numeric type, we convert it to int for comparison. + return all(int(node["round"]) >= total_rounds for node in nodes) + + except asyncpg.PostgresError as e: logging.error(f"PostgreSQL error during check_scenario_federation_completed for scenario '{scenario_name}': {e}") return False except ValueError as e: logging.error(f"Data error during check_scenario_federation_completed for scenario '{scenario_name}': {e}") return False + finally: + await conn.close() -def check_scenario_with_role(role, scenario_name, current_username=None): +async def check_scenario_with_role(role, scenario_name, current_username=None): """ Verify if a scenario exists that the user with the given role and username can access. @@ -637,7 +647,7 @@ def check_scenario_with_role(role, scenario_name, current_username=None): - If the user's role is not "admin", they can only access scenarios where the scenario's 'username' matches their `current_username`. """ - scenario_info = get_scenario_by_name(scenario_name) + scenario_info = await get_scenario_by_name(scenario_name) if not scenario_info: return False # Scenario does not exist @@ -657,46 +667,48 @@ def check_scenario_with_role(role, scenario_name, current_username=None): # --- Notes Management Functions --- -def save_notes(scenario, notes): +async def save_notes(scenario, notes): """ Save or update notes associated with a specific scenario. """ + conn = await get_async_conn() try: - with get_sync_conn() as conn: - with conn.cursor() as c: - c.execute( - """ - INSERT INTO notes (scenario, scenario_notes) VALUES (%s, %s) - ON CONFLICT(scenario) DO UPDATE SET scenario_notes = EXCLUDED.scenario_notes; - """, - (scenario, notes), - ) - conn.commit() - except psycopg2.IntegrityError as e: + await conn.execute( + """ + INSERT INTO notes (scenario, scenario_notes) VALUES ($1, $2) + ON CONFLICT(scenario) DO UPDATE SET scenario_notes = EXCLUDED.scenario_notes; + """, + scenario, notes, + ) + except asyncpg.IntegrityConstraintViolationError as e: logging.error(f"PostgreSQL integrity error during save_notes: {e}") - except psycopg2.Error as e: + except asyncpg.PostgresError as e: logging.error(f"PostgreSQL error during save_notes: {e}") + finally: + await conn.close() -def get_notes(scenario): +async def get_notes(scenario): """ Retrieve notes associated with a specific scenario. """ - with get_sync_conn() as conn: - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as c: - c.execute("SELECT * FROM notes WHERE scenario = %s;", (scenario,)) - result = c.fetchone() + conn = await get_async_conn() + try: + result = await conn.fetchrow("SELECT * FROM notes WHERE scenario = $1;", scenario) + finally: + await conn.close() return result -def remove_note(scenario): +async def remove_note(scenario): """ Delete the note associated with a specific scenario. """ - with get_sync_conn() as conn: - with conn.cursor() as c: - c.execute("DELETE FROM notes WHERE scenario = %s;", (scenario,)) - conn.commit() + conn = await get_async_conn() + try: + await conn.execute("DELETE FROM notes WHERE scenario = $1;", scenario) + finally: + await conn.close() if __name__ == "__main__": From f438a9a52cbba20ecced8fd345c5b7c9905fcf16 Mon Sep 17 00:00:00 2001 From: FerTV Date: Thu, 10 Jul 2025 11:48:50 +0200 Subject: [PATCH 08/20] feature: persistence for postgresql added --- app/databases/__init__.py | 0 app/deployer.py | 4 ++++ 2 files changed, 4 insertions(+) delete mode 100755 app/databases/__init__.py diff --git a/app/databases/__init__.py b/app/databases/__init__.py deleted file mode 100755 index e69de29bb..000000000 diff --git a/app/deployer.py b/app/deployer.py index 27ef8e8ce..34b2160fa 100644 --- a/app/deployer.py +++ b/app/deployer.py @@ -1044,9 +1044,13 @@ def run_database(self): host_sql_path = os.path.join(self.root_path, "nebula/database/init-configs.sql") container_sql_path = "/docker-entrypoint-initdb.d/init-configs.sql" + db_data_path = os.path.join(self.databases_dir, "postgres-data") + os.makedirs(db_data_path, exist_ok=True) + host_config = client.api.create_host_config( binds=[ f"{host_sql_path}:{container_sql_path}", + f"{db_data_path}:/var/lib/postgresql/data", ], extra_hosts={"host.docker.internal": "host-gateway"}, port_bindings={5432: host_port}, From e88488716c33a7fca3c8d0a7066b6597589ecf75 Mon Sep 17 00:00:00 2001 From: FerTV Date: Thu, 10 Jul 2025 12:59:33 +0200 Subject: [PATCH 09/20] feature: pool added for asynchronous connections to the database --- nebula/controller/controller.py | 27 +- nebula/controller/database.py | 599 ++++++++++++++------------------ 2 files changed, 277 insertions(+), 349 deletions(-) diff --git a/nebula/controller/controller.py b/nebula/controller/controller.py index 02a5d6da3..d2ba4c7a1 100755 --- a/nebula/controller/controller.py +++ b/nebula/controller/controller.py @@ -16,8 +16,12 @@ from fastapi import Body, FastAPI, Request, status, HTTPException, Path, File, UploadFile from fastapi.concurrency import asynccontextmanager -# from nebula.controller.database import scenario_set_all_status_to_finished, scenario_set_status_to_finished -from nebula.controller.database import scenario_set_all_status_to_finished, scenario_set_status_to_finished +from nebula.controller.database import ( + init_db_pool, + close_db_pool, + scenario_set_all_status_to_finished, + scenario_set_status_to_finished, +) from nebula.controller.http_helpers import remote_get, remote_post_form from nebula.utils import DockerUtils @@ -107,17 +111,24 @@ def configure_logger(controller_log): @asynccontextmanager async def lifespan(app: FastAPI): - # databases_dir: str = os.environ.get("NEBULA_DATABASES_DIR") + """ + Application lifespan context manager. + - Initializes the database connection pool on startup. + - Configures logging. + - Cleans up resources like the database pool on shutdown. + """ + # Code to run on startup controller_log: str = os.environ.get("NEBULA_CONTROLLER_LOG") - - # from nebula.controller.database import initialize_databases - - # await initialize_databases(databases_dir) - configure_logger(controller_log) + # Initialize the database connection pool + await init_db_pool() + yield + # Code to run on shutdown + await close_db_pool() + # Initialize FastAPI app outside the Controller class app = FastAPI(lifespan=lifespan) diff --git a/nebula/controller/database.py b/nebula/controller/database.py index 50d61db0d..5b01f0be2 100755 --- a/nebula/controller/database.py +++ b/nebula/controller/database.py @@ -1,13 +1,11 @@ import logging import os -import psycopg2 -import psycopg2.extras -from passlib.context import CryptContext import datetime import json import asyncpg import asyncio +from passlib.context import CryptContext from nebula.controller.scenarios import Scenario # --- Configuration --- @@ -20,12 +18,38 @@ # Asynchronous lock for node updates _node_lock = asyncio.Lock() +# --- Connection Pool Management --- +# Global pool variable, should be initialized at application startup +POOL = None -# --- Connection Management Helper Functions --- +async def init_db_pool(): + """ + Initializes the asynchronous PostgreSQL connection pool. + This should be called once when the application starts. + """ + global POOL + if POOL is None: + try: + POOL = await asyncpg.create_pool( + dsn=DATABASE_URL, + min_size=5, # Minimum number of connections in the pool + max_size=20, # Maximum number of connections in the pool + ) + logging.info("Database connection pool successfully created.") + except Exception as e: + logging.critical(f"Failed to create database connection pool: {e}", exc_info=True) + # Exit or handle the failure appropriately + raise -async def get_async_conn(): - """Establishes an asynchronous PostgreSQL connection.""" - return await asyncpg.connect(DATABASE_URL) +async def close_db_pool(): + """ + Closes the asynchronous PostgreSQL connection pool. + This should be called once when the application shuts down gracefully. + """ + global POOL + if POOL: + await POOL.close() + logging.info("Database connection pool closed.") # --- User Management Functions --- @@ -34,11 +58,8 @@ async def list_users(all_info=False): """ Retrieves a list of users from the users database. """ - conn = await get_async_conn() - try: + async with POOL.acquire() as conn: result = await conn.fetch("SELECT * FROM users") - finally: - await conn.close() if not all_info: result = [user["user"] for user in result] @@ -50,30 +71,23 @@ async def get_user_info(user): """ Fetches detailed information for a specific user from the users database. """ - conn = await get_async_conn() - try: - result = await conn.fetchrow('SELECT * FROM users WHERE "user" = $1', user) - finally: - await conn.close() - return result + async with POOL.acquire() as conn: + return await conn.fetchrow('SELECT * FROM users WHERE "user" = $1', user) async def verify(user, password): """ Verifies whether the provided password matches the stored hashed password for a user. """ - conn = await get_async_conn() - try: + async with POOL.acquire() as conn: result = await conn.fetchrow('SELECT password FROM users WHERE "user" = $1', user) - if result: - try: - return pwd_context.verify(password, result[0]) - except Exception: - # Catch more general exceptions during verification to be safe - logging.error(f"Error during password verification for user {user}", exc_info=True) - return False - finally: - await conn.close() + if result: + try: + return pwd_context.verify(password, result[0]) + except Exception: + # Catch more general exceptions during verification to be safe + logging.error(f"Error during password verification for user {user}", exc_info=True) + return False return False @@ -83,14 +97,11 @@ async def verify_hash_algorithm(user): """ user = user.upper() argon2_prefixes = ("$argon2i$", "$argon2id$") - conn = await get_async_conn() - try: + async with POOL.acquire() as conn: result = await conn.fetchrow('SELECT password FROM users WHERE "user" = $1', user) - if result: - password_hash = result["password"] - return password_hash.startswith(argon2_prefixes) - finally: - await conn.close() + if result: + password_hash = result["password"] + return password_hash.startswith(argon2_prefixes) return False @@ -98,41 +109,32 @@ async def delete_user_from_db(user): """ Deletes a user record from the users database. """ - conn = await get_async_conn() - try: + async with POOL.acquire() as conn: await conn.execute('DELETE FROM users WHERE "user" = $1', user) - finally: - await conn.close() async def add_user(user, password, role): """ Adds a new user to the users database with a hashed password. """ - conn = await get_async_conn() - try: - hashed_password = pwd_context.hash(password) + hashed_password = pwd_context.hash(password) + async with POOL.acquire() as conn: await conn.execute( 'INSERT INTO users ("user", password, role) VALUES ($1, $2, $3)', user.upper(), hashed_password, role, ) - finally: - await conn.close() async def update_user(user, password, role): """ Updates the password and role of an existing user in the users database. """ - conn = await get_async_conn() - try: - hashed_password = pwd_context.hash(password) + hashed_password = pwd_context.hash(password) + async with POOL.acquire() as conn: await conn.execute( 'UPDATE users SET password = $1, role = $2 WHERE "user" = $3', hashed_password, role, user.upper(), ) - finally: - await conn.close() # --- Node Management Functions --- @@ -140,45 +142,38 @@ async def list_nodes(scenario_name=None, sort_by="idx"): """ Retrieves a list of nodes from the nodes database, optionally filtered by scenario and sorted. """ - conn = await get_async_conn() - try: - # Validate sort_by to prevent SQL injection - allowed_sort_fields = ["uid", "idx", "ip", "port", "role", "timestamp", "federation", "round"] - if sort_by not in allowed_sort_fields: - sort_by = "idx" # Default to a safe field - - if scenario_name: - # Using f-string for column names is generally safe if validated as above - command = f"SELECT * FROM nodes WHERE scenario = $1 ORDER BY {sort_by};" - result = await conn.fetch(command, scenario_name) - else: - command = f"SELECT * FROM nodes ORDER BY {sort_by};" - result = await conn.fetch(command) + # Validate sort_by to prevent SQL injection + allowed_sort_fields = ["uid", "idx", "ip", "port", "role", "timestamp", "federation", "round"] + if sort_by not in allowed_sort_fields: + sort_by = "idx" # Default to a safe field - return result + try: + async with POOL.acquire() as conn: + if scenario_name: + # Using f-string for column names is generally safe if validated as above + command = f"SELECT * FROM nodes WHERE scenario = $1 ORDER BY {sort_by};" + result = await conn.fetch(command, scenario_name) + else: + command = f"SELECT * FROM nodes ORDER BY {sort_by};" + result = await conn.fetch(command) + return result except asyncpg.PostgresError as e: logging.error(f"Error occurred while listing nodes: {e}") return None - finally: - await conn.close() async def list_nodes_by_scenario_name(scenario_name): """ Fetches all nodes associated with a specific scenario, ordered by their index as integers. """ - conn = None try: - conn = await get_async_conn() - command = "SELECT * FROM nodes WHERE scenario = $1 ORDER BY CAST(idx AS INTEGER) ASC;" - result = await conn.fetch(command, scenario_name) - return [dict(record) for record in result] + async with POOL.acquire() as conn: + command = "SELECT * FROM nodes WHERE scenario = $1 ORDER BY CAST(idx AS INTEGER) ASC;" + result = await conn.fetch(command, scenario_name) + return [dict(record) for record in result] except Exception as e: logging.error(f"Error occurred while listing nodes by scenario name: {e}") return None - finally: - if conn: - await conn.close() async def update_node_record( @@ -189,67 +184,60 @@ async def update_node_record( Inserts or updates a node record in the database for a given scenario, ensuring thread-safe access. """ async with _node_lock: - # Await the get_async_conn() call to get the actual connection object - conn = await get_async_conn() - try: - async with conn.transaction(): - result = await conn.fetchrow( - "SELECT * FROM nodes WHERE uid = $1 AND scenario = $2 FOR UPDATE;", - node_uid, scenario - ) - - if result is None: - # Insert new node - await conn.execute( - """ - INSERT INTO nodes (uid, idx, ip, port, role, neighbors, latitude, longitude, - timestamp, federation, round, scenario, hash, malicious) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14); - """, - node_uid, idx, ip, port, role, neighbors, latitude, longitude, - timestamp, federation, federation_round, scenario, run_hash, malicious, - ) - else: - # Update existing node - await conn.execute( - """ - UPDATE nodes SET idx = $1, ip = $2, port = $3, role = $4, neighbors = $5, - latitude = $6, longitude = $7, timestamp = $8, federation = $9, round = $10, - hash = $11, malicious = $12 - WHERE uid = $13 AND scenario = $14; - """, - idx, ip, port, role, neighbors, latitude, longitude, - timestamp, federation, federation_round, run_hash, malicious, - node_uid, scenario, + async with POOL.acquire() as conn: + try: + async with conn.transaction(): + result = await conn.fetchrow( + "SELECT * FROM nodes WHERE uid = $1 AND scenario = $2 FOR UPDATE;", + node_uid, scenario ) - updated_row = await conn.fetchrow("SELECT * from nodes WHERE uid = $1 AND scenario = $2;", node_uid, scenario) - return dict(updated_row) if updated_row else None - finally: - # Ensure the connection is closed after use - await conn.close() + if result is None: + # Insert new node + await conn.execute( + """ + INSERT INTO nodes (uid, idx, ip, port, role, neighbors, latitude, longitude, + timestamp, federation, round, scenario, hash, malicious) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14); + """, + node_uid, idx, ip, port, role, neighbors, latitude, longitude, + timestamp, federation, federation_round, scenario, run_hash, malicious, + ) + else: + # Update existing node + await conn.execute( + """ + UPDATE nodes SET idx = $1, ip = $2, port = $3, role = $4, neighbors = $5, + latitude = $6, longitude = $7, timestamp = $8, federation = $9, round = $10, + hash = $11, malicious = $12 + WHERE uid = $13 AND scenario = $14; + """, + idx, ip, port, role, neighbors, latitude, longitude, + timestamp, federation, federation_round, run_hash, malicious, + node_uid, scenario, + ) + + updated_row = await conn.fetchrow("SELECT * from nodes WHERE uid = $1 AND scenario = $2;", node_uid, scenario) + return dict(updated_row) if updated_row else None + except asyncpg.PostgresError as e: + logging.error(f"Database error during node record update: {e}", exc_info=True) + return None async def remove_all_nodes(): """ Deletes all node records from the nodes database. """ - conn = await get_async_conn() - try: + async with POOL.acquire() as conn: await conn.execute("TRUNCATE nodes CASCADE;") # Use CASCADE if there are foreign key dependencies - finally: - await conn.close() async def remove_nodes_by_scenario_name(scenario_name): """ Deletes all nodes associated with a specific scenario from the database. """ - conn = await get_async_conn() - try: + async with POOL.acquire() as conn: await conn.execute("DELETE FROM nodes WHERE scenario = $1;", scenario_name) - finally: - await conn.close() # --- Scenario Management Functions --- @@ -270,15 +258,14 @@ async def get_all_scenarios(username, role, sort_by="start_time"): WHEN start_time IS NULL OR start_time = '' THEN 1 ELSE 0 END, - to_timestamp(start_time, 'YYYY/MM/DD HH24:MI:SS') DESC + to_timestamp(start_time, 'DD/MM/YYYY HH24:MI:SS') DESC """ elif sort_by in ["title", "model", "dataset", "rounds"]: # These are inside config JSONB order_by_clause = f"ORDER BY config->>'{sort_by}'" else: # For direct table columns like name, username, status order_by_clause = f"ORDER BY {sort_by}" - conn = await get_async_conn() - try: + async with POOL.acquire() as conn: # Select direct columns and relevant fields from config JSONB command = """ SELECT @@ -301,11 +288,7 @@ async def get_all_scenarios(username, role, sort_by="start_time"): params.append(username) full_command = f"{command} {order_by_clause};" - result = await conn.fetch(full_command, *params) - finally: - await conn.close() - - return result + return await conn.fetch(full_command, *params) async def get_all_scenarios_and_check_completed(username, role, sort_by="start_time"): @@ -318,7 +301,7 @@ async def get_all_scenarios_and_check_completed(username, role, sort_by="start_t if sort_by not in allowed_sort_fields: sort_by = "start_time" # Safe default value - # Building the ORDER BY clause (same as get_all_scenarios) + # Building the ORDER BY clause if sort_by == "start_time": order_by_clause = """ ORDER BY @@ -326,7 +309,6 @@ async def get_all_scenarios_and_check_completed(username, role, sort_by="start_t WHEN start_time IS NULL OR start_time = '' THEN 1 ELSE 0 END, - -- CORRECTED: Changed 'DD/MM/YYYY' to 'YYYY/MM/DD' to match the storage format to_timestamp(start_time, 'DD/MM/YYYY HH24:MI:SS') DESC """ elif sort_by in ["title", "model", "dataset", "rounds"]: # These are inside config JSONB @@ -334,8 +316,7 @@ async def get_all_scenarios_and_check_completed(username, role, sort_by="start_t else: # For direct table columns like name, username, status order_by_clause = f"ORDER BY {sort_by}" - conn = await get_async_conn() - try: + async with POOL.acquire() as conn: # Base query that extracts fields from the JSONB using the ->> operator command = f""" SELECT @@ -360,21 +341,19 @@ async def get_all_scenarios_and_check_completed(username, role, sort_by="start_t result_dicts = await conn.fetch(command, *params) - scenarios_to_return = [dict(s) for s in result_dicts] + scenarios_to_return = [dict(s) for s in result_dicts] - re_fetch_required = False - for scenario in scenarios_to_return: - if scenario["status"] == "running": - if await check_scenario_federation_completed(scenario["name"]): - await scenario_set_status_to_completed(scenario["name"]) - re_fetch_required = True - break + re_fetch_required = False + for scenario in scenarios_to_return: + if scenario["status"] == "running": + if await check_scenario_federation_completed(scenario["name"]): + await scenario_set_status_to_completed(scenario["name"]) + re_fetch_required = True + break - if re_fetch_required: - # Recursively call get_all_scenarios_and_check_completed to get fresh data - return await get_all_scenarios_and_check_completed(username, role, sort_by) - finally: - await conn.close() + if re_fetch_required: + # Recursively call to get fresh data after status update + return await get_all_scenarios_and_check_completed(username, role, sort_by) return scenarios_to_return @@ -385,29 +364,26 @@ async def scenario_update_record(name, start_time, end_time, scenario_config, st All configuration is saved in the 'config' column of type JSONB. Direct columns (name, start_time, end_time, username, status) are also handled. """ - conn = await get_async_conn() - try: - # Ensure scenario_config is a dictionary before dumping to JSON - if not isinstance(scenario_config, dict): - try: - scenario_config = json.loads(scenario_config) - except json.JSONDecodeError: - logging.error("scenario_config is not a valid JSON string or dict.") - return - - command = """ - INSERT INTO scenarios (name, start_time, end_time, username, status, config) - VALUES ($1, $2, $3, $4, $5, $6::jsonb) - ON CONFLICT (name) DO UPDATE SET - start_time = EXCLUDED.start_time, - end_time = EXCLUDED.end_time, - username = EXCLUDED.username, - status = EXCLUDED.status, - config = scenarios.config || EXCLUDED.config; -- Merge JSONB - """ + # Ensure scenario_config is a dictionary before dumping to JSON + if not isinstance(scenario_config, dict): + try: + scenario_config = json.loads(scenario_config) + except (json.JSONDecodeError, TypeError): + logging.error("scenario_config is not a valid JSON string or dict.") + return + + command = """ + INSERT INTO scenarios (name, start_time, end_time, username, status, config) + VALUES ($1, $2, $3, $4, $5, $6::jsonb) + ON CONFLICT (name) DO UPDATE SET + start_time = EXCLUDED.start_time, + end_time = EXCLUDED.end_time, + username = EXCLUDED.username, + status = EXCLUDED.status, + config = scenarios.config || EXCLUDED.config; -- Merge JSONB + """ + async with POOL.acquire() as conn: await conn.execute(command, name, start_time, end_time, username, status, json.dumps(scenario_config)) - finally: - await conn.close() async def scenario_set_all_status_to_finished(): @@ -415,22 +391,18 @@ async def scenario_set_all_status_to_finished(): Sets the status of all 'running' scenarios to 'finished' and updates their 'end_time' (both in the direct column and within JSONB). """ - conn = await get_async_conn() - try: - current_time = datetime.datetime.now().strftime('%Y/%m/%d %H:%M:%S') # Consistent format - # Update direct columns first, then update JSONB within config - command = """ - UPDATE scenarios - SET - status = 'finished', - end_time = $1, - config = jsonb_set(config, '{status}', '"finished"') || - jsonb_set(config, '{end_time}', $2::jsonb) - WHERE status = 'running'; - """ + current_time = datetime.datetime.now().strftime('%Y/%m/%d %H:%M:%S') # Consistent format + command = """ + UPDATE scenarios + SET + status = 'finished', + end_time = $1, + config = jsonb_set(config, '{status}', '"finished"') || + jsonb_set(config, '{end_time}', $2::jsonb) + WHERE status = 'running'; + """ + async with POOL.acquire() as conn: await conn.execute(command, current_time, json.dumps(current_time)) - finally: - await conn.close() async def scenario_set_status_to_finished(scenario_name): @@ -438,23 +410,20 @@ async def scenario_set_status_to_finished(scenario_name): Sets the status of a specific scenario to 'finished' and updates its 'end_time'. Updates both the direct columns and the JSONB 'config'. """ - conn = await get_async_conn() - try: - current_time = datetime.datetime.now().strftime('%Y/%m/%d %H:%M:%S') # Consistent format - command = """ - UPDATE scenarios - SET - status = 'finished', - end_time = $1, - config = jsonb_set( - jsonb_set(config, '{status}', '"finished"'), - '{end_time}', $2::jsonb - ) - WHERE name = $3; - """ + current_time = datetime.datetime.now().strftime('%Y/%m/%d %H:%M:%S') # Consistent format + command = """ + UPDATE scenarios + SET + status = 'finished', + end_time = $1, + config = jsonb_set( + jsonb_set(config, '{status}', '"finished"'), + '{end_time}', $2::jsonb + ) + WHERE name = $3; + """ + async with POOL.acquire() as conn: await conn.execute(command, current_time, json.dumps(current_time), scenario_name) - finally: - await conn.close() async def scenario_set_status_to_completed(scenario_name): @@ -462,18 +431,15 @@ async def scenario_set_status_to_completed(scenario_name): Sets the status of a specific scenario to 'completed'. Updates both the direct column and the JSONB 'config'. """ - conn = await get_async_conn() - try: - command = """ - UPDATE scenarios - SET - status = 'completed', - config = jsonb_set(config, '{status}', '"completed"') - WHERE name = $1; - """ + command = """ + UPDATE scenarios + SET + status = 'completed', + config = jsonb_set(config, '{status}', '"completed"') + WHERE name = $1; + """ + async with POOL.acquire() as conn: await conn.execute(command, scenario_name) - finally: - await conn.close() async def get_running_scenario(username=None, get_all=False): @@ -481,8 +447,7 @@ async def get_running_scenario(username=None, get_all=False): Retrieves scenarios with a 'running' status, optionally filtered by user. Returns full scenario record (including direct columns and config JSONB). """ - conn = await get_async_conn() - try: + async with POOL.acquire() as conn: params = ["running"] # Select all columns to get both direct and config data command = "SELECT name, username, status, start_time, end_time, config FROM scenarios WHERE status = $1" @@ -496,8 +461,6 @@ async def get_running_scenario(username=None, get_all=False): else: result_row = await conn.fetchrow(command, *params) result = dict(result_row) if result_row else None - finally: - await conn.close() return result @@ -506,38 +469,35 @@ async def get_completed_scenario(): Retrieves a single scenario with a 'completed' status. Returns full scenario record (including direct columns and config JSONB). """ - conn = await get_async_conn() - try: - # The status is now a direct column, not just in config->>'status' + async with POOL.acquire() as conn: command = "SELECT name, username, status, start_time, end_time, config FROM scenarios WHERE status = $1;" result_row = await conn.fetchrow(command, "completed") - result = dict(result_row) if result_row else None - finally: - await conn.close() - return result + return dict(result_row) if result_row else None async def get_scenario_by_name(scenario_name): """ Retrieves the complete record of a scenario by its name. """ - conn = await get_async_conn() - try: + async with POOL.acquire() as conn: result_row = await conn.fetchrow("SELECT name, start_time, end_time, username, status, config FROM scenarios WHERE name = $1;", scenario_name) - result = dict(result_row) if result_row else None - if result and result.get('config'): - # Assuming 'config' is a JSON string, so we parse it - config_data = json.loads(result['config']) + result = dict(result_row) if result_row else None + + if result and result.get('config'): + # Assuming 'config' is a JSON string from the DB, so we parse it + # It might already be a dict if asyncpg handles JSONB conversion automatically + config_data = result['config'] + if isinstance(config_data, str): + try: + config_data = json.loads(config_data) + except json.JSONDecodeError: + config_data = {} - # Extract the 'scenario_title' and add it as a top-level key - # Use .get() for safety in case 'scenario_title' is also missing within config - result['title'] = config_data.get('scenario_title') + # Extract the 'scenario_title' and add it as a top-level key + result['title'] = config_data.get('scenario_title') + result['description'] = config_data.get('description') - # Also, if 'description' is inside config, you'll need to extract it similarly - result['description'] = config_data.get('description') # Assuming 'description' is also in config - finally: - await conn.close() return result @@ -545,107 +505,63 @@ async def get_user_by_scenario_name(scenario_name): """ Retrieves the username associated with a scenario (from the direct 'username' column). """ - conn = await get_async_conn() - try: - result = await conn.fetchval("SELECT username FROM scenarios WHERE name = $1;", scenario_name) - finally: - await conn.close() - return result + async with POOL.acquire() as conn: + return await conn.fetchval("SELECT username FROM scenarios WHERE name = $1;", scenario_name) async def remove_scenario_by_name(scenario_name): """ Delete a scenario from the database by its unique name. - - Parameters: - scenario_name (str): The unique name identifier of the scenario to be removed. - - Behavior: - - Removes the scenario record matching the given name. - - Commits the deletion to the database. """ - conn = await get_async_conn() try: - await conn.execute("DELETE FROM scenarios WHERE name = $1;", scenario_name) + async with POOL.acquire() as conn: + await conn.execute("DELETE FROM scenarios WHERE name = $1;", scenario_name) logging.info(f"Scenario '{scenario_name}' successfully removed.") except asyncpg.PostgresError as e: logging.error(f"Error occurred while deleting scenario '{scenario_name}': {e}") - finally: - await conn.close() async def check_scenario_federation_completed(scenario_name): """ Check if all nodes in a given scenario have completed the required federation rounds. - - Parameters: - scenario_name (str): The unique name identifier of the scenario to check. - - Returns: - bool: True if all nodes have completed the total rounds specified for the scenario, False otherwise or if an error occurs. - - Behavior: - - Retrieves the total number of rounds defined for the scenario. - - Fetches the current round progress of all nodes in that scenario. - - Returns True only if every node has reached the total rounds. - - Handles database errors and missing scenario cases gracefully. """ - conn = await get_async_conn() try: - # Retrieve the total rounds for the scenario from the 'config' JSONB column - scenario_rounds_str = await conn.fetchval("SELECT config->>'rounds' AS rounds FROM scenarios WHERE name = $1;", scenario_name) + async with POOL.acquire() as conn: + # Retrieve the total rounds for the scenario from the 'config' JSONB column + scenario_rounds_str = await conn.fetchval("SELECT config->>'rounds' AS rounds FROM scenarios WHERE name = $1;", scenario_name) - if not scenario_rounds_str: - logging.warning(f"Scenario '{scenario_name}' not found or 'rounds' not defined.") - return False + if not scenario_rounds_str: + logging.warning(f"Scenario '{scenario_name}' not found or 'rounds' not defined.") + return False - # Ensure total_rounds is an integer for comparison - try: - total_rounds = int(scenario_rounds_str) - except (ValueError, TypeError): - logging.error(f"Invalid 'rounds' value for scenario '{scenario_name}': {scenario_rounds_str}") - return False + # Ensure total_rounds is an integer for comparison + try: + total_rounds = int(scenario_rounds_str) + except (ValueError, TypeError): + logging.error(f"Invalid 'rounds' value for scenario '{scenario_name}': {scenario_rounds_str}") + return False - # Fetch the current round progress of all nodes in that scenario - # The 'round' column in 'nodes' is a direct column - nodes = await conn.fetch("SELECT round FROM nodes WHERE scenario = $1;", scenario_name) + # Fetch the current round progress of all nodes in that scenario + nodes = await conn.fetch("SELECT round FROM nodes WHERE scenario = $1;", scenario_name) - if not nodes: - logging.info(f"No nodes found for scenario '{scenario_name}'. Federation not considered completed.") - return False + if not nodes: + logging.info(f"No nodes found for scenario '{scenario_name}'. Federation not considered completed.") + return False - # Check if all nodes have completed the total rounds - # The 'round' column in nodes is likely stored as a string or a numeric type. - # Assuming 'round' in 'nodes' is a numeric type, we convert it to int for comparison. - return all(int(node["round"]) >= total_rounds for node in nodes) + # Check if all nodes have completed the total rounds + return all(int(node["round"]) >= total_rounds for node in nodes) except asyncpg.PostgresError as e: - logging.error(f"PostgreSQL error during check_scenario_federation_completed for scenario '{scenario_name}': {e}") + logging.error(f"PostgreSQL error during check_scenario_federation_completed for '{scenario_name}': {e}") return False except ValueError as e: - logging.error(f"Data error during check_scenario_federation_completed for scenario '{scenario_name}': {e}") + logging.error(f"Data error during check_scenario_federation_completed for '{scenario_name}': {e}") return False - finally: - await conn.close() async def check_scenario_with_role(role, scenario_name, current_username=None): """ Verify if a scenario exists that the user with the given role and username can access. - - Parameters: - role (str): The role of the current user (e.g., "admin", "user"). - scenario_name (str): The unique name identifier of the scenario to check. - current_username (str, optional): The username of the currently authenticated user. - Required for non-admin roles. - - Returns: - bool: True if the scenario exists and the user has access, False otherwise. - - Behavior: - - If the user's role is "admin", they can access any existing scenario. - - If the user's role is not "admin", they can only access scenarios where the - scenario's 'username' matches their `current_username`. """ scenario_info = await get_scenario_by_name(scenario_name) @@ -654,16 +570,14 @@ async def check_scenario_with_role(role, scenario_name, current_username=None): if role == "admin": return True # Admins can access any existing scenario - else: - # For non-admin roles, check if the scenario's username matches the current user's username - if current_username is None: - logging.warning( - "`check_scenario_with_role` called for non-admin role without `current_username`. " - "Cannot verify user-specific scenario access." - ) - return False # Cannot verify access without the current user's username - return scenario_info.get("username") == current_username + if current_username is None: + logging.warning( + "check_scenario_with_role called for non-admin role without current_username." + ) + return False + + return scenario_info.get("username") == current_username # --- Notes Management Functions --- @@ -671,54 +585,57 @@ async def save_notes(scenario, notes): """ Save or update notes associated with a specific scenario. """ - conn = await get_async_conn() try: - await conn.execute( - """ - INSERT INTO notes (scenario, scenario_notes) VALUES ($1, $2) - ON CONFLICT(scenario) DO UPDATE SET scenario_notes = EXCLUDED.scenario_notes; - """, - scenario, notes, - ) - except asyncpg.IntegrityConstraintViolationError as e: - logging.error(f"PostgreSQL integrity error during save_notes: {e}") + async with POOL.acquire() as conn: + await conn.execute( + """ + INSERT INTO notes (scenario, scenario_notes) VALUES ($1, $2) + ON CONFLICT(scenario) DO UPDATE SET scenario_notes = EXCLUDED.scenario_notes; + """, + scenario, notes, + ) except asyncpg.PostgresError as e: logging.error(f"PostgreSQL error during save_notes: {e}") - finally: - await conn.close() async def get_notes(scenario): """ Retrieve notes associated with a specific scenario. """ - conn = await get_async_conn() - try: - result = await conn.fetchrow("SELECT * FROM notes WHERE scenario = $1;", scenario) - finally: - await conn.close() - return result + async with POOL.acquire() as conn: + return await conn.fetchrow("SELECT * FROM notes WHERE scenario = $1;", scenario) async def remove_note(scenario): """ Delete the note associated with a specific scenario. """ - conn = await get_async_conn() - try: + async with POOL.acquire() as conn: await conn.execute("DELETE FROM notes WHERE scenario = $1;", scenario) - finally: - await conn.close() if __name__ == "__main__": - """ - Entry point for the script to print the list of users. - """ - # Example usage (assuming DB_USER, DB_PASSWORD, DB_HOST, DB_PORT are set in env) - # os.environ['DB_USER'] = 'your_db_user' - # os.environ['DB_PASSWORD'] = 'your_db_password' - # os.environ['DB_HOST'] = 'localhost' - # os.environ['DB_PORT'] = '5432' + # This is an example of how to use the new pool-based functions. + # In a real application, init_db_pool() would be called at startup, + # and close_db_pool() at shutdown. + + async def main(): + # Set environment variables for local testing if not already set + os.environ.setdefault('DB_USER', 'your_user') + os.environ.setdefault('DB_PASSWORD', 'your_password') + os.environ.setdefault('DB_HOST', 'localhost') + os.environ.setdefault('DB_PORT', '5432') + + logging.basicConfig(level=logging.INFO) + + await init_db_pool() + try: + # Example: list all users + users = await list_users() + logging.info(f"Found users: {users}") + finally: + await close_db_pool() - logging.basicConfig(level=logging.INFO) + # To run this example: + # asyncio.run(main()) + pass From 0d5b11fb902e4dc222e5ad9de6e75c780759a230 Mon Sep 17 00:00:00 2001 From: FerTV Date: Thu, 10 Jul 2025 17:55:18 +0200 Subject: [PATCH 10/20] feature: redis implemented --- app/deployer.py | 4 +++- nebula/controller/controller.py | 36 +++++++++++++++++++++++++++------ nebula/controller/database.py | 26 ++++++++++++++++++++++++ pyproject.toml | 1 + 4 files changed, 60 insertions(+), 7 deletions(-) diff --git a/app/deployer.py b/app/deployer.py index 34b2160fa..7763d6a8d 100644 --- a/app/deployer.py +++ b/app/deployer.py @@ -1142,7 +1142,7 @@ def run_database(self): host_port_commander = 8081 environment_commander = { - "REDIS_HOSTS": "local:redis:6379", + "REDIS_HOSTS": f"{os.environ['USER']}_redis", "HTTP_USER": "root", "HTTP_PASSWORD": os.environ.get("HTTP_PASSWORD"), } @@ -1228,6 +1228,8 @@ def run_controller(self): "DB_PORT": 5432, "DB_USER": "nebula", "DB_PASSWORD": "nebula", + "REDIS_HOST": f"{os.environ['USER']}_redis", + "REDIS_PORT": 6379, } volumes = ["/nebula", "/var/run/docker.sock"] diff --git a/nebula/controller/controller.py b/nebula/controller/controller.py index d2ba4c7a1..6445312d7 100755 --- a/nebula/controller/controller.py +++ b/nebula/controller/controller.py @@ -19,6 +19,7 @@ from nebula.controller.database import ( init_db_pool, close_db_pool, + init_redis_pool, scenario_set_all_status_to_finished, scenario_set_status_to_finished, ) @@ -123,6 +124,7 @@ async def lifespan(app: FastAPI): # Initialize the database connection pool await init_db_pool() + await init_redis_pool() yield @@ -406,12 +408,18 @@ async def remove_scenario( Returns: dict: A message indicating successful removal. """ - from nebula.controller.database import remove_scenario_by_name + from nebula.controller.database import remove_scenario_by_name, get_user_by_scenario_name, REDIS_POOL from nebula.controller.scenarios import ScenarioManagement try: + user = await get_user_by_scenario_name(scenario_name) await remove_scenario_by_name(scenario_name) ScenarioManagement.remove_files_by_scenario(scenario_name) + # Invalidate caches + if user: + await REDIS_POOL.delete(f"scenarios:{user['user']}:{user['role']}") + await REDIS_POOL.delete(f"scenario:{scenario_name}") + except Exception as e: logging.exception(f"Error removing scenario {scenario_name}: {e}") raise HTTPException(status_code=500, detail="Internal server error") @@ -434,10 +442,17 @@ async def get_scenarios( Returns: dict: A list of scenarios and the currently running scenario. """ - from nebula.controller.database import get_all_scenarios_and_check_completed, get_running_scenario + from nebula.controller.database import get_all_scenarios_and_check_completed, get_running_scenario, REDIS_POOL try: - scenarios = await get_all_scenarios_and_check_completed(username=user, role=role) + # Try to get from cache first + cached_scenarios = await REDIS_POOL.get(f"scenarios:{user}:{role}") + if cached_scenarios: + scenarios = json.loads(cached_scenarios) + else: + scenarios = await get_all_scenarios_and_check_completed(username=user, role=role) + await REDIS_POOL.set(f"scenarios:{user}:{role}", json.dumps(scenarios), ex=3600) # Cache for 1 hour + if role == "admin": scenario_running = await get_running_scenario() else: @@ -474,11 +489,14 @@ async def update_scenario( Returns: dict: A message confirming the update. """ - from nebula.controller.database import scenario_update_record + from nebula.controller.database import scenario_update_record, REDIS_POOL # from nebula.controller.scenarios import Scenario try: await scenario_update_record(scenario_name, start_time, end_time, scenario, status, username) + # Invalidate caches + await REDIS_POOL.delete(f"scenarios:{username}:{role}") + await REDIS_POOL.delete(f"scenario:{scenario_name}") except Exception as e: logging.exception(f"Error updating scenario {scenario_name}: {e}") raise HTTPException(status_code=500, detail="Internal server error") @@ -576,10 +594,16 @@ async def get_scenario_by_name( Returns: dict: The scenario data. """ - from nebula.controller.database import get_scenario_by_name + from nebula.controller.database import get_scenario_by_name, REDIS_POOL try: - scenario = await get_scenario_by_name(scenario_name) + cached_scenario = await REDIS_POOL.get(f"scenario:{scenario_name}") + if cached_scenario: + scenario = json.loads(cached_scenario) + else: + scenario = await get_scenario_by_name(scenario_name) + if scenario: + await REDIS_POOL.set(f"scenario:{scenario_name}", json.dumps(scenario), ex=3600) # Cache for 1 hour except Exception as e: logging.exception(f"Error obtaining scenario {scenario_name}: {e}") raise HTTPException(status_code=500, detail="Internal server error") diff --git a/nebula/controller/database.py b/nebula/controller/database.py index 5b01f0be2..a07205388 100755 --- a/nebula/controller/database.py +++ b/nebula/controller/database.py @@ -4,6 +4,7 @@ import json import asyncpg import asyncio +import redis.asyncio as aioredis from passlib.context import CryptContext from nebula.controller.scenarios import Scenario @@ -11,6 +12,7 @@ # --- Configuration --- # Use environment variables for database credentials from the Docker Compose file DATABASE_URL = f"postgresql://{os.environ.get('DB_USER')}:{os.environ.get('DB_PASSWORD')}@{os.environ.get('DB_HOST')}:{os.environ.get('DB_PORT')}/nebula" +REDIS_URL = f"redis://{os.environ.get('REDIS_HOST', 'localhost')}:{os.environ.get('REDIS_PORT', 6379)}" # Password hashing context (using Argon2) pwd_context = CryptContext(schemes=["argon2"], deprecated="auto") @@ -21,6 +23,29 @@ # --- Connection Pool Management --- # Global pool variable, should be initialized at application startup POOL = None +REDIS_POOL = None + +async def init_redis_pool(): + """ + Initializes the asynchronous Redis connection pool. + """ + global REDIS_POOL + if REDIS_POOL is None: + try: + REDIS_POOL = await aioredis.from_url(REDIS_URL, encoding="utf-8", decode_responses=True) + logging.info("Redis connection pool successfully created.") + except Exception as e: + logging.critical(f"Failed to create Redis connection pool: {e}", exc_info=True) + raise + +async def close_redis_pool(): + """ + Closes the asynchronous Redis connection pool. + """ + global REDIS_POOL + if REDIS_POOL: + await REDIS_POOL.close() + logging.info("Redis connection pool closed.") async def init_db_pool(): """ @@ -50,6 +75,7 @@ async def close_db_pool(): if POOL: await POOL.close() logging.info("Database connection pool closed.") + await close_redis_pool() # --- User Management Functions --- diff --git a/pyproject.toml b/pyproject.toml index 7c7e666b5..bcd0d1920 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,7 @@ controller = [ "seaborn==0.13.2", "scikit-image==0.24.0", "scikit-learn==1.5.1", + "redis==5.0.7", ] database = [ "asyncpg==0.30.0", From 90458ffd4702b63fca4fbb7e525761a058daa38d Mon Sep 17 00:00:00 2001 From: FerTV Date: Thu, 10 Jul 2025 19:00:08 +0200 Subject: [PATCH 11/20] fix: docker containers tags and removal of containers --- app/deployer.py | 169 ++++++++++---------------------- nebula/controller/controller.py | 25 +++-- 2 files changed, 69 insertions(+), 125 deletions(-) diff --git a/app/deployer.py b/app/deployer.py index 7763d6a8d..6729a88fa 100644 --- a/app/deployer.py +++ b/app/deployer.py @@ -601,7 +601,7 @@ def __init__(self, args): ) logging.exception(warning_msg) sys.exit(1) - + self.configure_logger() self.credentialmanager = CredentialManager() self.credentialmanager.check_all_credentials() @@ -945,7 +945,6 @@ def run_frontend(self): client = docker.from_env() environment = { - "NEBULA_CONTROLLER_NAME": os.environ["USER"], "SECRET_KEY": os.environ.get("SECRET_KEY"), "NEBULA_PRODUCTION": self.production, "NEBULA_ENV_TAG": self.env_tag, @@ -1021,167 +1020,105 @@ def run_database(self): "/var/run/docker.sock not found, please check if Docker is running and Docker Compose is installed." ) - network_name = f"{os.environ['USER']}_nebula-net-base" - - ############### - # POSTGRES DB # - ############### - - host_port = 54312 - - # Create the Docker network + network_name = self.get_network_name("net-base") base = DockerUtils.create_docker_network(network_name) + Deployer._add_network_to_metadata(network_name) client = docker.from_env() - environment = { + # --- PostgreSQL --- + pg_container_name = self.get_container_name("nebula-database") + pg_environment = { "POSTGRES_USER": "nebula", "POSTGRES_PASSWORD": os.environ.get("POSTGRES_PASSWORD"), "POSTGRES_DB": "nebula", - "NEBULA_DATABASES_PORT": self.controller_host, } - host_sql_path = os.path.join(self.root_path, "nebula/database/init-configs.sql") - container_sql_path = "/docker-entrypoint-initdb.d/init-configs.sql" - db_data_path = os.path.join(self.databases_dir, "postgres-data") os.makedirs(db_data_path, exist_ok=True) - host_config = client.api.create_host_config( + pg_host_config = client.api.create_host_config( binds=[ - f"{host_sql_path}:{container_sql_path}", + f"{host_sql_path}:/docker-entrypoint-initdb.d/init-configs.sql", f"{db_data_path}:/var/lib/postgresql/data", ], - extra_hosts={"host.docker.internal": "host-gateway"}, - port_bindings={5432: host_port}, + port_bindings={5432: 5432}, + ) + pg_networking_config = client.api.create_networking_config( + {f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.125")} ) - networking_config = client.api.create_networking_config({ - f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.125") - }) - - container_id = client.api.create_container( + pg_container = client.api.create_container( image="nebula-database", - name=f"{os.environ['USER']}_nebula-database", + name=pg_container_name, detach=True, - environment=environment, - host_config=host_config, - networking_config=networking_config, + environment=pg_environment, + host_config=pg_host_config, + networking_config=pg_networking_config, ) - - client.api.start(container_id) - - ################ - # POSTGRES WEB # - ################ - - pgweb_host_port = 8085 - pgweb_container_port = 8081 - - pgweb_host_config = client.api.create_host_config( - port_bindings={pgweb_container_port: pgweb_host_port}, - device_requests=[{ - "Driver": "nvidia", - "Count": -1, - "Capabilities": [["gpu"]], - }] if self.gpu_available else None, + client.api.start(pg_container) + Deployer._add_container_to_metadata(pg_container_name) + + # --- PGWeb --- + pgweb_container_name = self.get_container_name("nebula-pgweb") + pgweb_host_config = client.api.create_host_config(port_bindings={8081: 8085}) + pgweb_networking_config = client.api.create_networking_config( + {f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.135")} ) - pgweb_networking_config = client.api.create_networking_config({ - f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.135") - }) - - pgweb_container_name = f"{os.environ.get('USER')}_nebula-pgweb" - - pgweb_container_id = client.api.create_container( + pgweb_container = client.api.create_container( image="nebula-pgweb", name=pgweb_container_name, detach=True, host_config=pgweb_host_config, networking_config=pgweb_networking_config, ) - - client.api.start(pgweb_container_id) - - ######### - # REDIS # - ######### - - # network_name = "redis-network" - # base = DockerUtils.create_docker_network(network_name) - - # --- REDIS --- - - host_port_redis = 6379 # You can change this if you want a different external port - - host_config_redis = client.api.create_host_config( - binds=[ - f"redis:/var/lib/redis", - ], - extra_hosts={"host.docker.internal": "host-gateway"}, - port_bindings={6379: host_port_redis}, + client.api.start(pgweb_container) + Deployer._add_container_to_metadata(pgweb_container_name) + + # --- Redis --- + redis_container_name = self.get_container_name("redis") + redis_host_config = client.api.create_host_config( + binds=[f"{self.get_container_name('redis_data')}:/var/lib/redis"], + port_bindings={6379: 6379}, + ) + redis_networking_config = client.api.create_networking_config( + {f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.126")} ) - - redis_networking_config = client.api.create_networking_config({ - f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.126") - }) redis_container = client.api.create_container( image="nebula-redis", - name=f"{os.environ['USER']}_redis", + name=redis_container_name, detach=True, command="redis-server", - host_config=host_config_redis, + host_config=redis_host_config, networking_config=redis_networking_config, ) - client.api.start(redis_container) + Deployer._add_container_to_metadata(redis_container_name) - # --- REDIS COMMANDER --- - - host_port_commander = 8081 - - environment_commander = { - "REDIS_HOSTS": f"{os.environ['USER']}_redis", + # --- Redis Commander --- + commander_container_name = self.get_container_name("redis_commander") + commander_environment = { + "REDIS_HOSTS": f"local:{redis_container_name}:6379", "HTTP_USER": "root", "HTTP_PASSWORD": os.environ.get("HTTP_PASSWORD"), } - - host_config_commander = client.api.create_host_config( - extra_hosts={"host.docker.internal": "host-gateway"}, - port_bindings={8081: host_port_commander}, + commander_host_config = client.api.create_host_config(port_bindings={8081: 8081}) + commander_networking_config = client.api.create_networking_config( + {f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.127")} ) - commander_networking_config = client.api.create_networking_config({ - f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.127") - }) - commander_container = client.api.create_container( image="nebula-rediscommander", - name=f"{os.environ['USER']}_redis_commander", + name=commander_container_name, detach=True, - environment=environment_commander, - host_config=host_config_commander, + environment=commander_environment, + host_config=commander_host_config, networking_config=commander_networking_config, ) - client.api.start(commander_container) - - def stop_database(self): - """ - Stops and removes all NEBULA database Docker containers associated with the current user. - - Responsibilities: - - Detects running Docker containers with names starting with '_nebula-database'. - - Gracefully stops and removes these database containers. - - Typical use cases: - - Cleaning up database containers during shutdown or redeployment processes. - """ - DockerUtils.remove_containers_by_prefix(f"{os.environ['USER']}_nebula-database") - DockerUtils.remove_containers_by_prefix(f"{os.environ['USER']}_nebula-pgweb") - DockerUtils.remove_containers_by_prefix(f"{os.environ['USER']}_nebula-redis") - DockerUtils.remove_containers_by_prefix(f"{os.environ['USER']}_nebula-rediscommander") + Deployer._add_container_to_metadata(commander_container_name) def run_controller(self): if sys.platform == "win32": @@ -1224,11 +1161,11 @@ def run_controller(self): "NEBULA_CONTROLLER_PORT": self.controller_port, "NEBULA_CONTROLLER_HOST": self.controller_host, "NEBULA_FRONTEND_PORT": self.frontend_port, - "DB_HOST": f"{os.environ['USER']}_nebula-database", + "DB_HOST": self.get_container_name("nebula-database"), "DB_PORT": 5432, "DB_USER": "nebula", "DB_PASSWORD": "nebula", - "REDIS_HOST": f"{os.environ['USER']}_redis", + "REDIS_HOST": self.get_container_name("redis"), "REDIS_PORT": 6379, } diff --git a/nebula/controller/controller.py b/nebula/controller/controller.py index 6445312d7..67250c7a7 100755 --- a/nebula/controller/controller.py +++ b/nebula/controller/controller.py @@ -73,9 +73,6 @@ def format(self, record): return super().format(record) -os.environ["NEBULA_CONTROLLER_NAME"] = os.environ.get("USER") - - def configure_logger(controller_log): """ Configures the logging system for the controller. @@ -383,6 +380,7 @@ async def stop_scenario( This function does not currently trigger statistics generation. """ from nebula.controller.scenarios import ScenarioManagement + from nebula.controller.database import get_user_info, REDIS_POOL ScenarioManagement.cleanup_scenario_containers() try: @@ -390,6 +388,13 @@ async def stop_scenario( await scenario_set_all_status_to_finished() else: await scenario_set_status_to_finished(scenario_name) + + # Invalidate caches + if username: + user = await get_user_info(username) + if user: + await REDIS_POOL.delete(f"scenarios:{user['user']}:{user['role']}") + await REDIS_POOL.delete(f"scenario:{scenario_name}") except Exception as e: logging.exception(f"Error setting scenario {scenario_name} to finished: {e}") raise HTTPException(status_code=500, detail="Internal server error") @@ -408,16 +413,18 @@ async def remove_scenario( Returns: dict: A message indicating successful removal. """ - from nebula.controller.database import remove_scenario_by_name, get_user_by_scenario_name, REDIS_POOL + from nebula.controller.database import remove_scenario_by_name, get_user_by_scenario_name, get_user_info, REDIS_POOL from nebula.controller.scenarios import ScenarioManagement try: - user = await get_user_by_scenario_name(scenario_name) + username = await get_user_by_scenario_name(scenario_name) await remove_scenario_by_name(scenario_name) ScenarioManagement.remove_files_by_scenario(scenario_name) # Invalidate caches - if user: - await REDIS_POOL.delete(f"scenarios:{user['user']}:{user['role']}") + if username: + user = await get_user_info(username) + if user: + await REDIS_POOL.delete(f"scenarios:{user['user']}:{user['role']}") await REDIS_POOL.delete(f"scenario:{scenario_name}") except Exception as e: @@ -682,7 +689,7 @@ async def update_nodes( raise HTTPException(status_code=500, detail="Internal server error") url = ( - f"http://{os.environ['NEBULA_CONTROLLER_NAME']}_nebula-frontend/platform/dashboard/{scenario_name}/node/update" + f"http://{os.environ['NEBULA_CONTROLLER_HOST']}_nebula-frontend/platform/dashboard/{scenario_name}/node/update" ) config["timestamp"] = str(timestamp) @@ -717,7 +724,7 @@ async def node_done( Returns the response from the frontend or raises an HTTPException if it fails. """ - url = f"http://{os.environ['NEBULA_CONTROLLER_NAME']}_nebula-frontend/platform/dashboard/{scenario_name}/node/done" + url = f"http://{os.environ['NEBULA_CONTROLLER_HOST']}_nebula-frontend/platform/dashboard/{scenario_name}/node/done" data = await request.json() From a1da74e6667b52f7137096ffbe782d8ed74e8cb0 Mon Sep 17 00:00:00 2001 From: FerTV Date: Fri, 11 Jul 2025 11:25:43 +0200 Subject: [PATCH 12/20] chore: temporarily disable Redis --- nebula/controller/controller.py | 42 ++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/nebula/controller/controller.py b/nebula/controller/controller.py index 67250c7a7..3ea52d79b 100755 --- a/nebula/controller/controller.py +++ b/nebula/controller/controller.py @@ -393,8 +393,9 @@ async def stop_scenario( if username: user = await get_user_info(username) if user: - await REDIS_POOL.delete(f"scenarios:{user['user']}:{user['role']}") - await REDIS_POOL.delete(f"scenario:{scenario_name}") + # await REDIS_POOL.delete(f"scenarios:{user['user']}:{user['role']}") + pass + # await REDIS_POOL.delete(f"scenario:{scenario_name}") except Exception as e: logging.exception(f"Error setting scenario {scenario_name} to finished: {e}") raise HTTPException(status_code=500, detail="Internal server error") @@ -424,8 +425,9 @@ async def remove_scenario( if username: user = await get_user_info(username) if user: - await REDIS_POOL.delete(f"scenarios:{user['user']}:{user['role']}") - await REDIS_POOL.delete(f"scenario:{scenario_name}") + # await REDIS_POOL.delete(f"scenarios:{user['user']}:{user['role']}") + pass + # await REDIS_POOL.delete(f"scenario:{scenario_name}") except Exception as e: logging.exception(f"Error removing scenario {scenario_name}: {e}") @@ -453,12 +455,13 @@ async def get_scenarios( try: # Try to get from cache first - cached_scenarios = await REDIS_POOL.get(f"scenarios:{user}:{role}") - if cached_scenarios: - scenarios = json.loads(cached_scenarios) - else: - scenarios = await get_all_scenarios_and_check_completed(username=user, role=role) - await REDIS_POOL.set(f"scenarios:{user}:{role}", json.dumps(scenarios), ex=3600) # Cache for 1 hour + # cached_scenarios = await REDIS_POOL.get(f"scenarios:{user}:{role}") + # if cached_scenarios: + # scenarios = json.loads(cached_scenarios) + # else: + # scenarios = await get_all_scenarios_and_check_completed(username=user, role=role) + # await REDIS_POOL.set(f"scenarios:{user}:{role}", json.dumps(scenarios), ex=3600) # Cache for 1 hour + scenarios = await get_all_scenarios_and_check_completed(username=user, role=role) if role == "admin": scenario_running = await get_running_scenario() @@ -502,8 +505,8 @@ async def update_scenario( try: await scenario_update_record(scenario_name, start_time, end_time, scenario, status, username) # Invalidate caches - await REDIS_POOL.delete(f"scenarios:{username}:{role}") - await REDIS_POOL.delete(f"scenario:{scenario_name}") + # await REDIS_POOL.delete(f"scenarios:{username}:{role}") + # await REDIS_POOL.delete(f"scenario:{scenario_name}") except Exception as e: logging.exception(f"Error updating scenario {scenario_name}: {e}") raise HTTPException(status_code=500, detail="Internal server error") @@ -604,13 +607,14 @@ async def get_scenario_by_name( from nebula.controller.database import get_scenario_by_name, REDIS_POOL try: - cached_scenario = await REDIS_POOL.get(f"scenario:{scenario_name}") - if cached_scenario: - scenario = json.loads(cached_scenario) - else: - scenario = await get_scenario_by_name(scenario_name) - if scenario: - await REDIS_POOL.set(f"scenario:{scenario_name}", json.dumps(scenario), ex=3600) # Cache for 1 hour + # cached_scenario = await REDIS_POOL.get(f"scenario:{scenario_name}") + # if cached_scenario: + # scenario = json.loads(cached_scenario) + # else: + # scenario = await get_scenario_by_name(scenario_name) + # if scenario: + # await REDIS_POOL.set(f"scenario:{scenario_name}", json.dumps(scenario), ex=3600) # Cache for 1 hour + scenario = await get_scenario_by_name(scenario_name) except Exception as e: logging.exception(f"Error obtaining scenario {scenario_name}: {e}") raise HTTPException(status_code=500, detail="Internal server error") From 381fb5d806e5dfe5598ef2ee5d643f6fea21b6db Mon Sep 17 00:00:00 2001 From: FerTV Date: Fri, 11 Jul 2025 11:59:58 +0200 Subject: [PATCH 13/20] fix: docker tags --- nebula/controller/controller.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nebula/controller/controller.py b/nebula/controller/controller.py index 3ea52d79b..e1a259466 100755 --- a/nebula/controller/controller.py +++ b/nebula/controller/controller.py @@ -693,7 +693,7 @@ async def update_nodes( raise HTTPException(status_code=500, detail="Internal server error") url = ( - f"http://{os.environ['NEBULA_CONTROLLER_HOST']}_nebula-frontend/platform/dashboard/{scenario_name}/node/update" + f"http://{os.environ['NEBULA_ENV_TAG']}_{os.environ['NEBULA_PREFIX_TAG']}_{os.environ['NEBULA_USER_TAG']}_nebula-frontend/platform/dashboard/{scenario_name}/node/update" ) config["timestamp"] = str(timestamp) @@ -728,7 +728,7 @@ async def node_done( Returns the response from the frontend or raises an HTTPException if it fails. """ - url = f"http://{os.environ['NEBULA_CONTROLLER_HOST']}_nebula-frontend/platform/dashboard/{scenario_name}/node/done" + url = f"http://{os.environ['NEBULA_ENV_TAG']}_{os.environ['NEBULA_PREFIX_TAG']}_{os.environ['NEBULA_USER_TAG']}_nebula-frontend/platform/dashboard/{scenario_name}/node/done" data = await request.json() From e6677e981209a483a221872508a4fd782a1403dd Mon Sep 17 00:00:00 2001 From: FerTV Date: Fri, 11 Jul 2025 12:03:59 +0200 Subject: [PATCH 14/20] fix: launching scenarios with different users simultaneously --- nebula/frontend/app.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nebula/frontend/app.py b/nebula/frontend/app.py index 3510aa0f0..84d4b1ca6 100755 --- a/nebula/frontend/app.py +++ b/nebula/frontend/app.py @@ -2153,7 +2153,8 @@ async def assign_available_gpu(scenario_data, role): running_gpus = [] # Obtain associated gpus of the running scenarios for scenario in running_scenarios: - scenario_gpus = json.loads(scenario["gpu_id"]) + config = json.loads(scenario["config"]) + scenario_gpus = config.get("gpu_id", []) # Obtain the list of gpus in use without duplicates for gpu in scenario_gpus: if gpu not in running_gpus: From be8ba6c26a1a80fa943bb2cb86a8ab8a8b334a32 Mon Sep 17 00:00:00 2001 From: FerTV Date: Fri, 11 Jul 2025 17:53:56 +0200 Subject: [PATCH 15/20] fix: delete redis --- app/deployer.py | 47 --------------------------------- nebula/controller/controller.py | 46 +++----------------------------- nebula/controller/database.py | 27 ------------------- pyproject.toml | 1 - 4 files changed, 4 insertions(+), 117 deletions(-) diff --git a/app/deployer.py b/app/deployer.py index 6729a88fa..e37dcb45a 100644 --- a/app/deployer.py +++ b/app/deployer.py @@ -105,7 +105,6 @@ def check_all_credentials(self): self.check_credential("SECRET_KEY", is_password=False) self.check_credential("GF_SECURITY_ADMIN_PASSWORD") self.check_credential("POSTGRES_PASSWORD") - self.check_credential("HTTP_PASSWORD") class NebulaEventHandler(PatternMatchingEventHandler): @@ -1076,50 +1075,6 @@ def run_database(self): client.api.start(pgweb_container) Deployer._add_container_to_metadata(pgweb_container_name) - # --- Redis --- - redis_container_name = self.get_container_name("redis") - redis_host_config = client.api.create_host_config( - binds=[f"{self.get_container_name('redis_data')}:/var/lib/redis"], - port_bindings={6379: 6379}, - ) - redis_networking_config = client.api.create_networking_config( - {f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.126")} - ) - - redis_container = client.api.create_container( - image="nebula-redis", - name=redis_container_name, - detach=True, - command="redis-server", - host_config=redis_host_config, - networking_config=redis_networking_config, - ) - client.api.start(redis_container) - Deployer._add_container_to_metadata(redis_container_name) - - # --- Redis Commander --- - commander_container_name = self.get_container_name("redis_commander") - commander_environment = { - "REDIS_HOSTS": f"local:{redis_container_name}:6379", - "HTTP_USER": "root", - "HTTP_PASSWORD": os.environ.get("HTTP_PASSWORD"), - } - commander_host_config = client.api.create_host_config(port_bindings={8081: 8081}) - commander_networking_config = client.api.create_networking_config( - {f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.127")} - ) - - commander_container = client.api.create_container( - image="nebula-rediscommander", - name=commander_container_name, - detach=True, - environment=commander_environment, - host_config=commander_host_config, - networking_config=commander_networking_config, - ) - client.api.start(commander_container) - Deployer._add_container_to_metadata(commander_container_name) - def run_controller(self): if sys.platform == "win32": if not os.path.exists("//./pipe/docker_Engine"): @@ -1165,8 +1120,6 @@ def run_controller(self): "DB_PORT": 5432, "DB_USER": "nebula", "DB_PASSWORD": "nebula", - "REDIS_HOST": self.get_container_name("redis"), - "REDIS_PORT": 6379, } volumes = ["/nebula", "/var/run/docker.sock"] diff --git a/nebula/controller/controller.py b/nebula/controller/controller.py index e1a259466..1805fd3f0 100755 --- a/nebula/controller/controller.py +++ b/nebula/controller/controller.py @@ -19,7 +19,6 @@ from nebula.controller.database import ( init_db_pool, close_db_pool, - init_redis_pool, scenario_set_all_status_to_finished, scenario_set_status_to_finished, ) @@ -121,7 +120,6 @@ async def lifespan(app: FastAPI): # Initialize the database connection pool await init_db_pool() - await init_redis_pool() yield @@ -356,7 +354,6 @@ async def run_scenario( @app.post("/scenarios/stop") async def stop_scenario( scenario_name: str = Body(..., embed=True), - username: str = Body(..., embed=True), all: bool = Body(False, embed=True), ): """ @@ -380,7 +377,6 @@ async def stop_scenario( This function does not currently trigger statistics generation. """ from nebula.controller.scenarios import ScenarioManagement - from nebula.controller.database import get_user_info, REDIS_POOL ScenarioManagement.cleanup_scenario_containers() try: @@ -388,14 +384,6 @@ async def stop_scenario( await scenario_set_all_status_to_finished() else: await scenario_set_status_to_finished(scenario_name) - - # Invalidate caches - if username: - user = await get_user_info(username) - if user: - # await REDIS_POOL.delete(f"scenarios:{user['user']}:{user['role']}") - pass - # await REDIS_POOL.delete(f"scenario:{scenario_name}") except Exception as e: logging.exception(f"Error setting scenario {scenario_name} to finished: {e}") raise HTTPException(status_code=500, detail="Internal server error") @@ -414,20 +402,12 @@ async def remove_scenario( Returns: dict: A message indicating successful removal. """ - from nebula.controller.database import remove_scenario_by_name, get_user_by_scenario_name, get_user_info, REDIS_POOL + from nebula.controller.database import remove_scenario_by_name, get_user_by_scenario_name from nebula.controller.scenarios import ScenarioManagement try: - username = await get_user_by_scenario_name(scenario_name) await remove_scenario_by_name(scenario_name) ScenarioManagement.remove_files_by_scenario(scenario_name) - # Invalidate caches - if username: - user = await get_user_info(username) - if user: - # await REDIS_POOL.delete(f"scenarios:{user['user']}:{user['role']}") - pass - # await REDIS_POOL.delete(f"scenario:{scenario_name}") except Exception as e: logging.exception(f"Error removing scenario {scenario_name}: {e}") @@ -451,16 +431,9 @@ async def get_scenarios( Returns: dict: A list of scenarios and the currently running scenario. """ - from nebula.controller.database import get_all_scenarios_and_check_completed, get_running_scenario, REDIS_POOL + from nebula.controller.database import get_all_scenarios_and_check_completed, get_running_scenario try: - # Try to get from cache first - # cached_scenarios = await REDIS_POOL.get(f"scenarios:{user}:{role}") - # if cached_scenarios: - # scenarios = json.loads(cached_scenarios) - # else: - # scenarios = await get_all_scenarios_and_check_completed(username=user, role=role) - # await REDIS_POOL.set(f"scenarios:{user}:{role}", json.dumps(scenarios), ex=3600) # Cache for 1 hour scenarios = await get_all_scenarios_and_check_completed(username=user, role=role) if role == "admin": @@ -499,14 +472,10 @@ async def update_scenario( Returns: dict: A message confirming the update. """ - from nebula.controller.database import scenario_update_record, REDIS_POOL - # from nebula.controller.scenarios import Scenario + from nebula.controller.database import scenario_update_record try: await scenario_update_record(scenario_name, start_time, end_time, scenario, status, username) - # Invalidate caches - # await REDIS_POOL.delete(f"scenarios:{username}:{role}") - # await REDIS_POOL.delete(f"scenario:{scenario_name}") except Exception as e: logging.exception(f"Error updating scenario {scenario_name}: {e}") raise HTTPException(status_code=500, detail="Internal server error") @@ -604,16 +573,9 @@ async def get_scenario_by_name( Returns: dict: The scenario data. """ - from nebula.controller.database import get_scenario_by_name, REDIS_POOL + from nebula.controller.database import get_scenario_by_name try: - # cached_scenario = await REDIS_POOL.get(f"scenario:{scenario_name}") - # if cached_scenario: - # scenario = json.loads(cached_scenario) - # else: - # scenario = await get_scenario_by_name(scenario_name) - # if scenario: - # await REDIS_POOL.set(f"scenario:{scenario_name}", json.dumps(scenario), ex=3600) # Cache for 1 hour scenario = await get_scenario_by_name(scenario_name) except Exception as e: logging.exception(f"Error obtaining scenario {scenario_name}: {e}") diff --git a/nebula/controller/database.py b/nebula/controller/database.py index a07205388..7392c5967 100755 --- a/nebula/controller/database.py +++ b/nebula/controller/database.py @@ -4,15 +4,12 @@ import json import asyncpg import asyncio -import redis.asyncio as aioredis from passlib.context import CryptContext -from nebula.controller.scenarios import Scenario # --- Configuration --- # Use environment variables for database credentials from the Docker Compose file DATABASE_URL = f"postgresql://{os.environ.get('DB_USER')}:{os.environ.get('DB_PASSWORD')}@{os.environ.get('DB_HOST')}:{os.environ.get('DB_PORT')}/nebula" -REDIS_URL = f"redis://{os.environ.get('REDIS_HOST', 'localhost')}:{os.environ.get('REDIS_PORT', 6379)}" # Password hashing context (using Argon2) pwd_context = CryptContext(schemes=["argon2"], deprecated="auto") @@ -23,29 +20,6 @@ # --- Connection Pool Management --- # Global pool variable, should be initialized at application startup POOL = None -REDIS_POOL = None - -async def init_redis_pool(): - """ - Initializes the asynchronous Redis connection pool. - """ - global REDIS_POOL - if REDIS_POOL is None: - try: - REDIS_POOL = await aioredis.from_url(REDIS_URL, encoding="utf-8", decode_responses=True) - logging.info("Redis connection pool successfully created.") - except Exception as e: - logging.critical(f"Failed to create Redis connection pool: {e}", exc_info=True) - raise - -async def close_redis_pool(): - """ - Closes the asynchronous Redis connection pool. - """ - global REDIS_POOL - if REDIS_POOL: - await REDIS_POOL.close() - logging.info("Redis connection pool closed.") async def init_db_pool(): """ @@ -75,7 +49,6 @@ async def close_db_pool(): if POOL: await POOL.close() logging.info("Database connection pool closed.") - await close_redis_pool() # --- User Management Functions --- diff --git a/pyproject.toml b/pyproject.toml index bcd0d1920..7c7e666b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,7 +82,6 @@ controller = [ "seaborn==0.13.2", "scikit-image==0.24.0", "scikit-learn==1.5.1", - "redis==5.0.7", ] database = [ "asyncpg==0.30.0", From 1bfb35b40ff72e4552ba4268d4dca741548f0a06 Mon Sep 17 00:00:00 2001 From: FerTV Date: Fri, 11 Jul 2025 17:55:52 +0200 Subject: [PATCH 16/20] fix: delete unused dockerfiles --- nebula/database/redis/Dockerfile | 1 - nebula/database/rediscommander/Dockerfile | 1 - 2 files changed, 2 deletions(-) delete mode 100644 nebula/database/redis/Dockerfile delete mode 100644 nebula/database/rediscommander/Dockerfile diff --git a/nebula/database/redis/Dockerfile b/nebula/database/redis/Dockerfile deleted file mode 100644 index 31ff6af02..000000000 --- a/nebula/database/redis/Dockerfile +++ /dev/null @@ -1 +0,0 @@ -FROM redis:latest \ No newline at end of file diff --git a/nebula/database/rediscommander/Dockerfile b/nebula/database/rediscommander/Dockerfile deleted file mode 100644 index 0eb1e1ead..000000000 --- a/nebula/database/rediscommander/Dockerfile +++ /dev/null @@ -1 +0,0 @@ -FROM rediscommander/redis-commander:latest \ No newline at end of file From dc52d667ea07d781e80e467a5bd601abaee01c39 Mon Sep 17 00:00:00 2001 From: FerTV Date: Fri, 11 Jul 2025 18:10:13 +0200 Subject: [PATCH 17/20] feature: dockerfile for pgweb addded --- Makefile | 1 + nebula/database/pgweb/Dockerfile | 1 + 2 files changed, 2 insertions(+) create mode 100644 nebula/database/pgweb/Dockerfile diff --git a/Makefile b/Makefile index 01eb836b0..c1e0cd768 100644 --- a/Makefile +++ b/Makefile @@ -77,6 +77,7 @@ update-dockers: ## Update docker images @echo "🐳 Building nebula-database docker image. Do you want to continue (overrides existing image)? (y/n)" @read ans; if [ "$${ans:-N}" = y ]; then \ docker build -t nebula-database -f nebula/database/Dockerfile .; \ + docker build -t nebula-pgweb -f nebula/database/pgweb/Dockerfile .; \ else \ echo "Skipping nebula-database docker build."; \ fi diff --git a/nebula/database/pgweb/Dockerfile b/nebula/database/pgweb/Dockerfile new file mode 100644 index 000000000..c9bb132e0 --- /dev/null +++ b/nebula/database/pgweb/Dockerfile @@ -0,0 +1 @@ +FROM sosedoff/pgweb From 4ac0d7c5d4f9c9e411a5fdde38291fe4b0198231 Mon Sep 17 00:00:00 2001 From: FerTV Date: Wed, 16 Jul 2025 11:43:14 +0200 Subject: [PATCH 18/20] fix: static admin password --- app/deployer.py | 4 +++- nebula/controller/controller.py | 2 ++ nebula/controller/database.py | 21 +++++++++++++++++++++ nebula/database/init-configs.sql | 6 ------ 4 files changed, 26 insertions(+), 7 deletions(-) diff --git a/app/deployer.py b/app/deployer.py index e37dcb45a..95bffda7e 100644 --- a/app/deployer.py +++ b/app/deployer.py @@ -105,6 +105,7 @@ def check_all_credentials(self): self.check_credential("SECRET_KEY", is_password=False) self.check_credential("GF_SECURITY_ADMIN_PASSWORD") self.check_credential("POSTGRES_PASSWORD") + self.check_credential("NEBULA_ADMIN_PASSWORD") class NebulaEventHandler(PatternMatchingEventHandler): @@ -1119,7 +1120,8 @@ def run_controller(self): "DB_HOST": self.get_container_name("nebula-database"), "DB_PORT": 5432, "DB_USER": "nebula", - "DB_PASSWORD": "nebula", + "DB_PASSWORD": os.environ.get("POSTGRES_PASSWORD"), + "NEBULA_ADMIN_PASSWORD": os.environ.get("NEBULA_ADMIN_PASSWORD") } volumes = ["/nebula", "/var/run/docker.sock"] diff --git a/nebula/controller/controller.py b/nebula/controller/controller.py index 1805fd3f0..a134fb034 100755 --- a/nebula/controller/controller.py +++ b/nebula/controller/controller.py @@ -19,6 +19,7 @@ from nebula.controller.database import ( init_db_pool, close_db_pool, + insert_default_admin, scenario_set_all_status_to_finished, scenario_set_status_to_finished, ) @@ -120,6 +121,7 @@ async def lifespan(app: FastAPI): # Initialize the database connection pool await init_db_pool() + await insert_default_admin() yield diff --git a/nebula/controller/database.py b/nebula/controller/database.py index 7392c5967..8fb1b7cc9 100755 --- a/nebula/controller/database.py +++ b/nebula/controller/database.py @@ -53,6 +53,27 @@ async def close_db_pool(): # --- User Management Functions --- +async def insert_default_admin(): + """ + Inserts a default 'ADMIN' user into the database with a hashed password. + The password must be provided via the ADMIN_PASSWORD environment variable. + """ + admin_password = os.environ.get("NEBULA_ADMIN_PASSWORD") + + hashed_password = pwd_context.hash(admin_password) + + query = """ + INSERT INTO users ("user", password, role) + VALUES ($1, $2, $3) + ON CONFLICT ("user") DO NOTHING; + """ + try: + async with POOL.acquire() as conn: + await conn.execute(query, "ADMIN", hashed_password, "admin") + logging.info("Default admin user inserted (or already exists).") + except Exception as e: + logging.error(f"Failed to insert default admin user: {e}", exc_info=True) + async def list_users(all_info=False): """ Retrieves a list of users from the users database. diff --git a/nebula/database/init-configs.sql b/nebula/database/init-configs.sql index 0d6d3631d..a34b17841 100644 --- a/nebula/database/init-configs.sql +++ b/nebula/database/init-configs.sql @@ -61,9 +61,3 @@ CREATE TABLE IF NOT EXISTS notes ( scenario TEXT PRIMARY KEY, scenario_notes TEXT ); - --- 6) Insert the default 'admin' user with a hashed password --- The hash must be generated by a Python script using passlib. --- Replace the placeholder with your generated hash. -INSERT INTO users ("user", password, role) VALUES ('ADMIN', '$argon2id$v=19$m=65536,t=3,p=4$OobPh8BkZeT6D5s+Rt11mQ$JjI2M3U5+4lupdr87/GrIn46ImzoQujNEyVd7IGYiXY', 'admin') -ON CONFLICT ("user") DO NOTHING; From d3cfdc5c796f3206b082487f9818390cf90af9c4 Mon Sep 17 00:00:00 2001 From: FerTV Date: Fri, 18 Jul 2025 12:16:35 +0200 Subject: [PATCH 19/20] chore: remove example --- nebula/controller/database.py | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/nebula/controller/database.py b/nebula/controller/database.py index 8fb1b7cc9..f7d04565d 100755 --- a/nebula/controller/database.py +++ b/nebula/controller/database.py @@ -632,30 +632,3 @@ async def remove_note(scenario): """ async with POOL.acquire() as conn: await conn.execute("DELETE FROM notes WHERE scenario = $1;", scenario) - - -if __name__ == "__main__": - # This is an example of how to use the new pool-based functions. - # In a real application, init_db_pool() would be called at startup, - # and close_db_pool() at shutdown. - - async def main(): - # Set environment variables for local testing if not already set - os.environ.setdefault('DB_USER', 'your_user') - os.environ.setdefault('DB_PASSWORD', 'your_password') - os.environ.setdefault('DB_HOST', 'localhost') - os.environ.setdefault('DB_PORT', '5432') - - logging.basicConfig(level=logging.INFO) - - await init_db_pool() - try: - # Example: list all users - users = await list_users() - logging.info(f"Found users: {users}") - finally: - await close_db_pool() - - # To run this example: - # asyncio.run(main()) - pass From 397aabd241d5e31d3e75f7ebe787a573d530a562 Mon Sep 17 00:00:00 2001 From: FerTV Date: Tue, 22 Jul 2025 11:55:08 +0200 Subject: [PATCH 20/20] fix: remove scenario and relaunch with user role fixed --- nebula/controller/controller.py | 5 +++-- nebula/controller/database.py | 1 + nebula/frontend/app.py | 8 ++++---- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/nebula/controller/controller.py b/nebula/controller/controller.py index a134fb034..7e59ae7ec 100755 --- a/nebula/controller/controller.py +++ b/nebula/controller/controller.py @@ -533,8 +533,9 @@ async def get_running_scenario(get_all: bool = False): raise HTTPException(status_code=500, detail="Internal server error") -@app.get("/scenarios/check/{role}/{scenario_name}") +@app.get("/scenarios/check/{user}/{role}/{scenario_name}") async def check_scenario( + user: Annotated[str, Path(regex="^[a-zA-Z0-9_-]+$", min_length=1, max_length=50, description="Valid username")], role: Annotated[str, Path(regex="^[a-zA-Z0-9_-]+$", min_length=1, max_length=50, description="Valid role")], scenario_name: Annotated[ str, Path(regex="^[a-zA-Z0-9_-]+$", min_length=1, max_length=50, description="Valid scenario name") @@ -553,7 +554,7 @@ async def check_scenario( from nebula.controller.database import check_scenario_with_role try: - allowed = await check_scenario_with_role(role, scenario_name) + allowed = await check_scenario_with_role(role, scenario_name, user) return {"allowed": allowed} except Exception as e: logging.exception(f"Error checking scenario with role: {e}") diff --git a/nebula/controller/database.py b/nebula/controller/database.py index f7d04565d..407ce3908 100755 --- a/nebula/controller/database.py +++ b/nebula/controller/database.py @@ -592,6 +592,7 @@ async def check_scenario_with_role(role, scenario_name, current_username=None): return True # Admins can access any existing scenario if current_username is None: + logging.info(f"[FER] db username {scenario_info.get('username')} current_username {current_username}") logging.warning( "check_scenario_with_role called for non-admin role without current_username." ) diff --git a/nebula/frontend/app.py b/nebula/frontend/app.py index 84d4b1ca6..42f5b3a68 100755 --- a/nebula/frontend/app.py +++ b/nebula/frontend/app.py @@ -687,7 +687,7 @@ async def remove_scenario_by_name(scenario_name): await controller_post(url, data) -async def check_scenario_with_role(role, scenario_name): +async def check_scenario_with_role(role, scenario_name, user): """ Check if a specific scenario is allowed for the session's role. @@ -701,7 +701,7 @@ async def check_scenario_with_role(role, scenario_name): Raises: HTTPException: If the underlying HTTP GET request fails. """ - url = f"http://{settings.controller_host}:{settings.controller_port}/scenarios/check/{role}/{scenario_name}" + url = f"http://{settings.controller_host}:{settings.controller_port}/scenarios/check/{user}/{role}/{scenario_name}" check_data = await controller_get(url) return check_data.get("allowed", False) @@ -1883,7 +1883,7 @@ async def nebula_relaunch_scenario( if session["role"] == "demo": raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) elif session["role"] == "user": - if not await check_scenario_with_role(session["role"], scenario_name): + if not await check_scenario_with_role(session["role"], scenario_name, session["user"]): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) scenario_path = FileUtils.check_path(settings.config_dir, os.path.join(scenario_name, "scenario.json")) @@ -1924,7 +1924,7 @@ async def nebula_remove_scenario(scenario_name: str, session: dict = Depends(get if session["role"] == "demo": raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) elif session["role"] == "user": - if not await check_scenario_with_role(session["role"], scenario_name): + if not await check_scenario_with_role(session["role"], scenario_name, session["user"]): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) await remove_scenario(scenario_name, session["user"]) return RedirectResponse(url="/platform/dashboard")