Skip to content

Commit e30ea9e

Browse files
committed
use env in subprocess instead of sh file
1 parent 5077d59 commit e30ea9e

1 file changed

Lines changed: 21 additions & 26 deletions

File tree

src/seml/commands/start.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import logging
55
import math
66
import os
7-
import shlex
87
import shutil
98
import subprocess
109
import sys
@@ -84,7 +83,7 @@ def set_slurm_job_name(
8483
"Can't set sbatch `job-name` parameter explicitly. "
8584
'Use `name` parameter instead and SEML will do that for you.'
8685
)
87-
job_name = f"{name}_{exp['batch_id']}"
86+
job_name = f'{name}_{exp["batch_id"]}'
8887
sbatch_options['job-name'] = job_name
8988
if sbatch_options.get('comment', db_collection_name) != db_collection_name:
9089
raise ConfigError(
@@ -94,9 +93,7 @@ def set_slurm_job_name(
9493
sbatch_options['comment'] = db_collection_name
9594

9695

97-
def create_slurm_options_string(
98-
slurm_options: SBatchOptions, env: dict[str, str] | None = None, srun: bool = False
99-
):
96+
def create_slurm_options_string(slurm_options: SBatchOptions, srun: bool = False):
10097
"""
10198
Convert a dictionary with sbatch_options into a string that can be used in a bash script.
10299
@@ -124,11 +121,9 @@ def create_slurm_options_string(
124121
slurm_options_str += option_structure.format(
125122
prepend=prepend, key=key, value=value
126123
)
127-
if env is not None:
128-
env_kv = shlex.quote(','.join(f'{key}={value}' for key, value in env.items()))
129-
slurm_options_str += option_structure.format(
130-
prepend='--', key='export', value=env_kv
131-
)
124+
slurm_options_str += option_structure.format(
125+
prepend='--', key='export', value='ALL'
126+
)
132127
return slurm_options_str
133128

134129

@@ -208,10 +203,10 @@ def start_sbatch_job(
208203
srun_str = '' if experiments_per_job > 1 else 'srun '
209204
# Construct sbatch options string
210205
env = get_experiment_environment(exp_array[0])
211-
sbatch_options_str = create_slurm_options_string(sbatch_options, env, False)
206+
sbatch_options_str = create_slurm_options_string(sbatch_options, False)
212207

213208
# Construct list with all experiment IDs
214-
expid_strings = f"{' '.join([str(exp['_id']) for exp in exp_array])}"
209+
expid_strings = f'{" ".join([str(exp["_id"]) for exp in exp_array])}'
215210

216211
with_sources = 'source_files' in seml_conf
217212
use_conda_env = seml_conf.get('conda_environment')
@@ -261,12 +256,12 @@ def start_sbatch_job(
261256
# Sbatch the script
262257
try:
263258
output = subprocess.run(
264-
f'sbatch {f.name}', shell=True, check=True, capture_output=True
259+
f'sbatch {f.name}', shell=True, check=True, capture_output=True, env=env
265260
).stdout
266261
except subprocess.CalledProcessError as e:
267262
logging.error(
268263
f"Could not start Slurm job via sbatch. Here's the sbatch error message:\n"
269-
f"{e.stderr.decode('utf-8')}"
264+
f'{e.stderr.decode("utf-8")}'
270265
)
271266
exit(1)
272267

@@ -330,20 +325,20 @@ def start_srun_job(
330325
if 'ntasks' not in srun_options:
331326
srun_options['ntasks'] = 1
332327
env = get_experiment_environment(exp)
333-
srun_options_str = create_slurm_options_string(srun_options, env, True)
328+
srun_options_str = create_slurm_options_string(srun_options, True)
334329

335330
# Set command args for job inside Slurm
336-
cmd_args = f"--local --sacred-id {exp['_id']} "
331+
cmd_args = f'--local --sacred-id {exp["_id"]} '
337332
cmd_args += ' '.join(seml_arguments)
338333

339334
cmd = f'srun{srun_options_str} seml {collection.name} start {cmd_args}'
340335
try:
341-
subprocess.run(cmd, shell=True, check=True)
336+
subprocess.run(cmd, shell=True, check=True, env=env)
342337
except subprocess.CalledProcessError as e:
343338
if e.stderr:
344339
logging.error(
345340
f"Could not start Slurm job via srun. Here's the sbatch error message:\n"
346-
f"{e.stderr.decode('utf-8')}"
341+
f'{e.stderr.decode("utf-8")}'
347342
)
348343
else:
349344
logging.error('Could not start Slurm job via srun.')
@@ -420,7 +415,7 @@ def start_local_job(
420415

421416
if output_dir_path:
422417
exp_name = get_exp_name(exp, collection.name)
423-
output_file = f"{output_dir_path}/{exp_name}_{exp['_id']}.out"
418+
output_file = f'{output_dir_path}/{exp_name}_{exp["_id"]}.out'
424419
if not unobserved:
425420
collection.update_one(
426421
{'_id': exp['_id']}, {'$set': {'seml.output_file': output_file}}
@@ -432,10 +427,10 @@ def start_local_job(
432427

433428
if seml_config.get('conda_environment') is not None:
434429
cmd = (
435-
f". $(conda info --base)/etc/profile.d/conda.sh "
436-
f"&& conda activate {seml_config['conda_environment']} "
437-
f"&& {cmd} "
438-
f"&& conda deactivate"
430+
f'. $(conda info --base)/etc/profile.d/conda.sh '
431+
f'&& conda activate {seml_config["conda_environment"]} '
432+
f'&& {cmd} '
433+
f'&& conda deactivate'
439434
)
440435

441436
if not unobserved:
@@ -793,7 +788,7 @@ def start_local_worker(
793788

794789
# Cancel Slurm job; after cleaning up to prevent race conditions
795790
if prompt(
796-
f"SLURM is currently executing experiment {exp['_id']}, do you want to cancel the SLURM job?",
791+
f'SLURM is currently executing experiment {exp["_id"]}, do you want to cancel the SLURM job?',
797792
type=bool,
798793
):
799794
cancel_experiment_by_id(
@@ -804,7 +799,7 @@ def start_local_worker(
804799
)
805800

806801
progress.console.print(
807-
f"current id : {exp['_id']}, failed={num_exceptions}/{jobs_counter} experiments"
802+
f'current id : {exp["_id"]}, failed={num_exceptions}/{jobs_counter} experiments'
808803
)
809804

810805
# Add newline if we need to avoid tqdm's output
@@ -996,7 +991,7 @@ def start_jupyter_job(
996991
except subprocess.CalledProcessError as e:
997992
logging.error(
998993
f"Could not start Slurm job via sbatch. Here's the sbatch error message:\n"
999-
f"{e.stderr.decode('utf-8')}"
994+
f'{e.stderr.decode("utf-8")}'
1000995
)
1001996
os.remove(path)
1002997
exit(1)

0 commit comments

Comments
 (0)