Skip to content

Commit 449cd12

Browse files
[Feature] Include metadata files and tags for deployment/shutdown scenarios (#50)
* chore: remove unused components and variables, improve dev mode * chore: new metadata logic and improve deployment and shutdown of Docker containers * chore: check docker containers exists with the same name * chore: save prefix in env file * chore: save prefix in env file * chore: reorganize variables in the init * chore: reorganize variables in the init * chore: fix the check of nebula production * chore: upgrade tag system for deployment * chore: update logging exceptions * chore: fix some issues with docket networks
1 parent 21991bb commit 449cd12

14 files changed

Lines changed: 620 additions & 612 deletions

File tree

app/deployer.py

Lines changed: 368 additions & 116 deletions
Large diffs are not rendered by default.

app/main.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -64,38 +64,36 @@
6464
help="Stop NEBULA platform or nodes only (use '--stop nodes' to stop only the nodes)",
6565
)
6666

67-
argparser.add_argument("-s", "--simulation", action="store_false", dest="simulation", help="Run simulation")
68-
6967
argparser.add_argument(
7068
"-c",
7169
"--config",
7270
dest="config",
7371
default=os.path.join(os.path.dirname(os.path.abspath(__file__)), "config"),
74-
help="Config directory path",
72+
help="NEBULA config directory path",
7573
)
7674

7775
argparser.add_argument(
7876
"-d",
7977
"--database",
8078
dest="databases",
8179
default=os.path.join(os.path.dirname(os.path.abspath(__file__)), "databases"),
82-
help="Nebula databases path",
80+
help="NEBULA databases directory path",
8381
)
8482

8583
argparser.add_argument(
8684
"-l",
8785
"--logs",
8886
dest="logs",
8987
default=os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs"),
90-
help="Logs directory path",
88+
help="NEBULA logs directory path",
9189
)
9290

9391
argparser.add_argument(
9492
"-ce",
9593
"--certs",
9694
dest="certs",
9795
default=os.path.join(os.path.dirname(os.path.abspath(__file__)), "certs"),
98-
help="Certs directory path",
96+
help="NEBULA certs directory path",
9997
)
10098

10199
argparser.add_argument(
@@ -106,24 +104,21 @@
106104
help=".env file path",
107105
)
108106

109-
argparser.add_argument("-dev", "--developement", dest="developement", default=True, help="Nebula for devs")
110-
111107
argparser.add_argument(
112108
"-p",
113109
"--production",
114110
dest="production",
115111
action="store_true",
116112
default=False,
117-
help="Production mode",
113+
help="Deploy NEBULA in production mode",
118114
)
119115

120116
argparser.add_argument(
121-
"-ad",
122-
"--advanced",
123-
dest="advanced_analytics",
124-
action="store_true",
125-
default=False,
126-
help="Advanced analytics",
117+
"-pr",
118+
"--prefix",
119+
dest="prefix",
120+
default="dev",
121+
help="Deploy NEBULA components with a prefix",
127122
)
128123

