1515
1616import logging
1717import os
18+ import re
1819from dataclasses import dataclass , field
19- from typing import Optional , Union
20+ from typing import Any , Dict , Optional , Union
2021
2122import yaml
22- from kubeflow .trainer .api .trainer_client import TrainerClient
23- from kubeflow .trainer .types .types import (
24- CustomTrainer ,
25- Runtime ,
26- )
23+ from kubeflow .trainer import CustomTrainer , TrainerClient
2724from kubernetes import client , config
2825from kubernetes .client .exceptions import ApiException
2926
3633logger = logging .getLogger (__name__ )
3734
3835
39- def _nemo_inline_entry_params (params : dict ):
40- """Execute inline Script content using the SDK's func_args injection style.
41-
42- The SDK injects a single positional dict when func_args is provided; this
43- function unpacks the dict and executes the content via bash or python.
44- """
45- if not isinstance (params , dict ):
46- raise ValueError ("Expected params to be a dict with keys 'script' and 'entrypoint'." )
47-
48- script = params .get ("script" , "" )
49- entrypoint = params .get ("entrypoint" , "bash" )
50-
51- # Self-contained to work when injected by the SDK: include imports here
52- import subprocess as _sp
53- import textwrap as _tw
54-
55- script = _tw .dedent (script )
56- if "python" in entrypoint :
57- exec (script , {})
58- return
59- _sp .run (["bash" , "-lc" , script ], check = True )
60-
61-
6236@dataclass (kw_only = True )
6337class KubeflowExecutor (Executor ):
6438 """
@@ -88,9 +62,6 @@ class KubeflowExecutor(Executor):
8862 exp.run()
8963 """
9064
91- #: Unique logical name for this executor; used for CRT and ConfigMap naming
92- name : str
93-
9465 #: Number of nodes for distributed training
9566 nodes : int = 1
9667
@@ -112,6 +83,9 @@ class KubeflowExecutor(Executor):
11283 #: Container image for training jobs
11384 image : str = "nvcr.io/nvidia/nemo:dev"
11485
86+ #: Training job filename
87+ training_entry : str = "experiment"
88+
11589 #: Volume mount path for staged files (default: /src)
11690 volume_mount_path : str = "/src"
11791
@@ -177,10 +151,15 @@ def assign(
177151 ):
178152 """Assign experiment and task information to the executor."""
179153 self .experiment_id = exp_id
154+ self .experiment_name = re .sub (r"([_\d]+)" , "" , exp_id )
180155 self .experiment_dir = exp_dir
181156 self .job_dir = os .path .join (exp_dir , task_dir )
182157 self .job_name = task_id
183158
159+ logger .info (
160+ f"KubeflowExecutor assigned: experiment_id={ self .experiment_id } , job_name={ self .job_name } "
161+ )
162+
184163 def set_detach_mode (self , detach : bool ):
185164 """Set detach mode for the executor."""
186165 self ._detach_mode = detach
@@ -237,15 +216,10 @@ def _get_trainer_client(self) -> TrainerClient:
237216 self ._trainer_client = TrainerClient (namespace = self .namespace )
238217 return self ._trainer_client
239218
240- def _get_runtime (self , trainer = None ) -> Runtime :
241- """Get the Runtime configuration for the training job."""
242- client = self ._get_trainer_client ()
243- runtime_name = self ._runtime_name ()
244- return client .get_runtime (runtime_name )
245-
246- def _create_cluster_training_runtime (self , configmap_name : str ) -> str :
219+ def _create_cluster_training_runtime (self , configmap_name : str , sha : str ) -> str :
247220 """Create or replace a ClusterTrainingRuntime bound to the given ConfigMap."""
248- runtime_name = self ._runtime_name ()
221+ runtime_name = self ._runtime_name (sha )
222+
249223 if not hasattr (self , "_kubernetes_available" ) or not self ._kubernetes_available :
250224 raise RuntimeError ("Kubernetes is not available; cannot create ClusterTrainingRuntime" )
251225
@@ -265,7 +239,7 @@ def _create_cluster_training_runtime(self, configmap_name: str) -> str:
265239 template_name = "kubeflow_clustertrainingruntime.yaml.j2" ,
266240 variables = template_vars ,
267241 )
268- runtime_body = yaml .safe_load (rendered )
242+ runtime_body = yaml .safe_load (rendered ) # type: ignore[assignment]
269243
270244 try :
271245 api_client .create_cluster_custom_object (
@@ -277,31 +251,124 @@ def _create_cluster_training_runtime(self, configmap_name: str) -> str:
277251 logger .info (f"Created ClusterTrainingRuntime: { runtime_name } " )
278252 except ApiException as e :
279253 if e .status == 409 :
280- # Replace to ensure the ClusterTrainingRuntime is updated
281- api_client .replace_cluster_custom_object (
282- group = "trainer.kubeflow.org" ,
283- version = "v1alpha1" ,
284- plural = "clustertrainingruntimes" ,
285- name = runtime_name ,
286- body = runtime_body ,
287- )
288- logger .info (f"Replaced existing ClusterTrainingRuntime: { runtime_name } " )
254+ # Resource already exists, fetch it first to get resourceVersion
255+ try :
256+ existing_runtime_obj = api_client .get_cluster_custom_object (
257+ group = "trainer.kubeflow.org" ,
258+ version = "v1alpha1" ,
259+ plural = "clustertrainingruntimes" ,
260+ name = runtime_name ,
261+ )
262+ existing_runtime : Dict [str , Any ] = existing_runtime_obj # type: ignore[assignment]
263+ # Update the resourceVersion in our new body
264+ runtime_body ["metadata" ]["resourceVersion" ] = existing_runtime ["metadata" ][
265+ "resourceVersion"
266+ ] # type: ignore[index]
267+
268+ # Replace the existing ClusterTrainingRuntime
269+ api_client .replace_cluster_custom_object (
270+ group = "trainer.kubeflow.org" ,
271+ version = "v1alpha1" ,
272+ plural = "clustertrainingruntimes" ,
273+ name = runtime_name ,
274+ body = runtime_body ,
275+ )
276+ logger .info (f"Replaced existing ClusterTrainingRuntime: { runtime_name } " )
277+ except Exception as replace_error :
278+ logger .error (
279+ f"Failed to replace existing ClusterTrainingRuntime: { replace_error } "
280+ )
281+ raise
289282 else :
290- logger .error (f"Failed to create/replace ClusterTrainingRuntime: { e } " )
283+ logger .error (f"Failed to create ClusterTrainingRuntime: { e } " )
291284 raise
292285 return runtime_name
293286
294- def stage_files (self , task_dir : str , task = None ) -> str :
295- """Stage files using the packager."""
296- if isinstance (self .packager , ConfigMapPackager ):
297- return self .packager .package_default (self .name )
298- else :
299- return task_dir
287+ def _get_additional_files (self , task ) -> dict [str , tuple [str , str ]]:
288+ """Get additional files to stage based on task type.
289+
290+ Returns:
291+ Dict mapping filename to (content, entrypoint) tuples
292+ """
293+ files_to_stage = {}
294+
295+ if task is None :
296+ return files_to_stage
297+
298+ if hasattr (task , "inline" ) and task .inline :
299+ # Script task - stage the script content in ConfigMap
300+ content : Optional [str ] = None
301+ entrypoint = getattr (task , "entrypoint" , "bash" )
302+
303+ # Check if inline content is a file path (processed by TorchX packaging)
304+ if task .inline .strip ().startswith ("/" ) and task .inline .strip ().endswith (".sh" ):
305+ # This is a script file path created by TorchX packaging
306+ script_path = task .inline .strip ()
307+ # Convert TorchX path to local path
308+ local_script_path = script_path .replace (
309+ "/nemo_run/scripts/" , f"{ self .job_dir } /scripts/"
310+ )
311+ if os .path .exists (local_script_path ):
312+ with open (local_script_path , "r" , encoding = "utf-8" ) as f :
313+ content = f .read ()
314+ logger .info (
315+ f"Read script content from TorchX-generated file: { local_script_path } "
316+ )
317+ else :
318+ logger .warning (f"TorchX script file not found, skipping: { local_script_path } " )
319+ return files_to_stage
320+ else :
321+ # Direct inline content
322+ content = task .inline
323+
324+ if content :
325+ files_to_stage [self .training_entry ] = (content , entrypoint )
326+ logger .info ("Script task - will stage content in ConfigMap" )
327+
328+ elif hasattr (task , "__fn_or_cls__" ):
329+ # Partial task - will be handled directly by CustomTrainer, no ConfigMap staging needed
330+ logger .info (
331+ "Partial task - will be passed directly to CustomTrainer, skipping ConfigMap staging"
332+ )
333+
334+ return files_to_stage
335+
336+ def stage_files (self , task_dir : str , task = None ) -> tuple [str , str ]:
337+ """Stage files using the packager.
338+
339+ Adds additional files based on task content and packages along with
340+ any original files configured on the packager. Returns the ConfigMap name.
341+ """
342+ if not isinstance (self .packager , ConfigMapPackager ):
343+ return (task_dir , "" )
344+
345+ # Get additional files to stage based on task type
346+ additional_files = self ._get_additional_files (task )
347+
348+ # Stage all additional files
349+ experiment_id = self ._get_experiment_identifier ()
350+ for filename , (content , entrypoint ) in additional_files .items ():
351+ self .packager .add_file (experiment_id , filename , content , entrypoint = entrypoint )
352+
353+ try :
354+ configmap_name , sha = self .packager .package_with_hash (experiment_id )
355+ logger .info (f"Staged files into ConfigMap: { configmap_name } (sha={ sha or 'n/a' } )" )
356+ return (configmap_name , sha )
357+ except Exception as e :
358+ logger .error (f"Failed to stage files: { e } " )
359+ raise
360+
361+ def _get_experiment_identifier (self ) -> str :
362+ """Return experiment_id; raise if not assigned yet."""
363+ if hasattr (self , "experiment_name" ) and self .experiment_name :
364+ return f"{ self .experiment_name } "
365+ raise RuntimeError ("Executor not assigned to experiment; missing experiment_name" )
300366
301367 def cleanup_files (self , task_dir : str , task = None ):
302368 """Clean up staged files."""
303369 if isinstance (self .packager , ConfigMapPackager ):
304- self .packager .cleanup (self .name )
370+ # Use experiment-specific naming for cleanup
371+ self .packager .cleanup (self ._get_experiment_identifier ())
305372
306373 def _get_custom_trainer (self , task ) -> CustomTrainer :
307374 """Get the CustomTrainer configuration for the training job."""
@@ -315,35 +382,40 @@ def _get_custom_trainer(self, task) -> CustomTrainer:
315382 resources_per_node ["nvidia.com/gpu" ] = str (self .gpus )
316383 trainer_kwargs ["resources_per_node" ] = resources_per_node
317384
318- if hasattr (task , "inline" ) and task .inline :
319- trainer_kwargs ["func" ] = _nemo_inline_entry_params
320- trainer_kwargs ["func_args" ] = {
321- "script" : task .inline ,
322- "entrypoint" : getattr (task , "entrypoint" , "bash" ),
323- }
324- elif hasattr (task , "__fn_or_cls__" ):
385+ if hasattr (task , "__fn_or_cls__" ):
325386 trainer_kwargs ["func" ] = task .__fn_or_cls__
387+ if hasattr (task , "__arguments__" ) and task .__arguments__ :
388+ trainer_kwargs ["func_args" ] = task .__arguments__
326389 else :
327- raise ValueError ("Task must be a Script or Partial object" )
390+ # Script task - set python_file and check for bash scripts
391+ trainer_kwargs ["python_file" ] = f"{ self .volume_mount_path } /{ self .training_entry } "
328392
329- return CustomTrainer (** trainer_kwargs )
393+ # Check if this is a bash script and set appropriate command
394+ if hasattr (task , "inline" ) and task .inline :
395+ entrypoint = getattr (task , "entrypoint" , "bash" )
396+ if entrypoint and "bash" in entrypoint .lower ():
397+ trainer_kwargs ["command" ] = ["/bin/bash" ]
398+ logger .info ("Using bash command for script execution" )
399+ # For Python scripts, let SDK auto-detect based on runtime
330400
331- def _get_staged_file_path ( self , filename : str ) -> str :
332- """Get the staged file path for a given filename."""
333- if isinstance ( self . packager , ConfigMapPackager ):
334- # Use executor name for mounted path grouping
335- effective_dir = sanitize_kubernetes_name ( self . name )
336- sanitized_filename = filename . replace ( "/" , "-" )
337- return f" { self . volume_mount_path } / { effective_dir } - { sanitized_filename } "
338- else :
339- return filename
401+ # Debug logging to see what we're passing to CustomTrainer
402+ logger . info ( f"Creating CustomTrainer with kwargs: { trainer_kwargs } " )
403+
404+ trainer = CustomTrainer ( ** trainer_kwargs )
405+
406+ # Debug logging to see what CustomTrainer actually received
407+ logger . info ( f"CustomTrainer created with func: { trainer . func } " )
408+ logger . info ( f"CustomTrainer created with func_args: { trainer . func_args } " )
409+ logger . info ( f"CustomTrainer created with python_file: { trainer . python_file } " )
340410
341- def create_trainjob (self , job_name : str , task ) -> str :
411+ return trainer
412+
413+ def create_trainjob (self , job_name : str , task , runtime_name : str ) -> str :
342414 """Create a TrainJob using the Kubeflow SDK."""
343415 try :
344416 client = self ._get_trainer_client ()
345417 trainer = self ._get_custom_trainer (task )
346- runtime = self . _get_runtime ( trainer = trainer )
418+ runtime = client . get_runtime ( runtime_name )
347419 job_id = client .train (runtime = runtime , trainer = trainer )
348420 logger .info (f"Created TrainJob: { job_id } " )
349421 return job_id
@@ -379,41 +451,30 @@ def get_trainjob_logs(self, job_name: str, follow: bool = False) -> dict:
379451 logger .error (f"Failed to get TrainJob logs: { e } " )
380452 return {}
381453
382- def prepare_runtime (self ) -> str :
454+ def prepare_runtime (self , task = None ) -> tuple [ str , str ] :
383455 """Atomically prepare runtime dependencies for this executor.
384456
385457 Steps:
386- - Upsert the ConfigMap for this executor's name (if using ConfigMapPackager)
387- - Create/replace the ClusterTrainingRuntime that references that ConfigMap
458+ - Create a unique ConfigMap for this experiment that includes:
459+ * Initial training code (from ConfigMapPackager)
460+ * Dynamic experiment scripts (created during task execution)
461+ - Create a unique ClusterTrainingRuntime that references that ConfigMap
388462
389- Returns the runtime name . Raises on failure so callers don't proceed to submit().
463+ Returns (runtime_name, sha) . Raises on failure so callers don't proceed to submit().
390464 """
391- configmap_name : Optional [str ] = None
392- if isinstance (self .packager , ConfigMapPackager ):
393- try :
394- # package_default returns the fully resolved ConfigMap name (with prefix)
395- configmap_name = self .packager .package_default (self .name )
396- logger .info (f"Prepared ConfigMap: { configmap_name } " )
397- except Exception as e :
398- logger .error (f"Failed to prepare ConfigMap for '{ self .name } ': { e } " )
399- raise
465+ # Stage files to ensure we have the latest content and ConfigMap
466+ configmap_name , sha = self .stage_files (task_dir = "" , task = task )
400467
468+ # Create runtime bound to this ConfigMap
401469 try :
402470 runtime_name = self ._create_cluster_training_runtime (
403- configmap_name = configmap_name or self . name
471+ configmap_name = configmap_name , sha = sha
404472 )
405473 logger .info (f"Prepared runtime: { runtime_name } " )
406- return runtime_name
474+ return ( runtime_name , sha )
407475 except Exception :
408476 raise
409477
410- # Backwards-compatible helpers call the atomic method
411- def ensure_configmap (self ) -> str :
412- return self .prepare_runtime ()
413-
414- def ensure_runtime (self ) -> str :
415- return self .prepare_runtime ()
416-
417478 def submit (self , task , job_name : str ) -> str :
418479 """
419480 Submit a job using the Kubeflow SDK.
@@ -425,10 +486,10 @@ def submit(self, task, job_name: str) -> str:
425486 raise RuntimeError ("Executor not assigned to experiment" )
426487
427488 try :
428- # Prepare runtime dependencies on every submit; K8s upserts make this safe
429- self .prepare_runtime ()
489+ # Prepare runtime dependencies (stages files and creates runtime)
490+ runtime_name , _ = self .prepare_runtime (task = task )
430491
431- job_id = self .create_trainjob (job_name , task )
492+ job_id = self .create_trainjob (job_name , task , runtime_name )
432493 logger .info (f"Submitted job { job_name } with ID: { job_id } " )
433494 return job_id
434495
@@ -464,5 +525,7 @@ def info(self) -> str:
464525 """Get information about the executor configuration."""
465526 return f"KubeflowExecutor (nodes={ self .nodes } , gpus={ self .gpus or 0 } )"
466527
467- def _runtime_name (self ) -> str :
468- return f"nemo-runtime-{ sanitize_kubernetes_name (self .name )} "
528+ def _runtime_name (self , sha : str ) -> str :
529+ """Build CRT name from the shared experiment identifier and sha."""
530+ identifier = self ._get_experiment_identifier ()
531+ return sanitize_kubernetes_name (f"nemo-runtime-{ identifier } -{ sha } " )
0 commit comments