Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
484 changes: 368 additions & 116 deletions app/deployer.py

Large diffs are not rendered by default.

25 changes: 10 additions & 15 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,38 +64,36 @@
help="Stop NEBULA platform or nodes only (use '--stop nodes' to stop only the nodes)",
)

argparser.add_argument("-s", "--simulation", action="store_false", dest="simulation", help="Run simulation")

argparser.add_argument(
"-c",
"--config",
dest="config",
default=os.path.join(os.path.dirname(os.path.abspath(__file__)), "config"),
help="Config directory path",
help="NEBULA config directory path",
)

argparser.add_argument(
"-d",
"--database",
dest="databases",
default=os.path.join(os.path.dirname(os.path.abspath(__file__)), "databases"),
help="Nebula databases path",
help="NEBULA databases directory path",
)

argparser.add_argument(
"-l",
"--logs",
dest="logs",
default=os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs"),
help="Logs directory path",
help="NEBULA logs directory path",
)

argparser.add_argument(
"-ce",
"--certs",
dest="certs",
default=os.path.join(os.path.dirname(os.path.abspath(__file__)), "certs"),
help="Certs directory path",
help="NEBULA certs directory path",
)

argparser.add_argument(
Expand All @@ -106,24 +104,21 @@
help=".env file path",
)

argparser.add_argument("-dev", "--developement", dest="developement", default=True, help="Nebula for devs")

argparser.add_argument(
"-p",
"--production",
dest="production",
action="store_true",
default=False,
help="Production mode",
help="Deploy NEBULA in production mode",
)

argparser.add_argument(
"-ad",
"--advanced",
dest="advanced_analytics",
action="store_true",
default=False,
help="Advanced analytics",
"-pr",
"--prefix",
dest="prefix",
default="dev",
help="Deploy NEBULA components with a prefix",
)

argparser.add_argument(
Expand Down
78 changes: 37 additions & 41 deletions nebula/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,24 +264,24 @@ async def get_available_gpu():

def validate_physical_fields(data: dict):
if data.get("deployment") != "physical":
return
return

ips = data.get("physical_ips")
if not ips:
raise HTTPException(
status_code=400,
detail="physical deployment requires 'physical_ips'"
)

if len(ips) != data.get("n_nodes"):
raise HTTPException(
status_code=400,
detail="'physical_ips' must have the same length as 'n_nodes'"
)

try:
for ip in ips:
ipaddress.ip_address(ip)
ipaddress.ip_address(ip)
print(ip)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
Expand Down Expand Up @@ -347,31 +347,27 @@ async def stop_scenario(
):
"""
Stops the execution of a federated learning scenario and performs cleanup operations.

This endpoint:
- Stops all participant containers associated with the specified scenario.
- Removes Docker containers and network resources tied to the scenario and user.
- Sets the scenario's status to "finished" in the database.
- Optionally finalizes all active scenarios if the 'all' flag is set.

Args:
scenario_name (str): Name of the scenario to stop.
username (str): User who initiated the stop operation.
all (bool): Whether to stop all running scenarios instead of just one (default: False).

Raises:
HTTPException: Returns a 500 status code if any step fails.

Note:
This function does not currently trigger statistics generation.
"""
from nebula.controller.scenarios import ScenarioManagement

# ScenarioManagement.stop_participants(scenario_name)
DockerUtils.remove_containers_by_prefix(f"{os.environ.get('NEBULA_CONTROLLER_NAME')}_{username}-participant")
DockerUtils.remove_docker_network(
f"{(os.environ.get('NEBULA_CONTROLLER_NAME'))}_{str(username).lower()}-nebula-net-scenario"
)
ScenarioManagement.cleanup_scenario_containers()
try:
if all:
scenario_set_all_status_to_finished()
Expand Down Expand Up @@ -847,27 +843,27 @@ async def discover_vpn():
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)

# 2) Wait for it to finish and capture stdout/stderr
out, err = await proc.communicate()
if proc.returncode != 0:
# If the CLI returned an error, raise to be caught below
raise RuntimeError(err.decode())

# 3) Parse the JSON output
data = json.loads(out.decode())

# 4) Collect only the IPv4 addresses from each peer
ips = []
for peer in data.get("Peer", {}).values():
for ip in peer.get("TailscaleIPs", []):
if ":" not in ip:
if ":" not in ip:
# Skip IPv6 entries (they contain colons)
ips.append(ip)

