diff --git a/nemo_curator/core/client.py b/nemo_curator/core/client.py index 7db7b3ce3e..10facab1a2 100644 --- a/nemo_curator/core/client.py +++ b/nemo_curator/core/client.py @@ -13,10 +13,16 @@ # limitations under the License. import atexit +import contextlib import os +import shutil import signal import socket import subprocess +import sys +import tempfile +import threading +import time from dataclasses import dataclass, field import yaml @@ -206,3 +212,395 @@ def __enter__(self): def __exit__(self, *exc): self.stop() + + +# --------------------------------------------------------------------------- # +# SLURM helpers +# --------------------------------------------------------------------------- # + + +def _find_ray_binary() -> str: + """Locate the ``ray`` CLI in the active Python environment.""" + candidate = os.path.join(os.path.dirname(sys.executable), "ray") + if os.path.isfile(candidate) and os.access(candidate, os.X_OK): + return candidate + found = shutil.which("ray") + if found: + return found + msg = "Could not find the `ray` binary. Make sure Ray is installed in the active Python environment." + raise FileNotFoundError(msg) + + +def _expand_slurm_nodelist(nodelist: str) -> list[str]: + """Expand a SLURM node-list expression into individual hostnames. + + Tries ``scontrol show hostnames`` first, then falls back to a + pure-Python parser that handles common compact formats like + ``prefix-[01,03-05]`` and ``node1,node2``. + """ + scontrol = shutil.which("scontrol") + if scontrol: + try: + result = subprocess.run( # noqa: S603 + [scontrol, "show", "hostnames", nodelist], + capture_output=True, + text=True, + check=True, + ) + nodes = [n.strip() for n in result.stdout.strip().splitlines() if n.strip()] + if nodes: + return nodes + except (subprocess.CalledProcessError, FileNotFoundError): + pass + return _parse_slurm_nodelist(nodelist) + + +def _parse_slurm_nodelist(nodelist: str) -> list[str]: + """Pure-Python parser for SLURM compact nodelist notation. + + Handles formats like: + - ``node1,node2,node3`` + - ``prefix-[01,03,05]`` + - ``prefix-[01-05]`` + - ``prefix-[01-03,07,10-12]`` + """ + import re + + nodes: list[str] = [] + for token in re.split(r",(?![^\[]*\])", nodelist): + m = re.match(r"^(.+?)\[(.+)\]$", token) + if not m: + nodes.append(token) + continue + prefix, ranges = m.group(1), m.group(2) + for part in ranges.split(","): + if "-" in part: + lo, hi = part.split("-", 1) + width = len(lo) + for n in range(int(lo), int(hi) + 1): + nodes.append(f"{prefix}{str(n).zfill(width)}") + else: + nodes.append(f"{prefix}{part}") + return nodes if nodes else [nodelist] + + +# --------------------------------------------------------------------------- # +# SlurmRayClient +# --------------------------------------------------------------------------- # + + +@dataclass +class SlurmRayClient(RayClient): + """RayClient extended for multi-node SLURM jobs. + + On single-node SLURM jobs (or when not running under SLURM at all), + behaves identically to :class:`RayClient`. + + On multi-node jobs, the script must be launched on **every** node + (e.g. via ``srun --ntasks-per-node=1``). Each process calls + ``SlurmRayClient``, which inspects ``SLURM_NODEID`` to determine + its role: + + - **Head (SLURM_NODEID=0)**: starts the Ray head, waits for all + workers to connect, then returns from :meth:`start` so the + pipeline can run. + - **Workers (SLURM_NODEID>0)**: start a Ray worker that connects + to the head and **block until the cluster is torn down**. When + the head stops Ray (after the pipeline finishes), the worker + process exits cleanly with ``sys.exit(0)``. + + This is analogous to how ``torchrun`` works: the same script is + launched on every node and each process discovers its role from the + environment. + + Example ``sbatch`` script:: + + #!/bin/bash + #SBATCH --nodes=4 + #SBATCH --ntasks-per-node=1 + #SBATCH --gpus-per-node=8 + + srun --ntasks-per-node=1 \\ + --container-image=nvcr.io/nvidia/nemo-curator:26.02 \\ + --container-mounts="/lustre:/lustre" \\ + bash -c "source .venv/bin/activate && python my_pipeline.py" + + For bare-metal (no container) setups, the same pattern works:: + + #!/bin/bash + #SBATCH --nodes=4 + #SBATCH --ntasks-per-node=1 + #SBATCH --gpus-per-node=8 + + srun python my_pipeline.py + + If ``RAY_ADDRESS`` is set before :meth:`start` is called, + ``SlurmRayClient`` connects to the existing cluster without + starting or stopping anything. + + Parameters + ---------- + worker_connect_timeout_s: + Maximum seconds to wait for all worker nodes to join after the + head is up. Raises ``TimeoutError`` if exceeded. + cleanup_on_start: + If *True*, run ``ray stop --force`` on the local node before + starting Ray. Helps clear stale processes from previous runs. + """ + + worker_connect_timeout_s: int = 300 + cleanup_on_start: bool = True + + ray_dashboard_host: str = "0.0.0.0" # noqa: S104 + + _slurm_nodes: list[str] = field(init=False, default_factory=list, repr=False) + _manages_cluster: bool = field(init=False, default=False, repr=False) + + def __post_init__(self) -> None: + super().__post_init__() + self._detect_slurm_resources() + + def _detect_slurm_resources(self) -> None: + """Auto-detect per-node CPU/GPU counts from SLURM env vars when not set explicitly.""" + if self.num_cpus is None: + slurm_cpus = os.environ.get("SLURM_CPUS_ON_NODE") + if slurm_cpus: + self.num_cpus = int(slurm_cpus) + + if self.num_gpus is None: + slurm_gpus = os.environ.get("SLURM_GPUS_ON_NODE") + if slurm_gpus: + self.num_gpus = int(slurm_gpus) + + # ------------------------------------------------------------------ # + # Lifecycle + # ------------------------------------------------------------------ # + + def start(self) -> None: + """Start the Ray cluster, with role detection on multi-node SLURM jobs. + + If ``RAY_ADDRESS`` is already set, connects to the existing + cluster without starting a new head or launching workers. + + On multi-node jobs, worker processes (``SLURM_NODEID > 0``) + block here until the cluster is torn down, then exit with + ``sys.exit(0)``. Only the head (``SLURM_NODEID = 0``) returns + from this method. + """ + if os.environ.get("RAY_ADDRESS"): + logger.info( + f"RAY_ADDRESS already set ({os.environ['RAY_ADDRESS']}). " + "Connecting to existing Ray cluster — skipping head/worker startup." + ) + super().start() + return + + slurm_job_id = os.environ.get("SLURM_JOB_ID") + if not slurm_job_id: + logger.warning("SLURM_JOB_ID not set — falling back to single-node RayClient behaviour") + super().start() + return + + nodelist = os.environ.get("SLURM_JOB_NODELIST", socket.gethostname()) + self._slurm_nodes = _expand_slurm_nodelist(nodelist) + self._manages_cluster = True + node_id = int(os.environ.get("SLURM_NODEID", "0")) + + logger.info( + f"SlurmRayClient: job {slurm_job_id}, {len(self._slurm_nodes)} node(s), " + f"SLURM_NODEID={node_id}, head={self._slurm_nodes[0]}, " + f"cpus/node={self.num_cpus}, gpus/node={self.num_gpus}" + ) + + if self.cleanup_on_start: + self._cleanup_local_ray() + + if len(self._slurm_nodes) <= 1 or node_id == 0: + # Head node — start Ray head (super().start() selects the actual port via get_free_port) + super().start() + # Broadcast the actual port the head chose so workers don't have to guess. + # Workers may be on different physical nodes and cannot call get_free_port on + # the head, so we write the port to a shared Lustre file keyed on job ID. + if len(self._slurm_nodes) > 1: + self._write_head_port(slurm_job_id) + self._wait_for_workers() + else: + # Worker node — read the port the head actually chose, then connect. + head_ip = socket.gethostbyname(self._slurm_nodes[0]) + actual_port = self._read_head_port(slurm_job_id) + self.ray_port = actual_port + logger.info(f"SlurmRayClient worker {node_id}: connecting to head at {head_ip}:{self.ray_port}") + sys.exit(self._run_as_worker(head_ip)) + + def stop(self) -> None: + """Stop the Ray head. Workers detect the head's death and exit on their own. + + Safe to call multiple times. Does not stop an externally + managed cluster (one discovered via ``RAY_ADDRESS``). + """ + if self._manages_cluster: + slurm_job_id = os.environ.get("SLURM_JOB_ID") + if slurm_job_id: + port_file = self._head_port_file(slurm_job_id) + with contextlib.suppress(FileNotFoundError): + os.remove(port_file) + logger.info(f"SlurmRayClient: removed port file {port_file}") + super().stop() + + # ------------------------------------------------------------------ # + # Internal helpers + # ------------------------------------------------------------------ # + + def _head_port_file(self, slurm_job_id: str) -> str: + """Return path to the shared port-broadcast file for this job. + + Must be on a filesystem visible to ALL nodes (Lustre, not /tmp). + Uses env var ``RAY_PORT_BROADCAST_DIR`` if set, otherwise falls back to + ``/tmp`` (works on single-node or when /tmp is shared, e.g. via NFS). + """ + broadcast_dir = os.environ.get("RAY_PORT_BROADCAST_DIR", "/tmp") # noqa: S108 + os.makedirs(broadcast_dir, exist_ok=True) + return os.path.join(broadcast_dir, f"ray_head_port_{slurm_job_id}") + + def _write_head_port(self, slurm_job_id: str) -> None: + """Write the actual Ray GCS port to a shared file so workers can read it. + + Uses an atomic write-then-rename so workers never observe an empty or + partially-written file (important on Lustre / NFS where open() truncates + before write() completes). + """ + port_file = self._head_port_file(slurm_job_id) + broadcast_dir = os.path.dirname(port_file) + with tempfile.NamedTemporaryFile(mode="w", dir=broadcast_dir, delete=False) as f: + tmp_path = f.name + f.write(str(self.ray_port)) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp_path, port_file) # atomic on POSIX + logger.info(f"SlurmRayClient head: wrote port {self.ray_port} to {port_file}") + + def _read_head_port(self, slurm_job_id: str, timeout_s: int = 600) -> int: + """Wait for the head to write its port file and return the port number.""" + port_file = self._head_port_file(slurm_job_id) + deadline = time.time() + timeout_s + while time.time() < deadline: + if os.path.exists(port_file): + try: + with open(port_file) as f: + port = int(f.read().strip()) + except (ValueError, OSError): + pass # file may be partially written; retry + else: + logger.info(f"SlurmRayClient worker: read head port {port} from {port_file}") + return port + time.sleep(2) + msg = f"Timed out waiting for head port file {port_file} after {timeout_s}s" + raise TimeoutError(msg) + + def _run_as_worker(self, head_ip: str) -> int: + """Start a Ray worker that connects to *head_ip* and block until the cluster is torn down. + + Returns the exit code of ``ray start --block`` so the caller can pass it to ``sys.exit``. + Exit code 0 means the cluster was torn down cleanly; non-zero indicates an error. + """ + ray_bin = _find_ray_binary() + cmd = [ + ray_bin, + "start", + "--address", + f"{head_ip}:{self.ray_port}", + "--temp-dir", + self.ray_temp_dir, + "--block", + "--disable-usage-stats", + ] + if self.num_gpus is not None: + cmd.extend(["--num-gpus", str(self.num_gpus)]) + if self.num_cpus is not None: + cmd.extend(["--num-cpus", str(self.num_cpus)]) + + logger.info(f"Ray worker starting: {' '.join(cmd)}") + result = subprocess.run(cmd, check=False) # noqa: S603 + logger.info(f"Ray worker exited with code {result.returncode}") + return result.returncode + + def _cleanup_local_ray(self) -> None: + """Stop any stale Ray processes on the local node.""" + with contextlib.suppress(Exception): + ray_bin = _find_ray_binary() + subprocess.run([ray_bin, "stop", "--force"], capture_output=True, timeout=30, check=False) # noqa: S603 + + @staticmethod + def _ray_init_with_timeout(address: str, timeout_s: int = 120) -> None: + """Call ``ray.init(address=...)`` with a SIGALRM-based timeout. + + ``ray.init`` can hang indefinitely if the GCS is slow or unstable + after a multi-job start. We use SIGALRM (Linux/macOS only) to raise + a ``TimeoutError`` if the call blocks longer than *timeout_s* seconds. + + Falls back to an unguarded ``ray.init`` when called from a non-main + thread, where SIGALRM is unavailable. + """ + import ray as _ray + + if threading.current_thread() is not threading.main_thread(): + logger.warning("SIGALRM unavailable outside main thread — calling ray.init without timeout") + _ray.init(address=address, ignore_reinit_error=True) + return + + def _handler(_signum: int, _frame: object) -> None: + msg = ( + f"ray.init(address={address!r}) timed out after {timeout_s}s — " + "GCS may be unresponsive; the job will exit and can be resubmitted." + ) + raise TimeoutError(msg) + + old_handler = signal.signal(signal.SIGALRM, _handler) + signal.alarm(timeout_s) + try: + _ray.init(address=address, ignore_reinit_error=True) + finally: + signal.alarm(0) + signal.signal(signal.SIGALRM, old_handler) + + def _wait_for_workers(self) -> None: + """Block until every allocated node is alive in the Ray cluster. + + Raises ``TimeoutError`` (after tearing everything down) if not + all nodes join within ``worker_connect_timeout_s``. + """ + import ray as _ray + + expected = len(self._slurm_nodes) + deadline = time.time() + self.worker_connect_timeout_s + + self._ray_init_with_timeout(os.environ["RAY_ADDRESS"], timeout_s=120) + try: + while True: + alive = [n for n in _ray.nodes() if n.get("Alive")] + if len(alive) >= expected: + total_cpus = sum(n.get("Resources", {}).get("CPU", 0) for n in alive) + total_gpus = sum(n.get("Resources", {}).get("GPU", 0) for n in alive) + logger.info( + f"All {expected} node(s) connected — " + f"total CPUs: {total_cpus:.0f}, total GPUs: {total_gpus:.0f}" + ) + return + + remaining = deadline - time.time() + if remaining <= 0: + logger.error( + f"Timeout: only {len(alive)}/{expected} node(s) connected " + f"after {self.worker_connect_timeout_s}s." + ) + self.stop() + msg = ( + f"Timed out after {self.worker_connect_timeout_s}s: " + f"only {len(alive)}/{expected} node(s) connected. Cluster torn down." + ) + raise TimeoutError(msg) + + logger.info(f"Waiting for workers: {len(alive)}/{expected} ({remaining:.0f}s left)") + time.sleep(min(5, remaining)) + finally: + _ray.shutdown() diff --git a/tests/core/test_slurm_ray_client.py b/tests/core/test_slurm_ray_client.py new file mode 100644 index 0000000000..de2adf7eed --- /dev/null +++ b/tests/core/test_slurm_ray_client.py @@ -0,0 +1,522 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import subprocess +import tempfile +import threading +import time +from typing import TYPE_CHECKING + +import pytest + +if TYPE_CHECKING: + from collections.abc import Callable + from pathlib import Path + from typing import Any, NoReturn + +from nemo_curator.core.client import ( + RayClient, + SlurmRayClient, + _expand_slurm_nodelist, + _find_ray_binary, + _parse_slurm_nodelist, +) + +# --------------------------------------------------------------------------- # +# Helper tests +# --------------------------------------------------------------------------- # + + +class TestFindRayBinary: + def test_finds_ray_in_venv(self): + binary = _find_ray_binary() + assert os.path.isfile(binary) + + def test_raises_when_not_found(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr("shutil.which", lambda _: None) + monkeypatch.setattr("os.path.isfile", lambda _: False) + with pytest.raises(FileNotFoundError, match="ray"): + _find_ray_binary() + + +class TestExpandSlurmNodelist: + def test_single_hostname(self): + result = _expand_slurm_nodelist("compute-001") + assert result == ["compute-001"] + + def test_expands_with_scontrol(self, monkeypatch: pytest.MonkeyPatch): + import nemo_curator.core.client as _client + + fake_result = subprocess.CompletedProcess(args=[], returncode=0, stdout="node-001\nnode-002\nnode-003\n") + monkeypatch.setattr("shutil.which", lambda _: "/usr/bin/scontrol") + monkeypatch.setattr(_client.subprocess, "run", lambda *_args, **_kw: fake_result) + result = _expand_slurm_nodelist("node-[001-003]") + assert result == ["node-001", "node-002", "node-003"] + + def test_fallback_no_scontrol(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr("shutil.which", lambda _: None) + result = _expand_slurm_nodelist("node-001") + assert result == ["node-001"] + + +class TestParseSlurmNodelist: + """Tests for the pure-Python fallback parser (no scontrol required).""" + + def test_single_node(self): + assert _parse_slurm_nodelist("node-001") == ["node-001"] + + def test_comma_separated(self): + assert _parse_slurm_nodelist("node-001,node-002,node-003") == [ + "node-001", + "node-002", + "node-003", + ] + + def test_simple_range(self): + assert _parse_slurm_nodelist("pool0-[01-05]") == [ + "pool0-01", + "pool0-02", + "pool0-03", + "pool0-04", + "pool0-05", + ] + + def test_mixed_range_and_list(self): + # prefix-[01-03,07,10-12] → 6 nodes + result = _parse_slurm_nodelist("node-[01-03,07,10-12]") + assert result == [ + "node-01", + "node-02", + "node-03", + "node-07", + "node-10", + "node-11", + "node-12", + ] + + def test_zero_padded_range(self): + result = _parse_slurm_nodelist("compute-[001-003]") + assert result == ["compute-001", "compute-002", "compute-003"] + + def test_multiple_prefixes_with_ranges(self): + # Two separate bracket groups in a comma-split list + result = _parse_slurm_nodelist("gpu-[1-2],cpu-[3-4]") + assert result == ["gpu-1", "gpu-2", "cpu-3", "cpu-4"] + + +# --------------------------------------------------------------------------- # +# SlurmRayClient unit tests +# --------------------------------------------------------------------------- # + + +class TestSlurmRayClientInit: + def test_detects_slurm_cpus(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("SLURM_CPUS_ON_NODE", "64") + monkeypatch.delenv("SLURM_GPUS_ON_NODE", raising=False) + client = SlurmRayClient() + assert client.num_cpus == 64 + assert client.num_gpus is None + + def test_detects_slurm_gpus(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("SLURM_GPUS_ON_NODE", "8") + monkeypatch.delenv("SLURM_CPUS_ON_NODE", raising=False) + client = SlurmRayClient() + assert client.num_gpus == 8 + + def test_explicit_overrides_slurm(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("SLURM_CPUS_ON_NODE", "64") + monkeypatch.setenv("SLURM_GPUS_ON_NODE", "8") + client = SlurmRayClient(num_cpus=32, num_gpus=4) + assert client.num_cpus == 32 + assert client.num_gpus == 4 + + def test_dashboard_host_defaults_to_all(self): + client = SlurmRayClient() + assert client.ray_dashboard_host == "0.0.0.0" # noqa: S104 + + +class TestSlurmRayClientFallback: + def test_falls_back_without_slurm_job_id(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.delenv("SLURM_JOB_ID", raising=False) + monkeypatch.delenv("RAY_ADDRESS", raising=False) + + with tempfile.TemporaryDirectory(prefix="ray_test_slurm_") as ray_tmp: + client = SlurmRayClient(ray_temp_dir=ray_tmp) + client.start() + try: + assert os.environ.get("RAY_ADDRESS") is not None + assert client.ray_process is not None + fn = os.path.join(ray_tmp, "ray_current_cluster") + t0 = time.perf_counter() + while not os.path.exists(fn) and time.perf_counter() - t0 < 30: + time.sleep(1) + assert os.path.exists(fn) + finally: + client.stop() + + +class TestSlurmRayClientSingleNode: + """Test single-node SLURM behaviour (no srun needed).""" + + def test_single_node_start_stop(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("SLURM_JOB_ID", "12345") + monkeypatch.setenv("SLURM_JOB_NODELIST", os.uname().nodename) + monkeypatch.setenv("SLURM_CPUS_ON_NODE", "4") + monkeypatch.delenv("RAY_ADDRESS", raising=False) + + with tempfile.TemporaryDirectory(prefix="ray_test_slurm_single_") as ray_tmp: + client = SlurmRayClient(ray_temp_dir=ray_tmp, cleanup_on_start=False) + client.start() + try: + assert os.environ.get("RAY_ADDRESS") is not None + assert client._slurm_nodes == [os.uname().nodename] + finally: + client.stop() + + def test_context_manager(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("SLURM_JOB_ID", "12345") + monkeypatch.setenv("SLURM_JOB_NODELIST", os.uname().nodename) + monkeypatch.setenv("SLURM_CPUS_ON_NODE", "4") + monkeypatch.delenv("RAY_ADDRESS", raising=False) + + with tempfile.TemporaryDirectory(prefix="ray_test_slurm_ctx_") as ray_tmp: + with SlurmRayClient(ray_temp_dir=ray_tmp, cleanup_on_start=False) as client: + assert os.environ.get("RAY_ADDRESS") is not None + + assert client.ray_process is None + + +# --------------------------------------------------------------------------- # +# Port-file helpers +# --------------------------------------------------------------------------- # + + +class TestHeadPortFile: + def test_default_uses_tmp(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("RAY_PORT_BROADCAST_DIR", raising=False) + client = SlurmRayClient() + path = client._head_port_file("42") + assert os.path.basename(path) == "ray_head_port_42" + + def test_custom_dir(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("RAY_PORT_BROADCAST_DIR", str(tmp_path)) + client = SlurmRayClient() + assert client._head_port_file("99") == str(tmp_path / "ray_head_port_99") + + +class TestWriteReadHeadPort: + def test_roundtrip(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("RAY_PORT_BROADCAST_DIR", str(tmp_path)) + client = SlurmRayClient(ray_port=12345) + client._write_head_port("job1") + assert client._read_head_port("job1", timeout_s=5) == 12345 + + def test_read_timeout_raises(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("RAY_PORT_BROADCAST_DIR", str(tmp_path)) + monkeypatch.setattr(time, "sleep", lambda _: None) + with pytest.raises(TimeoutError, match="Timed out waiting"): + SlurmRayClient()._read_head_port("no_such_job", timeout_s=0) + + def test_read_ignores_partial_write(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + """If the port file exists but is empty/corrupt, _read_head_port retries until valid.""" + monkeypatch.setenv("RAY_PORT_BROADCAST_DIR", str(tmp_path)) + port_file = tmp_path / "ray_head_port_job2" + write_calls = [0] + real_sleep = time.sleep + + def patched_sleep(s: float) -> None: + write_calls[0] += 1 + if write_calls[0] == 1: + port_file.write_text("6379") + real_sleep(min(s, 0.05)) + + monkeypatch.setattr(time, "sleep", patched_sleep) + port_file.write_text("") # start with corrupt content + assert SlurmRayClient()._read_head_port("job2", timeout_s=10) == 6379 + + +# --------------------------------------------------------------------------- # +# _run_as_worker +# --------------------------------------------------------------------------- # + + +class TestRunAsWorker: + def test_returns_exit_code(self, monkeypatch: pytest.MonkeyPatch) -> None: + import nemo_curator.core.client as _client + + monkeypatch.setattr("shutil.which", lambda _: "/usr/bin/ray") + monkeypatch.setattr( + _client.subprocess, + "run", + lambda _cmd, **_kw: subprocess.CompletedProcess(args=[], returncode=0), + ) + assert SlurmRayClient()._run_as_worker("10.0.0.1") == 0 + + def test_passes_gpu_and_cpu_flags(self, monkeypatch: pytest.MonkeyPatch) -> None: + import nemo_curator.core.client as _client + + captured: list[list[str]] = [] + + def fake_run(cmd: list[str], **_kw: object) -> subprocess.CompletedProcess[str]: + captured.append(cmd) + return subprocess.CompletedProcess(args=[], returncode=0) + + monkeypatch.setattr("shutil.which", lambda _: "/usr/bin/ray") + monkeypatch.setattr(_client.subprocess, "run", fake_run) + SlurmRayClient(num_gpus=4, num_cpus=16)._run_as_worker("10.0.0.1") + assert len(captured) == 1 + assert "--num-gpus" in captured[0] + assert "--num-cpus" in captured[0] + + def test_nonzero_exit_propagated(self, monkeypatch: pytest.MonkeyPatch) -> None: + import nemo_curator.core.client as _client + + monkeypatch.setattr("shutil.which", lambda _: "/usr/bin/ray") + monkeypatch.setattr( + _client.subprocess, + "run", + lambda _cmd, **_kw: subprocess.CompletedProcess(args=[], returncode=1), + ) + assert SlurmRayClient()._run_as_worker("10.0.0.1") == 1 + + +# --------------------------------------------------------------------------- # +# _cleanup_local_ray +# --------------------------------------------------------------------------- # + + +class TestCleanupLocalRay: + def test_calls_ray_stop(self, monkeypatch: pytest.MonkeyPatch) -> None: + import nemo_curator.core.client as _client + + calls: list[list[str]] = [] + monkeypatch.setattr("shutil.which", lambda _: "/usr/bin/ray") + monkeypatch.setattr(_client.subprocess, "run", lambda cmd, **_kw: calls.append(cmd)) + SlurmRayClient()._cleanup_local_ray() + assert any("stop" in c for c in calls[0]) + + def test_suppresses_errors(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Even if ray binary is missing, _cleanup_local_ray must not raise.""" + monkeypatch.setattr("shutil.which", lambda _: None) + monkeypatch.setattr("os.path.isfile", lambda _: False) + SlurmRayClient()._cleanup_local_ray() # should not raise + + +# --------------------------------------------------------------------------- # +# _expand_slurm_nodelist (additional edge cases) +# --------------------------------------------------------------------------- # + + +class TestExpandSlurmNodelistEdgeCases: + def test_scontrol_error_falls_back_to_parser(self, monkeypatch: pytest.MonkeyPatch) -> None: + import nemo_curator.core.client as _client + + def raise_error(*_a: object, **_kw: object) -> NoReturn: + raise subprocess.CalledProcessError(1, "scontrol") + + monkeypatch.setattr("shutil.which", lambda _: "/usr/bin/scontrol") + monkeypatch.setattr(_client.subprocess, "run", raise_error) + assert _expand_slurm_nodelist("node-001") == ["node-001"] + + def test_scontrol_empty_output_falls_back(self, monkeypatch: pytest.MonkeyPatch) -> None: + import nemo_curator.core.client as _client + + monkeypatch.setattr("shutil.which", lambda _: "/usr/bin/scontrol") + monkeypatch.setattr( + _client.subprocess, + "run", + lambda *_a, **_kw: subprocess.CompletedProcess(args=[], returncode=0, stdout=""), + ) + assert _expand_slurm_nodelist("node-001") == ["node-001"] + + +# --------------------------------------------------------------------------- # +# _ray_init_with_timeout +# --------------------------------------------------------------------------- # + + +def _inject_fake_ray( + monkeypatch: pytest.MonkeyPatch, + init_fn: Callable[..., Any] | None = None, + nodes_fn: Callable[[], list[Any]] | None = None, +) -> None: + import sys + import types + + fake_ray = types.ModuleType("ray") + fake_ray.init = init_fn or (lambda *_a, **_kw: None) + fake_ray.nodes = nodes_fn or list + fake_ray.shutdown = lambda: None + monkeypatch.setitem(sys.modules, "ray", fake_ray) + + +class TestRayInitWithTimeout: + def test_main_thread_calls_ray_init(self, monkeypatch: pytest.MonkeyPatch) -> None: + initted: list[str] = [] + _inject_fake_ray(monkeypatch, init_fn=lambda address, **_kw: initted.append(address)) + SlurmRayClient._ray_init_with_timeout("127.0.0.1:6379", timeout_s=10) + assert initted == ["127.0.0.1:6379"] + + def test_non_main_thread_skips_sigalrm(self, monkeypatch: pytest.MonkeyPatch) -> None: + initted: list[str] = [] + _inject_fake_ray(monkeypatch, init_fn=lambda address, **_kw: initted.append(address)) + results: list[str] = [] + + def _run() -> None: + SlurmRayClient._ray_init_with_timeout("127.0.0.1:6379", timeout_s=5) + results.append("done") + + t = threading.Thread(target=_run) + t.start() + t.join(timeout=10) + assert results == ["done"] + assert initted == ["127.0.0.1:6379"] + + +# --------------------------------------------------------------------------- # +# _wait_for_workers +# --------------------------------------------------------------------------- # + + +class TestWaitForWorkers: + def test_success_all_nodes_connected(self, monkeypatch: pytest.MonkeyPatch) -> None: + _inject_fake_ray( + monkeypatch, + nodes_fn=lambda: [ + {"Alive": True, "Resources": {"CPU": 4.0, "GPU": 1.0}}, + {"Alive": True, "Resources": {"CPU": 4.0, "GPU": 1.0}}, + ], + ) + monkeypatch.setenv("RAY_ADDRESS", "127.0.0.1:6379") + client = SlurmRayClient(worker_connect_timeout_s=30) + client._slurm_nodes = ["node1", "node2"] + client._wait_for_workers() # must not raise + + def test_timeout_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: + _inject_fake_ray(monkeypatch, nodes_fn=list) # workers never connect + monkeypatch.setenv("RAY_ADDRESS", "127.0.0.1:6379") + monkeypatch.setattr(time, "sleep", lambda _: None) + client = SlurmRayClient(worker_connect_timeout_s=0) + client._slurm_nodes = ["node1", "node2"] + with pytest.raises(TimeoutError, match="Timed out"): + client._wait_for_workers() + + def test_partial_nodes_then_all_connected(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Workers join across multiple polling iterations.""" + call_count = [0] + + def nodes_fn() -> list[dict[str, Any]]: + call_count[0] += 1 + if call_count[0] < 3: + return [{"Alive": True, "Resources": {}}] + return [{"Alive": True, "Resources": {}}, {"Alive": True, "Resources": {}}] + + _inject_fake_ray(monkeypatch, nodes_fn=nodes_fn) + monkeypatch.setenv("RAY_ADDRESS", "127.0.0.1:6379") + monkeypatch.setattr(time, "sleep", lambda _: None) + client = SlurmRayClient(worker_connect_timeout_s=60) + client._slurm_nodes = ["node1", "node2"] + client._wait_for_workers() + assert call_count[0] >= 3 + + +# --------------------------------------------------------------------------- # +# SlurmRayClient.stop (manages_cluster branch) +# --------------------------------------------------------------------------- # + + +class TestSlurmRayClientStopManagesCluster: + def test_removes_port_file(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("RAY_PORT_BROADCAST_DIR", str(tmp_path)) + monkeypatch.setenv("SLURM_JOB_ID", "99999") + monkeypatch.setattr(RayClient, "stop", lambda _self: None) + + client = SlurmRayClient(ray_port=6379) + client._manages_cluster = True + client._write_head_port("99999") + port_file = client._head_port_file("99999") + assert os.path.exists(port_file) + + client.stop() + assert not os.path.exists(port_file) + + def test_no_port_file_does_not_raise(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("RAY_PORT_BROADCAST_DIR", str(tmp_path)) + monkeypatch.setenv("SLURM_JOB_ID", "88888") + monkeypatch.setattr(RayClient, "stop", lambda _self: None) + + client = SlurmRayClient() + client._manages_cluster = True + client.stop() # no port file — FileNotFoundError must be suppressed + + +# --------------------------------------------------------------------------- # +# SlurmRayClient.start (additional branches) +# --------------------------------------------------------------------------- # + + +class TestSlurmRayClientStartBranches: + def test_ray_address_already_set_delegates_to_super(self, monkeypatch: pytest.MonkeyPatch) -> None: + super_calls: list[str] = [] + monkeypatch.setattr(RayClient, "start", lambda _self: super_calls.append("start")) + monkeypatch.setenv("RAY_ADDRESS", "127.0.0.1:6379") + SlurmRayClient().start() + assert super_calls == ["start"] + + def test_head_node_multi_node_calls_write_and_wait(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + """Head node (SLURM_NODEID=0) with 2 nodes: must write port and wait for workers.""" + monkeypatch.setenv("SLURM_JOB_ID", "55555") + monkeypatch.setenv("SLURM_JOB_NODELIST", "node-[001-002]") + monkeypatch.setenv("SLURM_NODEID", "0") + monkeypatch.delenv("RAY_ADDRESS", raising=False) + + super_starts: list[str] = [] + + def fake_super_start(_self: object) -> None: + super_starts.append("start") + os.environ["RAY_ADDRESS"] = "10.0.0.1:6379" + + wrote: list[str] = [] + waited: list[bool] = [] + monkeypatch.setattr(RayClient, "start", fake_super_start) + monkeypatch.setattr(SlurmRayClient, "_cleanup_local_ray", lambda _self: None) + monkeypatch.setattr(SlurmRayClient, "_write_head_port", lambda _self, jid: wrote.append(jid)) + monkeypatch.setattr(SlurmRayClient, "_wait_for_workers", lambda _self: waited.append(True)) + + SlurmRayClient(cleanup_on_start=True).start() + + assert super_starts == ["start"] + assert wrote == ["55555"] + assert waited == [True] + + def test_worker_node_calls_sys_exit(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + """Worker node (SLURM_NODEID=1) must call sys.exit with the worker's return code.""" + monkeypatch.setenv("SLURM_JOB_ID", "55556") + monkeypatch.setenv("SLURM_JOB_NODELIST", "node-[001-002]") + monkeypatch.setenv("SLURM_NODEID", "1") + monkeypatch.delenv("RAY_ADDRESS", raising=False) + + import nemo_curator.core.client as _client + + monkeypatch.setattr(SlurmRayClient, "_cleanup_local_ray", lambda _self: None) + monkeypatch.setattr(_client.socket, "gethostbyname", lambda _: "10.0.0.1") + monkeypatch.setattr(SlurmRayClient, "_read_head_port", lambda _self, _jid, **_kw: 6379) + monkeypatch.setattr(SlurmRayClient, "_run_as_worker", lambda _self, _ip: 0) + + with pytest.raises(SystemExit) as exc: + SlurmRayClient(cleanup_on_start=True).start() + assert exc.value.code == 0 diff --git a/tutorials/slurm/README.md b/tutorials/slurm/README.md new file mode 100644 index 0000000000..05565de2a0 --- /dev/null +++ b/tutorials/slurm/README.md @@ -0,0 +1,289 @@ +# Running NeMo Curator on SLURM + +This tutorial shows how to scale a NeMo Curator pipeline from a single laptop to a multi-node SLURM cluster with a **one-line change**. + +## Contents + +| File | Purpose | +|------|---------| +| `pipeline.py` | A simple CPU-only pipeline (word-count + node-tag) that runs locally or on SLURM | +| `submit.sh` | `sbatch` script for bare-metal clusters with a shared virtualenv | +| `submit_container.sh` | `sbatch` script using the official NGC container (Pyxis/enroot) | + +--- + +## The key concept: RayClient vs SlurmRayClient + +NeMo Curator uses a `RayClient` to manage the Ray cluster lifecycle. The `SlurmRayClient` is a drop-in replacement that handles the multi-process SLURM model automatically. + +```python +# Local development — Ray starts on the current machine +ray_client = RayClient() + +# SLURM multi-node — Ray spans all allocated nodes automatically +ray_client = SlurmRayClient() + +# One-liner to auto-detect the environment: +ray_client = SlurmRayClient() if os.environ.get("SLURM_JOB_ID") else RayClient() +``` + +That is the **only change** needed to go from a local run to a distributed SLURM job. Everything else — pipeline stages, executor, `pipeline.run()` — is identical. + +### How SlurmRayClient works + +When `srun` launches one Python process per node, `SlurmRayClient.start()` behaves differently on each node: + +``` +srun --ntasks-per-node=1 python pipeline.py --slurm + │ + ├─ Node 0 (SLURM_NODEID=0) — HEAD + │ start() → ray start --head + │ → writes GCS port to shared file + │ → waits for all workers to join + │ → returns ← pipeline runs here + │ + ├─ Node 1 — WORKER + │ start() → reads port file from Node 0 + │ → ray start --block --address=
: