Skip to content

Commit c252b9c

Browse files
committed
Expose gather mode to tridesclous2 and spykingcircus2
1 parent 2c6e800 commit c252b9c

6 files changed

Lines changed: 93 additions & 19 deletions

File tree

src/spikeinterface/core/node_pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -560,14 +560,14 @@ def run_node_pipeline(
560560
The classical job_kwargs
561561
job_name : str
562562
The name of the pipeline used for the progress_bar
563-
gather_mode : "memory" | "npz"
564-
563+
gather_mode : "memory" | "npy"
564+
How to gather the output of the nodes.
565565
gather_kwargs : dict
566566
OPtions to control the "gather engine". See GatherToMemory or GatherToNpy.
567567
squeeze_output : bool, default True
568568
If only one output node then squeeze the tuple
569569
folder : str | Path | None
570-
Used for gather_mode="npz"
570+
Used for gather_mode="npy"
571571
names : list of str
572572
Names of outputs.
573573
verbose : bool, default False

src/spikeinterface/sorters/internal/spyking_circus2.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
5151
},
5252
},
5353
"clustering": {"legacy": True},
54-
"matching": {"method": "circus-omp-svd"},
54+
"matching": {"method": "circus-omp-svd", "gather_mode": "memory"},
5555
"apply_preprocessing": True,
5656
"matched_filtering": True,
5757
"cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True},
@@ -321,14 +321,22 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
321321

322322
## We launch a OMP matching pursuit by full convolution of the templates and the raw traces
323323
matching_method = params["matching"].pop("method")
324-
matching_params = params["matching"].copy()
324+
gather_mode = params["matching"].pop("gather_mode", "memory")
325+
gather_kwargs = params["matching"].pop("gather_kwargs", {})
326+
matching_params = params["matching"].get("method_kwargs", {}).copy()
325327
matching_params["templates"] = templates
326328

327329
if matching_method is not None:
330+
if gather_mode == "npy":
331+
gather_kwargs["folder"] = gather_kwargs.get("folder", sorter_output_folder / "matching")
328332
spikes = find_spikes_from_templates(
329-
recording_w, matching_method, method_kwargs=matching_params, **job_kwargs
333+
recording_w,
334+
matching_method,
335+
method_kwargs=matching_params,
336+
gather_mode=gather_mode,
337+
gather_kwargs=gather_kwargs,
338+
**job_kwargs,
330339
)
331-
332340
if params["debug"]:
333341
fitting_folder = sorter_output_folder / "fitting"
334342
fitting_folder.mkdir(parents=True, exist_ok=True)

src/spikeinterface/sorters/internal/tests/test_spykingcircus2.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,35 @@
11
import unittest
22

33
from spikeinterface.sorters.tests.common_tests import SorterCommonTestSuite
4-
5-
from spikeinterface.sorters import Spykingcircus2Sorter
4+
from spikeinterface.sorters import Spykingcircus2Sorter, run_sorter
65

76
from pathlib import Path
87

98

109
class SpykingCircus2SorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase):
1110
SorterClass = Spykingcircus2Sorter
1211

12+
def test_with_numpy_gather(self):
13+
recording = self.recording
14+
sorter_name = self.SorterClass.sorter_name
15+
output_folder = self.cache_folder / sorter_name
16+
sorter_params = self.SorterClass.default_params()
17+
18+
sorter_params["matching"]["gather_mode"] = "npy"
19+
20+
sorting = run_sorter(
21+
sorter_name,
22+
recording,
23+
folder=output_folder,
24+
remove_existing_folder=True,
25+
delete_output_folder=False,
26+
verbose=False,
27+
raise_error=True,
28+
**sorter_params,
29+
)
30+
assert (output_folder / "sorter_output" / "matching").is_dir()
31+
assert (output_folder / "sorter_output" / "matching" / "spikes.npy").is_file()
32+
1333

1434
if __name__ == "__main__":
1535
from spikeinterface import set_global_job_kwargs

src/spikeinterface/sorters/internal/tests/test_tridesclous2.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,35 @@
22

33
from spikeinterface.sorters.tests.common_tests import SorterCommonTestSuite
44

5-
from spikeinterface.sorters import Tridesclous2Sorter
5+
from spikeinterface.sorters import Tridesclous2Sorter, run_sorter
66

77
from pathlib import Path
88

99

1010
class Tridesclous2SorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase):
1111
SorterClass = Tridesclous2Sorter
1212

