Skip to content
Merged
Show file tree
Hide file tree
Changes from 57 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
dd8f443
low level auto merge using template similarity for sorting components.
samuelgarcia May 14, 2025
3fcd9d7
improve auto merge in tdc_lustering and cicurs_clustering
samuelgarcia May 15, 2025
90e0b14
improve drift aware clustering tdc
samuelgarcia May 16, 2025
3653c81
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
samuelgarcia May 16, 2025
221e780
add with_template=False in BenchmarkClustering.compute_result
samuelgarcia Jun 4, 2025
0edfc77
oups
samuelgarcia Jun 13, 2025
bed91f2
update tdc
samuelgarcia Jun 16, 2025
ea7abd8
oups
samuelgarcia Jun 16, 2025
3ae7e6e
oups
samuelgarcia Jun 16, 2025
5c65137
fix
samuelgarcia Jun 16, 2025
811c0f1
tests
samuelgarcia Jun 16, 2025
bfa6e23
Merge branch 'main' into components_merge_templates
samuelgarcia Jun 17, 2025
6a24e8b
merge main and fixes
samuelgarcia Jun 17, 2025
b1ec837
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
samuelgarcia Jun 17, 2025
0102ad9
clean
samuelgarcia Jun 17, 2025
5c9b641
Fix MatchingStudy.plot_collisions
samuelgarcia Jun 17, 2025
d1ba03d
small fixes in circus-clustering
samuelgarcia Jun 17, 2025
aa1a6f3
speedup the collision comparison and benchmarkmatching
samuelgarcia Jun 18, 2025
861f73a
wip
samuelgarcia Jun 23, 2025
bef679b
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
samuelgarcia Jun 25, 2025
6129df5
Better sparsity for analyzer in benchmarks
samuelgarcia Jun 25, 2025
1b88a97
Fix some etra_outputs in clustering methods.
samuelgarcia Jun 25, 2025
2462ae5
merge conflict
samuelgarcia Jun 25, 2025
2d3d654
more fix in comparison
samuelgarcia Jun 25, 2025
9d93aaf
improve plot_performances_vs_snr()
samuelgarcia Jul 8, 2025
8150c47
improve tdridesclous2
samuelgarcia Jul 8, 2025
783cd8f
wip circus clustering
samuelgarcia Jul 8, 2025
97762b5
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
samuelgarcia Jul 8, 2025
867dc3e
Improve tridesclous2
samuelgarcia Jul 10, 2025
b823a97
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
samuelgarcia Jul 10, 2025
0729ddf
Fixes for the merging branch of Sam, mostly for SC2 purposes and to t…
yger Jul 10, 2025
a9bb5f6
clean
samuelgarcia Jul 10, 2025
7ca35a0
Add isosplit for tests
samuelgarcia Jul 10, 2025
8027150
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
samuelgarcia Jul 11, 2025
620076a
improve circus
yger Jul 11, 2025
f3eef33
Merge branch 'components_merge_templates' of github.com:samuelgarcia/…
samuelgarcia Jul 11, 2025
e50bda1
handling when isoplit6 is not installable in tridesclous2
samuelgarcia Jul 11, 2025
7794728
warn instead of raise for motion and spatial windows
samuelgarcia Jul 11, 2025
9dcdfcd
levels_to_keep > levels_to_group_by in benchmarks
samuelgarcia Jul 11, 2025
a4e8f78
benchmark improvs plot_unit_counts and plot_run_times
samuelgarcia Jul 16, 2025
7c93185
benchmark : replace seaborn by matplotlib for better cosmetic control
samuelgarcia Jul 17, 2025
6cf2248
plot_count_unit improvement
samuelgarcia Jul 17, 2025
ac091d2
Small fixes
samuelgarcia Jul 18, 2025
6eab417
optional installation of kilosort4like sorter (experimental and hidde…
samuelgarcia Jul 18, 2025
c944850
debug order of import and Template partially imported
samuelgarcia Jul 21, 2025
5b27515
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
samuelgarcia Jul 24, 2025
d1a7735
merge with main
samuelgarcia Jul 24, 2025
255d25c
fix tests
samuelgarcia Jul 24, 2025
af41d07
clean_template()
yger Jul 25, 2025
e37754b
clean_template in tridesclous2
samuelgarcia Jul 25, 2025
9a12364
fix motion tests
samuelgarcia Jul 25, 2025
edb2802
fix test when no isosplit6
samuelgarcia Jul 25, 2025
07bd8bc
merge and fix
samuelgarcia Jul 28, 2025
2693641
fix launcher
samuelgarcia Jul 28, 2025
2845809
skip tdc-clustering
samuelgarcia Jul 28, 2025
ebbd182
oups
samuelgarcia Jul 28, 2025
37e3734
fix conflict
samuelgarcia Jul 28, 2025
c8c02de
Merge branch 'main' into components_merge_templates
samuelgarcia Jul 28, 2025
6f09e17
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 28, 2025
0a39b12
tdc2 and sc2 versions are now yyyy.mm
samuelgarcia Jul 28, 2025
163bb68
Merge branch 'components_merge_templates' of github.com:samuelgarcia/…
samuelgarcia Jul 28, 2025
c8856eb
oups
samuelgarcia Jul 28, 2025
9e23a9e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 28, 2025
78a0ac1
try debug sc2 windows
samuelgarcia Jul 28, 2025
da9dd6a
Merge branch 'components_merge_templates' of github.com:samuelgarcia/…
samuelgarcia Jul 28, 2025
898f6d6
try debug sc2 windows
samuelgarcia Jul 28, 2025
1af1ebb
try debug sc2 windows
samuelgarcia Jul 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,14 @@ test = [
# streaming templates
"s3fs",

# tridesclous
# tridesclous2
"numba<0.61.0;python_version<'3.13'",
"numba>=0.61.0;python_version>='3.13'",
"hdbscan>=0.8.33", # Previous version had a broken wheel

# isosplit is needed for trideclous2 noramaly but isosplit is only build until python3.11
# so lets wait a new build of isosplit6
# "isosplit6",

# for sortingview backend
"sortingview>=0.12.0",
Expand Down
75 changes: 55 additions & 20 deletions src/spikeinterface/benchmark/benchmark_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import time


from spikeinterface.core import SortingAnalyzer, ChannelSparsity, NumpySorting
from spikeinterface.core import SortingAnalyzer, NumpySorting
from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs
from spikeinterface import load, create_sorting_analyzer, load_sorting_analyzer
Expand Down Expand Up @@ -46,6 +47,7 @@ def __init__(self, study_folder):
self.levels = None
self.colors_by_case = None
self.colors_by_levels = {}
self.labels_by_levels = {}
self.scan_folder()

@classmethod
Expand Down Expand Up @@ -120,8 +122,29 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None):
rec, gt_sorting = data

if gt_sorting is not None:
if "gt_unit_locations" in gt_sorting.get_property_keys():
# if real units locations is present then use it for a better sparsity
# then the real max channel is used
radius_um = 100.
channel_ids = rec.channel_ids
unit_ids = gt_sorting.unit_ids
gt_unit_locations = gt_sorting.get_property("gt_unit_locations")
channel_locations = rec.get_channel_locations()
max_channel_indices = np.argmin(np.linalg.norm(gt_unit_locations[:, np.newaxis, :2] - channel_locations[np.newaxis, :], axis=2), axis=1)
mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool")
distances = np.linalg.norm(channel_locations[:, np.newaxis] - channel_locations[np.newaxis, :], axis=2)
for unit_ind, unit_id in enumerate(unit_ids):
chan_ind = max_channel_indices[unit_ind]
(chan_inds,) = np.nonzero(distances[chan_ind, :] <= radius_um)
mask[unit_ind, chan_inds] = True
sparsity = ChannelSparsity(mask, unit_ids, channel_ids)
sparse =False
else:
sparse = True
sparsity = None

analyzer = create_sorting_analyzer(
gt_sorting, rec, sparse=True, format="binary_folder", folder=local_analyzer_folder
gt_sorting, rec, sparse=sparse, sparsity=sparsity, format="binary_folder", folder=local_analyzer_folder
)
analyzer.compute("random_spikes")
analyzer.compute("templates")
Expand All @@ -135,6 +158,7 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None):
analyzer = create_sorting_analyzer(
gt_sorting, rec, sparse=False, format="binary_folder", folder=local_analyzer_folder
)

