Skip to content

Commit 57c9ea2

Browse files
committed
Implement Additional Packages Configuration for Executor
Add support for specifying additional package installations in the KubeflowExecutor. This enhancement allows users to configure packages to be installed within the training container, improving flexibility for various training requirements. - Introduced an AdditionalPackages dataclass to encapsulate installation parameters. - Updated the CommandTrainer instantiation to accept additional packages if configured. - Modified get_volume_name and get_pvc_claim_name to ensure lowercase names for Kubernetes compatibility. This change enables more robust customization of the training environment, facilitating the inclusion of necessary dependencies directly through executor configuration. Signed-off-by: Krishnaswamy Subramanian <subramk@thoughtworks.com>
1 parent 8c3ffa2 commit 57c9ea2

2 files changed

Lines changed: 64 additions & 9 deletions

File tree

nemo_run/core/execution/kubeflow.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import logging
1717
import os
1818
import re
19-
from dataclasses import dataclass, field
19+
from dataclasses import asdict, dataclass, field
2020
from typing import Any, Dict, Optional, Union
2121

2222
import yaml
@@ -139,13 +139,35 @@ def to_template_fragment(self, index: int) -> dict[str, Any]:
139139
def get_volume_name(self, index: int) -> str:
140140
"""Return a DNS-1123 safe volume name, defaulting to pvc-{index}."""
141141
base = self.name or f"pvc-{index}"
142-
return sanitize_kubernetes_name(base)
142+
return sanitize_kubernetes_name(base).lower()
143143

144144
def get_pvc_claim_name(self) -> Optional[str]:
145145
"""Return a DNS-1123 safe PVC claim name or None if unset."""
146146
if not self.pvc_claim_name:
147147
return None
148-
return sanitize_kubernetes_name(self.pvc_claim_name)
148+
return sanitize_kubernetes_name(self.pvc_claim_name).lower()
149+
150+
151+
@dataclass
152+
class AdditionalPackages:
153+
"""Optional package installation configuration for the training container.
154+
155+
Fields map directly to SDK `CommandTrainer` parameters.
156+
"""
157+
158+
packages_to_install: Optional[list[str]] = None
159+
pip_index_urls: Optional[list[str]] = None
160+
pip_extra_args: Optional[list[str]] = None
161+
162+
def as_trainer_kwargs(self) -> Dict[str, Any]:
163+
"""Return subset of kwargs for CommandTrainer based on configured fields."""
164+
allowed = {"packages_to_install", "pip_index_urls", "pip_extra_args"}
165+
return asdict(
166+
self,
167+
dict_factory=lambda items: {
168+
k: (list(v) if isinstance(v, list) else v) for k, v in items if k in allowed and v
169+
},
170+
)
149171

150172

151173
@dataclass(kw_only=True)
@@ -224,6 +246,9 @@ class KubeflowExecutor(Executor):
224246

225247
storage_mounts: list["StorageMount"] = field(default_factory=list)
226248

249+
#: Optional package installation configuration
250+
additional_packages: Optional[AdditionalPackages] = None
251+
227252
def __post_init__(self):
228253
"""Validate executor configuration and setup Kubernetes access."""
229254
if self.nodes < 1:
@@ -533,12 +558,16 @@ def _get_custom_trainer(self, task) -> CommandTrainer:
533558
mounted_path = f"{self.workspace_mount_path}/{self.training_entry}"
534559
command, args = _build_trainer_command(task, mounted_path)
535560

536-
trainer = CommandTrainer(
537-
command=command,
538-
args=args,
539-
num_nodes=self.nodes,
540-
resources_per_node=resources_per_node,
541-
)
561+
trainer_kwargs: Dict[str, Any] = {
562+
"command": command,
563+
"args": args,
564+
"num_nodes": self.nodes,
565+
"resources_per_node": resources_per_node,
566+
}
567+
if self.additional_packages:
568+
trainer_kwargs.update(self.additional_packages.as_trainer_kwargs())
569+
570+
trainer = CommandTrainer(**trainer_kwargs)
542571

543572
logger.info(
544573
f"CommandTrainer created with command={trainer.command}, args={trainer.args}, "

test/core/execution/test_kubeflow.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from nemo_run.config import Partial, Script
2323
from nemo_run.core.execution.kubeflow import (
24+
AdditionalPackages,
2425
KubeflowExecutor,
2526
StorageMount,
2627
)
@@ -787,3 +788,28 @@ def _dummy(x, y=2):
787788
assert "--nproc_per_node ${PET_NPROC_PER_NODE}" in args_joined
788789
assert "--rdzv_backend c10d" in args_joined
789790
assert "--rdzv_endpoint ${PET_MASTER_ADDR}:${PET_MASTER_PORT}" in args_joined
791+
792+
793+
def test_executor_additional_packages_forwarding():
794+
script_task = Script(inline="python train.py")
795+
executor = KubeflowExecutor(nodes=1, ntasks_per_node=4)
796+
executor.packager = ConfigMapPackager()
797+
executor.assign("exp-abc123", "/tmp/exp", "task-1", "task_dir")
798+
799+
executor.additional_packages = AdditionalPackages(
800+
packages_to_install=["nemo==2.0.0", "deepspeed>=0.14.0"],
801+
pip_index_urls=["https://pypi.org/simple", "https://extra/simple"],
802+
pip_extra_args=["--no-cache-dir", "--find-links", "/wheels"],
803+
)
804+
805+
with patch("nemo_run.core.execution.kubeflow.CommandTrainer") as mock_trainer:
806+
instance = MagicMock()
807+
mock_trainer.return_value = instance
808+
809+
res = executor._get_custom_trainer(script_task)
810+
811+
assert res == instance
812+
kwargs = mock_trainer.call_args[1]
813+
assert kwargs["packages_to_install"] == ["nemo==2.0.0", "deepspeed>=0.14.0"]
814+
assert kwargs["pip_index_urls"] == ["https://pypi.org/simple", "https://extra/simple"]
815+
assert kwargs["pip_extra_args"] == ["--no-cache-dir", "--find-links", "/wheels"]

0 commit comments

Comments
 (0)