Skip to content

Commit aa1a6f3

Browse files
committed
speedup the collision comparison and benchmarkmatching
1 parent d1ba03d commit aa1a6f3

4 files changed

Lines changed: 154 additions & 142 deletions

File tree

src/spikeinterface/benchmark/benchmark_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,8 @@ def compute_results(self, case_keys=None, verbose=False, **result_params):
374374

375375
job_keys = []
376376
for key in case_keys:
377+
if verbose:
378+
print("### Compute result", key, "###")
377379
benchmark = self.benchmarks[key]
378380
assert benchmark is not None
379381
benchmark.compute_result(**result_params)

src/spikeinterface/benchmark/benchmark_matching.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,17 +91,20 @@ def plot_performances_ordered(self, *args, **kwargs):
9191

9292
return plot_performances_ordered(self, *args, **kwargs)
9393

94-
def plot_collisions(self, case_keys=None, figsize=None):
94+
def plot_collisions(self, case_keys=None, axs=None, figsize=None):
9595
if case_keys is None:
9696
case_keys = list(self.cases.keys())
9797
import matplotlib.pyplot as plt
9898

99-
fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False)
99+
if axs is None:
100+
fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False)
101+
axs = axs[0, :]
102+
100103

101104
for count, key in enumerate(case_keys):
102105
label = self.cases[key]["label"]
103106
templates_array = self.get_sorting_analyzer(key).get_extension("templates").get_templates(outputs="numpy")
104-
ax = axs[0, count]
107+
ax = axs[count]
105108
plot_comparison_collision_by_similarity(
106109
self.get_result(key)["gt_collision"],
107110
templates_array,
Lines changed: 120 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from __future__ import annotations
22

3+
import importlib
4+
35
from .paircomparisons import GroundTruthComparison
46

5-
# keep import as we do not want to delete code below.
6-
# from .groundtruthstudy import GroundTruthStudy
77
from .comparisontools import make_collision_events
88

99
import numpy as np
1010

11+
from tqdm.auto import tqdm
12+
1113

1214
class CollisionGTComparison(GroundTruthComparison):
1315
"""
@@ -31,98 +33,160 @@ class CollisionGTComparison(GroundTruthComparison):
3133
3234
"""
3335

34-
def __init__(self, gt_sorting, tested_sorting, collision_lag=2.0, nbins=11, **kwargs):
36+
def __init__(self, gt_sorting, tested_sorting, collision_lag=2.0, nbins=11, progress_bar=True, **kwargs):
3537
# Force compute labels
3638
kwargs["compute_labels"] = True
3739

3840
if gt_sorting.get_num_segments() > 1 or tested_sorting.get_num_segments() > 1:
3941
raise NotImplementedError("Collision comparison is only available for mono-segment sorting objects")
4042

43+
self.progress_bar = progress_bar
44+
4145
GroundTruthComparison.__init__(self, gt_sorting, tested_sorting, **kwargs)
4246

47+
48+
4349
self.collision_lag = collision_lag
4450
self.nbins = nbins
4551

4652
self.detect_gt_collision()
4753
self.compute_all_pair_collision_bins()
4854

55+
56+
4957
def detect_gt_collision(self):
5058
delta = int(self.collision_lag / 1000 * self.sampling_frequency)
51-
self.collision_events = make_collision_events(self.sorting1, delta)
59+
self.collision_events = make_collision_events(self.sorting1, delta, progress_bar=self.progress_bar)
60+
61+
# def get_label_for_collision(self, gt_unit_id1, gt_unit_id2):
62+
# gt_index1 = self.sorting1.id_to_index(gt_unit_id1)
63+
# gt_index2 = self.sorting1.id_to_index(gt_unit_id2)
64+
# if gt_index1 > gt_index2:
65+
# gt_unit_id1, gt_unit_id2 = gt_unit_id2, gt_unit_id1
66+
# reversed = True
67+
# else:
68+
# reversed = False
69+
70+
# # events
71+
# mask = (self.collision_events["unit_id1"] == gt_unit_id1) & (self.collision_events["unit_id2"] == gt_unit_id2)
72+
# event = self.collision_events[mask]
73+
74+
# score_label1 = self._labels_st1[gt_unit_id1][0][event["index1"]]
75+
# score_label2 = self._labels_st1[gt_unit_id2][0][event["index2"]]
76+
# delta = event["delta_frame"]
77+
78+
# if reversed:
79+
# score_label1, score_label2 = score_label2, score_label1
80+
# delta = -delta
81+
82+
# return score_label1, score_label2, delta
83+
84+
# def get_label_count_per_collision_bins(self, gt_unit_id1, gt_unit_id2, bins):
85+
# score_label1, score_label2, delta = self.get_label_for_collision(gt_unit_id1, gt_unit_id2)
5286

53-
def get_label_for_collision(self, gt_unit_id1, gt_unit_id2):
54-
gt_index1 = self.sorting1.id_to_index(gt_unit_id1)
55-
gt_index2 = self.sorting1.id_to_index(gt_unit_id2)
56-
if gt_index1 > gt_index2:
57-
gt_unit_id1, gt_unit_id2 = gt_unit_id2, gt_unit_id1
58-
reversed = True
59-
else:
60-
reversed = False
87+
# tp_count1 = np.zeros(bins.size - 1)
88+
# fn_count1 = np.zeros(bins.size - 1)
89+
# tp_count2 = np.zeros(bins.size - 1)
90+
# fn_count2 = np.zeros(bins.size - 1)
6191

62-
# events
63-
mask = (self.collision_events["unit_id1"] == gt_unit_id1) & (self.collision_events["unit_id2"] == gt_unit_id2)
64-
event = self.collision_events[mask]
92+
# for i in range(tp_count1.size):
93+
# l0, l1 = bins[i], bins[i + 1]
94+
# mask = (delta >= l0) & (delta < l1)
6595

66-
score_label1 = self._labels_st1[gt_unit_id1][0][event["index1"]]
67-
score_label2 = self._labels_st1[gt_unit_id2][0][event["index2"]]
68-
delta = event["delta_frame"]
96+
# tp_count1[i] = np.sum(score_label1[mask] == "TP")
97+
# fn_count1[i] = np.sum(score_label1[mask] == "FN")
98+
# tp_count2[i] = np.sum(score_label2[mask] == "TP")
99+
# fn_count2[i] = np.sum(score_label2[mask] == "FN")
69100

70-
if reversed:
71-
score_label1, score_label2 = score_label2, score_label1
72-
delta = -delta
101+
# # inverse for unit_id2
102+
# tp_count2 = tp_count2[::-1]
103+
# fn_count2 = fn_count2[::-1]
73104

74-
return score_label1, score_label2, delta
105+
# return tp_count1, fn_count1, tp_count2, fn_count2
75106

76-
def get_label_count_per_collision_bins(self, gt_unit_id1, gt_unit_id2, bins):
77-
score_label1, score_label2, delta = self.get_label_for_collision(gt_unit_id1, gt_unit_id2)
107+
# def compute_all_pair_collision_bins(self):
108+
# print('CollisionGTComparison.compute_all_pair_collision_bins')
109+
# d = int(self.collision_lag / 1000 * self.sampling_frequency)
110+
# bins = np.linspace(-d, d, self.nbins + 1)
111+
# self.bins = bins
78112

79-
tp_count1 = np.zeros(bins.size - 1)
80-
fn_count1 = np.zeros(bins.size - 1)
81-
tp_count2 = np.zeros(bins.size - 1)
82-
fn_count2 = np.zeros(bins.size - 1)
113+
# unit_ids = self.sorting1.unit_ids
114+
# n = len(unit_ids)
83115

84-
for i in range(tp_count1.size):
85-
l0, l1 = bins[i], bins[i + 1]
86-
mask = (delta >= l0) & (delta < l1)
116+
# all_tp_count1 = []
117+
# all_fn_count1 = []
118+
# all_tp_count2 = []
119+
# all_fn_count2 = []
87120

88-
tp_count1[i] = np.sum(score_label1[mask] == "TP")
89-
fn_count1[i] = np.sum(score_label1[mask] == "FN")
90-
tp_count2[i] = np.sum(score_label2[mask] == "TP")
91-
fn_count2[i] = np.sum(score_label2[mask] == "FN")
121+
# self.all_tp = np.zeros((n, n, self.nbins), dtype="int64")
122+
# self.all_fn = np.zeros((n, n, self.nbins), dtype="int64")
92123

93-
# inverse for unit_id2
94-
tp_count2 = tp_count2[::-1]
95-
fn_count2 = fn_count2[::-1]
124+
# for i in range(n):
125+
# print(i, n)
126+
# for j in range(i + 1, n):
127+
# u1 = unit_ids[i]
128+
# u2 = unit_ids[j]
96129

97-
return tp_count1, fn_count1, tp_count2, fn_count2
130+
# tp_count1, fn_count1, tp_count2, fn_count2 = self.get_label_count_per_collision_bins(u1, u2, bins)
131+
132+
# self.all_tp[i, j, :] = tp_count1
133+
# self.all_tp[j, i, :] = tp_count2
134+
# self.all_fn[i, j, :] = fn_count1
135+
# self.all_fn[j, i, :] = fn_count2
98136

99137
def compute_all_pair_collision_bins(self):
100138
d = int(self.collision_lag / 1000 * self.sampling_frequency)
101139
bins = np.linspace(-d, d, self.nbins + 1)
102140
self.bins = bins
103141

104-
unit_ids = self.sorting1.unit_ids
105-
n = len(unit_ids)
142+
collision_events = self.collision_events
143+
labels_st1 = self._labels_st1
144+
gt_unit_ids = self.sorting1.unit_ids
145+
146+
nbins = bins.size -1
147+
n = len(gt_unit_ids)
148+
all_tp = np.zeros((n, n, nbins), dtype="int64")
149+
all_fn = np.zeros((n, n, nbins), dtype="int64")
150+
151+
unit_ids1 = collision_events['unit_id1']
152+
unit_indices1 = collision_events['unit_index1']
153+
unit_ids2 = collision_events['unit_id2']
154+
unit_indices2 = collision_events['unit_index2']
155+
156+
spike_indices1 = collision_events['index1']
157+
spike_indices2 = collision_events['index2']
158+
delta_frame = collision_events['delta_frame']
159+
delta_frame
160+
delta_bin = np.clip(np.floor((delta_frame - bins[0]) / (bins[1] - bins[0])), 0, nbins-1).astype('int64')
161+
inv_delta_bin = np.clip(np.floor((-delta_frame - bins[0]) / (bins[1] - bins[0])), 0, nbins-1).astype('int64')
106162

107-
all_tp_count1 = []
108-
all_fn_count1 = []
109-
all_tp_count2 = []
110-
all_fn_count2 = []
163+
seg_index = 0
111164

112-
self.all_tp = np.zeros((n, n, self.nbins), dtype="int64")
113-
self.all_fn = np.zeros((n, n, self.nbins), dtype="int64")
165+
loop = range(len(unit_ids1))
166+
if self.progress_bar:
167+
loop = tqdm(loop, desc="collision by bin")
168+
169+
for c in loop:
114170

115-
for i in range(n):
116-
for j in range(i + 1, n):
117-
u1 = unit_ids[i]
118-
u2 = unit_ids[j]
171+
score1 = labels_st1[unit_ids1[c]][seg_index][spike_indices1[c]]
172+
score2 = labels_st1[unit_ids2[c]][seg_index][spike_indices2[c]]
119173

120-
tp_count1, fn_count1, tp_count2, fn_count2 = self.get_label_count_per_collision_bins(u1, u2, bins)
174+
unit_index1 = unit_indices1[c]
175+
unit_index2 = unit_indices2[c]
176+
177+
if score1 == "TP":
178+
all_tp[unit_index1, unit_index2, delta_bin[c]] += 1
179+
else:
180+
all_fn[unit_index1, unit_index2, delta_bin[c]] += 1
181+
182+
if score2 == "TP":
183+
all_tp[unit_index2, unit_index1, inv_delta_bin[c]] += 1
184+
else:
185+
all_fn[unit_index2, unit_index1, inv_delta_bin[c]] += 1
186+
187+
self.all_tp = all_tp
188+
self.all_fn = all_fn
121189

122-
self.all_tp[i, j, :] = tp_count1
123-
self.all_tp[j, i, :] = tp_count2
124-
self.all_fn[i, j, :] = fn_count1
125-
self.all_fn[j, i, :] = fn_count2
126190

127191
def compute_collision_by_similarity(self, similarity_matrix, unit_ids=None, good_only=False, min_accuracy=0.9):
128192
if unit_ids is None:
@@ -172,77 +236,3 @@ def compute_collision_by_similarity(self, similarity_matrix, unit_ids=None, good
172236
pair_names = pair_names[order]
173237

174238
return similarities, recall_scores, pair_names
175-
176-
177-
# This is removed at the moment.
178-
# We need to move this maybe one day in benchmark.
179-
# please do not delete this
180-
181-
# class CollisionGTStudy(GroundTruthStudy):
182-
# def run_comparisons(self, case_keys=None, exhaustive_gt=True, collision_lag=2.0, nbins=11, **kwargs):
183-
# _kwargs = dict()
184-
# _kwargs.update(kwargs)
185-
# _kwargs["exhaustive_gt"] = exhaustive_gt
186-
# _kwargs["collision_lag"] = collision_lag
187-
# _kwargs["nbins"] = nbins
188-
# GroundTruthStudy.run_comparisons(self, case_keys=None, comparison_class=CollisionGTComparison, **_kwargs)
189-
# self.exhaustive_gt = exhaustive_gt
190-
# self.collision_lag = collision_lag
191-
192-
# def get_lags(self, key):
193-
# comp = self.comparisons[key]
194-
# fs = comp.sorting1.get_sampling_frequency()
195-
# lags = comp.bins / fs * 1000.0
196-
# return lags
197-
198-
# def precompute_scores_by_similarities(self, case_keys=None, good_only=False, min_accuracy=0.9):
199-
# import sklearn
200-
201-
# if case_keys is None:
202-
# case_keys = self.cases.keys()
203-
204-
# self.all_similarities = {}
205-
# self.all_recall_scores = {}
206-
# self.good_only = good_only
207-
208-
# for key in case_keys:
209-
# templates = self.get_templates(key)
210-
# flat_templates = templates.reshape(templates.shape[0], -1)
211-
# similarity = sklearn.metrics.pairwise.cosine_similarity(flat_templates)
212-
# comp = self.comparisons[key]
213-
# similarities, recall_scores, pair_names = comp.compute_collision_by_similarity(
214-
# similarity, good_only=good_only, min_accuracy=min_accuracy
215-
# )
216-
# self.all_similarities[key] = similarities
217-
# self.all_recall_scores[key] = recall_scores
218-
219-
# def get_mean_over_similarity_range(self, similarity_range, key):
220-
# idx = (self.all_similarities[key] >= similarity_range[0]) & (self.all_similarities[key] <= similarity_range[1])
221-
# all_similarities = self.all_similarities[key][idx]
222-
# all_recall_scores = self.all_recall_scores[key][idx]
223-
224-
# order = np.argsort(all_similarities)
225-
# all_similarities = all_similarities[order]
226-
# all_recall_scores = all_recall_scores[order, :]
227-
228-
# mean_recall_scores = np.nanmean(all_recall_scores, axis=0)
229-
230-
# return mean_recall_scores
231-
232-
# def get_lag_profile_over_similarity_bins(self, similarity_bins, key):
233-
# all_similarities = self.all_similarities[key]
234-
# all_recall_scores = self.all_recall_scores[key]
235-
236-
# order = np.argsort(all_similarities)
237-
# all_similarities = all_similarities[order]
238-
# all_recall_scores = all_recall_scores[order, :]
239-
240-
# result = {}
241-
242-
# for i in range(similarity_bins.size - 1):
243-
# cmin, cmax = similarity_bins[i], similarity_bins[i + 1]
244-
# amin, amax = np.searchsorted(all_similarities, [cmin, cmax])
245-
# mean_recall_scores = np.nanmean(all_recall_scores[amin:amax], axis=0)
246-
# result[(cmin, cmax)] = mean_recall_scores
247-
248-
# return result

0 commit comments

Comments
 (0)