11from __future__ import annotations
22
3+ import importlib
4+
35from .paircomparisons import GroundTruthComparison
46
5- # keep import as we do not want to delete code below.
6- # from .groundtruthstudy import GroundTruthStudy
77from .comparisontools import make_collision_events
88
99import numpy as np
1010
11+ from tqdm .auto import tqdm
12+
1113
1214class CollisionGTComparison (GroundTruthComparison ):
1315 """
@@ -31,98 +33,160 @@ class CollisionGTComparison(GroundTruthComparison):
3133
3234 """
3335
34- def __init__ (self , gt_sorting , tested_sorting , collision_lag = 2.0 , nbins = 11 , ** kwargs ):
36+ def __init__ (self , gt_sorting , tested_sorting , collision_lag = 2.0 , nbins = 11 , progress_bar = True , ** kwargs ):
3537 # Force compute labels
3638 kwargs ["compute_labels" ] = True
3739
3840 if gt_sorting .get_num_segments () > 1 or tested_sorting .get_num_segments () > 1 :
3941 raise NotImplementedError ("Collision comparison is only available for mono-segment sorting objects" )
4042
43+ self .progress_bar = progress_bar
44+
4145 GroundTruthComparison .__init__ (self , gt_sorting , tested_sorting , ** kwargs )
4246
47+
48+
4349 self .collision_lag = collision_lag
4450 self .nbins = nbins
4551
4652 self .detect_gt_collision ()
4753 self .compute_all_pair_collision_bins ()
4854
55+
56+
4957 def detect_gt_collision (self ):
5058 delta = int (self .collision_lag / 1000 * self .sampling_frequency )
51- self .collision_events = make_collision_events (self .sorting1 , delta )
59+ self .collision_events = make_collision_events (self .sorting1 , delta , progress_bar = self .progress_bar )
60+
61+ # def get_label_for_collision(self, gt_unit_id1, gt_unit_id2):
62+ # gt_index1 = self.sorting1.id_to_index(gt_unit_id1)
63+ # gt_index2 = self.sorting1.id_to_index(gt_unit_id2)
64+ # if gt_index1 > gt_index2:
65+ # gt_unit_id1, gt_unit_id2 = gt_unit_id2, gt_unit_id1
66+ # reversed = True
67+ # else:
68+ # reversed = False
69+
70+ # # events
71+ # mask = (self.collision_events["unit_id1"] == gt_unit_id1) & (self.collision_events["unit_id2"] == gt_unit_id2)
72+ # event = self.collision_events[mask]
73+
74+ # score_label1 = self._labels_st1[gt_unit_id1][0][event["index1"]]
75+ # score_label2 = self._labels_st1[gt_unit_id2][0][event["index2"]]
76+ # delta = event["delta_frame"]
77+
78+ # if reversed:
79+ # score_label1, score_label2 = score_label2, score_label1
80+ # delta = -delta
81+
82+ # return score_label1, score_label2, delta
83+
84+ # def get_label_count_per_collision_bins(self, gt_unit_id1, gt_unit_id2, bins):
85+ # score_label1, score_label2, delta = self.get_label_for_collision(gt_unit_id1, gt_unit_id2)
5286
53- def get_label_for_collision (self , gt_unit_id1 , gt_unit_id2 ):
54- gt_index1 = self .sorting1 .id_to_index (gt_unit_id1 )
55- gt_index2 = self .sorting1 .id_to_index (gt_unit_id2 )
56- if gt_index1 > gt_index2 :
57- gt_unit_id1 , gt_unit_id2 = gt_unit_id2 , gt_unit_id1
58- reversed = True
59- else :
60- reversed = False
87+ # tp_count1 = np.zeros(bins.size - 1)
88+ # fn_count1 = np.zeros(bins.size - 1)
89+ # tp_count2 = np.zeros(bins.size - 1)
90+ # fn_count2 = np.zeros(bins.size - 1)
6191
62- # events
63- mask = ( self . collision_events [ "unit_id1" ] == gt_unit_id1 ) & ( self . collision_events [ "unit_id2" ] == gt_unit_id2 )
64- event = self . collision_events [ mask ]
92+ # for i in range(tp_count1.size):
93+ # l0, l1 = bins[i], bins[i + 1]
94+ # mask = (delta >= l0) & (delta < l1)
6595
66- score_label1 = self ._labels_st1 [gt_unit_id1 ][0 ][event ["index1" ]]
67- score_label2 = self ._labels_st1 [gt_unit_id2 ][0 ][event ["index2" ]]
68- delta = event ["delta_frame" ]
96+ # tp_count1[i] = np.sum(score_label1[mask] == "TP")
97+ # fn_count1[i] = np.sum(score_label1[mask] == "FN")
98+ # tp_count2[i] = np.sum(score_label2[mask] == "TP")
99+ # fn_count2[i] = np.sum(score_label2[mask] == "FN")
69100
70- if reversed :
71- score_label1 , score_label2 = score_label2 , score_label1
72- delta = - delta
101+ # # inverse for unit_id2
102+ # tp_count2 = tp_count2[::-1]
103+ # fn_count2 = fn_count2[::-1]
73104
74- return score_label1 , score_label2 , delta
105+ # return tp_count1, fn_count1, tp_count2, fn_count2
75106
76- def get_label_count_per_collision_bins (self , gt_unit_id1 , gt_unit_id2 , bins ):
77- score_label1 , score_label2 , delta = self .get_label_for_collision (gt_unit_id1 , gt_unit_id2 )
107+ # def compute_all_pair_collision_bins(self):
108+ # print('CollisionGTComparison.compute_all_pair_collision_bins')
109+ # d = int(self.collision_lag / 1000 * self.sampling_frequency)
110+ # bins = np.linspace(-d, d, self.nbins + 1)
111+ # self.bins = bins
78112
79- tp_count1 = np .zeros (bins .size - 1 )
80- fn_count1 = np .zeros (bins .size - 1 )
81- tp_count2 = np .zeros (bins .size - 1 )
82- fn_count2 = np .zeros (bins .size - 1 )
113+ # unit_ids = self.sorting1.unit_ids
114+ # n = len(unit_ids)
83115
84- for i in range (tp_count1 .size ):
85- l0 , l1 = bins [i ], bins [i + 1 ]
86- mask = (delta >= l0 ) & (delta < l1 )
116+ # all_tp_count1 = []
117+ # all_fn_count1 = []
118+ # all_tp_count2 = []
119+ # all_fn_count2 = []
87120
88- tp_count1 [i ] = np .sum (score_label1 [mask ] == "TP" )
89- fn_count1 [i ] = np .sum (score_label1 [mask ] == "FN" )
90- tp_count2 [i ] = np .sum (score_label2 [mask ] == "TP" )
91- fn_count2 [i ] = np .sum (score_label2 [mask ] == "FN" )
121+ # self.all_tp = np.zeros((n, n, self.nbins), dtype="int64")
122+ # self.all_fn = np.zeros((n, n, self.nbins), dtype="int64")
92123
93- # inverse for unit_id2
94- tp_count2 = tp_count2 [::- 1 ]
95- fn_count2 = fn_count2 [::- 1 ]
124+ # for i in range(n):
125+ # print(i, n)
126+ # for j in range(i + 1, n):
127+ # u1 = unit_ids[i]
128+ # u2 = unit_ids[j]
96129
97- return tp_count1 , fn_count1 , tp_count2 , fn_count2
130+ # tp_count1, fn_count1, tp_count2, fn_count2 = self.get_label_count_per_collision_bins(u1, u2, bins)
131+
132+ # self.all_tp[i, j, :] = tp_count1
133+ # self.all_tp[j, i, :] = tp_count2
134+ # self.all_fn[i, j, :] = fn_count1
135+ # self.all_fn[j, i, :] = fn_count2
98136
99137 def compute_all_pair_collision_bins (self ):
100138 d = int (self .collision_lag / 1000 * self .sampling_frequency )
101139 bins = np .linspace (- d , d , self .nbins + 1 )
102140 self .bins = bins
103141
104- unit_ids = self .sorting1 .unit_ids
105- n = len (unit_ids )
142+ collision_events = self .collision_events
143+ labels_st1 = self ._labels_st1
144+ gt_unit_ids = self .sorting1 .unit_ids
145+
146+ nbins = bins .size - 1
147+ n = len (gt_unit_ids )
148+ all_tp = np .zeros ((n , n , nbins ), dtype = "int64" )
149+ all_fn = np .zeros ((n , n , nbins ), dtype = "int64" )
150+
151+ unit_ids1 = collision_events ['unit_id1' ]
152+ unit_indices1 = collision_events ['unit_index1' ]
153+ unit_ids2 = collision_events ['unit_id2' ]
154+ unit_indices2 = collision_events ['unit_index2' ]
155+
156+ spike_indices1 = collision_events ['index1' ]
157+ spike_indices2 = collision_events ['index2' ]
158+ delta_frame = collision_events ['delta_frame' ]
159+ delta_frame
160+ delta_bin = np .clip (np .floor ((delta_frame - bins [0 ]) / (bins [1 ] - bins [0 ])), 0 , nbins - 1 ).astype ('int64' )
161+ inv_delta_bin = np .clip (np .floor ((- delta_frame - bins [0 ]) / (bins [1 ] - bins [0 ])), 0 , nbins - 1 ).astype ('int64' )
106162
107- all_tp_count1 = []
108- all_fn_count1 = []
109- all_tp_count2 = []
110- all_fn_count2 = []
163+ seg_index = 0
111164
112- self .all_tp = np .zeros ((n , n , self .nbins ), dtype = "int64" )
113- self .all_fn = np .zeros ((n , n , self .nbins ), dtype = "int64" )
165+ loop = range (len (unit_ids1 ))
166+ if self .progress_bar :
167+ loop = tqdm (loop , desc = "collision by bin" )
168+
169+ for c in loop :
114170
115- for i in range (n ):
116- for j in range (i + 1 , n ):
117- u1 = unit_ids [i ]
118- u2 = unit_ids [j ]
171+ score1 = labels_st1 [unit_ids1 [c ]][seg_index ][spike_indices1 [c ]]
172+ score2 = labels_st1 [unit_ids2 [c ]][seg_index ][spike_indices2 [c ]]
119173
120- tp_count1 , fn_count1 , tp_count2 , fn_count2 = self .get_label_count_per_collision_bins (u1 , u2 , bins )
174+ unit_index1 = unit_indices1 [c ]
175+ unit_index2 = unit_indices2 [c ]
176+
177+ if score1 == "TP" :
178+ all_tp [unit_index1 , unit_index2 , delta_bin [c ]] += 1
179+ else :
180+ all_fn [unit_index1 , unit_index2 , delta_bin [c ]] += 1
181+
182+ if score2 == "TP" :
183+ all_tp [unit_index2 , unit_index1 , inv_delta_bin [c ]] += 1
184+ else :
185+ all_fn [unit_index2 , unit_index1 , inv_delta_bin [c ]] += 1
186+
187+ self .all_tp = all_tp
188+ self .all_fn = all_fn
121189
122- self .all_tp [i , j , :] = tp_count1
123- self .all_tp [j , i , :] = tp_count2
124- self .all_fn [i , j , :] = fn_count1
125- self .all_fn [j , i , :] = fn_count2
126190
127191 def compute_collision_by_similarity (self , similarity_matrix , unit_ids = None , good_only = False , min_accuracy = 0.9 ):
128192 if unit_ids is None :
@@ -172,77 +236,3 @@ def compute_collision_by_similarity(self, similarity_matrix, unit_ids=None, good
172236 pair_names = pair_names [order ]
173237
174238 return similarities , recall_scores , pair_names
175-
176-
177- # This is removed at the moment.
178- # We need to move this maybe one day in benchmark.
179- # please do not delete this
180-
181- # class CollisionGTStudy(GroundTruthStudy):
182- # def run_comparisons(self, case_keys=None, exhaustive_gt=True, collision_lag=2.0, nbins=11, **kwargs):
183- # _kwargs = dict()
184- # _kwargs.update(kwargs)
185- # _kwargs["exhaustive_gt"] = exhaustive_gt
186- # _kwargs["collision_lag"] = collision_lag
187- # _kwargs["nbins"] = nbins
188- # GroundTruthStudy.run_comparisons(self, case_keys=None, comparison_class=CollisionGTComparison, **_kwargs)
189- # self.exhaustive_gt = exhaustive_gt
190- # self.collision_lag = collision_lag
191-
192- # def get_lags(self, key):
193- # comp = self.comparisons[key]
194- # fs = comp.sorting1.get_sampling_frequency()
195- # lags = comp.bins / fs * 1000.0
196- # return lags
197-
198- # def precompute_scores_by_similarities(self, case_keys=None, good_only=False, min_accuracy=0.9):
199- # import sklearn
200-
201- # if case_keys is None:
202- # case_keys = self.cases.keys()
203-
204- # self.all_similarities = {}
205- # self.all_recall_scores = {}
206- # self.good_only = good_only
207-
208- # for key in case_keys:
209- # templates = self.get_templates(key)
210- # flat_templates = templates.reshape(templates.shape[0], -1)
211- # similarity = sklearn.metrics.pairwise.cosine_similarity(flat_templates)
212- # comp = self.comparisons[key]
213- # similarities, recall_scores, pair_names = comp.compute_collision_by_similarity(
214- # similarity, good_only=good_only, min_accuracy=min_accuracy
215- # )
216- # self.all_similarities[key] = similarities
217- # self.all_recall_scores[key] = recall_scores
218-
219- # def get_mean_over_similarity_range(self, similarity_range, key):
220- # idx = (self.all_similarities[key] >= similarity_range[0]) & (self.all_similarities[key] <= similarity_range[1])
221- # all_similarities = self.all_similarities[key][idx]
222- # all_recall_scores = self.all_recall_scores[key][idx]
223-
224- # order = np.argsort(all_similarities)
225- # all_similarities = all_similarities[order]
226- # all_recall_scores = all_recall_scores[order, :]
227-
228- # mean_recall_scores = np.nanmean(all_recall_scores, axis=0)
229-
230- # return mean_recall_scores
231-
232- # def get_lag_profile_over_similarity_bins(self, similarity_bins, key):
233- # all_similarities = self.all_similarities[key]
234- # all_recall_scores = self.all_recall_scores[key]
235-
236- # order = np.argsort(all_similarities)
237- # all_similarities = all_similarities[order]
238- # all_recall_scores = all_recall_scores[order, :]
239-
240- # result = {}
241-
242- # for i in range(similarity_bins.size - 1):
243- # cmin, cmax = similarity_bins[i], similarity_bins[i + 1]
244- # amin, amax = np.searchsorted(all_similarities, [cmin, cmax])
245- # mean_recall_scores = np.nanmean(all_recall_scores[amin:amax], axis=0)
246- # result[(cmin, cmax)] = mean_recall_scores
247-
248- # return result
0 commit comments