1515
1616import base64
1717import glob
18+ import gzip
1819import json
1920import logging
2021import os
22+ import shutil
2123import subprocess
2224import tempfile
2325import time
@@ -70,8 +72,8 @@ class DGXCloudExecutor(Executor):
7072
7173 base_url : str
7274 kube_apiserver_url : str
73- app_id : str
74- app_secret : str
75+ client_id : str
76+ client_secret : str
7577 project_name : str
7678 container_image : str
7779 pvc_nemo_run_dir : str
@@ -83,13 +85,14 @@ class DGXCloudExecutor(Executor):
8385 pvcs : list [dict [str , Any ]] = field (default_factory = list )
8486 distributed_framework : str = "PyTorch"
8587 custom_spec : dict [str , Any ] = field (default_factory = dict )
88+ MAX_ARGS_CHARS : int = 9500
8689
8790 def get_auth_token (self ) -> Optional [str ]:
8891 url = f"{ self .base_url } /token"
8992 payload = {
90- "grantType" : "app_token " ,
91- "appId " : self .app_id ,
92- "appSecret " : self .app_secret ,
93+ "grantType" : "client_credentials " ,
94+ "clientId " : self .client_id ,
95+ "clientSecret " : self .client_secret ,
9396 }
9497
9598 n_attempts = 0
@@ -138,18 +141,46 @@ def copy_directory_data_command(self, local_dir_path: str, dest_path: str) -> st
138141 cmd = f"rm -rf { dest_path } && mkdir -p { dest_path } && echo { encoded_data } | base64 -d > { dest_path } /archive.tar.gz && tar -xzf { dest_path } /archive.tar.gz -C { dest_path } && rm { dest_path } /archive.tar.gz"
139142 return cmd
140143
141- def create_data_mover_workload (self , token : str , project_id : str , cluster_id : str ):
142- """
143- Creates a CPU only workload to move job directory into PVC using the provided project/cluster IDs.
144- """
144+ def delete_workload (self , token : str , workload_id : str ):
145+ url = f"{ self .base_url } /workloads/workspaces/{ workload_id } "
146+ headers = self ._default_headers (token = token )
145147
146- cmd = self . copy_directory_data_command ( self . job_dir , self . pvc_job_dir )
148+ response = requests . delete ( url , headers = headers )
147149
148- url = f"{ self .base_url } /workloads/workspaces"
150+ logger .debug (
151+ "Delete interactive workspace; response code=%s, content=%s" ,
152+ response .status_code ,
153+ response .text .strip (),
154+ )
155+ return response
156+
157+ def _workspace_status (self , workload_id : str ) -> Optional [DGXCloudState ]:
158+ """Query workspace-specific status endpoint for data-mover workloads."""
159+ url = f"{ self .base_url } /workloads/workspaces/{ workload_id } "
160+ token = self .get_auth_token ()
161+ if not token :
162+ return None
149163 headers = self ._default_headers (token = token )
164+ response = requests .get (url , headers = headers )
165+ if response .status_code != 200 :
166+ return None
167+ data = response .json ()
168+ phase = data .get ("actualPhase" ) or data .get ("phase" )
169+ return DGXCloudState (phase ) if phase else None
150170
171+ def _run_workspace_and_wait (
172+ self ,
173+ token : str ,
174+ project_id : str ,
175+ cluster_id : str ,
176+ name : str ,
177+ cmd : str ,
178+ sleep : float = 10 ,
179+ timeout : int = 300 ,
180+ ) -> None :
181+ """Create a workspace workload, poll until done, then delete it."""
151182 payload = {
152- "name" : "data-mover" ,
183+ "name" : name ,
153184 "useGivenNameAsPrefix" : True ,
154185 "projectId" : project_id ,
155186 "clusterId" : cluster_id ,
@@ -160,76 +191,94 @@ def create_data_mover_workload(self, token: str, project_id: str, cluster_id: st
160191 "storage" : {"pvc" : self .pvcs },
161192 },
162193 }
163-
164- response = requests .post (url , json = payload , headers = headers )
165-
166- logger .debug (
167- "Created workload; response code=%s, content=%s" ,
168- response .status_code ,
169- response .text .strip (),
170- )
171-
172- return response
173-
174- def delete_workload (self , token : str , workload_id : str ):
175- url = f"{ self .base_url } /workloads/workspaces/{ workload_id } "
176194 headers = self ._default_headers (token = token )
177-
178- response = requests .delete (url , headers = headers )
179-
180- logger .debug (
181- "Delete interactive workspace; response code=%s, content=%s" ,
182- response .status_code ,
183- response .text .strip (),
184- )
185- return response
195+ resp = requests .post (f"{ self .base_url } /workloads/workspaces" , json = payload , headers = headers )
196+ if resp .status_code not in (200 , 202 ):
197+ raise RuntimeError (f"Workload '{ name } ' failed: { resp .status_code } { resp .text } " )
198+ wid = resp .json ()["workloadId" ]
199+ logger .info (" workload %s (%s) created" , name , wid [:12 ])
200+
201+ elapsed = 0
202+ while elapsed < timeout :
203+ time .sleep (sleep )
204+ elapsed += sleep
205+ status = self ._workspace_status (wid )
206+ if status == DGXCloudState .COMPLETED :
207+ self .delete_workload (token , wid )
208+ return
209+ if status in (DGXCloudState .FAILED , DGXCloudState .STOPPED , DGXCloudState .DEGRADED ):
210+ self .delete_workload (token , wid )
211+ raise RuntimeError (f"Workload { wid } ended with: { status } " )
212+ raise RuntimeError (f"Workload { wid } timed out after { timeout } s" )
186213
187214 def move_data (self , token : str , project_id : str , cluster_id : str , sleep : float = 10 ) -> None :
188- """
189- Moves job directory into PVC and deletes the workload after completion
190- """
191-
192- resp = self .create_data_mover_workload (token , project_id , cluster_id )
193- if resp .status_code not in [200 , 202 ]:
194- raise RuntimeError (
195- f"Failed to create data mover workload, status_code={ resp .status_code } , reason={ resp .text } "
196- )
197-
198- resp_json = resp .json ()
199- workload_id = resp_json ["workloadId" ]
200- status = DGXCloudState (resp_json ["actualPhase" ])
215+ """Move job directory into PVC.
201216
202- logger .info (f"Successfully created data movement workload { workload_id } on DGXCloud" )
203-
204- while status in [
205- DGXCloudState .PENDING ,
206- DGXCloudState .CREATING ,
207- DGXCloudState .INITIALIZING ,
208- DGXCloudState .RUNNING ,
209- ]:
210- time .sleep (sleep )
211- status = self .status (workload_id )
212- logger .debug (
213- f"Polling data movement workload { workload_id } 's status. Current status is: { status } "
214- )
217+ Uses the fast single-command tarball when it fits within the API's
218+ 10 000-char limit. Falls back to per-file deployment otherwise.
219+ """
220+ cmd = self .copy_directory_data_command (self .job_dir , self .pvc_job_dir )
215221
216- if status is not DGXCloudState .COMPLETED :
217- raise RuntimeError (f"Failed to move data to PVC. Workload status is { status } " )
222+ if len (cmd ) <= self .MAX_ARGS_CHARS :
223+ self ._run_workspace_and_wait (token , project_id , cluster_id , "data-mover" , cmd , sleep )
224+ return
218225
219- resp = self .delete_workload (token , workload_id )
220- if resp .status_code >= 200 and resp .status_code < 300 :
221- logger .info (
222- "Successfully deleted data movement workload %s on DGXCloud with response code %d" ,
223- workload_id ,
224- resp .status_code ,
225- )
226- else :
227- logger .error (
228- "Failed to delete data movement workload %s, response code=%d, reason=%s" ,
229- workload_id ,
230- resp .status_code ,
231- resp .text ,
232- )
226+ logger .info (
227+ "Tarball is %d chars (limit %d), deploying files individually" ,
228+ len (cmd ),
229+ self .MAX_ARGS_CHARS ,
230+ )
231+ for root , _ , filenames in os .walk (self .job_dir ):
232+ for fn in filenames :
233+ if fn .endswith (".tar.gz" ):
234+ continue
235+ abs_path = os .path .join (root , fn )
236+ rel_path = os .path .relpath (abs_path , self .job_dir )
237+ dest = os .path .join (self .pvc_job_dir , rel_path )
238+ with open (abs_path , "rb" ) as f :
239+ data = f .read ()
240+
241+ compressed = gzip .compress (data , compresslevel = 9 )
242+ encoded = base64 .b64encode (compressed ).decode ()
243+ overhead = len (f"mkdir -p $(dirname { dest } ) && echo | base64 -d | gunzip > { dest } " )
244+ chunk_b64_limit = self .MAX_ARGS_CHARS - overhead - 50
245+
246+ if len (encoded ) <= chunk_b64_limit :
247+ file_cmd = f"mkdir -p $(dirname { dest } ) && echo { encoded } | base64 -d | gunzip > { dest } "
248+ logger .info (
249+ " deploying %s (%d→%d bytes)" , rel_path , len (data ), len (compressed )
250+ )
251+ self ._run_workspace_and_wait (
252+ token , project_id , cluster_id , "data-mover" , file_cmd , sleep
253+ )
254+ else :
255+ chunk_size = (chunk_b64_limit * 3 ) // 4
256+ raw_chunks = [
257+ compressed [i : i + chunk_size ]
258+ for i in range (0 , len (compressed ), chunk_size )
259+ ]
260+ logger .info (
261+ " deploying %s in %d chunks (%d→%d bytes)" ,
262+ rel_path ,
263+ len (raw_chunks ),
264+ len (data ),
265+ len (compressed ),
266+ )
267+ for ci , chunk in enumerate (raw_chunks ):
268+ b64 = base64 .b64encode (chunk ).decode ()
269+ if ci == 0 :
270+ file_cmd = (
271+ f"mkdir -p $(dirname { dest } ) && echo { b64 } | base64 -d > { dest } .gz"
272+ )
273+ else :
274+ file_cmd = f"echo { b64 } | base64 -d >> { dest } .gz"
275+ self ._run_workspace_and_wait (
276+ token , project_id , cluster_id , "data-mover" , file_cmd , sleep
277+ )
278+ gunzip_cmd = f"gunzip -f { dest } .gz"
279+ self ._run_workspace_and_wait (
280+ token , project_id , cluster_id , "data-mover" , gunzip_cmd , sleep
281+ )
233282
234283 def create_training_job (
235284 self , token : str , project_id : str , cluster_id : str , name : str
@@ -272,7 +321,7 @@ def create_training_job(
272321 common_spec = {
273322 "command" : f"/bin/bash { self .pvc_job_dir } /launch_script.sh" ,
274323 "image" : self .container_image ,
275- "compute" : {"gpuDevicesRequest" : self .gpus_per_node },
324+ "compute" : {"gpuDevicesRequest" : self .gpus_per_node , "largeShmRequest" : True },
276325 "storage" : {"pvc" : self .pvcs },
277326 "environmentVariables" : [
278327 {"name" : key , "value" : value } for key , value in self .env_vars .items ()
@@ -321,6 +370,17 @@ def launch(self, name: str, cmd: list[str]) -> tuple[str, str]:
321370 if not project_id or not cluster_id :
322371 raise RuntimeError ("Unable to determine project/cluster IDs for job submission" )
323372
373+ # Copy experiment-level files referenced in cmd into job_dir
374+ # so they are included in the data mover transfer to the PVC
375+ cmd_str = " " .join (cmd )
376+ for fname in os .listdir (self .experiment_dir ):
377+ fpath = os .path .join (self .experiment_dir , fname )
378+ if os .path .isfile (fpath ) and fpath in cmd_str :
379+ shutil .copy2 (fpath , os .path .join (self .job_dir , fname ))
380+
381+ # Rewrite local paths in cmd to point to the PVC job directory
382+ cmd = [c .replace (self .experiment_dir , self .pvc_job_dir ) for c in cmd ]
383+
324384 # prepare launch script and move data to PVC
325385 launch_script = f"""
326386ln -s { self .pvc_job_dir } / /nemo_run
@@ -390,9 +450,37 @@ def fetch_logs(
390450 stderr : Optional [bool ] = None ,
391451 stdout : Optional [bool ] = None ,
392452 ) -> Iterable [str ]:
393- while self .status (job_id ) != DGXCloudState .RUNNING :
394- logger .info ("Waiting for job to start..." )
453+ state = self .status (job_id )
454+ while state != DGXCloudState .RUNNING :
455+ logger .info ("Job %s — status: %s" , job_id [:12 ], state .value if state else "Unknown" )
456+ if state in (
457+ DGXCloudState .COMPLETED ,
458+ DGXCloudState .FAILED ,
459+ DGXCloudState .STOPPED ,
460+ DGXCloudState .DEGRADED ,
461+ ):
462+ logger .warning ("Job reached terminal state %s before logs were available" , state )
463+ return
395464 time .sleep (15 )
465+ state = self .status (job_id )
466+
467+ if not self .launched_from_cluster :
468+ logger .info ("Job %s is RUNNING. Logs are available in the Run:AI UI." , job_id [:12 ])
469+ terminal = (
470+ DGXCloudState .COMPLETED ,
471+ DGXCloudState .FAILED ,
472+ DGXCloudState .STOPPED ,
473+ DGXCloudState .DEGRADED ,
474+ )
475+ while True :
476+ time .sleep (30 )
477+ state = self .status (job_id )
478+ logger .info ("Job %s — status: %s" , job_id [:12 ], state .value if state else "Unknown" )
479+ if state in terminal :
480+ yield f"Job finished with status: { state .value } "
481+ return
482+
483+ logger .info ("Job %s is RUNNING, waiting for log files..." , job_id [:12 ])
396484
397485 cmd = ["tail" ]
398486
@@ -405,12 +493,21 @@ def fetch_logs(
405493 self .pvc_job_dir = os .path .join (self .pvc_nemo_run_dir , job_subdir )
406494
407495 files = []
496+ poll_count = 0
408497 while len (files ) < self .nodes :
409498 files = list (glob .glob (f"{ self .pvc_job_dir } /log_*.out" ))
410499 files = [f for f in files if "log-allranks_0" not in f ]
411- logger .info (
412- f"Waiting for { self .nodes + 1 - len (files )} log files to be created in { self .pvc_job_dir } ..."
413- )
500+ if poll_count == 0 or poll_count % 10 == 0 :
501+ logger .info (
502+ "Log files: %d/%d ready (watching %s)" ,
503+ len (files ),
504+ self .nodes ,
505+ self .pvc_job_dir ,
506+ )
507+ poll_count += 1
508+ if poll_count > 100 :
509+ logger .warning ("Timed out waiting for log files after 5 minutes" )
510+ return
414511 time .sleep (3 )
415512
416513 cmd .extend (files )
@@ -526,6 +623,30 @@ def assign(
526623 )
527624 self .experiment_id = exp_id
528625
626+ def deploy_script_to_pvc (
627+ self ,
628+ script_content : str ,
629+ dest_path : str ,
630+ token : Optional [str ] = None ,
631+ project_id : Optional [str ] = None ,
632+ cluster_id : Optional [str ] = None ,
633+ ) -> None :
634+ """Write a script to the PVC via a short-lived busybox workspace."""
635+ if not token :
636+ token = self .get_auth_token ()
637+ if not token :
638+ raise RuntimeError ("Failed to get auth token for script deployment" )
639+ if not project_id or not cluster_id :
640+ project_id , cluster_id = self .get_project_and_cluster_id (token )
641+
642+ encoded = base64 .b64encode (gzip .compress (script_content .encode (), compresslevel = 9 )).decode ()
643+ cmd = (
644+ f"mkdir -p $(dirname { dest_path } ) && "
645+ f"echo { encoded } | base64 -d | gunzip > { dest_path } && "
646+ f"chmod +x { dest_path } "
647+ )
648+ self ._run_workspace_and_wait (token , project_id , cluster_id , "script-deploy" , cmd )
649+
529650 def get_launcher_prefix (self ) -> Optional [list [str ]]:
530651 launcher = self .get_launcher ()
531652 if launcher .nsys_profile :
@@ -574,6 +695,7 @@ def package(self, packager: Packager, job_name: str):
574695 ctx .run (
575696 f"tar -xvzf { local_pkg } -C { local_code_extraction_path } --ignore-zeros" , hide = True
576697 )
698+ os .remove (local_pkg )
577699
578700 def macro_values (self ) -> Optional [ExecutorMacros ]:
579701 return None
0 commit comments