Skip to content

Commit dfff05e

Browse files
committed
Refactor KubeflowExecutor for Improved Task Handling
This commit refactors the KubeflowExecutor class to enhance task management and streamline the creation of ClusterTrainingRuntime. Key changes include: - Removed the _nemo_inline_entry_params function, simplifying inline script handling. - Introduced a new method to get additional files based on task types, allowing for better staging of files in ConfigMap. - Updated the create_trainjob method to accept runtime_name directly, improving clarity in job submissions. - Adjusted the _runtime_name method to generate names based on a unique identifier and hash, ensuring no collisions. - Improved logging for better traceability during execution. These modifications aim to simplify the executor's interface and enhance its usability for developers working with Kubeflow.
1 parent 93f4c58 commit dfff05e

4 files changed

Lines changed: 418 additions & 255 deletions

File tree

nemo_run/core/execution/kubeflow.py

Lines changed: 168 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,12 @@
1515

1616
import logging
1717
import os
18+
import re
1819
from dataclasses import dataclass, field
19-
from typing import Optional, Union
20+
from typing import Any, Dict, Optional, Union
2021

2122
import yaml
22-
from kubeflow.trainer.api.trainer_client import TrainerClient
23-
from kubeflow.trainer.types.types import (
24-
CustomTrainer,
25-
Runtime,
26-
)
23+
from kubeflow.trainer import CustomTrainer, TrainerClient
2724
from kubernetes import client, config
2825
from kubernetes.client.exceptions import ApiException
2926

@@ -36,29 +33,6 @@
3633
logger = logging.getLogger(__name__)
3734

3835

39-
def _nemo_inline_entry_params(params: dict):
40-
"""Execute inline Script content using the SDK's func_args injection style.
41-
42-
The SDK injects a single positional dict when func_args is provided; this
43-
function unpacks the dict and executes the content via bash or python.
44-
"""
45-
if not isinstance(params, dict):
46-
raise ValueError("Expected params to be a dict with keys 'script' and 'entrypoint'.")
47-
48-
script = params.get("script", "")
49-
entrypoint = params.get("entrypoint", "bash")
50-
51-
# Self-contained to work when injected by the SDK: include imports here
52-
import subprocess as _sp
53-
import textwrap as _tw
54-
55-
script = _tw.dedent(script)
56-
if "python" in entrypoint:
57-
exec(script, {})
58-
return
59-
_sp.run(["bash", "-lc", script], check=True)
60-
61-
6236
@dataclass(kw_only=True)
6337
class KubeflowExecutor(Executor):
6438
"""
@@ -88,9 +62,6 @@ class KubeflowExecutor(Executor):
8862
exp.run()
8963
"""
9064

91-
#: Unique logical name for this executor; used for CRT and ConfigMap naming
92-
name: str
93-
9465
#: Number of nodes for distributed training
9566
nodes: int = 1
9667

@@ -112,6 +83,9 @@ class KubeflowExecutor(Executor):
11283
#: Container image for training jobs
11384
image: str = "nvcr.io/nvidia/nemo:dev"
11485

86+
#: Training job filename
87+
training_entry: str = "experiment"
88+
11589
#: Volume mount path for staged files (default: /src)
11690
volume_mount_path: str = "/src"
11791

@@ -177,10 +151,15 @@ def assign(
177151
):
178152
"""Assign experiment and task information to the executor."""
179153
self.experiment_id = exp_id
154+
self.experiment_name = re.sub(r"([_\d]+)", "", exp_id)
180155
self.experiment_dir = exp_dir
181156
self.job_dir = os.path.join(exp_dir, task_dir)
182157
self.job_name = task_id
183158

159+
logger.info(
160+
f"KubeflowExecutor assigned: experiment_id={self.experiment_id}, job_name={self.job_name}"
161+
)
162+
184163
def set_detach_mode(self, detach: bool):
185164
"""Set detach mode for the executor."""
186165
self._detach_mode = detach
@@ -237,15 +216,10 @@ def _get_trainer_client(self) -> TrainerClient:
237216
self._trainer_client = TrainerClient(namespace=self.namespace)
238217
return self._trainer_client
239218

240-
def _get_runtime(self, trainer=None) -> Runtime:
241-
"""Get the Runtime configuration for the training job."""
242-
client = self._get_trainer_client()
243-
runtime_name = self._runtime_name()
244-
return client.get_runtime(runtime_name)
245-
246-
def _create_cluster_training_runtime(self, configmap_name: str) -> str:
219+
def _create_cluster_training_runtime(self, configmap_name: str, sha: str) -> str:
247220
"""Create or replace a ClusterTrainingRuntime bound to the given ConfigMap."""
248-
runtime_name = self._runtime_name()
221+
runtime_name = self._runtime_name(sha)
222+
249223
if not hasattr(self, "_kubernetes_available") or not self._kubernetes_available:
250224
raise RuntimeError("Kubernetes is not available; cannot create ClusterTrainingRuntime")
251225

