Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
313 changes: 165 additions & 148 deletions src/spikeinterface/sorters/internal/spyking_circus2.py

Large diffs are not rendered by default.

256 changes: 92 additions & 164 deletions src/spikeinterface/sortingcomponents/clustering/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,10 @@
from .clustering_tools import remove_duplicates_via_matching
from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances
from spikeinterface.sortingcomponents.peak_selection import select_peaks
from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection
from spikeinterface.sortingcomponents.waveforms.hanning_filter import HanningFilter
from spikeinterface.core.template import Templates
from spikeinterface.core.sparsity import compute_sparsity
from spikeinterface.sortingcomponents.tools import remove_empty_templates
import pickle, json
from spikeinterface.core.node_pipeline import (
run_node_pipeline,
ExtractSparseWaveforms,
PeakRetriever,
)
from spikeinterface.sortingcomponents.clustering.peak_svd import extract_peaks_svd


from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel
Expand All @@ -48,20 +41,24 @@ class CircusClustering:
"allow_single_cluster": True,
},
"cleaning_kwargs": {},
"remove_mixtures": False,
"waveforms": {"ms_before": 2, "ms_after": 2},
"sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25},
"recursive_kwargs": {
"recursive": True,
"recursive_depth": 3,
"returns_split_count": True,
},
"split_kwargs": {"projection_mode": "tsvd", "n_pca_features": 0.9},
"radius_um": 100,
"neighbors_radius_um": 50,
"n_svd": 5,
"few_waveforms": None,
"ms_before": 0.5,
"ms_after": 0.5,
"noise_threshold": 4,
"rank": 5,
"templates_from_svd": False,
"noise_levels": None,
"tmp_folder": None,
"verbose": True,
Expand All @@ -78,6 +75,8 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
fs = recording.get_sampling_frequency()
ms_before = params["ms_before"]
ms_after = params["ms_after"]
radius_um = params["radius_um"]
neighbors_radius_um = params["neighbors_radius_um"]
nbefore = int(ms_before * fs / 1000.0)
nafter = int(ms_after * fs / 1000.0)
if params["tmp_folder"] is None:
Expand Down Expand Up @@ -108,210 +107,139 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
valid = np.argmax(np.abs(wfs), axis=1) == nbefore
wfs = wfs[valid]

# Perform Hanning filtering
hanning_before = np.hanning(2 * nbefore)
hanning_after = np.hanning(2 * nafter)
hanning = np.concatenate((hanning_before[:nbefore], hanning_after[nafter:]))
wfs *= hanning

from sklearn.decomposition import TruncatedSVD

tsvd = TruncatedSVD(params["n_svd"])
tsvd.fit(wfs)

model_folder = tmp_folder / "tsvd_model"

model_folder.mkdir(exist_ok=True)
with open(model_folder / "pca_model.pkl", "wb") as f:
pickle.dump(tsvd, f)

model_params = {
"ms_before": ms_before,
"ms_after": ms_after,
"sampling_frequency": float(fs),
}

with open(model_folder / "params.json", "w") as f:
json.dump(model_params, f)
svd_model = TruncatedSVD(params["n_svd"])
svd_model.fit(wfs)
features_folder = tmp_folder / "tsvd_features"
features_folder.mkdir(exist_ok=True)

# features
node0 = PeakRetriever(recording, peaks)