# 5) Return the list of IPv4s
return {"ips": ips}

except Exception as e:
# 6) Log any failure and respond with HTTP 500
logging.error(f"Error discovering VPN devices: {e}")
Expand All @@ -877,14 +873,14 @@ async def discover_vpn():
@app.get("/physical/run/{ip}", tags=["physical"])
async def physical_run(ip: str):
status, data = await remote_get(ip, "/run/")

if status == 200:
return data
if status is None:
raise HTTPException(status_code=502, detail=f"Node unreachable: {data}")
raise HTTPException(status_code=status, detail=data)


@app.get("/physical/stop/{ip}", tags=["physical"])
async def physical_stop(ip: str):
status, data = await remote_get(ip, "/stop/")
Expand All @@ -893,8 +889,8 @@ async def physical_stop(ip: str):
if status is None:
raise HTTPException(status_code=502, detail=f"Node unreachable: {data}")
raise HTTPException(status_code=status, detail=data)


@app.put("/physical/setup/{ip}", tags=["physical"],
status_code=status.HTTP_201_CREATED)
async def physical_setup(
Expand All @@ -903,7 +899,7 @@ async def physical_setup(
global_test: UploadFile = File(..., description="Global Dataset*.h5*"),
train_set: UploadFile = File(..., description="Training dataset*.h5*"),
):

form = aiohttp.FormData()
await config.seek(0)
form.add_field("config", config.file,
Expand All @@ -914,40 +910,40 @@ async def physical_setup(
await train_set.seek(0)
form.add_field("train_set", train_set.file,
filename=train_set.filename, content_type="application/octet-stream")

status_code, data = await remote_post_form(
ip, "/setup/", form, method="PUT"
)

if status_code == 201:
return data
if status_code is None:
raise HTTPException(status_code=502, detail=f"Node unreachable: {data}")
raise HTTPException(status_code=status_code, detail=data)

# ──────────────────────────────────────────────────────────────
# Physical · single-node state
# ──────────────────────────────────────────────────────────────
@app.get("/physical/state/{ip}", tags=["physical"])
async def get_physical_node_state(ip: str):
"""
Query a single Raspberry Pi (or other node) for its training state.

Parameters
----------
ip : str
IP address or hostname of the node.

Returns
-------
dict
• running (bool) – True if a training process is active.
• running (bool) – True if a training process is active.
• error (str) – Optional error message when the node is unreachable
or returns a non-200 HTTP status.
"""
# Short global timeout so a dead node doesn't block the whole request
timeout = aiohttp.ClientTimeout(total=3) # seconds

try:
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(f"http://{ip}/state/") as resp:
Expand All @@ -960,21 +956,21 @@ async def get_physical_node_state(ip: str):
except Exception as exc:
# Network errors, timeouts, DNS failures, …
return {"running": False, "error": str(exc)}


# ──────────────────────────────────────────────────────────────
# Physical · aggregate state for an entire scenario
# ──────────────────────────────────────────────────────────────
@app.get("/physical/scenario-state/{scenario_name}", tags=["physical"])
async def get_physical_scenario_state(scenario_name: str):
"""
Check the training state of *every* physical node assigned to a scenario.

Parameters
----------
scenario_name : str
Scenario identifier.

Returns
-------
dict
Expand All @@ -989,16 +985,16 @@ async def get_physical_scenario_state(scenario_name: str):
scenario = await get_scenario_by_name(scenario_name)
if not scenario:
raise HTTPException(status_code=404, detail="Scenario not found")

nodes = await list_nodes_by_scenario_name(scenario_name)
if not nodes:
raise HTTPException(status_code=404, detail="No nodes found for scenario")

# 2) Probe all nodes concurrently
ips = [n["ip"] for n in nodes]
tasks = [get_physical_node_state(ip) for ip in ips]
states = await asyncio.gather(*tasks) # parallel HTTP calls

# 3) Aggregate results
nodes_state = dict(zip(ips, states))
any_running = any(s.get("running") for s in states)
Expand All @@ -1007,7 +1003,7 @@ async def get_physical_scenario_state(scenario_name: str):
all_available = all(
(not s.get("running")) and (not s.get("error")) for s in states
)

return {
"running": any_running,
"nodes_state": nodes_state,
Expand Down
Loading