Skip to content

Commit 5db6f31

Browse files
committed
Update DGXCloudExecutor for improved workload management
- Renamed app_id and app_secret to client_id and client_secret for clarity. - Introduced new methods for deleting workloads and checking workspace status. - Enhanced data movement functionality to use a tarball when within character limits, falling back to individual file deployment otherwise. - Updated RayCluster and RayJob to integrate DGXCloudExecutor and its corresponding classes. Fixes #478 Signed-off-by: Rakesh Paul <rapaul@nvidia.com>
1 parent 91c8c2a commit 5db6f31

File tree

4 files changed

+572
-82
lines changed

4 files changed

+572
-82
lines changed

nemo_run/core/execution/dgxcloud.py

Lines changed: 204 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515

1616
import base64
1717
import glob
18+
import gzip
1819
import json
1920
import logging
2021
import os
22+
import shutil
2123
import subprocess
2224
import tempfile
2325
import time
@@ -70,8 +72,8 @@ class DGXCloudExecutor(Executor):
7072

7173
base_url: str
7274
kube_apiserver_url: str
73-
app_id: str
74-
app_secret: str
75+
client_id: str
76+
client_secret: str
7577
project_name: str
7678
container_image: str
7779
pvc_nemo_run_dir: str
@@ -83,13 +85,14 @@ class DGXCloudExecutor(Executor):
8385
pvcs: list[dict[str, Any]] = field(default_factory=list)
8486
distributed_framework: str = "PyTorch"
8587
custom_spec: dict[str, Any] = field(default_factory=dict)
88+
MAX_ARGS_CHARS: int = 9500
8689

8790
def get_auth_token(self) -> Optional[str]:
8891
url = f"{self.base_url}/token"
8992
payload = {
90-
"grantType": "app_token",
91-
"appId": self.app_id,
92-
"appSecret": self.app_secret,
93+
"grantType": "client_credentials",
94+
"clientId": self.client_id,
95+
"clientSecret": self.client_secret,
9396
}
9497

9598
n_attempts = 0
@@ -138,18 +141,46 @@ def copy_directory_data_command(self, local_dir_path: str, dest_path: str) -> st
138141
cmd = f"rm -rf {dest_path} && mkdir -p {dest_path} && echo {encoded_data} | base64 -d > {dest_path}/archive.tar.gz && tar -xzf {dest_path}/archive.tar.gz -C {dest_path} && rm {dest_path}/archive.tar.gz"
139142
return cmd
140143

141-
def create_data_mover_workload(self, token: str, project_id: str, cluster_id: str):
142-
"""
143-
Creates a CPU only workload to move job directory into PVC using the provided project/cluster IDs.
144-
"""
144+
def delete_workload(self, token: str, workload_id: str):
145+
url = f"{self.base_url}/workloads/workspaces/{workload_id}"
146+
headers = self._default_headers(token=token)
145147

146-
cmd = self.copy_directory_data_command(self.job_dir, self.pvc_job_dir)
148+
response = requests.delete(url, headers=headers)
147149

148-
url = f"{self.base_url}/workloads/workspaces"
150+
logger.debug(
151+
"Delete interactive workspace; response code=%s, content=%s",
152+
response.status_code,
153+
response.text.strip(),
154+
)
155+
return response
156+
157+
def _workspace_status(self, workload_id: str) -> Optional[DGXCloudState]:
158+
"""Query workspace-specific status endpoint for data-mover workloads."""
159+
url = f"{self.base_url}/workloads/workspaces/{workload_id}"
160+
token = self.get_auth_token()
161+
if not token:
162+
return None
149163
headers = self._default_headers(token=token)
164+
response = requests.get(url, headers=headers)
165+
if response.status_code != 200:
166+
return None
167+
data = response.json()
168+
phase = data.get("actualPhase") or data.get("phase")
169+
return DGXCloudState(phase) if phase else None
150170

