diff --git a/nemo_run/config.py b/nemo_run/config.py index 34a7d162..8bccfc36 100644 --- a/nemo_run/config.py +++ b/nemo_run/config.py @@ -49,6 +49,9 @@ RUNDIR_SPECIAL_NAME = "/$nemo_run" SCRIPTS_DIR = "scripts" +# Metadata keys +USE_WITH_RAY_CLUSTER_KEY = "use_with_ray_cluster" + def get_nemorun_home() -> str: """ diff --git a/nemo_run/run/ray/cluster.py b/nemo_run/run/ray/cluster.py index e955af7a..e165adea 100644 --- a/nemo_run/run/ray/cluster.py +++ b/nemo_run/run/ray/cluster.py @@ -14,32 +14,41 @@ # limitations under the License. from dataclasses import dataclass -from typing import Optional +from typing import Optional, Type from nemo_run.core.execution.base import Executor -from nemo_run.core.execution.kuberay import KubeRayExecutor from nemo_run.core.execution.slurm import SlurmExecutor -from nemo_run.run.ray.kuberay import KubeRayCluster from nemo_run.run.ray.slurm import SlurmRayCluster -USE_WITH_RAY_CLUSTER_KEY = "use_with_ray_cluster" +# Import guard for Kubernetes dependencies +try: + from nemo_run.core.execution.kuberay import KubeRayExecutor + from nemo_run.run.ray.kuberay import KubeRayCluster + + _KUBERAY_AVAILABLE = True +except ImportError: + KubeRayExecutor = None + KubeRayCluster = None + _KUBERAY_AVAILABLE = False @dataclass(kw_only=True) class RayCluster: - BACKEND_MAP = { - KubeRayExecutor: KubeRayCluster, - SlurmExecutor: SlurmRayCluster, - } - name: str executor: Executor def __post_init__(self): - if self.executor.__class__ not in self.BACKEND_MAP: + backend_map: dict[Type[Executor], Type] = { + SlurmExecutor: SlurmRayCluster, + } + + if _KUBERAY_AVAILABLE and KubeRayExecutor is not None and KubeRayCluster is not None: + backend_map[KubeRayExecutor] = KubeRayCluster + + if self.executor.__class__ not in backend_map: raise ValueError(f"Unsupported executor: {self.executor.__class__}") - backend_cls = self.BACKEND_MAP[self.executor.__class__] + backend_cls = backend_map[self.executor.__class__] self.backend = backend_cls(name=self.name, executor=self.executor) # type: ignore[arg-type] self._port_forward_map = {} diff --git a/nemo_run/run/ray/job.py b/nemo_run/run/ray/job.py index b0ed2548..2abe6f6c 100644 --- a/nemo_run/run/ray/job.py +++ b/nemo_run/run/ray/job.py @@ -14,35 +14,45 @@ # limitations under the License. from dataclasses import dataclass -from typing import Any, Optional +from typing import Any, Optional, Type from nemo_run.core.execution.base import Executor -from nemo_run.core.execution.kuberay import KubeRayExecutor from nemo_run.core.execution.slurm import SlurmExecutor -from nemo_run.run.ray.kuberay import KubeRayJob from nemo_run.run.ray.slurm import SlurmRayJob +# Import guard for Kubernetes dependencies +try: + from nemo_run.core.execution.kuberay import KubeRayExecutor + from nemo_run.run.ray.kuberay import KubeRayJob + + _KUBERAY_AVAILABLE = True +except ImportError: + KubeRayExecutor = None + KubeRayJob = None + _KUBERAY_AVAILABLE = False + @dataclass(kw_only=True) class RayJob: """Backend-agnostic convenience wrapper around Ray *jobs*.""" - BACKEND_MAP = { - KubeRayExecutor: KubeRayJob, - SlurmExecutor: SlurmRayJob, - } - name: str executor: Executor pre_ray_start_commands: Optional[list[str]] = None def __post_init__(self) -> None: # noqa: D401 – simple implementation - if self.executor.__class__ not in self.BACKEND_MAP: + backend_map: dict[Type[Executor], Type[Any]] = { + SlurmExecutor: SlurmRayJob, + } + + if _KUBERAY_AVAILABLE and KubeRayExecutor is not None and KubeRayJob is not None: + backend_map[KubeRayExecutor] = KubeRayJob + + if self.executor.__class__ not in backend_map: raise ValueError(f"Unsupported executor: {self.executor.__class__}") - self.backend = self.BACKEND_MAP[self.executor.__class__]( - name=self.name, executor=self.executor - ) + backend_cls = backend_map[self.executor.__class__] + self.backend = backend_cls(name=self.name, executor=self.executor) # ------------------------------------------------------------------ # Public API diff --git a/nemo_run/run/torchx_backend/packaging.py b/nemo_run/run/torchx_backend/packaging.py index 92008db9..e915e9b0 100644 --- a/nemo_run/run/torchx_backend/packaging.py +++ b/nemo_run/run/torchx_backend/packaging.py @@ -21,7 +21,7 @@ import fiddle._src.experimental.dataclasses as fdl_dc from torchx import specs -from nemo_run.config import SCRIPTS_DIR, Partial, Script +from nemo_run.config import SCRIPTS_DIR, USE_WITH_RAY_CLUSTER_KEY, Partial, Script from nemo_run.core.execution.base import Executor from nemo_run.core.execution.dgxcloud import DGXCloudExecutor from nemo_run.core.execution.launcher import FaultTolerance, Torchrun @@ -30,7 +30,6 @@ from nemo_run.core.execution.slurm import SlurmExecutor from nemo_run.core.serialization.yaml import YamlSerializer from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer -from nemo_run.run.ray.cluster import USE_WITH_RAY_CLUSTER_KEY from nemo_run.run.torchx_backend.components import ft_launcher, torchrun log: logging.Logger = logging.getLogger(__name__) diff --git a/nemo_run/run/torchx_backend/schedulers/slurm.py b/nemo_run/run/torchx_backend/schedulers/slurm.py index f20778fd..f0d4f6f0 100644 --- a/nemo_run/run/torchx_backend/schedulers/slurm.py +++ b/nemo_run/run/torchx_backend/schedulers/slurm.py @@ -50,12 +50,11 @@ ) from torchx.specs.api import is_terminal -from nemo_run.config import RUNDIR_NAME, from_dict, get_nemorun_home +from nemo_run.config import RUNDIR_NAME, USE_WITH_RAY_CLUSTER_KEY, from_dict, get_nemorun_home from nemo_run.core.execution.base import Executor from nemo_run.core.execution.slurm import SlurmBatchRequest, SlurmExecutor, SlurmJobDetails from nemo_run.core.tunnel.client import LocalTunnel, PackagingJob, SSHTunnel, Tunnel from nemo_run.run import experiment as run_experiment -from nemo_run.run.ray.cluster import USE_WITH_RAY_CLUSTER_KEY from nemo_run.run.ray.slurm import SlurmRayRequest from nemo_run.run.torchx_backend.schedulers.api import SchedulerMixin