radius_um = params["radius_um"]
node1 = ExtractSparseWaveforms(
peaks_svd, sparse_mask, svd_model = extract_peaks_svd(
recording,
parents=[node0],
return_output=False,
peaks,
ms_before=ms_before,
ms_after=ms_after,
svd_model=svd_model,
radius_um=radius_um,
folder=features_folder,
**job_kwargs,
)

node2 = HanningFilter(recording, parents=[node0, node1], return_output=False)
neighbours_mask = get_channel_distances(recording) <= neighbors_radius_um

node3 = TemporalPCAProjection(
recording, parents=[node0, node2], return_output=True, model_folder_path=model_folder
)
if params["debug"]:
np.save(features_folder / "sparse_mask.npy", sparse_mask)
np.save(features_folder / "peaks.npy", peaks)

pipeline_nodes = [node0, node1, node2, node3]
original_labels = peaks["channel_index"]
from spikeinterface.sortingcomponents.clustering.split import split_clusters

if len(params["recursive_kwargs"]) == 0:
from sklearn.decomposition import PCA
split_kwargs = params["split_kwargs"].copy()
split_kwargs["neighbours_mask"] = neighbours_mask
split_kwargs["waveforms_sparse_mask"] = sparse_mask
split_kwargs["min_size_split"] = 2 * params["hdbscan_kwargs"].get("min_cluster_size", 50)
split_kwargs["clusterer_kwargs"] = params["hdbscan_kwargs"]

all_pc_data = run_node_pipeline(
recording,
pipeline_nodes,
job_kwargs,
job_name="extracting features",
)

peak_labels = -1 * np.ones(len(peaks), dtype=int)
nb_clusters = 0
for c in np.unique(peaks["channel_index"]):
mask = peaks["channel_index"] == c
sub_data = all_pc_data[mask]
sub_data = sub_data.reshape(len(sub_data), -1)

if all_pc_data.shape[1] > params["n_svd"]:
tsvd = PCA(params["n_svd"], whiten=True)
else:
tsvd = PCA(all_pc_data.shape[1], whiten=True)

hdbscan_data = tsvd.fit_transform(sub_data)
try:
clustering = hdbscan.hdbscan(hdbscan_data, **d["hdbscan_kwargs"])
local_labels = clustering[0]
except Exception:
local_labels = np.zeros(len(hdbscan_data))
valid_clusters = local_labels > -1
if np.sum(valid_clusters) > 0:
local_labels[valid_clusters] += nb_clusters
peak_labels[mask] = local_labels
nb_clusters += len(np.unique(local_labels[valid_clusters]))
if params["debug"]:
debug_folder = tmp_folder / "split"
else:
debug_folder = None

features_folder = tmp_folder / "tsvd_features"
features_folder.mkdir(exist_ok=True)

_ = run_node_pipeline(
recording,
pipeline_nodes,
job_kwargs,
job_name="extracting features",
gather_mode="npy",
gather_kwargs=dict(exist_ok=True),
folder=features_folder,
names=["sparse_tsvd"],
)

sparse_mask = node1.neighbours_mask
neighbours_mask = get_channel_distances(recording) <= radius_um

# np.save(features_folder / "sparse_mask.npy", sparse_mask)
np.save(features_folder / "peaks.npy", peaks)

original_labels = peaks["channel_index"]
from spikeinterface.sortingcomponents.clustering.split import split_clusters
peak_labels, _ = split_clusters(
original_labels,
recording,
{"peaks": peaks, "sparse_tsvd": peaks_svd},
method="local_feature_clustering",
method_kwargs=split_kwargs,
debug_folder=debug_folder,
**params["recursive_kwargs"],
**job_kwargs,
)

min_size = 2 * params["hdbscan_kwargs"].get("min_cluster_size", 20)
if params["noise_levels"] is None:
params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs)

if params["debug"]:
debug_folder = tmp_folder / "split"
else:
debug_folder = None
if not params["templates_from_svd"]:
from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_recording

peak_labels, _ = split_clusters(
original_labels,
templates = get_templates_from_peaks_and_recording(
recording,
features_folder,
method="local_feature_clustering",
method_kwargs=dict(
clusterer="hdbscan",
feature_name="sparse_tsvd",
neighbours_mask=neighbours_mask,
waveforms_sparse_mask=sparse_mask,
min_size_split=min_size,
clusterer_kwargs=d["hdbscan_kwargs"],
n_pca_features=5,
),
debug_folder=debug_folder,
**params["recursive_kwargs"],
peaks,
peak_labels,
ms_before,
ms_after,
**job_kwargs,
)
else:
from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd

non_noise = peak_labels > -1
labels, inverse = np.unique(peak_labels[non_noise], return_inverse=True)
peak_labels[non_noise] = inverse
labels = np.unique(inverse)

spikes = np.zeros(non_noise.sum(), dtype=minimum_spike_dtype)
spikes["sample_index"] = peaks[non_noise]["sample_index"]
spikes["segment_index"] = peaks[non_noise]["segment_index"]
spikes["unit_index"] = peak_labels[non_noise]

unit_ids = labels

nbefore = int(params["waveforms"]["ms_before"] * fs / 1000.0)
nafter = int(params["waveforms"]["ms_after"] * fs / 1000.0)

if params["noise_levels"] is None:
params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs)

templates_array = estimate_templates(
recording,
spikes,
unit_ids,
nbefore,
nafter,
return_scaled=False,
job_name=None,
**job_kwargs,
)
templates = get_templates_from_peaks_and_svd(
recording,
peaks,
peak_labels,
ms_before,
ms_after,
svd_model,
peaks_svd,
sparse_mask,
operator="median",
)

templates_array = templates.templates_array
best_channels = np.argmax(np.abs(templates_array[:, nbefore, :]), axis=1)
peak_snrs = np.abs(templates_array[:, nbefore, :])
best_snrs_ratio = (peak_snrs / params["noise_levels"])[np.arange(len(peak_snrs)), best_channels]
old_unit_ids = templates.unit_ids.copy()
valid_templates = best_snrs_ratio > params["noise_threshold"]