171+
def _run_workspace_and_wait(
172+
self,
173+
token: str,
174+
project_id: str,
175+
cluster_id: str,
176+
name: str,
177+
cmd: str,
178+
sleep: float = 10,
179+
timeout: int = 300,
180+
) -> None:
181+
"""Create a workspace workload, poll until done, then delete it."""
151182
payload = {
152-
"name": "data-mover",
183+
"name": name,
153184
"useGivenNameAsPrefix": True,
154185
"projectId": project_id,
155186
"clusterId": cluster_id,
@@ -160,76 +191,94 @@ def create_data_mover_workload(self, token: str, project_id: str, cluster_id: st
160191
"storage": {"pvc": self.pvcs},
161192
},
162193
}
163-
164-
response = requests.post(url, json=payload, headers=headers)
165-
166-
logger.debug(
167-
"Created workload; response code=%s, content=%s",
168-
response.status_code,
169-
response.text.strip(),
170-
)
171-
172-
return response
173-
174-
def delete_workload(self, token: str, workload_id: str):
175-
url = f"{self.base_url}/workloads/workspaces/{workload_id}"
176194
headers = self._default_headers(token=token)
177-
178-
response = requests.delete(url, headers=headers)
179-
180-
logger.debug(
181-
"Delete interactive workspace; response code=%s, content=%s",
182-
response.status_code,
183-
response.text.strip(),
184-
)
185-
return response
195+
resp = requests.post(f"{self.base_url}/workloads/workspaces", json=payload, headers=headers)
196+
if resp.status_code not in (200, 202):
197+
raise RuntimeError(f"Workload '{name}' failed: {resp.status_code} {resp.text}")
198+
wid = resp.json()["workloadId"]
199+
logger.info(" workload %s (%s) created", name, wid[:12])
200+
201+
elapsed = 0
202+
while elapsed < timeout:
203+
time.sleep(sleep)
204+
elapsed += sleep
205+
status = self._workspace_status(wid)
206+
if status == DGXCloudState.COMPLETED:
207+
self.delete_workload(token, wid)
208+
return
209+
if status in (DGXCloudState.FAILED, DGXCloudState.STOPPED, DGXCloudState.DEGRADED):
210+
self.delete_workload(token, wid)
211+
raise RuntimeError(f"Workload {wid} ended with: {status}")
212+
raise RuntimeError(f"Workload {wid} timed out after {timeout}s")
186213

187214
def move_data(self, token: str, project_id: str, cluster_id: str, sleep: float = 10) -> None:
188-
"""
189-
Moves job directory into PVC and deletes the workload after completion
190-
"""
191-
192-
resp = self.create_data_mover_workload(token, project_id, cluster_id)
193-
if resp.status_code not in [200, 202]:
194-
raise RuntimeError(
195-
f"Failed to create data mover workload, status_code={resp.status_code}, reason={resp.text}"
196-
)
197-
198-
resp_json = resp.json()
199-
workload_id = resp_json["workloadId"]
200-
status = DGXCloudState(resp_json["actualPhase"])
215+
"""Move job directory into PVC.
201216
202-
logger.info(f"Successfully created data movement workload {workload_id} on DGXCloud")
203-
204-
while status in [
205-
DGXCloudState.PENDING,
206-
DGXCloudState.CREATING,
207-
DGXCloudState.INITIALIZING,
208-
DGXCloudState.RUNNING,
209-
]:
210-
time.sleep(sleep)
211-
status = self.status(workload_id)
212-
logger.debug(
213-
f"Polling data movement workload {workload_id}'s status. Current status is: {status}"
214-
)
217+
Uses the fast single-command tarball when it fits within the API's
218+
10 000-char limit. Falls back to per-file deployment otherwise.
219+
"""
220+
cmd = self.copy_directory_data_command(self.job_dir, self.pvc_job_dir)
215221

216-
if status is not DGXCloudState.COMPLETED:
217-
raise RuntimeError(f"Failed to move data to PVC. Workload status is {status}")
222+
if len(cmd) <= self.MAX_ARGS_CHARS:
223+
self._run_workspace_and_wait(token, project_id, cluster_id, "data-mover", cmd, sleep)
224+
return
218225