@@ -265,7 +239,7 @@ def _create_cluster_training_runtime(self, configmap_name: str) -> str:
265239
template_name="kubeflow_clustertrainingruntime.yaml.j2",
266240
variables=template_vars,
267241
)
268-
runtime_body = yaml.safe_load(rendered)
242+
runtime_body = yaml.safe_load(rendered) # type: ignore[assignment]
269243

270244
try:
271245
api_client.create_cluster_custom_object(
@@ -277,31 +251,124 @@ def _create_cluster_training_runtime(self, configmap_name: str) -> str:
277251
logger.info(f"Created ClusterTrainingRuntime: {runtime_name}")
278252
except ApiException as e:
279253
if e.status == 409:
280-
# Replace to ensure the ClusterTrainingRuntime is updated
281-
api_client.replace_cluster_custom_object(
282-
group="trainer.kubeflow.org",
283-
version="v1alpha1",
284-
plural="clustertrainingruntimes",
285-
name=runtime_name,
286-
body=runtime_body,
287-
)
288-
logger.info(f"Replaced existing ClusterTrainingRuntime: {runtime_name}")
254+
# Resource already exists, fetch it first to get resourceVersion
255+
try:
256+
existing_runtime_obj = api_client.get_cluster_custom_object(
257+
group="trainer.kubeflow.org",
258+
version="v1alpha1",
259+
plural="clustertrainingruntimes",
260+
name=runtime_name,
261+
)
262+
existing_runtime: Dict[str, Any] = existing_runtime_obj # type: ignore[assignment]
263+
# Update the resourceVersion in our new body
264+
runtime_body["metadata"]["resourceVersion"] = existing_runtime["metadata"][
265+
"resourceVersion"
266+
] # type: ignore[index]
267+
268+
# Replace the existing ClusterTrainingRuntime
269+
api_client.replace_cluster_custom_object(
270+
group="trainer.kubeflow.org",
271+
version="v1alpha1",
272+
plural="clustertrainingruntimes",
273+
name=runtime_name,
274+
body=runtime_body,
275+
)
276+
logger.info(f"Replaced existing ClusterTrainingRuntime: {runtime_name}")
277+
except Exception as replace_error:
278+
logger.error(
279+
f"Failed to replace existing ClusterTrainingRuntime: {replace_error}"
280+
)
281+
raise
289282
else:
290-
logger.error(f"Failed to create/replace ClusterTrainingRuntime: {e}")
283+
logger.error(f"Failed to create ClusterTrainingRuntime: {e}")
291284
raise
292285
return runtime_name
293286

294-
def stage_files(self, task_dir: str, task=None) -> str:
295-
"""Stage files using the packager."""
296-
if isinstance(self.packager, ConfigMapPackager):
297-
return self.packager.package_default(self.name)
298-
else:
299-
return task_dir
287+
def _get_additional_files(self, task) -> dict[str, tuple[str, str]]:
288+
"""Get additional files to stage based on task type.
289+
290+
Returns:
291+
Dict mapping filename to (content, entrypoint) tuples
292+
"""
293+
files_to_stage = {}
294+
295+
if task is None:
296+
return files_to_stage
297+
298+
if hasattr(task, "inline") and task.inline:
299+
# Script task - stage the script content in ConfigMap
300+
content: Optional[str] = None
301+
entrypoint = getattr(task, "entrypoint", "bash")
302+
303+
# Check if inline content is a file path (processed by TorchX packaging)
304+
if task.inline.strip().startswith("/") and task.inline.strip().endswith(".sh"):
305+
# This is a script file path created by TorchX packaging
306+
script_path = task.inline.strip()
307+
# Convert TorchX path to local path
308+
local_script_path = script_path.replace(
309+
"/nemo_run/scripts/", f"{self.job_dir}/scripts/"
310+
)
311+
if os.path.exists(local_script_path):
312+
with open(local_script_path, "r", encoding="utf-8") as f:
313+
content = f.read()
314+
logger.info(
315+
f"Read script content from TorchX-generated file: {local_script_path}"
316+
)
317+
else:
318+
logger.warning(f"TorchX script file not found, skipping: {local_script_path}")
319+
return files_to_stage
320+
else:
321+
# Direct inline content
322+
content = task.inline
323+
324+
if content:
325+
files_to_stage[self.training_entry] = (content, entrypoint)
326+
logger.info("Script task - will stage content in ConfigMap")
327+
328+
elif hasattr(task, "__fn_or_cls__"):
329+
# Partial task - will be handled directly by CustomTrainer, no ConfigMap staging needed
330+
logger.info(
331+
"Partial task - will be passed directly to CustomTrainer, skipping ConfigMap staging"
332+
)
333+
334+
return files_to_stage
335+
336+
def stage_files(self, task_dir: str, task=None) -> tuple[str, str]:
337+
"""Stage files using the packager.
338+
339+
Adds additional files based on task content and packages along with
340+
any original files configured on the packager. Returns the ConfigMap name.
341+
"""
342+
if not isinstance(self.packager, ConfigMapPackager):
343+
return (task_dir, "")
344+
345+
# Get additional files to stage based on task type
346+
additional_files = self._get_additional_files(task)
347+
348+
# Stage all additional files
349+
experiment_id = self._get_experiment_identifier()
350+
for filename, (content, entrypoint) in additional_files.items():
351+
self.packager.add_file(experiment_id, filename, content, entrypoint=entrypoint)
352+
353+
try:
354+
configmap_name, sha = self.packager.package_with_hash(experiment_id)
355+
logger.info(f"Staged files into ConfigMap: {configmap_name} (sha={sha or 'n/a'})")
356+
return (configmap_name, sha)
357+
except Exception as e:
358+
logger.error(f"Failed to stage files: {e}")
359+
raise
360+
361+
def _get_experiment_identifier(self) -> str:
362+
"""Return experiment_id; raise if not assigned yet."""
363+
if hasattr(self, "experiment_name") and self.experiment_name:
364+
return f"{self.experiment_name}"
365+
raise RuntimeError("Executor not assigned to experiment; missing experiment_name")
300366

301367
def cleanup_files(self, task_dir: str, task=None):
302368
"""Clean up staged files."""
303369
if isinstance(self.packager, ConfigMapPackager):
304-
self.packager.cleanup(self.name)
370+
# Use experiment-specific naming for cleanup
371+
self.packager.cleanup(self._get_experiment_identifier())
305372

306373
def _get_custom_trainer(self, task) -> CustomTrainer:
307374
"""Get the CustomTrainer configuration for the training job."""
@@ -315,35 +382,40 @@ def _get_custom_trainer(self, task) -> CustomTrainer:
315382
resources_per_node["nvidia.com/gpu"] = str(self.gpus)
316383
trainer_kwargs["resources_per_node"] = resources_per_node
317384

318-
if hasattr(task, "inline") and task.inline:
319-
trainer_kwargs["func"] = _nemo_inline_entry_params
320-
trainer_kwargs["func_args"] = {
321-
"script": task.inline,
322-
"entrypoint": getattr(task, "entrypoint", "bash"),
323-
}
324-
elif hasattr(task, "__fn_or_cls__"):
385+
if hasattr(task, "__fn_or_cls__"):
325386
trainer_kwargs["func"] = task.__fn_or_cls__
387+
if hasattr(task, "__arguments__") and task.__arguments__:
388+
trainer_kwargs["func_args"] = task.__arguments__
326389
else:
327-
raise ValueError("Task must be a Script or Partial object")
390+
# Script task - set python_file and check for bash scripts
391+
trainer_kwargs["python_file"] = f"{self.volume_mount_path}/{self.training_entry}"
328392

329-
return CustomTrainer(**trainer_kwargs)
393+
# Check if this is a bash script and set appropriate command
394+
if hasattr(task, "inline") and task.inline:
395+
entrypoint = getattr(task, "entrypoint", "bash")
396+
if entrypoint and "bash" in entrypoint.lower():
397+
trainer_kwargs["command"] = ["/bin/bash"]
398+
logger.info("Using bash command for script execution")
399+
# For Python scripts, let SDK auto-detect based on runtime
330400

331-
def _get_staged_file_path(self, filename: str) -> str:
332-
"""Get the staged file path for a given filename."""
333-
if isinstance(self.packager, ConfigMapPackager):
334-
# Use executor name for mounted path grouping
335-
effective_dir = sanitize_kubernetes_name(self.name)
336-
sanitized_filename = filename.replace("/", "-")
337-
return f"{self.volume_mount_path}/{effective_dir}-{sanitized_filename}"
338-
else:
339-
return filename
401+
# Debug logging to see what we're passing to CustomTrainer
402+
logger.info(f"Creating CustomTrainer with kwargs: {trainer_kwargs}")
403+
404+
trainer = CustomTrainer(**trainer_kwargs)
405+
406+
# Debug logging to see what CustomTrainer actually received
407+
logger.info(f"CustomTrainer created with func: {trainer.func}")
408+
logger.info(f"CustomTrainer created with func_args: {trainer.func_args}")
409+
logger.info(f"CustomTrainer created with python_file: {trainer.python_file}")
340410

341-
def create_trainjob(self, job_name: str, task) -> str:
411+
return trainer
412+
413+
def create_trainjob(self, job_name: str, task, runtime_name: str) -> str:
342414
"""Create a TrainJob using the Kubeflow SDK."""
343415
try:
344416
client = self._get_trainer_client()
345417
trainer = self._get_custom_trainer(task)
346-
runtime = self._get_runtime(trainer=trainer)
418+
runtime = client.get_runtime(runtime_name)
347419
job_id = client.train(runtime=runtime, trainer=trainer)
348420
logger.info(f"Created TrainJob: {job_id}")
349421
return job_id
@@ -379,41 +451,30 @@ def get_trainjob_logs(self, job_name: str, follow: bool = False) -> dict:
379451
logger.error(f"Failed to get TrainJob logs: {e}")
380452
return {}
381453

382-
def prepare_runtime(self) -> str:
454+
def prepare_runtime(self, task=None) -> tuple[str, str]:
383455
"""Atomically prepare runtime dependencies for this executor.
384456
385457
Steps:
386-
- Upsert the ConfigMap for this executor's name (if using ConfigMapPackager)
387-
- Create/replace the ClusterTrainingRuntime that references that ConfigMap
458+
- Create a unique ConfigMap for this experiment that includes:
459+
* Initial training code (from ConfigMapPackager)
460+
* Dynamic experiment scripts (created during task execution)
461+
- Create a unique ClusterTrainingRuntime that references that ConfigMap
388462
389-
Returns the runtime name. Raises on failure so callers don't proceed to submit().
463+
Returns (runtime_name, sha). Raises on failure so callers don't proceed to submit().
390464
"""
391-
configmap_name: Optional[str] = None
392-
if isinstance(self.packager, ConfigMapPackager):
393-
try:
394-
# package_default returns the fully resolved ConfigMap name (with prefix)
395-
configmap_name = self.packager.package_default(self.name)
396-
logger.info(f"Prepared ConfigMap: {configmap_name}")
397-
except Exception as e:
398-
logger.error(f"Failed to prepare ConfigMap for '{self.name}': {e}")
399-
raise
465+
# Stage files to ensure we have the latest content and ConfigMap
466+
configmap_name, sha = self.stage_files(task_dir="", task=task)
400467

468+
# Create runtime bound to this ConfigMap
401469
try:
402470
runtime_name = self._create_cluster_training_runtime(
403-
configmap_name=configmap_name or self.name
471+
configmap_name=configmap_name, sha=sha
404472
)
405473
logger.info(f"Prepared runtime: {runtime_name}")
406-
return runtime_name
474+
return (runtime_name, sha)
407475
except Exception:
408476
raise
409477

410-
# Backwards-compatible helpers call the atomic method
411-
def ensure_configmap(self) -> str:
412-
return self.prepare_runtime()
413-
414-
def ensure_runtime(self) -> str:
415-
return self.prepare_runtime()
416-
417478
def submit(self, task, job_name: str) -> str:
418479
"""
419480
Submit a job using the Kubeflow SDK.
@@ -425,10 +486,10 @@ def submit(self, task, job_name: str) -> str:
425486
raise RuntimeError("Executor not assigned to experiment")
426487

427488
try:
428-
# Prepare runtime dependencies on every submit; K8s upserts make this safe
429-
self.prepare_runtime()
489+
# Prepare runtime dependencies (stages files and creates runtime)
490+
runtime_name, _ = self.prepare_runtime(task=task)
430491

431-
job_id = self.create_trainjob(job_name, task)
492+
job_id = self.create_trainjob(job_name, task, runtime_name)
432493
logger.info(f"Submitted job {job_name} with ID: {job_id}")
433494
return job_id
434495

@@ -464,5 +525,7 @@ def info(self) -> str:
464525
"""Get information about the executor configuration."""
465526
return f"KubeflowExecutor (nodes={self.nodes}, gpus={self.gpus or 0})"
466527

467-
def _runtime_name(self) -> str:
468-
return f"nemo-runtime-{sanitize_kubernetes_name(self.name)}"
528+
def _runtime_name(self, sha: str) -> str:
529+
"""Build CRT name from the shared experiment identifier and sha."""
530+
identifier = self._get_experiment_identifier()
531+
return sanitize_kubernetes_name(f"nemo-runtime-{identifier}-{sha}")

0 commit comments

Comments
 (0)