Skip to content

Commit dc1ea7a

Browse files
Method to bypass matching, and assign a labels to all peaks given templates and SVD representation (#3856)
* WIP * Example of how to use SVD to estimate templates in SC2 * Patching to get a working example * WIP * WIP * WIP * WIP * WIP * WIP * Cosmetic * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Patch * WIP * WIP * Fix * WIP * WIP * WIP * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Start a full clustering pipeline * Option to bypass matchin * WIP * Fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * WIP * Make delete_mixtures optional * Better logs * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * WIP * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * WIP * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 8dfef70 commit dc1ea7a

2 files changed

Lines changed: 161 additions & 5 deletions

File tree

src/spikeinterface/sorters/internal/spyking_circus2.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
107107
num_channels = recording.get_num_channels()
108108
ms_before = params["general"].get("ms_before", 2)
109109
ms_after = params["general"].get("ms_after", 2)
110-
radius_um = params["general"].get("radius_um", 75)
110+
radius_um = params["general"].get("radius_um", 100)
111111
peak_sign = params["detection"].get("peak_sign", "neg")
112112
templates_from_svd = params["templates_from_svd"]
113113
debug = params["debug"]
@@ -170,7 +170,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
170170
## Then, we are detecting peaks with a locally_exclusive method
171171
detection_method = params["detection"].get("method", "matched_filtering")
172172
detection_params = params["detection"].get("method_kwargs", dict())
173-
detection_params["radius_um"] = radius_um
173+
detection_params["radius_um"] = radius_um / 2
174174
detection_params["exclude_sweep_ms"] = exclude_sweep_ms
175175
detection_params["noise_levels"] = noise_levels
176176

@@ -219,6 +219,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
219219
)
220220
detection_method = "locally_exclusive"
221221

222+
matching_method = params["matching"].get("method", "circus-omp-svd")
223+
if matching_method is None:
224+
# We want all peaks if we are planning to assign them to templates afterwards
225+
detection_params["skip_after_n_peaks"] = None
226+
222227
peaks = detect_peaks(recording_w, detection_method, **detection_params, **job_kwargs)
223228
order = np.lexsort((peaks["sample_index"], peaks["segment_index"]))
224229
peaks = peaks[order]
@@ -353,11 +358,24 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
353358
else:
354359
## we should have a case to deal with clustering all peaks without matching
355360
## for small density channel counts
361+
from spikeinterface.sortingcomponents.matching.tools import assign_templates_to_peaks
362+
363+
peak_labels = assign_templates_to_peaks(
364+
recording_w,
365+
peaks,
366+
templates=templates,
367+
svd_model=svd_model,
368+
sparse_mask=sparsity_mask,
369+
**job_kwargs,
370+
)
371+
372+
if verbose:
373+
print("Found %d spikes" % len(peaks))
356374

357-
sorting = np.zeros(selected_peaks.size, dtype=minimum_spike_dtype)
358-
sorting["sample_index"] = selected_peaks["sample_index"]
375+
sorting = np.zeros(peaks.size, dtype=minimum_spike_dtype)
376+
sorting["sample_index"] = peaks["sample_index"]
359377
sorting["unit_index"] = peak_labels
360-
sorting["segment_index"] = selected_peaks["segment_index"]
378+
sorting["segment_index"] = peaks["segment_index"]
361379
sorting = NumpySorting(sorting, sampling_frequency, templates.unit_ids)
362380

