Skip to content

Commit 911edd9

Browse files
committed
Fix ray srun formatting and cover heterogeneous command args
Signed-off-by: Hemil Desai <hemild@nvidia.com>
1 parent 2af4f5e commit 911edd9

2 files changed

Lines changed: 52 additions & 6 deletions

File tree

nemo_run/run/ray/slurm.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,9 +1271,9 @@ def start(
12711271
if isinstance(self.executor.tunnel, SSHTunnel):
12721272
# Rsync workdir honouring .gitignore
12731273
self.executor.tunnel.connect()
1274-
assert (
1275-
self.executor.tunnel.session is not None
1276-
), "Tunnel session is not connected"
1274+
assert self.executor.tunnel.session is not None, (
1275+
"Tunnel session is not connected"
1276+
)
12771277
rsync(
12781278
self.executor.tunnel.session,
12791279
workdir,
@@ -1328,9 +1328,9 @@ def start(
13281328

13291329
if isinstance(self.executor.tunnel, SSHTunnel):
13301330
self.executor.tunnel.connect()
1331-
assert (
1332-
self.executor.tunnel.session is not None
1333-
), "Tunnel session is not connected"
1331+
assert self.executor.tunnel.session is not None, (
1332+
"Tunnel session is not connected"
1333+
)
13341334
rsync(
13351335
self.executor.tunnel.session,
13361336
os.path.join(local_code_extraction_path, ""),

test/run/ray/test_slurm_ray_request.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,52 @@ def test_command_srun_honors_executor_srun_args(self):
645645
script = request.materialize()
646646
assert "--gpus=0 --overlap --mpi=pmix --container-name=ray-head" in script
647647

648+
def test_command_srun_honors_head_resource_group_srun_args(self):
649+
"""Test that heterogeneous grouped runs use head resource-group srun_args for COMMAND."""
650+
executor = SlurmExecutor(
651+
account="test_account",
652+
heterogeneous=True,
653+
srun_args=["--mpi=none"],
654+
)
655+
executor.run_as_group = True
656+
executor.resource_group = [
657+
SlurmExecutor.ResourceRequest(
658+
packager=Mock(),
659+
nodes=1,
660+
ntasks_per_node=1,
661+
container_image="image1",
662+
container_mounts=["/data:/data"],
663+
srun_args=["--mpi=pmix"],
664+
het_group_index=0,
665+
),
666+
SlurmExecutor.ResourceRequest(
667+
packager=Mock(),
668+
nodes=1,
669+
ntasks_per_node=1,
670+
container_image="image2",
671+
container_mounts=["/data:/data"],
672+
het_group_index=1,
673+
),
674+
]
675+
executor.tunnel = Mock(spec=SSHTunnel)
676+
executor.tunnel.job_dir = "/tmp/test_jobs"
677+
678+
request = SlurmRayRequest(
679+
name="test-ray-cluster",
680+
cluster_dir="/tmp/test_jobs/test-ray-cluster",
681+
template_name="ray.sub.j2",
682+
executor=executor,
683+
command="echo hello",
684+
launch_cmd=["sbatch", "--parsable"],
685+
)
686+
687+
script = request.materialize()
688+
assert (
689+
"--het-group=0 --no-container-mount-home --gpus=0 --overlap --mpi=pmix "
690+
"--container-name=ray-head" in script
691+
)
692+
assert "--gpus=0 --overlap --mpi=none --container-name=ray-head" not in script
693+
648694
def test_env_vars_formatting(self):
649695
"""Test that environment variables are properly formatted as export statements."""
650696
executor = SlurmExecutor(

0 commit comments

Comments
 (0)