219-
resp = self.delete_workload(token, workload_id)
220-
if resp.status_code >= 200 and resp.status_code < 300:
221-
logger.info(
222-
"Successfully deleted data movement workload %s on DGXCloud with response code %d",
223-
workload_id,
224-
resp.status_code,
225-
)
226-
else:
227-
logger.error(
228-
"Failed to delete data movement workload %s, response code=%d, reason=%s",
229-
workload_id,
230-
resp.status_code,
231-
resp.text,
232-
)
226+
logger.info(
227+
"Tarball is %d chars (limit %d), deploying files individually",
228+
len(cmd),
229+
self.MAX_ARGS_CHARS,
230+
)
231+
for root, _, filenames in os.walk(self.job_dir):
232+
for fn in filenames:
233+
if fn.endswith(".tar.gz"):
234+
continue
235+
abs_path = os.path.join(root, fn)
236+
rel_path = os.path.relpath(abs_path, self.job_dir)
237+
dest = os.path.join(self.pvc_job_dir, rel_path)
238+
with open(abs_path, "rb") as f:
239+
data = f.read()
240+
241+
compressed = gzip.compress(data, compresslevel=9)
242+
encoded = base64.b64encode(compressed).decode()
243+
overhead = len(f"mkdir -p $(dirname {dest}) && echo | base64 -d | gunzip > {dest}")
244+
chunk_b64_limit = self.MAX_ARGS_CHARS - overhead - 50
245+
246+
if len(encoded) <= chunk_b64_limit:
247+
file_cmd = f"mkdir -p $(dirname {dest}) && echo {encoded} | base64 -d | gunzip > {dest}"
248+
logger.info(
249+
" deploying %s (%d→%d bytes)", rel_path, len(data), len(compressed)
250+
)
251+
self._run_workspace_and_wait(
252+
token, project_id, cluster_id, "data-mover", file_cmd, sleep
253+
)
254+
else:
255+
chunk_size = (chunk_b64_limit * 3) // 4
256+
raw_chunks = [
257+
compressed[i : i + chunk_size]
258+
for i in range(0, len(compressed), chunk_size)
259+
]
260+
logger.info(
261+
" deploying %s in %d chunks (%d→%d bytes)",
262+
rel_path,
263+
len(raw_chunks),
264+
len(data),
265+
len(compressed),
266+
)
267+
for ci, chunk in enumerate(raw_chunks):
268+
b64 = base64.b64encode(chunk).decode()
269+
if ci == 0:
270+
file_cmd = (
271+
f"mkdir -p $(dirname {dest}) && echo {b64} | base64 -d > {dest}.gz"
272+
)
273+
else:
274+
file_cmd = f"echo {b64} | base64 -d >> {dest}.gz"
275+
self._run_workspace_and_wait(
276+
token, project_id, cluster_id, "data-mover", file_cmd, sleep
277+
)
278+
gunzip_cmd = f"gunzip -f {dest}.gz"
279+
self._run_workspace_and_wait(
280+
token, project_id, cluster_id, "data-mover", gunzip_cmd, sleep
281+
)
233282

