Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions nemo_run/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@
RUNDIR_SPECIAL_NAME = "/$nemo_run"
SCRIPTS_DIR = "scripts"

# Metadata keys
USE_WITH_RAY_CLUSTER_KEY = "use_with_ray_cluster"


def get_nemorun_home() -> str:
"""
Expand Down
31 changes: 20 additions & 11 deletions nemo_run/run/ray/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,41 @@
# limitations under the License.

from dataclasses import dataclass
from typing import Optional
from typing import Optional, Type

from nemo_run.core.execution.base import Executor
from nemo_run.core.execution.kuberay import KubeRayExecutor
from nemo_run.core.execution.slurm import SlurmExecutor
from nemo_run.run.ray.kuberay import KubeRayCluster
from nemo_run.run.ray.slurm import SlurmRayCluster

USE_WITH_RAY_CLUSTER_KEY = "use_with_ray_cluster"
# Import guard for Kubernetes dependencies
try:
from nemo_run.core.execution.kuberay import KubeRayExecutor
from nemo_run.run.ray.kuberay import KubeRayCluster

_KUBERAY_AVAILABLE = True
except ImportError:
KubeRayExecutor = None
KubeRayCluster = None
_KUBERAY_AVAILABLE = False


@dataclass(kw_only=True)
class RayCluster:
BACKEND_MAP = {
KubeRayExecutor: KubeRayCluster,
SlurmExecutor: SlurmRayCluster,
}

name: str
executor: Executor

def __post_init__(self):
if self.executor.__class__ not in self.BACKEND_MAP:
backend_map: dict[Type[Executor], Type] = {
SlurmExecutor: SlurmRayCluster,
}

if _KUBERAY_AVAILABLE and KubeRayExecutor is not None and KubeRayCluster is not None:
backend_map[KubeRayExecutor] = KubeRayCluster

if self.executor.__class__ not in backend_map:
raise ValueError(f"Unsupported executor: {self.executor.__class__}")

backend_cls = self.BACKEND_MAP[self.executor.__class__]
backend_cls = backend_map[self.executor.__class__]
self.backend = backend_cls(name=self.name, executor=self.executor) # type: ignore[arg-type]

self._port_forward_map = {}
Expand Down
34 changes: 22 additions & 12 deletions nemo_run/run/ray/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,45 @@
# limitations under the License.

from dataclasses import dataclass
from typing import Any, Optional
from typing import Any, Optional, Type

from nemo_run.core.execution.base import Executor
from nemo_run.core.execution.kuberay import KubeRayExecutor
from nemo_run.core.execution.slurm import SlurmExecutor
from nemo_run.run.ray.kuberay import KubeRayJob
from nemo_run.run.ray.slurm import SlurmRayJob

# Import guard for Kubernetes dependencies
try:
from nemo_run.core.execution.kuberay import KubeRayExecutor
from nemo_run.run.ray.kuberay import KubeRayJob

_KUBERAY_AVAILABLE = True
except ImportError:
KubeRayExecutor = None
KubeRayJob = None
_KUBERAY_AVAILABLE = False


@dataclass(kw_only=True)
class RayJob:
"""Backend-agnostic convenience wrapper around Ray *jobs*."""

BACKEND_MAP = {
KubeRayExecutor: KubeRayJob,
SlurmExecutor: SlurmRayJob,
}

name: str
executor: Executor
pre_ray_start_commands: Optional[list[str]] = None

def __post_init__(self) -> None: # noqa: D401 – simple implementation
if self.executor.__class__ not in self.BACKEND_MAP:
backend_map: dict[Type[Executor], Type[Any]] = {
SlurmExecutor: SlurmRayJob,
}

if _KUBERAY_AVAILABLE and KubeRayExecutor is not None and KubeRayJob is not None:
backend_map[KubeRayExecutor] = KubeRayJob

if self.executor.__class__ not in backend_map:
raise ValueError(f"Unsupported executor: {self.executor.__class__}")

self.backend = self.BACKEND_MAP[self.executor.__class__](
name=self.name, executor=self.executor
)
backend_cls = backend_map[self.executor.__class__]
self.backend = backend_cls(name=self.name, executor=self.executor)

# ------------------------------------------------------------------
# Public API
Expand Down
3 changes: 1 addition & 2 deletions nemo_run/run/torchx_backend/packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import fiddle._src.experimental.dataclasses as fdl_dc
from torchx import specs

from nemo_run.config import SCRIPTS_DIR, Partial, Script
from nemo_run.config import SCRIPTS_DIR, USE_WITH_RAY_CLUSTER_KEY, Partial, Script
from nemo_run.core.execution.base import Executor
from nemo_run.core.execution.dgxcloud import DGXCloudExecutor
from nemo_run.core.execution.launcher import FaultTolerance, Torchrun
Expand All @@ -30,7 +30,6 @@
from nemo_run.core.execution.slurm import SlurmExecutor
from nemo_run.core.serialization.yaml import YamlSerializer
from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer
from nemo_run.run.ray.cluster import USE_WITH_RAY_CLUSTER_KEY
from nemo_run.run.torchx_backend.components import ft_launcher, torchrun

log: logging.Logger = logging.getLogger(__name__)
Expand Down
3 changes: 1 addition & 2 deletions nemo_run/run/torchx_backend/schedulers/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,11 @@
)
from torchx.specs.api import is_terminal

from nemo_run.config import RUNDIR_NAME, from_dict, get_nemorun_home
from nemo_run.config import RUNDIR_NAME, USE_WITH_RAY_CLUSTER_KEY, from_dict, get_nemorun_home
from nemo_run.core.execution.base import Executor
from nemo_run.core.execution.slurm import SlurmBatchRequest, SlurmExecutor, SlurmJobDetails
from nemo_run.core.tunnel.client import LocalTunnel, PackagingJob, SSHTunnel, Tunnel
from nemo_run.run import experiment as run_experiment
from nemo_run.run.ray.cluster import USE_WITH_RAY_CLUSTER_KEY
from nemo_run.run.ray.slurm import SlurmRayRequest
from nemo_run.run.torchx_backend.schedulers.api import SchedulerMixin

Expand Down
Loading