Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,14 +571,14 @@ def run_node_pipeline(
The classical job_kwargs
job_name : str
The name of the pipeline used for the progress_bar
gather_mode : "memory" | "npz"

gather_mode : "memory" | "npy"
How to gather the output of the nodes.
gather_kwargs : dict
OPtions to control the "gather engine". See GatherToMemory or GatherToNpy.
squeeze_output : bool, default True
If only one output node then squeeze the tuple
folder : str | Path | None
Used for gather_mode="npz"
Used for gather_mode="npy"
names : list of str
Names of outputs.
verbose : bool, default False
Expand Down
12 changes: 11 additions & 1 deletion src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,13 +365,23 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
templates.to_zarr(folder_path=clustering_folder / "templates")

## We launch a OMP matching pursuit by full convolution of the templates and the raw traces

matching_method = params["matching"].get("method", "circus-omp_svd")
gather_mode = params["matching"].pop("gather_mode", "memory")
matching_params = params["matching"].get("method_kwargs", dict())
matching_params["templates"] = templates

if matching_method is not None:
gather_kwargs = {}
if gather_mode == "npy":
gather_kwargs["folder"] = sorter_output_folder / "matching"
spikes = find_spikes_from_templates(
recording_w, matching_method, method_kwargs=matching_params, **job_kwargs
recording_w,
matching_method,
method_kwargs=matching_params,
gather_mode=gather_mode,
gather_kwargs=gather_kwargs,
**job_kwargs,
)

if debug:
Expand Down
25 changes: 23 additions & 2 deletions src/spikeinterface/sorters/internal/tests/test_spykingcircus2.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,36 @@
import unittest

from spikeinterface.sorters.tests.common_tests import SorterCommonTestSuite

from spikeinterface.sorters import Spykingcircus2Sorter
from spikeinterface.sorters import Spykingcircus2Sorter, run_sorter

from pathlib import Path


class SpykingCircus2SorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase):
SorterClass = Spykingcircus2Sorter

@unittest.skip("performance reason")
def test_with_numpy_gather(self):
Comment thread
alejoe91 marked this conversation as resolved.
recording = self.recording
sorter_name = self.SorterClass.sorter_name
output_folder = self.cache_folder / sorter_name
sorter_params = self.SorterClass.default_params()

sorter_params["matching"]["gather_mode"] = "npy"

sorting = run_sorter(
sorter_name,
recording,
folder=output_folder,
remove_existing_folder=True,
delete_output_folder=False,
verbose=False,
raise_error=True,
**sorter_params,
)
assert (output_folder / "sorter_output" / "matching").is_dir()
assert (output_folder / "sorter_output" / "matching" / "spikes.npy").is_file()


if __name__ == "__main__":
from spikeinterface import set_global_job_kwargs
Expand Down
24 changes: 23 additions & 1 deletion src/spikeinterface/sorters/internal/tests/test_tridesclous2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,36 @@

from spikeinterface.sorters.tests.common_tests import SorterCommonTestSuite

from spikeinterface.sorters import Tridesclous2Sorter
from spikeinterface.sorters import Tridesclous2Sorter, run_sorter

from pathlib import Path


class Tridesclous2SorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase):
SorterClass = Tridesclous2Sorter

@unittest.skip("performance reason")
def test_with_numpy_gather(self):
Comment thread
alejoe91 marked this conversation as resolved.
recording = self.recording
sorter_name = self.SorterClass.sorter_name
output_folder = self.cache_folder / sorter_name
sorter_params = self.SorterClass.default_params()

sorter_params["matching"]["gather_mode"] = "npy"

sorting = run_sorter(
sorter_name,
recording,
folder=output_folder,
remove_existing_folder=True,
delete_output_folder=False,
verbose=False,
raise_error=True,
**sorter_params,
)
assert (output_folder / "sorter_output" / "matching").is_dir()
assert (output_folder / "sorter_output" / "matching" / "spikes.npy").is_file()


if __name__ == "__main__":
test = Tridesclous2SorterCommonTestSuite()
Expand Down
19 changes: 14 additions & 5 deletions src/spikeinterface/sorters/internal/tridesclous2.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class Tridesclous2Sorter(ComponentsBasedSorter):
},
# "matching": {"method": "tridesclous", "method_kwargs": {"peak_shift_ms": 0.2, "radius_um": 100.0}},
# "matching": {"method": "circus-omp-svd", "method_kwargs": {}},
"matching": {"method": "wobble", "method_kwargs": {}},
"matching": {"method": "wobble", "method_kwargs": {}, "gather_mode": "memory"},
"job_kwargs": {"n_jobs": -1},
"save_array": True,
}
Expand Down Expand Up @@ -232,13 +232,22 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
templates = remove_empty_templates(templates)

## peeler
matching_method = params["matching"]["method"]
matching_params = params["matching"]["method_kwargs"].copy()
matching_method = params["matching"].pop("method")
gather_mode = params["matching"].pop("gather_mode", "memory")
matching_params = params["matching"].get("matching_kwargs", {}).copy()
matching_params["templates"] = templates
if params["matching"]["method"] in ("tdc-peeler",):
if matching_method in ("tdc-peeler",):
matching_params["noise_levels"] = noise_levels
gather_kwargs = {}
if gather_mode == "npy":
gather_kwargs["folder"] = sorter_output_folder / "matching"
spikes = find_spikes_from_templates(
recording_for_peeler, method=matching_method, method_kwargs=matching_params, **job_kwargs
recording_for_peeler,
method=matching_method,
method_kwargs=matching_params,
gather_mode=gather_mode,
gather_kwargs=gather_kwargs,
**job_kwargs,
)

if params["save_array"]:
Expand Down
24 changes: 20 additions & 4 deletions src/spikeinterface/sortingcomponents/matching/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@


def find_spikes_from_templates(
recording, method="naive", method_kwargs={}, extra_outputs=False, verbose=False, **job_kwargs
recording,
method="naive",
method_kwargs={},
extra_outputs=False,
gather_mode="memory",
gather_kwargs=None,
verbose=False,
**job_kwargs,
) -> np.ndarray | tuple[np.ndarray, dict]:
"""Find spike from a recording from given templates.

Expand All @@ -25,10 +32,14 @@ def find_spikes_from_templates(
Keyword arguments for the chosen method
extra_outputs : bool
If True then a dict is also returned is also returned
**job_kwargs : dict
Parameters for ChunkRecordingExecutor
gather_mode : "memory" | "npy", default: "memory"
If "memory" then the output is gathered in memory, if "npy" then the output is gathered on disk
gather_kwargs : dict, optional
The kwargs for the gather method
verbose : Bool, default: False
If True, output is verbose
**job_kwargs : keyword arguments
Parameters for ChunkRecordingExecutor

Returns
-------
Expand All @@ -50,13 +61,18 @@ def find_spikes_from_templates(
if len(method_kwargs["templates"].unit_ids) == 0:
return np.zeros(0, dtype=node0.get_dtype())

gather_kwargs = gather_kwargs or {}
names = ["spikes"]

spikes = run_node_pipeline(
recording,
nodes,
job_kwargs,
job_name=f"find spikes ({method})",
gather_mode="memory",
gather_mode=gather_mode,
squeeze_output=True,
names=names,
**gather_kwargs,
)
if extra_outputs:
outputs = node0.get_extra_outputs()
Expand Down