13+
def test_with_numpy_gather(self):
14+
recording = self.recording
15+
sorter_name = self.SorterClass.sorter_name
16+
output_folder = self.cache_folder / sorter_name
17+
sorter_params = self.SorterClass.default_params()
18+
19+
sorter_params["matching"]["gather_mode"] = "npy"
20+
21+
sorting = run_sorter(
22+
sorter_name,
23+
recording,
24+
folder=output_folder,
25+
remove_existing_folder=True,
26+
delete_output_folder=False,
27+
verbose=False,
28+
raise_error=True,
29+
**sorter_params,
30+
)
31+
assert (output_folder / "sorter_output" / "matching").is_dir()
32+
assert (output_folder / "sorter_output" / "matching" / "spikes.npy").is_file()
33+
1334

1435
if __name__ == "__main__":
1536
test = Tridesclous2SorterCommonTestSuite()

src/spikeinterface/sorters/internal/tridesclous2.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class Tridesclous2Sorter(ComponentsBasedSorter):
6363
},
6464
# "matching": {"method": "tridesclous", "method_kwargs": {"peak_shift_ms": 0.2, "radius_um": 100.0}},
6565
# "matching": {"method": "circus-omp-svd", "method_kwargs": {}},
66-
"matching": {"method": "wobble", "method_kwargs": {}},
66+
"matching": {"method": "wobble", "method_kwargs": {}, "gather_mode": "memory"},
6767
"job_kwargs": {"n_jobs": -1},
6868
"save_array": True,
6969
}
@@ -232,13 +232,22 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
232232
templates = remove_empty_templates(templates)
233233

234234
## peeler
235-
matching_method = params["matching"]["method"]
236-
matching_params = params["matching"]["method_kwargs"].copy()
235+
matching_method = params["matching"].pop("method")
236+
gather_mode = params["matching"].pop("gather_mode", "memory")
237+
gather_kwargs = params["matching"].pop("gather_kwargs", {})
238+
matching_params = params["matching"].get("matching_kwargs", {}).copy()
237239
matching_params["templates"] = templates
238-
if params["matching"]["method"] in ("tdc-peeler",):
240+
if matching_method in ("tdc-peeler",):
239241
matching_params["noise_levels"] = noise_levels
242+
if gather_mode == "npy":
243+
gather_kwargs = {"folder": gather_kwargs.get("folder", sorter_output_folder / "matching")}
240244
spikes = find_spikes_from_templates(
241-
recording_for_peeler, method=matching_method, method_kwargs=matching_params, **job_kwargs
245+
recording_for_peeler,
246+
method=matching_method,
247+
method_kwargs=matching_params,
248+
gather_mode=gather_mode,
249+
gather_kwargs=gather_kwargs,
250+
**job_kwargs,
242251
)
243252

244253
if params["save_array"]:

src/spikeinterface/sortingcomponents/matching/main.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,14 @@
1111

1212

1313
def find_spikes_from_templates(
14-
recording, method="naive", method_kwargs={}, extra_outputs=False, verbose=False, **job_kwargs
14+
recording,
15+
method="naive",
16+
method_kwargs={},
17+
extra_outputs=False,
18+
gather_mode="memory",
19+
gather_kwargs=None,
20+
verbose=False,
21+
**job_kwargs,
1522
) -> np.ndarray | tuple[np.ndarray, dict]:
1623
"""Find spike from a recording from given templates.
1724
@@ -25,10 +32,14 @@ def find_spikes_from_templates(
2532
Keyword arguments for the chosen method
2633
extra_outputs : bool
2734
If True then a dict is also returned is also returned
28-
**job_kwargs : dict
29-
Parameters for ChunkRecordingExecutor
35+
gather_mode : "memory" | "npy", default: "memory"
36+
If "memory" then the output is gathered in memory, if "npy" then the output is gathered on disk
37+
gather_kwargs : dict, optional
38+
The kwargs for the gather method
3039
verbose : Bool, default: False
3140
If True, output is verbose
41+
**job_kwargs : keyword arguments
42+
Parameters for ChunkRecordingExecutor
3243
3344
Returns
3445
-------
@@ -47,13 +58,18 @@ def find_spikes_from_templates(
4758
node0 = method_class(recording, **method_kwargs)
4859
nodes = [node0]
4960

61+
gather_kwargs = gather_kwargs or {}
62+
names = ["spikes"]
63+
5064
spikes = run_node_pipeline(
5165
recording,
5266
nodes,
5367
job_kwargs,
5468
job_name=f"find spikes ({method})",
55-
gather_mode="memory",
69+
gather_mode=gather_mode,
5670
squeeze_output=True,
71+
names=names,
72+
**gather_kwargs,
5773
)
5874
if extra_outputs:
5975
outputs = node0.get_extra_outputs()

0 commit comments

Comments
 (0)