else:
# new case : analzyer
assert isinstance(data, SortingAnalyzer)
Expand Down Expand Up @@ -254,6 +278,8 @@ def run(self, case_keys=None, keep=True, verbose=False, **job_kwargs):

for key in job_keys:
benchmark = self.create_benchmark(key)
if verbose:
print("### Run benchmark", key, "###")
t0 = time.perf_counter()
benchmark.run(**job_kwargs)
t1 = time.perf_counter()
Expand Down Expand Up @@ -324,14 +350,16 @@ def get_run_times(self, case_keys=None):
df.index.names = self.levels
return df

def get_grouped_keys_mapping(self, levels_to_group_by=None):
def get_grouped_keys_mapping(self, levels_to_group_by=None, case_keys=None):
"""
Return a dictionary of grouped keys.

Parameters
----------
levels_to_group_by : list
A list of levels to group by.
case_keys : list
Optionaly a sub list of case_keys to consider

Returns
-------
Expand All @@ -341,18 +369,19 @@ def get_grouped_keys_mapping(self, levels_to_group_by=None):
labels : dict
A dictionary of labels, with the new keys as keys and the labels as values.
"""
cases = list(self.cases.keys())
if case_keys is None:
case_keys = list(self.cases.keys())
if levels_to_group_by is None or self.levels is None:
keys_mapping = {key: [key] for key in cases}
keys_mapping = {key: [key] for key in case_keys}
elif len(self.levels) == 1:
keys_mapping = {key: [key] for key in cases}
keys_mapping = {key: [key] for key in case_keys}
else:
study_levels = self.levels
assert np.all(
[l in study_levels for l in levels_to_group_by]
), f"levels_to_group_by must be in {study_levels}, got {levels_to_group_by}"
keys_mapping = {}
for key in cases:
for key in case_keys:
new_key = tuple(key[list(study_levels).index(level)] for level in levels_to_group_by)
if len(new_key) == 1:
new_key = new_key[0]
Expand All @@ -361,13 +390,17 @@ def get_grouped_keys_mapping(self, levels_to_group_by=None):
keys_mapping[new_key].append(key)

