Skip to content

Commit b2ff9f5

Browse files
committed
Add experiment integration
Signed-off-by: Hemil Desai <hemild@nvidia.com>
1 parent 96f4bd8 commit b2ff9f5

9 files changed

Lines changed: 192 additions & 143 deletions

File tree

nemo_run/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,8 @@ class Script(ConfigurableMixin):
449449
#: Whether to use ``python -m`` when executing via python.
450450
m: bool = False
451451

452+
metadata: dict[str, Any] = dataclasses.field(default_factory=dict)
453+
452454
def __post_init__(self):
453455
assert self.path or self.inline
454456
assert self.entrypoint, "Need to provide an entrypoint for script."

nemo_run/core/execution/slurm.py

Lines changed: 49 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -748,10 +748,10 @@ def _as_sbatch_flag(key: str, value: Any) -> str:
748748

749749
@dataclass(kw_only=True)
750750
class SlurmBatchRequest:
751-
cmd: list[str]
751+
launch_cmd: list[str]
752752
jobs: list[str]
753753
command_groups: list[list[str]]
754-
slurm_config: SlurmExecutor
754+
executor: SlurmExecutor
755755
max_retries: int
756756
setup: Optional[list[str]] = None
757757
extra_env: dict[str, str]
@@ -786,7 +786,7 @@ def materialize(self) -> str:
786786
In case an erroneous keyword argument is added, a list of all eligible parameters
787787
is printed, with their default values
788788
"""
789-
args = asdict(self.slurm_config) # noqa: F821
789+
args = asdict(self.executor) # noqa: F821
790790
parameters = {
791791
k: v for k, v in args.items() if v is not None and k in SlurmExecutor.SBATCH_FLAGS
792792
}
@@ -800,18 +800,16 @@ def materialize(self) -> str:
800800
# add necessary parameters
801801
original_job_name: str = self.jobs[0] # type: ignore
802802
job_name_prefix = (
803-
self.slurm_config.job_name_prefix
804-
if self.slurm_config.job_name_prefix
805-
else f"{self.slurm_config.account}-{self.slurm_config.account.split('_')[-1]}."
803+
self.executor.job_name_prefix
804+
if self.executor.job_name_prefix
805+
else f"{self.executor.account}-{self.executor.account.split('_')[-1]}."
806806
)
807807
job_name = f"{job_name_prefix}{original_job_name}"
808808
slurm_job_dir = (
809-
self.slurm_config.tunnel.job_dir
810-
if self.slurm_config.tunnel
811-
else self.slurm_config.job_dir
809+
self.executor.tunnel.job_dir if self.executor.tunnel else self.executor.job_dir
812810
)
813-
job_directory_name = Path(self.slurm_config.job_dir).name
814-
job_details = self.slurm_config.job_details
811+
job_directory_name = Path(self.executor.job_dir).name
812+
job_details = self.executor.job_details
815813

816814
if not job_details.job_name:
817815
job_details.job_name = job_name
@@ -824,41 +822,41 @@ def materialize(self) -> str:
824822
stdout = str(job_details.stdout)
825823
stderr = str(job_details.stderr)
826824

827-
if self.slurm_config.array is not None:
825+
if self.executor.array is not None:
828826
stdout = stdout.replace("%j", "%A_%a")
829827
stderr = stderr.replace("%j", "%A_%a")
830828
parameters["output"] = stdout.replace("%t", "0")
831829

832-
if not self.slurm_config.stderr_to_stdout:
830+
if not self.executor.stderr_to_stdout:
833831
parameters["error"] = stderr.replace("%t", "0")
834832

835-
if self.slurm_config.additional_parameters is not None:
836-
parameters.update(self.slurm_config.additional_parameters)
833+
if self.executor.additional_parameters is not None:
834+
parameters.update(self.executor.additional_parameters)
837835

838836
# now create
839-
sbatch_cmd = " ".join([shlex.quote(arg) for arg in self.cmd])
837+
sbatch_cmd = " ".join([shlex.quote(arg) for arg in self.launch_cmd])
840838

841839
sbatch_flags = []
842-
if self.slurm_config.heterogeneous:
843-
assert len(self.jobs) == len(self.slurm_config.resource_group), (
844-
f"Number of jobs {len(self.jobs)} must match number of resource group requests {len(self.slurm_config.resource_group)}.\nIf you are just submitting a single job, make sure that heterogeneous=False in the executor."
840+
if self.executor.heterogeneous:
841+
assert len(self.jobs) == len(self.executor.resource_group), (
842+
f"Number of jobs {len(self.jobs)} must match number of resource group requests {len(self.executor.resource_group)}.\nIf you are just submitting a single job, make sure that heterogeneous=False in the executor."
845843
)
846-
final_group_index = len(self.slurm_config.resource_group) - 1
847-
if self.slurm_config.het_group_indices:
848-
final_group_index = self.slurm_config.het_group_indices.index(
849-
max(self.slurm_config.het_group_indices)
844+
final_group_index = len(self.executor.resource_group) - 1
845+
if self.executor.het_group_indices:
846+
final_group_index = self.executor.het_group_indices.index(
847+
max(self.executor.het_group_indices)
850848
)
851849

852-
for i in range(len(self.slurm_config.resource_group)):
853-
resource_req = self.slurm_config.resource_group[i]
850+
for i in range(len(self.executor.resource_group)):
851+
resource_req = self.executor.resource_group[i]
854852
if resource_req.het_group_index:
855-
assert self.slurm_config.resource_group[i - 1].het_group_index is not None, (
853+
assert self.executor.resource_group[i - 1].het_group_index is not None, (
856854
"het_group_index must be set for all requests in resource_group"
857855
)
858856
if (
859857
i > 0
860858
and resource_req.het_group_index
861-
== self.slurm_config.resource_group[i - 1].het_group_index
859+
== self.executor.resource_group[i - 1].het_group_index
862860
):
863861
continue
864862

@@ -887,33 +885,31 @@ def materialize(self) -> str:
887885
for k in sorted(parameters):
888886
sbatch_flags.append(_as_sbatch_flag(k, parameters[k]))
889887

890-
if self.slurm_config.dependencies:
891-
slurm_deps = self.slurm_config.parse_deps()
888+
if self.executor.dependencies:
889+
slurm_deps = self.executor.parse_deps()
892890
sbatch_flags.append(
893891
_as_sbatch_flag(
894-
"dependency", f"{self.slurm_config.dependency_type}:{':'.join(slurm_deps)}"
892+
"dependency", f"{self.executor.dependency_type}:{':'.join(slurm_deps)}"
895893
)
896894
)
897895

898896
env_vars = []
899-
full_env_vars = self.slurm_config.env_vars | self.extra_env
897+
full_env_vars = self.executor.env_vars | self.extra_env
900898
for key, value in full_env_vars.items():
901899
env_vars.append(f"export {key.upper()}={value}")
902900

903901
# commandline (this will run the function and args specified in the file provided as argument)
904902
# We pass --output and --error here, because the SBATCH command doesn't work as expected with a filename pattern
905-
stderr_flags = [] if self.slurm_config.stderr_to_stdout else ["--error", stderr]
903+
stderr_flags = [] if self.executor.stderr_to_stdout else ["--error", stderr]
906904

907905
srun_commands = []
908906
group_env_vars = []
909907
srun_stdout = noquote(job_details.srun_stdout)
910908
stderr_flags = (
911-
[]
912-
if self.slurm_config.stderr_to_stdout
913-
else ["--error", noquote(job_details.srun_stderr)]
909+
[] if self.executor.stderr_to_stdout else ["--error", noquote(job_details.srun_stderr)]
914910
)
915911
memory_measure_out = None
916-
if self.slurm_config.memory_measure:
912+
if self.executor.memory_measure:
917913
memory_measure_out = srun_stdout
918914

919915
def get_container_flags(
@@ -937,10 +933,10 @@ def get_container_flags(
937933
return _container_flags
938934

939935
for group_ind, command_group in enumerate(self.command_groups):
940-
if self.slurm_config.run_as_group and len(self.slurm_config.resource_group) == len(
936+
if self.executor.run_as_group and len(self.executor.resource_group) == len(
941937
self.command_groups
942938
):
943-
resource_req = self.slurm_config.resource_group[group_ind]
939+
resource_req = self.executor.resource_group[group_ind]
944940
if not resource_req.job_details.job_name:
945941
resource_req.job_details.job_name = f"{job_name_prefix}{self.jobs[group_ind]}"
946942

@@ -952,7 +948,7 @@ def get_container_flags(
952948
cmd_stdout = noquote(resource_req.job_details.srun_stdout)
953949
cmd_stderr = (
954950
[]
955-
if self.slurm_config.stderr_to_stdout
951+
if self.executor.stderr_to_stdout
956952
else [
957953
"--error",
958954
noquote(resource_req.job_details.srun_stderr),
@@ -980,20 +976,20 @@ def get_container_flags(
980976
if cmd_stderr:
981977
cmd_stderr[-1] = cmd_stderr[-1].replace(original_job_name, self.jobs[group_ind])
982978
_container_flags = get_container_flags(
983-
base_mounts=self.slurm_config.container_mounts,
979+
base_mounts=self.executor.container_mounts,
984980
src_job_dir=os.path.join(
985981
slurm_job_dir,
986982
job_directory_name,
987983
),
988-
container_image=self.slurm_config.container_image,
984+
container_image=self.executor.container_image,
989985
)
990986
_srun_args = ["--wait=60", "--kill-on-bad-exit=1"]
991-
_srun_args.extend(self.slurm_config.srun_args or [])
987+
_srun_args.extend(self.executor.srun_args or [])
992988

993-
if self.slurm_config.run_as_group and self.slurm_config.heterogeneous:
989+
if self.executor.run_as_group and self.executor.heterogeneous:
994990
het_group_index = (
995-
self.slurm_config.resource_group[group_ind].het_group_index
996-
if self.slurm_config.resource_group[group_ind].het_group_index is not None
991+
self.executor.resource_group[group_ind].het_group_index
992+
if self.executor.resource_group[group_ind].het_group_index is not None
997993
else group_ind
998994
)
999995
het_group_flag = [f"--het-group={het_group_index}"]
@@ -1018,10 +1014,10 @@ def get_container_flags(
10181014
)
10191015
command = " ".join(command_group)
10201016

1021-
if self.slurm_config.run_as_group:
1017+
if self.executor.run_as_group:
10221018
srun_command = f"{srun_cmd} {command} & pids[{group_ind}]=$!"
10231019
if group_ind != len(self.command_groups) - 1:
1024-
srun_command += f"\n\nsleep {self.slurm_config.wait_time_for_group_job}\n"
1020+
srun_command += f"\n\nsleep {self.executor.wait_time_for_group_job}\n"
10251021
else:
10261022
srun_command = f"{srun_cmd} {command}"
10271023

@@ -1033,15 +1029,14 @@ def get_container_flags(
10331029
"max_retries": self.max_retries,
10341030
"env_vars": env_vars,
10351031
"head_node_ip_var": SlurmExecutor.HEAD_NODE_IP_VAR,
1036-
"setup_lines": self.slurm_config.setup_lines,
1032+
"setup_lines": self.executor.setup_lines,
10371033
"memory_measure": memory_measure_out,
10381034
"srun_commands": srun_commands,
10391035
"group_env_vars": group_env_vars,
1040-
"heterogeneous": self.slurm_config.heterogeneous,
1041-
"run_as_group": self.slurm_config.run_as_group,
1042-
"monitor_group_job": self.slurm_config.run_as_group
1043-
and self.slurm_config.monitor_group_job,
1044-
"monitor_group_job_wait_time": self.slurm_config.monitor_group_job_wait_time,
1036+
"heterogeneous": self.executor.heterogeneous,
1037+
"run_as_group": self.executor.run_as_group,
1038+
"monitor_group_job": self.executor.run_as_group and self.executor.monitor_group_job,
1039+
"monitor_group_job_wait_time": self.executor.monitor_group_job_wait_time,
10451040
"het_group_host_var": SlurmExecutor.HET_GROUP_HOST_VAR,
10461041
"ft_enabled": self.launcher and isinstance(self.launcher, FaultTolerance),
10471042
}
@@ -1060,7 +1055,7 @@ def get_container_flags(
10601055
return sbatch_script
10611056

10621057
def __repr__(self) -> str:
1063-
return f"""{" ".join(self.cmd + ["$SBATCH_SCRIPT"])}
1058+
return f"""{" ".join(self.launch_cmd + ["$SBATCH_SCRIPT"])}
10641059
10651060
#----------------
10661061
# SBATCH_SCRIPT

nemo_run/core/execution/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def fill_template(template_name: str, variables: dict, template_dir: Optional[st
2525
template_dir = template_dir or os.path.join(os.path.dirname(__file__), "templates")
2626
template_path = os.path.join(template_dir, template_name)
2727
if not os.path.exists(template_path):
28-
raise FileNotFoundError(f'Template "{template_name}" does not exist.')
28+
raise FileNotFoundError(f'Template "{template_path}" does not exist.')
2929
with open(template_path, "r", encoding="utf-8") as fin:
3030
template = fin.read()
3131

nemo_run/run/ray/cluster.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from nemo_run.run.ray.kuberay import KubeRayCluster
2323
from nemo_run.run.ray.slurm import SlurmRayCluster
2424

25+
USE_WITH_RAY_CLUSTER_KEY = "use_with_ray_cluster"
26+
2527

2628
@dataclass(kw_only=True)
2729
class RayCluster:

nemo_run/run/ray/slurm.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import copy
1617
import json
1718
import logging
1819
import os
@@ -28,6 +29,7 @@
2829
from pathlib import Path
2930
from typing import Any, Dict, Optional, TypeAlias, Union
3031

32+
from nemo_run.config import RUNDIR_NAME, RUNDIR_SPECIAL_NAME
3133
from nemo_run.core.execution.slurm import SlurmExecutor, _as_sbatch_flag
3234
from nemo_run.core.execution.utils import fill_template
3335
from nemo_run.core.packaging.git import GitArchivePackager
@@ -43,11 +45,14 @@
4345
class SlurmRayRequest:
4446
name: str
4547
cluster_dir: str
46-
template_path: str
48+
template_name: str
49+
template_dir: Optional[str] = None
4750
executor: SlurmExecutor
4851
pre_ray_start_commands: Optional[list[str]] = None
4952
command: Optional[str] = None
5053
workdir: Optional[str] = None
54+
nemo_run_dir: Optional[str] = None
55+
launch_cmd: list[str]
5156

5257
@staticmethod
5358
def get_job_name(executor: SlurmExecutor, name: str) -> str:
@@ -125,8 +130,19 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str:
125130
else:
126131
_srun_flags.append("--gres=gpu:8")
127132

128-
_srun_flags.append(f"--container-workdir={self.cluster_dir}")
129-
_srun_flags += ["--container-mounts", ",".join(mounts)]
133+
if self.nemo_run_dir:
134+
new_mounts = copy.deepcopy(mounts)
135+
for i, mount in enumerate(new_mounts):
136+
if mount.startswith(RUNDIR_SPECIAL_NAME):
137+
new_mounts[i] = mount.replace(RUNDIR_SPECIAL_NAME, self.nemo_run_dir, 1)
138+
139+
new_mounts.append(f"{self.nemo_run_dir}:/{RUNDIR_NAME}")
140+
else:
141+
new_mounts = mounts
142+
143+
_srun_flags += ["--container-mounts", ",".join(new_mounts)]
144+
container_workdir = self.workdir or self.cluster_dir
145+
_srun_flags.append(f"--container-workdir={container_workdir}")
130146

131147
return " ".join(_srun_flags)
132148

@@ -148,7 +164,12 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str:
148164
if self.pre_ray_start_commands:
149165
vars_to_fill["pre_ray_start_commands"] = "\n".join(self.pre_ray_start_commands)
150166

151-
sbatch_script = fill_template(self.template_path, vars_to_fill)
167+
sbatch_script = fill_template(
168+
self.template_name,
169+
vars_to_fill,
170+
template_dir=self.template_dir
171+
or os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates"),
172+
)
152173
return sbatch_script
153174

154175
def __repr__(self) -> str:
@@ -254,13 +275,12 @@ def create_ray_cluster(
254275
ray_sbatch = SlurmRayRequest(
255276
name=name,
256277
cluster_dir=cluster_dir,
257-
template_path=os.path.join(
258-
os.path.dirname(os.path.abspath(__file__)), "templates", "ray.sub.j2"
259-
),
278+
template_name="ray.sub.j2",
260279
executor=executor,
261280
pre_ray_start_commands=pre_ray_start_commands,
262281
command=command,
263282
workdir=workdir,
283+
launch_cmd=["sbatch", "--requeue", "--parsable"],
264284
).materialize()
265285

266286
if dryrun:
@@ -321,6 +341,7 @@ def schedule_ray_job(
321341
command: str,
322342
workdir: Optional[str] = None,
323343
pre_ray_start_commands: Optional[list[str]] = None,
344+
runtime_env_yaml: Optional[str] = None,
324345
dryrun: bool = False,
325346
):
326347
remote_workdir = None

nemo_run/run/ray/templates/ray.sub.j2

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,10 +289,11 @@ echo "[INFO] Ray cluster information saved to $CLUSTER_DIR/ray_cluster_info.json
289289
# This driver process is responsible for launching a job on the Ray cluster
290290
CONTAINER_CWD=$(scontrol show job $SLURM_JOB_ID --json | jq -r '.jobs[].current_working_directory')
291291
# Define command to be empty by default
292-
COMMAND=""
292+
COMMAND="${COMMAND:-{{ command }}}"
293+
COMMAND_WORKDIR={{ command_workdir | default('$CONTAINER_CWD') }}
293294

294295
if [[ -n "$COMMAND" ]]; then
295-
srun --no-container-mount-home --gpus=0 --overlap --container-name=ray-head --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/ray-driver.log bash -c "$COMMAND"
296+
srun --no-container-mount-home --gpus=0 --overlap --container-name=ray-head --container-workdir=$COMMAND_WORKDIR --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/ray-job.log bash -c "$COMMAND"
296297
else
297298
echo "[INFO]: Ray Cluster is idled, run this on the slurm head node to get a shell to the head node:"
298299
cat <<EOF >$CLUSTER_DIR/scripts/${SLURM_JOB_ID}-attach.sh

0 commit comments

Comments
 (0)