Skip to content

Commit 468af84

Browse files
authored
Merge pull request #3105 from MarinManuel/slurm_updates
update to run_sorter_jobs() and slurm
2 parents fcf2284 + 49dda0d commit 468af84

3 files changed

Lines changed: 132 additions & 31 deletions

File tree

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,9 @@ test_core = [
128128
# for release we need pypi, so this need to be commented
129129
"probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git",
130130
"neo @ git+https://github.com/NeuralEnsemble/python-neo.git",
131+
132+
# for slurm jobs,
133+
"pytest-mock"
131134
]
132135

133136
test_extractors = [
@@ -176,6 +179,9 @@ test = [
176179
# for release we need pypi, so this need to be commented
177180
"probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git",
178181
"neo @ git+https://github.com/NeuralEnsemble/python-neo.git",
182+
183+
# for slurm jobs
184+
"pytest-mock",
179185
]
180186

181187
docs = [

src/spikeinterface/sorters/launcher.py

Lines changed: 53 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,36 +4,29 @@
44

55
from __future__ import annotations
66

7-
8-
from pathlib import Path
9-
import shutil
10-
import numpy as np
11-
import tempfile
127
import os
138
import stat
149
import subprocess
1510
import sys
11+
import tempfile
1612
import warnings
17-
13+
import numpy as np
14+
from pathlib import Path
1815
from spikeinterface.core import aggregate_units
19-
20-
from .sorterlist import sorter_dict
2116
from .runsorter import run_sorter
22-
from .basesorter import is_log_ok
2317

2418
_default_engine_kwargs = dict(
2519
loop=dict(),
2620
joblib=dict(n_jobs=-1, backend="loky"),
2721
processpoolexecutor=dict(max_workers=2, mp_context=None),
2822
dask=dict(client=None),
29-
slurm=dict(tmp_script_folder=None, cpus_per_task=1, mem="1G"),
23+
slurm={"tmp_script_folder": None, "sbatch_args": {"cpus-per-task": 1, "mem": "1G"}},
3024
)
3125

32-
3326
_implemented_engine = list(_default_engine_kwargs.keys())
3427

3528

36-
def run_sorter_jobs(job_list, engine="loop", engine_kwargs={}, return_output=False):
29+
def run_sorter_jobs(job_list, engine="loop", engine_kwargs=None, return_output=False):
3730
"""
3831
Run several :py:func:`run_sorter()` sequentially or in parallel given a list of jobs.
3932
@@ -55,23 +48,43 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs={}, return_output=Fal
5548
5649
Where *blocking* means that this function is blocking until the results are returned.
5750
This is in opposition to *asynchronous*, where the function returns `None` almost immediately (aka non-blocking),
58-
but the results must be retrieved by hand when jobs are finished. No mechanisim is provided here to be know
59-
when jobs are finish.
51+
but the results must be retrieved by hand when jobs are finished. No mechanism is provided here to know
52+
when jobs are finished.
6053
In this *asynchronous* case, the :py:func:`~spikeinterface.sorters.read_sorter_folder()` helps to retrieve individual results.
6154
62-
6355
Parameters
6456
----------
6557
job_list : list of dict
6658
A list a dict that are propagated to run_sorter(...)
6759
engine : str "loop", "joblib", "dask", "slurm"
6860
The engine to run the list.
69-
* "loop" : a simple loop. This engine is
7061
engine_kwargs : dict
71-
72-
return_output : bool, dfault False
62+
Parameters to be passed to the underlying engine.
63+
* loop : None
64+
* joblib :
65+
- n_jobs : int
66+
The maximum number of concurrently running jobs (default=-1, tries to use all CPUs)
67+
- backend : str
68+
Specify the parallelization backend implementation (default="loky")
69+
* multiprocessing :
70+
- max_workers : int
71+
maximum number of processes (default=2)
72+
- mp_context : str
73+
multiprocessing context (default=None)
74+
* dask :
75+
- client : dask.distributed.Client
76+
Dask client to connect to (required)
77+
* slurm :
78+
- tmp_script_folder : str,Path
79+
the folder in which the job scripts are created (default=None, create a random temporary directory)
80+
- sbatch_args: dict
81+
dictionary of arguments to be passed to the sbatch command. They will be automatically prefixed with --.
82+
Arguments must be in the format slurm specify, see the [documentation for `sbatch`](https://slurm.schedmd.com/sbatch.html)
83+
for a list of possible arguments (default={"cpus-per-task": 1, "mem": "1G"})
84+
85+
return_output : bool, default: False
7386
Return a sortings or None.
74-
This also overwrite kwargs in in run_sorter(with_sorting=True/False)
87+
This also overwrites kwargs in run_sorter(with_sorting=True/False)
7588
7689
Returns
7790
-------
@@ -81,6 +94,8 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs={}, return_output=Fal
8194

8295
assert engine in _implemented_engine, f"engine must be in {_implemented_engine}"
8396

97+
if engine_kwargs is None:
98+
engine_kwargs = dict()
8499
engine_kwargs_ = dict()
85100
engine_kwargs_.update(_default_engine_kwargs[engine])
86101
engine_kwargs_.update(engine_kwargs)
@@ -145,14 +160,16 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs={}, return_output=Fal
145160
task.result()
146161

147162
elif engine == "slurm":
163+
if "cpus_per_task" in engine_kwargs:
164+
raise ValueError(
165+
"keyword argument cpus_per_task is no longer supported for slurm engine, "
166+
"please use cpus-per-task instead."
167+
)
148168
# generate python script for slurm
149169
tmp_script_folder = engine_kwargs["tmp_script_folder"]
150170
if tmp_script_folder is None:
151171
tmp_script_folder = tempfile.mkdtemp(prefix="spikeinterface_slurm_")
152172
tmp_script_folder = Path(tmp_script_folder)
153-
cpus_per_task = engine_kwargs["cpus_per_task"]
154-
mem = engine_kwargs["mem"]
155-
156173
tmp_script_folder.mkdir(exist_ok=True, parents=True)
157174

158175
for i, kwargs in enumerate(job_list):
@@ -181,7 +198,16 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs={}, return_output=Fal
181198
f.write(slurm_script)
182199
os.fchmod(f.fileno(), mode=stat.S_IRWXU)
183200

184-
subprocess.Popen(["sbatch", str(script_name.absolute()), f"-cpus-per-task={cpus_per_task}", f"-mem={mem}"])
201+
progr = ["sbatch"]
202+
for k, v in engine_kwargs["sbatch_args"].items():
203+
progr.append(f"--{k}")
204+
progr.append(f"{v}")
205+
progr.append(str(script_name.absolute()))
206+
print(f"subprocess called with command {' '.join(progr)}")
207+
p = subprocess.run(progr, capture_output=True, text=True)
208+
print(p.stdout)
209+
if len(p.stderr) > 0:
210+
warnings.warn(p.stderr)
185211

186212
return out
187213

@@ -209,7 +235,7 @@ def run_sorter_by_property(
209235
folder,
210236
mode_if_folder_exists=None,
211237
engine="loop",
212-
engine_kwargs={},
238+
engine_kwargs=None,
213239
verbose=False,
214240
docker_image=None,
215241
singularity_image=None,
@@ -239,13 +265,11 @@ def run_sorter_by_property(
239265
Must be None. This is deprecated.
240266
If not None then a warning is raise.
241267
Will be removed in next release.
242-
engine : "loop" | "joblib" | "dask", default: "loop"
268+
engine : "loop" | "joblib" | "dask" | "slurm", default: "loop"
243269
Which engine to use to run sorter.
244270
engine_kwargs : dict
245-
This contains kwargs specific to the launcher engine:
246-
* "loop" : no kwargs
247-
* "joblib" : {"n_jobs" : } number of processes
248-
* "dask" : {"client":} the dask client for submitting task
271+
This contains kwargs specific to the launcher engine.
272+
See the documentation for :py:func:`~spikeinterface.sorters.launcher.run_sorter_jobs()` for more details.
249273
verbose : bool, default: False
250274
Controls sorter verboseness
251275
docker_image : None or str, default: None

src/spikeinterface/sorters/tests/test_launcher.py

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import sys
22
import shutil
3+
import tempfile
34
import time
4-
55
import pytest
66
from pathlib import Path
7-
7+
from platform import system
88
from spikeinterface import generate_ground_truth_recording
99
from spikeinterface.sorters import run_sorter_jobs, run_sorter_by_property
1010

@@ -126,6 +126,77 @@ def test_run_sorter_jobs_slurm(job_list, create_cache_folder):
126126
)
127127

128128

129+
@pytest.mark.skipif(system() != "Linux", reason="Assumes we are on Linux to run SLURM")
130+
def test_run_sorter_jobs_slurm_kwargs(mocker, tmp_path, job_list):
131+
"""
132+
Mock `subprocess.run()` to check that engine_kwargs are
133+
propagated to the call as expected.
134+
"""
135+
# First, mock `subprocess.run()`, set up a call to `run_sorter_jobs`
136+
# then check the mocked `subprocess.run()` was called with the
137+
# expected signature. Two jobs are passed in `jobs_list`, first
138+
# check the most recent call.
139+
mock_subprocess_run = mocker.patch("spikeinterface.sorters.launcher.subprocess.run")
140+
141+
tmp_script_folder = tmp_path / "slurm_scripts"
142+
143+
engine_kwargs = dict(
144+
tmp_script_folder=tmp_script_folder,
145+
sbatch_args={
146+
"cpus-per-task": 32,
147+
"mem": "32G",
148+
"gres": "gpu:1",
149+
"any_random_kwarg": 12322,
150+
},
151+
)
152+
153+
run_sorter_jobs(job_list, engine="slurm", engine_kwargs=engine_kwargs)
154+
155+
script_0_path = f"{tmp_script_folder}/si_script_0.py"
156+
script_1_path = f"{tmp_script_folder}/si_script_1.py"
157+
158+
expected_command = [
159+
"sbatch",
160+
"--cpus-per-task",
161+
"32",
162+
"--mem",
163+
"32G",
164+
"--gres",
165+
"gpu:1",
166+
"--any_random_kwarg",
167+
"12322",
168+
script_1_path,
169+
]
170+
mock_subprocess_run.assert_called_with(expected_command, capture_output=True, text=True)
171+
172+
# Next, check the fisrt call (which sets up `si_script_0.py`)
173+
# also has the expected arguments.
174+
expected_command[9] = script_0_path
175+
assert mock_subprocess_run.call_args_list[0].args[0] == expected_command
176+
177+
# Next, check that defaults are used properly when no kwargs are
178+
# passed. This will default to `_default_engine_kwargs` as
179+
# set in `launcher.py`
180+
run_sorter_jobs(
181+
job_list,
182+
engine="slurm",
183+
engine_kwargs={"tmp_script_folder": tmp_script_folder},
184+
)
185+
expected_command = ["sbatch", "--cpus-per-task", "1", "--mem", "1G", script_1_path]
186+
mock_subprocess_run.assert_called_with(expected_command, capture_output=True, text=True)
187+
188+
# Finally, check that the `tmp_script_folder` is generated on the
189+
# fly as expected. A random foldername is generated, just check that
190+
# the folder to which the scripts are saved is in the `tempfile` format.
191+
run_sorter_jobs(
192+
job_list,
193+
engine="slurm",
194+
engine_kwargs=None,
195+
)
196+
tmp_script_folder = "_".join(tempfile.mkdtemp(prefix="spikeinterface_slurm_").split("_")[:-1])
197+
assert tmp_script_folder in mock_subprocess_run.call_args_list[-1].args[0][5]
198+
199+
129200
def test_run_sorter_by_property(create_cache_folder):
130201
cache_folder = create_cache_folder
131202
working_folder1 = cache_folder / "test_run_sorter_by_property_1"

0 commit comments

Comments
 (0)