129124
argparser.add_argument(

nebula/controller/controller.py

Lines changed: 37 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -264,24 +264,24 @@ async def get_available_gpu():
264264

265265
def validate_physical_fields(data: dict):
266266
if data.get("deployment") != "physical":
267-
return
268-
267+
return
268+
269269
ips = data.get("physical_ips")
270270
if not ips:
271271
raise HTTPException(
272272
status_code=400,
273273
detail="physical deployment requires 'physical_ips'"
274274
)
275-
275+
276276
if len(ips) != data.get("n_nodes"):
277277
raise HTTPException(
278278
status_code=400,
279279
detail="'physical_ips' must have the same length as 'n_nodes'"
280280
)
281-
281+
282282
try:
283283
for ip in ips:
284-
ipaddress.ip_address(ip)
284+
ipaddress.ip_address(ip)
285285
print(ip)
286286
except ValueError as e:
287287
raise HTTPException(status_code=400, detail=str(e))
@@ -347,31 +347,27 @@ async def stop_scenario(
347347
):
348348
"""
349349
Stops the execution of a federated learning scenario and performs cleanup operations.
350-
350+
351351
This endpoint:
352352
- Stops all participant containers associated with the specified scenario.
353353
- Removes Docker containers and network resources tied to the scenario and user.
354354
- Sets the scenario's status to "finished" in the database.
355355
- Optionally finalizes all active scenarios if the 'all' flag is set.
356-
356+
357357
Args:
358358
scenario_name (str): Name of the scenario to stop.
359359
username (str): User who initiated the stop operation.
360360
all (bool): Whether to stop all running scenarios instead of just one (default: False).
361-
361+
362362
Raises:
363363
HTTPException: Returns a 500 status code if any step fails.
364-
364+
365365
Note:
366366
This function does not currently trigger statistics generation.
367367
"""
368368
from nebula.controller.scenarios import ScenarioManagement
369369

370-
# ScenarioManagement.stop_participants(scenario_name)
371-
DockerUtils.remove_containers_by_prefix(f"{os.environ.get('NEBULA_CONTROLLER_NAME')}_{username}-participant")
372-
DockerUtils.remove_docker_network(
373-
f"{(os.environ.get('NEBULA_CONTROLLER_NAME'))}_{str(username).lower()}-nebula-net-scenario"
374-
)
370+
ScenarioManagement.cleanup_scenario_containers()
375371
try:
376372
if all:
377373
scenario_set_all_status_to_finished()
@@ -847,27 +843,27 @@ async def discover_vpn():
847843
stdout=asyncio.subprocess.PIPE,
848844
stderr=asyncio.subprocess.PIPE,
849845
)
850-
846+
851847
# 2) Wait for it to finish and capture stdout/stderr
852848
out, err = await proc.communicate()
853849
if proc.returncode != 0:
854850
# If the CLI returned an error, raise to be caught below
855851
raise RuntimeError(err.decode())
856-
852+
857853
# 3) Parse the JSON output
858854
data = json.loads(out.decode())
859-
855+
860856
# 4) Collect only the IPv4 addresses from each peer
861857
ips = []
862858
for peer in data.get("Peer", {}).values():
863859
for ip in peer.get("TailscaleIPs", []):
864-
if ":" not in ip:
860+
if ":" not in ip:
865861
# Skip IPv6 entries (they contain colons)
866862
ips.append(ip)
867-
863+
868864
# 5) Return the list of IPv4s
869865
return {"ips": ips}
870-
866+
871867
except Exception as e:
872868
# 6) Log any failure and respond with HTTP 500
873869
logging.error(f"Error discovering VPN devices: {e}")
@@ -877,14 +873,14 @@ async def discover_vpn():
877873
@app.get("/physical/run/{ip}", tags=["physical"])
878874
async def physical_run(ip: str):
879875
status, data = await remote_get(ip, "/run/")
880-
876+
881877
if status == 200:
882878
return data
883879
if status is None:
884880
raise HTTPException(status_code=502, detail=f"Node unreachable: {data}")
885881
raise HTTPException(status_code=status, detail=data)
886-
887-
882+
883+
888884
@app.get("/physical/stop/{ip}", tags=["physical"])
889885
async def physical_stop(ip: str):
890886
status, data = await remote_get(ip, "/stop/")
@@ -893,8 +889,8 @@ async def physical_stop(ip: str):
893889
if status is None:
894890
raise HTTPException(status_code=502, detail=f"Node unreachable: {data}")
895891
raise HTTPException(status_code=status, detail=data)
896-
897-
892+
893+
898894
@app.put("/physical/setup/{ip}", tags=["physical"],
899895
status_code=status.HTTP_201_CREATED)
900896
async def physical_setup(
@@ -903,7 +899,7 @@ async def physical_setup(
903899
global_test: UploadFile = File(..., description="Global Dataset*.h5*"),
904900
train_set: UploadFile = File(..., description="Training dataset*.h5*"),
905901
):
906-
902+
907903
form = aiohttp.FormData()
908904
await config.seek(0)
909905
form.add_field("config", config.file,
@@ -914,40 +910,40 @@ async def physical_setup(
914910
await train_set.seek(0)
915911
form.add_field("train_set", train_set.file,
916912
filename=train_set.filename, content_type="application/octet-stream")
917-
913+
918914
status_code, data = await remote_post_form(
919915
ip, "/setup/", form, method="PUT"
920916
)
921-
917+
922918
if status_code == 201:
923919
return data
924920
if status_code is None:
925921
raise HTTPException(status_code=502, detail=f"Node unreachable: {data}")
926922
raise HTTPException(status_code=status_code, detail=data)
927-
923+
928924
# ──────────────────────────────────────────────────────────────
929925
# Physical · single-node state
930926
# ──────────────────────────────────────────────────────────────
931927
@app.get("/physical/state/{ip}", tags=["physical"])
932928
async def get_physical_node_state(ip: str):
933929
"""
934930
Query a single Raspberry Pi (or other node) for its training state.
935-
931+
936932
Parameters
937933
----------
938934
ip : str
939935
IP address or hostname of the node.
940-
936+
941937
Returns
942938
-------
943939
dict
944-
• running (bool) – True if a training process is active.
940+
• running (bool) – True if a training process is active.
945941
• error (str) – Optional error message when the node is unreachable
946942
or returns a non-200 HTTP status.
947943
"""
948944
# Short global timeout so a dead node doesn't block the whole request
949945
timeout = aiohttp.ClientTimeout(total=3) # seconds
950-
946+
951947
try:
952948
async with aiohttp.ClientSession(timeout=timeout) as session:
953949
async with session.get(f"http://{ip}/state/") as resp:
@@ -960,21 +956,21 @@ async def get_physical_node_state(ip: str):
960956
except Exception as exc:
961957
# Network errors, timeouts, DNS failures, …
962958
return {"running": False, "error": str(exc)}
963-
964-
959+
960+
965961
# ──────────────────────────────────────────────────────────────
966962
# Physical · aggregate state for an entire scenario
967963
# ──────────────────────────────────────────────────────────────
968964
@app.get("/physical/scenario-state/{scenario_name}", tags=["physical"])
969965
async def get_physical_scenario_state(scenario_name: str):
970966
"""
971967
Check the training state of *every* physical node assigned to a scenario.
972-
968+
973969
Parameters
974970
----------
975971
scenario_name : str
976972
Scenario identifier.
977-
973+
978974
Returns
979975
-------
980976
dict
@@ -989,16 +985,16 @@ async def get_physical_scenario_state(scenario_name: str):
989985
scenario = await get_scenario_by_name(scenario_name)
990986
if not scenario:
991987
raise HTTPException(status_code=404, detail="Scenario not found")
992-
988+
993989
nodes = await list_nodes_by_scenario_name(scenario_name)
994990
if not nodes:
995991
raise HTTPException(status_code=404, detail="No nodes found for scenario")
996-
992+
997993
# 2) Probe all nodes concurrently
998994
ips = [n["ip"] for n in nodes]
999995
tasks = [get_physical_node_state(ip) for ip in ips]
1000996
states = await asyncio.gather(*tasks) # parallel HTTP calls
1001-
997+
1002998
# 3) Aggregate results
1003999
nodes_state = dict(zip(ips, states))
10041000
any_running = any(s.get("running") for s in states)
@@ -1007,7 +1003,7 @@ async def get_physical_scenario_state(scenario_name: str):
10071003
all_available = all(
10081004
(not s.get("running")) and (not s.get("error")) for s in states
10091005
)
1010-
1006+
10111007
return {
10121008
"running": any_running,
10131009
"nodes_state": nodes_state,

0 commit comments

Comments
 (0)