if d["rank"] is not None:
from spikeinterface.sortingcomponents.matching.circus import compress_templates
mask = np.isin(peak_labels, old_unit_ids[~valid_templates])
peak_labels[mask] = -1

_, _, _, templates_array = compress_templates(templates_array, d["rank"])
from spikeinterface.core.template import Templates

templates = Templates(
templates_array=templates_array[valid_templates],
sampling_frequency=fs,
nbefore=nbefore,
nbefore=templates.nbefore,
sparsity_mask=None,
channel_ids=recording.channel_ids,
unit_ids=unit_ids[valid_templates],
unit_ids=templates.unit_ids[valid_templates],
probe=recording.get_probe(),
is_scaled=False,
)

if params["debug"]:
templates_folder = tmp_folder / "dense_templates"
templates.to_zarr(folder_path=templates_folder)

sparsity = compute_sparsity(templates, noise_levels=params["noise_levels"], **params["sparsity"])
templates = templates.to_sparse(sparsity)
empty_templates = templates.sparsity_mask.sum(axis=1) == 0
old_unit_ids = templates.unit_ids.copy()
templates = remove_empty_templates(templates)

mask = np.isin(peak_labels, np.where(empty_templates)[0])
mask = np.isin(peak_labels, old_unit_ids[empty_templates])
peak_labels[mask] = -1

mask = np.isin(peak_labels, np.where(~valid_templates)[0])
peak_labels[mask] = -1
labels = np.unique(peak_labels)
labels = labels[labels >= 0]

if verbose:
print("Found %d raw clusters, starting to clean with matching" % (len(templates.unit_ids)))
if params["remove_mixtures"]:
if verbose:
print("Found %d raw clusters, starting to clean with matching" % (len(templates.unit_ids)))

cleaning_job_kwargs = job_kwargs.copy()
cleaning_job_kwargs["progress_bar"] = False
cleaning_params = params["cleaning_kwargs"].copy()
cleaning_job_kwargs = job_kwargs.copy()
cleaning_job_kwargs["progress_bar"] = False
cleaning_params = params["cleaning_kwargs"].copy()

labels, peak_labels = remove_duplicates_via_matching(
templates, peak_labels, job_kwargs=cleaning_job_kwargs, **cleaning_params
)
labels, peak_labels = remove_duplicates_via_matching(
templates, peak_labels, job_kwargs=cleaning_job_kwargs, **cleaning_params
)

if verbose:
print("Kept %d non-duplicated clusters" % len(labels))
if verbose:
print("Kept %d non-duplicated clusters" % len(labels))
else:
if verbose:
print("Kept %d raw clusters" % len(labels))

return labels, peak_labels
return labels, peak_labels, svd_model, peaks_svd, sparse_mask
Comment thread
samuelgarcia marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
radius_um = params["radius_um"]
motion = params["motion"]
seed = params["seed"]
ms_before = params["ms_before"]
ms_after = params["ms_after"]
clustering_method = params["clustering_method"]
clustering_kwargs = params["clustering_kwargs"]
graph_kwargs = params["graph_kwargs"]
Expand All @@ -70,9 +72,11 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
elif graph_kwargs["bin_mode"] == "vertical_bins":
assert radius_um >= graph_kwargs["bin_um"] * 3

peaks_svd, sparse_mask, _ = extract_peaks_svd(
peaks_svd, sparse_mask, svd_model = extract_peaks_svd(
recording,
peaks,
ms_before=ms_before,
ms_after=ms_after,
radius_um=radius_um,
motion_aware=motion_aware,
motion=None,
Expand All @@ -98,7 +102,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
# print(distances.shape)
# print("sparsity: ", distances.indices.size / (distances.shape[0]**2))

print("clustering_method", clustering_method)
# print("clustering_method", clustering_method)

if clustering_method == "networkx-louvain":
# using networkx : very slow (possible backend with cude backend="cugraph",)
Expand Down Expand Up @@ -191,7 +195,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
labels_set = np.unique(peak_labels)
labels_set = labels_set[labels_set >= 0]

return labels_set, peak_labels
return labels_set, peak_labels, svd_model, peaks_svd, sparse_mask
Comment thread
samuelgarcia marked this conversation as resolved.


def _remove_small_cluster(peak_labels, min_size=1):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def create_graph_from_peak_features(
raise ValueError("create_graph_from_peak_features : wrong bin_mode")

if progress_bar:
loop = tqdm(loop, desc=f"Construct distance graph looping over {bin_mode}")
loop = tqdm(loop, desc=f"Build distance graph over {bin_mode}")

local_graphs = []
row_indices = []
Expand Down
Loading