diff --git a/nemo_curator/core/client.py b/nemo_curator/core/client.py index 7db7b3ce3e..10facab1a2 100644 --- a/nemo_curator/core/client.py +++ b/nemo_curator/core/client.py @@ -13,10 +13,16 @@ # limitations under the License. import atexit +import contextlib import os +import shutil import signal import socket import subprocess +import sys +import tempfile +import threading +import time from dataclasses import dataclass, field import yaml @@ -206,3 +212,395 @@ def __enter__(self): def __exit__(self, *exc): self.stop() + + +# --------------------------------------------------------------------------- # +# SLURM helpers +# --------------------------------------------------------------------------- # + + +def _find_ray_binary() -> str: + """Locate the ``ray`` CLI in the active Python environment.""" + candidate = os.path.join(os.path.dirname(sys.executable), "ray") + if os.path.isfile(candidate) and os.access(candidate, os.X_OK): + return candidate + found = shutil.which("ray") + if found: + return found + msg = "Could not find the `ray` binary. Make sure Ray is installed in the active Python environment." + raise FileNotFoundError(msg) + + +def _expand_slurm_nodelist(nodelist: str) -> list[str]: + """Expand a SLURM node-list expression into individual hostnames. + + Tries ``scontrol show hostnames`` first, then falls back to a + pure-Python parser that handles common compact formats like + ``prefix-[01,03-05]`` and ``node1,node2``. + """ + scontrol = shutil.which("scontrol") + if scontrol: + try: + result = subprocess.run( # noqa: S603 + [scontrol, "show", "hostnames", nodelist], + capture_output=True, + text=True, + check=True, + ) + nodes = [n.strip() for n in result.stdout.strip().splitlines() if n.strip()] + if nodes: + return nodes + except (subprocess.CalledProcessError, FileNotFoundError): + pass + return _parse_slurm_nodelist(nodelist) + + +def _parse_slurm_nodelist(nodelist: str) -> list[str]: + """Pure-Python parser for SLURM compact nodelist notation. + + Handles formats like: + - ``node1,node2,node3`` + - ``prefix-[01,03,05]`` + - ``prefix-[01-05]`` + - ``prefix-[01-03,07,10-12]`` + """ + import re + + nodes: list[str] = [] + for token in re.split(r",(?![^\[]*\])", nodelist): + m = re.match(r"^(.+?)\[(.+)\]$", token) + if not m: + nodes.append(token) + continue + prefix, ranges = m.group(1), m.group(2) + for part in ranges.split(","): + if "-" in part: + lo, hi = part.split("-", 1) + width = len(lo) + for n in range(int(lo), int(hi) + 1): + nodes.append(f"{prefix}{str(n).zfill(width)}") + else: + nodes.append(f"{prefix}{part}") + return nodes if nodes else [nodelist] + + +# --------------------------------------------------------------------------- # +# SlurmRayClient +# --------------------------------------------------------------------------- # + + +@dataclass +class SlurmRayClient(RayClient): + """RayClient extended for multi-node SLURM jobs. + + On single-node SLURM jobs (or when not running under SLURM at all), + behaves identically to :class:`RayClient`. + + On multi-node jobs, the script must be launched on **every** node + (e.g. via ``srun --ntasks-per-node=1``). Each process calls + ``SlurmRayClient``, which inspects ``SLURM_NODEID`` to determine + its role: + + - **Head (SLURM_NODEID=0)**: starts the Ray head, waits for all + workers to connect, then returns from :meth:`start` so the + pipeline can run. + - **Workers (SLURM_NODEID>0)**: start a Ray worker that connects + to the head and **block until the cluster is torn down**. When + the head stops Ray (after the pipeline finishes), the worker + process exits cleanly with ``sys.exit(0)``. + + This is analogous to how ``torchrun`` works: the same script is + launched on every node and each process discovers its role from the + environment. + + Example ``sbatch`` script:: + + #!/bin/bash + #SBATCH --nodes=4 + #SBATCH --ntasks-per-node=1 + #SBATCH --gpus-per-node=8 + + srun --ntasks-per-node=1 \\ + --container-image=nvcr.io/nvidia/nemo-curator:26.02 \\ + --container-mounts="/lustre:/lustre" \\ + bash -c "source .venv/bin/activate && python my_pipeline.py" + + For bare-metal (no container) setups, the same pattern works:: + + #!/bin/bash + #SBATCH --nodes=4 + #SBATCH --ntasks-per-node=1 + #SBATCH --gpus-per-node=8 + + srun python my_pipeline.py + + If ``RAY_ADDRESS`` is set before :meth:`start` is called, + ``SlurmRayClient`` connects to the existing cluster without + starting or stopping anything. + + Parameters + ---------- + worker_connect_timeout_s: + Maximum seconds to wait for all worker nodes to join after the + head is up. Raises ``TimeoutError`` if exceeded. + cleanup_on_start: + If *True*, run ``ray stop --force`` on the local node before + starting Ray. Helps clear stale processes from previous runs. + """ + + worker_connect_timeout_s: int = 300 + cleanup_on_start: bool = True + + ray_dashboard_host: str = "0.0.0.0" # noqa: S104 + + _slurm_nodes: list[str] = field(init=False, default_factory=list, repr=False) + _manages_cluster: bool = field(init=False, default=False, repr=False) + + def __post_init__(self) -> None: + super().__post_init__() + self._detect_slurm_resources() + + def _detect_slurm_resources(self) -> None: + """Auto-detect per-node CPU/GPU counts from SLURM env vars when not set explicitly.""" + if self.num_cpus is None: + slurm_cpus = os.environ.get("SLURM_CPUS_ON_NODE") + if slurm_cpus: + self.num_cpus = int(slurm_cpus) + + if self.num_gpus is None: + slurm_gpus = os.environ.get("SLURM_GPUS_ON_NODE") + if slurm_gpus: + self.num_gpus = int(slurm_gpus) + + # ------------------------------------------------------------------ # + # Lifecycle + # ------------------------------------------------------------------ # + + def start(self) -> None: + """Start the Ray cluster, with role detection on multi-node SLURM jobs. + + If ``RAY_ADDRESS`` is already set, connects to the existing + cluster without starting a new head or launching workers. + + On multi-node jobs, worker processes (``SLURM_NODEID > 0``) + block here until the cluster is torn down, then exit with + ``sys.exit(0)``. Only the head (``SLURM_NODEID = 0``) returns + from this method. + """ + if os.environ.get("RAY_ADDRESS"): + logger.info( + f"RAY_ADDRESS already set ({os.environ['RAY_ADDRESS']}). " + "Connecting to existing Ray cluster — skipping head/worker startup." + ) + super().start() + return + + slurm_job_id = os.environ.get("SLURM_JOB_ID") + if not slurm_job_id: + logger.warning("SLURM_JOB_ID not set — falling back to single-node RayClient behaviour") + super().start() + return + + nodelist = os.environ.get("SLURM_JOB_NODELIST", socket.gethostname()) + self._slurm_nodes = _expand_slurm_nodelist(nodelist) + self._manages_cluster = True + node_id = int(os.environ.get("SLURM_NODEID", "0")) + + logger.info( + f"SlurmRayClient: job {slurm_job_id}, {len(self._slurm_nodes)} node(s), " + f"SLURM_NODEID={node_id}, head={self._slurm_nodes[0]}, " + f"cpus/node={self.num_cpus}, gpus/node={self.num_gpus}" + ) + + if self.cleanup_on_start: + self._cleanup_local_ray() + + if len(self._slurm_nodes) <= 1 or node_id == 0: + # Head node — start Ray head (super().start() selects the actual port via get_free_port) + super().start() + # Broadcast the actual port the head chose so workers don't have to guess. + # Workers may be on different physical nodes and cannot call get_free_port on + # the head, so we write the port to a shared Lustre file keyed on job ID. + if len(self._slurm_nodes) > 1: + self._write_head_port(slurm_job_id) + self._wait_for_workers() + else: + # Worker node — read the port the head actually chose, then connect. + head_ip = socket.gethostbyname(self._slurm_nodes[0]) + actual_port = self._read_head_port(slurm_job_id) + self.ray_port = actual_port + logger.info(f"SlurmRayClient worker {node_id}: connecting to head at {head_ip}:{self.ray_port}") + sys.exit(self._run_as_worker(head_ip)) + + def stop(self) -> None: + """Stop the Ray head. Workers detect the head's death and exit on their own. + + Safe to call multiple times. Does not stop an externally + managed cluster (one discovered via ``RAY_ADDRESS``). + """ + if self._manages_cluster: + slurm_job_id = os.environ.get("SLURM_JOB_ID") + if slurm_job_id: + port_file = self._head_port_file(slurm_job_id) + with contextlib.suppress(FileNotFoundError): + os.remove(port_file) + logger.info(f"SlurmRayClient: removed port file {port_file}") + super().stop() + + # ------------------------------------------------------------------ # + # Internal helpers + # ------------------------------------------------------------------ # + + def _head_port_file(self, slurm_job_id: str) -> str: + """Return path to the shared port-broadcast file for this job. + + Must be on a filesystem visible to ALL nodes (Lustre, not /tmp). + Uses env var ``RAY_PORT_BROADCAST_DIR`` if set, otherwise falls back to + ``/tmp`` (works on single-node or when /tmp is shared, e.g. via NFS). + """ + broadcast_dir = os.environ.get("RAY_PORT_BROADCAST_DIR", "/tmp") # noqa: S108 + os.makedirs(broadcast_dir, exist_ok=True) + return os.path.join(broadcast_dir, f"ray_head_port_{slurm_job_id}") + + def _write_head_port(self, slurm_job_id: str) -> None: + """Write the actual Ray GCS port to a shared file so workers can read it. + + Uses an atomic write-then-rename so workers never observe an empty or + partially-written file (important on Lustre / NFS where open() truncates + before write() completes). + """ + port_file = self._head_port_file(slurm_job_id) + broadcast_dir = os.path.dirname(port_file) + with tempfile.NamedTemporaryFile(mode="w", dir=broadcast_dir, delete=False) as f: + tmp_path = f.name + f.write(str(self.ray_port)) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp_path, port_file) # atomic on POSIX + logger.info(f"SlurmRayClient head: wrote port {self.ray_port} to {port_file}") + + def _read_head_port(self, slurm_job_id: str, timeout_s: int = 600) -> int: + """Wait for the head to write its port file and return the port number.""" + port_file = self._head_port_file(slurm_job_id) + deadline = time.time() + timeout_s + while time.time() < deadline: + if os.path.exists(port_file): + try: + with open(port_file) as f: + port = int(f.read().strip()) + except (ValueError, OSError): + pass # file may be partially written; retry + else: + logger.info(f"SlurmRayClient worker: read head port {port} from {port_file}") + return port + time.sleep(2) + msg = f"Timed out waiting for head port file {port_file} after {timeout_s}s" + raise TimeoutError(msg) + + def _run_as_worker(self, head_ip: str) -> int: + """Start a Ray worker that connects to *head_ip* and block until the cluster is torn down. + + Returns the exit code of ``ray start --block`` so the caller can pass it to ``sys.exit``. + Exit code 0 means the cluster was torn down cleanly; non-zero indicates an error. + """ + ray_bin = _find_ray_binary() + cmd = [ + ray_bin, + "start", + "--address", + f"{head_ip}:{self.ray_port}", + "--temp-dir", + self.ray_temp_dir, + "--block", + "--disable-usage-stats", + ] + if self.num_gpus is not None: + cmd.extend(["--num-gpus", str(self.num_gpus)]) + if self.num_cpus is not None: + cmd.extend(["--num-cpus", str(self.num_cpus)]) + + logger.info(f"Ray worker starting: {' '.join(cmd)}") + result = subprocess.run(cmd, check=False) # noqa: S603 + logger.info(f"Ray worker exited with code {result.returncode}") + return result.returncode + + def _cleanup_local_ray(self) -> None: + """Stop any stale Ray processes on the local node.""" + with contextlib.suppress(Exception): + ray_bin = _find_ray_binary() + subprocess.run([ray_bin, "stop", "--force"], capture_output=True, timeout=30, check=False) # noqa: S603 + + @staticmethod + def _ray_init_with_timeout(address: str, timeout_s: int = 120) -> None: + """Call ``ray.init(address=...)`` with a SIGALRM-based timeout. + + ``ray.init`` can hang indefinitely if the GCS is slow or unstable + after a multi-job start. We use SIGALRM (Linux/macOS only) to raise + a ``TimeoutError`` if the call blocks longer than *timeout_s* seconds. + + Falls back to an unguarded ``ray.init`` when called from a non-main + thread, where SIGALRM is unavailable. + """ + import ray as _ray + + if threading.current_thread() is not threading.main_thread(): + logger.warning("SIGALRM unavailable outside main thread — calling ray.init without timeout") + _ray.init(address=address, ignore_reinit_error=True) + return + + def _handler(_signum: int, _frame: object) -> None: + msg = ( + f"ray.init(address={address!r}) timed out after {timeout_s}s — " + "GCS may be unresponsive; the job will exit and can be resubmitted." + ) + raise TimeoutError(msg) + + old_handler = signal.signal(signal.SIGALRM, _handler) + signal.alarm(timeout_s) + try: + _ray.init(address=address, ignore_reinit_error=True) + finally: + signal.alarm(0) + signal.signal(signal.SIGALRM, old_handler) + + def _wait_for_workers(self) -> None: + """Block until every allocated node is alive in the Ray cluster. + + Raises ``TimeoutError`` (after tearing everything down) if not + all nodes join within ``worker_connect_timeout_s``. + """ + import ray as _ray + + expected = len(self._slurm_nodes) + deadline = time.time() + self.worker_connect_timeout_s + + self._ray_init_with_timeout(os.environ["RAY_ADDRESS"], timeout_s=120) + try: + while True: + alive = [n for n in _ray.nodes() if n.get("Alive")] + if len(alive) >= expected: + total_cpus = sum(n.get("Resources", {}).get("CPU", 0) for n in alive) + total_gpus = sum(n.get("Resources", {}).get("GPU", 0) for n in alive) + logger.info( + f"All {expected} node(s) connected — " + f"total CPUs: {total_cpus:.0f}, total GPUs: {total_gpus:.0f}" + ) + return + + remaining = deadline - time.time() + if remaining <= 0: + logger.error( + f"Timeout: only {len(alive)}/{expected} node(s) connected " + f"after {self.worker_connect_timeout_s}s." + ) + self.stop() + msg = ( + f"Timed out after {self.worker_connect_timeout_s}s: " + f"only {len(alive)}/{expected} node(s) connected. Cluster torn down." + ) + raise TimeoutError(msg) + + logger.info(f"Waiting for workers: {len(alive)}/{expected} ({remaining:.0f}s left)") + time.sleep(min(5, remaining)) + finally: + _ray.shutdown() diff --git a/tests/core/test_slurm_ray_client.py b/tests/core/test_slurm_ray_client.py new file mode 100644 index 0000000000..de2adf7eed --- /dev/null +++ b/tests/core/test_slurm_ray_client.py @@ -0,0 +1,522 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import subprocess +import tempfile +import threading +import time +from typing import TYPE_CHECKING + +import pytest + +if TYPE_CHECKING: + from collections.abc import Callable + from pathlib import Path + from typing import Any, NoReturn + +from nemo_curator.core.client import ( + RayClient, + SlurmRayClient, + _expand_slurm_nodelist, + _find_ray_binary, + _parse_slurm_nodelist, +) + +# --------------------------------------------------------------------------- # +# Helper tests +# --------------------------------------------------------------------------- # + + +class TestFindRayBinary: + def test_finds_ray_in_venv(self): + binary = _find_ray_binary() + assert os.path.isfile(binary) + + def test_raises_when_not_found(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr("shutil.which", lambda _: None) + monkeypatch.setattr("os.path.isfile", lambda _: False) + with pytest.raises(FileNotFoundError, match="ray"): + _find_ray_binary() + + +class TestExpandSlurmNodelist: + def test_single_hostname(self): + result = _expand_slurm_nodelist("compute-001") + assert result == ["compute-001"] + + def test_expands_with_scontrol(self, monkeypatch: pytest.MonkeyPatch): + import nemo_curator.core.client as _client + + fake_result = subprocess.CompletedProcess(args=[], returncode=0, stdout="node-001\nnode-002\nnode-003\n") + monkeypatch.setattr("shutil.which", lambda _: "/usr/bin/scontrol") + monkeypatch.setattr(_client.subprocess, "run", lambda *_args, **_kw: fake_result) + result = _expand_slurm_nodelist("node-[001-003]") + assert result == ["node-001", "node-002", "node-003"] + + def test_fallback_no_scontrol(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr("shutil.which", lambda _: None) + result = _expand_slurm_nodelist("node-001") + assert result == ["node-001"] + + +class TestParseSlurmNodelist: + """Tests for the pure-Python fallback parser (no scontrol required).""" + + def test_single_node(self): + assert _parse_slurm_nodelist("node-001") == ["node-001"] + + def test_comma_separated(self): + assert _parse_slurm_nodelist("node-001,node-002,node-003") == [ + "node-001", + "node-002", + "node-003", + ] + + def test_simple_range(self): + assert _parse_slurm_nodelist("pool0-[01-05]") == [ + "pool0-01", + "pool0-02", + "pool0-03", + "pool0-04", + "pool0-05", + ] + + def test_mixed_range_and_list(self): + # prefix-[01-03,07,10-12] → 6 nodes + result = _parse_slurm_nodelist("node-[01-03,07,10-12]") + assert result == [ + "node-01", + "node-02", + "node-03", + "node-07", + "node-10", + "node-11", + "node-12", + ] + + def test_zero_padded_range(self): + result = _parse_slurm_nodelist("compute-[001-003]") + assert result == ["compute-001", "compute-002", "compute-003"] + + def test_multiple_prefixes_with_ranges(self): + # Two separate bracket groups in a comma-split list + result = _parse_slurm_nodelist("gpu-[1-2],cpu-[3-4]") + assert result == ["gpu-1", "gpu-2", "cpu-3", "cpu-4"] + + +# --------------------------------------------------------------------------- # +# SlurmRayClient unit tests +# --------------------------------------------------------------------------- # + + +class TestSlurmRayClientInit: + def test_detects_slurm_cpus(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("SLURM_CPUS_ON_NODE", "64") + monkeypatch.delenv("SLURM_GPUS_ON_NODE", raising=False) + client = SlurmRayClient() + assert client.num_cpus == 64 + assert client.num_gpus is None + + def test_detects_slurm_gpus(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("SLURM_GPUS_ON_NODE", "8") + monkeypatch.delenv("SLURM_CPUS_ON_NODE", raising=False) + client = SlurmRayClient() + assert client.num_gpus == 8 + + def test_explicit_overrides_slurm(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("SLURM_CPUS_ON_NODE", "64") + monkeypatch.setenv("SLURM_GPUS_ON_NODE", "8") + client = SlurmRayClient(num_cpus=32, num_gpus=4) + assert client.num_cpus == 32 + assert client.num_gpus == 4 + + def test_dashboard_host_defaults_to_all(self): + client = SlurmRayClient() + assert client.ray_dashboard_host == "0.0.0.0" # noqa: S104 + + +class TestSlurmRayClientFallback: + def test_falls_back_without_slurm_job_id(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.delenv("SLURM_JOB_ID", raising=False) + monkeypatch.delenv("RAY_ADDRESS", raising=False) + + with tempfile.TemporaryDirectory(prefix="ray_test_slurm_") as ray_tmp: + client = SlurmRayClient(ray_temp_dir=ray_tmp) + client.start() + try: + assert os.environ.get("RAY_ADDRESS") is not None + assert client.ray_process is not None + fn = os.path.join(ray_tmp, "ray_current_cluster") + t0 = time.perf_counter() + while not os.path.exists(fn) and time.perf_counter() - t0 < 30: + time.sleep(1) + assert os.path.exists(fn) + finally: + client.stop() + + +class TestSlurmRayClientSingleNode: + """Test single-node SLURM behaviour (no srun needed).""" + + def test_single_node_start_stop(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("SLURM_JOB_ID", "12345") + monkeypatch.setenv("SLURM_JOB_NODELIST", os.uname().nodename) + monkeypatch.setenv("SLURM_CPUS_ON_NODE", "4") + monkeypatch.delenv("RAY_ADDRESS", raising=False) + + with tempfile.TemporaryDirectory(prefix="ray_test_slurm_single_") as ray_tmp: + client = SlurmRayClient(ray_temp_dir=ray_tmp, cleanup_on_start=False) + client.start() + try: + assert os.environ.get("RAY_ADDRESS") is not None + assert client._slurm_nodes == [os.uname().nodename] + finally: + client.stop() + + def test_context_manager(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("SLURM_JOB_ID", "12345") + monkeypatch.setenv("SLURM_JOB_NODELIST", os.uname().nodename) + monkeypatch.setenv("SLURM_CPUS_ON_NODE", "4") + monkeypatch.delenv("RAY_ADDRESS", raising=False) + + with tempfile.TemporaryDirectory(prefix="ray_test_slurm_ctx_") as ray_tmp: + with SlurmRayClient(ray_temp_dir=ray_tmp, cleanup_on_start=False) as client: + assert os.environ.get("RAY_ADDRESS") is not None + + assert client.ray_process is None + + +# --------------------------------------------------------------------------- # +# Port-file helpers +# --------------------------------------------------------------------------- # + + +class TestHeadPortFile: + def test_default_uses_tmp(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("RAY_PORT_BROADCAST_DIR", raising=False) + client = SlurmRayClient() + path = client._head_port_file("42") + assert os.path.basename(path) == "ray_head_port_42" + + def test_custom_dir(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("RAY_PORT_BROADCAST_DIR", str(tmp_path)) + client = SlurmRayClient() + assert client._head_port_file("99") == str(tmp_path / "ray_head_port_99") + + +class TestWriteReadHeadPort: + def test_roundtrip(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("RAY_PORT_BROADCAST_DIR", str(tmp_path)) + client = SlurmRayClient(ray_port=12345) + client._write_head_port("job1") + assert client._read_head_port("job1", timeout_s=5) == 12345 + + def test_read_timeout_raises(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("RAY_PORT_BROADCAST_DIR", str(tmp_path)) + monkeypatch.setattr(time, "sleep", lambda _: None) + with pytest.raises(TimeoutError, match="Timed out waiting"): + SlurmRayClient()._read_head_port("no_such_job", timeout_s=0) + + def test_read_ignores_partial_write(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + """If the port file exists but is empty/corrupt, _read_head_port retries until valid.""" + monkeypatch.setenv("RAY_PORT_BROADCAST_DIR", str(tmp_path)) + port_file = tmp_path / "ray_head_port_job2" + write_calls = [0] + real_sleep = time.sleep + + def patched_sleep(s: float) -> None: + write_calls[0] += 1 + if write_calls[0] == 1: + port_file.write_text("6379") + real_sleep(min(s, 0.05)) + + monkeypatch.setattr(time, "sleep", patched_sleep) + port_file.write_text("") # start with corrupt content + assert SlurmRayClient()._read_head_port("job2", timeout_s=10) == 6379 + + +# --------------------------------------------------------------------------- # +# _run_as_worker +# --------------------------------------------------------------------------- # + + +class TestRunAsWorker: + def test_returns_exit_code(self, monkeypatch: pytest.MonkeyPatch) -> None: + import nemo_curator.core.client as _client + + monkeypatch.setattr("shutil.which", lambda _: "/usr/bin/ray") + monkeypatch.setattr( + _client.subprocess, + "run", + lambda _cmd, **_kw: subprocess.CompletedProcess(args=[], returncode=0), + ) + assert SlurmRayClient()._run_as_worker("10.0.0.1") == 0 + + def test_passes_gpu_and_cpu_flags(self, monkeypatch: pytest.MonkeyPatch) -> None: + import nemo_curator.core.client as _client + + captured: list[list[str]] = [] + + def fake_run(cmd: list[str], **_kw: object) -> subprocess.CompletedProcess[str]: + captured.append(cmd) + return subprocess.CompletedProcess(args=[], returncode=0) + + monkeypatch.setattr("shutil.which", lambda _: "/usr/bin/ray") + monkeypatch.setattr(_client.subprocess, "run", fake_run) + SlurmRayClient(num_gpus=4, num_cpus=16)._run_as_worker("10.0.0.1") + assert len(captured) == 1 + assert "--num-gpus" in captured[0] + assert "--num-cpus" in captured[0] + + def test_nonzero_exit_propagated(self, monkeypatch: pytest.MonkeyPatch) -> None: + import nemo_curator.core.client as _client + + monkeypatch.setattr("shutil.which", lambda _: "/usr/bin/ray") + monkeypatch.setattr( + _client.subprocess, + "run", + lambda _cmd, **_kw: subprocess.CompletedProcess(args=[], returncode=1), + ) + assert SlurmRayClient()._run_as_worker("10.0.0.1") == 1 + + +# --------------------------------------------------------------------------- # +# _cleanup_local_ray +# --------------------------------------------------------------------------- # + + +class TestCleanupLocalRay: + def test_calls_ray_stop(self, monkeypatch: pytest.MonkeyPatch) -> None: + import nemo_curator.core.client as _client + + calls: list[list[str]] = [] + monkeypatch.setattr("shutil.which", lambda _: "/usr/bin/ray") + monkeypatch.setattr(_client.subprocess, "run", lambda cmd, **_kw: calls.append(cmd)) + SlurmRayClient()._cleanup_local_ray() + assert any("stop" in c for c in calls[0]) + + def test_suppresses_errors(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Even if ray binary is missing, _cleanup_local_ray must not raise.""" + monkeypatch.setattr("shutil.which", lambda _: None) + monkeypatch.setattr("os.path.isfile", lambda _: False) + SlurmRayClient()._cleanup_local_ray() # should not raise + + +# --------------------------------------------------------------------------- # +# _expand_slurm_nodelist (additional edge cases) +# --------------------------------------------------------------------------- # + + +class TestExpandSlurmNodelistEdgeCases: + def test_scontrol_error_falls_back_to_parser(self, monkeypatch: pytest.MonkeyPatch) -> None: + import nemo_curator.core.client as _client + + def raise_error(*_a: object, **_kw: object) -> NoReturn: + raise subprocess.CalledProcessError(1, "scontrol") + + monkeypatch.setattr("shutil.which", lambda _: "/usr/bin/scontrol") + monkeypatch.setattr(_client.subprocess, "run", raise_error) + assert _expand_slurm_nodelist("node-001") == ["node-001"] + + def test_scontrol_empty_output_falls_back(self, monkeypatch: pytest.MonkeyPatch) -> None: + import nemo_curator.core.client as _client + + monkeypatch.setattr("shutil.which", lambda _: "/usr/bin/scontrol") + monkeypatch.setattr( + _client.subprocess, + "run", + lambda *_a, **_kw: subprocess.CompletedProcess(args=[], returncode=0, stdout=""), + ) + assert _expand_slurm_nodelist("node-001") == ["node-001"] + + +# --------------------------------------------------------------------------- # +# _ray_init_with_timeout +# --------------------------------------------------------------------------- # + + +def _inject_fake_ray( + monkeypatch: pytest.MonkeyPatch, + init_fn: Callable[..., Any] | None = None, + nodes_fn: Callable[[], list[Any]] | None = None, +) -> None: + import sys + import types + + fake_ray = types.ModuleType("ray") + fake_ray.init = init_fn or (lambda *_a, **_kw: None) + fake_ray.nodes = nodes_fn or list + fake_ray.shutdown = lambda: None + monkeypatch.setitem(sys.modules, "ray", fake_ray) + + +class TestRayInitWithTimeout: + def test_main_thread_calls_ray_init(self, monkeypatch: pytest.MonkeyPatch) -> None: + initted: list[str] = [] + _inject_fake_ray(monkeypatch, init_fn=lambda address, **_kw: initted.append(address)) + SlurmRayClient._ray_init_with_timeout("127.0.0.1:6379", timeout_s=10) + assert initted == ["127.0.0.1:6379"] + + def test_non_main_thread_skips_sigalrm(self, monkeypatch: pytest.MonkeyPatch) -> None: + initted: list[str] = [] + _inject_fake_ray(monkeypatch, init_fn=lambda address, **_kw: initted.append(address)) + results: list[str] = [] + + def _run() -> None: + SlurmRayClient._ray_init_with_timeout("127.0.0.1:6379", timeout_s=5) + results.append("done") + + t = threading.Thread(target=_run) + t.start() + t.join(timeout=10) + assert results == ["done"] + assert initted == ["127.0.0.1:6379"] + + +# --------------------------------------------------------------------------- # +# _wait_for_workers +# --------------------------------------------------------------------------- # + + +class TestWaitForWorkers: + def test_success_all_nodes_connected(self, monkeypatch: pytest.MonkeyPatch) -> None: + _inject_fake_ray( + monkeypatch, + nodes_fn=lambda: [ + {"Alive": True, "Resources": {"CPU": 4.0, "GPU": 1.0}}, + {"Alive": True, "Resources": {"CPU": 4.0, "GPU": 1.0}}, + ], + ) + monkeypatch.setenv("RAY_ADDRESS", "127.0.0.1:6379") + client = SlurmRayClient(worker_connect_timeout_s=30) + client._slurm_nodes = ["node1", "node2"] + client._wait_for_workers() # must not raise + + def test_timeout_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: + _inject_fake_ray(monkeypatch, nodes_fn=list) # workers never connect + monkeypatch.setenv("RAY_ADDRESS", "127.0.0.1:6379") + monkeypatch.setattr(time, "sleep", lambda _: None) + client = SlurmRayClient(worker_connect_timeout_s=0) + client._slurm_nodes = ["node1", "node2"] + with pytest.raises(TimeoutError, match="Timed out"): + client._wait_for_workers() + + def test_partial_nodes_then_all_connected(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Workers join across multiple polling iterations.""" + call_count = [0] + + def nodes_fn() -> list[dict[str, Any]]: + call_count[0] += 1 + if call_count[0] < 3: + return [{"Alive": True, "Resources": {}}] + return [{"Alive": True, "Resources": {}}, {"Alive": True, "Resources": {}}] + + _inject_fake_ray(monkeypatch, nodes_fn=nodes_fn) + monkeypatch.setenv("RAY_ADDRESS", "127.0.0.1:6379") + monkeypatch.setattr(time, "sleep", lambda _: None) + client = SlurmRayClient(worker_connect_timeout_s=60) + client._slurm_nodes = ["node1", "node2"] + client._wait_for_workers() + assert call_count[0] >= 3 + + +# --------------------------------------------------------------------------- # +# SlurmRayClient.stop (manages_cluster branch) +# --------------------------------------------------------------------------- # + + +class TestSlurmRayClientStopManagesCluster: + def test_removes_port_file(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("RAY_PORT_BROADCAST_DIR", str(tmp_path)) + monkeypatch.setenv("SLURM_JOB_ID", "99999") + monkeypatch.setattr(RayClient, "stop", lambda _self: None) + + client = SlurmRayClient(ray_port=6379) + client._manages_cluster = True + client._write_head_port("99999") + port_file = client._head_port_file("99999") + assert os.path.exists(port_file) + + client.stop() + assert not os.path.exists(port_file) + + def test_no_port_file_does_not_raise(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("RAY_PORT_BROADCAST_DIR", str(tmp_path)) + monkeypatch.setenv("SLURM_JOB_ID", "88888") + monkeypatch.setattr(RayClient, "stop", lambda _self: None) + + client = SlurmRayClient() + client._manages_cluster = True + client.stop() # no port file — FileNotFoundError must be suppressed + + +# --------------------------------------------------------------------------- # +# SlurmRayClient.start (additional branches) +# --------------------------------------------------------------------------- # + + +class TestSlurmRayClientStartBranches: + def test_ray_address_already_set_delegates_to_super(self, monkeypatch: pytest.MonkeyPatch) -> None: + super_calls: list[str] = [] + monkeypatch.setattr(RayClient, "start", lambda _self: super_calls.append("start")) + monkeypatch.setenv("RAY_ADDRESS", "127.0.0.1:6379") + SlurmRayClient().start() + assert super_calls == ["start"] + + def test_head_node_multi_node_calls_write_and_wait(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + """Head node (SLURM_NODEID=0) with 2 nodes: must write port and wait for workers.""" + monkeypatch.setenv("SLURM_JOB_ID", "55555") + monkeypatch.setenv("SLURM_JOB_NODELIST", "node-[001-002]") + monkeypatch.setenv("SLURM_NODEID", "0") + monkeypatch.delenv("RAY_ADDRESS", raising=False) + + super_starts: list[str] = [] + + def fake_super_start(_self: object) -> None: + super_starts.append("start") + os.environ["RAY_ADDRESS"] = "10.0.0.1:6379" + + wrote: list[str] = [] + waited: list[bool] = [] + monkeypatch.setattr(RayClient, "start", fake_super_start) + monkeypatch.setattr(SlurmRayClient, "_cleanup_local_ray", lambda _self: None) + monkeypatch.setattr(SlurmRayClient, "_write_head_port", lambda _self, jid: wrote.append(jid)) + monkeypatch.setattr(SlurmRayClient, "_wait_for_workers", lambda _self: waited.append(True)) + + SlurmRayClient(cleanup_on_start=True).start() + + assert super_starts == ["start"] + assert wrote == ["55555"] + assert waited == [True] + + def test_worker_node_calls_sys_exit(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + """Worker node (SLURM_NODEID=1) must call sys.exit with the worker's return code.""" + monkeypatch.setenv("SLURM_JOB_ID", "55556") + monkeypatch.setenv("SLURM_JOB_NODELIST", "node-[001-002]") + monkeypatch.setenv("SLURM_NODEID", "1") + monkeypatch.delenv("RAY_ADDRESS", raising=False) + + import nemo_curator.core.client as _client + + monkeypatch.setattr(SlurmRayClient, "_cleanup_local_ray", lambda _self: None) + monkeypatch.setattr(_client.socket, "gethostbyname", lambda _: "10.0.0.1") + monkeypatch.setattr(SlurmRayClient, "_read_head_port", lambda _self, _jid, **_kw: 6379) + monkeypatch.setattr(SlurmRayClient, "_run_as_worker", lambda _self, _ip: 0) + + with pytest.raises(SystemExit) as exc: + SlurmRayClient(cleanup_on_start=True).start() + assert exc.value.code == 0 diff --git a/tutorials/slurm/README.md b/tutorials/slurm/README.md new file mode 100644 index 0000000000..05565de2a0 --- /dev/null +++ b/tutorials/slurm/README.md @@ -0,0 +1,289 @@ +# Running NeMo Curator on SLURM + +This tutorial shows how to scale a NeMo Curator pipeline from a single laptop to a multi-node SLURM cluster with a **one-line change**. + +## Contents + +| File | Purpose | +|------|---------| +| `pipeline.py` | A simple CPU-only pipeline (word-count + node-tag) that runs locally or on SLURM | +| `submit.sh` | `sbatch` script for bare-metal clusters with a shared virtualenv | +| `submit_container.sh` | `sbatch` script using the official NGC container (Pyxis/enroot) | + +--- + +## The key concept: RayClient vs SlurmRayClient + +NeMo Curator uses a `RayClient` to manage the Ray cluster lifecycle. The `SlurmRayClient` is a drop-in replacement that handles the multi-process SLURM model automatically. + +```python +# Local development — Ray starts on the current machine +ray_client = RayClient() + +# SLURM multi-node — Ray spans all allocated nodes automatically +ray_client = SlurmRayClient() + +# One-liner to auto-detect the environment: +ray_client = SlurmRayClient() if os.environ.get("SLURM_JOB_ID") else RayClient() +``` + +That is the **only change** needed to go from a local run to a distributed SLURM job. Everything else — pipeline stages, executor, `pipeline.run()` — is identical. + +### How SlurmRayClient works + +When `srun` launches one Python process per node, `SlurmRayClient.start()` behaves differently on each node: + +``` +srun --ntasks-per-node=1 python pipeline.py --slurm + │ + ├─ Node 0 (SLURM_NODEID=0) — HEAD + │ start() → ray start --head + │ → writes GCS port to shared file + │ → waits for all workers to join + │ → returns ← pipeline runs here + │ + ├─ Node 1 — WORKER + │ start() → reads port file from Node 0 + │ → ray start --block --address=: + │ → blocks here (serving Ray tasks) + │ + └─ Node N — WORKER (same as Node 1) +``` + +Worker nodes never return from `start()`. They serve Ray remote tasks dispatched by the Xenna executor running on the head. When `ray_client.stop()` is called on the head, the `ray stop` signal propagates and worker `srun` tasks exit. + +--- + +## Quick start — local run + +No SLURM needed. This is useful for iterating on pipeline logic. + +```bash +# Install NeMo Curator +pip install nemo-curator + +# Run locally (RayClient, single machine) +python tutorials/slurm/pipeline.py + +# Expected output: +# Tasks processed by 1 distinct node(s): ['your-hostname'] +``` + +--- + +## SLURM run — NGC container (Pyxis/enroot) + +The recommended approach on clusters that support it. The official NeMo Curator image from NGC provides a stable Python environment; the local virtualenv (on your shared filesystem) is activated inside the container to pick up any unreleased code from your checkout. + +### Prerequisites + +Check that your cluster has the Pyxis SLURM plugin: + +```bash +srun --help | grep container-image +# Should print: --container-image=... +``` + +If this flag is missing, ask your cluster admin or see the [bare-metal section](#slurm-run--bare-metal-shared-virtualenv) below. + +### 1. Build the virtualenv on a shared filesystem + +```bash +# From the NeMo Curator root on a login node (or wherever the shared FS is mounted) +python -m venv .venv +source .venv/bin/activate +pip install -e . +``` + +### 2. Submit the job + +```bash +# Default: 2 nodes, 2 GPUs each, nvcr.io/nvidia/nemo-curator:26.02 +sbatch tutorials/slurm/submit_container.sh + +# Override container image +export CONTAINER_IMAGE="nvcr.io/nvidia/nemo-curator:25.06" +sbatch tutorials/slurm/submit_container.sh + +# Override mounts (default: /lustre:/lustre) +export CONTAINER_MOUNTS="/scratch:/scratch,/data:/data" +sbatch tutorials/slurm/submit_container.sh +``` + +Override resources without editing the script: + +```bash +sbatch --nodes=1 --gpus-per-node=8 tutorials/slurm/submit_container.sh +sbatch --nodes=4 --cpus-per-task=32 --time=00:30:00 tutorials/slurm/submit_container.sh +``` + +### 3. Check the output + +```bash +tail -f logs/slurm_demo_container_.log +``` + +On a 2-node run you should see both hostnames in the processed-by summary: + +``` +Tasks processed by 2 distinct node(s): + node-001: 2 GPU(s): NVIDIA A100-SXM4-80GB, 81251 MiB; NVIDIA A100-SXM4-80GB, 81251 MiB + node-002: 2 GPU(s): NVIDIA A100-SXM4-80GB, 81251 MiB; NVIDIA A100-SXM4-80GB, 81251 MiB +``` + +### Singularity / Apptainer + +If your cluster uses Singularity or Apptainer instead of Pyxis: + +```bash +# Pull the image once (on the login node) +singularity pull nemo-curator.sif docker://nvcr.io/nvidia/nemo-curator:26.02 + +# In your sbatch script, replace the srun flags with: +srun singularity exec \ + --nv \ + --bind /lustre:/lustre \ + nemo-curator.sif \ + bash -c "source /path/to/Curator/.venv/bin/activate && python pipeline.py --slurm" +``` + +--- + +## SLURM run — bare metal (shared virtualenv) + +Use this if your cluster does not have a container runtime. + +### 1. Install on shared filesystem + +Build a virtualenv on a **shared filesystem** (Lustre, NFS, GPFS) so every node sees the same Python environment: + +```bash +# On the login node, from the NeMo Curator root +python -m venv .venv +source .venv/bin/activate +pip install -e . +``` + +### 2. Submit the job + +```bash +sbatch tutorials/slurm/submit.sh +``` + +Override resources without editing the script: + +```bash +sbatch --nodes=4 --cpus-per-task=32 --time=00:30:00 tutorials/slurm/submit.sh +``` + +### 3. Check the output + +```bash +tail -f logs/slurm_demo_.log +``` + +--- + +## Configuration reference + +### SlurmRayClient parameters + +```python +SlurmRayClient( + # Ray GCS port — defaults to a random free port + ray_port=6379, + + # Shared directory for Ray temp files (logs, sockets) + # Must be visible to all nodes + ray_temp_dir="/tmp/ray", + + # Resource overrides (auto-detected from SLURM env vars if not set) + num_gpus=8, # GPUs per node + num_cpus=64, # CPUs per node + + # How long to wait for all worker nodes to join (seconds) + worker_connect_timeout_s=300, +) +``` + +### Environment variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `RAY_PORT_BROADCAST_DIR` | `/tmp` | Directory for the port-broadcast file. **Set to a shared filesystem path when `/tmp` is not shared across nodes.** | +| `RAY_TMPDIR` | `/tmp/ray` | Ray temp directory. Recommend setting to `/tmp/ray_${SLURM_JOB_ID}` to avoid cross-job collisions. | +| `SLURM_JOB_ID` | set by SLURM | Used to name the port-broadcast file. Set manually if testing outside SLURM. | + +> **Important**: If your cluster's `/tmp` is local to each node (the common case), set +> `RAY_PORT_BROADCAST_DIR` to a Lustre/NFS path so all nodes can read the port file: +> +> ```bash +> export RAY_PORT_BROADCAST_DIR=/lustre/my-project/ray_ports +> ``` + +--- + +## Adapting to your own pipeline + +Switching any existing pipeline from `RayClient` to `SlurmRayClient` is the same one-line change shown in `pipeline.py`: + +```python +# Before (local only): +from nemo_curator.core.client import RayClient +ray_client = RayClient() + +# After (works locally AND on SLURM): +from nemo_curator.core.client import RayClient, SlurmRayClient +ray_client = SlurmRayClient() if os.environ.get("SLURM_JOB_ID") else RayClient() +``` + +Then wrap your `pipeline.run()` call in `srun`: + +```bash +# In your sbatch script: +srun --ntasks-per-node=1 python my_pipeline.py +``` + +No other changes to stages, executor, or pipeline logic are required. + +--- + +## Troubleshooting + +**Workers not joining the cluster** + +The most common cause is that `/tmp` is node-local so workers cannot read the port file written by the head. Fix: + +```bash +export RAY_PORT_BROADCAST_DIR=/shared/filesystem/path +``` + +**`TimeoutError: ray.init timed out`** + +The GCS port file exists but `ray.init()` hung. This usually means a firewall is blocking inter-node communication. Verify that the GCS port (default: random in 20000–30000) is open between nodes, or pin a known-open port: + +```python +SlurmRayClient(ray_port=6379) +``` + +**Jobs finish too quickly / no tasks processed** + +Ensure `--num-tasks` is larger than the number of workers × 2, otherwise all tasks may be completed before workers connect. The script will warn you: + +``` +Job allocated 2 nodes but only 1 node(s) processed tasks. +Check that --num-tasks is large enough to distribute across all workers. +``` + +**Container image not found** + +```bash +# Pull manually and verify +docker pull nvcr.io/nvidia/nemo-curator:26.02 +# or with enroot: +enroot import docker://nvcr.io/nvidia/nemo-curator:26.02 +``` + +**`ImportError: cannot import name 'SlurmRayClient'`** + +The container image has an older NeMo Curator without `SlurmRayClient`. Activating the local virtualenv (`source .venv/bin/activate`) inside the container overrides the container's installed version with your local checkout. Make sure the virtualenv was built from a source tree that includes `SlurmRayClient`. diff --git a/tutorials/slurm/pipeline.py b/tutorials/slurm/pipeline.py new file mode 100644 index 0000000000..be103f8639 --- /dev/null +++ b/tutorials/slurm/pipeline.py @@ -0,0 +1,254 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simple pipeline showing RayClient vs SlurmRayClient. + +The pipeline is intentionally CPU-only and dependency-free so the focus +stays on the client switch rather than model setup. + +Stages: + 1. TaskCreationStage (_EmptyTask -> list[SampleTask]) + Generates ``num_tasks`` tasks, each holding a small DataFrame of sentences. + + 2. WordCountStage (SampleTask -> SampleTask) + Adds a ``word_count`` column to each task. + + 3. NodeTagStage (SampleTask -> SampleTask) + Records which Ray-worker hostname processed the task. + On a multi-node SLURM job this column will show different hostnames, + proving that work is genuinely distributed. + +Usage:: + + # Local (single-node): + python pipeline.py + + # SLURM (multi-node) — called via srun inside submit.sh: + python pipeline.py --slurm + + # Limit tasks for a quick smoke test: + python pipeline.py --num-tasks 4 +""" + +from __future__ import annotations + +import argparse +import os +import random +import socket +from dataclasses import field + +import pandas as pd +from loguru import logger + +from nemo_curator.backends.xenna import XennaExecutor +from nemo_curator.core.client import RayClient, SlurmRayClient +from nemo_curator.pipeline import Pipeline +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.tasks import Task, _EmptyTask + +# --------------------------------------------------------------------------- +# Sample data +# --------------------------------------------------------------------------- + +SAMPLE_SENTENCES = [ + "The quick brown fox jumps over the lazy dog", + "NeMo Curator scales data curation across many nodes", + "SLURM manages workload scheduling on HPC clusters", + "Ray distributes Python workloads transparently", + "GPU acceleration dramatically speeds up deep learning", + "Data quality is critical for training large language models", + "Distributed systems require careful coordination", + "Multimodal AI combines text, image, and audio understanding", +] + + +# --------------------------------------------------------------------------- +# Task +# --------------------------------------------------------------------------- + + +class SampleTask(Task[pd.DataFrame]): + """A task holding a small DataFrame of sentences.""" + + data: pd.DataFrame = field(default_factory=pd.DataFrame) + + @property + def num_items(self) -> int: + return len(self.data) + + def validate(self) -> bool: + return True + + +# --------------------------------------------------------------------------- +# Stages +# --------------------------------------------------------------------------- + + +class TaskCreationStage(ProcessingStage[_EmptyTask, SampleTask]): + """Generate ``num_tasks`` tasks, each with ``sentences_per_task`` rows.""" + + name: str = "TaskCreationStage" + + def __init__(self, num_tasks: int = 20, sentences_per_task: int = 5) -> None: + self.num_tasks = num_tasks + self.sentences_per_task = sentences_per_task + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def outputs(self) -> tuple[list[str], list[str]]: + return ["data"], ["sentence"] + + def process(self, _: _EmptyTask) -> list[SampleTask]: + tasks = [] + for i in range(self.num_tasks): + sentences = random.choices(SAMPLE_SENTENCES, k=self.sentences_per_task) # noqa: S311 + tasks.append( + SampleTask( + data=pd.DataFrame({"sentence": sentences}), + task_id=f"task_{i:04d}", + dataset_name="slurm_demo", + ) + ) + return tasks + + +class WordCountStage(ProcessingStage[SampleTask, SampleTask]): + """Add a ``word_count`` column — pure CPU, no dependencies.""" + + name: str = "WordCountStage" + + def inputs(self) -> tuple[list[str], list[str]]: + return ["data"], ["sentence"] + + def outputs(self) -> tuple[list[str], list[str]]: + return ["data"], ["sentence", "word_count"] + + def process(self, task: SampleTask) -> SampleTask: + task.data["word_count"] = task.data["sentence"].str.split().str.len() + return task + + +class NodeTagStage(ProcessingStage[SampleTask, SampleTask]): + """Tag each task with the hostname and GPU info of the worker that processed it. + + On a multi-node SLURM run the ``processed_by`` column will show + different hostnames, confirming tasks are spread across nodes. + ``gpu_info`` reports the GPUs visible to the Ray worker process. + """ + + name: str = "NodeTagStage" + + def inputs(self) -> tuple[list[str], list[str]]: + return ["data"], ["sentence", "word_count"] + + def outputs(self) -> tuple[list[str], list[str]]: + return ["data"], ["sentence", "word_count", "processed_by", "gpu_info"] + + def process(self, task: SampleTask) -> SampleTask: + task.data["processed_by"] = socket.gethostname() + task.data["gpu_info"] = _gpu_summary() + return task + + +def _gpu_summary() -> str: + """Return a short string describing GPUs visible to the current process.""" + try: + import subprocess + result = subprocess.run( + ["nvidia-smi", "--query-gpu=name,memory.total", "--format=csv,noheader"], # noqa: S607 + capture_output=True, text=True, timeout=10, check=False, + ) + except FileNotFoundError: + return "no GPUs (nvidia-smi not found)" + except Exception: # noqa: BLE001 + return "gpu_info unavailable" + if result.returncode == 0 and result.stdout.strip(): + gpus = [line.strip() for line in result.stdout.strip().splitlines()] + return f"{len(gpus)} GPU(s): " + "; ".join(gpus) + return "no GPUs (nvidia-smi failed)" + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def build_pipeline(num_tasks: int, sentences_per_task: int) -> Pipeline: + pipeline = Pipeline( + name="slurm_demo", + description="Word-count + node-tag pipeline — no GPU required", + ) + pipeline.add_stage(TaskCreationStage(num_tasks=num_tasks, sentences_per_task=sentences_per_task)) + pipeline.add_stage(WordCountStage()) + pipeline.add_stage(NodeTagStage()) + return pipeline + + +def main() -> None: + parser = argparse.ArgumentParser(description="SLURM demo pipeline") + parser.add_argument("--slurm", action="store_true", help="Use SlurmRayClient (set when running via srun)") + parser.add_argument("--num-tasks", type=int, default=20, help="Number of tasks to generate") + parser.add_argument("--sentences-per-task", type=int, default=5) + args = parser.parse_args() + + # ----------------------------------------------------------------------- + # The only change needed to go from local to SLURM is this one line. + # ----------------------------------------------------------------------- + ray_client = SlurmRayClient() if args.slurm else RayClient() + ray_client.start() + # On worker nodes (SLURM_NODEID > 0), start() never returns — + # they block running the Ray daemon. Only the head continues below. + + try: + pipeline = build_pipeline(args.num_tasks, args.sentences_per_task) + logger.info(f"\n{pipeline.describe()}") + + executor = XennaExecutor(config={"execution_mode": "streaming"}) + results = pipeline.run(executor=executor) + finally: + ray_client.stop() + + if not results: + logger.warning("No results returned") + return + + logger.info(f"Completed {len(results)} tasks") + + # Show which nodes + GPUs processed tasks + node_gpu: dict[str, str] = {} + for task in results: + for _, row in task.data[["processed_by", "gpu_info"]].drop_duplicates().iterrows(): + node_gpu[row["processed_by"]] = row["gpu_info"] + + logger.info(f"Tasks processed by {len(node_gpu)} distinct node(s):") + for node, gpu in sorted(node_gpu.items()): + logger.info(f" {node}: {gpu}") + + # Print a sample result + sample = results[0].data + logger.info(f"\nSample output (task '{results[0].task_id}'):\n{sample.to_string(index=False)}") + + slurm_nodes = int(os.environ.get("SLURM_JOB_NUM_NODES", "1")) + if slurm_nodes > 1 and len(node_gpu) < 2: # noqa: PLR2004 + logger.warning( + f"Job allocated {slurm_nodes} nodes but only {len(node_gpu)} node(s) processed tasks. " + "Check that --num-tasks is large enough to distribute across all workers." + ) + + +if __name__ == "__main__": + main() diff --git a/tutorials/slurm/submit.sh b/tutorials/slurm/submit.sh new file mode 100644 index 0000000000..897861c7c0 --- /dev/null +++ b/tutorials/slurm/submit.sh @@ -0,0 +1,79 @@ +#!/bin/bash +# ============================================================================= +# NeMo Curator — SLURM submit script (bare-metal, using uv) +# +# Runs the slurm demo pipeline across multiple nodes using SlurmRayClient. +# Uses `uv run` to execute with the correct project dependencies without +# requiring a system Python installation on compute nodes. +# +# Prerequisites: +# - uv installed (https://docs.astral.sh/uv/getting-started/installation/) +# - NeMo Curator source checked out on a shared filesystem +# - Shared filesystem accessible from all nodes (e.g. Lustre, NFS) +# +# If your cluster has Pyxis/enroot, prefer submit_container.sh instead — +# it uses the official NGC container and is the recommended approach. +# +# Usage: +# sbatch tutorials/slurm/submit.sh +# +# Override resources without editing this file: +# sbatch --nodes=1 --gpus-per-node=2 tutorials/slurm/submit.sh +# sbatch --nodes=1 --gpus-per-node=8 tutorials/slurm/submit.sh +# sbatch --nodes=2 --gpus-per-node=2 tutorials/slurm/submit.sh +# sbatch --nodes=2 --gpus-per-node=8 tutorials/slurm/submit.sh +# ============================================================================= + +#SBATCH --job-name=curator-slurm-demo +#SBATCH --nodes=2 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=16 +#SBATCH --gpus-per-node=2 +#SBATCH --time=00:10:00 +#SBATCH --output=logs/slurm_demo_%j.log +#SBATCH --error=logs/slurm_demo_%j.log + +set -euo pipefail + +# --------------------------------------------------------------------------- +# Paths — adjust to your environment +# --------------------------------------------------------------------------- +CURATOR_DIR="${CURATOR_DIR:-$(cd "$(dirname "$0")/../.." && pwd)}" + +# Shared directory for Ray port broadcast — must be visible to ALL nodes. +# On most clusters /tmp is node-local, so we use a path on the shared FS. +export RAY_PORT_BROADCAST_DIR="${CURATOR_DIR}/logs" +export RAY_TMPDIR="/tmp/ray_${SLURM_JOB_ID}" + +# uv cache — set to a shared location to avoid re-downloading on each node +export UV_CACHE_DIR="${UV_CACHE_DIR:-${HOME}/.cache/uv}" + +echo "==================================================" +echo " NeMo Curator — SLURM Demo" +echo "==================================================" +echo " Job ID : ${SLURM_JOB_ID}" +echo " Nodes : ${SLURM_JOB_NODELIST} (${SLURM_JOB_NUM_NODES} nodes)" +echo " GPUs/node : ${SLURM_GPUS_ON_NODE:-none}" +echo " CPUs/node : ${SLURM_CPUS_ON_NODE:-N/A}" +echo " Dir : ${CURATOR_DIR}" +echo "==================================================" + +mkdir -p logs + +srun \ + --ntasks-per-node=1 \ + bash -c " +cd '${CURATOR_DIR}' +export RAY_TMPDIR=/tmp/ray_\${SLURM_JOB_ID} +export RAY_PORT_BROADCAST_DIR='${CURATOR_DIR}/logs' +echo \"[\$(hostname)] SLURM_NODEID=\${SLURM_NODEID} python=\$(uv run python --version 2>&1)\" +nvidia-smi --query-gpu=index,name,memory.total --format=csv,noheader 2>/dev/null \ + | sed \"s/^/ [\$(hostname)] GPU /\" || echo \" [\$(hostname)] no GPUs\" +uv run python '${CURATOR_DIR}/tutorials/slurm/pipeline.py' \ + --slurm \ + --num-tasks 80 +" + +echo "==================================================" +echo " DONE" +echo "==================================================" diff --git a/tutorials/slurm/submit_container.sh b/tutorials/slurm/submit_container.sh new file mode 100644 index 0000000000..2d023db934 --- /dev/null +++ b/tutorials/slurm/submit_container.sh @@ -0,0 +1,93 @@ +#!/bin/bash +# ============================================================================= +# NeMo Curator — SLURM submit script (NGC container via Pyxis/enroot) +# +# Runs the slurm demo pipeline inside the official NeMo Curator container +# using the Pyxis SLURM plugin, with the local Curator virtualenv activated +# so that the latest (unreleased) code is used. +# +# This mirrors the pattern used for the Nemotron-Parse PDF pipeline. +# +# Prerequisites: +# - Pyxis plugin installed on the cluster (check: srun --help | grep container) +# - NeMo Curator source checked out on a shared filesystem (Lustre / NFS) +# - A virtualenv built from that source: python -m venv .venv && pip install -e . +# - The shared filesystem mounted at the same path inside the container +# +# Usage: +# sbatch tutorials/slurm/submit_container.sh +# +# Override resources without editing this file: +# sbatch --nodes=1 --gpus-per-node=2 tutorials/slurm/submit_container.sh +# sbatch --nodes=1 --gpus-per-node=8 tutorials/slurm/submit_container.sh +# sbatch --nodes=2 --gpus-per-node=2 tutorials/slurm/submit_container.sh +# sbatch --nodes=2 --gpus-per-node=8 tutorials/slurm/submit_container.sh +# ============================================================================= + +#SBATCH --job-name=curator-slurm-demo-container +#SBATCH --nodes=2 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=16 +#SBATCH --gpus-per-node=2 +#SBATCH --time=00:10:00 +#SBATCH --output=logs/slurm_demo_container_%j.log +#SBATCH --error=logs/slurm_demo_container_%j.log + +set -euo pipefail + +# --------------------------------------------------------------------------- +# Paths — adjust to your environment +# --------------------------------------------------------------------------- +CURATOR_DIR="${CURATOR_DIR:-$(cd "$(dirname "$0")/../.." && pwd)}" + +# Official NeMo Curator container from NGC. +# Browse available tags: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo-curator +CONTAINER_IMAGE="${CONTAINER_IMAGE:-nvcr.io/nvidia/nemo-curator:26.02}" + +# Mount the shared filesystem that contains your code and data. +# Format: :[,:] +CONTAINER_MOUNTS="${CONTAINER_MOUNTS:-/lustre:/lustre}" + +# Shared directory for Ray port broadcast — must be visible to ALL nodes. +# On most clusters /tmp is node-local, so we use a Lustre path here. +# Adjust to any shared filesystem path accessible from every compute node. +export RAY_PORT_BROADCAST_DIR="${CURATOR_DIR}/logs" + +echo "==================================================" +echo " NeMo Curator — SLURM Demo (container)" +echo "==================================================" +echo " Job ID : ${SLURM_JOB_ID}" +echo " Nodes : ${SLURM_JOB_NODELIST} (${SLURM_JOB_NUM_NODES} nodes)" +echo " GPUs/node : ${SLURM_GPUS_ON_NODE:-none}" +echo " Container : ${CONTAINER_IMAGE}" +echo " Mounts : ${CONTAINER_MOUNTS}" +echo " Dir : ${CURATOR_DIR}" +echo "==================================================" + +mkdir -p logs + +srun \ + --ntasks-per-node=1 \ + --container-image="${CONTAINER_IMAGE}" \ + --container-mounts="${CONTAINER_MOUNTS}" \ + --container-workdir="${CURATOR_DIR}" \ + bash -c " +export RAY_TMPDIR=/tmp/ray_\${SLURM_JOB_ID} +export RAY_PORT_BROADCAST_DIR='${CURATOR_DIR}/logs' + +# Activate the local virtualenv so the latest Curator code (from this +# checkout) is used instead of the version bundled in the container image. +source '${CURATOR_DIR}/.venv/bin/activate' + +echo \"[\$(hostname)] SLURM_NODEID=\${SLURM_NODEID} python=\$(python --version 2>&1)\" +nvidia-smi --query-gpu=index,name,memory.total --format=csv,noheader 2>/dev/null \ + | sed \"s/^/ [\$(hostname)] GPU /\" || echo \" [\$(hostname)] no GPUs\" + +python '${CURATOR_DIR}/tutorials/slurm/pipeline.py' \ + --slurm \ + --num-tasks 80 +" + +echo "==================================================" +echo " DONE" +echo "=================================================="