From 3369160497a66f7891dd2c89e5029b7254f321e0 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Fri, 16 May 2025 08:43:15 -0700 Subject: [PATCH 01/18] Add SlurmRayCluster Signed-off-by: Hemil Desai --- nemo_run/core/execution/utils.py | 13 +- nemo_run/run/ray/cluster.py | 18 +- nemo_run/run/ray/kuberay.py | 34 +- nemo_run/run/ray/slurm.py | 731 ++++++++++++++++++++++++++ nemo_run/run/ray/templates/ray.sub.j2 | 312 +++++++++++ 5 files changed, 1093 insertions(+), 15 deletions(-) create mode 100644 nemo_run/run/ray/slurm.py create mode 100644 nemo_run/run/ray/templates/ray.sub.j2 diff --git a/nemo_run/core/execution/utils.py b/nemo_run/core/execution/utils.py index 7ff5b4d5..5c4f48ac 100644 --- a/nemo_run/core/execution/utils.py +++ b/nemo_run/core/execution/utils.py @@ -14,22 +14,23 @@ # limitations under the License. import os +from typing import Optional import jinja2 -def fill_template(template_name: str, variables: dict) -> str: +def fill_template(template_name: str, variables: dict, template_dir: Optional[str] = None) -> str: """Create a file from a Jinja template and return the filename.""" assert template_name.endswith(".j2"), template_name - root_dir = os.path.dirname(__file__) - template_path = os.path.join(root_dir, "templates", template_name) + template_dir = template_dir or os.path.join(os.path.dirname(__file__), "templates") + template_path = os.path.join(template_dir, template_name) if not os.path.exists(template_path): raise FileNotFoundError(f'Template "{template_name}" does not exist.') with open(template_path, "r", encoding="utf-8") as fin: template = fin.read() - j2_template = jinja2.Environment( - loader=jinja2.FileSystemLoader(os.path.join(os.path.dirname(__file__), "templates")) - ).from_string(template) + j2_template = jinja2.Environment(loader=jinja2.FileSystemLoader(template_dir)).from_string( + template + ) content = j2_template.render(**variables) return content diff --git a/nemo_run/run/ray/cluster.py b/nemo_run/run/ray/cluster.py index d193562c..454d0acc 100644 --- a/nemo_run/run/ray/cluster.py +++ b/nemo_run/run/ray/cluster.py @@ -14,32 +14,42 @@ # limitations under the License. from dataclasses import dataclass +from typing import Optional from nemo_run.core.execution.base import Executor from nemo_run.core.execution.kuberay import KubeRayExecutor +from nemo_run.core.execution.slurm import SlurmExecutor from nemo_run.run.ray.kuberay import KubeRayCluster +from nemo_run.run.ray.slurm import SlurmRayCluster @dataclass(kw_only=True) class RayCluster: BACKEND_MAP = { KubeRayExecutor: KubeRayCluster, + SlurmExecutor: SlurmRayCluster, } name: str executor: Executor + pre_ray_start_commands: Optional[list[str]] = None def __post_init__(self): if self.executor.__class__ not in self.BACKEND_MAP: - raise ValueError(f"Unsupported executor: {self.executor}") + raise ValueError(f"Unsupported executor: {self.executor.__class__}") self.backend = self.BACKEND_MAP[self.executor.__class__]() self._port_forward_map = {} - def start(self, wait_until_ready: bool = True, timeout: int = 1000): + def start(self, wait_until_ready: bool = True, timeout: int = 1000, dryrun: bool = False): assert isinstance(self.executor, self.backend.EXECUTOR_CLS) - self.backend.create_ray_cluster(name=self.name, executor=self.executor) + self.backend.create_ray_cluster( + name=self.name, + executor=self.executor, + pre_ray_start_commands=self.pre_ray_start_commands, + dryrun=dryrun, + ) if wait_until_ready: self.backend.wait_until_ray_cluster_running( name=self.name, executor=self.executor, timeout=timeout @@ -52,7 +62,7 @@ def port_forward(self, port: int = 8265, target_port: int = 8265, wait: bool = F self._port_forward_map[port] = self.backend.port_forward( name=self.name, - k8s_namespace=self.executor.namespace, + executor=self.executor, port=port, target_port=target_port, wait=wait, diff --git a/nemo_run/run/ray/kuberay.py b/nemo_run/run/ray/kuberay.py index f71ad858..eb313e3d 100644 --- a/nemo_run/run/ray/kuberay.py +++ b/nemo_run/run/ray/kuberay.py @@ -18,6 +18,7 @@ import time from typing import Any, Optional +import yaml from kubernetes import client, config from kubernetes.client.rest import ApiException @@ -171,9 +172,9 @@ def wait_until_ray_cluster_running( self, name: str, executor: KubeRayExecutor, - k8s_namespace: Optional[str] = None, timeout: int = 60, delay_between_attempts: int = 5, + k8s_namespace: Optional[str] = None, ) -> bool: namespace = k8s_namespace or executor.namespace logger.info( @@ -216,7 +217,12 @@ def wait_until_ray_cluster_running( return False def create_ray_cluster( - self, name: str, executor: KubeRayExecutor, k8s_namespace: Optional[str] = None + self, + name: str, + executor: KubeRayExecutor, + pre_ray_start_commands: Optional[list[str]] = None, + dryrun: bool = False, + k8s_namespace: Optional[str] = None, ) -> Any: namespace = k8s_namespace or executor.namespace logger.info(f"Creating Ray cluster: {name} in namespace: {namespace}") @@ -229,12 +235,30 @@ def create_ray_cluster( Returns: Any: The created custom resource, or None if it already exists or there was an error. """ + if pre_ray_start_commands: + k8s_pre_ray_start_commands = "\n".join(pre_ray_start_commands) + executor.lifecycle_kwargs["postStart"] = { + "exec": { + "command": [ + "/bin/sh", + "-c", + k8s_pre_ray_start_commands, + ] + } + } + + body = executor.get_cluster_body(name) + + if dryrun: + print(yaml.dump(body)) + return + try: resource: Any = self.api.create_namespaced_custom_object( group=GROUP, version=VERSION, plural=PLURAL, - body=executor.get_cluster_body(name), + body=body, namespace=k8s_namespace or executor.namespace, ) return resource @@ -405,7 +429,7 @@ def port_forward( name: str, port: int, target_port: int, - k8s_namespace: str, + executor: KubeRayExecutor, wait: bool = False, ): """Port forward a Ray cluster service using kubectl in a daemon thread. @@ -435,7 +459,7 @@ def port_forward( import time # Get cluster details - cluster = self.get_ray_cluster(name, k8s_namespace or "default") + cluster = self.get_ray_cluster(name, executor.namespace or "default") if not cluster: raise RuntimeError(f"Could not find Ray cluster {name}") diff --git a/nemo_run/run/ray/slurm.py b/nemo_run/run/ray/slurm.py new file mode 100644 index 00000000..fcd61ff9 --- /dev/null +++ b/nemo_run/run/ray/slurm.py @@ -0,0 +1,731 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import queue +import shlex +import socket +import subprocess +import tempfile +import threading +import time +import warnings +from dataclasses import asdict, dataclass +from typing import Any, Dict, Optional, TypeAlias, Union + +from nemo_run.core.execution.slurm import SlurmExecutor, _as_sbatch_flag +from nemo_run.core.execution.utils import fill_template + +noquote: TypeAlias = str + +logger = logging.getLogger(__name__) + + +@dataclass(kw_only=True) +class SlurmRayRequest: + name: str + cluster_dir: str + template_path: str + executor: SlurmExecutor + pre_ray_start_commands: Optional[list[str]] = None + + @staticmethod + def get_job_name(executor: SlurmExecutor, name: str) -> str: + job_name_prefix = ( + executor.job_name_prefix + if executor.job_name_prefix + else f"{executor.account}-{executor.account.split('_')[-1]}." + ) + return f"{job_name_prefix}{name}" + + def materialize(self) -> str: + args = asdict(self.executor) # noqa: F821 + parameters = { + k: v for k, v in args.items() if v is not None and k in SlurmExecutor.SBATCH_FLAGS + } + + # rename and reformat parameters + + if "cpus_per_gpu" in parameters and "gpus_per_task" not in parameters: + warnings.warn( # noqa: F821 + '"cpus_per_gpu" requires to set "gpus_per_task" to work (and not "gpus_per_node")' + ) + # add necessary parameters + job_name = SlurmRayRequest.get_job_name(self.executor, self.name) + slurm_job_dir = self.cluster_dir + job_details = self.executor.job_details + + if not job_details.job_name: + job_details.job_name = job_name + + if not job_details.folder: + job_details.folder = os.path.join(slurm_job_dir, "logs") + + parameters["job_name"] = job_details.job_name + + stdout = str(job_details.stdout) + stderr = str(job_details.stderr) + + assert self.executor.array is None, "array is not supported for ray clusters" + parameters["output"] = stdout.replace("%t", "0") + + if not self.executor.stderr_to_stdout: + parameters["error"] = stderr.replace("%t", "0") + + if self.executor.additional_parameters is not None: + parameters.update(self.executor.additional_parameters) + + sbatch_flags = [] + assert not self.executor.heterogeneous, "heterogeneous is not supported for ray clusters" + for k in sorted(parameters): + sbatch_flags.append(_as_sbatch_flag(k, parameters[k])) + + if self.executor.dependencies: + slurm_deps = self.executor.parse_deps() + sbatch_flags.append( + _as_sbatch_flag( + "dependency", f"{self.executor.dependency_type}:{':'.join(slurm_deps)}" + ) + ) + + env_vars = [] + for key, value in self.executor.env_vars.items(): + env_vars.append(f"export {key.upper()}={value}") + + def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str: + _srun_flags = [f"--container-image={container_image}"] if container_image else [] + _srun_flags.append("--no-container-mount-home") + _srun_flags.append("--mpi=pmix") + _srun_flags.append(f"-A={self.executor.account}") + _srun_flags.append(f"-p={self.executor.partition}") + if self.executor.gres: + _srun_flags.append(f"--gres={self.executor.gres}") + elif self.executor.gpus_per_node: + _srun_flags.append(f"--gres=gpu:{self.executor.gpus_per_node}") + else: + _srun_flags.append("--gres=gpu:8") + + _srun_flags.append(f"--container-workdir={self.cluster_dir}") + _srun_flags += ["--container-mounts", ",".join(mounts)] + + return " ".join(_srun_flags) + + vars_to_fill = { + "sbatch_flags": sbatch_flags, + "cluster_dir": self.cluster_dir, + "log_dir": os.path.join(self.cluster_dir, "logs"), + "uv_cache_dir": os.path.join(self.cluster_dir, "uv_cache"), + "num_retries": max(1, self.executor.retries), + "env_vars": env_vars, + "setup_lines": self.executor.setup_lines, + "common_srun_args": get_srun_flags( + self.executor.container_mounts, self.executor.container_image + ), + } + + if self.pre_ray_start_commands: + vars_to_fill["pre_ray_start_commands"] = "\n".join(self.pre_ray_start_commands) + + sbatch_script = fill_template(self.template_path, vars_to_fill) + return sbatch_script + + def __repr__(self) -> str: + return f""" +#---------------- +# SBATCH_SCRIPT +#---------------- + +{self.materialize()}""" + + +class SlurmRayCluster: + EXECUTOR_CLS = SlurmExecutor + + def __init__(self): + self.cluster_map = {} + + def _get_ray_cluster_info(self, name: str, executor: SlurmExecutor) -> Dict[str, Any]: + """Get Ray cluster information from ray_cluster_info.json file. + + Parameters: + - name (str): The name of the Ray cluster + - executor (SlurmExecutor): The executor containing the tunnel + + Returns: + Dict containing Ray cluster information or empty dict if info not found + """ + executor.tunnel.connect() + cluster_dir = os.path.join(executor.tunnel.job_dir, name) + cmd = f"test -f {cluster_dir}/ray_cluster_info.json && cat {cluster_dir}/ray_cluster_info.json" + result = executor.tunnel.run(cmd, warn=True) + + if result.return_code == 0 and result.stdout.strip(): + try: + return json.loads(result.stdout.strip()) + except json.JSONDecodeError: + logger.error(f"Failed to parse Ray cluster info for '{name}'") + return {} + return {} + + def get_ray_cluster_status( + self, + name: str, + executor: SlurmExecutor, + ) -> Dict[str, Union[str, bool, None]]: + executor.tunnel.connect() + + # Try to find the job by name + job_name = SlurmRayRequest.get_job_name(executor, name) + + cmd = f"squeue -n {job_name} -h -o %A" + result = executor.tunnel.run(cmd) + + job_id = result.stdout.strip() + + # If job not found in running jobs, check if it's in cluster_map + if not job_id and name in self.cluster_map: + job_id = self.cluster_map[name] + # Verify this job_id exists + cmd = f"squeue -j {job_id} -h -o %A" + result = executor.tunnel.run(cmd) + if not result.stdout.strip(): + # Job might be completed, check sacct + cmd = f"sacct -j {job_id} --format=State --noheader --parsable2" + result = executor.tunnel.run(cmd) + if result.stdout.strip(): + state = result.stdout.strip().split("\n")[0] + return {"state": state, "job_id": job_id, "ray_ready": state == "COMPLETED"} + # Job not found in sacct either, so it doesn't exist + return {"state": "NOT_FOUND", "job_id": None, "ray_ready": False} + + if not job_id: + return {"state": "NOT_FOUND", "job_id": None, "ray_ready": False} + + # Store job_id in cluster_map for future reference + self.cluster_map[name] = job_id + + # Check job status + cmd = f"squeue -j {job_id} -h -o %T" + result = executor.tunnel.run(cmd) + + if not result.stdout.strip(): + # Job not found in squeue, check sacct + cmd = f"sacct -j {job_id} --format=State --noheader --parsable2" + result = executor.tunnel.run(cmd) + status = result.stdout.strip().split("\n")[0] if result.stdout.strip() else "UNKNOWN" + + return {"state": status, "job_id": job_id, "ray_ready": status == "COMPLETED"} + + status = result.stdout.strip() + + # When running, also check if ray is actually ready + ray_ready = False + if status == "RUNNING": + ray_cluster_info = self._get_ray_cluster_info(name, executor) + if ray_cluster_info: + ray_ready = True + + return {"state": status, "job_id": job_id, "ray_ready": ray_ready} + + def create_ray_cluster( + self, + name: str, + executor: SlurmExecutor, + pre_ray_start_commands: Optional[list[str]] = None, + dryrun: bool = False, + ) -> Any: + # Check if a cluster with this name already exists + status = self.get_ray_cluster_status(name, executor) + + if status["job_id"] is not None: + job_state = status["state"] + if job_state in ["PENDING", "RUNNING", "CONFIGURING", "COMPLETING"]: + logger.info( + f"Ray cluster '{name}' already exists with job ID {status['job_id']} " + f"and is currently in {job_state} state. " + f"Skipping creation." + ) + return None + elif job_state not in ["COMPLETED", "CANCELLED", "FAILED", "TIMEOUT", "NOT_FOUND"]: + logger.warning( + f"Ray cluster '{name}' exists with job ID {status['job_id']} " + f"in state {job_state}. Creating new cluster anyway." + ) + + cluster_dir = os.path.join(executor.tunnel.job_dir, name) + ray_sbatch = SlurmRayRequest( + name=name, + cluster_dir=cluster_dir, + template_path=os.path.join( + os.path.dirname(os.path.abspath(__file__)), "templates", "ray.sub.j2" + ), + executor=executor, + pre_ray_start_commands=pre_ray_start_commands, + ).materialize() + + if dryrun: + print(ray_sbatch) + return + + executor.tunnel.connect() + executor.tunnel.run(f"mkdir -p {cluster_dir}") + + with tempfile.NamedTemporaryFile(mode="w", delete=True) as f: + f.write(ray_sbatch) + f.flush() + os.fsync(f.fileno()) + ray_sbatch_path = f.name + executor.tunnel.put(ray_sbatch_path, os.path.join(cluster_dir, "ray.sub")) + + sbatch_cmd = ["sbatch", "--parsable", os.path.join(cluster_dir, "ray.sub")] + job_id = executor.tunnel.run(" ".join(sbatch_cmd)).stdout.strip() + + # Store job_id in cluster_map + self.cluster_map[name] = job_id + + logger.info(f"Slurm job for Ray cluster '{name}' created with job ID {job_id}") + + return job_id + + def wait_until_ray_cluster_running( + self, + name: str, + executor: SlurmExecutor, + timeout: int = 600, + delay_between_attempts: int = 30, + ) -> bool: + start_time = time.time() + while time.time() - start_time < timeout: + status = self.get_ray_cluster_status(name, executor) + + if status["ray_ready"]: + logger.info(f"Ray cluster '{name}' is ready.") + return True + + # If job failed or was cancelled, return False + if status["state"] in ["FAILED", "CANCELLED", "TIMEOUT", "NOT_FOUND"]: + logger.error(f"Ray cluster '{name}' failed to start. Job state: {status['state']}") + return False + + logger.info(f"Ray cluster '{name}' is not ready, waiting for it to be ready...") + time.sleep(delay_between_attempts) + + logger.info(f"Ray cluster '{name}' is not ready after {timeout} seconds") + return False + + def delete_ray_cluster( + self, + name: str, + executor: SlurmExecutor, + wait: bool = False, + timeout: int = 60, + poll_interval: int = 5, + ) -> bool: + status = self.get_ray_cluster_status(name, executor) + + if status["job_id"] is None: + logger.warning(f"Ray cluster '{name}' does not exist or is already deleted") + return True + + job_id = status["job_id"] + + # If job is already completed or failed, no need to cancel + if any( + state in status["state"] # type: ignore + for state in ["COMPLETED", "FAILED", "CANCELLED", "TIMEOUT", "NOT_FOUND"] + ): + logger.info(f"Ray cluster '{name}' job {job_id} is already in state {status['state']}") + # Remove from cluster_map + if name in self.cluster_map: + del self.cluster_map[name] + return True + # Cancel the job + executor.tunnel.connect() + cmd = f"scancel {job_id}" + logger.info(f"Cancelling Ray cluster '{name}' job {job_id}") + + try: + executor.tunnel.run(cmd) + except Exception as e: + logger.error(f"Failed to cancel Ray cluster '{name}' job {job_id}: {e}") + return False + + # Remove from cluster_map if it exists + if name in self.cluster_map: + del self.cluster_map[name] + + # Wait for job to be fully terminated if requested + if wait: + start_time = time.time() + while time.time() - start_time < timeout: + status = self.get_ray_cluster_status(name, executor) + + # If job is not found anymore, it's been successfully cancelled + if status["job_id"] is None: + logger.info( + f"Ray cluster '{name}' job {job_id} has been successfully cancelled" + ) + if name in self.cluster_map: + del self.cluster_map[name] + return True + + # If job is in a terminated state, success + if any(state in status["state"] for state in ["CANCELLED", "FAILED", "TIMEOUT"]): # type: ignore + logger.info( + f"Ray cluster '{name}' job {job_id} is now in state {status['state']}" + ) + if name in self.cluster_map: + del self.cluster_map[name] + return True + + logger.info(f"Waiting for Ray cluster '{name}' job {job_id} to terminate...") + time.sleep(poll_interval) + + logger.warning(f"Timed out waiting for Ray cluster '{name}' job {job_id} to terminate") + return False + + return True + + def port_forward( + self, + name: str, + port: int, + target_port: int, + executor: SlurmExecutor, + wait: bool = False, + ): + """Port forward to a Ray cluster using SSH tunnel. + + When you want to stop the forwarding: + forward_thread.stop_forwarding() # Call this method to stop forwarding + + If wait=True, this function will block until interrupted (e.g., with Ctrl+C). + + Parameters: + - name (str): The name of the Ray cluster. + - port (int): The local port to use for forwarding. + - target_port (int): The target port on the Ray cluster to forward to. + - executor (SlurmExecutor): The executor containing the tunnel configuration. + - wait (bool, optional): If True, block indefinitely until interrupted. Defaults to False. + + Returns: + - ForwardingThread: A thread object with stop_forwarding method. + + Raises: + - RuntimeError: If the Ray cluster info cannot be found or is incomplete. + - TimeoutError: If port forwarding fails to establish within the timeout period. + """ + # Check if cluster exists and is running + status = self.get_ray_cluster_status(name, executor) + if status["job_id"] is None: + raise RuntimeError(f"Could not find Ray cluster {name}") + + if not status["ray_ready"]: + raise RuntimeError(f"Ray cluster {name} is not running or not ready yet") + + # Get cluster info + ray_cluster_info = self._get_ray_cluster_info(name, executor) + if not ray_cluster_info: + raise RuntimeError(f"Could not find Ray cluster info for {name}") + + if "head_ip" not in ray_cluster_info: + raise RuntimeError(f"Ray cluster info for {name} does not contain head_ip") + + head_ip = ray_cluster_info["head_ip"] + + # Use a queue for thread communication + status_queue = queue.Queue() + stop_event = threading.Event() + + class ForwardingThread(threading.Thread): + def __init__(self, daemon=True): + super().__init__(daemon=daemon) + self._stop_event = stop_event + self._ssh_process = None + + def run(self): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + sock.bind(("localhost", port)) + sock.close() + except socket.error: + sock.close() + raise RuntimeError(f"Port {port} is already in use locally") + + self._ssh_process = None + ssh_cmd_list_for_error_reporting = [] + ssh_cmd_list = [] + + try: + ssh_cmd_list = ["ssh"] + ssh_cmd_list.extend(["-L", f"{port}:localhost:{target_port}"]) + ssh_cmd_list.extend( + [ + "-N", + "-o", + "StrictHostKeyChecking=no", + "-o", + "UserKnownHostsFile=/dev/null", + "-o", + "ExitOnForwardFailure=yes", + "-o", + "ConnectTimeout=10", + "-o", + "IdentitiesOnly=yes", + ] + ) + + jump_arg_str = f"{executor.tunnel.user}@{executor.tunnel.host}" + raw_jump_identity = getattr(executor.tunnel, "identity", None) + jump_identity_path_for_proxy = None + if raw_jump_identity: + expanded_path = os.path.expanduser(str(raw_jump_identity)) + if os.path.isfile(expanded_path): + jump_identity_path_for_proxy = expanded_path + logger.debug( + f"Using jump identity {jump_identity_path_for_proxy} for ProxyCommand to {jump_arg_str}" + ) + else: + logger.warning( + f"Jump host identity path {expanded_path} (from {raw_jump_identity}) not found." + ) + logger.debug(f"Using jump host spec for ProxyCommand: {jump_arg_str}") + + if jump_arg_str: + proxy_ssh_parts = ["ssh"] + if jump_identity_path_for_proxy: + proxy_ssh_parts.extend(["-i", jump_identity_path_for_proxy]) + ssh_cmd_list.extend(["-i", jump_identity_path_for_proxy]) + proxy_ssh_parts.extend( + [ + "-o", + "StrictHostKeyChecking=no", + "-o", + "UserKnownHostsFile=/dev/null", + "-o", + "ConnectTimeout=10", + ] + ) + proxy_ssh_parts.extend(["-W", "%h:%p", jump_arg_str]) + proxy_command_value = shlex.join(proxy_ssh_parts) + ssh_cmd_list.extend(["-o", f"ProxyCommand={proxy_command_value}"]) + logger.debug(f"Using ProxyCommand: {proxy_command_value}") + + target_user = getattr(executor.tunnel, "user", None) + if target_user: + target_spec = f"{str(target_user)}@{head_ip}" + else: + target_spec = head_ip + logger.warning( + f"No explicit user for target {head_ip}, SSH will use default." + ) + ssh_cmd_list.append(target_spec) + + ssh_cmd_list = [ + p for p in ssh_cmd_list if isinstance(p, str) and p.strip() != "" + ] + + if not ssh_cmd_list or "ssh" not in ssh_cmd_list[0]: + err_msg_empty_cmd = "SSH command list is invalid or empty before Popen. Cannot start forwarding." + logger.error(err_msg_empty_cmd) + status_queue.put(("error", err_msg_empty_cmd)) + return + + ssh_cmd_list_for_error_reporting = list(ssh_cmd_list) + logger.debug( + f"Constructed SSH command: {' '.join(shlex.quote(p) for p in ssh_cmd_list)}" + ) + + self._ssh_process = subprocess.Popen( + ssh_cmd_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + + status_queue.put(("success", None)) + pid_info = str(self._ssh_process.pid) if self._ssh_process else "unknown" + logger.info( + f"SSH tunnel process (PID: {pid_info}) launched. " + f"Forwarding localhost:{port} to {head_ip}:{target_port}" + ) + + while not self._stop_event.is_set() and self._ssh_process: + process_return_code = self._ssh_process.poll() + if process_return_code is not None: + stdout_bytes, stderr_bytes = self._ssh_process.communicate() + decoded_stderr = stderr_bytes.decode(errors="replace") + decoded_stdout = stdout_bytes.decode(errors="replace") + logger.error( + f"SSH tunnel process terminated unexpectedly. " + f"Return code: {process_return_code}.\\n" + f"Command: {' '.join(shlex.quote(p) for p in ssh_cmd_list_for_error_reporting)}\\n" + f"Stdout: {decoded_stdout}\\n" + f"Stderr: {decoded_stderr}" + ) + self._ssh_process = None + break + time.sleep(0.5) + + except Exception as e: + logger.error( + f"Exception in port forwarding thread run method: {str(e)}", exc_info=True + ) + cmd_for_report = ( + " ".join(shlex.quote(p) for p in ssh_cmd_list_for_error_reporting) + if ssh_cmd_list_for_error_reporting + else "[command construction failed]" + ) + error_detail = f"Error starting or managing SSH tunnel: {str(e)}. Command (if available): {cmd_for_report}" + try: + status_queue.put_nowait(("error", error_detail)) + except queue.Full: + logger.warning( + "Status queue was full when trying to report SSH setup/Popen error." + ) + except Exception as q_e: + logger.error(f"Failed to put error on status_queue: {q_e}") + + finally: + self._cleanup() + + def _cleanup(self): + if hasattr(self, "_ssh_process") and self._ssh_process: + process = self._ssh_process + pid_info = "unknown" # Default pid_info + try: + # Check if PID exists before trying to access it, in case Popen failed partially + if hasattr(process, "pid") and process.pid is not None: + pid_info = str(process.pid) + except Exception: # Broad catch if .pid access itself errors + pass # pid_info remains "unknown" + + if process.poll() is None: # Process is still running + logger.debug(f"Attempting to stop SSH tunnel process (PID: {pid_info})...") + process.terminate() # SIGTERM + try: + process.wait(timeout=2) # Short wait for graceful termination + logger.debug( + f"SSH tunnel process (PID: {pid_info}) terminated gracefully (SIGTERM), exit code: {process.returncode}." + ) + except subprocess.TimeoutExpired: + logger.warning( + f"SSH tunnel process (PID: {pid_info}) did not respond to SIGTERM within 2s. Sending SIGKILL." + ) + process.kill() # SIGKILL + try: + process.wait(timeout=1) # Shorter wait for SIGKILL + logger.debug( + f"SSH tunnel process (PID: {pid_info}) killed (SIGKILL), exit code: {process.returncode}." + ) + except subprocess.TimeoutExpired: + logger.error( + f"SSH tunnel process (PID: {pid_info}) did not terminate even after SIGKILL and 1s wait." + ) + except Exception as e: + # Catch other exceptions during wait, e.g., if process died between poll() and wait() + logger.error( + f"Exception while waiting for SSH process (PID: {pid_info}) termination: {e}" + ) + if process.poll() is not None: + logger.debug( + f"SSH tunnel process (PID: {pid_info}) had already exited with code: {process.returncode} during exception handling." + ) + else: # Process had already exited before cleanup explicitly tried to stop it + # communicate() might have been called already if termination was handled in the run loop. + # Calling it again can lead to errors if pipes are closed. + # Just log that it was already stopped. + logger.debug( + f"SSH tunnel process (PID: {pid_info}) was already stopped. Exit code: {process.returncode}." + ) + self._ssh_process = None # Ensure it's cleared + + def stop_forwarding(self): + logger.info("Stopping port forwarding") + self._stop_event.set() + + # Create and start the forwarding thread + forward_thread = ForwardingThread() + forward_thread.start() + + # Wait for port forwarding to establish or fail with a timeout + try: + status, error_msg = status_queue.get(timeout=30) + if status == "error": + raise RuntimeError(f"Failed to establish port forwarding: {error_msg}") + except queue.Empty: + stop_event.set() + time.sleep(0.2) # Give it time to clean up + raise TimeoutError("Timed out waiting for port forwarding to establish") + + # If wait option is set, block indefinitely until interrupted + if wait: + try: + # Set up signal handlers for graceful shutdown + import signal + + original_sigint_handler = signal.getsignal(signal.SIGINT) + original_sigterm_handler = signal.getsignal(signal.SIGTERM) + + def signal_handler(sig, frame): + logger.info(f"Received signal {sig} to stop port forwarding") + stop_event.set() + + # Restore original signal handlers + signal.signal(signal.SIGINT, original_sigint_handler) + signal.signal(signal.SIGTERM, original_sigterm_handler) + + # Set up signal handlers + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + logger.info("Port forwarding is active. Press Ctrl+C to stop...") + while not stop_event.is_set(): + if not forward_thread.is_alive(): + logger.error( + "Port forwarding thread died unexpectedly after successful start." + ) + break + time.sleep(1) + except KeyboardInterrupt: + logger.debug("Keyboard interrupt received, stopping port forwarding") + finally: + logger.debug("Wait loop for port forwarding ended. Ensuring stop event is set.") + stop_event.set() + forward_thread.join(timeout=10) + if forward_thread.is_alive(): + logger.warning( + "Port forwarding thread did not terminate in time after stop signal." + ) + if ( + hasattr(forward_thread, "_ssh_process") + and forward_thread._ssh_process + and forward_thread._ssh_process.poll() is None + ): + pid_info = ( + str(forward_thread._ssh_process.pid) + if forward_thread._ssh_process + else "unknown" + ) + logger.warning( + f"SSH process (PID: {pid_info}) appears to be still running. Attempting to kill." + ) + forward_thread._ssh_process.kill() + try: + forward_thread._ssh_process.wait(timeout=2) + except subprocess.TimeoutExpired: + logger.error( + f"SSH process (PID: {forward_thread._ssh_process.pid}) did not respond to kill." + ) + return forward_thread diff --git a/nemo_run/run/ray/templates/ray.sub.j2 b/nemo_run/run/ray/templates/ray.sub.j2 new file mode 100644 index 00000000..7320fde6 --- /dev/null +++ b/nemo_run/run/ray/templates/ray.sub.j2 @@ -0,0 +1,312 @@ +#!/bin/bash +# +# Generated by NeMo Run +# + +# Parameters +{%- for sbatch_flag in sbatch_flags %} +{{sbatch_flag}} +{%- endfor %} + +set -eoux pipefail + +######################################################## +# User defined variables +######################################################## +export PYTHONUNBUFFERED=1 +export SLURM_UNBUFFEREDIO=1 + +{%- for env_var in env_vars %} +{{env_var}} +{%- endfor %} + +# Ports for all nodes (should be odd numbers since we place head/worker[0] on the same node) so all workers get the odd ports, but the head will get +1 the ports +NODE_MANAGER_PORT=${NODE_MANAGER_PORT:-53001} +OBJECT_MANAGER_PORT=${OBJECT_MANAGER_PORT:-53003} +RUNTIME_ENV_AGENT_PORT=${RUNTIME_ENV_AGENT_PORT:-53005} +DASHBOARD_AGENT_GRPC_PORT=${DASHBOARD_AGENT_GRPC_PORT:-53007} +METRICS_EXPORT_PORT=${METRICS_EXPORT_PORT:-53009} + +# Ports for the head node +PORT=${PORT:-6379} +RAY_CLIENT_SERVER_PORT=${RAY_CLIENT_SERVER_PORT:-10001} +#REDIT_SHARD_PORTS=${REDIT_SHARD_PORTS:-"random"} ?? +DASHBOARD_GRPC_PORT=${DASHBOARD_GRPC_PORT:-52367} +DASHBOARD_PORT=${DASHBOARD_PORT:-8265} # Also used by debugger +DASHBOARD_AGENT_LISTEN_PORT=${DASHBOARD_AGENT_LISTEN_PORT:-52365} + +# On our clusters, the largest port range on an idle worker appeared between 52369-64607 +# (not including the other ports set by this script). So this range is chosen to be +# somewhere in the middle +MIN_WORKER_PORT=${MIN_WORKER_PORT:-54001} +MAX_WORKER_PORT=${MAX_WORKER_PORT:-54257} + +# Directory setup +export CLUSTER_DIR={{ cluster_dir }} +mkdir -p $CLUSTER_DIR/scripts + +export LOG_DIR={{ log_dir }} +mkdir -p $LOG_DIR + +# Clean up any previous run files +rm -f $LOG_DIR/STARTED_RAY_HEAD +rm -f $LOG_DIR/ENDED + +# Defaults to placing uv cache inside the CLUSTER_DIR +# This directory is mounted into the container at /home/ray/.cache/uv so it is shared between the head and worker nodes +# UV_CACHE_DIR={{ uv_cache_dir }} +# mkdir -p $UV_CACHE_DIR +######################################################## + +# Number of GPUs per node +gpus_per_node=8 + +num_retries={{ num_retries }} + +# Getting the node names and IP addresses in the SLURM allocation +nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") +nodes_array=($nodes) +ip_addresses_array=() + +for node in $nodes; do + ip_address=$(host $node | awk '/has address/ { print $4 }') + # Add the IP address to the array + ip_addresses_array+=("$ip_address") +done + +head_node=${nodes_array[0]} +head_node_ip=${ip_addresses_array[0]} + +ip_head=$head_node_ip:$PORT + +{%- if setup_lines %} +{{setup_lines}} +{%- endif %} + +######################################################## +# Ray cluster setup +######################################################## +# First we start the head of the ray cluster on one of the physical nodes +# Set GPU/CPU resources to 0 to avoid scheduling on the head node + +head_cmd=$(cat </dev/null); do + echo "[INFO][$(date)] Waiting for Ray head node container to start and be ready..." + sleep 2 +done + +NUM_ACTORS=$((gpus_per_node * SLURM_JOB_NUM_NODES)) + +# Start Ray worker nodes +# We want 1 Ray worker node per physical node +# Worker nodes are started with ray start but without the --head flag +for ((i = 1; i < SLURM_JOB_NUM_NODES; i++)); do + node_i=${nodes_array[$i]} + + worker_cmd=$(cat <$CLUSTER_DIR/ray_cluster_info.json +{ + "head_ip": "$head_node_ip", + "dashboard_port": "$DASHBOARD_PORT", + "port": "$PORT" +} +EOF +# Set up trap to clean up cluster info on job termination +cleanup_cluster_info() { + echo "[INFO] Cleaning up Ray cluster information" + rm -f $CLUSTER_DIR/ray_cluster_info.json +} + +# Register the cleanup function to run on script exit +trap cleanup_cluster_info EXIT + + +echo "[INFO] Ray cluster information saved to $CLUSTER_DIR/ray_cluster_info.json" + +######################################################## + +# We can now launch a job on this cluster +# We do so by launching a driver process on the physical node that the head node is on +# This driver process is responsible for launching a job on the Ray cluster +CONTAINER_CWD=$(scontrol show job $SLURM_JOB_ID --json | jq -r '.jobs[].current_working_directory') +# Define command to be empty by default +COMMAND="" + +if [[ -n "$COMMAND" ]]; then + srun --no-container-mount-home --gpus=0 --overlap --container-name=ray-head --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/ray-driver.log bash -c "$COMMAND" +else + echo "[INFO]: Ray Cluster is idled, run this on the slurm head node to get a shell to the head node:" + cat <$CLUSTER_DIR/scripts/${SLURM_JOB_ID}-attach.sh +# No args launches on the head node +WORKER_NUM=\${1:-} +if [[ -z "\$WORKER_NUM" ]]; then + # Empty means we are on the head node + srun --no-container-mount-home --gpus=0 -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-head --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "$head_node" --jobid $SLURM_JOB_ID --pty bash +else + nodes_array=($nodes) + srun --no-container-mount-home --gres=gpu:8 -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-worker-\$WORKER_NUM --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "\${nodes_array[\$WORKER_NUM]}" --jobid $SLURM_JOB_ID --pty bash +fi +EOF + chmod +x $CLUSTER_DIR/scripts/${SLURM_JOB_ID}-attach.sh + echo " bash $CLUSTER_DIR/scripts/${SLURM_JOB_ID}-attach.sh" + sleep infinity +fi From f93a01138f9f063d24c112d2a1641e141748aa6c Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Fri, 16 May 2025 08:51:49 -0700 Subject: [PATCH 02/18] fix logging Signed-off-by: Hemil Desai --- nemo_run/run/ray/kuberay.py | 68 ++++++++++++++++++++----------------- nemo_run/run/ray/slurm.py | 13 +++---- 2 files changed, 40 insertions(+), 41 deletions(-) diff --git a/nemo_run/run/ray/kuberay.py b/nemo_run/run/ray/kuberay.py index eb313e3d..93b5b9be 100644 --- a/nemo_run/run/ray/kuberay.py +++ b/nemo_run/run/ray/kuberay.py @@ -87,7 +87,7 @@ def list_ray_clusters( return None def get_ray_cluster(self, name: str, k8s_namespace: str = "default") -> Any: - logger.info(f"Getting Ray cluster: {name} in namespace: {k8s_namespace}") + logger.info(f"Getting Ray cluster '{name}' in namespace '{k8s_namespace}'") """Get a specific Ray cluster in a given namespace. Parameters: @@ -111,10 +111,12 @@ def get_ray_cluster(self, name: str, k8s_namespace: str = "default") -> Any: return resource except ApiException as e: if e.status == 404: - logger.error("raycluster resource is not found. error = {}".format(e)) + logger.error(f"Ray cluster '{name}' not found in namespace '{k8s_namespace}': {e}") return None else: - logger.error("error fetching custom resource: {}".format(e)) + logger.error( + f"Error fetching Ray cluster '{name}' in namespace '{k8s_namespace}': {e}" + ) return None def get_ray_cluster_status( @@ -125,7 +127,7 @@ def get_ray_cluster_status( delay_between_attempts: int = 5, ) -> Any: logger.info( - f"Getting Ray cluster status: {name} in namespace: {k8s_namespace}, timeout: {timeout}s, delay: {delay_between_attempts}s" + f"Getting Ray cluster status for '{name}' in namespace '{k8s_namespace}', timeout: {timeout}s, delay: {delay_between_attempts}s" ) """Get a specific Ray cluster in a given namespace. @@ -152,20 +154,24 @@ def get_ray_cluster_status( ) except ApiException as e: if e.status == 404: - logger.error("raycluster resource is not found. error = {}".format(e)) + logger.error( + f"Ray cluster '{name}' status fetch failed: resource not found: {e}" + ) return None else: - logger.error("error fetching custom resource: {}".format(e)) + logger.error( + f"Error fetching status for Ray cluster '{name}' in namespace '{k8s_namespace}': {e}" + ) return None if "status" in resource and resource["status"]: return resource["status"] else: - logger.info("raycluster {} status not set yet, waiting...".format(name)) + logger.info(f"Ray cluster '{name}' status not set yet, waiting...") time.sleep(delay_between_attempts) timeout -= delay_between_attempts - logger.info("raycluster {} status not set yet, timing out...".format(name)) + logger.info(f"Ray cluster '{name}' status not set yet, timing out...") return None def wait_until_ray_cluster_running( @@ -178,7 +184,7 @@ def wait_until_ray_cluster_running( ) -> bool: namespace = k8s_namespace or executor.namespace logger.info( - f"Waiting until Ray cluster: {name} in namespace: {namespace} is running, timeout: {timeout}s, delay: {delay_between_attempts}s" + f"Waiting until Ray cluster '{name}' is running in namespace '{namespace}', timeout: {timeout}s, delay: {delay_between_attempts}s" ) """Get a specific Ray cluster in a given namespace. @@ -197,23 +203,21 @@ def wait_until_ray_cluster_running( name, k8s_namespace or executor.namespace, timeout, delay_between_attempts ) if not status: - logger.info(f"Ray cluster {name} status could not be retrieved") + logger.info(f"Ray cluster '{name}' status could not be retrieved") return False # TODO: once we add State to Status, we should check for that as well if status and status["head"] and status["head"]["serviceIP"]: - logger.info(f"Ray cluster {name} is running") + logger.info(f"Ray cluster '{name}' is running") return True logger.info( - "raycluster {} status is not running yet, current status is {}".format( - name, status["state"] if status and "state" in status else "unknown" - ) + f"Ray cluster '{name}' status is not running yet, current status: {status.get('state', 'unknown')}" ) time.sleep(delay_between_attempts) timeout -= delay_between_attempts - logger.info("raycluster {} status is not running yet, timing out...".format(name)) + logger.info(f"Ray cluster '{name}' status is not running yet, timing out...") return False def create_ray_cluster( @@ -225,7 +229,7 @@ def create_ray_cluster( k8s_namespace: Optional[str] = None, ) -> Any: namespace = k8s_namespace or executor.namespace - logger.info(f"Creating Ray cluster: {name} in namespace: {namespace}") + logger.info(f"Creating Ray cluster '{name}' in namespace '{namespace}'") """Create a new Ray cluster custom resource. Parameters: @@ -264,10 +268,10 @@ def create_ray_cluster( return resource except ApiException as e: if e.status == 409: - logger.error("raycluster resource already exists. error = {}".format(e.reason)) + logger.error(f"Ray cluster '{name}' already exists: {e.reason}") return None else: - logger.error("error creating custom resource: {}".format(e)) + logger.error(f"Error creating Ray cluster '{name}' in namespace '{namespace}': {e}") return None def delete_ray_cluster( @@ -293,7 +297,7 @@ def delete_ray_cluster( Optional[bool]: True if deletion was successful, None if already deleted or there was an error. """ namespace = k8s_namespace or executor.namespace - logger.info(f"Deleting Ray cluster: {name} in namespace: {namespace}") + logger.info(f"Deleting Ray cluster '{name}' in namespace '{namespace}'") try: self.api.delete_namespaced_custom_object( @@ -307,7 +311,7 @@ def delete_ray_cluster( if not wait: return True - logger.info(f"Waiting for Ray cluster {name} and its pods to be fully deleted...") + logger.info(f"Waiting for Ray cluster '{name}' and its pods to be fully deleted...") start_time = time.time() cluster_deleted = False @@ -318,11 +322,11 @@ def delete_ray_cluster( try: cluster = self.get_ray_cluster(name, namespace) if not cluster: - logger.info(f"Ray cluster CR {name} has been deleted") + logger.info(f"Ray cluster CR '{name}' has been deleted") cluster_deleted = True except ApiException as e: if e.status == 404: - logger.info(f"Ray cluster CR {name} has been deleted") + logger.info(f"Ray cluster CR '{name}' has been deleted") cluster_deleted = True else: logger.error(f"Error checking Ray cluster status during deletion: {e}") @@ -336,7 +340,7 @@ def delete_ray_cluster( ) if not pods.items: - logger.info(f"All pods for Ray cluster {name} have been terminated") + logger.info(f"All pods for Ray cluster '{name}' have been terminated") return True active_pods = [pod.metadata.name for pod in pods.items] @@ -357,14 +361,14 @@ def delete_ray_cluster( # If we reach here, we've timed out logger.warning( - f"Timed out waiting for Ray cluster {name} to be fully deleted after {timeout} seconds" + f"Timed out waiting for Ray cluster '{name}' to be fully deleted after {timeout} seconds" ) # Check final state try: cluster_exists = self.get_ray_cluster(name, namespace) is not None if cluster_exists: - logger.warning(f"Ray cluster CR {name} still exists after timeout") + logger.warning(f"Ray cluster CR '{name}' still exists after timeout") pods = self.core_v1_api.list_namespaced_pod( namespace=namespace, label_selector=f"ray.io/cluster={name}" @@ -372,19 +376,19 @@ def delete_ray_cluster( if pods.items: pod_names = [pod.metadata.name for pod in pods.items] logger.warning( - f"Ray cluster {name} still has {len(pod_names)} pods: {', '.join(pod_names[:5])}" + f"Ray cluster '{name}' still has {len(pod_names)} pods: {', '.join(pod_names[:5])}" ) except Exception as e: - logger.error(f"Error checking final state of Ray cluster {name}: {e}") + logger.error(f"Error checking final state of Ray cluster '{name}': {e}") return False except ApiException as e: if e.status == 404: - logger.warning(f"Ray cluster {name} was already deleted") + logger.warning(f"Ray cluster '{name}' was already deleted") return None else: - logger.error(f"Error deleting the Ray cluster {name}: {e}") + logger.error(f"Error deleting Ray cluster '{name}': {e}") return None def patch_ray_cluster( @@ -395,7 +399,7 @@ def patch_ray_cluster( k8s_namespace: Optional[str] = None, ) -> Any: namespace = k8s_namespace or executor.namespace - logger.info(f"Patching Ray cluster: {name} in namespace: {namespace}") + logger.info(f"Patching Ray cluster '{name}' in namespace '{namespace}'") """Patch an existing Ray cluster custom resource. Parameters: @@ -417,10 +421,10 @@ def patch_ray_cluster( namespace=namespace, ) except ApiException as e: - logger.error("raycluster `{}` failed to patch, with error: {}".format(name, e)) + logger.error(f"Failed to patch Ray cluster '{name}': {e}") return False else: - logger.info("raycluster `%s` is patched successfully", name) + logger.info(f"Ray cluster '{name}' patched successfully") return True diff --git a/nemo_run/run/ray/slurm.py b/nemo_run/run/ray/slurm.py index fcd61ff9..e45ba78f 100644 --- a/nemo_run/run/ray/slurm.py +++ b/nemo_run/run/ray/slurm.py @@ -159,15 +159,6 @@ def __init__(self): self.cluster_map = {} def _get_ray_cluster_info(self, name: str, executor: SlurmExecutor) -> Dict[str, Any]: - """Get Ray cluster information from ray_cluster_info.json file. - - Parameters: - - name (str): The name of the Ray cluster - - executor (SlurmExecutor): The executor containing the tunnel - - Returns: - Dict containing Ray cluster information or empty dict if info not found - """ executor.tunnel.connect() cluster_dir = os.path.join(executor.tunnel.job_dir, name) cmd = f"test -f {cluster_dir}/ray_cluster_info.json && cat {cluster_dir}/ray_cluster_info.json" @@ -186,6 +177,7 @@ def get_ray_cluster_status( name: str, executor: SlurmExecutor, ) -> Dict[str, Union[str, bool, None]]: + logger.info(f"Getting Ray cluster status for '{name}'") executor.tunnel.connect() # Try to find the job by name @@ -248,6 +240,7 @@ def create_ray_cluster( pre_ray_start_commands: Optional[list[str]] = None, dryrun: bool = False, ) -> Any: + logger.info(f"Creating Ray cluster '{name}'") # Check if a cluster with this name already exists status = self.get_ray_cluster_status(name, executor) @@ -308,6 +301,7 @@ def wait_until_ray_cluster_running( timeout: int = 600, delay_between_attempts: int = 30, ) -> bool: + logger.info(f"Waiting until Ray cluster '{name}' is running") start_time = time.time() while time.time() - start_time < timeout: status = self.get_ray_cluster_status(name, executor) @@ -335,6 +329,7 @@ def delete_ray_cluster( timeout: int = 60, poll_interval: int = 5, ) -> bool: + logger.info(f"Deleting Ray cluster '{name}'") status = self.get_ray_cluster_status(name, executor) if status["job_id"] is None: From aeb4902cd0c1916ecaac33fc92b2c7834f2a306f Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Fri, 16 May 2025 08:58:39 -0700 Subject: [PATCH 03/18] fix Signed-off-by: Hemil Desai --- nemo_run/core/execution/kuberay.py | 2 +- nemo_run/run/ray/kuberay.py | 117 ----------------------------- 2 files changed, 1 insertion(+), 118 deletions(-) diff --git a/nemo_run/core/execution/kuberay.py b/nemo_run/core/execution/kuberay.py index ab75bbbc..80df90d5 100644 --- a/nemo_run/core/execution/kuberay.py +++ b/nemo_run/core/execution/kuberay.py @@ -80,7 +80,7 @@ class KubeRayExecutor(Executor): head_ports: list[dict[str, Any]] = field(default_factory=list) volume_mounts: list[dict[str, Any]] = field(default_factory=list) volumes: list[dict[str, Any]] = field(default_factory=list) - reuse_volumes_in_worker_groups: bool = False + reuse_volumes_in_worker_groups: bool = True spec_kwargs: dict[str, Any] = field(default_factory=dict) lifecycle_kwargs: dict[str, Any] = field(default_factory=dict) diff --git a/nemo_run/run/ray/kuberay.py b/nemo_run/run/ray/kuberay.py index 93b5b9be..edf562ea 100644 --- a/nemo_run/run/ray/kuberay.py +++ b/nemo_run/run/ray/kuberay.py @@ -28,17 +28,6 @@ class KubeRayCluster: - """ - RayClusterApi provides APIs to list, get, create, build, update, delete rayclusters. - - Methods: - - list_ray_clusters(k8s_namespace: str = "default", async_req: bool = False) -> Any: - - get_ray_cluster(name: str, k8s_namespace: str = "default") -> Any: - - create_ray_cluster(body: Any, k8s_namespace: str = "default") -> Any: - - delete_ray_cluster(name: str, k8s_namespace: str = "default") -> bool: - - patch_ray_cluster(name: str, ray_patch: Any, k8s_namespace: str = "default") -> Any: - """ - EXECUTOR_CLS = KubeRayExecutor # initial config to setup the kube client @@ -54,18 +43,7 @@ def list_ray_clusters( logger.info( f"Listing Ray clusters in namespace: {k8s_namespace}, label_selector: {label_selector}, async_req: {async_req}" ) - """List Ray clusters in a given namespace. - - Parameters: - - k8s_namespace (str, optional): The namespace in which to list the Ray clusters. Defaults to "default". - - async_req (bool, optional): Whether to make the request asynchronously. Defaults to False. - Returns: - Any: The custom resource for Ray clusters in the specified namespace, or None if not found. - - Raises: - ApiException: If there was an error fetching the custom resource. - """ try: resource: Any = self.api.list_namespaced_custom_object( group=GROUP, @@ -88,18 +66,7 @@ def list_ray_clusters( def get_ray_cluster(self, name: str, k8s_namespace: str = "default") -> Any: logger.info(f"Getting Ray cluster '{name}' in namespace '{k8s_namespace}'") - """Get a specific Ray cluster in a given namespace. - - Parameters: - - name (str): The name of the Ray cluster custom resource. Defaults to "". - - k8s_namespace (str, optional): The namespace in which to retrieve the Ray cluster. Defaults to "default". - Returns: - Any: The custom resource for the specified Ray cluster, or None if not found. - - Raises: - ApiException: If there was an error fetching the custom resource. - """ try: resource: Any = self.api.get_namespaced_custom_object( group=GROUP, @@ -129,20 +96,7 @@ def get_ray_cluster_status( logger.info( f"Getting Ray cluster status for '{name}' in namespace '{k8s_namespace}', timeout: {timeout}s, delay: {delay_between_attempts}s" ) - """Get a specific Ray cluster in a given namespace. - - Parameters: - - name (str): The name of the Ray cluster custom resource. Defaults to "". - - k8s_namespace (str, optional): The namespace in which to retrieve the Ray cluster. Defaults to "default". - - timeout (int, optional): The duration in seconds after which we stop trying to get status if still not set. Defaults to 60 seconds. - - delay_between_attempts (int, optional): The duration in seconds to wait between attempts to get status if not set. Defaults to 5 seconds. - - Returns: - Any: The custom resource status for the specified Ray cluster, or None if not found. - Raises: - ApiException: If there was an error fetching the custom resource. - """ while timeout > 0: try: resource: Any = self.api.get_namespaced_custom_object_status( @@ -186,18 +140,7 @@ def wait_until_ray_cluster_running( logger.info( f"Waiting until Ray cluster '{name}' is running in namespace '{namespace}', timeout: {timeout}s, delay: {delay_between_attempts}s" ) - """Get a specific Ray cluster in a given namespace. - - Parameters: - - name (str): The name of the Ray cluster custom resource. Defaults to "". - - k8s_namespace (str, optional): The namespace in which to retrieve the Ray cluster. Defaults to "default". - - timeout (int, optional): The duration in seconds after which we stop trying to get status. Defaults to 60 seconds. - - delay_between_attempts (int, optional): The duration in seconds to wait between attempts to get status if not set. Defaults to 5 seconds. - - Returns: - Bool: True if the raycluster status is Running, False otherwise. - """ while timeout > 0: status = self.get_ray_cluster_status( name, k8s_namespace or executor.namespace, timeout, delay_between_attempts @@ -230,15 +173,7 @@ def create_ray_cluster( ) -> Any: namespace = k8s_namespace or executor.namespace logger.info(f"Creating Ray cluster '{name}' in namespace '{namespace}'") - """Create a new Ray cluster custom resource. - Parameters: - - body (Any): The data of the custom resource to create. - - k8s_namespace (str, optional): The namespace in which to create the custom resource. Defaults to "default". - - Returns: - Any: The created custom resource, or None if it already exists or there was an error. - """ if pre_ray_start_commands: k8s_pre_ray_start_commands = "\n".join(pre_ray_start_commands) executor.lifecycle_kwargs["postStart"] = { @@ -283,19 +218,6 @@ def delete_ray_cluster( timeout: int = 300, poll_interval: int = 5, ) -> Optional[bool]: - """Delete a Ray cluster custom resource and optionally wait for deletion to complete. - - Parameters: - - name (str): The name of the Ray cluster custom resource to delete. - - executor (KubeRayExecutor): The executor containing configuration details. - - k8s_namespace (str, optional): The namespace in which the Ray cluster exists. - - wait (bool, optional): Whether to wait for the cluster and all its pods to be fully deleted. Defaults to False. - - timeout (int, optional): Maximum time to wait for deletion in seconds. Defaults to 300 seconds (5 minutes). - - poll_interval (int, optional): Time between checks for deletion status in seconds. Defaults to 5 seconds. - - Returns: - Optional[bool]: True if deletion was successful, None if already deleted or there was an error. - """ namespace = k8s_namespace or executor.namespace logger.info(f"Deleting Ray cluster '{name}' in namespace '{namespace}'") @@ -400,16 +322,6 @@ def patch_ray_cluster( ) -> Any: namespace = k8s_namespace or executor.namespace logger.info(f"Patching Ray cluster '{name}' in namespace '{namespace}'") - """Patch an existing Ray cluster custom resource. - - Parameters: - - name (str): The name of the Ray cluster custom resource to be patched. - - ray_patch (Any): The patch data for the Ray cluster. - - k8s_namespace (str, optional): The namespace in which the Ray cluster exists. Defaults to "default". - - Returns: - bool: True if the patch was successful, False otherwise. - """ try: # we patch the existing raycluster with the new config self.api.patch_namespaced_custom_object( @@ -436,27 +348,6 @@ def port_forward( executor: KubeRayExecutor, wait: bool = False, ): - """Port forward a Ray cluster service using kubectl in a daemon thread. - - When you want to stop the forwarding: - forward_thread.stop_forwarding() # Call this method to stop forwarding - - If wait=True, this function will block until interrupted (e.g., with Ctrl+C). - - Parameters: - - name (str): The name of the Ray cluster custom resource. - - port (int): The local port to use for forwarding. - - target_port (int): The target port on the Ray cluster to forward to. - - k8s_namespace (str, optional): The namespace in which the Ray cluster exists. - - wait (bool, optional): If True, block indefinitely until interrupted. Defaults to False. - - Returns: - - ForwardingThread: A thread object with stop_forwarding method. - - Raises: - - RuntimeError: If the Ray head service cannot be found. - - TimeoutError: If port forwarding fails to establish within the timeout period. - """ import queue import subprocess import threading @@ -618,14 +509,6 @@ def forward_port_daemon(): return forward_thread def _wait_for_forwarding_termination(self, forward_thread, stop_event): - """Helper method to wait for port forwarding termination. - - Sets up signal handlers and blocks until interrupted or the stop_event is set. - - Parameters: - - forward_thread: The thread running the port forwarding. - - stop_event: The event used to signal the thread to stop. - """ import signal import time From 96f4bd81d617b3601cbc2735dd78503cb5903279 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Fri, 16 May 2025 14:41:47 -0700 Subject: [PATCH 04/18] Add schedule_job to ray cluster Signed-off-by: Hemil Desai --- nemo_run/run/ray/cluster.py | 8 +++ nemo_run/run/ray/kuberay.py | 9 +++ nemo_run/run/ray/slurm.py | 132 +++++++++++++++++++++++++++++++----- 3 files changed, 132 insertions(+), 17 deletions(-) diff --git a/nemo_run/run/ray/cluster.py b/nemo_run/run/ray/cluster.py index 454d0acc..4c0f5014 100644 --- a/nemo_run/run/ray/cluster.py +++ b/nemo_run/run/ray/cluster.py @@ -55,6 +55,14 @@ def start(self, wait_until_ready: bool = True, timeout: int = 1000, dryrun: bool name=self.name, executor=self.executor, timeout=timeout ) + def schedule_job( + self, name: str, executor: Executor, command: str, workdir: str, dryrun: bool = False + ): + assert isinstance(self.executor, self.backend.EXECUTOR_CLS) + self.backend.schedule_ray_job( + name=name, executor=executor, command=command, workdir=workdir, dryrun=dryrun + ) + def port_forward(self, port: int = 8265, target_port: int = 8265, wait: bool = False): assert isinstance(self.executor, self.backend.EXECUTOR_CLS) if self._port_forward_map.get(port) is not None: diff --git a/nemo_run/run/ray/kuberay.py b/nemo_run/run/ray/kuberay.py index edf562ea..514bb889 100644 --- a/nemo_run/run/ray/kuberay.py +++ b/nemo_run/run/ray/kuberay.py @@ -209,6 +209,15 @@ def create_ray_cluster( logger.error(f"Error creating Ray cluster '{name}' in namespace '{namespace}': {e}") return None + def schedule_ray_job( + self, + name: str, + executor: KubeRayExecutor, + command: str, + workdir: str, + ): + raise NotImplementedError("KubeRay does not support scheduling jobs") + def delete_ray_cluster( self, name: str, diff --git a/nemo_run/run/ray/slurm.py b/nemo_run/run/ray/slurm.py index e45ba78f..5acd717c 100644 --- a/nemo_run/run/ray/slurm.py +++ b/nemo_run/run/ray/slurm.py @@ -25,10 +25,14 @@ import time import warnings from dataclasses import asdict, dataclass +from pathlib import Path from typing import Any, Dict, Optional, TypeAlias, Union from nemo_run.core.execution.slurm import SlurmExecutor, _as_sbatch_flag from nemo_run.core.execution.utils import fill_template +from nemo_run.core.packaging.git import GitArchivePackager +from nemo_run.core.tunnel.client import SSHTunnel +from nemo_run.core.tunnel.rsync import rsync noquote: TypeAlias = str @@ -42,6 +46,8 @@ class SlurmRayRequest: template_path: str executor: SlurmExecutor pre_ray_start_commands: Optional[list[str]] = None + command: Optional[str] = None + workdir: Optional[str] = None @staticmethod def get_job_name(executor: SlurmExecutor, name: str) -> str: @@ -135,6 +141,8 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str: "common_srun_args": get_srun_flags( self.executor.container_mounts, self.executor.container_image ), + "command": self.command, + "command_workdir": self.workdir, } if self.pre_ray_start_commands: @@ -239,41 +247,53 @@ def create_ray_cluster( executor: SlurmExecutor, pre_ray_start_commands: Optional[list[str]] = None, dryrun: bool = False, + command: Optional[str] = None, + workdir: Optional[str] = None, ) -> Any: + cluster_dir = os.path.join(executor.tunnel.job_dir, name) + ray_sbatch = SlurmRayRequest( + name=name, + cluster_dir=cluster_dir, + template_path=os.path.join( + os.path.dirname(os.path.abspath(__file__)), "templates", "ray.sub.j2" + ), + executor=executor, + pre_ray_start_commands=pre_ray_start_commands, + command=command, + workdir=workdir, + ).materialize() + + if dryrun: + logger.info(f"Dry run: Ray cluster '{name}'") + print(ray_sbatch) + return + logger.info(f"Creating Ray cluster '{name}'") # Check if a cluster with this name already exists status = self.get_ray_cluster_status(name, executor) if status["job_id"] is not None: job_state = status["state"] - if job_state in ["PENDING", "RUNNING", "CONFIGURING", "COMPLETING"]: + if job_state in ["PENDING", "RUNNING", "CONFIGURING"]: logger.info( f"Ray cluster '{name}' already exists with job ID {status['job_id']} " f"and is currently in {job_state} state. " f"Skipping creation." ) return None - elif job_state not in ["COMPLETED", "CANCELLED", "FAILED", "TIMEOUT", "NOT_FOUND"]: + elif job_state not in [ + "COMPLETING", + "COMPLETED", + "CANCELLED", + "FAILED", + "TIMEOUT", + "NOT_FOUND", + ]: logger.warning( f"Ray cluster '{name}' exists with job ID {status['job_id']} " f"in state {job_state}. Creating new cluster anyway." ) - cluster_dir = os.path.join(executor.tunnel.job_dir, name) - ray_sbatch = SlurmRayRequest( - name=name, - cluster_dir=cluster_dir, - template_path=os.path.join( - os.path.dirname(os.path.abspath(__file__)), "templates", "ray.sub.j2" - ), - executor=executor, - pre_ray_start_commands=pre_ray_start_commands, - ).materialize() - - if dryrun: - print(ray_sbatch) - return - executor.tunnel.connect() executor.tunnel.run(f"mkdir -p {cluster_dir}") @@ -294,6 +314,84 @@ def create_ray_cluster( return job_id + def schedule_ray_job( + self, + name: str, + executor: SlurmExecutor, + command: str, + workdir: Optional[str] = None, + pre_ray_start_commands: Optional[list[str]] = None, + dryrun: bool = False, + ): + remote_workdir = None + if workdir: + if isinstance(executor.tunnel, SSHTunnel): + # Rsync workdir honoring .gitignore + remote_workdir = os.path.join(executor.tunnel.job_dir, name, "code") + if not dryrun: + executor.tunnel.connect() + assert executor.tunnel.session is not None, "Tunnel session is not connected" + rsync( + executor.tunnel.session, + workdir, + remote_workdir, + rsync_opts="--filter=':- .gitignore'", + ) + else: + remote_workdir = workdir + elif executor.packager: + if not dryrun: + if isinstance(executor.tunnel, SSHTunnel): + package_dir_ref = tempfile.TemporaryDirectory() + package_dir = package_dir_ref.name + else: + package_dir_ref = None + package_dir = os.path.join(executor.tunnel.job_dir, name) + + if isinstance(executor.packager, GitArchivePackager): + output = subprocess.run( + ["git", "rev-parse", "--show-toplevel"], + check=True, + stdout=subprocess.PIPE, + ) + path = output.stdout.splitlines()[0].decode() + base_path = Path(path).absolute() + else: + base_path = Path(os.getcwd()).absolute() + + local_tar_file = executor.packager.package(base_path, package_dir, name) + local_code_extraction_path = os.path.join(package_dir, "code") + os.makedirs(local_code_extraction_path, exist_ok=True) + subprocess.run( + f"tar -xvzf {local_tar_file} -C {local_code_extraction_path} --ignore-zeros", + shell=True, + check=True, + ) + + if isinstance(executor.tunnel, SSHTunnel): + remote_workdir = os.path.join(executor.tunnel.job_dir, name, "code") + executor.tunnel.connect() + assert executor.tunnel.session is not None, "Tunnel session is not connected" + rsync( + executor.tunnel.session, + os.path.join(local_code_extraction_path, ""), + remote_workdir, + rsync_opts="--filter=':- .gitignore'", + ) + else: + remote_workdir = local_code_extraction_path + + assert remote_workdir is not None, "workdir is not set" + job_id = self.create_ray_cluster( + name, + executor, + pre_ray_start_commands=pre_ray_start_commands, + dryrun=dryrun, + command=command, + workdir=remote_workdir, + ) + return job_id + def wait_until_ray_cluster_running( self, name: str, From b2ff9f569d43a818a5697136e677cbbdc28f38db Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Mon, 19 May 2025 10:39:41 -0700 Subject: [PATCH 05/18] Add experiment integration Signed-off-by: Hemil Desai --- nemo_run/config.py | 2 + nemo_run/core/execution/slurm.py | 103 +++++++++--------- nemo_run/core/execution/utils.py | 2 +- nemo_run/run/ray/cluster.py | 2 + nemo_run/run/ray/slurm.py | 35 ++++-- nemo_run/run/ray/templates/ray.sub.j2 | 5 +- nemo_run/run/torchx_backend/packaging.py | 12 ++ .../run/torchx_backend/schedulers/slurm.py | 88 +++++++++------ test/core/execution/test_slurm_templates.py | 86 +++++++-------- 9 files changed, 192 insertions(+), 143 deletions(-) diff --git a/nemo_run/config.py b/nemo_run/config.py index c018be1b..34a7d162 100644 --- a/nemo_run/config.py +++ b/nemo_run/config.py @@ -449,6 +449,8 @@ class Script(ConfigurableMixin): #: Whether to use ``python -m`` when executing via python. m: bool = False + metadata: dict[str, Any] = dataclasses.field(default_factory=dict) + def __post_init__(self): assert self.path or self.inline assert self.entrypoint, "Need to provide an entrypoint for script." diff --git a/nemo_run/core/execution/slurm.py b/nemo_run/core/execution/slurm.py index 6f436bdf..3a6f921a 100644 --- a/nemo_run/core/execution/slurm.py +++ b/nemo_run/core/execution/slurm.py @@ -748,10 +748,10 @@ def _as_sbatch_flag(key: str, value: Any) -> str: @dataclass(kw_only=True) class SlurmBatchRequest: - cmd: list[str] + launch_cmd: list[str] jobs: list[str] command_groups: list[list[str]] - slurm_config: SlurmExecutor + executor: SlurmExecutor max_retries: int setup: Optional[list[str]] = None extra_env: dict[str, str] @@ -786,7 +786,7 @@ def materialize(self) -> str: In case an erroneous keyword argument is added, a list of all eligible parameters is printed, with their default values """ - args = asdict(self.slurm_config) # noqa: F821 + args = asdict(self.executor) # noqa: F821 parameters = { k: v for k, v in args.items() if v is not None and k in SlurmExecutor.SBATCH_FLAGS } @@ -800,18 +800,16 @@ def materialize(self) -> str: # add necessary parameters original_job_name: str = self.jobs[0] # type: ignore job_name_prefix = ( - self.slurm_config.job_name_prefix - if self.slurm_config.job_name_prefix - else f"{self.slurm_config.account}-{self.slurm_config.account.split('_')[-1]}." + self.executor.job_name_prefix + if self.executor.job_name_prefix + else f"{self.executor.account}-{self.executor.account.split('_')[-1]}." ) job_name = f"{job_name_prefix}{original_job_name}" slurm_job_dir = ( - self.slurm_config.tunnel.job_dir - if self.slurm_config.tunnel - else self.slurm_config.job_dir + self.executor.tunnel.job_dir if self.executor.tunnel else self.executor.job_dir ) - job_directory_name = Path(self.slurm_config.job_dir).name - job_details = self.slurm_config.job_details + job_directory_name = Path(self.executor.job_dir).name + job_details = self.executor.job_details if not job_details.job_name: job_details.job_name = job_name @@ -824,41 +822,41 @@ def materialize(self) -> str: stdout = str(job_details.stdout) stderr = str(job_details.stderr) - if self.slurm_config.array is not None: + if self.executor.array is not None: stdout = stdout.replace("%j", "%A_%a") stderr = stderr.replace("%j", "%A_%a") parameters["output"] = stdout.replace("%t", "0") - if not self.slurm_config.stderr_to_stdout: + if not self.executor.stderr_to_stdout: parameters["error"] = stderr.replace("%t", "0") - if self.slurm_config.additional_parameters is not None: - parameters.update(self.slurm_config.additional_parameters) + if self.executor.additional_parameters is not None: + parameters.update(self.executor.additional_parameters) # now create - sbatch_cmd = " ".join([shlex.quote(arg) for arg in self.cmd]) + sbatch_cmd = " ".join([shlex.quote(arg) for arg in self.launch_cmd]) sbatch_flags = [] - if self.slurm_config.heterogeneous: - assert len(self.jobs) == len(self.slurm_config.resource_group), ( - f"Number of jobs {len(self.jobs)} must match number of resource group requests {len(self.slurm_config.resource_group)}.\nIf you are just submitting a single job, make sure that heterogeneous=False in the executor." + if self.executor.heterogeneous: + assert len(self.jobs) == len(self.executor.resource_group), ( + f"Number of jobs {len(self.jobs)} must match number of resource group requests {len(self.executor.resource_group)}.\nIf you are just submitting a single job, make sure that heterogeneous=False in the executor." ) - final_group_index = len(self.slurm_config.resource_group) - 1 - if self.slurm_config.het_group_indices: - final_group_index = self.slurm_config.het_group_indices.index( - max(self.slurm_config.het_group_indices) + final_group_index = len(self.executor.resource_group) - 1 + if self.executor.het_group_indices: + final_group_index = self.executor.het_group_indices.index( + max(self.executor.het_group_indices) ) - for i in range(len(self.slurm_config.resource_group)): - resource_req = self.slurm_config.resource_group[i] + for i in range(len(self.executor.resource_group)): + resource_req = self.executor.resource_group[i] if resource_req.het_group_index: - assert self.slurm_config.resource_group[i - 1].het_group_index is not None, ( + assert self.executor.resource_group[i - 1].het_group_index is not None, ( "het_group_index must be set for all requests in resource_group" ) if ( i > 0 and resource_req.het_group_index - == self.slurm_config.resource_group[i - 1].het_group_index + == self.executor.resource_group[i - 1].het_group_index ): continue @@ -887,33 +885,31 @@ def materialize(self) -> str: for k in sorted(parameters): sbatch_flags.append(_as_sbatch_flag(k, parameters[k])) - if self.slurm_config.dependencies: - slurm_deps = self.slurm_config.parse_deps() + if self.executor.dependencies: + slurm_deps = self.executor.parse_deps() sbatch_flags.append( _as_sbatch_flag( - "dependency", f"{self.slurm_config.dependency_type}:{':'.join(slurm_deps)}" + "dependency", f"{self.executor.dependency_type}:{':'.join(slurm_deps)}" ) ) env_vars = [] - full_env_vars = self.slurm_config.env_vars | self.extra_env + full_env_vars = self.executor.env_vars | self.extra_env for key, value in full_env_vars.items(): env_vars.append(f"export {key.upper()}={value}") # commandline (this will run the function and args specified in the file provided as argument) # We pass --output and --error here, because the SBATCH command doesn't work as expected with a filename pattern - stderr_flags = [] if self.slurm_config.stderr_to_stdout else ["--error", stderr] + stderr_flags = [] if self.executor.stderr_to_stdout else ["--error", stderr] srun_commands = [] group_env_vars = [] srun_stdout = noquote(job_details.srun_stdout) stderr_flags = ( - [] - if self.slurm_config.stderr_to_stdout - else ["--error", noquote(job_details.srun_stderr)] + [] if self.executor.stderr_to_stdout else ["--error", noquote(job_details.srun_stderr)] ) memory_measure_out = None - if self.slurm_config.memory_measure: + if self.executor.memory_measure: memory_measure_out = srun_stdout def get_container_flags( @@ -937,10 +933,10 @@ def get_container_flags( return _container_flags for group_ind, command_group in enumerate(self.command_groups): - if self.slurm_config.run_as_group and len(self.slurm_config.resource_group) == len( + if self.executor.run_as_group and len(self.executor.resource_group) == len( self.command_groups ): - resource_req = self.slurm_config.resource_group[group_ind] + resource_req = self.executor.resource_group[group_ind] if not resource_req.job_details.job_name: resource_req.job_details.job_name = f"{job_name_prefix}{self.jobs[group_ind]}" @@ -952,7 +948,7 @@ def get_container_flags( cmd_stdout = noquote(resource_req.job_details.srun_stdout) cmd_stderr = ( [] - if self.slurm_config.stderr_to_stdout + if self.executor.stderr_to_stdout else [ "--error", noquote(resource_req.job_details.srun_stderr), @@ -980,20 +976,20 @@ def get_container_flags( if cmd_stderr: cmd_stderr[-1] = cmd_stderr[-1].replace(original_job_name, self.jobs[group_ind]) _container_flags = get_container_flags( - base_mounts=self.slurm_config.container_mounts, + base_mounts=self.executor.container_mounts, src_job_dir=os.path.join( slurm_job_dir, job_directory_name, ), - container_image=self.slurm_config.container_image, + container_image=self.executor.container_image, ) _srun_args = ["--wait=60", "--kill-on-bad-exit=1"] - _srun_args.extend(self.slurm_config.srun_args or []) + _srun_args.extend(self.executor.srun_args or []) - if self.slurm_config.run_as_group and self.slurm_config.heterogeneous: + if self.executor.run_as_group and self.executor.heterogeneous: het_group_index = ( - self.slurm_config.resource_group[group_ind].het_group_index - if self.slurm_config.resource_group[group_ind].het_group_index is not None + self.executor.resource_group[group_ind].het_group_index + if self.executor.resource_group[group_ind].het_group_index is not None else group_ind ) het_group_flag = [f"--het-group={het_group_index}"] @@ -1018,10 +1014,10 @@ def get_container_flags( ) command = " ".join(command_group) - if self.slurm_config.run_as_group: + if self.executor.run_as_group: srun_command = f"{srun_cmd} {command} & pids[{group_ind}]=$!" if group_ind != len(self.command_groups) - 1: - srun_command += f"\n\nsleep {self.slurm_config.wait_time_for_group_job}\n" + srun_command += f"\n\nsleep {self.executor.wait_time_for_group_job}\n" else: srun_command = f"{srun_cmd} {command}" @@ -1033,15 +1029,14 @@ def get_container_flags( "max_retries": self.max_retries, "env_vars": env_vars, "head_node_ip_var": SlurmExecutor.HEAD_NODE_IP_VAR, - "setup_lines": self.slurm_config.setup_lines, + "setup_lines": self.executor.setup_lines, "memory_measure": memory_measure_out, "srun_commands": srun_commands, "group_env_vars": group_env_vars, - "heterogeneous": self.slurm_config.heterogeneous, - "run_as_group": self.slurm_config.run_as_group, - "monitor_group_job": self.slurm_config.run_as_group - and self.slurm_config.monitor_group_job, - "monitor_group_job_wait_time": self.slurm_config.monitor_group_job_wait_time, + "heterogeneous": self.executor.heterogeneous, + "run_as_group": self.executor.run_as_group, + "monitor_group_job": self.executor.run_as_group and self.executor.monitor_group_job, + "monitor_group_job_wait_time": self.executor.monitor_group_job_wait_time, "het_group_host_var": SlurmExecutor.HET_GROUP_HOST_VAR, "ft_enabled": self.launcher and isinstance(self.launcher, FaultTolerance), } @@ -1060,7 +1055,7 @@ def get_container_flags( return sbatch_script def __repr__(self) -> str: - return f"""{" ".join(self.cmd + ["$SBATCH_SCRIPT"])} + return f"""{" ".join(self.launch_cmd + ["$SBATCH_SCRIPT"])} #---------------- # SBATCH_SCRIPT diff --git a/nemo_run/core/execution/utils.py b/nemo_run/core/execution/utils.py index 5c4f48ac..3b61d0d9 100644 --- a/nemo_run/core/execution/utils.py +++ b/nemo_run/core/execution/utils.py @@ -25,7 +25,7 @@ def fill_template(template_name: str, variables: dict, template_dir: Optional[st template_dir = template_dir or os.path.join(os.path.dirname(__file__), "templates") template_path = os.path.join(template_dir, template_name) if not os.path.exists(template_path): - raise FileNotFoundError(f'Template "{template_name}" does not exist.') + raise FileNotFoundError(f'Template "{template_path}" does not exist.') with open(template_path, "r", encoding="utf-8") as fin: template = fin.read() diff --git a/nemo_run/run/ray/cluster.py b/nemo_run/run/ray/cluster.py index 4c0f5014..9bb569b2 100644 --- a/nemo_run/run/ray/cluster.py +++ b/nemo_run/run/ray/cluster.py @@ -22,6 +22,8 @@ from nemo_run.run.ray.kuberay import KubeRayCluster from nemo_run.run.ray.slurm import SlurmRayCluster +USE_WITH_RAY_CLUSTER_KEY = "use_with_ray_cluster" + @dataclass(kw_only=True) class RayCluster: diff --git a/nemo_run/run/ray/slurm.py b/nemo_run/run/ray/slurm.py index 5acd717c..87f622a9 100644 --- a/nemo_run/run/ray/slurm.py +++ b/nemo_run/run/ray/slurm.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import json import logging import os @@ -28,6 +29,7 @@ from pathlib import Path from typing import Any, Dict, Optional, TypeAlias, Union +from nemo_run.config import RUNDIR_NAME, RUNDIR_SPECIAL_NAME from nemo_run.core.execution.slurm import SlurmExecutor, _as_sbatch_flag from nemo_run.core.execution.utils import fill_template from nemo_run.core.packaging.git import GitArchivePackager @@ -43,11 +45,14 @@ class SlurmRayRequest: name: str cluster_dir: str - template_path: str + template_name: str + template_dir: Optional[str] = None executor: SlurmExecutor pre_ray_start_commands: Optional[list[str]] = None command: Optional[str] = None workdir: Optional[str] = None + nemo_run_dir: Optional[str] = None + launch_cmd: list[str] @staticmethod def get_job_name(executor: SlurmExecutor, name: str) -> str: @@ -125,8 +130,19 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str: else: _srun_flags.append("--gres=gpu:8") - _srun_flags.append(f"--container-workdir={self.cluster_dir}") - _srun_flags += ["--container-mounts", ",".join(mounts)] + if self.nemo_run_dir: + new_mounts = copy.deepcopy(mounts) + for i, mount in enumerate(new_mounts): + if mount.startswith(RUNDIR_SPECIAL_NAME): + new_mounts[i] = mount.replace(RUNDIR_SPECIAL_NAME, self.nemo_run_dir, 1) + + new_mounts.append(f"{self.nemo_run_dir}:/{RUNDIR_NAME}") + else: + new_mounts = mounts + + _srun_flags += ["--container-mounts", ",".join(new_mounts)] + container_workdir = self.workdir or self.cluster_dir + _srun_flags.append(f"--container-workdir={container_workdir}") return " ".join(_srun_flags) @@ -148,7 +164,12 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str: if self.pre_ray_start_commands: vars_to_fill["pre_ray_start_commands"] = "\n".join(self.pre_ray_start_commands) - sbatch_script = fill_template(self.template_path, vars_to_fill) + sbatch_script = fill_template( + self.template_name, + vars_to_fill, + template_dir=self.template_dir + or os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates"), + ) return sbatch_script def __repr__(self) -> str: @@ -254,13 +275,12 @@ def create_ray_cluster( ray_sbatch = SlurmRayRequest( name=name, cluster_dir=cluster_dir, - template_path=os.path.join( - os.path.dirname(os.path.abspath(__file__)), "templates", "ray.sub.j2" - ), + template_name="ray.sub.j2", executor=executor, pre_ray_start_commands=pre_ray_start_commands, command=command, workdir=workdir, + launch_cmd=["sbatch", "--requeue", "--parsable"], ).materialize() if dryrun: @@ -321,6 +341,7 @@ def schedule_ray_job( command: str, workdir: Optional[str] = None, pre_ray_start_commands: Optional[list[str]] = None, + runtime_env_yaml: Optional[str] = None, dryrun: bool = False, ): remote_workdir = None diff --git a/nemo_run/run/ray/templates/ray.sub.j2 b/nemo_run/run/ray/templates/ray.sub.j2 index 7320fde6..e4341b40 100644 --- a/nemo_run/run/ray/templates/ray.sub.j2 +++ b/nemo_run/run/ray/templates/ray.sub.j2 @@ -289,10 +289,11 @@ echo "[INFO] Ray cluster information saved to $CLUSTER_DIR/ray_cluster_info.json # This driver process is responsible for launching a job on the Ray cluster CONTAINER_CWD=$(scontrol show job $SLURM_JOB_ID --json | jq -r '.jobs[].current_working_directory') # Define command to be empty by default -COMMAND="" +COMMAND="${COMMAND:-{{ command }}}" +COMMAND_WORKDIR={{ command_workdir | default('$CONTAINER_CWD') }} if [[ -n "$COMMAND" ]]; then - srun --no-container-mount-home --gpus=0 --overlap --container-name=ray-head --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/ray-driver.log bash -c "$COMMAND" + srun --no-container-mount-home --gpus=0 --overlap --container-name=ray-head --container-workdir=$COMMAND_WORKDIR --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/ray-job.log bash -c "$COMMAND" else echo "[INFO]: Ray Cluster is idled, run this on the slurm head node to get a shell to the head node:" cat <$CLUSTER_DIR/scripts/${SLURM_JOB_ID}-attach.sh diff --git a/nemo_run/run/torchx_backend/packaging.py b/nemo_run/run/torchx_backend/packaging.py index 49857c90..55062967 100644 --- a/nemo_run/run/torchx_backend/packaging.py +++ b/nemo_run/run/torchx_backend/packaging.py @@ -26,8 +26,10 @@ from nemo_run.core.execution.dgxcloud import DGXCloudExecutor from nemo_run.core.execution.launcher import FaultTolerance, Torchrun from nemo_run.core.execution.local import LocalExecutor +from nemo_run.core.execution.slurm import SlurmExecutor from nemo_run.core.serialization.yaml import YamlSerializer from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer +from nemo_run.run.ray.cluster import USE_WITH_RAY_CLUSTER_KEY from nemo_run.run.torchx_backend.components import ft_launcher, torchrun log: logging.Logger = logging.getLogger(__name__) @@ -130,10 +132,12 @@ def _get_details_from_script(fn_or_script: Script, serialize_configs: bool): if isinstance(fn_or_script, Partial): role_args, args, m, no_python, script, entrypoint = _get_details_from_partial(fn_or_script) + metadata = {} else: role_args, args, m, no_python, script, entrypoint = _get_details_from_script( fn_or_script, serialize_configs=True ) + metadata = fn_or_script.metadata env = env | fn_or_script.env launcher = executor.get_launcher() @@ -223,6 +227,14 @@ def _get_details_from_script(fn_or_script: Script, serialize_configs: bool): role.entrypoint = "nsys" role.args = nsys_prefix + role.args + if metadata: + if USE_WITH_RAY_CLUSTER_KEY in metadata: + assert isinstance(executor, SlurmExecutor), ( + f"{USE_WITH_RAY_CLUSTER_KEY} is only supported for SlurmExecutor" + ) + assert len(app_def.roles) == 1, "Only one command is supported for Ray jobs." + + app_def.metadata = metadata return app_def diff --git a/nemo_run/run/torchx_backend/schedulers/slurm.py b/nemo_run/run/torchx_backend/schedulers/slurm.py index 4674d13c..f20778fd 100644 --- a/nemo_run/run/torchx_backend/schedulers/slurm.py +++ b/nemo_run/run/torchx_backend/schedulers/slurm.py @@ -50,11 +50,13 @@ ) from torchx.specs.api import is_terminal -from nemo_run.config import from_dict, get_nemorun_home +from nemo_run.config import RUNDIR_NAME, from_dict, get_nemorun_home from nemo_run.core.execution.base import Executor from nemo_run.core.execution.slurm import SlurmBatchRequest, SlurmExecutor, SlurmJobDetails from nemo_run.core.tunnel.client import LocalTunnel, PackagingJob, SSHTunnel, Tunnel from nemo_run.run import experiment as run_experiment +from nemo_run.run.ray.cluster import USE_WITH_RAY_CLUSTER_KEY +from nemo_run.run.ray.slurm import SlurmRayRequest from nemo_run.run.torchx_backend.schedulers.api import SchedulerMixin log: logging.Logger = logging.getLogger(__name__) @@ -101,38 +103,52 @@ def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # t executor.package(packager=executor.packager, job_name=Path(job_dir).name) - srun_cmds: list[list[str]] = [] - jobs = [] - envs = {} - values = executor.macro_values() - - if values: - executor.env_vars = { - key: values.substitute(arg) for key, arg in executor.env_vars.items() - } - for resource_req in executor.resource_group: - resource_req.env_vars = { - key: values.substitute(arg) for key, arg in resource_req.env_vars.items() - } + if app.metadata and app.metadata.get(USE_WITH_RAY_CLUSTER_KEY, False): + assert len(app.roles) == 1, "Only one command is supported for Ray jobs." + command = [app.roles[0].entrypoint] + app.roles[0].args + req = SlurmRayRequest( + name=app.roles[0].name, + launch_cmd=["sbatch", "--requeue", "--parsable"], + command=" ".join(command), + cluster_dir=os.path.join(executor.tunnel.job_dir, Path(job_dir).name, "ray"), + template_name="ray.sub.j2", + executor=executor, + workdir=f"/{RUNDIR_NAME}/code", + nemo_run_dir=os.path.join(executor.tunnel.job_dir, Path(job_dir).name), + ) + else: + srun_cmds: list[list[str]] = [] + jobs = [] + envs = {} + values = executor.macro_values() - for role in app.roles: if values: - role = values.apply(role) - srun_cmd = [role.entrypoint] + role.args - srun_cmds.append([" ".join(srun_cmd)]) - jobs.append(role.name) - envs |= role.env - - cmd = ["sbatch", "--requeue", "--parsable"] - req = SlurmBatchRequest( - cmd=cmd, - jobs=jobs, - command_groups=srun_cmds, - slurm_config=executor, - max_retries=min(role.max_retries for role in app.roles), - extra_env=envs, - launcher=executor.get_launcher(), - ) + executor.env_vars = { + key: values.substitute(arg) for key, arg in executor.env_vars.items() + } + for resource_req in executor.resource_group: + resource_req.env_vars = { + key: values.substitute(arg) for key, arg in resource_req.env_vars.items() + } + + for role in app.roles: + if values: + role = values.apply(role) + srun_cmd = [role.entrypoint] + role.args + srun_cmds.append([" ".join(srun_cmd)]) + jobs.append(role.name) + envs |= role.env + + cmd = ["sbatch", "--requeue", "--parsable"] + req = SlurmBatchRequest( + launch_cmd=cmd, + jobs=jobs, + command_groups=srun_cmds, + executor=executor, + max_retries=min(role.max_retries for role in app.roles), + extra_env=envs, + launcher=executor.get_launcher(), + ) # Write and copy sbatch script sbatch_dir = executor.experiment_dir @@ -144,10 +160,10 @@ def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # t return AppDryRunInfo(req, repr) - def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest]) -> str: # type: ignore + def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest | SlurmRayRequest]) -> str: # type: ignore # Setup req = dryrun_info.request - slurm_executor = dryrun_info.request.slurm_config + slurm_executor = dryrun_info.request.executor assert slurm_executor.experiment_id, "Executor not assigned to experiment." job_dir = slurm_executor.job_dir @@ -164,11 +180,11 @@ def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest]) -> str: # typ cmd = ["sbatch", "--requeue", "--parsable"] slurm_deps = slurm_executor.parse_deps() cmd.append(f"--dependency={slurm_executor.dependency_type}:{':'.join(slurm_deps)}") - req.cmd = cmd + req.launch_cmd = cmd # Run sbatch script - req.cmd += [dst_path] - job_id = self.tunnel.run(" ".join(req.cmd)).stdout.strip() + req.launch_cmd += [dst_path] + job_id = self.tunnel.run(" ".join(req.launch_cmd)).stdout.strip() # Save metadata _save_job_dir(job_id, job_dir, tunnel, slurm_executor.job_details.ls_term) diff --git a/test/core/execution/test_slurm_templates.py b/test/core/execution/test_slurm_templates.py index dbb33690..0aa32619 100644 --- a/test/core/execution/test_slurm_templates.py +++ b/test/core/execution/test_slurm_templates.py @@ -60,10 +60,10 @@ def dummy_slurm_request_with_artifact( extra_env = {"ENV_VAR": "value"} return ( SlurmBatchRequest( - cmd=cmd, + launch_cmd=cmd, jobs=["sample_job"], command_groups=command_groups, - slurm_config=slurm_config, + executor=slurm_config, max_retries=max_retries, extra_env=extra_env, ), @@ -95,10 +95,10 @@ def ft_slurm_request_with_artifact( extra_env = {"ENV_VAR": "value"} return ( SlurmBatchRequest( - cmd=cmd, + launch_cmd=cmd, jobs=["sample_job"], command_groups=command_groups, - slurm_config=slurm_config, + executor=slurm_config, max_retries=max_retries, extra_env=extra_env, launcher=slurm_config.get_launcher(), @@ -140,10 +140,10 @@ def group_slurm_request_with_artifact( extra_env = {"ENV_VAR": "value"} return ( SlurmBatchRequest( - cmd=cmd, + launch_cmd=cmd, jobs=["sample_job-0", "sample_job-1"], command_groups=command_groups, - slurm_config=slurm_config, + executor=slurm_config, max_retries=max_retries, extra_env=extra_env, ), @@ -155,7 +155,7 @@ def group_no_monitor_slurm_request_with_artifact( self, group_slurm_request_with_artifact ) -> tuple[SlurmBatchRequest, str]: req, _ = group_slurm_request_with_artifact - req.slurm_config.monitor_group_job = False + req.executor.monitor_group_job = False return ( req, os.path.join(ARTIFACTS_DIR, "group_slurm_no_monitor.sh"), @@ -201,10 +201,10 @@ def group_resource_req_slurm_request_with_artifact( extra_env = {"ENV_VAR": "value"} return ( SlurmBatchRequest( - cmd=cmd, + launch_cmd=cmd, jobs=["sample_job-0", "sample_job-1"], command_groups=command_groups, - slurm_config=executor, + executor=executor, max_retries=max_retries, extra_env=extra_env, ), @@ -271,10 +271,10 @@ def het_slurm_request_with_artifact( extra_env = {"ENV_VAR": "value"} return ( SlurmBatchRequest( - cmd=cmd, + launch_cmd=cmd, jobs=["sample_job-0", "sample_job-1"], command_groups=command_groups, - slurm_config=slurm_config, + executor=slurm_config, max_retries=max_retries, extra_env=extra_env, ), @@ -336,10 +336,10 @@ def ft_het_slurm_request_with_artifact( extra_env = {"ENV_VAR": "value"} return ( SlurmBatchRequest( - cmd=cmd, + launch_cmd=cmd, jobs=["sample_job-0", "sample_job-1"], command_groups=command_groups, - slurm_config=slurm_config, + executor=slurm_config, max_retries=max_retries, extra_env=extra_env, launcher=slurm_config.get_launcher(), @@ -382,14 +382,14 @@ def test_dummy_batch_request_dependencies( dummy_slurm_request_with_artifact: tuple[SlurmBatchRequest, str], ): dummy_slurm_request, _ = dummy_slurm_request_with_artifact - dummy_slurm_request.slurm_config.dependencies = [ + dummy_slurm_request.executor.dependencies = [ "slurm_tunnel://nemo_run/depend1", "slurm_tunnel://nemo_run/depend2", ] sbatch_script = dummy_slurm_request.materialize() assert "#SBATCH --dependency=afterok:depend1:depend2" in sbatch_script - dummy_slurm_request.slurm_config.dependency_type = "afterany" + dummy_slurm_request.executor.dependency_type = "afterany" sbatch_script = dummy_slurm_request.materialize() assert "#SBATCH --dependency=afterany:depend1:depend2" in sbatch_script @@ -398,11 +398,11 @@ def test_dummy_batch_request_memory_measure( dummy_slurm_request_with_artifact: tuple[SlurmBatchRequest, str], ): dummy_slurm_request, _ = dummy_slurm_request_with_artifact - dummy_slurm_request.slurm_config.dependencies = [ + dummy_slurm_request.executor.dependencies = [ "slurm_tunnel://nemo_run/depend1", "slurm_tunnel://nemo_run/depend2", ] - dummy_slurm_request.slurm_config.memory_measure = True + dummy_slurm_request.executor.memory_measure = True sbatch_script = dummy_slurm_request.materialize() assert ( "srun --ntasks=1 --ntasks-per-node=1 --output /root/sample_job/log-account-account.sample_job_%j_${SLURM_RESTART_COUNT:-0}.out --wait=60 --kill-on-bad-exit=1 --overlap nvidia-smi" @@ -425,7 +425,7 @@ def srun_stdout(self) -> Path: return Path(self.folder) / "log_job.out" dummy_slurm_request, _ = dummy_slurm_request_with_artifact - dummy_slurm_request.slurm_config.job_details = CustomJobDetails() + dummy_slurm_request.executor.job_details = CustomJobDetails() sbatch_script = dummy_slurm_request.materialize() assert "#SBATCH --job-name=account-account.sample_job" in sbatch_script assert "--output /root/sample_job/log_job.out" in sbatch_script @@ -447,7 +447,7 @@ def srun_stdout(self) -> Path: return Path(self.folder) / "log_job.out" dummy_slurm_request, _ = dummy_slurm_request_with_artifact - dummy_slurm_request.slurm_config.job_details = CustomJobDetails( + dummy_slurm_request.executor.job_details = CustomJobDetails( job_name="custom_sample_job", folder="/custom_folder" ) sbatch_script = dummy_slurm_request.materialize() @@ -460,8 +460,8 @@ def test_dummy_batch_request_nsys( dummy_slurm_request_with_artifact: tuple[SlurmBatchRequest, str], ): dummy_slurm_request, _ = dummy_slurm_request_with_artifact - dummy_slurm_request.slurm_config.get_launcher().nsys_profile = True - launcher_prefix = dummy_slurm_request.slurm_config.get_launcher_prefix() + dummy_slurm_request.executor.get_launcher().nsys_profile = True + launcher_prefix = dummy_slurm_request.executor.get_launcher_prefix() assert launcher_prefix == [ "profile", "-s", @@ -482,8 +482,8 @@ def test_dummy_batch_request_warn( dummy_slurm_request_with_artifact: tuple[SlurmBatchRequest, str], ): dummy_slurm_request, _ = dummy_slurm_request_with_artifact - dummy_slurm_request.slurm_config.cpus_per_gpu = 10 - dummy_slurm_request.slurm_config.gpus_per_task = None + dummy_slurm_request.executor.cpus_per_gpu = 10 + dummy_slurm_request.executor.gpus_per_task = None with pytest.warns(match='"cpus_per_gpu" requires to set "gpus_per_task"'): dummy_slurm_request.materialize() @@ -493,7 +493,7 @@ def test_dummy_batch_request_array( dummy_slurm_request_with_artifact: tuple[SlurmBatchRequest, str], ): dummy_slurm_request, _ = dummy_slurm_request_with_artifact - dummy_slurm_request.slurm_config.array = "0-10" + dummy_slurm_request.executor.array = "0-10" sbatch_script = dummy_slurm_request.materialize() assert "#SBATCH --array=0-10" in sbatch_script @@ -507,7 +507,7 @@ def test_dummy_batch_additonal_params( dummy_slurm_request_with_artifact: tuple[SlurmBatchRequest, str], ): dummy_slurm_request, _ = dummy_slurm_request_with_artifact - dummy_slurm_request.slurm_config.additional_parameters = {"abc": "def"} + dummy_slurm_request.executor.additional_parameters = {"abc": "def"} sbatch_script = dummy_slurm_request.materialize() assert "#SBATCH --abc=def" in sbatch_script @@ -517,7 +517,7 @@ def test_dummy_batch_job_name_prefix( dummy_slurm_request_with_artifact: tuple[SlurmBatchRequest, str], ): dummy_slurm_request, _ = dummy_slurm_request_with_artifact - dummy_slurm_request.slurm_config.job_name_prefix = "my-custom-prefix:" + dummy_slurm_request.executor.job_name_prefix = "my-custom-prefix:" sbatch_script = dummy_slurm_request.materialize() assert "#SBATCH --job-name=my-custom-prefix:sample_job" in sbatch_script @@ -537,7 +537,7 @@ def test_het_batch_request_materialize( het_slurm_request_with_artifact: tuple[SlurmBatchRequest, str], ): het_slurm_request, artifact = het_slurm_request_with_artifact - executor = het_slurm_request.slurm_config + executor = het_slurm_request.executor self.apply_macros(executor) sbatch_script = het_slurm_request.materialize() expected = Path(artifact).read_text() @@ -548,7 +548,7 @@ def test_het_batch_request_dependencies( het_slurm_request_with_artifact: tuple[SlurmBatchRequest, str], ): het_slurm_request, _ = het_slurm_request_with_artifact - het_slurm_request.slurm_config.dependencies = [ + het_slurm_request.executor.dependencies = [ "slurm_tunnel://nemo_run/depend1", "slurm_tunnel://nemo_run/depend2", ] @@ -560,8 +560,8 @@ def test_group_batch_request_materialize( group_slurm_request_with_artifact: tuple[SlurmBatchRequest, str], ): group_slurm_request, artifact = group_slurm_request_with_artifact - executor = group_slurm_request.slurm_config - group_slurm_request.slurm_config = SlurmExecutor.merge([executor], num_tasks=2) + executor = group_slurm_request.executor + group_slurm_request.executor = SlurmExecutor.merge([executor], num_tasks=2) self.apply_macros(executor) sbatch_script = group_slurm_request.materialize() expected = Path(artifact).read_text() @@ -572,8 +572,8 @@ def test_group_no_monitor_batch_request_materialize( group_no_monitor_slurm_request_with_artifact: tuple[SlurmBatchRequest, str], ): group_slurm_request, artifact = group_no_monitor_slurm_request_with_artifact - executor = group_slurm_request.slurm_config - group_slurm_request.slurm_config = SlurmExecutor.merge([executor], num_tasks=2) + executor = group_slurm_request.executor + group_slurm_request.executor = SlurmExecutor.merge([executor], num_tasks=2) self.apply_macros(executor) sbatch_script = group_slurm_request.materialize() expected = Path(artifact).read_text() @@ -584,8 +584,8 @@ def test_group_resource_req_batch_request_materialize( group_resource_req_slurm_request_with_artifact: tuple[SlurmBatchRequest, str], ): group_slurm_request, artifact = group_resource_req_slurm_request_with_artifact - executor = group_slurm_request.slurm_config - group_slurm_request.slurm_config = SlurmExecutor.merge([executor], num_tasks=2) + executor = group_slurm_request.executor + group_slurm_request.executor = SlurmExecutor.merge([executor], num_tasks=2) self.apply_macros(executor) sbatch_script = group_slurm_request.materialize() expected = Path(artifact).read_text() @@ -607,15 +607,15 @@ def srun_stdout(self) -> Path: return Path(self.folder) / f"log_{self.job_name}.out" group_resource_req_slurm_request, _ = group_resource_req_slurm_request_with_artifact - group_resource_req_slurm_request.slurm_config.job_details = CustomJobDetails( + group_resource_req_slurm_request.executor.job_details = CustomJobDetails( job_name="custom_sample_job", folder="/custom_folder" ) - group_resource_req_slurm_request.slurm_config.resource_group[0].job_details = copy.deepcopy( - group_resource_req_slurm_request.slurm_config.job_details + group_resource_req_slurm_request.executor.resource_group[0].job_details = copy.deepcopy( + group_resource_req_slurm_request.executor.job_details + ) + group_resource_req_slurm_request.executor.resource_group[1].job_details = CustomJobDetails( + job_name="custom_sample_job_2", folder="/custom_folder_2" ) - group_resource_req_slurm_request.slurm_config.resource_group[ - 1 - ].job_details = CustomJobDetails(job_name="custom_sample_job_2", folder="/custom_folder_2") sbatch_script = group_resource_req_slurm_request.materialize() assert "#SBATCH --job-name=custom_sample_job" in sbatch_script @@ -637,7 +637,7 @@ def test_ft_het_slurm_request_materialize( self, ft_het_slurm_request_with_artifact: tuple[SlurmBatchRequest, str] ): ft_het_slurm_request, artifact = ft_het_slurm_request_with_artifact - executor = ft_het_slurm_request.slurm_config + executor = ft_het_slurm_request.executor self.apply_macros(executor) sbatch_script = ft_het_slurm_request.materialize() expected = Path(artifact).read_text() @@ -648,7 +648,7 @@ def test_ft_het_slurm_request_materialize( def test_het_job_name_prefix(self, het_slurm_request_with_artifact): # Set the job_name_prefix to a custom value het_request, _ = het_slurm_request_with_artifact - het_request.slurm_config.job_name_prefix = "prefix_" + het_request.executor.job_name_prefix = "prefix_" # Materialize the batch request script sbatch_script = het_request.materialize() @@ -676,7 +676,7 @@ def srun_stdout(self): return Path(self.folder) / "log_job.out" custom_name = "custom_het_job" - het_request.slurm_config.job_details = CustomJobDetails( + het_request.executor.job_details = CustomJobDetails( job_name=custom_name, folder="/custom_folder" ) sbatch_script = het_request.materialize() From dc650313c6572c77a06821e0f90977df57deee2d Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Mon, 19 May 2025 12:38:27 -0700 Subject: [PATCH 06/18] Potential fix for code scanning alert no. 424: Variable defined multiple times Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Signed-off-by: Hemil Desai Signed-off-by: Hemil Desai --- nemo_run/core/execution/slurm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_run/core/execution/slurm.py b/nemo_run/core/execution/slurm.py index 3a6f921a..3f662723 100644 --- a/nemo_run/core/execution/slurm.py +++ b/nemo_run/core/execution/slurm.py @@ -900,7 +900,7 @@ def materialize(self) -> str: # commandline (this will run the function and args specified in the file provided as argument) # We pass --output and --error here, because the SBATCH command doesn't work as expected with a filename pattern - stderr_flags = [] if self.executor.stderr_to_stdout else ["--error", stderr] + # Removed redundant assignment to stderr_flags srun_commands = [] group_env_vars = [] From 6ca73019461173ff59664ea00c838b99c736e0e0 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Mon, 19 May 2025 12:38:39 -0700 Subject: [PATCH 07/18] Potential fix for code scanning alert no. 428: Jinja2 templating with autoescape=False Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Signed-off-by: Hemil Desai Signed-off-by: Hemil Desai --- nemo_run/core/execution/utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/nemo_run/core/execution/utils.py b/nemo_run/core/execution/utils.py index 3b61d0d9..1ad270c0 100644 --- a/nemo_run/core/execution/utils.py +++ b/nemo_run/core/execution/utils.py @@ -17,6 +17,7 @@ from typing import Optional import jinja2 +from jinja2 import select_autoescape def fill_template(template_name: str, variables: dict, template_dir: Optional[str] = None) -> str: @@ -29,8 +30,9 @@ def fill_template(template_name: str, variables: dict, template_dir: Optional[st with open(template_path, "r", encoding="utf-8") as fin: template = fin.read() - j2_template = jinja2.Environment(loader=jinja2.FileSystemLoader(template_dir)).from_string( - template - ) + j2_template = jinja2.Environment( + loader=jinja2.FileSystemLoader(template_dir), + autoescape=select_autoescape(['html', 'xml']) + ).from_string(template) content = j2_template.render(**variables) return content From 3906fa2536fd6c8a9484943670731fb6100bba7d Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Mon, 19 May 2025 12:39:20 -0700 Subject: [PATCH 08/18] Potential fix for code scanning alert no. 427: Explicit returns mixed with implicit (fall through) returns Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Signed-off-by: Hemil Desai Signed-off-by: Hemil Desai --- nemo_run/run/ray/slurm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nemo_run/run/ray/slurm.py b/nemo_run/run/ray/slurm.py index 87f622a9..96712ec9 100644 --- a/nemo_run/run/ray/slurm.py +++ b/nemo_run/run/ray/slurm.py @@ -334,6 +334,8 @@ def create_ray_cluster( return job_id + return None + def schedule_ray_job( self, name: str, From 1d74c5b4f60ffe07673d99a2060a26893f019a77 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Mon, 19 May 2025 12:45:19 -0700 Subject: [PATCH 09/18] Fixes Signed-off-by: Hemil Desai --- nemo_run/run/ray/slurm.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/nemo_run/run/ray/slurm.py b/nemo_run/run/ray/slurm.py index 96712ec9..296664c6 100644 --- a/nemo_run/run/ray/slurm.py +++ b/nemo_run/run/ray/slurm.py @@ -334,8 +334,6 @@ def create_ray_cluster( return job_id - return None - def schedule_ray_job( self, name: str, @@ -413,6 +411,19 @@ def schedule_ray_job( command=command, workdir=remote_workdir, ) + + # Descriptive log for the user with useful paths / identifiers + cluster_dir = os.path.join(executor.tunnel.job_dir, name) + logger.info( + f"""\n\n\033[1;34mRay job submitted to Slurm cluster at {executor.tunnel.key}:\033[0m + • \033[1mJob ID\033[0m : \033[32m{job_id}\033[0m + • \033[1mCluster dir\033[0m : {cluster_dir} + • \033[1mLogs directory\033[0m : {os.path.join(cluster_dir, "logs")} + • \033[1mSBATCH script\033[0m : {os.path.join(cluster_dir, "ray.sub")} + • \033[1mRemote workdir\033[0m : {remote_workdir} + (use `squeue -j {job_id}` to check status, `scancel {job_id}` to cancel)\n""" + ) + return job_id def wait_until_ray_cluster_running( From 0e6d51abd13f8a82638c992c48c9d5fb1daeb5a1 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Mon, 19 May 2025 14:18:49 -0700 Subject: [PATCH 10/18] fix Signed-off-by: Hemil Desai --- nemo_run/core/execution/utils.py | 2 -- nemo_run/run/ray/slurm.py | 2 ++ 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo_run/core/execution/utils.py b/nemo_run/core/execution/utils.py index 1ad270c0..0eda43da 100644 --- a/nemo_run/core/execution/utils.py +++ b/nemo_run/core/execution/utils.py @@ -17,7 +17,6 @@ from typing import Optional import jinja2 -from jinja2 import select_autoescape def fill_template(template_name: str, variables: dict, template_dir: Optional[str] = None) -> str: @@ -32,7 +31,6 @@ def fill_template(template_name: str, variables: dict, template_dir: Optional[st j2_template = jinja2.Environment( loader=jinja2.FileSystemLoader(template_dir), - autoescape=select_autoescape(['html', 'xml']) ).from_string(template) content = j2_template.render(**variables) return content diff --git a/nemo_run/run/ray/slurm.py b/nemo_run/run/ray/slurm.py index 296664c6..be87a8b2 100644 --- a/nemo_run/run/ray/slurm.py +++ b/nemo_run/run/ray/slurm.py @@ -140,6 +140,8 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str: else: new_mounts = mounts + new_mounts.append(f"{self.cluster_dir}:{self.cluster_dir}") + _srun_flags += ["--container-mounts", ",".join(new_mounts)] container_workdir = self.workdir or self.cluster_dir _srun_flags.append(f"--container-workdir={container_workdir}") From f78e106bd70156bced0d6a0f516ec0a25839e1ea Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Mon, 19 May 2025 14:47:52 -0700 Subject: [PATCH 11/18] fix Signed-off-by: Hemil Desai --- nemo_run/run/ray/templates/ray.sub.j2 | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo_run/run/ray/templates/ray.sub.j2 b/nemo_run/run/ray/templates/ray.sub.j2 index e4341b40..6ea2816d 100644 --- a/nemo_run/run/ray/templates/ray.sub.j2 +++ b/nemo_run/run/ray/templates/ray.sub.j2 @@ -157,7 +157,7 @@ EOF srun {{ common_srun_args }} --container-name=ray-head --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/ray-head.log bash -x -c "$head_cmd" & # Wait for the head node container to start and for Ray to be ready -while ! (srun --overlap --nodes=1 --ntasks=1 -w $head_node test -f $LOG_DIR/STARTED_RAY_HEAD && srun --overlap --container-name=ray-head --nodes=1 --ntasks=1 -w $head_node ray status 2>/dev/null); do +while ! (srun --overlap --nodes=1 --ntasks=1 -w $head_node test -f $LOG_DIR/STARTED_RAY_HEAD && srun --overlap --container-name=ray-head --nodes=1 --ntasks=1 -w $head_node ray status --address $ip_head 2>/dev/null); do echo "[INFO][$(date)] Waiting for Ray head node container to start and be ready..." sleep 2 done @@ -238,7 +238,7 @@ done # Before we launch a job on this cluster we need to make sure that the bringup is complete # We do so by querying the number of worker_units in the ray cluster and asserting = NUM_ACTORS extract_worker_units() { - status_output=$(srun --overlap --container-name=ray-head --nodes=1 --ntasks=1 -w "$head_node" ray status) + status_output=$(srun --overlap --container-name=ray-head --nodes=1 --ntasks=1 -w "$head_node" ray status --address $ip_head) if echo "$status_output" | grep -q "worker_units"; then worker_units=$(echo "$status_output" | grep "worker_units" | awk -F'[/. ]' '{print $4}') echo $worker_units From ba124e14230a8c94a91fd96f3e41890a464e321a Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Wed, 21 May 2025 09:37:55 -0700 Subject: [PATCH 12/18] Unify apis and add a guide Signed-off-by: Hemil Desai --- docs/source/guides/index.rst | 1 + docs/source/guides/ray.md | 263 ++++++++++ nemo_run/core/execution/kuberay.py | 163 ++++++- nemo_run/run/ray/cluster.py | 37 +- nemo_run/run/ray/job.py | 81 +++ nemo_run/run/ray/kuberay.py | 652 ++++++++++++++++++++----- nemo_run/run/ray/slurm.py | 676 ++++++++++++++++++-------- nemo_run/run/ray/templates/ray.sub.j2 | 10 +- 8 files changed, 1544 insertions(+), 339 deletions(-) create mode 100644 docs/source/guides/ray.md create mode 100644 nemo_run/run/ray/job.py diff --git a/docs/source/guides/index.rst b/docs/source/guides/index.rst index f68a2497..ead59c65 100644 --- a/docs/source/guides/index.rst +++ b/docs/source/guides/index.rst @@ -8,3 +8,4 @@ Guides execution management why-use-nemo-run + ray diff --git a/docs/source/guides/ray.md b/docs/source/guides/ray.md new file mode 100644 index 00000000..3be50d0a --- /dev/null +++ b/docs/source/guides/ray.md @@ -0,0 +1,263 @@ +# Ray Clusters & Jobs + +--- + +> **Audience**: You already know how to configure executors with NeMo-Run and want distributed *Ray* on either Kubernetes **or** Slurm. +> +> **TL;DR**: `RayCluster` manages the _cluster_; `RayJob` submits a job with an ephemeral cluster. Everything else is syntactic sugar. + +## RayCluster vs. RayJob – which one do I need? + +| Aspect | RayCluster (interactive) | RayJob (batch) | +|--------|--------------------------|----------------| +| Cluster lifetime | Remains until you call `.stop()` | Ephemeral – cluster auto-deletes after the job finishes | +| Spin-up cost per run | Paid **once** (reuse for many jobs) | Paid **per** submission | +| Multiple jobs on same cluster | Yes | No (one job per submission) | +| Dashboard access | `.port_forward()` to open | Not exposed by default | +| Best for | Debugging, notebooks, iterative dev, hyper-param sweeps that reuse workers | CI/CD pipelines, scheduled training/eval, one-off runs | +| Resource efficiency | Great when you launch many jobs interactively | Great when you just need results & want resources freed asap | + +**Rules of thumb** + +• Pick **RayCluster** when you want a long-lived playground: start it once, poke around with the Ray CLI or a Jupyter notebook, submit multiple Ray Jobs yourself, and tear it down when you're done. + +• Pick **RayJob** when you simply need *"run this script with N GPUs and tell me when you're done"* – the backend spins up a transient cluster, runs the entrypoint, collects logs/status, and cleans everything up automatically. + +## 1. Mental model + +| Object | What it abstracts | Back-ends supported | +|-------------|-------------------|---------------------| +| `run.ray.cluster.RayCluster` | Lifecycle of a Ray **cluster** (create ⇒ wait ⇢ status ⇢ port-forward ⇢ delete). | `KubeRayExecutor`, `SlurmExecutor` | +| `run.ray.job.RayJob` | Lifecycle of a Ray **job** (submit ⇒ monitor ⇢ logs ⇢ cancel). | same | + +The two helpers share a uniform API; the chosen *Executor* decides whether we talk to the **KubeRay** operator (K8s) or a **Slurm** job under the hood. + +```mermaid +classDiagram + RayCluster <|-- KubeRayCluster + RayCluster <|-- SlurmRayCluster + RayJob <|-- KubeRayJob + RayJob <|-- SlurmRayJob +``` + +## 2. KubeRay quick-start + +```python +from nemo_run.core.execution.kuberay import KubeRayExecutor, KubeRayWorkerGroup +from nemo_run.run.ray.cluster import RayCluster +from nemo_run.run.ray.job import RayJob + +# 1) Configure a KubeRay executor (resources + cluster policy) +executor = KubeRayExecutor( + namespace="my-k8s-namespace", + ray_version="2.43.0", + image="anyscale/ray:2.43.0-py312-cu125", + head_cpu="4", + head_memory="12Gi", + worker_groups=[ + KubeRayWorkerGroup( + group_name="worker", # arbitrary string + replicas=2, # two worker pods + gpus_per_worker=8, + ) + ], + # Optional tweaks ---------------------------------------------------- + reuse_volumes_in_worker_groups=True, # mount PVCs on workers too + spec_kwargs={"schedulerName": "runai-scheduler"}, # e.g. Run:ai + volume_mounts=[{"name": "workspace", "mountPath": "/workspace"}], + volumes=[{ + "name": "workspace", + "persistentVolumeClaim": {"claimName": "my-workspace-pvc"}, + }], + env_vars={ + "UV_PROJECT_ENVIRONMENT": "/home/ray/venvs/driver", + "NEMO_RL_VENV_DIR": "/home/ray/venvs", + "HF_HOME": "/workspace/hf_cache", + }, + container_kwargs={ + "securityContext": { + "allowPrivilegeEscalation": False, + "runAsUser": 0, + } + }, +) + +# 2) Commands executed in EVERY Ray container before the daemon starts +pre_ray_start = [ + "pip install uv", + "echo 'unset RAY_RUNTIME_ENV_HOOK' >> /home/ray/.bashrc", +] + +# 3) Spin-up the cluster & expose the dashboard +cluster = RayCluster(name="demo-kuberay-cluster", executor=executor) +cluster.start(timeout=900, pre_ray_start_commands=pre_ray_start) +cluster.port_forward(port=8265, target_port=8265, wait=False) # dashboard → http://localhost:8265 + +# 4) Submit a Ray Job that runs inside that cluster +job = RayJob(name="demo-kuberay-job", executor=executor) +job.start( + command="uv run python examples/train.py --config cfgs/train.yaml", + workdir="/path/to/project/", # synced to PVC automatically + runtime_env_yaml="/path/to/runtime_env.yaml", # optional + pre_ray_start_commands=pre_ray_start, +) +job.follow_logs_until_completion() + +# 5) Clean-up +cluster.stop() +``` + +### Notes +1. `workdir` is rsync'ed into the first declared `volume_mounts` on the executor, so relative imports *just work*. +2. Add `pre_ray_start_commands=["apt-get update && …"]` to inject shell snippets that run inside the **head** and **worker** containers **before** Ray starts. + +## 3. Slurm quick-start + +```python +import os +from pathlib import Path + +import nemo_run as run +from nemo_run.core.execution.slurm import SlurmExecutor, SlurmJobDetails, SSHTunnel +from nemo_run.run.ray.cluster import RayCluster +from nemo_run.run.ray.job import RayJob + +# 1) SSH tunnel to the Slurm login node so we can launch remotely +ssh = SSHTunnel( + host="login.my-hpc.com", # public hostname of login node + user="jdoe", # your cluster username + job_dir="/scratch/jdoe/runs", # where NeMo-Run stores Ray artefacts like logs, code, etc. + identity="~/.ssh/id_ed25519", # optional SSH key +) + +# 2) Create a Slurm executor and tweak defaults +executor = SlurmExecutor( + account="gpu-dept", + partition="a100", + nodes=2, + gpus_per_node=8, + gres="gpu:8", + time="04:00:00", + container_image="nvcr.io/nvidia/pytorch:24.05-py3", + container_mounts=["/scratch:/scratch"], + env_vars={"HF_HOME": "/scratch/hf_cache"}, + tunnel=ssh, +) + +# Optional: customise where Slurm writes stdout/err +class CustomJobDetails(SlurmJobDetails): + @property + def stdout(self) -> Path: # noqa: D401 – illustrative only + assert self.folder + return Path(self.folder) / "sbatch_stdout.out" # Will write sbatch output here. + + @property + def stderr(self) -> Path: # noqa: D401 + assert self.folder + return Path(self.folder) / "sbatch_stderr.err" + +executor.job_details = CustomJobDetails() + +# 3) Commands executed on every node right before Ray starts +pre_ray_start = [ + "pip install uv", +] + +# 4) Bring up the Ray cluster (Slurm array job under the hood) +cluster = RayCluster(name="demo-slurm-ray", executor=executor) +cluster.start(timeout=1800, pre_ray_start_commands=pre_ray_start) +cluster.port_forward(port=8265, target_port=8265) # dashboard → http://localhost:8265 + +# 5) Submit a Ray job that runs inside the cluster +job = RayJob(name="demo-slurm-job", executor=executor) +job.start( + command="uv run python train.py --config cfgs/train.yaml cluster.num_nodes=2", + workdir="/path/to/project/", # rsync'ed via SSH to the cluster_dir/code/ + pre_ray_start_commands=pre_ray_start, +) +job.follow_logs_until_completion() + +# 6) Tear everything down (or just let the wall-time expire) +cluster.stop() +``` + +### Tips for Slurm users +* `executor.packager = run.GitArchivePackager()` if you prefer packaging a git tree instead of rsync. +* `cluster.port_forward()` opens an SSH tunnel from *your laptop* to the Ray dashboard running on the head node. + +## 4. API reference cheat-sheet + +```python +cluster = RayCluster(name, executor) +cluster.start(wait_until_ready=True, timeout=600, pre_ray_start_commands=["pip install -r …"]) +cluster.status(display=True) +cluster.port_forward(port=8265, target_port=8265, wait=False) +cluster.stop() + +job = RayJob(name, executor) +job.start(command, workdir, runtime_env_yaml=None, pre_ray_start_commands=None) +job.status() +job.logs(follow=True) +job.stop() +``` + +All methods are synchronous and **return immediately** when their work is done; the helpers hide the messy details (kubectl, squeue, ssh, …). + +## 5. Rolling your own CLI + +Because `RayCluster` and `RayJob` are plain Python, you can compose them inside **argparse**, **Typer**, **Click** – anything. Here is a minimal **argparse** script: + +```python +import argparse +from nemo_run.core.execution.kuberay import KubeRayExecutor, KubeRayWorkerGroup +from nemo_run.run.ray.cluster import RayCluster +from nemo_run.run.ray.job import RayJob + + +def main() -> None: + parser = argparse.ArgumentParser(description="Submit a Ray job via NeMo-Run") + parser.add_argument("--name", default="demo", help="Base name for cluster + job") + parser.add_argument( + "--image", + default="anyscale/ray:2.43.0-py312-cu125", + help="Ray container image", + ) + parser.add_argument( + "--command", + default="python script.py", + help="Entrypoint to execute inside Ray job", + ) + args = parser.parse_args() + + # 1) Build the executor programmatically + executor = KubeRayExecutor( + namespace="ml-team", + ray_version="2.43.0", + image=args.image, + worker_groups=[KubeRayWorkerGroup(group_name="worker", replicas=1, gpus_per_worker=8)], + ) + + # 2) Spin up a cluster and keep it for the lifetime of the script + cluster = RayCluster(name=f"{args.name}-cluster", executor=executor) + cluster.start() + + # 3) Submit a job against that cluster + job = RayJob(name=f"{args.name}-job", executor=executor) + job.start(command=args.command, workdir="./") + + # 4) Stream logs and block until completion + job.follow_logs_until_completion() + + # 5) Tidy-up + cluster.stop() + + +if __name__ == "__main__": + main() +``` + +From there you can wrap the script with `uvx`, bake it into a Docker image, or integrate it into a larger orchestration system – the underlying NeMo-Run APIs stay the same. + +--- + +Happy distributed hacking! 🚀 diff --git a/nemo_run/core/execution/kuberay.py b/nemo_run/core/execution/kuberay.py index 80df90d5..a3c16e78 100644 --- a/nemo_run/core/execution/kuberay.py +++ b/nemo_run/core/execution/kuberay.py @@ -16,10 +16,17 @@ import copy import logging +import os import re +import subprocess +import time from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple +from kubernetes import client, watch +from kubernetes.client import CoreV1Api +from kubernetes.client.rest import ApiException + from nemo_run.core.execution.base import Executor # Group, Version, Plural @@ -82,6 +89,7 @@ class KubeRayExecutor(Executor): volumes: list[dict[str, Any]] = field(default_factory=list) reuse_volumes_in_worker_groups: bool = True spec_kwargs: dict[str, Any] = field(default_factory=dict) + container_kwargs: dict[str, Any] = field(default_factory=dict) lifecycle_kwargs: dict[str, Any] = field(default_factory=dict) def __post_init__(self): @@ -116,10 +124,12 @@ def get_cluster_body(self, name: str) -> dict[str, Any]: memory_limits=self.head_memory, ray_start_params=self.ray_start_params, head_ports=self.head_ports, + env_vars=self.env_vars, volumes=self.volumes, volume_mounts=self.volume_mounts, spec_kwargs=self.spec_kwargs, lifecycle_kwargs=self.lifecycle_kwargs, + container_kwargs=self.container_kwargs, ) for worker_group in self.worker_groups: cluster = populate_worker_group( @@ -141,6 +151,8 @@ def get_cluster_body(self, name: str) -> dict[str, Any]: annotations=worker_group.annotations, spec_kwargs=self.spec_kwargs, lifecycle_kwargs=self.lifecycle_kwargs, + container_kwargs=self.container_kwargs, + env_vars=self.env_vars, ) return cluster @@ -183,15 +195,17 @@ def populate_ray_head( memory_limits: str, ray_start_params: dict, head_ports: list[dict[str, Any]], + env_vars: dict[str, str], volume_mounts: list[dict[str, Any]], volumes: list[dict[str, Any]], spec_kwargs: dict[str, Any], lifecycle_kwargs: dict[str, Any], + container_kwargs: dict[str, Any], ) -> dict[str, Any]: # make sure metadata exists if "spec" in cluster.keys(): if "headGroupSpec" not in cluster.keys(): - logger.info(f"setting the headGroupSpec for cluster {cluster['metadata']['name']}") + logger.debug(f"setting the headGroupSpec for cluster {cluster['metadata']['name']}") cluster["spec"]["headGroupSpec"] = [] else: logger.error("error creating ray head, the spec and/or metadata is not define") @@ -211,6 +225,7 @@ def populate_ray_head( "image": ray_image, "name": "ray-head", "ports": head_ports, + "env": [{"name": k, "value": v} for k, v in env_vars.items()], "lifecycle": { "preStop": {"exec": {"command": ["/bin/sh", "-c", "ray stop"]}}, **lifecycle_kwargs, @@ -223,6 +238,7 @@ def populate_ray_head( "limits": {"cpu": cpu_limits, "memory": memory_limits}, }, "volumeMounts": volume_mounts, + **container_kwargs, } ], "volumes": volumes, @@ -253,6 +269,8 @@ def populate_worker_group( annotations: dict[str, Any], spec_kwargs: dict[str, Any], lifecycle_kwargs: dict[str, Any], + container_kwargs: dict[str, Any], + env_vars: dict[str, str], ) -> dict[str, Any]: assert is_valid_name(group_name) assert max_replicas >= min_replicas @@ -286,16 +304,18 @@ def populate_worker_group( "containers": [ { "image": ray_image, + "name": "ray-worker", + "env": [{"name": k, "value": v} for k, v in env_vars.items()], "lifecycle": { "preStop": {"exec": {"command": ["/bin/sh", "-c", "ray stop"]}}, **lifecycle_kwargs, }, - "name": "ray-worker", "resources": { "requests": resource_requests, "limits": resource_limits, }, "volumeMounts": volume_mounts, + **container_kwargs, } ], "volumes": volumes, @@ -435,3 +455,142 @@ def is_valid_label(name: str) -> bool: logger.error(msg) return False return True + + +def sync_workdir_via_pod( + *, + name: str, + namespace: str, + workdir: str, + core_v1_api: CoreV1Api, + volumes: List[dict[str, object]], + volume_mounts: List[dict[str, object]], + workspace_path: str = "/workspace", + image: str = "alpine:3.19", + cleanup: bool = False, + cleanup_timeout: int = 5, +) -> None: + """Spin up a throw-away Pod that mounts the same volumes as the Ray + cluster and streams *workdir* into *workspace_path* inside the mount. + + The function blocks until the copy is complete and the Pod is removed. + Requires that the *kubectl* binary is available in PATH and can access + the same cluster context as the Kubernetes Python client. + """ + + pod_name = f"{name}-dm" + + # Pod manifest + pod_body = client.V1Pod( + metadata=client.V1ObjectMeta(name=pod_name, namespace=namespace), + spec=client.V1PodSpec( + restart_policy="Never", + containers=[ + client.V1Container( + name="mover", + image=image, + command=["sh", "-c", "sleep infinity"], + volume_mounts=volume_mounts, + lifecycle={ + "postStart": { + "exec": { + "command": [ + "sh", + "-c", + # Install rsync on first container start if missing (Alpine) + "command -v rsync >/dev/null 2>&1 || apk add --no-cache rsync", + ] + } + } + }, + ) + ], + volumes=volumes, + ), + ) + + # Create Pod (idempotent – reuse if already exists) + logger.info( + f"Creating data-mover pod '{pod_name}' in namespace '{namespace}' (or re-using if present)" + ) + try: + core_v1_api.create_namespaced_pod(namespace=namespace, body=pod_body) + except ApiException as e: + if e.status == 409: # AlreadyExists + logger.info(f"Data-mover pod '{pod_name}' already exists – will reuse it") + else: + raise + + # Wait until pod is Running + w = watch.Watch() + for event in w.stream( + core_v1_api.list_namespaced_pod, + namespace=namespace, + field_selector=f"metadata.name={pod_name}", + timeout_seconds=120, + ): + pod_obj: client.V1Pod = event.get("object") # type: ignore[assignment] + phase = pod_obj.status.phase if pod_obj.status else None + if phase == "Running": + w.stop() + break + else: + raise RuntimeError("Data-mover pod did not reach Running state in time") + + # Ensure workspace dir exists + subprocess.check_call( + [ + "kubectl", + "exec", + "-n", + namespace, + pod_name, + "--", + "mkdir", + "-p", + workspace_path, + ] + ) + + # Use rsync over kubectl exec + rsync_cmd: list[str] = [ + "rsync", + "-az", + "--delete", + ] + + # Respect .gitignore rules if present in the workdir + if os.path.isfile(os.path.join(workdir, ".gitignore")): + rsync_cmd.extend(["--filter=:- .gitignore"]) + + # Tell rsync to reach the remote side via kubectl exec + rsync_cmd.extend( + [ + "-e", + f"kubectl exec -i -n {namespace} {pod_name}", + "--", # Marks end-of-options for rsync – mandatory when the dest starts with "--:" + f"{os.path.abspath(workdir).rstrip(os.sep)}/", + f"--:{workspace_path.rstrip('/')}/", + ] + ) + + # Emit the full command for easier troubleshooting + logger.debug("Running rsync command: %s", " ".join(rsync_cmd)) + + subprocess.check_call(rsync_cmd) + + if cleanup: + logger.info("Workdir synced to PVC via data-mover pod. Cleaning up…") + core_v1_api.delete_namespaced_pod( + name=pod_name, namespace=namespace, body=client.V1DeleteOptions() + ) + + # Wait for termination + timeout = time.time() + cleanup_timeout + while time.time() < timeout: + try: + core_v1_api.read_namespaced_pod(name=pod_name, namespace=namespace) + except ApiException as e: + if e.status == 404: + break + time.sleep(2) diff --git a/nemo_run/run/ray/cluster.py b/nemo_run/run/ray/cluster.py index 9bb569b2..e955af7a 100644 --- a/nemo_run/run/ray/cluster.py +++ b/nemo_run/run/ray/cluster.py @@ -34,36 +34,33 @@ class RayCluster: name: str executor: Executor - pre_ray_start_commands: Optional[list[str]] = None def __post_init__(self): if self.executor.__class__ not in self.BACKEND_MAP: raise ValueError(f"Unsupported executor: {self.executor.__class__}") - self.backend = self.BACKEND_MAP[self.executor.__class__]() + backend_cls = self.BACKEND_MAP[self.executor.__class__] + self.backend = backend_cls(name=self.name, executor=self.executor) # type: ignore[arg-type] self._port_forward_map = {} - def start(self, wait_until_ready: bool = True, timeout: int = 1000, dryrun: bool = False): + def start( + self, + wait_until_ready: bool = True, + timeout: int = 1000, + dryrun: bool = False, + pre_ray_start_commands: Optional[list[str]] = None, + ): assert isinstance(self.executor, self.backend.EXECUTOR_CLS) - self.backend.create_ray_cluster( - name=self.name, - executor=self.executor, - pre_ray_start_commands=self.pre_ray_start_commands, + self.backend.create( + pre_ray_start_commands=pre_ray_start_commands, dryrun=dryrun, ) - if wait_until_ready: - self.backend.wait_until_ray_cluster_running( - name=self.name, executor=self.executor, timeout=timeout - ) + if wait_until_ready and not dryrun: + self.backend.wait_until_running(timeout=timeout) - def schedule_job( - self, name: str, executor: Executor, command: str, workdir: str, dryrun: bool = False - ): - assert isinstance(self.executor, self.backend.EXECUTOR_CLS) - self.backend.schedule_ray_job( - name=name, executor=executor, command=command, workdir=workdir, dryrun=dryrun - ) + def status(self, display: bool = True): + return self.backend.status(display=display) # type: ignore[attr-defined] def port_forward(self, port: int = 8265, target_port: int = 8265, wait: bool = False): assert isinstance(self.executor, self.backend.EXECUTOR_CLS) @@ -71,8 +68,6 @@ def port_forward(self, port: int = 8265, target_port: int = 8265, wait: bool = F self._port_forward_map[port].stop_forwarding() self._port_forward_map[port] = self.backend.port_forward( - name=self.name, - executor=self.executor, port=port, target_port=target_port, wait=wait, @@ -83,4 +78,4 @@ def stop(self): for port_forward in self._port_forward_map.values(): port_forward.stop_forwarding() - self.backend.delete_ray_cluster(name=self.name, executor=self.executor, wait=True) + self.backend.delete(wait=True) diff --git a/nemo_run/run/ray/job.py b/nemo_run/run/ray/job.py new file mode 100644 index 00000000..b0ed2548 --- /dev/null +++ b/nemo_run/run/ray/job.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Optional + +from nemo_run.core.execution.base import Executor +from nemo_run.core.execution.kuberay import KubeRayExecutor +from nemo_run.core.execution.slurm import SlurmExecutor +from nemo_run.run.ray.kuberay import KubeRayJob +from nemo_run.run.ray.slurm import SlurmRayJob + + +@dataclass(kw_only=True) +class RayJob: + """Backend-agnostic convenience wrapper around Ray *jobs*.""" + + BACKEND_MAP = { + KubeRayExecutor: KubeRayJob, + SlurmExecutor: SlurmRayJob, + } + + name: str + executor: Executor + pre_ray_start_commands: Optional[list[str]] = None + + def __post_init__(self) -> None: # noqa: D401 – simple implementation + if self.executor.__class__ not in self.BACKEND_MAP: + raise ValueError(f"Unsupported executor: {self.executor.__class__}") + + self.backend = self.BACKEND_MAP[self.executor.__class__]( + name=self.name, executor=self.executor + ) + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def start( + self, + command: str, + workdir: str, + runtime_env_yaml: Optional[str] | None = None, + pre_ray_start_commands: Optional[list[str]] = None, + dryrun: bool = False, + ) -> Any: + """Submit a Ray job and return a live helper (backend specific). + + The *pre_ray_start_commands* provided at construction time are forwarded + to the backend implementation so callers can inject arbitrary shell + commands that run inside the Ray *head* container right before the + cluster starts. + """ + self.backend.start( # type: ignore[attr-defined] + command=command, + workdir=workdir, + runtime_env_yaml=runtime_env_yaml, + pre_ray_start_commands=pre_ray_start_commands, + dryrun=dryrun, + ) + + def stop(self) -> None: + self.backend.stop() # type: ignore[attr-defined] + + def status(self, display: bool = True): + return self.backend.status(display=display) # type: ignore[attr-defined] + + def logs(self, *, follow: bool = False, lines: int = 100, timeout: int = 100): + self.backend.logs(follow=follow, lines=lines, timeout=timeout) # type: ignore[attr-defined] diff --git a/nemo_run/run/ray/kuberay.py b/nemo_run/run/ray/kuberay.py index 514bb889..d3947c9a 100644 --- a/nemo_run/run/ray/kuberay.py +++ b/nemo_run/run/ray/kuberay.py @@ -15,8 +15,12 @@ # Based on https://github.com/ray-project/kuberay/blob/master/clients/python-client/python_client/kuberay_cluster_api.py import logging +import os +import subprocess import time -from typing import Any, Optional +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Optional import yaml from kubernetes import client, config @@ -27,45 +31,51 @@ logger = logging.getLogger(__name__) +@dataclass(kw_only=True) class KubeRayCluster: + """Lightweight helper around the KubeRay operator RayCluster CRD lifecycle. + + The class mirrors :class:`SlurmRayCluster` and :class:`KubeRayJob` – it's + a lightweight dataclass holding just two identifying fields (*name* and + *executor*). All Kubernetes clients are instantiated lazily in + :py:meth:`__post_init__`. + """ + + # ------------------------------------------------------------------ + # Class-level constants / type bindings + # ------------------------------------------------------------------ EXECUTOR_CLS = KubeRayExecutor - # initial config to setup the kube client - def __init__(self): - # loading the config - self.kube_config: Optional[Any] = config.load_kube_config() + # ------------------------------------------------------------------ + # Primary identifiers (mirrors SlurmRayCluster API) + # ------------------------------------------------------------------ + name: str + executor: KubeRayExecutor + + # ------------------------------------------------------------------ + # Dataclass lifecycle hooks + # ------------------------------------------------------------------ + def __post_init__(self) -> None: # noqa: D401 – simple verb is fine + """Initialise Kubernetes API clients once the instance is created.""" + # Load local kube-config once; the function returns *None* so we don't store it. + config.load_kube_config() + + # The dedicated clients are what we interact with throughout the class + # – separating CoreV1 for pods/services from CustomObjects for CRDs. self.api = client.CustomObjectsApi() self.core_v1_api = client.CoreV1Api() - def list_ray_clusters( - self, k8s_namespace: str = "default", label_selector: str = "", async_req: bool = False + def _get( + self, + name: Optional[str] = None, + k8s_namespace: Optional[str] = None, ) -> Any: - logger.info( - f"Listing Ray clusters in namespace: {k8s_namespace}, label_selector: {label_selector}, async_req: {async_req}" - ) + # Return the RayCluster custom object, if present. - try: - resource: Any = self.api.list_namespaced_custom_object( - group=GROUP, - version=VERSION, - plural=PLURAL, - namespace=k8s_namespace, - label_selector=label_selector, - async_req=async_req, - ) - if "items" in resource: - return resource - return None - except ApiException as e: - if e.status == 404: - logger.error("raycluster resource is not found. error = {}".format(e)) - return None - else: - logger.error("error fetching custom resource: {}".format(e)) - return None + name = name or self.name + namespace = k8s_namespace or self.executor.namespace or "default" - def get_ray_cluster(self, name: str, k8s_namespace: str = "default") -> Any: - logger.info(f"Getting Ray cluster '{name}' in namespace '{k8s_namespace}'") + logger.debug(f"Getting Ray cluster '{name}' in namespace '{namespace}'") try: resource: Any = self.api.get_namespaced_custom_object( @@ -73,124 +83,135 @@ def get_ray_cluster(self, name: str, k8s_namespace: str = "default") -> Any: version=VERSION, plural=PLURAL, name=name, - namespace=k8s_namespace, + namespace=namespace, ) return resource except ApiException as e: if e.status == 404: - logger.error(f"Ray cluster '{name}' not found in namespace '{k8s_namespace}': {e}") + logger.error(f"Ray cluster '{name}' not found in namespace '{namespace}': {e}") return None else: - logger.error( - f"Error fetching Ray cluster '{name}' in namespace '{k8s_namespace}': {e}" - ) + logger.error(f"Error fetching Ray cluster '{name}' in namespace '{namespace}': {e}") return None - def get_ray_cluster_status( + def status( self, - name: str, - k8s_namespace: str = "default", timeout: int = 60, delay_between_attempts: int = 5, + *, + display: bool = False, ) -> Any: + """Return the ``status`` stanza of the RayCluster CR (blocking). + + Polls until the CR contains a *status* field or *timeout* is reached. + """ + + namespace = self.executor.namespace or "default" + name = self.name + logger.info( - f"Getting Ray cluster status for '{name}' in namespace '{k8s_namespace}', timeout: {timeout}s, delay: {delay_between_attempts}s" + f"Getting Ray cluster status for '{name}' in namespace '{namespace}', " + f"timeout: {timeout}s, delay: {delay_between_attempts}s" ) - while timeout > 0: + remaining = timeout + while remaining > 0: try: resource: Any = self.api.get_namespaced_custom_object_status( group=GROUP, version=VERSION, plural=PLURAL, name=name, - namespace=k8s_namespace, + namespace=namespace, ) except ApiException as e: if e.status == 404: - logger.error( + logger.debug( f"Ray cluster '{name}' status fetch failed: resource not found: {e}" ) return None - else: - logger.error( - f"Error fetching status for Ray cluster '{name}' in namespace '{k8s_namespace}': {e}" - ) - return None + logger.error( + f"Error fetching status for Ray cluster '{name}' in namespace '{namespace}': {e}" + ) + return None - if "status" in resource and resource["status"]: - return resource["status"] - else: - logger.info(f"Ray cluster '{name}' status not set yet, waiting...") - time.sleep(delay_between_attempts) - timeout -= delay_between_attempts + if resource.get("status"): + status_dict = resource["status"] + if display: + self._display_banner(status_dict) + return status_dict + + logger.debug(f"Ray cluster '{name}' status not set yet, waiting...") + time.sleep(delay_between_attempts) + remaining -= delay_between_attempts - logger.info(f"Ray cluster '{name}' status not set yet, timing out...") + logger.debug(f"Ray cluster '{name}' status not set yet, timing out...") return None - def wait_until_ray_cluster_running( + def wait_until_running( self, - name: str, - executor: KubeRayExecutor, timeout: int = 60, delay_between_attempts: int = 5, - k8s_namespace: Optional[str] = None, ) -> bool: - namespace = k8s_namespace or executor.namespace + """Block until the Ray head service has a reachable IP (or timeout).""" + + namespace = self.executor.namespace or "default" + name = self.name + logger.info( - f"Waiting until Ray cluster '{name}' is running in namespace '{namespace}', timeout: {timeout}s, delay: {delay_between_attempts}s" + f"Waiting until Ray cluster '{name}' is running in namespace '{namespace}', " + f"timeout: {timeout}s, delay: {delay_between_attempts}s" ) - while timeout > 0: - status = self.get_ray_cluster_status( - name, k8s_namespace or executor.namespace, timeout, delay_between_attempts - ) + remaining = timeout + while remaining > 0: + poll_window = min(delay_between_attempts, remaining) + status = self.status(poll_window, poll_window, display=False) if not status: logger.info(f"Ray cluster '{name}' status could not be retrieved") return False - # TODO: once we add State to Status, we should check for that as well - if status and status["head"] and status["head"]["serviceIP"]: + # TODO: once the operator exposes a proper .state field, use that + # For now we infer readiness from the presence of head.serviceIP + if status.get("head", {}).get("serviceIP"): logger.info(f"Ray cluster '{name}' is running") return True - logger.info( + logger.debug( f"Ray cluster '{name}' status is not running yet, current status: {status.get('state', 'unknown')}" ) - time.sleep(delay_between_attempts) - timeout -= delay_between_attempts + remaining -= poll_window - logger.info(f"Ray cluster '{name}' status is not running yet, timing out...") + logger.debug(f"Ray cluster '{name}' status is not running yet, timing out...") return False - def create_ray_cluster( + def create( self, - name: str, - executor: KubeRayExecutor, pre_ray_start_commands: Optional[list[str]] = None, dryrun: bool = False, - k8s_namespace: Optional[str] = None, ) -> Any: - namespace = k8s_namespace or executor.namespace + """Create the RayCluster CR (idempotent).""" + + namespace = self.executor.namespace or "default" + name = self.name + logger.info(f"Creating Ray cluster '{name}' in namespace '{namespace}'") + # Ensure lifecycle_kwargs dict exists (older executor versions may omit it) + if not hasattr(self.executor, "lifecycle_kwargs") or self.executor.lifecycle_kwargs is None: + self.executor.lifecycle_kwargs = {} + if pre_ray_start_commands: k8s_pre_ray_start_commands = "\n".join(pre_ray_start_commands) - executor.lifecycle_kwargs["postStart"] = { - "exec": { - "command": [ - "/bin/sh", - "-c", - k8s_pre_ray_start_commands, - ] - } + self.executor.lifecycle_kwargs["postStart"] = { + "exec": {"command": ["/bin/sh", "-c", k8s_pre_ray_start_commands]} } - body = executor.get_cluster_body(name) + body = self.executor.get_cluster_body(name) if dryrun: print(yaml.dump(body)) - return + return body try: resource: Any = self.api.create_namespaced_custom_object( @@ -198,36 +219,46 @@ def create_ray_cluster( version=VERSION, plural=PLURAL, body=body, - namespace=k8s_namespace or executor.namespace, + namespace=namespace, ) return resource except ApiException as e: if e.status == 409: logger.error(f"Ray cluster '{name}' already exists: {e.reason}") return None - else: - logger.error(f"Error creating Ray cluster '{name}' in namespace '{namespace}': {e}") - return None - - def schedule_ray_job( - self, - name: str, - executor: KubeRayExecutor, - command: str, - workdir: str, - ): - raise NotImplementedError("KubeRay does not support scheduling jobs") + logger.error(f"Error creating Ray cluster '{name}' in namespace '{namespace}': {e}") + return None - def delete_ray_cluster( + def delete( self, - name: str, - executor: KubeRayExecutor, - k8s_namespace: Optional[str] = None, wait: bool = False, timeout: int = 300, poll_interval: int = 5, ) -> Optional[bool]: - namespace = k8s_namespace or executor.namespace + """Delete the RayCluster CR and, optionally, wait for full teardown. + + Parameters + ---------- + wait : bool, default False + When *True*, block until the RayCluster CR and all its pods have + disappeared. A best-effort poll is performed every + *poll_interval* seconds up to *timeout* seconds. + timeout : int, default 300 + Maximum time in seconds to wait for deletion when *wait* is + enabled. + poll_interval : int, default 5 + Interval between successive status checks. + + Returns + ------- + bool | None + • *True* – deletion confirmed. + • *False* – timed out while waiting. + • *None* – cluster already absent before the call. + """ + namespace = self.executor.namespace or "default" + name = self.name + logger.info(f"Deleting Ray cluster '{name}' in namespace '{namespace}'") try: @@ -242,7 +273,7 @@ def delete_ray_cluster( if not wait: return True - logger.info(f"Waiting for Ray cluster '{name}' and its pods to be fully deleted...") + logger.debug(f"Waiting for Ray cluster '{name}' and its pods to be fully deleted...") start_time = time.time() cluster_deleted = False @@ -251,7 +282,7 @@ def delete_ray_cluster( # Check if CR still exists if not cluster_deleted: try: - cluster = self.get_ray_cluster(name, namespace) + cluster = self._get(name=name, k8s_namespace=namespace) if not cluster: logger.info(f"Ray cluster CR '{name}' has been deleted") cluster_deleted = True @@ -275,7 +306,7 @@ def delete_ray_cluster( return True active_pods = [pod.metadata.name for pod in pods.items] - logger.info( + logger.debug( f"Waiting for {len(active_pods)} pods to terminate: {', '.join(active_pods[:3])}" + ( f"... and {len(active_pods) - 3} more" @@ -297,7 +328,7 @@ def delete_ray_cluster( # Check final state try: - cluster_exists = self.get_ray_cluster(name, namespace) is not None + cluster_exists = self._get(name=name, k8s_namespace=namespace) is not None if cluster_exists: logger.warning(f"Ray cluster CR '{name}' still exists after timeout") @@ -322,14 +353,27 @@ def delete_ray_cluster( logger.error(f"Error deleting Ray cluster '{name}': {e}") return None - def patch_ray_cluster( + def patch( self, - name: str, ray_patch: Any, - executor: KubeRayExecutor, - k8s_namespace: Optional[str] = None, ) -> Any: - namespace = k8s_namespace or executor.namespace + """Patch the RayCluster custom resource with a user-supplied body. + + The patch is applied using the Kubernetes *merge* strategy, mirroring + ``kubectl patch --type=merge``. + + Parameters + ---------- + ray_patch : Any + A JSON-serialisable object representing the patch to apply. + + Returns + ------- + bool + *True* on success, *False* if the API call raised an exception. + """ + namespace = self.executor.namespace or "default" + name = self.name logger.info(f"Patching Ray cluster '{name}' in namespace '{namespace}'") try: # we patch the existing raycluster with the new config @@ -351,19 +395,37 @@ def patch_ray_cluster( def port_forward( self, - name: str, port: int, target_port: int, - executor: KubeRayExecutor, wait: bool = False, ): + """Expose the Ray head service locally via *kubectl port-forward*. + + Parameters + ---------- + port : int + Local port on which to listen. + target_port : int + Port number of the Ray head service inside the cluster. + wait : bool, default False + If *True*, block until the user terminates the process (SIGINT or + SIGTERM). Otherwise a daemon thread is returned immediately. + + Returns + ------- + threading.Thread + The daemon thread encapsulating the port-forwarding subprocess. + """ import queue import subprocess import threading import time + name = self.name + executor = self.executor + # Get cluster details - cluster = self.get_ray_cluster(name, executor.namespace or "default") + cluster = self._get(name=name, k8s_namespace=executor.namespace or "default") if not cluster: raise RuntimeError(f"Could not find Ray cluster {name}") @@ -393,11 +455,11 @@ def __init__(self, target, daemon): self._stop_event = stop_event def stop_forwarding(self): - logger.info("Stopping port forwarding") + logger.debug("Stopping port forwarding") self._stop_event.set() def forward_port_daemon(): - logger.info( + logger.debug( f"Starting port forwarding from localhost:{port} to service {service_name}:{target_port} in namespace {namespace}" ) @@ -417,10 +479,13 @@ def forward_port_daemon(): namespace, ] - logger.info(f"Running command: {' '.join(cmd)}") + logger.debug(f"Running command: {' '.join(cmd)}") process = subprocess.Popen( - cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True + cmd, + stdout=subprocess.DEVNULL, # avoid dead-lock on unread STDOUT + stderr=subprocess.PIPE, + text=True, ) # Signal success to the main thread after short wait to ensure it started @@ -429,7 +494,7 @@ def forward_port_daemon(): if not connection_established: connection_established = True status_queue.put(("success", None)) - logger.info("Port forwarding connection established") + logger.debug("Port forwarding connection established") # Wait for the process to complete or be killed while not stop_event.is_set() and process.poll() is None: @@ -449,7 +514,6 @@ def forward_port_daemon(): # If process exited with error, log it if process.returncode != 0: - # Safe way to read stderr that handles None case stderr_output = "" if process.stderr: stderr_output = process.stderr.read() or "" @@ -459,7 +523,7 @@ def forward_port_daemon(): ) # If we get here, the connection was closed unexpectedly - logger.info( + logger.debug( "Port forwarding connection closed, reconnecting in 5 seconds..." ) time.sleep(5) @@ -526,7 +590,7 @@ def _wait_for_forwarding_termination(self, forward_thread, stop_event): original_sigterm_handler = signal.getsignal(signal.SIGTERM) def signal_handler(sig, frame): - logger.info("Received signal to stop port forwarding") + logger.info("Received signal to stop port forwarding.") stop_event.set() # Restore original signal handlers @@ -551,5 +615,345 @@ def signal_handler(sig, frame): stop_event.set() # Wait for the thread to finish + logger.info("Waiting for port forwarding thread to finish for 5 seconds...") forward_thread.join(timeout=5) logger.info("Port forwarding stopped") + + # Helper to print banner + def _display_banner(self, status_dict: Any) -> None: + namespace = self.executor.namespace or "default" + logger.info( + f"""\n\n\033[1;34mRay cluster status (KubeRay) in namespace {namespace}:\033[0m + • \033[1mName\033[0m : {self.name} + • \033[1mState\033[0m : {status_dict.get("state", "UNKNOWN") if isinstance(status_dict, dict) else "UNKNOWN"} + • \033[1mHead svc IP\033[0m: {status_dict.get("head", {}).get("serviceIP") if isinstance(status_dict, dict) else "N/A"} + (use `kubectl get rayclusters {self.name} -n {namespace}` to inspect, `kubectl delete rayclusters {self.name} -n {namespace}` to delete)\n""" + ) + + +@dataclass(kw_only=True) +class KubeRayJob: + """Helper object for interacting with a KubeRay RayJob. + + Parameters + ---------- + name : str + Name of the RayJob custom resource. + namespace : str + Kubernetes namespace in which the job was created. + """ + + name: str + executor: KubeRayExecutor + + def __post_init__(self): + # Lazily create K8s API clients if not supplied + self.api = client.CustomObjectsApi() + self.core_v1_api = client.CoreV1Api() + # Ensure backward-compat: if cluster is None we still work (stand-alone usage) + + # ------------------------------------------------------------------ + # Public helpers mirroring SlurmRayJob API for downstream symmetry. + # ------------------------------------------------------------------ + + def stop(self) -> None: + """Delete the RayJob custom resource (equivalent to job cancellation).""" + logger.debug(f"Cancelling RayJob '{self.name}' in namespace '{self.executor.namespace}'") + try: + self.api.delete_namespaced_custom_object( + group="ray.io", + version="v1", + plural="rayjobs", + name=self.name, + namespace=self.executor.namespace, + ) + logger.debug(f"RayJob '{self.name}' cancellation requested (CR deleted)") + except ApiException as e: + if e.status == 404: + logger.warning(f"RayJob '{self.name}' not found – maybe already deleted") + else: + logger.error(f"Failed to cancel RayJob '{self.name}': {e}") + + def logs(self, follow: bool = False, lines: int = 100) -> None: + """Stream or show logs from the RayJob submitter pod. + + This simply shells out to ``kubectl logs -l job-name=`` which + is how the Ray docs recommend fetching RayJob logs. + """ + + cmd = [ + "kubectl", + "logs", + "-l", + f"job-name={self.name}", + "-n", + self.executor.namespace, + ] + + if follow: + cmd.append("-f") + else: + cmd.extend(["--tail", str(lines)]) + + logger.info( + f"Running: {' '.join(cmd)} (streaming={'yes' if follow else 'no'}, tail={lines})" + ) + + try: + if follow: + subprocess.run(cmd, check=False) + else: + output = subprocess.check_output(cmd, text=True) + print(output) + except FileNotFoundError: + logger.error("kubectl not found in PATH – cannot fetch logs") + except subprocess.CalledProcessError as e: + logger.error(f"kubectl logs returned non-zero exit status {e.returncode}") + + def status(self, display: bool = True) -> Dict[str, Any]: + """Return current RayJob status as a lightweight dict and pretty-print it.""" + + try: + resource = self.api.get_namespaced_custom_object( + group="ray.io", + version="v1", + plural="rayjobs", + name=self.name, + namespace=self.executor.namespace, + ) + except ApiException as e: + logger.error(f"Failed to fetch status for RayJob '{self.name}': {e}") + return {"jobStatus": "ERROR", "jobDeploymentStatus": "ERROR"} + + status = resource.get("status", {}) if isinstance(resource, dict) else {} + job_status = status.get("jobStatus", "UNKNOWN") + deployment_status = status.get("jobDeploymentStatus", "UNKNOWN") + + if display: + logger.info( + f"""\n\n\033[1;34mRay Job status for KubeRay cluster in namespace {self.executor.namespace}:\033[0m + • \033[1mName\033[0m : {self.name} + • \033[1mJob status\033[0m : {job_status} + • \033[1mDeployment\033[0m : {deployment_status} + (use `kubectl logs -l job-name={self.name} -n {self.executor.namespace} -f` to view logs)\n""" + ) + + return {"jobStatus": job_status, "jobDeploymentStatus": deployment_status} + + # ------------------------------------------------------------------ + # Convenience: tail logs asynchronously while waiting for completion, + # then optionally delete the RayJob CR once finished. + # ------------------------------------------------------------------ + + def follow_logs_until_completion( + self, + poll_interval: int = 10, + delete_on_finish: bool = True, + ) -> None: + """Stream job logs in real-time and clean up when the RayJob ends. + + This helper starts a background thread running ``kubectl logs -f`` + while the main thread polls the RayJob status every *poll_interval* + seconds. As soon as the job transitions to a terminal state + (SUCCEEDED/FAILED or Deployment Complete/Failed) the log thread is + joined and – if *delete_on_finish* is *True* – the RayJob CR is + deleted. + """ + + # ------------------------------------------------------------------ + # 1) Poll until the RayJob is actually running – only then start logs + # ------------------------------------------------------------------ + + RUNNING_DEPLOY_STATUS = "Running" + + while True: + st = self.status(display=True) + if st.get("jobDeploymentStatus") == RUNNING_DEPLOY_STATUS: + break + + # If job already finished/failed before reaching Running, bail out + if st.get("jobDeploymentStatus") in {"Complete", "Failed"}: + if delete_on_finish: + self.stop() + return + time.sleep(poll_interval) + + # ------------------------------------------------------------------ + # 2) Start log streaming in a daemon thread + # ------------------------------------------------------------------ + + def _tail(): + try: + self.logs(follow=True) + except Exception as e: # pragma: no cover – logging only + logger.error(f"Log tailing thread encountered an error: {e}") + + import threading + + log_thread = threading.Thread(target=_tail, daemon=True) + log_thread.start() + + # ------------------------------------------------------------------ + # 3) Poll until RayJob ends, then cleanup + # ------------------------------------------------------------------ + + TERMINAL_JOB_STATUSES = {"SUCCEEDED", "FAILED"} + TERMINAL_DEPLOY_STATUSES = {"Complete", "Failed"} + + try: + while True: + status = self.status(display=False) + if ( + status.get("jobStatus") in TERMINAL_JOB_STATUSES + or status.get("jobDeploymentStatus") in TERMINAL_DEPLOY_STATUSES + ): + break + time.sleep(poll_interval) + finally: + log_thread.join(timeout=5) + + if delete_on_finish: + try: + self.stop() + except Exception as e: # pragma: no cover + logger.debug(f"Ignoring error during job cleanup: {e}") + + def start( + self, + command: str, + workdir: str | None = None, + runtime_env_yaml: str | None = None, + pre_ray_start_commands: Optional[list[str]] = None, + dryrun: bool = False, + ): + """Create a RayJob CR via the KubeRay operator and return a live helper. + + This is a front-door convenience wrapper around + :py:meth:`KubeRayCluster.schedule_ray_job` so users can directly do:: + + KubeRayJob.start( + name="my-job", + executor=my_kuberay_executor, + command="python train.py", + workdir="./src", + ) + """ + # We directly replicate the logic previously living in + # `KubeRayCluster.schedule_ray_job` so that callers interact solely with + # *job* helpers, keeping cluster classes focused on cluster lifecycle + # only. + + # ------------------------------------------------------------------ + # 1. Handle optional *workdir* sync (data-mover pod). + # ------------------------------------------------------------------ + from nemo_run.core.execution.kuberay import sync_workdir_via_pod + + name = self.name + executor = self.executor + namespace = executor.namespace + + # Ensure lifecycle_kwargs dict exists on executor + if not hasattr(executor, "lifecycle_kwargs") or executor.lifecycle_kwargs is None: + executor.lifecycle_kwargs = {} + + if pre_ray_start_commands: + k8s_pre_cmds = "\n".join(pre_ray_start_commands) + executor.lifecycle_kwargs["postStart"] = { + "exec": {"command": ["/bin/sh", "-c", k8s_pre_cmds]} + } + + if workdir: + if not executor.volumes or not executor.volume_mounts: + raise ValueError( + "`workdir` specified but executor has no volumes/volume_mounts to mount it." + ) + + workspace_path = os.path.join( + executor.volume_mounts[0]["mountPath"], Path(workdir).name + ) + + if not dryrun: + sync_workdir_via_pod( + name=name, + namespace=namespace, + workdir=workdir, + core_v1_api=self.core_v1_api, + volumes=executor.volumes, + volume_mounts=executor.volume_mounts, + workspace_path=workspace_path, + ) + + # In-place patch of executor.lifecycle_kwargs with *postStart* if needed + if pre_ray_start_commands: + executor.lifecycle_kwargs["postStart"] = { + "exec": {"command": ["/bin/sh", "-c", "\n".join(pre_ray_start_commands)]} + } + + # ------------------------------------------------------------------ + # 2. Build RayCluster spec (via executor). + # ------------------------------------------------------------------ + cluster_name = f"{name}-raycluster" + ray_cluster_body = executor.get_cluster_body(cluster_name) + ray_cluster_spec = ray_cluster_body.get("spec", {}) + + # Ensure consistent workingDir inside all Ray containers so that relative + # paths in `ray job submit` resolve as expected. + container_workdir = "/workspace" + if workdir: + container_workdir = os.path.join( + executor.volume_mounts[0]["mountPath"], Path(workdir).name + ) + + def _apply_workdir(pod_template: dict): + try: + for c in pod_template["spec"]["containers"]: + c["workingDir"] = container_workdir + except Exception: + pass # ignore malformed specs + + if "headGroupSpec" in ray_cluster_spec: + _apply_workdir(ray_cluster_spec["headGroupSpec"]["template"]) + + for w in ray_cluster_spec.get("workerGroupSpecs", []): + _apply_workdir(w["template"]) # type: ignore[arg-type] + + # ------------------------------------------------------------------ + # 3. Assemble RayJob CRD manifest + # ------------------------------------------------------------------ + if runtime_env_yaml and os.path.isfile(Path(runtime_env_yaml)): + with open(runtime_env_yaml, "r") as f: + runtime_env_yaml = f.read() + + rayjob_body = { + "apiVersion": "ray.io/v1", + "kind": "RayJob", + "metadata": { + "name": name, + "namespace": namespace, + }, + "spec": { + "entrypoint": command, + "shutdownAfterJobFinishes": True, + "rayClusterSpec": ray_cluster_spec, + "runtimeEnvYAML": runtime_env_yaml, + }, + } + + if dryrun: + print(yaml.dump(rayjob_body)) + return rayjob_body + + # Create the RayJob CR via Kubernetes API + try: + self.api.create_namespaced_custom_object( + group="ray.io", + version="v1", + plural="rayjobs", + body=rayjob_body, + namespace=namespace, + ) + self.status() + except ApiException as e: + if e.status == 409: + raise RuntimeError(f"RayJob '{name}' already exists: {e.reason}") + raise RuntimeError(f"Error creating RayJob '{name}': {e}") diff --git a/nemo_run/run/ray/slurm.py b/nemo_run/run/ray/slurm.py index be87a8b2..c0c98624 100644 --- a/nemo_run/run/ray/slurm.py +++ b/nemo_run/run/ray/slurm.py @@ -27,7 +27,7 @@ import warnings from dataclasses import asdict, dataclass from pathlib import Path -from typing import Any, Dict, Optional, TypeAlias, Union +from typing import Any, Optional, TypeAlias from nemo_run.config import RUNDIR_NAME, RUNDIR_SPECIAL_NAME from nemo_run.core.execution.slurm import SlurmExecutor, _as_sbatch_flag @@ -40,6 +40,71 @@ logger = logging.getLogger(__name__) +# ----------------------------------------------------------------------------- +# Shared helper: cancel a Slurm job (used by SlurmRayCluster & SlurmRayJob) +# ----------------------------------------------------------------------------- + + +def cancel_slurm_job( + executor: SlurmExecutor, + name: str, + job_id: int | str, + *, + wait: bool = False, + timeout: int = 60, + poll_interval: int = 5, +) -> bool: + """Cancel a Slurm *job_id* and optionally wait until it terminates.""" + + executor.tunnel.connect() + logger.info(f"Cancelling Slurm job {job_id} for '{name}'") + + try: + executor.tunnel.run(f"scancel {job_id}") + except Exception as e: + logger.error(f"Failed to cancel job {job_id} for '{name}': {e}") + return False + + if not wait: + return True + + start_ts = time.time() + while time.time() - start_ts < timeout: + res = executor.tunnel.run(f"squeue -j {job_id} -h -o %T", warn=True) + state = res.stdout.strip() + + if not state: + logger.info(f"Job {job_id} for '{name}' successfully cancelled") + return True + + if state in {"FAILED", "CANCELLED", "TIMEOUT", "COMPLETED"}: + logger.info(f"Job {job_id} for '{name}' now in terminal state {state}") + return True + + logger.debug(f"Waiting for job {job_id} ('{name}') to terminate…") + time.sleep(poll_interval) + + logger.warning(f"Timed-out waiting for job {job_id} ('{name}') to cancel") + return False + + +def get_last_job_id(cluster_dir: str, executor: SlurmExecutor) -> Optional[int]: + """Return the last job ID for this cluster.""" + job_ids_file = os.path.join(cluster_dir, "job_ids.json") + if isinstance(executor.tunnel, SSHTunnel): + job_ids_result = executor.tunnel.run(f"cat {job_ids_file}", warn=True) + if job_ids_result.return_code == 0: + job_ids = json.loads(job_ids_result.stdout) + return int(job_ids[-1]) + else: + return None + else: + if not os.path.exists(job_ids_file): + return None + with open(job_ids_file, "r") as f: + job_ids = json.load(f) + return int(job_ids[-1]) + @dataclass(kw_only=True) class SlurmRayRequest: @@ -183,13 +248,26 @@ def __repr__(self) -> str: {self.materialize()}""" +@dataclass(kw_only=True) class SlurmRayCluster: EXECUTOR_CLS = SlurmExecutor - def __init__(self): - self.cluster_map = {} + name: str + executor: SlurmExecutor + + def __post_init__(self): + self.cluster_map: dict[str, str] = {} + + def _get_ray_cluster_info( + self, + name: Optional[str] = None, + executor: Optional[SlurmExecutor] = None, + ) -> dict[str, Any]: + # Private helper – intentionally undocumented (no public docstring) + + name = name or self.name + executor = executor or self.executor - def _get_ray_cluster_info(self, name: str, executor: SlurmExecutor) -> Dict[str, Any]: executor.tunnel.connect() cluster_dir = os.path.join(executor.tunnel.job_dir, name) cmd = f"test -f {cluster_dir}/ray_cluster_info.json && cat {cluster_dir}/ray_cluster_info.json" @@ -203,12 +281,12 @@ def _get_ray_cluster_info(self, name: str, executor: SlurmExecutor) -> Dict[str, return {} return {} - def get_ray_cluster_status( + def _status( self, - name: str, - executor: SlurmExecutor, - ) -> Dict[str, Union[str, bool, None]]: - logger.info(f"Getting Ray cluster status for '{name}'") + ) -> dict[str, str | bool | None]: + name = self.name + executor = self.executor + logger.debug(f"Getting Ray cluster status for '{name}'") executor.tunnel.connect() # Try to find the job by name @@ -220,38 +298,29 @@ def get_ray_cluster_status( job_id = result.stdout.strip() # If job not found in running jobs, check if it's in cluster_map - if not job_id and name in self.cluster_map: - job_id = self.cluster_map[name] - # Verify this job_id exists - cmd = f"squeue -j {job_id} -h -o %A" - result = executor.tunnel.run(cmd) - if not result.stdout.strip(): - # Job might be completed, check sacct - cmd = f"sacct -j {job_id} --format=State --noheader --parsable2" - result = executor.tunnel.run(cmd) - if result.stdout.strip(): - state = result.stdout.strip().split("\n")[0] - return {"state": state, "job_id": job_id, "ray_ready": state == "COMPLETED"} - # Job not found in sacct either, so it doesn't exist - return {"state": "NOT_FOUND", "job_id": None, "ray_ready": False} + if not job_id: + if name in self.cluster_map: + job_id = self.cluster_map[name] + else: + job_id = get_last_job_id(os.path.join(executor.tunnel.job_dir, name), executor) if not job_id: return {"state": "NOT_FOUND", "job_id": None, "ray_ready": False} # Store job_id in cluster_map for future reference - self.cluster_map[name] = job_id + self.cluster_map[name] = str(job_id) # Check job status cmd = f"squeue -j {job_id} -h -o %T" - result = executor.tunnel.run(cmd) + result = executor.tunnel.run(cmd, warn=True) - if not result.stdout.strip(): + if result.return_code != 0 or not result.stdout.strip(): # Job not found in squeue, check sacct cmd = f"sacct -j {job_id} --format=State --noheader --parsable2" result = executor.tunnel.run(cmd) status = result.stdout.strip().split("\n")[0] if result.stdout.strip() else "UNKNOWN" - return {"state": status, "job_id": job_id, "ray_ready": status == "COMPLETED"} + return {"state": status, "job_id": str(job_id), "ray_ready": status == "COMPLETED"} status = result.stdout.strip() @@ -262,17 +331,74 @@ def get_ray_cluster_status( if ray_cluster_info: ray_ready = True - return {"state": status, "job_id": job_id, "ray_ready": ray_ready} + return {"state": status, "job_id": str(job_id), "ray_ready": ray_ready} + + def status( + self, + *, + display: bool = False, + ) -> dict[str, Any]: + """Return the current Slurm and Ray status for this cluster. + + Parameters + ---------- + display : bool, optional + When *True* print a pretty, colourised summary to the logger. Defaults to *False*. + + Returns + ------- + dict[str, Any] + Mapping with keys ``state`` (str), ``job_id`` (str | None) and ``ray_ready`` (bool). + """ + status_dict = self._status() + if display: + cluster_dir = os.path.join(self.executor.tunnel.job_dir, self.name) + logs_dir = os.path.join(cluster_dir, "logs") + logger.info( + f"""\n\n\033[1;34mRay cluster status (Slurm) at {self.executor.tunnel.key}:\033[0m + • \033[1mName\033[0m : {self.name} + • \033[1mJob ID\033[0m : {status_dict.get("job_id")} + • \033[1mState\033[0m : {status_dict.get("state")} + • \033[1mRay ready\033[0m : {status_dict.get("ray_ready")} + • \033[1mCluster dir\033[0m: {cluster_dir} + • \033[1mLogs dir\033[0m : {logs_dir} + (use `squeue -j {status_dict.get("job_id")}` to check status, `scancel {status_dict.get("job_id")}` to cancel)\n""" + ) + + return status_dict - def create_ray_cluster( + def create( self, - name: str, - executor: SlurmExecutor, pre_ray_start_commands: Optional[list[str]] = None, dryrun: bool = False, command: Optional[str] = None, workdir: Optional[str] = None, ) -> Any: + """Create (or reuse) a Slurm-backed Ray cluster and return its job-id. + + If an active cluster with the same *name* already exists, that cluster is reused and + *None* is returned. With *dryrun=True* the generated SBATCH script is printed instead of + being submitted. + + Parameters + ---------- + pre_ray_start_commands : list[str] | None + Shell commands to run on each node *before* Ray is started. + dryrun : bool, optional + When *True* do **not** submit the job – only print the SBATCH script. Defaults to + *False*. + command : str | None + Optional command executed after the Ray head node is ready (e.g. ``ray job submit``). + workdir : str | None + Remote working directory that becomes the CWD inside the container. + + Returns + ------- + str | None + The Slurm job-id string, or *None* for dry-run / reuse cases. + """ + name = self.name + executor = self.executor cluster_dir = os.path.join(executor.tunnel.job_dir, name) ray_sbatch = SlurmRayRequest( name=name, @@ -286,19 +412,19 @@ def create_ray_cluster( ).materialize() if dryrun: - logger.info(f"Dry run: Ray cluster '{name}'") + logger.debug(f"Dry run: Ray cluster '{name}'") print(ray_sbatch) - return + return None logger.info(f"Creating Ray cluster '{name}'") # Check if a cluster with this name already exists - status = self.get_ray_cluster_status(name, executor) + status = self.status() if status["job_id"] is not None: job_state = status["state"] if job_state in ["PENDING", "RUNNING", "CONFIGURING"]: - logger.info( - f"Ray cluster '{name}' already exists with job ID {status['job_id']} " + logger.debug( + f"Ray cluster '{name}' already exists with ID {status['job_id']} " f"and is currently in {job_state} state. " f"Skipping creation." ) @@ -312,7 +438,7 @@ def create_ray_cluster( "NOT_FOUND", ]: logger.warning( - f"Ray cluster '{name}' exists with job ID {status['job_id']} " + f"Ray cluster '{name}' exists with ID {status['job_id']} " f"in state {job_state}. Creating new cluster anyway." ) @@ -332,113 +458,25 @@ def create_ray_cluster( # Store job_id in cluster_map self.cluster_map[name] = job_id - logger.info(f"Slurm job for Ray cluster '{name}' created with job ID {job_id}") + logger.info(f"Slurm job for Ray cluster '{name}' created with ID {job_id}") return job_id - def schedule_ray_job( + def wait_until_running( self, - name: str, - executor: SlurmExecutor, - command: str, - workdir: Optional[str] = None, - pre_ray_start_commands: Optional[list[str]] = None, - runtime_env_yaml: Optional[str] = None, - dryrun: bool = False, - ): - remote_workdir = None - if workdir: - if isinstance(executor.tunnel, SSHTunnel): - # Rsync workdir honoring .gitignore - remote_workdir = os.path.join(executor.tunnel.job_dir, name, "code") - if not dryrun: - executor.tunnel.connect() - assert executor.tunnel.session is not None, "Tunnel session is not connected" - rsync( - executor.tunnel.session, - workdir, - remote_workdir, - rsync_opts="--filter=':- .gitignore'", - ) - else: - remote_workdir = workdir - elif executor.packager: - if not dryrun: - if isinstance(executor.tunnel, SSHTunnel): - package_dir_ref = tempfile.TemporaryDirectory() - package_dir = package_dir_ref.name - else: - package_dir_ref = None - package_dir = os.path.join(executor.tunnel.job_dir, name) - - if isinstance(executor.packager, GitArchivePackager): - output = subprocess.run( - ["git", "rev-parse", "--show-toplevel"], - check=True, - stdout=subprocess.PIPE, - ) - path = output.stdout.splitlines()[0].decode() - base_path = Path(path).absolute() - else: - base_path = Path(os.getcwd()).absolute() - - local_tar_file = executor.packager.package(base_path, package_dir, name) - local_code_extraction_path = os.path.join(package_dir, "code") - os.makedirs(local_code_extraction_path, exist_ok=True) - subprocess.run( - f"tar -xvzf {local_tar_file} -C {local_code_extraction_path} --ignore-zeros", - shell=True, - check=True, - ) - - if isinstance(executor.tunnel, SSHTunnel): - remote_workdir = os.path.join(executor.tunnel.job_dir, name, "code") - executor.tunnel.connect() - assert executor.tunnel.session is not None, "Tunnel session is not connected" - rsync( - executor.tunnel.session, - os.path.join(local_code_extraction_path, ""), - remote_workdir, - rsync_opts="--filter=':- .gitignore'", - ) - else: - remote_workdir = local_code_extraction_path - - assert remote_workdir is not None, "workdir is not set" - job_id = self.create_ray_cluster( - name, - executor, - pre_ray_start_commands=pre_ray_start_commands, - dryrun=dryrun, - command=command, - workdir=remote_workdir, - ) - - # Descriptive log for the user with useful paths / identifiers - cluster_dir = os.path.join(executor.tunnel.job_dir, name) - logger.info( - f"""\n\n\033[1;34mRay job submitted to Slurm cluster at {executor.tunnel.key}:\033[0m - • \033[1mJob ID\033[0m : \033[32m{job_id}\033[0m - • \033[1mCluster dir\033[0m : {cluster_dir} - • \033[1mLogs directory\033[0m : {os.path.join(cluster_dir, "logs")} - • \033[1mSBATCH script\033[0m : {os.path.join(cluster_dir, "ray.sub")} - • \033[1mRemote workdir\033[0m : {remote_workdir} - (use `squeue -j {job_id}` to check status, `scancel {job_id}` to cancel)\n""" - ) - - return job_id - - def wait_until_ray_cluster_running( - self, - name: str, - executor: SlurmExecutor, timeout: int = 600, delay_between_attempts: int = 30, ) -> bool: + """Block until the Ray head reports *ready* or the timeout expires. + + Returns *True* when the cluster reaches the ``RUNNING`` + ``ray_ready`` state, otherwise + *False*. + """ + name = self.name logger.info(f"Waiting until Ray cluster '{name}' is running") start_time = time.time() while time.time() - start_time < timeout: - status = self.get_ray_cluster_status(name, executor) + status = self.status() if status["ray_ready"]: logger.info(f"Ray cluster '{name}' is ready.") @@ -449,22 +487,38 @@ def wait_until_ray_cluster_running( logger.error(f"Ray cluster '{name}' failed to start. Job state: {status['state']}") return False - logger.info(f"Ray cluster '{name}' is not ready, waiting for it to be ready...") + logger.debug(f"Ray cluster '{name}' is not ready, waiting for it to be ready...") time.sleep(delay_between_attempts) - logger.info(f"Ray cluster '{name}' is not ready after {timeout} seconds") + logger.debug(f"Ray cluster '{name}' is not ready after {timeout} seconds") return False - def delete_ray_cluster( + def delete( self, - name: str, - executor: SlurmExecutor, wait: bool = False, timeout: int = 60, poll_interval: int = 5, ) -> bool: - logger.info(f"Deleting Ray cluster '{name}'") - status = self.get_ray_cluster_status(name, executor) + """Terminate the Slurm job backing this Ray cluster. + + Parameters + ---------- + wait : bool, optional + If *True* block until the job leaves the queue (or *timeout* elapses). + timeout : int, optional + Maximum seconds to wait when *wait* is *True*. Defaults to *60*. + poll_interval : int, optional + Seconds between successive ``squeue`` polls. Defaults to *5*. + + Returns + ------- + bool + *True* if the job was successfully cancelled (or already gone), *False* otherwise. + """ + name = self.name + executor = self.executor + logger.debug(f"Deleting Ray cluster '{name}'") + status = self.status() if status["job_id"] is None: logger.warning(f"Ray cluster '{name}' does not exist or is already deleted") @@ -477,64 +531,30 @@ def delete_ray_cluster( state in status["state"] # type: ignore for state in ["COMPLETED", "FAILED", "CANCELLED", "TIMEOUT", "NOT_FOUND"] ): - logger.info(f"Ray cluster '{name}' job {job_id} is already in state {status['state']}") + logger.debug(f"Ray cluster '{name}' {job_id} is already in state {status['state']}") # Remove from cluster_map if name in self.cluster_map: del self.cluster_map[name] return True - # Cancel the job - executor.tunnel.connect() - cmd = f"scancel {job_id}" - logger.info(f"Cancelling Ray cluster '{name}' job {job_id}") - try: - executor.tunnel.run(cmd) - except Exception as e: - logger.error(f"Failed to cancel Ray cluster '{name}' job {job_id}: {e}") - return False + success = cancel_slurm_job( + executor, + name, + job_id, + wait=wait, + timeout=timeout, + poll_interval=poll_interval, + ) - # Remove from cluster_map if it exists if name in self.cluster_map: del self.cluster_map[name] - # Wait for job to be fully terminated if requested - if wait: - start_time = time.time() - while time.time() - start_time < timeout: - status = self.get_ray_cluster_status(name, executor) - - # If job is not found anymore, it's been successfully cancelled - if status["job_id"] is None: - logger.info( - f"Ray cluster '{name}' job {job_id} has been successfully cancelled" - ) - if name in self.cluster_map: - del self.cluster_map[name] - return True - - # If job is in a terminated state, success - if any(state in status["state"] for state in ["CANCELLED", "FAILED", "TIMEOUT"]): # type: ignore - logger.info( - f"Ray cluster '{name}' job {job_id} is now in state {status['state']}" - ) - if name in self.cluster_map: - del self.cluster_map[name] - return True - - logger.info(f"Waiting for Ray cluster '{name}' job {job_id} to terminate...") - time.sleep(poll_interval) - - logger.warning(f"Timed out waiting for Ray cluster '{name}' job {job_id} to terminate") - return False - - return True + return success def port_forward( self, - name: str, - port: int, - target_port: int, - executor: SlurmExecutor, + port: int = 8265, + target_port: int = 8265, wait: bool = False, ): """Port forward to a Ray cluster using SSH tunnel. @@ -559,7 +579,9 @@ def port_forward( - TimeoutError: If port forwarding fails to establish within the timeout period. """ # Check if cluster exists and is running - status = self.get_ray_cluster_status(name, executor) + name = self.name + executor = self.executor + status = self.status() if status["job_id"] is None: raise RuntimeError(f"Could not find Ray cluster {name}") @@ -781,7 +803,7 @@ def _cleanup(self): self._ssh_process = None # Ensure it's cleared def stop_forwarding(self): - logger.info("Stopping port forwarding") + logger.debug("Stopping port forwarding") self._stop_event.set() # Create and start the forwarding thread @@ -808,7 +830,7 @@ def stop_forwarding(self): original_sigterm_handler = signal.getsignal(signal.SIGTERM) def signal_handler(sig, frame): - logger.info(f"Received signal {sig} to stop port forwarding") + logger.debug(f"Received signal {sig} to stop port forwarding") stop_event.set() # Restore original signal handlers @@ -819,7 +841,7 @@ def signal_handler(sig, frame): signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) - logger.info("Port forwarding is active. Press Ctrl+C to stop...") + logger.debug("Port forwarding is active. Press Ctrl+C to stop...") while not stop_event.is_set(): if not forward_thread.is_alive(): logger.error( @@ -858,3 +880,275 @@ def signal_handler(sig, frame): f"SSH process (PID: {forward_thread._ssh_process.pid}) did not respond to kill." ) return forward_thread + + +@dataclass(kw_only=True) +class SlurmRayJob: + """Lightweight helper around a single Ray Slurm job returned by ``schedule_ray_job``. + + Parameters + ---------- + name : str + Logical name of the Ray cluster (not necessarily the Slurm job-name). + job_id : str + Numeric Slurm job id returned by ``sbatch``. + cluster_dir : str + Remote directory where cluster artefacts (logs, SBATCH script, etc.) are stored. + executor : SlurmExecutor + The executor used to submit/run the job. We only need it for its tunnel. + """ + + name: str + executor: SlurmExecutor + + # --------------------------------------------------------------------- + # Internals + # --------------------------------------------------------------------- + def __post_init__(self): + self.cluster_dir = os.path.join(self.executor.tunnel.job_dir, self.name) + self.job_id = None + + def _logs_path(self) -> str: + # Private helper – path construction only (no public docstring) + assert self.cluster_dir is not None, "cluster_dir is not set" + return os.path.join(self.cluster_dir, "logs", "ray-job.log") + + # ------------------------------------------------------------------ + # Public helpers + # ------------------------------------------------------------------ + + def stop( + self, + *, + wait: bool = False, + timeout: int = 60, + poll_interval: int = 5, + ) -> bool: + """Cancel this Slurm Ray *job* (optionally blocking until it disappears). + + Parameters + ---------- + wait : bool, optional + If *True* block until the job is gone / in a terminal state, up to + *timeout* seconds. Defaults to *False* (fire-and-forget). + timeout : int, optional + Max seconds to wait when *wait* is *True*. Defaults to *60*. + poll_interval : int, optional + Seconds between ``squeue`` polls when waiting. Defaults to *5*. + """ + + if self.job_id is None: + self.job_id = get_last_job_id(self.cluster_dir, self.executor) + if self.job_id is None: + raise RuntimeError(f"Ray job '{self.name}' has no job_id") + + return cancel_slurm_job( + self.executor, + self.name, + self.job_id, + wait=wait, + timeout=timeout, + poll_interval=poll_interval, + ) + + def logs(self, follow: bool = False, lines: int = 100, timeout: int = 100) -> None: + """Show the remote ``ray-job.log``. + + Parameters + ---------- + follow : bool, optional + If *True* we stream the log (`tail -f`). Otherwise the last *lines* + lines are printed. Defaults to *False*. + lines : int, optional + Number of lines to show when *follow* is *False*. Ignored when + *follow* is *True*. + timeout : int, optional + Max seconds to wait for the log file to appear on the remote host + before giving up. Only applies if the file does not yet exist. + """ + # Lazily resolve missing job-id and fail only if still unavailable + if self.job_id is None: + self.job_id = get_last_job_id(self.cluster_dir, self.executor) + if self.job_id is None: + raise RuntimeError(f"Ray job '{self.name}' has no job_id") + + self.executor.tunnel.connect() + log_path = self._logs_path() + if follow: + # Run tail in background on remote host and poll Slurm until the + # job disappears from `squeue`. When it is gone, we kill the + # background tail which makes the whole SSH command exit + # gracefully so our local call returns without manual Ctrl+C. + cmd = ( + "bash -c '" + f'tail -n {lines} -F "{log_path}" & ' + "TAIL_PID=$!; " + f"while squeue -j {self.job_id} -h | grep -q .; do sleep 5; done; " + "kill $TAIL_PID; wait $TAIL_PID'" + ) + else: + cmd = f"tail -n {lines} {log_path}" + + # Ensure file exists or wait up to *timeout* seconds + start_ts = time.time() + exists = False + while time.time() - start_ts < timeout: + print(f"Checking if {log_path} exists") + test_result = self.executor.tunnel.run(f"test -f {log_path}", hide=True, warn=True) + if test_result.return_code == 0: + exists = True + break + time.sleep(2) + + if not exists: + logger.warning( + f"Log file {log_path} not found after {timeout}s. Skipping tail." # noqa: G004 + ) + return + + try: + self.executor.tunnel.run(cmd, hide=False, warn=True) + except KeyboardInterrupt: + # User interrupted tailing; stop remote process (connection will close automatically). + logger.debug("Stopped tailing logs (Ctrl+C)") + # Fabric/Invoke should handle remote process termination. We just return. + + def status(self, display: bool = True) -> dict[str, Any]: + """Return and pretty-print current Slurm/Ray status for this job.""" + assert self.cluster_dir is not None, "cluster_dir is not set" + if self.job_id is None: + self.job_id = get_last_job_id(self.cluster_dir, self.executor) + + cluster = SlurmRayCluster(name=self.name, executor=self.executor) + if self.job_id is not None: + cluster.cluster_map[self.name] = str(self.job_id) + + status_info = cluster.status(display=False) + + # Build a concise, colourful summary mirroring the submission banner + sbatch_script = os.path.join(self.cluster_dir, "ray.sub") + logs_dir = os.path.join(self.cluster_dir, "logs") + if display: + logger.info( + f"""\n\n\033[1;34mRay job status for Slurm cluster at {self.executor.tunnel.key}:\033[0m + • \033[1mJob ID\033[0m : \033[32m{self.job_id}\033[0m + • \033[1mState\033[0m : {status_info.get("state", "UNKNOWN")} + • \033[1mRay ready\033[0m : {status_info.get("ray_ready", False)} + • \033[1mCluster dir\033[0m : {self.cluster_dir} + • \033[1mLogs directory\033[0m : {logs_dir} + • \033[1mSBATCH script\033[0m : {sbatch_script} + (use `squeue -j {self.job_id}` to check status, `scancel {self.job_id}` to cancel, + `tail -f {self._logs_path()}` to view logs)\n""" + ) + return status_info + + def start( + self, + command: str, + workdir: str, + runtime_env_yaml: Optional[str] | None = None, + pre_ray_start_commands: Optional[list[str]] = None, + dryrun: bool = False, + ): + """Submit a Ray job via Slurm and return a *live* SlurmRayJob helper. + + This is a thin wrapper around :py:meth:`SlurmRayCluster.schedule_ray_job` so + that users can work directly with *RayJob* rather than *RayCluster* + helpers:: + + SlurmRayJob.start( + name="my-job", + executor=my_slurm_executor, + command="python train.py", + workdir="./src", + ) + """ + # ------------------------------------------------------------------ + # 1) Early exit if a RayJob with this *logical* name already exists + # ------------------------------------------------------------------ + cluster = SlurmRayCluster(name=self.name, executor=self.executor) + if cluster.status()["job_id"] is not None: + raise RuntimeError(f"Ray job '{self.name}' already exists") + + # ------------------------------------------------------------------ + # 2) Ship *workdir* over to the remote side (or package via packager) + # ------------------------------------------------------------------ + remote_workdir: Optional[str] = None + + if workdir: + if isinstance(self.executor.tunnel, SSHTunnel): + # Rsync workdir honouring .gitignore + remote_workdir = os.path.join(self.executor.tunnel.job_dir, self.name, "code") + if not dryrun: + self.executor.tunnel.connect() + assert self.executor.tunnel.session is not None, ( + "Tunnel session is not connected" + ) + rsync( + self.executor.tunnel.session, + workdir, + remote_workdir, + rsync_opts="--filter=':- .gitignore'", + ) + else: + remote_workdir = workdir + elif self.executor.packager is not None: + # Use the packager to create an archive which we then extract on the + # submission host and optionally rsync to the target. + if not dryrun: + if isinstance(self.executor.tunnel, SSHTunnel): + package_dir = tempfile.mkdtemp(prefix="nemo_packager_") + else: + package_dir = os.path.join(self.executor.tunnel.job_dir, self.name) + + # Base path for packaging – either Git repo root (GitArchivePackager) + # or current cwd for generic packagers. + if isinstance(self.executor.packager, GitArchivePackager): + output = subprocess.run( + ["git", "rev-parse", "--show-toplevel"], + check=True, + stdout=subprocess.PIPE, + ) + path = output.stdout.splitlines()[0].decode() + base_path = Path(path).absolute() + else: + base_path = Path(os.getcwd()).absolute() + + local_tar_file = self.executor.packager.package(base_path, package_dir, self.name) + local_code_extraction_path = os.path.join(package_dir, "code") + os.makedirs(local_code_extraction_path, exist_ok=True) + subprocess.run( + f"tar -xvzf {local_tar_file} -C {local_code_extraction_path} --ignore-zeros", + shell=True, + check=True, + ) + + if isinstance(self.executor.tunnel, SSHTunnel): + remote_workdir = os.path.join(self.executor.tunnel.job_dir, self.name, "code") + self.executor.tunnel.connect() + assert self.executor.tunnel.session is not None, ( + "Tunnel session is not connected" + ) + rsync( + self.executor.tunnel.session, + os.path.join(local_code_extraction_path, ""), + remote_workdir, + rsync_opts="--filter=':- .gitignore'", + ) + else: + remote_workdir = local_code_extraction_path + + assert remote_workdir is not None, "workdir could not be determined" + + # ------------------------------------------------------------------ + # 3) Spin up / reuse the Ray *cluster* (Slurm array job) + # ------------------------------------------------------------------ + job_id = cluster.create( + pre_ray_start_commands=pre_ray_start_commands, + dryrun=dryrun, + command=command, + workdir=remote_workdir, + ) + + self.job_id = job_id + self.status() diff --git a/nemo_run/run/ray/templates/ray.sub.j2 b/nemo_run/run/ray/templates/ray.sub.j2 index 6ea2816d..cc1a77cd 100644 --- a/nemo_run/run/ray/templates/ray.sub.j2 +++ b/nemo_run/run/ray/templates/ray.sub.j2 @@ -43,6 +43,14 @@ MAX_WORKER_PORT=${MAX_WORKER_PORT:-54257} # Directory setup export CLUSTER_DIR={{ cluster_dir }} +JOB_IDS_FILE="$CLUSTER_DIR/job_ids.json" +if [[ -f "$JOB_IDS_FILE" ]]; then + tmp="$(mktemp)" + jq --arg id "$SLURM_JOB_ID" '. + [$id]' "$JOB_IDS_FILE" > "$tmp" && mv "$tmp" "$JOB_IDS_FILE" +else + echo "[\"$SLURM_JOB_ID\"]" > "$JOB_IDS_FILE" +fi + mkdir -p $CLUSTER_DIR/scripts export LOG_DIR={{ log_dir }} @@ -289,7 +297,7 @@ echo "[INFO] Ray cluster information saved to $CLUSTER_DIR/ray_cluster_info.json # This driver process is responsible for launching a job on the Ray cluster CONTAINER_CWD=$(scontrol show job $SLURM_JOB_ID --json | jq -r '.jobs[].current_working_directory') # Define command to be empty by default -COMMAND="${COMMAND:-{{ command }}}" +COMMAND="${COMMAND:-{{ command | default('', true) }}}" COMMAND_WORKDIR={{ command_workdir | default('$CONTAINER_CWD') }} if [[ -n "$COMMAND" ]]; then From 4295953f69dc3d9604f30e7d387b4aa1484c7411 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Wed, 21 May 2025 09:42:09 -0700 Subject: [PATCH 13/18] fix Signed-off-by: Hemil Desai --- docs/source/guides/ray.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/guides/ray.md b/docs/source/guides/ray.md index 3be50d0a..74687aab 100644 --- a/docs/source/guides/ray.md +++ b/docs/source/guides/ray.md @@ -101,7 +101,7 @@ job.start( runtime_env_yaml="/path/to/runtime_env.yaml", # optional pre_ray_start_commands=pre_ray_start, ) -job.follow_logs_until_completion() +job.logs(follow=True) # 5) Clean-up cluster.stop() @@ -175,7 +175,7 @@ job.start( workdir="/path/to/project/", # rsync'ed via SSH to the cluster_dir/code/ pre_ray_start_commands=pre_ray_start, ) -job.follow_logs_until_completion() +job.logs(follow=True) # 6) Tear everything down (or just let the wall-time expire) cluster.stop() @@ -246,7 +246,7 @@ def main() -> None: job.start(command=args.command, workdir="./") # 4) Stream logs and block until completion - job.follow_logs_until_completion() + job.logs(follow=True) # 5) Tidy-up cluster.stop() From 9284278863df7d3d644c29fddc2a377f14806564 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Wed, 21 May 2025 13:32:35 -0700 Subject: [PATCH 14/18] fix Signed-off-by: Hemil Desai --- nemo_run/run/ray/templates/ray.sub.j2 | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo_run/run/ray/templates/ray.sub.j2 b/nemo_run/run/ray/templates/ray.sub.j2 index cc1a77cd..d17849ab 100644 --- a/nemo_run/run/ray/templates/ray.sub.j2 +++ b/nemo_run/run/ray/templates/ray.sub.j2 @@ -48,6 +48,7 @@ if [[ -f "$JOB_IDS_FILE" ]]; then tmp="$(mktemp)" jq --arg id "$SLURM_JOB_ID" '. + [$id]' "$JOB_IDS_FILE" > "$tmp" && mv "$tmp" "$JOB_IDS_FILE" else + touch "$JOB_IDS_FILE" echo "[\"$SLURM_JOB_ID\"]" > "$JOB_IDS_FILE" fi From 2ad635e44166a498d9cf8beb4636350b6c1a8a20 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Wed, 21 May 2025 13:38:53 -0700 Subject: [PATCH 15/18] fix Signed-off-by: Hemil Desai --- nemo_run/run/ray/templates/ray.sub.j2 | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nemo_run/run/ray/templates/ray.sub.j2 b/nemo_run/run/ray/templates/ray.sub.j2 index d17849ab..238f8e63 100644 --- a/nemo_run/run/ray/templates/ray.sub.j2 +++ b/nemo_run/run/ray/templates/ray.sub.j2 @@ -43,6 +43,8 @@ MAX_WORKER_PORT=${MAX_WORKER_PORT:-54257} # Directory setup export CLUSTER_DIR={{ cluster_dir }} +mkdir -p $CLUSTER_DIR + JOB_IDS_FILE="$CLUSTER_DIR/job_ids.json" if [[ -f "$JOB_IDS_FILE" ]]; then tmp="$(mktemp)" From 6c5f5de524cfafc2bda345bdfe9f63588ca20f8f Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Wed, 21 May 2025 19:48:58 -0700 Subject: [PATCH 16/18] fixes Signed-off-by: Hemil Desai --- nemo_run/core/execution/kuberay.py | 20 ++++----- nemo_run/run/ray/kuberay.py | 71 +++++++++++++++++++++++++----- nemo_run/run/ray/slurm.py | 16 +++---- 3 files changed, 76 insertions(+), 31 deletions(-) diff --git a/nemo_run/core/execution/kuberay.py b/nemo_run/core/execution/kuberay.py index a3c16e78..9cc5b3ef 100644 --- a/nemo_run/core/execution/kuberay.py +++ b/nemo_run/core/execution/kuberay.py @@ -21,7 +21,7 @@ import subprocess import time from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Optional from kubernetes import client, watch from kubernetes.client import CoreV1Api @@ -80,9 +80,9 @@ class KubeRayExecutor(Executor): image: str = "" # Will be set in __post_init__ if empty head_cpu: str = "1" head_memory: str = "2Gi" - ray_start_params: Dict[str, Any] = field(default_factory=dict) - worker_groups: List[KubeRayWorkerGroup] = field(default_factory=list) - labels: Dict[str, Any] = field(default_factory=dict) + ray_start_params: dict[str, Any] = field(default_factory=dict) + worker_groups: list[KubeRayWorkerGroup] = field(default_factory=list) + labels: dict[str, Any] = field(default_factory=dict) service_type: str = "ClusterIP" head_ports: list[dict[str, Any]] = field(default_factory=list) volume_mounts: list[dict[str, Any]] = field(default_factory=list) @@ -344,7 +344,7 @@ def update_worker_group_replicas( max_replicas: int, min_replicas: int, replicas: int, -) -> Tuple[dict, bool]: +) -> tuple[dict, bool]: assert cluster["spec"]["workerGroupSpecs"] assert max_replicas >= min_replicas @@ -366,7 +366,7 @@ def update_worker_group_resources( cpu_limits: str, memory_limits: str, container_name="unspecified", -) -> Tuple[dict, bool]: +) -> tuple[dict, bool]: assert cluster["spec"]["workerGroupSpecs"] worker_groups = cluster["spec"]["workerGroupSpecs"] @@ -417,7 +417,7 @@ def duplicate_worker_group( cluster: dict, group_name: str, new_group_name: str, -) -> Tuple[dict, bool]: +) -> tuple[dict, bool]: assert is_valid_name(new_group_name) assert cluster["spec"]["workerGroupSpecs"] @@ -436,7 +436,7 @@ def duplicate_worker_group( def delete_worker_group( cluster: dict, group_name: str, -) -> Tuple[dict, bool]: +) -> tuple[dict, bool]: assert cluster["spec"]["workerGroupSpecs"] worker_groups = cluster["spec"]["workerGroupSpecs"] @@ -463,8 +463,8 @@ def sync_workdir_via_pod( namespace: str, workdir: str, core_v1_api: CoreV1Api, - volumes: List[dict[str, object]], - volume_mounts: List[dict[str, object]], + volumes: list[dict[str, object]], + volume_mounts: list[dict[str, object]], workspace_path: str = "/workspace", image: str = "alpine:3.19", cleanup: bool = False, diff --git a/nemo_run/run/ray/kuberay.py b/nemo_run/run/ray/kuberay.py index d3947c9a..5f75d151 100644 --- a/nemo_run/run/ray/kuberay.py +++ b/nemo_run/run/ray/kuberay.py @@ -98,7 +98,6 @@ def status( self, timeout: int = 60, delay_between_attempts: int = 5, - *, display: bool = False, ) -> Any: """Return the ``status`` stanza of the RayCluster CR (blocking). @@ -109,7 +108,7 @@ def status( namespace = self.executor.namespace or "default" name = self.name - logger.info( + logger.debug( f"Getting Ray cluster status for '{name}' in namespace '{namespace}', " f"timeout: {timeout}s, delay: {delay_between_attempts}s" ) @@ -150,10 +149,25 @@ def status( def wait_until_running( self, - timeout: int = 60, + timeout: int = 600, delay_between_attempts: int = 5, ) -> bool: - """Block until the Ray head service has a reachable IP (or timeout).""" + """Block until the Ray head service has a reachable IP **and** the head pod is running. + + The previous implementation returned as soon as the operator had + populated ``status.head.serviceIP`` in the RayCluster CR. This is a + good proxy for readiness of the *service* object but does **not** + guarantee that the underlying *pod* has actually reached the + ``Running``/``Ready`` state. + + We now additionally query the Kubernetes API for the head pod and + ensure that it is both *Running* **and** *Ready* before returning + success. The head pod is identified via the same labels that the + KubeRay operator applies to every pod: + + • ``ray.io/cluster=`` + • ``ray.io/node-type=head`` + """ namespace = self.executor.namespace or "default" name = self.name @@ -163,23 +177,60 @@ def wait_until_running( f"timeout: {timeout}s, delay: {delay_between_attempts}s" ) + def _head_pod_is_ready() -> bool: + """Return *True* if the head pod exists and is Running/Ready.""" + try: + pods = self.core_v1_api.list_namespaced_pod( + namespace=namespace, label_selector=f"ray.io/cluster={name}" + ) + except ApiException as e: + logger.debug(f"Error listing pods for Ray cluster '{name}': {e}") + return False + + for pod in pods.items: + labels = pod.metadata.labels or {} + # Newer KubeRay versions set `ray.io/node-type=head`; fall back to + # a heuristic on the pod name otherwise. + is_head = labels.get("ray.io/node-type") == "head" or "-head" in pod.metadata.name + if not is_head: + continue + + if pod.status.phase != "Running": + return False + + # Ensure the Ready condition is *True* (best-effort) + if pod.status.conditions: + for cond in pod.status.conditions: + if cond.type == "Ready": + return cond.status == "True" + # If no conditions, fall back to phase only + return True + + # No head pod found + return False + remaining = timeout while remaining > 0: poll_window = min(delay_between_attempts, remaining) - status = self.status(poll_window, poll_window, display=False) + + status = self.status(display=False) if not status: logger.info(f"Ray cluster '{name}' status could not be retrieved") return False - # TODO: once the operator exposes a proper .state field, use that - # For now we infer readiness from the presence of head.serviceIP - if status.get("head", {}).get("serviceIP"): - logger.info(f"Ray cluster '{name}' is running") + svc_ip_ready = bool(status.get("head", {}).get("serviceIP")) + pod_ready = False + if svc_ip_ready: + pod_ready = _head_pod_is_ready() + + if svc_ip_ready and pod_ready: + logger.info(f"Ray cluster '{name}' is running and head pod is ready") return True logger.debug( - f"Ray cluster '{name}' status is not running yet, current status: {status.get('state', 'unknown')}" + f"Ray cluster '{name}' not ready yet – svc_ip_ready={svc_ip_ready}, pod_ready={pod_ready}" ) + remaining -= poll_window logger.debug(f"Ray cluster '{name}' status is not running yet, timing out...") diff --git a/nemo_run/run/ray/slurm.py b/nemo_run/run/ray/slurm.py index c0c98624..1275c148 100644 --- a/nemo_run/run/ray/slurm.py +++ b/nemo_run/run/ray/slurm.py @@ -408,7 +408,7 @@ def create( pre_ray_start_commands=pre_ray_start_commands, command=command, workdir=workdir, - launch_cmd=["sbatch", "--requeue", "--parsable"], + launch_cmd=["sbatch", "--requeue", "--parsable", "--dependency=singleton"], ).materialize() if dryrun: @@ -993,7 +993,7 @@ def logs(self, follow: bool = False, lines: int = 100, timeout: int = 100) -> No start_ts = time.time() exists = False while time.time() - start_ts < timeout: - print(f"Checking if {log_path} exists") + logger.debug(f"Checking if {log_path} exists") test_result = self.executor.tunnel.run(f"test -f {log_path}", hide=True, warn=True) if test_result.return_code == 0: exists = True @@ -1064,14 +1064,7 @@ def start( ) """ # ------------------------------------------------------------------ - # 1) Early exit if a RayJob with this *logical* name already exists - # ------------------------------------------------------------------ - cluster = SlurmRayCluster(name=self.name, executor=self.executor) - if cluster.status()["job_id"] is not None: - raise RuntimeError(f"Ray job '{self.name}' already exists") - - # ------------------------------------------------------------------ - # 2) Ship *workdir* over to the remote side (or package via packager) + # Ship *workdir* over to the remote side (or package via packager) # ------------------------------------------------------------------ remote_workdir: Optional[str] = None @@ -1141,8 +1134,9 @@ def start( assert remote_workdir is not None, "workdir could not be determined" # ------------------------------------------------------------------ - # 3) Spin up / reuse the Ray *cluster* (Slurm array job) + # Spin up / reuse the Ray *cluster* (Slurm array job) # ------------------------------------------------------------------ + cluster = SlurmRayCluster(name=self.name, executor=self.executor) job_id = cluster.create( pre_ray_start_commands=pre_ray_start_commands, dryrun=dryrun, From 990c9450478cf8bc635736933f53ca15fb96df30 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Wed, 21 May 2025 21:52:45 -0700 Subject: [PATCH 17/18] fix Signed-off-by: Hemil Desai --- nemo_run/run/ray/slurm.py | 18 ++++++++++++------ nemo_run/run/ray/templates/ray.sub.j2 | 2 +- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/nemo_run/run/ray/slurm.py b/nemo_run/run/ray/slurm.py index 1275c148..f59e4491 100644 --- a/nemo_run/run/ray/slurm.py +++ b/nemo_run/run/ray/slurm.py @@ -182,18 +182,23 @@ def materialize(self) -> str: for key, value in self.executor.env_vars.items(): env_vars.append(f"export {key.upper()}={value}") + def get_gres_specification() -> str: + if self.executor.gres: + return f"--gres={self.executor.gres}" + elif self.executor.gpus_per_node: + return f"gpu:{self.executor.gpus_per_node}" + else: + return "" + def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str: _srun_flags = [f"--container-image={container_image}"] if container_image else [] _srun_flags.append("--no-container-mount-home") _srun_flags.append("--mpi=pmix") _srun_flags.append(f"-A={self.executor.account}") _srun_flags.append(f"-p={self.executor.partition}") - if self.executor.gres: - _srun_flags.append(f"--gres={self.executor.gres}") - elif self.executor.gpus_per_node: - _srun_flags.append(f"--gres=gpu:{self.executor.gpus_per_node}") - else: - _srun_flags.append("--gres=gpu:8") + gres_specification = get_gres_specification() + if gres_specification: + _srun_flags.append(gres_specification) if self.nemo_run_dir: new_mounts = copy.deepcopy(mounts) @@ -226,6 +231,7 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str: ), "command": self.command, "command_workdir": self.workdir, + "gres_specification": get_gres_specification(), } if self.pre_ray_start_commands: diff --git a/nemo_run/run/ray/templates/ray.sub.j2 b/nemo_run/run/ray/templates/ray.sub.j2 index 238f8e63..4375e15d 100644 --- a/nemo_run/run/ray/templates/ray.sub.j2 +++ b/nemo_run/run/ray/templates/ray.sub.j2 @@ -315,7 +315,7 @@ if [[ -z "\$WORKER_NUM" ]]; then srun --no-container-mount-home --gpus=0 -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-head --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "$head_node" --jobid $SLURM_JOB_ID --pty bash else nodes_array=($nodes) - srun --no-container-mount-home --gres=gpu:8 -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-worker-\$WORKER_NUM --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "\${nodes_array[\$WORKER_NUM]}" --jobid $SLURM_JOB_ID --pty bash + srun --no-container-mount-home {%- if gres_specification %}{{gres_specification}}{% endif %} -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-worker-\$WORKER_NUM --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "\${nodes_array[\$WORKER_NUM]}" --jobid $SLURM_JOB_ID --pty bash fi EOF chmod +x $CLUSTER_DIR/scripts/${SLURM_JOB_ID}-attach.sh From 1bd4cef3d345306480cfe326fc1ba3c6bf5f2813 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Thu, 22 May 2025 12:25:59 -0700 Subject: [PATCH 18/18] fix Signed-off-by: Hemil Desai --- nemo_run/run/ray/slurm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_run/run/ray/slurm.py b/nemo_run/run/ray/slurm.py index f59e4491..12dd3bec 100644 --- a/nemo_run/run/ray/slurm.py +++ b/nemo_run/run/ray/slurm.py @@ -186,7 +186,7 @@ def get_gres_specification() -> str: if self.executor.gres: return f"--gres={self.executor.gres}" elif self.executor.gpus_per_node: - return f"gpu:{self.executor.gpus_per_node}" + return f"--gres=gpu:{self.executor.gpus_per_node}" else: return ""