363381
merging_params = params["merging"].copy()
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
from spikeinterface.core.node_pipeline import (
2+
run_node_pipeline,
3+
ExtractSparseWaveforms,
4+
ExtractDenseWaveforms,
5+
PeakRetriever,
6+
PipelineNode,
7+
)
8+
from spikeinterface.sortingcomponents.waveforms.temporal_pca import (
9+
TemporalPCAProjection,
10+
)
11+
from spikeinterface.core.job_tools import fix_job_kwargs
12+
import numpy as np
13+
from scipy.spatial.distance import cdist
14+
15+
16+
class FindNearestTemplate(PipelineNode):
17+
def __init__(
18+
self,
19+
recording,
20+
pca_model,
21+
sparsity_mask,
22+
templates,
23+
name="nn_templates",
24+
return_output=True,
25+
parents=None,
26+
):
27+
PipelineNode.__init__(self, recording, return_output=return_output, parents=parents)
28+
templates_array = templates.get_dense_templates()
29+
n_templates = templates_array.shape[0]
30+
num_channels = recording.get_num_channels()
31+
self.svd_templates = np.zeros((n_templates, pca_model.n_components, num_channels), "float32")
32+
for i in range(n_templates):
33+
self.svd_templates[i] = pca_model.transform(templates_array[i].T).T
34+
self.sparsity_mask = sparsity_mask
35+
self._dtype = recording.get_dtype()
36+
self._kwargs.update(
37+
dict(
38+
sparsity_mask=self.sparsity_mask,
39+
svd_templates=self.svd_templates,
40+
)
41+
)
42+
43+
def get_dtype(self):
44+
return self._dtype
45+
46+
def compute(self, traces, peaks, waveforms):
47+
peak_labels = np.empty(len(peaks), dtype="int64")
48+
for main_chan in np.unique(peaks["channel_index"]):
49+
(idx,) = np.nonzero(peaks["channel_index"] == main_chan)
50+
(chan_inds,) = np.nonzero(self.sparsity_mask[main_chan])
51+
local_svds = waveforms[idx][:, :, : len(chan_inds)]
52+
XA = local_svds.reshape(local_svds.shape[0], -1)
53+
XB = self.svd_templates[:, :, chan_inds].reshape(self.svd_templates.shape[0], -1)
54+
distances = cdist(XA, XB, metric="euclidean")
55+
peak_labels[idx] = np.argmin(distances, axis=1)
56+
return peak_labels
57+
58+
59+
def assign_templates_to_peaks(
60+
recording, peaks, svd_model, sparse_mask, templates, gather_mode="memory", **job_kwargs
61+
) -> np.ndarray | tuple[np.ndarray, dict]:
62+
"""
63+
Assigns templates to peaks using a pipeline of nodes.
64+
65+
Parameters
66+
----------
67+
recording : RecordingExtractor
68+
The recording extractor.
69+
peaks : np.ndarray
70+
Peaks that should be assigned to templates.
71+
templates : Templates
72+
The templates used for matching.
73+
svd_model : SVDModel
74+
The SVD model used for PCA projection.
75+
sparse_mask : np.ndarray
76+
The sparsity mask used to extract waveforms.
77+
gather_mode : str
78+
The mode for gathering results. Can be 'memory' or 'file'.
79+
job_kwargs : dict
80+
Additional keyword arguments for joblib.
81+
82+
Returns
83+
-------
84+
peak_labels: np.ndarray
85+
The labels assigned to each peak.
86+
"""
87+
88+
job_kwargs = fix_job_kwargs(job_kwargs)
89+
90+
node0 = PeakRetriever(recording, peaks)
91+
ms_before = templates.ms_before
92+
ms_after = templates.ms_after
93+
94+
if templates.are_templates_sparse():
95+
node1 = ExtractSparseWaveforms(
96+
recording,
97+
parents=[node0],
98+
return_output=False,
99+
ms_before=ms_before,
100+
ms_after=ms_after,
101+
sparsity_mask=sparse_mask,
102+
)
103+
else:
104+
node1 = ExtractDenseWaveforms(
105+
recording,
106+
parents=[node0],
107+
return_output=False,
108+
ms_before=ms_before,
109+
ms_after=ms_after,
110+
)
111+
112+
node2 = TemporalPCAProjection(
113+
recording,
114+
parents=[node0, node1],
115+
return_output=False,
116+
pca_model=svd_model,
117+
)
118+
119+
node3 = FindNearestTemplate(
120+
recording,
121+
parents=[node0, node2],
122+
return_output=True,
123+
pca_model=svd_model,
124+
templates=templates,
125+
sparsity_mask=sparse_mask,
126+
)
127+
128+
pipeline_nodes = [node0, node1, node2, node3]
129+
130+
peak_labels = run_node_pipeline(
131+
recording,
132+
pipeline_nodes,
133+
job_kwargs,
134+
job_name=f"assign labels",
135+
gather_mode=gather_mode,
136+
squeeze_output=True,
137+
)
138+
return peak_labels

0 commit comments

Comments
 (0)