@@ -748,10 +748,10 @@ def _as_sbatch_flag(key: str, value: Any) -> str:
748748
749749@dataclass (kw_only = True )
750750class SlurmBatchRequest :
751- cmd : list [str ]
751+ launch_cmd : list [str ]
752752 jobs : list [str ]
753753 command_groups : list [list [str ]]
754- slurm_config : SlurmExecutor
754+ executor : SlurmExecutor
755755 max_retries : int
756756 setup : Optional [list [str ]] = None
757757 extra_env : dict [str , str ]
@@ -786,7 +786,7 @@ def materialize(self) -> str:
786786 In case an erroneous keyword argument is added, a list of all eligible parameters
787787 is printed, with their default values
788788 """
789- args = asdict (self .slurm_config ) # noqa: F821
789+ args = asdict (self .executor ) # noqa: F821
790790 parameters = {
791791 k : v for k , v in args .items () if v is not None and k in SlurmExecutor .SBATCH_FLAGS
792792 }
@@ -800,18 +800,16 @@ def materialize(self) -> str:
800800 # add necessary parameters
801801 original_job_name : str = self .jobs [0 ] # type: ignore
802802 job_name_prefix = (
803- self .slurm_config .job_name_prefix
804- if self .slurm_config .job_name_prefix
805- else f"{ self .slurm_config .account } -{ self .slurm_config .account .split ('_' )[- 1 ]} ."
803+ self .executor .job_name_prefix
804+ if self .executor .job_name_prefix
805+ else f"{ self .executor .account } -{ self .executor .account .split ('_' )[- 1 ]} ."
806806 )
807807 job_name = f"{ job_name_prefix } { original_job_name } "
808808 slurm_job_dir = (
809- self .slurm_config .tunnel .job_dir
810- if self .slurm_config .tunnel
811- else self .slurm_config .job_dir
809+ self .executor .tunnel .job_dir if self .executor .tunnel else self .executor .job_dir
812810 )
813- job_directory_name = Path (self .slurm_config .job_dir ).name
814- job_details = self .slurm_config .job_details
811+ job_directory_name = Path (self .executor .job_dir ).name
812+ job_details = self .executor .job_details
815813
816814 if not job_details .job_name :
817815 job_details .job_name = job_name
@@ -824,41 +822,41 @@ def materialize(self) -> str:
824822 stdout = str (job_details .stdout )
825823 stderr = str (job_details .stderr )
826824
827- if self .slurm_config .array is not None :
825+ if self .executor .array is not None :
828826 stdout = stdout .replace ("%j" , "%A_%a" )
829827 stderr = stderr .replace ("%j" , "%A_%a" )
830828 parameters ["output" ] = stdout .replace ("%t" , "0" )
831829
832- if not self .slurm_config .stderr_to_stdout :
830+ if not self .executor .stderr_to_stdout :
833831 parameters ["error" ] = stderr .replace ("%t" , "0" )
834832
835- if self .slurm_config .additional_parameters is not None :
836- parameters .update (self .slurm_config .additional_parameters )
833+ if self .executor .additional_parameters is not None :
834+ parameters .update (self .executor .additional_parameters )
837835
838836 # now create
839- sbatch_cmd = " " .join ([shlex .quote (arg ) for arg in self .cmd ])
837+ sbatch_cmd = " " .join ([shlex .quote (arg ) for arg in self .launch_cmd ])
840838
841839 sbatch_flags = []
842- if self .slurm_config .heterogeneous :
843- assert len (self .jobs ) == len (self .slurm_config .resource_group ), (
844- f"Number of jobs { len (self .jobs )} must match number of resource group requests { len (self .slurm_config .resource_group )} .\n If you are just submitting a single job, make sure that heterogeneous=False in the executor."
840+ if self .executor .heterogeneous :
841+ assert len (self .jobs ) == len (self .executor .resource_group ), (
842+ f"Number of jobs { len (self .jobs )} must match number of resource group requests { len (self .executor .resource_group )} .\n If you are just submitting a single job, make sure that heterogeneous=False in the executor."
845843 )
846- final_group_index = len (self .slurm_config .resource_group ) - 1
847- if self .slurm_config .het_group_indices :
848- final_group_index = self .slurm_config .het_group_indices .index (
849- max (self .slurm_config .het_group_indices )
844+ final_group_index = len (self .executor .resource_group ) - 1
845+ if self .executor .het_group_indices :
846+ final_group_index = self .executor .het_group_indices .index (
847+ max (self .executor .het_group_indices )
850848 )
851849
852- for i in range (len (self .slurm_config .resource_group )):
853- resource_req = self .slurm_config .resource_group [i ]
850+ for i in range (len (self .executor .resource_group )):
851+ resource_req = self .executor .resource_group [i ]
854852 if resource_req .het_group_index :
855- assert self .slurm_config .resource_group [i - 1 ].het_group_index is not None , (
853+ assert self .executor .resource_group [i - 1 ].het_group_index is not None , (
856854 "het_group_index must be set for all requests in resource_group"
857855 )
858856 if (
859857 i > 0
860858 and resource_req .het_group_index
861- == self .slurm_config .resource_group [i - 1 ].het_group_index
859+ == self .executor .resource_group [i - 1 ].het_group_index
862860 ):
863861 continue
864862
@@ -887,33 +885,31 @@ def materialize(self) -> str:
887885 for k in sorted (parameters ):
888886 sbatch_flags .append (_as_sbatch_flag (k , parameters [k ]))
889887
890- if self .slurm_config .dependencies :
891- slurm_deps = self .slurm_config .parse_deps ()
888+ if self .executor .dependencies :
889+ slurm_deps = self .executor .parse_deps ()
892890 sbatch_flags .append (
893891 _as_sbatch_flag (
894- "dependency" , f"{ self .slurm_config .dependency_type } :{ ':' .join (slurm_deps )} "
892+ "dependency" , f"{ self .executor .dependency_type } :{ ':' .join (slurm_deps )} "
895893 )
896894 )
897895
898896 env_vars = []
899- full_env_vars = self .slurm_config .env_vars | self .extra_env
897+ full_env_vars = self .executor .env_vars | self .extra_env
900898 for key , value in full_env_vars .items ():
901899 env_vars .append (f"export { key .upper ()} ={ value } " )
902900
903901 # commandline (this will run the function and args specified in the file provided as argument)
904902 # We pass --output and --error here, because the SBATCH command doesn't work as expected with a filename pattern
905- stderr_flags = [] if self .slurm_config .stderr_to_stdout else ["--error" , stderr ]
903+ stderr_flags = [] if self .executor .stderr_to_stdout else ["--error" , stderr ]
906904
907905 srun_commands = []
908906 group_env_vars = []
909907 srun_stdout = noquote (job_details .srun_stdout )
910908 stderr_flags = (
911- []
912- if self .slurm_config .stderr_to_stdout
913- else ["--error" , noquote (job_details .srun_stderr )]
909+ [] if self .executor .stderr_to_stdout else ["--error" , noquote (job_details .srun_stderr )]
914910 )
915911 memory_measure_out = None
916- if self .slurm_config .memory_measure :
912+ if self .executor .memory_measure :
917913 memory_measure_out = srun_stdout
918914
919915 def get_container_flags (
@@ -937,10 +933,10 @@ def get_container_flags(
937933 return _container_flags
938934
939935 for group_ind , command_group in enumerate (self .command_groups ):
940- if self .slurm_config .run_as_group and len (self .slurm_config .resource_group ) == len (
936+ if self .executor .run_as_group and len (self .executor .resource_group ) == len (
941937 self .command_groups
942938 ):
943- resource_req = self .slurm_config .resource_group [group_ind ]
939+ resource_req = self .executor .resource_group [group_ind ]
944940 if not resource_req .job_details .job_name :
945941 resource_req .job_details .job_name = f"{ job_name_prefix } { self .jobs [group_ind ]} "
946942
@@ -952,7 +948,7 @@ def get_container_flags(
952948 cmd_stdout = noquote (resource_req .job_details .srun_stdout )
953949 cmd_stderr = (
954950 []
955- if self .slurm_config .stderr_to_stdout
951+ if self .executor .stderr_to_stdout
956952 else [
957953 "--error" ,
958954 noquote (resource_req .job_details .srun_stderr ),
@@ -980,20 +976,20 @@ def get_container_flags(
980976 if cmd_stderr :
981977 cmd_stderr [- 1 ] = cmd_stderr [- 1 ].replace (original_job_name , self .jobs [group_ind ])
982978 _container_flags = get_container_flags (
983- base_mounts = self .slurm_config .container_mounts ,
979+ base_mounts = self .executor .container_mounts ,
984980 src_job_dir = os .path .join (
985981 slurm_job_dir ,
986982 job_directory_name ,
987983 ),
988- container_image = self .slurm_config .container_image ,
984+ container_image = self .executor .container_image ,
989985 )
990986 _srun_args = ["--wait=60" , "--kill-on-bad-exit=1" ]
991- _srun_args .extend (self .slurm_config .srun_args or [])
987+ _srun_args .extend (self .executor .srun_args or [])
992988
993- if self .slurm_config .run_as_group and self .slurm_config .heterogeneous :
989+ if self .executor .run_as_group and self .executor .heterogeneous :
994990 het_group_index = (
995- self .slurm_config .resource_group [group_ind ].het_group_index
996- if self .slurm_config .resource_group [group_ind ].het_group_index is not None
991+ self .executor .resource_group [group_ind ].het_group_index
992+ if self .executor .resource_group [group_ind ].het_group_index is not None
997993 else group_ind
998994 )
999995 het_group_flag = [f"--het-group={ het_group_index } " ]
@@ -1018,10 +1014,10 @@ def get_container_flags(
10181014 )
10191015 command = " " .join (command_group )
10201016
1021- if self .slurm_config .run_as_group :
1017+ if self .executor .run_as_group :
10221018 srun_command = f"{ srun_cmd } { command } & pids[{ group_ind } ]=$!"
10231019 if group_ind != len (self .command_groups ) - 1 :
1024- srun_command += f"\n \n sleep { self .slurm_config .wait_time_for_group_job } \n "
1020+ srun_command += f"\n \n sleep { self .executor .wait_time_for_group_job } \n "
10251021 else :
10261022 srun_command = f"{ srun_cmd } { command } "
10271023
@@ -1033,15 +1029,14 @@ def get_container_flags(
10331029 "max_retries" : self .max_retries ,
10341030 "env_vars" : env_vars ,
10351031 "head_node_ip_var" : SlurmExecutor .HEAD_NODE_IP_VAR ,
1036- "setup_lines" : self .slurm_config .setup_lines ,
1032+ "setup_lines" : self .executor .setup_lines ,
10371033 "memory_measure" : memory_measure_out ,
10381034 "srun_commands" : srun_commands ,
10391035 "group_env_vars" : group_env_vars ,
1040- "heterogeneous" : self .slurm_config .heterogeneous ,
1041- "run_as_group" : self .slurm_config .run_as_group ,
1042- "monitor_group_job" : self .slurm_config .run_as_group
1043- and self .slurm_config .monitor_group_job ,
1044- "monitor_group_job_wait_time" : self .slurm_config .monitor_group_job_wait_time ,
1036+ "heterogeneous" : self .executor .heterogeneous ,
1037+ "run_as_group" : self .executor .run_as_group ,
1038+ "monitor_group_job" : self .executor .run_as_group and self .executor .monitor_group_job ,
1039+ "monitor_group_job_wait_time" : self .executor .monitor_group_job_wait_time ,
10451040 "het_group_host_var" : SlurmExecutor .HET_GROUP_HOST_VAR ,
10461041 "ft_enabled" : self .launcher and isinstance (self .launcher , FaultTolerance ),
10471042 }
@@ -1060,7 +1055,7 @@ def get_container_flags(
10601055 return sbatch_script
10611056
10621057 def __repr__ (self ) -> str :
1063- return f"""{ " " .join (self .cmd + ["$SBATCH_SCRIPT" ])}
1058+ return f"""{ " " .join (self .launch_cmd + ["$SBATCH_SCRIPT" ])}
10641059
10651060#----------------
10661061# SBATCH_SCRIPT
0 commit comments