234283
def create_training_job(
235284
self, token: str, project_id: str, cluster_id: str, name: str
@@ -272,7 +321,7 @@ def create_training_job(
272321
common_spec = {
273322
"command": f"/bin/bash {self.pvc_job_dir}/launch_script.sh",
274323
"image": self.container_image,
275-
"compute": {"gpuDevicesRequest": self.gpus_per_node},
324+
"compute": {"gpuDevicesRequest": self.gpus_per_node, "largeShmRequest": True},
276325
"storage": {"pvc": self.pvcs},
277326
"environmentVariables": [
278327
{"name": key, "value": value} for key, value in self.env_vars.items()
@@ -321,6 +370,17 @@ def launch(self, name: str, cmd: list[str]) -> tuple[str, str]:
321370
if not project_id or not cluster_id:
322371
raise RuntimeError("Unable to determine project/cluster IDs for job submission")
323372

373+
# Copy experiment-level files referenced in cmd into job_dir
374+
# so they are included in the data mover transfer to the PVC
375+
cmd_str = " ".join(cmd)
376+
for fname in os.listdir(self.experiment_dir):
377+
fpath = os.path.join(self.experiment_dir, fname)
378+
if os.path.isfile(fpath) and fpath in cmd_str:
379+
shutil.copy2(fpath, os.path.join(self.job_dir, fname))
380+
381+
# Rewrite local paths in cmd to point to the PVC job directory
382+
cmd = [c.replace(self.experiment_dir, self.pvc_job_dir) for c in cmd]
383+
324384
# prepare launch script and move data to PVC
325385
launch_script = f"""
326386
ln -s {self.pvc_job_dir}/ /nemo_run
@@ -390,9 +450,37 @@ def fetch_logs(
390450
stderr: Optional[bool] = None,
391451
stdout: Optional[bool] = None,
392452
) -> Iterable[str]:
393-
while self.status(job_id) != DGXCloudState.RUNNING:
394-
logger.info("Waiting for job to start...")
453+
state = self.status(job_id)
454+
while state != DGXCloudState.RUNNING:
455+
logger.info("Job %s — status: %s", job_id[:12], state.value if state else "Unknown")
456+
if state in (
457+
DGXCloudState.COMPLETED,
458+
DGXCloudState.FAILED,
459+
DGXCloudState.STOPPED,
460+
DGXCloudState.DEGRADED,
461+
):
462+
logger.warning("Job reached terminal state %s before logs were available", state)
463+
return
395464
time.sleep(15)
465+
state = self.status(job_id)
466+
467+
if not self.launched_from_cluster:
468+
logger.info("Job %s is RUNNING. Logs are available in the Run:AI UI.", job_id[:12])
469+
terminal = (
470+
DGXCloudState.COMPLETED,
471+
DGXCloudState.FAILED,
472+
DGXCloudState.STOPPED,
473+
DGXCloudState.DEGRADED,
474+
)
475+
while True:
476+
time.sleep(30)
477+
state = self.status(job_id)
478+
logger.info("Job %s — status: %s", job_id[:12], state.value if state else "Unknown")
479+
if state in terminal:
480+
yield f"Job finished with status: {state.value}"
481+
return
482+
483+
logger.info("Job %s is RUNNING, waiting for log files...", job_id[:12])
396484

397485
cmd = ["tail"]
398486

@@ -405,12 +493,21 @@ def fetch_logs(
405493
self.pvc_job_dir = os.path.join(self.pvc_nemo_run_dir, job_subdir)
406494

407495
files = []
496+
poll_count = 0
408497
while len(files) < self.nodes:
409498
files = list(glob.glob(f"{self.pvc_job_dir}/log_*.out"))
410499
files = [f for f in files if "log-allranks_0" not in f]
411-
logger.info(
412-
f"Waiting for {self.nodes + 1 - len(files)} log files to be created in {self.pvc_job_dir}..."
413-
)
500+
if poll_count == 0 or poll_count % 10 == 0:
501+
logger.info(
502+
"Log files: %d/%d ready (watching %s)",
503+
len(files),
504+
self.nodes,
505+
self.pvc_job_dir,
506+
)
507+
poll_count += 1
508+
if poll_count > 100:
509+
logger.warning("Timed out waiting for log files after 5 minutes")
510+
return
414511
time.sleep(3)
415512

416513
cmd.extend(files)
@@ -526,6 +623,30 @@ def assign(
526623
)
527624
self.experiment_id = exp_id
528625

626+
def deploy_script_to_pvc(
627+
self,
628+
script_content: str,
629+
dest_path: str,
630+
token: Optional[str] = None,
631+
project_id: Optional[str] = None,
632+
cluster_id: Optional[str] = None,
633+
) -> None:
634+
"""Write a script to the PVC via a short-lived busybox workspace."""
635+
if not token:
636+
token = self.get_auth_token()
637+
if not token:
638+
raise RuntimeError("Failed to get auth token for script deployment")
639+
if not project_id or not cluster_id:
640+
project_id, cluster_id = self.get_project_and_cluster_id(token)
641+
642+
encoded = base64.b64encode(gzip.compress(script_content.encode(), compresslevel=9)).decode()
643+
cmd = (
644+
f"mkdir -p $(dirname {dest_path}) && "
645+
f"echo {encoded} | base64 -d | gunzip > {dest_path} && "
646+
f"chmod +x {dest_path}"
647+
)
648+
self._run_workspace_and_wait(token, project_id, cluster_id, "script-deploy", cmd)
649+
529650
def get_launcher_prefix(self) -> Optional[list[str]]:
530651
launcher = self.get_launcher()
531652
if launcher.nsys_profile:
@@ -574,6 +695,7 @@ def package(self, packager: Packager, job_name: str):
574695
ctx.run(
575696
f"tar -xvzf {local_pkg} -C {local_code_extraction_path} --ignore-zeros", hide=True
576697
)
698+
os.remove(local_pkg)
577699

578700
def macro_values(self) -> Optional[ExecutorMacros]:
579701
return None

0 commit comments

Comments
 (0)