diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 2edd38e77a..f038fe8b93 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -6,10 +6,8 @@ import numpy as np from spikeinterface.core import NumpySorting -from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs +from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.core.recording_tools import get_noise_levels -from spikeinterface.core.template import Templates -from spikeinterface.core.waveform_tools import estimate_templates from spikeinterface.preprocessing import common_reference, whiten, bandpass_filter, correct_motion from spikeinterface.sortingcomponents.tools import ( cache_preprocessing, @@ -24,28 +22,25 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): sorter_name = "spykingcircus2" _default_params = { - "general": {"ms_before": 2, "ms_after": 2, "radius_um": 75}, - "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 1}, + "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, + "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, "filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2, "margin_ms": 10}, "whitening": {"mode": "local", "regularize": False}, - "detection": {"peak_sign": "neg", "detect_threshold": 5}, + "detection": {"method": "matched_filtering", "method_kwargs": dict(peak_sign="neg", detect_threshold=5)}, "selection": { "method": "uniform", - "n_peaks_per_channel": 5000, - "min_n_peaks": 100000, - "select_per_channel": False, - "seed": 42, + "method_kwargs": dict(n_peaks_per_channel=5000, min_n_peaks=100000, select_per_channel=False), }, "apply_motion_correction": True, "motion_correction": {"preset": "dredge_fast"}, "merging": {"max_distance_um": 50}, - "clustering": {"legacy": True}, - "matching": {"method": "circus-omp-svd"}, + "clustering": {"method": "circus", "method_kwargs": dict()}, + "matching": {"method": "circus-omp-svd", "method_kwargs": dict()}, "apply_preprocessing": True, - "matched_filtering": True, + "templates_from_svd": True, "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, "multi_units_only": False, - "job_kwargs": {"n_jobs": 0.5}, + "job_kwargs": {"n_jobs": 0.75}, "seed": 42, "debug": False, } @@ -56,20 +51,18 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "general": "A dictionary to describe how templates should be computed. User can define ms_before and ms_after (in ms) \ and also the radius_um used to be considered during clustering", "sparsity": "A dictionary to be passed to all the calls to sparsify the templates", - "filtering": "A dictionary for the high_pass filter to be used during preprocessing", - "whitening": "A dictionary for the whitening option to be used during preprocessing", - "detection": "A dictionary for the peak detection node (locally_exclusive)", - "selection": "A dictionary for the peak selection node. Default is to use smart_sampling_amplitudes, with a minimum of 20000 peaks\ - and 5000 peaks per electrode on average.", - "clustering": "A dictionary to be provided to the clustering method. By default, random_projections is used, but if legacy is set to\ - True, one other clustering called circus will be used, similar to the one used in Spyking Circus 1", - "matching": "A dictionary to specify the matching engine used to recover spikes. The method default is circus-omp-svd, but other engines\ - can be used", + "filtering": "A dictionary for the high_pass filter used during preprocessing", + "whitening": "A dictionary for the whitening used during preprocessing", + "detection": "A dictionary for the peak detection component. Default is matched filtering", + "selection": "A dictionary for the peak selection component. Default is to use uniform", + "clustering": "A dictionary for the clustering component. Default, graph_clustering is used", + "matching": "A dictionary for the matching component. Default circus-omp-svd. Use None to avoid matching", "merging": "A dictionary to specify the final merging param to group cells after template matching (auto_merge_units)", "motion_correction": "A dictionary to be provided if motion correction has to be performed (dense probe only)", "apply_preprocessing": "Boolean to specify whether circus 2 should preprocess the recording or not. If yes, then high_pass filtering + common\ median reference + whitening", "apply_motion_correction": "Boolean to specify whether circus 2 should apply motion correction to the recording or not", + "templates_from_svd": "Boolean to specify whether templates should be computed from SVD or not.", "matched_filtering": "Boolean to specify whether circus 2 should detect peaks via matched filtering (slightly slower)", "cache_preprocessing": "How to cache the preprocessed recording. Mode can be memory, file, zarr, with extra arguments. In case of memory (default), \ memory_limit will control how much RAM can be used. In case of folder or zarr, delete_cache controls if cache is cleaned after sorting", @@ -86,18 +79,10 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): @classmethod def get_sorter_version(cls): - return "2.0" + return "2.1" @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): - try: - import hdbscan - - HAVE_HDBSCAN = True - except: - HAVE_HDBSCAN = False - - assert HAVE_HDBSCAN, "spykingcircus2 needs hdbscan to be installed" try: import torch @@ -124,11 +109,16 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ms_after = params["general"].get("ms_after", 2) radius_um = params["general"].get("radius_um", 75) peak_sign = params["detection"].get("peak_sign", "neg") + templates_from_svd = params["templates_from_svd"] + debug = params["debug"] + seed = params["seed"] + apply_preprocessing = params["apply_preprocessing"] + apply_motion_correction = params["apply_motion_correction"] exclude_sweep_ms = params["detection"].get("exclude_sweep_ms", max(ms_before, ms_after)) ## First, we are filtering the data filtering_params = params["filtering"].copy() - if params["apply_preprocessing"]: + if apply_preprocessing: if verbose: print("Preprocessing the recording (bandpass filtering + CMR + whitening)") recording_f = bandpass_filter(recording, **filtering_params, dtype="float32") @@ -141,7 +131,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording_f.annotate(is_filtered=True) valid_geometry = check_probe_for_drift_correction(recording_f) - if params["apply_motion_correction"]: + if apply_motion_correction: if not valid_geometry: if verbose: print("Geometry of the probe does not allow 1D drift correction") @@ -164,8 +154,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): whitening_kwargs["regularize"] = False if whitening_kwargs["regularize"]: whitening_kwargs["regularize_kwargs"] = {"method": "LedoitWolf"} + whitening_kwargs["apply_mean"] = True recording_w = whiten(recording_f, **whitening_kwargs) + noise_levels = get_noise_levels(recording_w, return_scaled=False, **job_kwargs) if recording_w.check_serializability("json"): @@ -176,146 +168,167 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording_w = cache_preprocessing(recording_w, **job_kwargs, **params["cache_preprocessing"]) ## Then, we are detecting peaks with a locally_exclusive method - detection_params = params["detection"].copy() - selection_params = params["selection"].copy() + detection_method = params["detection"].get("method", "matched_filtering") + detection_params = params["detection"].get("method_kwargs", dict()) detection_params["radius_um"] = radius_um detection_params["exclude_sweep_ms"] = exclude_sweep_ms detection_params["noise_levels"] = noise_levels - fs = recording_w.get_sampling_frequency() - nbefore = int(ms_before * fs / 1000.0) - nafter = int(ms_after * fs / 1000.0) - - skip_peaks = not params["multi_units_only"] and selection_params.get("method", "uniform") == "uniform" - max_n_peaks = selection_params["n_peaks_per_channel"] * num_channels - n_peaks = max(selection_params["min_n_peaks"], max_n_peaks) - - if params["debug"]: + selection_method = params["selection"].get("method", "uniform") + selection_params = params["selection"].get("method_kwargs", dict()) + n_peaks_per_channel = selection_params.get("n_peaks_per_channel", 5000) + min_n_peaks = selection_params.get("min_n_peaks", 100000) + skip_peaks = not params["multi_units_only"] and selection_method == "uniform" + max_n_peaks = n_peaks_per_channel * num_channels + n_peaks = max(min_n_peaks, max_n_peaks) + selection_params["n_peaks"] = n_peaks + selection_params["noise_levels"] = noise_levels + + if debug: clustering_folder = sorter_output_folder / "clustering" clustering_folder.mkdir(parents=True, exist_ok=True) np.save(clustering_folder / "noise_levels.npy", noise_levels) - if params["matched_filtering"]: + if detection_method == "matched_filtering": prototype, waveforms, _ = get_prototype_and_waveforms_from_recording( recording_w, n_peaks=10000, ms_before=ms_before, ms_after=ms_after, - seed=params["seed"], + seed=seed, **detection_params, **job_kwargs, ) detection_params["prototype"] = prototype detection_params["ms_before"] = ms_before - if params["debug"]: + if debug: np.save(clustering_folder / "waveforms.npy", waveforms) np.save(clustering_folder / "prototype.npy", prototype) if skip_peaks: detection_params["skip_after_n_peaks"] = n_peaks detection_params["recording_slices"] = get_shuffled_recording_slices( - recording_w, seed=params["seed"], **job_kwargs + recording_w, seed=seed, **job_kwargs ) - peaks = detect_peaks(recording_w, "matched_filtering", **detection_params, **job_kwargs) + detection_method = "matched_filtering" else: waveforms = None if skip_peaks: detection_params["skip_after_n_peaks"] = n_peaks detection_params["recording_slices"] = get_shuffled_recording_slices( - recording_w, seed=params["seed"], **job_kwargs + recording_w, seed=seed, **job_kwargs ) - peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params, **job_kwargs) + detection_method = "locally_exclusive" + + peaks = detect_peaks(recording_w, detection_method, **detection_params, **job_kwargs) + order = np.lexsort((peaks["sample_index"], peaks["segment_index"])) + peaks = peaks[order] + + if debug: + np.save(clustering_folder / "peaks.npy", peaks) if not skip_peaks and verbose: print("Found %d peaks in total" % len(peaks)) + sparsity_kwargs = params["sparsity"].copy() + if "peak_sign" not in sparsity_kwargs: + sparsity_kwargs["peak_sign"] = peak_sign + + sorting_folder = sorter_output_folder / "sorting" + if sorting_folder.exists(): + shutil.rmtree(sorting_folder) + if params["multi_units_only"]: - sorting = NumpySorting.from_peaks(peaks, sampling_frequency, unit_ids=recording_w.unit_ids) + sorting = NumpySorting.from_peaks(peaks, sampling_frequency, unit_ids=recording_w.channel_ids) else: ## We subselect a subset of all the peaks, by making the distributions os SNRs over all ## channels as flat as possible - selection_params = params["selection"] - selection_params["n_peaks"] = n_peaks - selection_params.update({"noise_levels": noise_levels}) - selected_peaks = select_peaks(peaks, **selection_params) + selected_peaks = select_peaks(peaks, seed=seed, method=selection_method, **selection_params) if verbose: print("Kept %d peaks for clustering" % len(selected_peaks)) - ## We launch a clustering (using hdbscan) relying on positions and features extracted on - ## the fly from the snippets - clustering_params = params["clustering"].copy() - clustering_params["waveforms"] = {} - sparsity_kwargs = params["sparsity"].copy() - if "peak_sign" not in sparsity_kwargs: - sparsity_kwargs["peak_sign"] = peak_sign - - clustering_params["sparsity"] = sparsity_kwargs - clustering_params["radius_um"] = radius_um - clustering_params["waveforms"]["ms_before"] = ms_before - clustering_params["waveforms"]["ms_after"] = ms_after - clustering_params["few_waveforms"] = waveforms - clustering_params["noise_levels"] = noise_levels - clustering_params["ms_before"] = ms_before - clustering_params["ms_after"] = ms_after - clustering_params["verbose"] = verbose - clustering_params["tmp_folder"] = sorter_output_folder / "clustering" - clustering_params["debug"] = params["debug"] - clustering_params["noise_threshold"] = detection_params.get("detect_threshold", 4) - - legacy = clustering_params.get("legacy", True) - - if legacy: - clustering_method = "circus" - else: - clustering_method = "random_projections" - - labels, peak_labels = find_cluster_from_peaks( - recording_w, selected_peaks, method=clustering_method, method_kwargs=clustering_params, **job_kwargs - ) - - ## We get the labels for our peaks - mask = peak_labels > -1 - - labeled_peaks = np.zeros(np.sum(mask), dtype=minimum_spike_dtype) - labeled_peaks["sample_index"] = selected_peaks[mask]["sample_index"] - labeled_peaks["segment_index"] = selected_peaks[mask]["segment_index"] - for count, l in enumerate(labels): - sub_mask = peak_labels[mask] == l - labeled_peaks["unit_index"][sub_mask] = count - unit_ids = np.arange(len(np.unique(labeled_peaks["unit_index"]))) - sorting = NumpySorting(labeled_peaks, sampling_frequency, unit_ids=unit_ids) - - if params["debug"]: - np.save(clustering_folder / "peak_labels", peak_labels) - np.save(clustering_folder / "labels", labels) - np.save(clustering_folder / "peaks", selected_peaks) - - templates_array = estimate_templates( - recording_w, labeled_peaks, unit_ids, nbefore, nafter, return_scaled=False, job_name=None, **job_kwargs + clustering_method = params["clustering"].get("method", "graph_clustering") + clustering_params = params["clustering"].get("method_kwargs", dict()) + + if clustering_method == "circus": + clustering_params["waveforms"] = {} + clustering_params["sparsity"] = sparsity_kwargs + clustering_params["neighbors_radius_um"] = 50 + clustering_params["radius_um"] = radius_um + clustering_params["waveforms"]["ms_before"] = ms_before + clustering_params["waveforms"]["ms_after"] = ms_after + clustering_params["few_waveforms"] = waveforms + clustering_params["noise_levels"] = noise_levels + clustering_params["ms_before"] = ms_before + clustering_params["ms_after"] = ms_after + clustering_params["verbose"] = verbose + clustering_params["templates_from_svd"] = templates_from_svd + clustering_params["tmp_folder"] = sorter_output_folder / "clustering" + clustering_params["debug"] = debug + clustering_params["noise_threshold"] = detection_params.get("detect_threshold", 4) + elif clustering_method == "graph_clustering": + clustering_params = { + "ms_before": ms_before, + "ms_after": ms_after, + "clustering_method": "hdbscan", + "radius_um": radius_um, + "clustering_kwargs": dict( + min_samples=1, + min_cluster_size=50, + core_dist_n_jobs=-1, + cluster_selection_method="leaf", + allow_single_cluster=True, + cluster_selection_epsilon=0.1, + ), + } + + outputs = find_cluster_from_peaks( + recording_w, + selected_peaks, + method=clustering_method, + method_kwargs=clustering_params, + extra_outputs=templates_from_svd, + **job_kwargs, ) - templates = Templates( - templates_array=templates_array, - sampling_frequency=sampling_frequency, - nbefore=nbefore, - sparsity_mask=None, - channel_ids=recording_w.channel_ids, - unit_ids=unit_ids, - probe=recording_w.get_probe(), - is_scaled=False, - ) + if len(outputs) == 2: + _, peak_labels = outputs + from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_recording + + templates = get_templates_from_peaks_and_recording( + recording_w, + selected_peaks, + peak_labels, + ms_before, + ms_after, + **job_kwargs, + ) + elif len(outputs) == 5: + _, peak_labels, svd_model, svd_features, sparsity_mask = outputs + from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd + + templates = get_templates_from_peaks_and_svd( + recording_w, + selected_peaks, + peak_labels, + ms_before, + ms_after, + svd_model, + svd_features, + sparsity_mask, + operator="median", + ) sparsity = compute_sparsity(templates, noise_levels, **sparsity_kwargs) templates = templates.to_sparse(sparsity) templates = remove_empty_templates(templates) - if params["debug"]: + if debug: templates.to_zarr(folder_path=clustering_folder / "templates") - sorting = sorting.save(folder=clustering_folder / "sorting") ## We launch a OMP matching pursuit by full convolution of the templates and the raw traces - matching_method = params["matching"].pop("method") - matching_params = params["matching"].copy() + matching_method = params["matching"].get("method", "circus-omp_svd") + matching_params = params["matching"].get("method_kwargs", dict()) matching_params["templates"] = templates if matching_method is not None: @@ -323,7 +336,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording_w, matching_method, method_kwargs=matching_params, **job_kwargs ) - if params["debug"]: + if debug: fitting_folder = sorter_output_folder / "fitting" fitting_folder.mkdir(parents=True, exist_ok=True) np.save(fitting_folder / "spikes", spikes) @@ -336,39 +349,43 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sorting["sample_index"] = spikes["sample_index"] sorting["unit_index"] = spikes["cluster_index"] sorting["segment_index"] = spikes["segment_index"] - sorting = NumpySorting(sorting, sampling_frequency, unit_ids) + sorting = NumpySorting(sorting, sampling_frequency, templates.unit_ids) + else: + ## we should have a case to deal with clustering all peaks without matching + ## for small density channel counts - sorting_folder = sorter_output_folder / "sorting" - if sorting_folder.exists(): - shutil.rmtree(sorting_folder) + sorting = np.zeros(selected_peaks.size, dtype=minimum_spike_dtype) + sorting["sample_index"] = selected_peaks["sample_index"] + sorting["unit_index"] = peak_labels + sorting["segment_index"] = selected_peaks["segment_index"] + sorting = NumpySorting(sorting, sampling_frequency, templates.unit_ids) - merging_params = params["merging"].copy() - if params["debug"]: + merging_params = params["merging"].copy() merging_params["debug_folder"] = sorter_output_folder / "merging" - if len(merging_params) > 0: - if params["motion_correction"] and motion_folder is not None: - from spikeinterface.preprocessing.motion import load_motion_info + if len(merging_params) > 0: + if params["motion_correction"] and motion_folder is not None: + from spikeinterface.preprocessing.motion import load_motion_info - motion_info = load_motion_info(motion_folder) - motion = motion_info["motion"] - max_motion = max( - np.max(np.abs(motion.displacement[seg_index])) for seg_index in range(len(motion.displacement)) - ) - max_distance_um = merging_params.get("max_distance_um", 50) - merging_params["max_distance_um"] = max(max_distance_um, 2 * max_motion) + motion_info = load_motion_info(motion_folder) + motion = motion_info["motion"] + max_motion = max( + np.max(np.abs(motion.displacement[seg_index])) for seg_index in range(len(motion.displacement)) + ) + max_distance_um = merging_params.get("max_distance_um", 50) + merging_params["max_distance_um"] = max(max_distance_um, 2 * max_motion) - if params["debug"]: - curation_folder = sorter_output_folder / "curation" - if curation_folder.exists(): - shutil.rmtree(curation_folder) - sorting.save(folder=curation_folder) - # np.save(fitting_folder / "amplitudes", guessed_amplitudes) + if debug: + curation_folder = sorter_output_folder / "curation" + if curation_folder.exists(): + shutil.rmtree(curation_folder) + sorting.save(folder=curation_folder) + # np.save(fitting_folder / "amplitudes", guessed_amplitudes) - sorting = final_cleaning_circus(recording_w, sorting, templates, **merging_params, **job_kwargs) + sorting = final_cleaning_circus(recording_w, sorting, templates, **merging_params, **job_kwargs) - if verbose: - print(f"Kept {len(sorting.unit_ids)} units after final merging") + if verbose: + print(f"Kept {len(sorting.unit_ids)} units after final merging") folder_to_delete = None cache_mode = params["cache_preprocessing"].get("mode", "memory") diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 7bce0800d3..0fd58d3011 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -19,17 +19,10 @@ from .clustering_tools import remove_duplicates_via_matching from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances from spikeinterface.sortingcomponents.peak_selection import select_peaks -from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection -from spikeinterface.sortingcomponents.waveforms.hanning_filter import HanningFilter from spikeinterface.core.template import Templates from spikeinterface.core.sparsity import compute_sparsity from spikeinterface.sortingcomponents.tools import remove_empty_templates -import pickle, json -from spikeinterface.core.node_pipeline import ( - run_node_pipeline, - ExtractSparseWaveforms, - PeakRetriever, -) +from spikeinterface.sortingcomponents.clustering.peak_svd import extract_peaks_svd from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel @@ -48,6 +41,7 @@ class CircusClustering: "allow_single_cluster": True, }, "cleaning_kwargs": {}, + "remove_mixtures": False, "waveforms": {"ms_before": 2, "ms_after": 2}, "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, "recursive_kwargs": { @@ -55,13 +49,16 @@ class CircusClustering: "recursive_depth": 3, "returns_split_count": True, }, + "split_kwargs": {"projection_mode": "tsvd", "n_pca_features": 0.9}, "radius_um": 100, + "neighbors_radius_um": 50, "n_svd": 5, "few_waveforms": None, "ms_before": 0.5, "ms_after": 0.5, "noise_threshold": 4, "rank": 5, + "templates_from_svd": False, "noise_levels": None, "tmp_folder": None, "verbose": True, @@ -78,6 +75,8 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): fs = recording.get_sampling_frequency() ms_before = params["ms_before"] ms_after = params["ms_after"] + radius_um = params["radius_um"] + neighbors_radius_um = params["neighbors_radius_um"] nbefore = int(ms_before * fs / 1000.0) nafter = int(ms_after * fs / 1000.0) if params["tmp_folder"] is None: @@ -108,210 +107,139 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): valid = np.argmax(np.abs(wfs), axis=1) == nbefore wfs = wfs[valid] - # Perform Hanning filtering - hanning_before = np.hanning(2 * nbefore) - hanning_after = np.hanning(2 * nafter) - hanning = np.concatenate((hanning_before[:nbefore], hanning_after[nafter:])) - wfs *= hanning - from sklearn.decomposition import TruncatedSVD - tsvd = TruncatedSVD(params["n_svd"]) - tsvd.fit(wfs) - - model_folder = tmp_folder / "tsvd_model" - - model_folder.mkdir(exist_ok=True) - with open(model_folder / "pca_model.pkl", "wb") as f: - pickle.dump(tsvd, f) - - model_params = { - "ms_before": ms_before, - "ms_after": ms_after, - "sampling_frequency": float(fs), - } - - with open(model_folder / "params.json", "w") as f: - json.dump(model_params, f) + svd_model = TruncatedSVD(params["n_svd"]) + svd_model.fit(wfs) + features_folder = tmp_folder / "tsvd_features" + features_folder.mkdir(exist_ok=True) - # features - node0 = PeakRetriever(recording, peaks) - - radius_um = params["radius_um"] - node1 = ExtractSparseWaveforms( + peaks_svd, sparse_mask, svd_model = extract_peaks_svd( recording, - parents=[node0], - return_output=False, + peaks, ms_before=ms_before, ms_after=ms_after, + svd_model=svd_model, radius_um=radius_um, + folder=features_folder, + **job_kwargs, ) - node2 = HanningFilter(recording, parents=[node0, node1], return_output=False) + neighbours_mask = get_channel_distances(recording) <= neighbors_radius_um - node3 = TemporalPCAProjection( - recording, parents=[node0, node2], return_output=True, model_folder_path=model_folder - ) + if params["debug"]: + np.save(features_folder / "sparse_mask.npy", sparse_mask) + np.save(features_folder / "peaks.npy", peaks) - pipeline_nodes = [node0, node1, node2, node3] + original_labels = peaks["channel_index"] + from spikeinterface.sortingcomponents.clustering.split import split_clusters - if len(params["recursive_kwargs"]) == 0: - from sklearn.decomposition import PCA + split_kwargs = params["split_kwargs"].copy() + split_kwargs["neighbours_mask"] = neighbours_mask + split_kwargs["waveforms_sparse_mask"] = sparse_mask + split_kwargs["min_size_split"] = 2 * params["hdbscan_kwargs"].get("min_cluster_size", 50) + split_kwargs["clusterer_kwargs"] = params["hdbscan_kwargs"] - all_pc_data = run_node_pipeline( - recording, - pipeline_nodes, - job_kwargs, - job_name="extracting features", - ) - - peak_labels = -1 * np.ones(len(peaks), dtype=int) - nb_clusters = 0 - for c in np.unique(peaks["channel_index"]): - mask = peaks["channel_index"] == c - sub_data = all_pc_data[mask] - sub_data = sub_data.reshape(len(sub_data), -1) - - if all_pc_data.shape[1] > params["n_svd"]: - tsvd = PCA(params["n_svd"], whiten=True) - else: - tsvd = PCA(all_pc_data.shape[1], whiten=True) - - hdbscan_data = tsvd.fit_transform(sub_data) - try: - clustering = hdbscan.hdbscan(hdbscan_data, **d["hdbscan_kwargs"]) - local_labels = clustering[0] - except Exception: - local_labels = np.zeros(len(hdbscan_data)) - valid_clusters = local_labels > -1 - if np.sum(valid_clusters) > 0: - local_labels[valid_clusters] += nb_clusters - peak_labels[mask] = local_labels - nb_clusters += len(np.unique(local_labels[valid_clusters])) + if params["debug"]: + debug_folder = tmp_folder / "split" else: + debug_folder = None - features_folder = tmp_folder / "tsvd_features" - features_folder.mkdir(exist_ok=True) - - _ = run_node_pipeline( - recording, - pipeline_nodes, - job_kwargs, - job_name="extracting features", - gather_mode="npy", - gather_kwargs=dict(exist_ok=True), - folder=features_folder, - names=["sparse_tsvd"], - ) - - sparse_mask = node1.neighbours_mask - neighbours_mask = get_channel_distances(recording) <= radius_um - - # np.save(features_folder / "sparse_mask.npy", sparse_mask) - np.save(features_folder / "peaks.npy", peaks) - - original_labels = peaks["channel_index"] - from spikeinterface.sortingcomponents.clustering.split import split_clusters + peak_labels, _ = split_clusters( + original_labels, + recording, + {"peaks": peaks, "sparse_tsvd": peaks_svd}, + method="local_feature_clustering", + method_kwargs=split_kwargs, + debug_folder=debug_folder, + **params["recursive_kwargs"], + **job_kwargs, + ) - min_size = 2 * params["hdbscan_kwargs"].get("min_cluster_size", 20) + if params["noise_levels"] is None: + params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs) - if params["debug"]: - debug_folder = tmp_folder / "split" - else: - debug_folder = None + if not params["templates_from_svd"]: + from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_recording - peak_labels, _ = split_clusters( - original_labels, + templates = get_templates_from_peaks_and_recording( recording, - features_folder, - method="local_feature_clustering", - method_kwargs=dict( - clusterer="hdbscan", - feature_name="sparse_tsvd", - neighbours_mask=neighbours_mask, - waveforms_sparse_mask=sparse_mask, - min_size_split=min_size, - clusterer_kwargs=d["hdbscan_kwargs"], - n_pca_features=5, - ), - debug_folder=debug_folder, - **params["recursive_kwargs"], + peaks, + peak_labels, + ms_before, + ms_after, **job_kwargs, ) + else: + from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd - non_noise = peak_labels > -1 - labels, inverse = np.unique(peak_labels[non_noise], return_inverse=True) - peak_labels[non_noise] = inverse - labels = np.unique(inverse) - - spikes = np.zeros(non_noise.sum(), dtype=minimum_spike_dtype) - spikes["sample_index"] = peaks[non_noise]["sample_index"] - spikes["segment_index"] = peaks[non_noise]["segment_index"] - spikes["unit_index"] = peak_labels[non_noise] - - unit_ids = labels - - nbefore = int(params["waveforms"]["ms_before"] * fs / 1000.0) - nafter = int(params["waveforms"]["ms_after"] * fs / 1000.0) - - if params["noise_levels"] is None: - params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs) - - templates_array = estimate_templates( - recording, - spikes, - unit_ids, - nbefore, - nafter, - return_scaled=False, - job_name=None, - **job_kwargs, - ) + templates = get_templates_from_peaks_and_svd( + recording, + peaks, + peak_labels, + ms_before, + ms_after, + svd_model, + peaks_svd, + sparse_mask, + operator="median", + ) + templates_array = templates.templates_array best_channels = np.argmax(np.abs(templates_array[:, nbefore, :]), axis=1) peak_snrs = np.abs(templates_array[:, nbefore, :]) best_snrs_ratio = (peak_snrs / params["noise_levels"])[np.arange(len(peak_snrs)), best_channels] + old_unit_ids = templates.unit_ids.copy() valid_templates = best_snrs_ratio > params["noise_threshold"] - if d["rank"] is not None: - from spikeinterface.sortingcomponents.matching.circus import compress_templates + mask = np.isin(peak_labels, old_unit_ids[~valid_templates]) + peak_labels[mask] = -1 - _, _, _, templates_array = compress_templates(templates_array, d["rank"]) + from spikeinterface.core.template import Templates templates = Templates( templates_array=templates_array[valid_templates], sampling_frequency=fs, - nbefore=nbefore, + nbefore=templates.nbefore, sparsity_mask=None, channel_ids=recording.channel_ids, - unit_ids=unit_ids[valid_templates], + unit_ids=templates.unit_ids[valid_templates], probe=recording.get_probe(), is_scaled=False, ) + if params["debug"]: + templates_folder = tmp_folder / "dense_templates" + templates.to_zarr(folder_path=templates_folder) + sparsity = compute_sparsity(templates, noise_levels=params["noise_levels"], **params["sparsity"]) templates = templates.to_sparse(sparsity) empty_templates = templates.sparsity_mask.sum(axis=1) == 0 + old_unit_ids = templates.unit_ids.copy() templates = remove_empty_templates(templates) - mask = np.isin(peak_labels, np.where(empty_templates)[0]) + mask = np.isin(peak_labels, old_unit_ids[empty_templates]) peak_labels[mask] = -1 - mask = np.isin(peak_labels, np.where(~valid_templates)[0]) - peak_labels[mask] = -1 + labels = np.unique(peak_labels) + labels = labels[labels >= 0] - if verbose: - print("Found %d raw clusters, starting to clean with matching" % (len(templates.unit_ids))) + if params["remove_mixtures"]: + if verbose: + print("Found %d raw clusters, starting to clean with matching" % (len(templates.unit_ids))) - cleaning_job_kwargs = job_kwargs.copy() - cleaning_job_kwargs["progress_bar"] = False - cleaning_params = params["cleaning_kwargs"].copy() + cleaning_job_kwargs = job_kwargs.copy() + cleaning_job_kwargs["progress_bar"] = False + cleaning_params = params["cleaning_kwargs"].copy() - labels, peak_labels = remove_duplicates_via_matching( - templates, peak_labels, job_kwargs=cleaning_job_kwargs, **cleaning_params - ) + labels, peak_labels = remove_duplicates_via_matching( + templates, peak_labels, job_kwargs=cleaning_job_kwargs, **cleaning_params + ) - if verbose: - print("Kept %d non-duplicated clusters" % len(labels)) + if verbose: + print("Kept %d non-duplicated clusters" % len(labels)) + else: + if verbose: + print("Kept %d raw clusters" % len(labels)) - return labels, peak_labels + return labels, peak_labels, svd_model, peaks_svd, sparse_mask diff --git a/src/spikeinterface/sortingcomponents/clustering/graph_clustering.py b/src/spikeinterface/sortingcomponents/clustering/graph_clustering.py index 28409c2221..a0034e7741 100644 --- a/src/spikeinterface/sortingcomponents/clustering/graph_clustering.py +++ b/src/spikeinterface/sortingcomponents/clustering/graph_clustering.py @@ -59,6 +59,8 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): radius_um = params["radius_um"] motion = params["motion"] seed = params["seed"] + ms_before = params["ms_before"] + ms_after = params["ms_after"] clustering_method = params["clustering_method"] clustering_kwargs = params["clustering_kwargs"] graph_kwargs = params["graph_kwargs"] @@ -70,9 +72,11 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): elif graph_kwargs["bin_mode"] == "vertical_bins": assert radius_um >= graph_kwargs["bin_um"] * 3 - peaks_svd, sparse_mask, _ = extract_peaks_svd( + peaks_svd, sparse_mask, svd_model = extract_peaks_svd( recording, peaks, + ms_before=ms_before, + ms_after=ms_after, radius_um=radius_um, motion_aware=motion_aware, motion=None, @@ -98,7 +102,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): # print(distances.shape) # print("sparsity: ", distances.indices.size / (distances.shape[0]**2)) - print("clustering_method", clustering_method) + # print("clustering_method", clustering_method) if clustering_method == "networkx-louvain": # using networkx : very slow (possible backend with cude backend="cugraph",) @@ -191,7 +195,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): labels_set = np.unique(peak_labels) labels_set = labels_set[labels_set >= 0] - return labels_set, peak_labels + return labels_set, peak_labels, svd_model, peaks_svd, sparse_mask def _remove_small_cluster(peak_labels, min_size=1): diff --git a/src/spikeinterface/sortingcomponents/clustering/graph_tools.py b/src/spikeinterface/sortingcomponents/clustering/graph_tools.py index 43f9abe141..409181bcf3 100644 --- a/src/spikeinterface/sortingcomponents/clustering/graph_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/graph_tools.py @@ -118,7 +118,7 @@ def create_graph_from_peak_features( raise ValueError("create_graph_from_peak_features : wrong bin_mode") if progress_bar: - loop = tqdm(loop, desc=f"Construct distance graph looping over {bin_mode}") + loop = tqdm(loop, desc=f"Build distance graph over {bin_mode}") local_graphs = [] row_indices = [] diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index 693f67305f..20b8f2c8de 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -200,6 +200,7 @@ def get_templates_from_peaks_and_recording( peak_labels, ms_before, ms_after, + operator="average", **job_kwargs, ): """ @@ -219,6 +220,8 @@ def get_templates_from_peaks_and_recording( The time window before the peak in milliseconds. ms_after : float The time window after the peak in milliseconds. + operator : str + The operator to use for template estimation. Can be 'average' or 'median'. job_kwargs : dict Additional keyword arguments for the estimate_templates function. @@ -228,17 +231,34 @@ def get_templates_from_peaks_and_recording( The estimated templates object. """ from spikeinterface.core.template import Templates + from spikeinterface.core.basesorting import minimum_spike_dtype mask = peak_labels > -1 - labels = np.unique(peak_labels[mask]) + valid_peaks = peaks[mask] + valid_labels = peak_labels[mask] + labels, indices = np.unique(valid_labels, return_inverse=True) + fs = recording.get_sampling_frequency() nbefore = int(ms_before * fs / 1000.0) nafter = int(ms_after * fs / 1000.0) + spikes = np.zeros(valid_peaks.size, dtype=minimum_spike_dtype) + spikes["sample_index"] = valid_peaks["sample_index"] + spikes["unit_index"] = indices + spikes["segment_index"] = valid_peaks["segment_index"] + from spikeinterface.core.waveform_tools import estimate_templates templates_array = estimate_templates( - recording, peaks, labels, nbefore, nafter, return_scaled=False, job_name=None, **job_kwargs + recording, + spikes, + np.arange(len(labels)), + nbefore, + nafter, + operator=operator, + return_scaled=False, + job_name=None, + **job_kwargs, ) templates = Templates( @@ -264,7 +284,7 @@ def get_templates_from_peaks_and_svd( svd_model, svd_features, sparsity_mask, - operator="mean", + operator="average", ): """ Get templates from recording using the SVD components @@ -287,6 +307,8 @@ def get_templates_from_peaks_and_svd( The SVD features array. sparsity_mask : numpy.ndarray The sparsity mask array. + operator : str + The operator to use for template estimation. Can be 'average' or 'median'. Returns ------- @@ -295,9 +317,12 @@ def get_templates_from_peaks_and_svd( """ from spikeinterface.core.template import Templates - assert operator in ["mean", "median"], "operator should be either 'mean' or 'median'" + assert operator in ["average", "median"], "operator should be either 'average' or 'median'" mask = peak_labels > -1 - labels = np.unique(peak_labels[mask]) + valid_peaks = peaks[mask] + valid_labels = peak_labels[mask] + valid_svd_features = svd_features[mask] + labels = np.unique(valid_labels) fs = recording.get_sampling_frequency() nbefore = int(ms_before * fs / 1000.0) @@ -306,14 +331,14 @@ def get_templates_from_peaks_and_svd( templates_array = np.zeros((len(labels), nbefore + nafter, num_channels), dtype=np.float32) for unit_ind, label in enumerate(labels): - mask = peak_labels == label - local_peaks = peaks[mask] - local_svd = svd_features[mask] + mask = valid_labels == label + local_peaks = valid_peaks[mask] + local_svd = valid_svd_features[mask] peak_channels, b = np.unique(local_peaks["channel_index"], return_counts=True) best_channel = peak_channels[np.argmax(b)] sub_mask = local_peaks["channel_index"] == best_channel for count, i in enumerate(np.flatnonzero(sparsity_mask[best_channel])): - if operator == "mean": + if operator == "average": data = np.mean(local_svd[sub_mask, :, count], 0) elif operator == "median": data = np.median(local_svd[sub_mask, :, count], 0) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 3b97f2dc6a..64a6b2333d 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -31,7 +31,7 @@ def compress_templates( - templates_array, approx_rank, remove_mean=True, return_new_templates=True + templates_array, approx_rank, remove_mean=False, return_new_templates=True ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray | None]: """Compress templates using singular value decomposition.