if levels_to_group_by is None:
labels = {key: self.cases[key]["label"] for key in cases}
labels = {key: self.cases[key]["label"] for key in case_keys}
else:
key0 = list(keys_mapping.keys())[0]
if isinstance(key0, tuple):
labels = {key: "-".join(key) for key in keys_mapping}
level_key = tuple(levels_to_group_by) if len(levels_to_group_by) > 1 else levels_to_group_by[0]
if level_key in self.labels_by_levels:
labels = self.labels_by_levels[level_key]
else:
labels = {key: key for key in keys_mapping}
key0 = list(keys_mapping.keys())[0]
if isinstance(key0, tuple):
labels = {key: "-".join(key) for key in keys_mapping}
else:
labels = {key: key for key in keys_mapping}

return keys_mapping, labels

Expand All @@ -383,6 +416,8 @@ def compute_results(self, case_keys=None, verbose=False, **result_params):

job_keys = []
for key in case_keys:
if verbose:
print("### Compute result", key, "###")
benchmark = self.benchmarks[key]
assert benchmark is not None
benchmark.compute_result(**result_params)
Expand Down Expand Up @@ -438,9 +473,9 @@ def get_gt_unit_locations(self, case_key):
unit_locations_ext = sorting_analyzer.get_extension("unit_locations")
return unit_locations_ext.get_data()

def get_templates(self, key, operator="average"):
def get_templates(self, key, operator="average", outputs='numpy'):
sorting_analyzer = self.get_sorting_analyzer(case_key=key)
templates = sorting_analyzer.get_extenson("templates").get_data(operator=operator)
templates = sorting_analyzer.get_extension("templates").get_data(operator=operator, outputs=outputs)
return templates

