diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index f038fe8b93..65099a3a12 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -107,7 +107,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): num_channels = recording.get_num_channels() ms_before = params["general"].get("ms_before", 2) ms_after = params["general"].get("ms_after", 2) - radius_um = params["general"].get("radius_um", 75) + radius_um = params["general"].get("radius_um", 100) peak_sign = params["detection"].get("peak_sign", "neg") templates_from_svd = params["templates_from_svd"] debug = params["debug"] @@ -170,7 +170,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## Then, we are detecting peaks with a locally_exclusive method detection_method = params["detection"].get("method", "matched_filtering") detection_params = params["detection"].get("method_kwargs", dict()) - detection_params["radius_um"] = radius_um + detection_params["radius_um"] = radius_um / 2 detection_params["exclude_sweep_ms"] = exclude_sweep_ms detection_params["noise_levels"] = noise_levels @@ -219,6 +219,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) detection_method = "locally_exclusive" + matching_method = params["matching"].get("method", "circus-omp-svd") + if matching_method is None: + # We want all peaks if we are planning to assign them to templates afterwards + detection_params["skip_after_n_peaks"] = None + peaks = detect_peaks(recording_w, detection_method, **detection_params, **job_kwargs) order = np.lexsort((peaks["sample_index"], peaks["segment_index"])) peaks = peaks[order] @@ -353,11 +358,24 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): else: ## we should have a case to deal with clustering all peaks without matching ## for small density channel counts + from spikeinterface.sortingcomponents.matching.tools import assign_templates_to_peaks + + peak_labels = assign_templates_to_peaks( + recording_w, + peaks, + templates=templates, + svd_model=svd_model, + sparse_mask=sparsity_mask, + **job_kwargs, + ) + + if verbose: + print("Found %d spikes" % len(peaks)) - sorting = np.zeros(selected_peaks.size, dtype=minimum_spike_dtype) - sorting["sample_index"] = selected_peaks["sample_index"] + sorting = np.zeros(peaks.size, dtype=minimum_spike_dtype) + sorting["sample_index"] = peaks["sample_index"] sorting["unit_index"] = peak_labels - sorting["segment_index"] = selected_peaks["segment_index"] + sorting["segment_index"] = peaks["segment_index"] sorting = NumpySorting(sorting, sampling_frequency, templates.unit_ids) merging_params = params["merging"].copy() diff --git a/src/spikeinterface/sortingcomponents/matching/tools.py b/src/spikeinterface/sortingcomponents/matching/tools.py new file mode 100644 index 0000000000..684bf138a0 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/matching/tools.py @@ -0,0 +1,138 @@ +from spikeinterface.core.node_pipeline import ( + run_node_pipeline, + ExtractSparseWaveforms, + ExtractDenseWaveforms, + PeakRetriever, + PipelineNode, +) +from spikeinterface.sortingcomponents.waveforms.temporal_pca import ( + TemporalPCAProjection, +) +from spikeinterface.core.job_tools import fix_job_kwargs +import numpy as np +from scipy.spatial.distance import cdist + + +class FindNearestTemplate(PipelineNode): + def __init__( + self, + recording, + pca_model, + sparsity_mask, + templates, + name="nn_templates", + return_output=True, + parents=None, + ): + PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) + templates_array = templates.get_dense_templates() + n_templates = templates_array.shape[0] + num_channels = recording.get_num_channels() + self.svd_templates = np.zeros((n_templates, pca_model.n_components, num_channels), "float32") + for i in range(n_templates): + self.svd_templates[i] = pca_model.transform(templates_array[i].T).T + self.sparsity_mask = sparsity_mask + self._dtype = recording.get_dtype() + self._kwargs.update( + dict( + sparsity_mask=self.sparsity_mask, + svd_templates=self.svd_templates, + ) + ) + + def get_dtype(self): + return self._dtype + + def compute(self, traces, peaks, waveforms): + peak_labels = np.empty(len(peaks), dtype="int64") + for main_chan in np.unique(peaks["channel_index"]): + (idx,) = np.nonzero(peaks["channel_index"] == main_chan) + (chan_inds,) = np.nonzero(self.sparsity_mask[main_chan]) + local_svds = waveforms[idx][:, :, : len(chan_inds)] + XA = local_svds.reshape(local_svds.shape[0], -1) + XB = self.svd_templates[:, :, chan_inds].reshape(self.svd_templates.shape[0], -1) + distances = cdist(XA, XB, metric="euclidean") + peak_labels[idx] = np.argmin(distances, axis=1) + return peak_labels + + +def assign_templates_to_peaks( + recording, peaks, svd_model, sparse_mask, templates, gather_mode="memory", **job_kwargs +) -> np.ndarray | tuple[np.ndarray, dict]: + """ + Assigns templates to peaks using a pipeline of nodes. + + Parameters + ---------- + recording : RecordingExtractor + The recording extractor. + peaks : np.ndarray + Peaks that should be assigned to templates. + templates : Templates + The templates used for matching. + svd_model : SVDModel + The SVD model used for PCA projection. + sparse_mask : np.ndarray + The sparsity mask used to extract waveforms. + gather_mode : str + The mode for gathering results. Can be 'memory' or 'file'. + job_kwargs : dict + Additional keyword arguments for joblib. + + Returns + ------- + peak_labels: np.ndarray + The labels assigned to each peak. + """ + + job_kwargs = fix_job_kwargs(job_kwargs) + + node0 = PeakRetriever(recording, peaks) + ms_before = templates.ms_before + ms_after = templates.ms_after + + if templates.are_templates_sparse(): + node1 = ExtractSparseWaveforms( + recording, + parents=[node0], + return_output=False, + ms_before=ms_before, + ms_after=ms_after, + sparsity_mask=sparse_mask, + ) + else: + node1 = ExtractDenseWaveforms( + recording, + parents=[node0], + return_output=False, + ms_before=ms_before, + ms_after=ms_after, + ) + + node2 = TemporalPCAProjection( + recording, + parents=[node0, node1], + return_output=False, + pca_model=svd_model, + ) + + node3 = FindNearestTemplate( + recording, + parents=[node0, node2], + return_output=True, + pca_model=svd_model, + templates=templates, + sparsity_mask=sparse_mask, + ) + + pipeline_nodes = [node0, node1, node2, node3] + + peak_labels = run_node_pipeline( + recording, + pipeline_nodes, + job_kwargs, + job_name=f"assign labels", + gather_mode=gather_mode, + squeeze_output=True, + ) + return peak_labels