44
55from __future__ import annotations
66
7-
8- from pathlib import Path
9- import shutil
10- import numpy as np
11- import tempfile
127import os
138import stat
149import subprocess
1510import sys
11+ import tempfile
1612import warnings
17-
13+ import numpy as np
14+ from pathlib import Path
1815from spikeinterface .core import aggregate_units
19-
20- from .sorterlist import sorter_dict
2116from .runsorter import run_sorter
22- from .basesorter import is_log_ok
2317
2418_default_engine_kwargs = dict (
2519 loop = dict (),
2620 joblib = dict (n_jobs = - 1 , backend = "loky" ),
2721 processpoolexecutor = dict (max_workers = 2 , mp_context = None ),
2822 dask = dict (client = None ),
29- slurm = dict ( tmp_script_folder = None , cpus_per_task = 1 , mem = " 1G") ,
23+ slurm = { " tmp_script_folder" : None , "sbatch_args" : { "cpus-per-task" : 1 , " mem" : " 1G"}} ,
3024)
3125
32-
3326_implemented_engine = list (_default_engine_kwargs .keys ())
3427
3528
36- def run_sorter_jobs (job_list , engine = "loop" , engine_kwargs = {} , return_output = False ):
29+ def run_sorter_jobs (job_list , engine = "loop" , engine_kwargs = None , return_output = False ):
3730 """
3831 Run several :py:func:`run_sorter()` sequentially or in parallel given a list of jobs.
3932
@@ -55,23 +48,43 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs={}, return_output=Fal
5548
5649 Where *blocking* means that this function is blocking until the results are returned.
5750 This is in opposition to *asynchronous*, where the function returns `None` almost immediately (aka non-blocking),
58- but the results must be retrieved by hand when jobs are finished. No mechanisim is provided here to be know
59- when jobs are finish .
51+ but the results must be retrieved by hand when jobs are finished. No mechanism is provided here to know
52+ when jobs are finished .
6053 In this *asynchronous* case, the :py:func:`~spikeinterface.sorters.read_sorter_folder()` helps to retrieve individual results.
6154
62-
6355 Parameters
6456 ----------
6557 job_list : list of dict
6658 A list a dict that are propagated to run_sorter(...)
6759 engine : str "loop", "joblib", "dask", "slurm"
6860 The engine to run the list.
69- * "loop" : a simple loop. This engine is
7061 engine_kwargs : dict
71-
72- return_output : bool, dfault False
62+ Parameters to be passed to the underlying engine.
63+ * loop : None
64+ * joblib :
65+ - n_jobs : int
66+ The maximum number of concurrently running jobs (default=-1, tries to use all CPUs)
67+ - backend : str
68+ Specify the parallelization backend implementation (default="loky")
69+ * multiprocessing :
70+ - max_workers : int
71+ maximum number of processes (default=2)
72+ - mp_context : str
73+ multiprocessing context (default=None)
74+ * dask :
75+ - client : dask.distributed.Client
76+ Dask client to connect to (required)
77+ * slurm :
78+ - tmp_script_folder : str,Path
79+ the folder in which the job scripts are created (default=None, create a random temporary directory)
80+ - sbatch_args: dict
81+ dictionary of arguments to be passed to the sbatch command. They will be automatically prefixed with --.
82+ Arguments must be in the format slurm specify, see the [documentation for `sbatch`](https://slurm.schedmd.com/sbatch.html)
83+ for a list of possible arguments (default={"cpus-per-task": 1, "mem": "1G"})
84+
85+ return_output : bool, default: False
7386 Return a sortings or None.
74- This also overwrite kwargs in in run_sorter(with_sorting=True/False)
87+ This also overwrites kwargs in run_sorter(with_sorting=True/False)
7588
7689 Returns
7790 -------
@@ -81,6 +94,8 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs={}, return_output=Fal
8194
8295 assert engine in _implemented_engine , f"engine must be in { _implemented_engine } "
8396
97+ if engine_kwargs is None :
98+ engine_kwargs = dict ()
8499 engine_kwargs_ = dict ()
85100 engine_kwargs_ .update (_default_engine_kwargs [engine ])
86101 engine_kwargs_ .update (engine_kwargs )
@@ -145,14 +160,16 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs={}, return_output=Fal
145160 task .result ()
146161
147162 elif engine == "slurm" :
163+ if "cpus_per_task" in engine_kwargs :
164+ raise ValueError (
165+ "keyword argument cpus_per_task is no longer supported for slurm engine, "
166+ "please use cpus-per-task instead."
167+ )
148168 # generate python script for slurm
149169 tmp_script_folder = engine_kwargs ["tmp_script_folder" ]
150170 if tmp_script_folder is None :
151171 tmp_script_folder = tempfile .mkdtemp (prefix = "spikeinterface_slurm_" )
152172 tmp_script_folder = Path (tmp_script_folder )
153- cpus_per_task = engine_kwargs ["cpus_per_task" ]
154- mem = engine_kwargs ["mem" ]
155-
156173 tmp_script_folder .mkdir (exist_ok = True , parents = True )
157174
158175 for i , kwargs in enumerate (job_list ):
@@ -181,7 +198,16 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs={}, return_output=Fal
181198 f .write (slurm_script )
182199 os .fchmod (f .fileno (), mode = stat .S_IRWXU )
183200
184- subprocess .Popen (["sbatch" , str (script_name .absolute ()), f"-cpus-per-task={ cpus_per_task } " , f"-mem={ mem } " ])
201+ progr = ["sbatch" ]
202+ for k , v in engine_kwargs ["sbatch_args" ].items ():
203+ progr .append (f"--{ k } " )
204+ progr .append (f"{ v } " )
205+ progr .append (str (script_name .absolute ()))
206+ print (f"subprocess called with command { ' ' .join (progr )} " )
207+ p = subprocess .run (progr , capture_output = True , text = True )
208+ print (p .stdout )
209+ if len (p .stderr ) > 0 :
210+ warnings .warn (p .stderr )
185211
186212 return out
187213
@@ -209,7 +235,7 @@ def run_sorter_by_property(
209235 folder ,
210236 mode_if_folder_exists = None ,
211237 engine = "loop" ,
212- engine_kwargs = {} ,
238+ engine_kwargs = None ,
213239 verbose = False ,
214240 docker_image = None ,
215241 singularity_image = None ,
@@ -239,13 +265,11 @@ def run_sorter_by_property(
239265 Must be None. This is deprecated.
240266 If not None then a warning is raise.
241267 Will be removed in next release.
242- engine : "loop" | "joblib" | "dask", default: "loop"
268+ engine : "loop" | "joblib" | "dask" | "slurm" , default: "loop"
243269 Which engine to use to run sorter.
244270 engine_kwargs : dict
245- This contains kwargs specific to the launcher engine:
246- * "loop" : no kwargs
247- * "joblib" : {"n_jobs" : } number of processes
248- * "dask" : {"client":} the dask client for submitting task
271+ This contains kwargs specific to the launcher engine.
272+ See the documentation for :py:func:`~spikeinterface.sorters.launcher.run_sorter_jobs()` for more details.
249273 verbose : bool, default: False
250274 Controls sorter verboseness
251275 docker_image : None or str, default: None
0 commit comments