def compute_metrics(self, case_keys=None, metric_names=["snr", "firing_rate"], force=False, **job_kwargs):
Expand Down Expand Up @@ -668,15 +703,15 @@ def get_count_units(self, case_keys=None, well_detected_score=None, redundant_sc
gt_sorting = comp.sorting1
sorting = comp.sorting2

count_units.loc[key, "num_gt"] = len(gt_sorting.get_unit_ids())
count_units.loc[key, "num_sorter"] = len(sorting.get_unit_ids())
count_units.loc[key, "num_well_detected"] = comp.count_well_detected_units(well_detected_score)
count_units.at[key, "num_gt"] = len(gt_sorting.get_unit_ids())
count_units.at[key, "num_sorter"] = len(sorting.get_unit_ids())
count_units.at[key, "num_well_detected"] = comp.count_well_detected_units(well_detected_score)

if comp.exhaustive_gt:
count_units.loc[key, "num_redundant"] = comp.count_redundant_units(redundant_score)
count_units.loc[key, "num_overmerged"] = comp.count_overmerged_units(overmerged_score)
count_units.loc[key, "num_false_positive"] = comp.count_false_positive_units(redundant_score)
count_units.loc[key, "num_bad"] = comp.count_bad_units()
count_units.at[key, "num_redundant"] = comp.count_redundant_units(redundant_score)
count_units.at[key, "num_overmerged"] = comp.count_overmerged_units(overmerged_score)
count_units.at[key, "num_false_positive"] = comp.count_false_positive_units(redundant_score)
count_units.at[key, "num_bad"] = comp.count_bad_units()

return count_units

Expand Down
29 changes: 15 additions & 14 deletions src/spikeinterface/benchmark/benchmark_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def run(self, **job_kwargs):
)
self.result["peak_labels"] = peak_labels

def compute_result(self, **result_params):
result_params, job_kwargs = split_job_kwargs(result_params)
def compute_result(self, with_template=False, **job_kwargs):
# result_params, job_kwargs = split_job_kwargs(result_params)
job_kwargs = fix_job_kwargs(job_kwargs)
self.noise = self.result["peak_labels"] < 0
spikes = self.gt_sorting.to_spike_vector()
Expand Down Expand Up @@ -68,19 +68,20 @@ def compute_result(self, **result_params):
self.result["sliced_gt_sorting"], self.result["clustering"], exhaustive_gt=self.exhaustive_gt
)

sorting_analyzer = create_sorting_analyzer(
self.result["sliced_gt_sorting"], self.recording, format="memory", sparse=False, **job_kwargs
)
sorting_analyzer.compute("random_spikes")
ext = sorting_analyzer.compute("templates", **job_kwargs)
self.result["sliced_gt_templates"] = ext.get_data(outputs="Templates")
if with_template:
sorting_analyzer = create_sorting_analyzer(
self.result["sliced_gt_sorting"], self.recording, format="memory", sparse=False, **job_kwargs
)
sorting_analyzer.compute("random_spikes")
ext = sorting_analyzer.compute("templates", **job_kwargs)
self.result["sliced_gt_templates"] = ext.get_data(outputs="Templates")

sorting_analyzer = create_sorting_analyzer(
self.result["clustering"], self.recording, format="memory", sparse=False, **job_kwargs
)
sorting_analyzer.compute("random_spikes")
ext = sorting_analyzer.compute("templates", **job_kwargs)
self.result["clustering_templates"] = ext.get_data(outputs="Templates")
sorting_analyzer = create_sorting_analyzer(
self.result["clustering"], self.recording, format="memory", sparse=False, **job_kwargs
)
sorting_analyzer.compute("random_spikes")
ext = sorting_analyzer.compute("templates", **job_kwargs)
self.result["clustering_templates"] = ext.get_data(outputs="Templates")

_run_key_saved = [("peak_labels", "npy")]

Expand Down
26 changes: 19 additions & 7 deletions src/spikeinterface/benchmark/benchmark_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,23 +91,35 @@ def plot_performances_ordered(self, *args, **kwargs):

return plot_performances_ordered(self, *args, **kwargs)

def plot_collisions(self, case_keys=None, figsize=None):
def plot_collisions(self, case_keys=None, metric="l2", mode="lines", show_legend=True, axs=None, figsize=None):
if case_keys is None:
case_keys = list(self.cases.keys())
import matplotlib.pyplot as plt

fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False)
if axs is None:
fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False)
axs = axs[0, :]
else:
fig = axs[0].figure


for count, key in enumerate(case_keys):
templates_array = self.get_result(key)["templates"].templates_array
label = self.cases[key]["label"]
templates_array = self.get_sorting_analyzer(key).get_extension("templates").get_templates(outputs="numpy")
ax = axs[count]
plot_comparison_collision_by_similarity(
self.get_result(key)["gt_collision"],
templates_array,
ax=axs[0, count],
show_legend=True,
mode="lines",
good_only=False,
metric=metric,
ax=ax,
show_legend=show_legend,
mode=mode,
# good_only=False,
# good_only=False,
good_only=True,
)

ax.set_title(label)

return fig

Expand Down
Loading
Loading