Skip to content

Commit 7640137

Browse files
authored
Honor executor srun_args for Ray COMMAND srun (#440)
* feat: support container-image None in slurm Signed-off-by: Hemil Desai <hemild@nvidia.com> * fix Signed-off-by: Hemil Desai <hemild@nvidia.com> * Honor executor srun_args for Ray command srun Signed-off-by: Hemil Desai <hemild@nvidia.com> * Fix ray srun formatting and cover heterogeneous command args Signed-off-by: Hemil Desai <hemild@nvidia.com> --------- Signed-off-by: Hemil Desai <hemild@nvidia.com>
1 parent c7aed4f commit 7640137

3 files changed

Lines changed: 79 additions & 1 deletion

File tree

nemo_run/run/ray/slurm.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,19 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str:
278278

279279
return " ".join(_srun_flags)
280280

281+
def get_command_srun_args() -> str:
282+
if (
283+
self.executor.run_as_group
284+
and self.executor.heterogeneous
285+
and self.executor.resource_group
286+
and self.executor.resource_group[0].srun_args is not None
287+
):
288+
command_srun_args = self.executor.resource_group[0].srun_args
289+
else:
290+
command_srun_args = self.executor.srun_args or []
291+
292+
return " ".join(shlex.quote(arg) for arg in command_srun_args)
293+
281294
ray_log_prefix = job_details.ray_log_prefix
282295
vars_to_fill = {
283296
"sbatch_flags": sbatch_flags,
@@ -296,6 +309,7 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str:
296309
"ray_log_prefix": ray_log_prefix,
297310
"heterogeneous": self.executor.heterogeneous,
298311
"resource_group": self.executor.resource_group if self.executor.heterogeneous else [],
312+
"command_srun_args": get_command_srun_args(),
299313
}
300314

301315
if self.command_groups:

nemo_run/run/ray/templates/ray.sub.j2

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ COMMAND="${COMMAND:-{{ command | default('', true) }}}"
454454
COMMAND_WORKDIR={{ command_workdir | default('$CONTAINER_CWD') }}
455455
456456
if [[ -n "$COMMAND" ]]; then
457-
srun {% if heterogeneous %}--het-group=0 {% endif %}--no-container-mount-home --gpus=0 --overlap --container-name=ray-head --container-workdir=$COMMAND_WORKDIR --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/{{ ray_log_prefix }}job.log bash -c "$COMMAND"
457+
srun {% if heterogeneous %}--het-group=0 {% endif %}--no-container-mount-home --gpus=0 --overlap {% if command_srun_args %}{{ command_srun_args }} {% endif %}--container-name=ray-head --container-workdir=$COMMAND_WORKDIR --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/{{ ray_log_prefix }}job.log bash -c "$COMMAND"
458458
else
459459
echo "[INFO]: Ray Cluster is idled, run this on the slurm head node to get a shell to the head node:"
460460
cat <<EOF >$CLUSTER_DIR/scripts/${SLURM_JOB_ID}-attach.sh

test/run/ray/test_slurm_ray_request.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,70 @@ def test_command_groups_without_resource_group(self):
627627
assert "--overlap" in script
628628
assert "cmd1" in script # Second command in the list (index 1)
629629

630+
def test_command_srun_honors_executor_srun_args(self):
631+
"""Test that the COMMAND launch srun includes executor srun_args."""
632+
executor = SlurmExecutor(account="test_account", srun_args=["--mpi=pmix"])
633+
executor.tunnel = Mock(spec=SSHTunnel)
634+
executor.tunnel.job_dir = "/tmp/test_jobs"
635+
636+
request = SlurmRayRequest(
637+
name="test-ray-cluster",
638+
cluster_dir="/tmp/test_jobs/test-ray-cluster",
639+
template_name="ray.sub.j2",
640+
executor=executor,
641+
command="echo hello",
642+
launch_cmd=["sbatch", "--parsable"],
643+
)
644+
645+
script = request.materialize()
646+
assert "--gpus=0 --overlap --mpi=pmix --container-name=ray-head" in script
647+
648+
def test_command_srun_honors_head_resource_group_srun_args(self):
649+
"""Test that heterogeneous grouped runs use head resource-group srun_args for COMMAND."""
650+
executor = SlurmExecutor(
651+
account="test_account",
652+
heterogeneous=True,
653+
srun_args=["--mpi=none"],
654+
)
655+
executor.run_as_group = True
656+
executor.resource_group = [
657+
SlurmExecutor.ResourceRequest(
658+
packager=Mock(),
659+
nodes=1,
660+
ntasks_per_node=1,
661+
container_image="image1",
662+
container_mounts=["/data:/data"],
663+
srun_args=["--mpi=pmix"],
664+
het_group_index=0,
665+
),
666+
SlurmExecutor.ResourceRequest(
667+
packager=Mock(),
668+
nodes=1,
669+
ntasks_per_node=1,
670+
container_image="image2",
671+
container_mounts=["/data:/data"],
672+
het_group_index=1,
673+
),
674+
]
675+
executor.tunnel = Mock(spec=SSHTunnel)
676+
executor.tunnel.job_dir = "/tmp/test_jobs"
677+
678+
request = SlurmRayRequest(
679+
name="test-ray-cluster",
680+
cluster_dir="/tmp/test_jobs/test-ray-cluster",
681+
template_name="ray.sub.j2",
682+
executor=executor,
683+
command="echo hello",
684+
launch_cmd=["sbatch", "--parsable"],
685+
)
686+
687+
script = request.materialize()
688+
assert (
689+
"--het-group=0 --no-container-mount-home --gpus=0 --overlap --mpi=pmix "
690+
"--container-name=ray-head" in script
691+
)
692+
assert "--gpus=0 --overlap --mpi=none --container-name=ray-head" not in script
693+
630694
def test_env_vars_formatting(self):
631695
"""Test that environment variables are properly formatted as export statements."""
632696
executor = SlurmExecutor(

0 commit comments

Comments
 (0)